blob: 36f5878308679bb6ebce506f2f2959790210bd2e [file] [log] [blame]
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat
import (
"fmt"
"testing"
"golang.org/x/exp/rand"
)
type dims struct{ r, c int }
var productTests = []struct {
n int
factors []dims
product dims
panics bool
}{
{
n: 1,
factors: []dims{{3, 4}},
product: dims{3, 4},
panics: false,
},
{
n: 1,
factors: []dims{{2, 4}},
product: dims{3, 4},
panics: true,
},
{
n: 3,
factors: []dims{{10, 30}, {30, 5}, {5, 60}},
product: dims{10, 60},
panics: false,
},
{
n: 3,
factors: []dims{{100, 30}, {30, 5}, {5, 60}},
product: dims{10, 60},
panics: true,
},
{
n: 7,
factors: []dims{{60, 5}, {5, 5}, {5, 4}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
product: dims{60, 10},
panics: false,
},
{
n: 7,
factors: []dims{{60, 5}, {5, 5}, {5, 400}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
product: dims{60, 10},
panics: true,
},
{
n: 3,
factors: []dims{{1, 1000}, {1000, 2}, {2, 2}},
product: dims{1, 2},
panics: false,
},
// Random chains.
{
n: 0,
product: dims{0, 0},
panics: false,
},
{
n: 2,
product: dims{60, 10},
panics: false,
},
{
n: 3,
product: dims{60, 10},
panics: false,
},
{
n: 4,
product: dims{60, 10},
panics: false,
},
{
n: 10,
product: dims{60, 10},
panics: false,
},
}
func TestProduct(t *testing.T) {
t.Parallel()
rnd := rand.New(rand.NewSource(1))
for _, test := range productTests {
dimensions := test.factors
if dimensions == nil && test.n > 0 {
dimensions = make([]dims, test.n)
for i := range dimensions {
if i != 0 {
dimensions[i].r = dimensions[i-1].c
}
dimensions[i].c = rnd.Intn(50) + 1
}
dimensions[0].r = test.product.r
dimensions[test.n-1].c = test.product.c
}
factors := make([]Matrix, test.n)
for i, d := range dimensions {
data := make([]float64, d.r*d.c)
for i := range data {
data[i] = rnd.Float64()
}
factors[i] = NewDense(d.r, d.c, data)
}
want := &Dense{}
if !test.panics {
var a *Dense
for i, b := range factors {
if i == 0 {
want.CloneFrom(b)
continue
}
a, want = want, &Dense{}
want.Mul(a, b)
}
}
got := &Dense{}
if test.product.r != 0 && test.product.c != 0 {
got = NewDense(test.product.r, test.product.c, nil)
}
panicked, message := panics(func() {
got.Product(factors...)
})
if test.panics {
if !panicked {
t.Errorf("fail to panic with product chain dimensions: %+v result dimension: %+v",
dimensions, test.product)
}
continue
} else if panicked {
t.Errorf("unexpected panic %q with product chain dimensions: %+v result dimension: %+v",
message, dimensions, test.product)
continue
}
if len(factors) > 0 {
p := newMultiplier(NewDense(test.product.r, test.product.c, nil), factors)
p.optimize()
gotCost := p.table.at(0, len(factors)-1).cost
expr, wantCost, ok := bestExpressionFor(dimensions)
if !ok {
t.Fatal("unexpected number of expressions in brute force expression search")
}
if gotCost != wantCost {
t.Errorf("unexpected cost for chain dimensions: %+v got: %v want: %v\n%s",
dimensions, got, want, expr)
}
}
if !EqualApprox(got, want, 1e-14) {
t.Errorf("unexpected result from product chain dimensions: %+v", dimensions)
}
}
}
// node is a subexpression node.
type node struct {
dims
left, right *node
}
func (n *node) String() string {
if n.left == nil || n.right == nil {
rows, cols := n.shape()
return fmt.Sprintf("[%d×%d]", rows, cols)
}
rows, cols := n.shape()
return fmt.Sprintf("(%s * %s):[%d×%d]", n.left, n.right, rows, cols)
}
// shape returns the dimensions of the result of the subexpression.
func (n *node) shape() (rows, cols int) {
if n.left == nil || n.right == nil {
return n.r, n.c
}
rows, _ = n.left.shape()
_, cols = n.right.shape()
return rows, cols
}
// cost returns the cost to evaluate the subexpression.
func (n *node) cost() int {
if n.left == nil || n.right == nil {
return 0
}
lr, lc := n.left.shape()
_, rc := n.right.shape()
return lr*lc*rc + n.left.cost() + n.right.cost()
}
// expressionsFor returns a channel that can be used to iterate over all
// expressions of the given factor dimensions.
func expressionsFor(factors []dims) chan *node {
if len(factors) == 1 {
c := make(chan *node, 1)
c <- &node{dims: factors[0]}
close(c)
return c
}
c := make(chan *node)
go func() {
for i := 1; i < len(factors); i++ {
for left := range expressionsFor(factors[:i]) {
for right := range expressionsFor(factors[i:]) {
c <- &node{left: left, right: right}
}
}
}
close(c)
}()
return c
}
// catalan returns the nth 0-based Catalan number.
func catalan(n int) int {
// Work in 64-bit integers since we overflow 32-bits for some tests.
p := int64(1)
for k := n + 1; k < 2*n+1; k++ {
p *= int64(k)
}
for k := 2; k < n+2; k++ {
p /= int64(k)
}
return int(p)
}
// bestExpressonFor returns the lowest cost expression for the given expression
// factor dimensions, the cost of the expression and whether the number of
// expressions searched matches the Catalan number for the number of factors.
func bestExpressionFor(factors []dims) (exp *node, cost int, ok bool) {
const maxInt = int(^uint(0) >> 1)
min := maxInt
var best *node
var n int
for exp := range expressionsFor(factors) {
n++
cost := exp.cost()
if cost < min {
min = cost
best = exp
}
}
return best, min, n == catalan(len(factors)-1)
}