mat: add PivotedCholesky
diff --git a/mat/cholesky.go b/mat/cholesky.go
index 0f957cd..56b35de 100644
--- a/mat/cholesky.go
+++ b/mat/cholesky.go
@@ -25,6 +25,9 @@
_ Symmetric = (*BandCholesky)(nil)
_ Banded = (*BandCholesky)(nil)
_ SymBanded = (*BandCholesky)(nil)
+
+ _ Matrix = (*PivotedCholesky)(nil)
+ _ Symmetric = (*PivotedCholesky)(nil)
)
// Cholesky is a symmetric positive definite matrix represented by its
@@ -936,3 +939,204 @@
func (ch *BandCholesky) valid() bool {
return ch.chol != nil && !ch.chol.IsEmpty()
}
+
+// PivotedCholesky is a symmetric positive semi-definite matrix represented by
+// its Cholesky factorization with complete pivoting.
+//
+// The factorization has the form
+//
+// A = P * Uᵀ * U * Pᵀ
+//
+// where U is an upper triangular matrix and P is a permutation matrix.
+//
+// Cholesky methods may only be called on a receiver that has been successfully
+// initialized by a call to Factorize. SolveTo and SolveVecTo methods may only
+// called if Factorize has returned true.
+//
+// If the matrix A is certainly positive definite, then the unpivoted Cholesky
+// could be more efficient, especially for smaller matrices.
+type PivotedCholesky struct {
+ chol *TriDense // The factor U
+ piv, pivTrans []int // The permutation matrices P and Pᵀ
+ rank int // The computed rank of A
+
+ ok bool // Indicates whether and the factorization can be used for solving linear systems
+ cond float64 // The condition number when ok is true
+}
+
+// Factorize computes the Cholesky factorization of the symmetric positive
+// semi-definite matrix A and returns whether the matrix is positive definite.
+// If Factorize returns false, the SolveTo methods must not be used.
+//
+// tol is a tolerance used to determine the computed rank of A. If it is
+// negative, a default value will be used.
+func (c *PivotedCholesky) Factorize(a Symmetric, tol float64) (ok bool) {
+ n := a.SymmetricDim()
+ c.reset(n)
+ copySymIntoTriangle(c.chol, a)
+
+ work := getFloat64s(3*c.chol.mat.N, false)
+ defer putFloat64s(work)
+
+ sym := c.chol.asSymBlas()
+ aNorm := lapack64.Lansy(CondNorm, sym, work)
+ _, c.rank, c.ok = lapack64.Pstrf(sym, c.piv, tol, work)
+ if c.ok {
+ iwork := getInts(n, false)
+ defer putInts(iwork)
+ c.cond = 1 / lapack64.Pocon(sym, aNorm, work, iwork)
+ }
+ for i, p := range c.piv {
+ c.pivTrans[p] = i
+ }
+
+ return c.ok
+}
+
+// reset prepares the receiver for factorization of matrices of size n.
+func (c *PivotedCholesky) reset(n int) {
+ if c.chol == nil {
+ c.chol = NewTriDense(n, Upper, nil)
+ } else {
+ c.chol.Reset()
+ c.chol.reuseAsNonZeroed(n, Upper)
+ }
+ c.piv = useInt(c.piv, n)
+ c.pivTrans = useInt(c.pivTrans, n)
+ c.rank = 0
+ c.ok = false
+ c.cond = math.Inf(1)
+}
+
+// Dims returns the dimensions of the matrix A.
+func (ch *PivotedCholesky) Dims() (r, c int) {
+ if ch.chol == nil {
+ panic(badCholesky)
+ }
+ r, c = ch.chol.Dims()
+ return r, c
+}
+
+// At returns the element of A at row i, column j.
+func (c *PivotedCholesky) At(i, j int) float64 {
+ if c.chol == nil {
+ panic(badCholesky)
+ }
+ n := c.SymmetricDim()
+ if uint(i) >= uint(n) {
+ panic(ErrRowAccess)
+ }
+ if uint(j) >= uint(n) {
+ panic(ErrColAccess)
+ }
+
+ i = c.pivTrans[i]
+ j = c.pivTrans[j]
+ minij := min(min(i+1, j+1), c.rank)
+ var val float64
+ for k := 0; k < minij; k++ {
+ val += c.chol.at(k, i) * c.chol.at(k, j)
+ }
+ return val
+}
+
+// T returns the receiver, the transpose of a symmetric matrix.
+func (c *PivotedCholesky) T() Matrix {
+ return c
+}
+
+// SymmetricDim implements the Symmetric interface and returns the number of
+// rows (or columns) in the matrix .
+func (c *PivotedCholesky) SymmetricDim() int {
+ if c.chol == nil {
+ panic(badCholesky)
+ }
+ n, _ := c.chol.Dims()
+ return n
+}
+
+// Rank returns the computed rank of the matrix A.
+func (c *PivotedCholesky) Rank() int {
+ if c.chol == nil {
+ panic(badCholesky)
+ }
+ return c.rank
+}
+
+// Cond returns the condition number of the factorized matrix.
+func (c *PivotedCholesky) Cond() float64 {
+ if !c.ok {
+ panic(badCholesky)
+ }
+ return c.cond
+}
+
+// SolveTo finds the matrix X that solves A * X = B where A is represented by
+// the Cholesky decomposition. The result is stored in-place into dst. If the
+// Cholesky decomposition is singular or near-singular, a Condition error is
+// returned. See the documentation for Condition for more information.
+//
+// If Factorize returned false, SolveTo will panic.
+func (c *PivotedCholesky) SolveTo(dst *Dense, b Matrix) error {
+ if !c.ok {
+ panic(badCholesky)
+ }
+ n := c.chol.mat.N
+ bm, bn := b.Dims()
+ if n != bm {
+ panic(ErrShape)
+ }
+
+ dst.reuseAsNonZeroed(bm, bn)
+ if dst != b {
+ dst.Copy(b)
+ }
+
+ // Permute rows of B: D = Pᵀ * B.
+ lapack64.Lapmr(true, dst.mat, c.piv)
+ // Solve Uᵀ * U * Y = D.
+ lapack64.Potrs(c.chol.mat, dst.mat)
+ // Permute rows of Y to recover the solution: X = P * Y.
+ lapack64.Lapmr(false, dst.mat, c.piv)
+
+ if c.cond > ConditionTolerance {
+ return Condition(c.cond)
+ }
+ return nil
+}
+
+// SolveVecTo finds the vector x that solves A * x = b where A is represented by
+// the Cholesky decomposition. The result is stored in-place into dst. If the
+// Cholesky decomposition is singular or near-singular, a Condition error is
+// returned. See the documentation for Condition for more information.
+//
+// If Factorize returned false, SolveVecTo will panic.
+func (c *PivotedCholesky) SolveVecTo(dst *VecDense, b Vector) error {
+ if !c.ok {
+ panic(badCholesky)
+ }
+ n := c.chol.mat.N
+ if br, bc := b.Dims(); br != n || bc != 1 {
+ panic(ErrShape)
+ }
+ if b, ok := b.(RawVectorer); ok && dst != b {
+ dst.checkOverlap(b.RawVector())
+ }
+
+ dst.reuseAsNonZeroed(n)
+ if dst != b {
+ dst.CopyVec(b)
+ }
+
+ // Permute rows of B: D = Pᵀ * B.
+ lapack64.Lapmr(true, dst.asGeneral(), c.piv)
+ // Solve Uᵀ * U * Y = D.
+ lapack64.Potrs(c.chol.mat, dst.asGeneral())
+ // Permute rows of Y to recover the solution: X = P * Y.
+ lapack64.Lapmr(false, dst.asGeneral(), c.piv)
+
+ if c.cond > ConditionTolerance {
+ return Condition(c.cond)
+ }
+ return nil
+}
diff --git a/mat/cholesky_test.go b/mat/cholesky_test.go
index de10a4b..559733b 100644
--- a/mat/cholesky_test.go
+++ b/mat/cholesky_test.go
@@ -924,3 +924,187 @@
}
}
}
+
+func TestPivotedCholesky(t *testing.T) {
+ t.Parallel()
+
+ const tol = 1e-14
+ src := rand.NewSource(1)
+ for _, n := range []int{1, 2, 3, 4, 5, 10} {
+ for _, rank := range []int{int(0.3 * float64(n)), int(0.7 * float64(n)), n} {
+ name := fmt.Sprintf("n=%d, rank=%d", n, rank)
+
+ // Generate a random symmetric semi-definite matrix A with the given rank.
+ a := NewSymDense(n, nil)
+ for i := 0; i < rank; i++ {
+ x := randVecDense(n, 1, 1, src)
+ a.SymRankOne(a, 1, x)
+ }
+
+ // Compute the pivoted Cholesky factorization of A.
+ var chol PivotedCholesky
+ ok := chol.Factorize(a, -1)
+
+ // Check that the ok return matches the rank of A.
+ if !ok && rank == n {
+ t.Errorf("%s: unexpected factorization failure with full rank", name)
+ }
+ if ok && rank != n {
+ t.Errorf("%s: unexpected factorization success with deficit rank", name)
+ }
+
+ // Check that the computed rank matches the rank of A.
+ if chol.Rank() != rank {
+ t.Errorf("%s: unexpected computed rank, got %d", name, chol.Rank())
+ }
+
+ // Check the size.
+ r, c := chol.Dims()
+ if r != n || c != n {
+ t.Errorf("n=%d, rank=%d: unexpected dims: r=%d, c=%d", n, rank, r, c)
+ }
+ if chol.SymmetricDim() != n {
+ t.Errorf("n=%d, rank=%d: unexpected symmetric dim: dim=%d", n, rank, chol.SymmetricDim())
+ }
+
+ // Compute the norm of the difference |P*Uᵀ*U*Pᵀ - A|.
+ diff := NewDense(n, n, nil)
+ for i := 0; i < n; i++ {
+ for j := 0; j < n; j++ {
+ diff.Set(i, j, chol.At(i, j)-a.At(i, j))
+ }
+ }
+ res := Norm(diff, 1)
+ if res > tol {
+ t.Errorf("n=%d, rank=%d: unexpected result (|diff|=%v)\ndiff = %.4g", n, rank, res, Formatted(diff, Prefix(" ")))
+ }
+ }
+ }
+}
+
+func TestPivotedCholeskySolveTo(t *testing.T) {
+ t.Parallel()
+
+ const (
+ nrhs = 4
+ tol = 1e-14
+ )
+ rnd := rand.New(rand.NewSource(1))
+ for _, n := range []int{1, 2, 3, 5, 10} {
+ a := NewSymDense(n, nil)
+ for i := 0; i < n; i++ {
+ a.SetSym(i, i, rnd.Float64()+float64(n))
+ for j := i + 1; j < n; j++ {
+ a.SetSym(i, j, rnd.Float64())
+ }
+ }
+
+ want := NewDense(n, nrhs, nil)
+ for i := 0; i < n; i++ {
+ for j := 0; j < nrhs; j++ {
+ want.Set(i, j, rnd.NormFloat64())
+ }
+ }
+
+ var b Dense
+ b.Mul(a, want)
+
+ for _, typ := range []Symmetric{a, asBasicSymmetric(a)} {
+ name := fmt.Sprintf("Case n=%d,type=%T,nrhs=%d", n, typ, nrhs)
+
+ var chol PivotedCholesky
+ ok := chol.Factorize(typ, -1)
+ if !ok {
+ t.Fatalf("%v: matrix not positive definite", name)
+ }
+
+ var got Dense
+ err := chol.SolveTo(&got, &b)
+ if err != nil {
+ t.Errorf("%v: unexpected error from SolveTo: %v", name, err)
+ continue
+ }
+
+ var resid Dense
+ resid.Sub(want, &got)
+ diff := Norm(&resid, math.Inf(1))
+ if diff > tol {
+ t.Errorf("%v: unexpected solution; diff=%v", name, diff)
+ }
+
+ got.Copy(&b)
+ err = chol.SolveTo(&got, &got)
+ if err != nil {
+ t.Errorf("%v: unexpected error from SolveTo when dst==b: %v", name, err)
+ continue
+ }
+
+ resid.Sub(want, &got)
+ diff = Norm(&resid, math.Inf(1))
+ if diff > tol {
+ t.Errorf("%v: unexpected solution when dst==b; diff=%v", name, diff)
+ }
+ }
+ }
+}
+
+func TestPivotedCholeskySolveVecTo(t *testing.T) {
+ t.Parallel()
+
+ const tol = 1e-14
+ rnd := rand.New(rand.NewSource(1))
+ for _, n := range []int{1, 2, 3, 5, 10} {
+
+ a := NewSymDense(n, nil)
+ for i := 0; i < n; i++ {
+ a.SetSym(i, i, rnd.Float64()+float64(n))
+ for j := i + 1; j < n; j++ {
+ a.SetSym(i, j, rnd.Float64())
+ }
+ }
+
+ want := NewVecDense(n, nil)
+ for i := 0; i < n; i++ {
+ want.SetVec(i, rnd.NormFloat64())
+ }
+ var b VecDense
+ b.MulVec(a, want)
+
+ for _, typ := range []Symmetric{a, asBasicSymmetric(a)} {
+ name := fmt.Sprintf("Case n=%d,type=%T", n, typ)
+
+ var chol PivotedCholesky
+ ok := chol.Factorize(typ, -1)
+ if !ok {
+ t.Fatalf("%v: matrix not positive definite", name)
+ }
+
+ var got VecDense
+ err := chol.SolveVecTo(&got, &b)
+ if err != nil {
+ t.Errorf("%v: unexpected error from SolveVecTo: %v", name, err)
+ continue
+ }
+
+ var resid VecDense
+ resid.SubVec(want, &got)
+ diff := Norm(&resid, math.Inf(1))
+ if diff > tol {
+ t.Errorf("%v: unexpected solution; diff=%v", name, diff)
+ }
+
+ got.CopyVec(&b)
+ err = chol.SolveVecTo(&got, &got)
+ if err != nil {
+ t.Errorf("%v: unexpected error from SolveVecTo when dst==b: %v", name, err)
+ continue
+ }
+
+ resid.SubVec(want, &got)
+ diff = Norm(&resid, math.Inf(1))
+ if diff > tol {
+ t.Errorf("%v: unexpected solution when dst==b; diff=%v", name, diff)
+ }
+ }
+ }
+}