blob: d8e0302758b9d3277fb8dd6e1355f19c63e061d1 [file] [log] [blame]
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat
import "fmt"
// Product calculates the product of the given factors and places the result in
// the receiver. The order of multiplication operations is optimized to minimize
// the number of floating point operations on the basis that all matrix
// multiplications are general.
func (m *Dense) Product(factors ...Matrix) {
// The operation order optimisation is the naive O(n^3) dynamic
// programming approach and does not take into consideration
// finer-grained optimisations that might be available.
//
// TODO(kortschak) Consider using the O(nlogn) or O(mlogn)
// algorithms that are available. e.g.
//
// e.g. http://www.jofcis.com/publishedpapers/2014_10_10_4299_4306.pdf
//
// In the case that this is replaced, retain this code in
// tests to compare against.
r, c := m.Dims()
switch len(factors) {
case 0:
if r != 0 || c != 0 {
panic(ErrShape)
}
return
case 1:
m.reuseAsNonZeroed(factors[0].Dims())
m.Copy(factors[0])
return
case 2:
// Don't do work that we know the answer to.
m.Mul(factors[0], factors[1])
return
}
p := newMultiplier(m, factors)
p.optimize()
result := p.multiply()
m.reuseAsNonZeroed(result.Dims())
m.Copy(result)
putWorkspace(result)
}
// debugProductWalk enables debugging output for Product.
const debugProductWalk = false
// multiplier performs operation order optimisation and tree traversal.
type multiplier struct {
// factors is the ordered set of
// factors to multiply.
factors []Matrix
// dims is the chain of factor
// dimensions.
dims []int
// table contains the dynamic
// programming costs and subchain
// division indices.
table table
}
func newMultiplier(m *Dense, factors []Matrix) *multiplier {
// Check size early, but don't yet
// allocate data for m.
r, c := m.Dims()
fr, fc := factors[0].Dims() // newMultiplier is only called with len(factors) > 2.
if !m.IsEmpty() {
if fr != r {
panic(ErrShape)
}
if _, lc := factors[len(factors)-1].Dims(); lc != c {
panic(ErrShape)
}
}
dims := make([]int, len(factors)+1)
dims[0] = r
dims[len(dims)-1] = c
pc := fc
for i, f := range factors[1:] {
cr, cc := f.Dims()
dims[i+1] = cr
if pc != cr {
panic(ErrShape)
}
pc = cc
}
return &multiplier{
factors: factors,
dims: dims,
table: newTable(len(factors)),
}
}
// optimize determines an optimal matrix multiply operation order.
func (p *multiplier) optimize() {
if debugProductWalk {
fmt.Printf("chain dims: %v\n", p.dims)
}
const maxInt = int(^uint(0) >> 1)
for f := 1; f < len(p.factors); f++ {
for i := 0; i < len(p.factors)-f; i++ {
j := i + f
p.table.set(i, j, entry{cost: maxInt})
for k := i; k < j; k++ {
cost := p.table.at(i, k).cost + p.table.at(k+1, j).cost + p.dims[i]*p.dims[k+1]*p.dims[j+1]
if cost < p.table.at(i, j).cost {
p.table.set(i, j, entry{cost: cost, k: k})
}
}
}
}
}
// multiply walks the optimal operation tree found by optimize,
// leaving the final result in the stack. It returns the
// product, which may be copied but should be returned to
// the workspace pool.
func (p *multiplier) multiply() *Dense {
result, _ := p.multiplySubchain(0, len(p.factors)-1)
if debugProductWalk {
r, c := result.Dims()
fmt.Printf("\tpop result (%d×%d) cost=%d\n", r, c, p.table.at(0, len(p.factors)-1).cost)
}
return result.(*Dense)
}
func (p *multiplier) multiplySubchain(i, j int) (m Matrix, intermediate bool) {
if i == j {
return p.factors[i], false
}
a, aTmp := p.multiplySubchain(i, p.table.at(i, j).k)
b, bTmp := p.multiplySubchain(p.table.at(i, j).k+1, j)
ar, ac := a.Dims()
br, bc := b.Dims()
if ac != br {
// Panic with a string since this
// is not a user-facing panic.
panic(ErrShape.Error())
}
if debugProductWalk {
fmt.Printf("\tpush f[%d] (%d×%d)%s * f[%d] (%d×%d)%s\n",
i, ar, ac, result(aTmp), j, br, bc, result(bTmp))
}
r := getWorkspace(ar, bc, false)
r.Mul(a, b)
if aTmp {
putWorkspace(a.(*Dense))
}
if bTmp {
putWorkspace(b.(*Dense))
}
return r, true
}
type entry struct {
k int // is the chain subdivision index.
cost int // cost is the cost of the operation.
}
// table is a row major n×n dynamic programming table.
type table struct {
n int
entries []entry
}
func newTable(n int) table {
return table{n: n, entries: make([]entry, n*n)}
}
func (t table) at(i, j int) entry { return t.entries[i*t.n+j] }
func (t table) set(i, j int, e entry) { t.entries[i*t.n+j] = e }
type result bool
func (r result) String() string {
if r {
return " (popped result)"
}
return ""
}