| // Copyright ©2015 The Gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package mat |
| |
| import ( |
| "fmt" |
| "testing" |
| |
| "golang.org/x/exp/rand" |
| ) |
| |
| type dims struct{ r, c int } |
| |
| var productTests = []struct { |
| n int |
| factors []dims |
| product dims |
| panics bool |
| }{ |
| { |
| n: 1, |
| factors: []dims{{3, 4}}, |
| product: dims{3, 4}, |
| panics: false, |
| }, |
| { |
| n: 1, |
| factors: []dims{{2, 4}}, |
| product: dims{3, 4}, |
| panics: true, |
| }, |
| { |
| n: 3, |
| factors: []dims{{10, 30}, {30, 5}, {5, 60}}, |
| product: dims{10, 60}, |
| panics: false, |
| }, |
| { |
| n: 3, |
| factors: []dims{{100, 30}, {30, 5}, {5, 60}}, |
| product: dims{10, 60}, |
| panics: true, |
| }, |
| { |
| n: 7, |
| factors: []dims{{60, 5}, {5, 5}, {5, 4}, {4, 10}, {10, 22}, {22, 45}, {45, 10}}, |
| product: dims{60, 10}, |
| panics: false, |
| }, |
| { |
| n: 7, |
| factors: []dims{{60, 5}, {5, 5}, {5, 400}, {4, 10}, {10, 22}, {22, 45}, {45, 10}}, |
| product: dims{60, 10}, |
| panics: true, |
| }, |
| { |
| n: 3, |
| factors: []dims{{1, 1000}, {1000, 2}, {2, 2}}, |
| product: dims{1, 2}, |
| panics: false, |
| }, |
| |
| // Random chains. |
| { |
| n: 0, |
| product: dims{0, 0}, |
| panics: false, |
| }, |
| { |
| n: 2, |
| product: dims{60, 10}, |
| panics: false, |
| }, |
| { |
| n: 3, |
| product: dims{60, 10}, |
| panics: false, |
| }, |
| { |
| n: 4, |
| product: dims{60, 10}, |
| panics: false, |
| }, |
| { |
| n: 10, |
| product: dims{60, 10}, |
| panics: false, |
| }, |
| } |
| |
| func TestProduct(t *testing.T) { |
| t.Parallel() |
| rnd := rand.New(rand.NewSource(1)) |
| for _, test := range productTests { |
| dimensions := test.factors |
| if dimensions == nil && test.n > 0 { |
| dimensions = make([]dims, test.n) |
| for i := range dimensions { |
| if i != 0 { |
| dimensions[i].r = dimensions[i-1].c |
| } |
| dimensions[i].c = rnd.Intn(50) + 1 |
| } |
| dimensions[0].r = test.product.r |
| dimensions[test.n-1].c = test.product.c |
| } |
| factors := make([]Matrix, test.n) |
| for i, d := range dimensions { |
| data := make([]float64, d.r*d.c) |
| for i := range data { |
| data[i] = rnd.Float64() |
| } |
| factors[i] = NewDense(d.r, d.c, data) |
| } |
| |
| want := &Dense{} |
| if !test.panics { |
| var a *Dense |
| for i, b := range factors { |
| if i == 0 { |
| want.CloneFrom(b) |
| continue |
| } |
| a, want = want, &Dense{} |
| want.Mul(a, b) |
| } |
| } |
| |
| got := &Dense{} |
| if test.product.r != 0 && test.product.c != 0 { |
| got = NewDense(test.product.r, test.product.c, nil) |
| } |
| panicked, message := panics(func() { |
| got.Product(factors...) |
| }) |
| if test.panics { |
| if !panicked { |
| t.Errorf("fail to panic with product chain dimensions: %+v result dimension: %+v", |
| dimensions, test.product) |
| } |
| continue |
| } else if panicked { |
| t.Errorf("unexpected panic %q with product chain dimensions: %+v result dimension: %+v", |
| message, dimensions, test.product) |
| continue |
| } |
| |
| if len(factors) > 0 { |
| p := newMultiplier(NewDense(test.product.r, test.product.c, nil), factors) |
| p.optimize() |
| gotCost := p.table.at(0, len(factors)-1).cost |
| expr, wantCost, ok := bestExpressionFor(dimensions) |
| if !ok { |
| t.Fatal("unexpected number of expressions in brute force expression search") |
| } |
| if gotCost != wantCost { |
| t.Errorf("unexpected cost for chain dimensions: %+v got: %v want: %v\n%s", |
| dimensions, got, want, expr) |
| } |
| } |
| |
| if !EqualApprox(got, want, 1e-14) { |
| t.Errorf("unexpected result from product chain dimensions: %+v", dimensions) |
| } |
| } |
| } |
| |
| // node is a subexpression node. |
| type node struct { |
| dims |
| left, right *node |
| } |
| |
| func (n *node) String() string { |
| if n.left == nil || n.right == nil { |
| rows, cols := n.shape() |
| return fmt.Sprintf("[%d×%d]", rows, cols) |
| } |
| rows, cols := n.shape() |
| return fmt.Sprintf("(%s * %s):[%d×%d]", n.left, n.right, rows, cols) |
| } |
| |
| // shape returns the dimensions of the result of the subexpression. |
| func (n *node) shape() (rows, cols int) { |
| if n.left == nil || n.right == nil { |
| return n.r, n.c |
| } |
| rows, _ = n.left.shape() |
| _, cols = n.right.shape() |
| return rows, cols |
| } |
| |
| // cost returns the cost to evaluate the subexpression. |
| func (n *node) cost() int { |
| if n.left == nil || n.right == nil { |
| return 0 |
| } |
| lr, lc := n.left.shape() |
| _, rc := n.right.shape() |
| return lr*lc*rc + n.left.cost() + n.right.cost() |
| } |
| |
| // expressionsFor returns a channel that can be used to iterate over all |
| // expressions of the given factor dimensions. |
| func expressionsFor(factors []dims) chan *node { |
| if len(factors) == 1 { |
| c := make(chan *node, 1) |
| c <- &node{dims: factors[0]} |
| close(c) |
| return c |
| } |
| c := make(chan *node) |
| go func() { |
| for i := 1; i < len(factors); i++ { |
| for left := range expressionsFor(factors[:i]) { |
| for right := range expressionsFor(factors[i:]) { |
| c <- &node{left: left, right: right} |
| } |
| } |
| } |
| close(c) |
| }() |
| return c |
| } |
| |
| // catalan returns the nth 0-based Catalan number. |
| func catalan(n int) int { |
| // Work in 64-bit integers since we overflow 32-bits for some tests. |
| p := int64(1) |
| for k := n + 1; k < 2*n+1; k++ { |
| p *= int64(k) |
| } |
| for k := 2; k < n+2; k++ { |
| p /= int64(k) |
| } |
| return int(p) |
| } |
| |
| // bestExpressonFor returns the lowest cost expression for the given expression |
| // factor dimensions, the cost of the expression and whether the number of |
| // expressions searched matches the Catalan number for the number of factors. |
| func bestExpressionFor(factors []dims) (exp *node, cost int, ok bool) { |
| const maxInt = int(^uint(0) >> 1) |
| min := maxInt |
| var best *node |
| var n int |
| for exp := range expressionsFor(factors) { |
| n++ |
| cost := exp.cost() |
| if cost < min { |
| min = cost |
| best = exp |
| } |
| } |
| return best, min, n == catalan(len(factors)-1) |
| } |