| // Copyright ©2015 The Gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package mat |
| |
| import "fmt" |
| |
| // Product calculates the product of the given factors and places the result in |
| // the receiver. The order of multiplication operations is optimized to minimize |
| // the number of floating point operations on the basis that all matrix |
| // multiplications are general. |
| func (m *Dense) Product(factors ...Matrix) { |
| // The operation order optimisation is the naive O(n^3) dynamic |
| // programming approach and does not take into consideration |
| // finer-grained optimisations that might be available. |
| // |
| // TODO(kortschak) Consider using the O(nlogn) or O(mlogn) |
| // algorithms that are available. e.g. |
| // |
| // e.g. http://www.jofcis.com/publishedpapers/2014_10_10_4299_4306.pdf |
| // |
| // In the case that this is replaced, retain this code in |
| // tests to compare against. |
| |
| r, c := m.Dims() |
| switch len(factors) { |
| case 0: |
| if r != 0 || c != 0 { |
| panic(ErrShape) |
| } |
| return |
| case 1: |
| m.reuseAsNonZeroed(factors[0].Dims()) |
| m.Copy(factors[0]) |
| return |
| case 2: |
| // Don't do work that we know the answer to. |
| m.Mul(factors[0], factors[1]) |
| return |
| } |
| |
| p := newMultiplier(m, factors) |
| p.optimize() |
| result := p.multiply() |
| m.reuseAsNonZeroed(result.Dims()) |
| m.Copy(result) |
| putWorkspace(result) |
| } |
| |
| // debugProductWalk enables debugging output for Product. |
| const debugProductWalk = false |
| |
| // multiplier performs operation order optimisation and tree traversal. |
| type multiplier struct { |
| // factors is the ordered set of |
| // factors to multiply. |
| factors []Matrix |
| // dims is the chain of factor |
| // dimensions. |
| dims []int |
| |
| // table contains the dynamic |
| // programming costs and subchain |
| // division indices. |
| table table |
| } |
| |
| func newMultiplier(m *Dense, factors []Matrix) *multiplier { |
| // Check size early, but don't yet |
| // allocate data for m. |
| r, c := m.Dims() |
| fr, fc := factors[0].Dims() // newMultiplier is only called with len(factors) > 2. |
| if !m.IsEmpty() { |
| if fr != r { |
| panic(ErrShape) |
| } |
| if _, lc := factors[len(factors)-1].Dims(); lc != c { |
| panic(ErrShape) |
| } |
| } |
| |
| dims := make([]int, len(factors)+1) |
| dims[0] = r |
| dims[len(dims)-1] = c |
| pc := fc |
| for i, f := range factors[1:] { |
| cr, cc := f.Dims() |
| dims[i+1] = cr |
| if pc != cr { |
| panic(ErrShape) |
| } |
| pc = cc |
| } |
| |
| return &multiplier{ |
| factors: factors, |
| dims: dims, |
| table: newTable(len(factors)), |
| } |
| } |
| |
| // optimize determines an optimal matrix multiply operation order. |
| func (p *multiplier) optimize() { |
| if debugProductWalk { |
| fmt.Printf("chain dims: %v\n", p.dims) |
| } |
| const maxInt = int(^uint(0) >> 1) |
| for f := 1; f < len(p.factors); f++ { |
| for i := 0; i < len(p.factors)-f; i++ { |
| j := i + f |
| p.table.set(i, j, entry{cost: maxInt}) |
| for k := i; k < j; k++ { |
| cost := p.table.at(i, k).cost + p.table.at(k+1, j).cost + p.dims[i]*p.dims[k+1]*p.dims[j+1] |
| if cost < p.table.at(i, j).cost { |
| p.table.set(i, j, entry{cost: cost, k: k}) |
| } |
| } |
| } |
| } |
| } |
| |
| // multiply walks the optimal operation tree found by optimize, |
| // leaving the final result in the stack. It returns the |
| // product, which may be copied but should be returned to |
| // the workspace pool. |
| func (p *multiplier) multiply() *Dense { |
| result, _ := p.multiplySubchain(0, len(p.factors)-1) |
| if debugProductWalk { |
| r, c := result.Dims() |
| fmt.Printf("\tpop result (%d×%d) cost=%d\n", r, c, p.table.at(0, len(p.factors)-1).cost) |
| } |
| return result.(*Dense) |
| } |
| |
| func (p *multiplier) multiplySubchain(i, j int) (m Matrix, intermediate bool) { |
| if i == j { |
| return p.factors[i], false |
| } |
| |
| a, aTmp := p.multiplySubchain(i, p.table.at(i, j).k) |
| b, bTmp := p.multiplySubchain(p.table.at(i, j).k+1, j) |
| |
| ar, ac := a.Dims() |
| br, bc := b.Dims() |
| if ac != br { |
| // Panic with a string since this |
| // is not a user-facing panic. |
| panic(ErrShape.Error()) |
| } |
| |
| if debugProductWalk { |
| fmt.Printf("\tpush f[%d] (%d×%d)%s * f[%d] (%d×%d)%s\n", |
| i, ar, ac, result(aTmp), j, br, bc, result(bTmp)) |
| } |
| |
| r := getWorkspace(ar, bc, false) |
| r.Mul(a, b) |
| if aTmp { |
| putWorkspace(a.(*Dense)) |
| } |
| if bTmp { |
| putWorkspace(b.(*Dense)) |
| } |
| return r, true |
| } |
| |
| type entry struct { |
| k int // is the chain subdivision index. |
| cost int // cost is the cost of the operation. |
| } |
| |
| // table is a row major n×n dynamic programming table. |
| type table struct { |
| n int |
| entries []entry |
| } |
| |
| func newTable(n int) table { |
| return table{n: n, entries: make([]entry, n*n)} |
| } |
| |
| func (t table) at(i, j int) entry { return t.entries[i*t.n+j] } |
| func (t table) set(i, j int, e entry) { t.entries[i*t.n+j] = e } |
| |
| type result bool |
| |
| func (r result) String() string { |
| if r { |
| return " (popped result)" |
| } |
| return "" |
| } |