blob: 38a5d8a3442b81a85f1b449eeeb26baf1bf0d81f [file] [log] [blame]
 // Copyright ©2014 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package fd import ( "math" "testing" ) var xSquared = func(x float64) float64 { return x * x } type testPoint struct { f func(float64) float64 loc float64 fofx float64 ans float64 } var testsFirst = []testPoint{ { f: xSquared, loc: 0, fofx: 0, ans: 0, }, { f: xSquared, loc: 5, fofx: 25, ans: 10, }, { f: xSquared, loc: 2, fofx: 4, ans: 4, }, { f: xSquared, loc: -5, fofx: 25, ans: -10, }, } var testsSecond = []testPoint{ { f: xSquared, loc: 0, fofx: 0, ans: 2, }, { f: xSquared, loc: 5, fofx: 25, ans: 2, }, { f: xSquared, loc: 2, fofx: 4, ans: 2, }, { f: xSquared, loc: -5, fofx: 25, ans: 2, }, } func testDerivative(t *testing.T, formula Formula, tol float64, tests []testPoint) { for i, test := range tests { ans := Derivative(test.f, test.loc, &Settings{ Formula: formula, }) if math.Abs(test.ans-ans) > tol { t.Errorf("Case %v: ans mismatch serial: expected %v, found %v", i, test.ans, ans) } ans = Derivative(test.f, test.loc, &Settings{ Formula: formula, OriginKnown: true, OriginValue: test.fofx, }) if math.Abs(test.ans-ans) > tol { t.Errorf("Case %v: ans mismatch serial origin known: expected %v, found %v", i, test.ans, ans) } ans = Derivative(test.f, test.loc, &Settings{ Formula: formula, Concurrent: true, }) if math.Abs(test.ans-ans) > tol { t.Errorf("Case %v: ans mismatch concurrent: expected %v, found %v", i, test.ans, ans) } ans = Derivative(test.f, test.loc, &Settings{ Formula: formula, OriginKnown: true, OriginValue: test.fofx, Concurrent: true, }) if math.Abs(test.ans-ans) > tol { t.Errorf("Case %v: ans mismatch concurrent: expected %v, found %v", i, test.ans, ans) } } } func TestForward(t *testing.T) { t.Parallel() testDerivative(t, Forward, 2e-4, testsFirst) } func TestBackward(t *testing.T) { t.Parallel() testDerivative(t, Backward, 2e-4, testsFirst) } func TestCentral(t *testing.T) { t.Parallel() testDerivative(t, Central, 1e-6, testsFirst) } func TestCentralSecond(t *testing.T) { t.Parallel() testDerivative(t, Central2nd, 1e-3, testsSecond) } // TestDerivativeDefault checks that the derivative works when settings is nil // or zero value. func TestDerivativeDefault(t *testing.T) { t.Parallel() tol := 1e-6 for i, test := range testsFirst { ans := Derivative(test.f, test.loc, nil) if math.Abs(test.ans-ans) > tol { t.Errorf("Case %v: ans mismatch default: expected %v, found %v", i, test.ans, ans) } ans = Derivative(test.f, test.loc, &Settings{}) if math.Abs(test.ans-ans) > tol { t.Errorf("Case %v: ans mismatch zero value: expected %v, found %v", i, test.ans, ans) } } }