blob: a2aacd288cb1b146dc9496ae7adcd8fcf54b2814 [file] [log] [blame]
// Copyright ©2020 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fourier
import (
"bytes"
"fmt"
"math/bits"
"strconv"
"testing"
"unsafe"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/cmplxs"
)
func TestCoefficients(t *testing.T) {
const tol = 1e-8
src := rand.NewSource(1)
for n := 4; n < 1<<20; n <<= 1 {
for i := 0; i < 10; i++ {
t.Run(fmt.Sprintf("Radix2/%d", n), func(t *testing.T) {
d := randComplexes(n, src)
fft := NewCmplxFFT(n)
want := fft.Coefficients(nil, d)
CoefficientsRadix2(d)
got := d
if !cmplxs.EqualApprox(got, want, tol) {
t.Errorf("unexpected result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
}
want = fft.Sequence(nil, got)
scale(1/float64(n), want)
SequenceRadix2(got)
scale(1/float64(n), got)
if !cmplxs.EqualApprox(got, want, tol) {
t.Errorf("unexpected ifft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
}
})
if bits.Len(uint(n))&0x1 == 0 {
continue
}
t.Run(fmt.Sprintf("Radix4/%d", n), func(t *testing.T) {
d := randComplexes(n, src)
fft := NewCmplxFFT(n)
want := fft.Coefficients(nil, d)
CoefficientsRadix4(d)
got := d
if !cmplxs.EqualApprox(got, want, tol) {
t.Errorf("unexpected fft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
}
want = fft.Sequence(nil, got)
scale(1/float64(n), want)
SequenceRadix4(got)
scale(1/float64(n), got)
if !cmplxs.EqualApprox(got, want, tol) {
t.Errorf("unexpected ifft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
}
})
}
}
}
func TestSequence(t *testing.T) {
const tol = 1e-10
src := rand.NewSource(1)
for n := 4; n < 1<<20; n <<= 1 {
for i := 0; i < 10; i++ {
t.Run(fmt.Sprintf("Radix2/%d", n), func(t *testing.T) {
d := randComplexes(n, src)
want := make([]complex128, n)
copy(want, d)
SequenceRadix2(CoefficientsRadix2(d))
got := d
scale(1/float64(n), got)
if !cmplxs.EqualApprox(got, want, tol) {
t.Errorf("unexpected result for ifft(fft(d)) n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
}
})
if bits.Len(uint(n))&0x1 == 0 {
continue
}
t.Run(fmt.Sprintf("Radix4/%d", n), func(t *testing.T) {
d := randComplexes(n, src)
want := make([]complex128, n)
copy(want, d)
SequenceRadix4(CoefficientsRadix4(d))
got := d
scale(1/float64(n), got)
if !cmplxs.EqualApprox(got, want, tol) {
t.Errorf("unexpected result for ifft(fft(d)) n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
}
})
}
}
}
func scale(f float64, c []complex128) {
for i, v := range c {
c[i] = complex(f*real(v), f*imag(v))
}
}
func TestBitReversePermute(t *testing.T) {
for n := 2; n <= 1024; n <<= 1 {
x := make([]complex128, n)
for i := range x {
x[i] = complex(float64(i), float64(i))
}
bitReversePermute(x)
for i, got := range x {
j := bits.Reverse(uint(i)) >> bits.LeadingZeros(uint(n-1))
want := complex(float64(j), float64(j))
if got != want {
t.Errorf("unexpected value at %d: got:%f want:%f", i, got, want)
}
}
}
}
func TestPadRadix2(t *testing.T) {
for n := 1; n <= 1025; n++ {
x := make([]complex128, n)
y := PadRadix2(x)
if bits.OnesCount(uint(len(y))) != 1 {
t.Errorf("unexpected length of padded slice: not a power of 2: %d", len(y))
}
if len(x) == len(y) && &y[0] != &x[0] {
t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n)
}
if len(y) < len(x) {
t.Errorf("unexpected short result: len(y)=%d < len(x)=%d", len(y), len(x))
}
}
}
func TestTrimRadix2(t *testing.T) {
for n := 1; n <= 1025; n++ {
x := make([]complex128, n)
y, r := TrimRadix2(x)
if bits.OnesCount(uint(len(y))) != 1 {
t.Errorf("unexpected length of trimmed slice: not a power of 2: %d", len(y))
}
if len(y)+len(r) != len(x) {
t.Errorf("unexpected total result: len(y)=%d + len(r)%d != len(x)=%d", len(y), len(r), len(x))
}
if len(x) == len(y) && &y[0] != &x[0] {
t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n)
}
if len(y) > len(x) {
t.Errorf("unexpected long result: len(y)=%d > len(x)=%d", len(y), len(x))
}
}
}
func TestBitPairReversePermute(t *testing.T) {
for n := 4; n <= 1024; n <<= 2 {
x := make([]complex128, n)
for i := range x {
x[i] = complex(float64(i), float64(i))
}
bitPairReversePermute(x)
for i, got := range x {
j := reversePairs(uint(i)) >> bits.LeadingZeros(uint(n-1))
want := complex(float64(j), float64(j))
if got != want {
t.Errorf("unexpected value at %d: got:%f want:%f", i, got, want)
}
}
}
}
func TestReversePairs(t *testing.T) {
rnd := rand.New(rand.NewSource(1))
for i := 0; i < 1000; i++ {
x := uint(rnd.Uint64())
got := reversePairs(x)
want := naiveReversePairs(x)
if got != want {
t.Errorf("unexpected bit-pair reversal for 0b%064b:\ngot: 0b%064b\nwant:0b%064b", x, got, want)
}
}
}
// naiveReversePairs does a bit-pair reversal by formatting as a base-4 string,
// reversing the digits of the formatted number and then re-parsing the value.
func naiveReversePairs(x uint) uint {
bits := int(unsafe.Sizeof(uint(0)) * 8)
// Format the number as a quaternary, padded with zeros.
// We avoid the leftpad issue by doing it ourselves.
b := strconv.AppendUint(bytes.Repeat([]byte("0"), bits/2), uint64(x), 4)
b = b[len(b)-bits/2:]
// Reverse the quits.
for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 {
b[i], b[j] = b[j], b[i]
}
// Re-parse the reversed number.
y, err := strconv.ParseUint(string(b), 4, 64)
if err != nil {
panic(fmt.Sprintf("unexpected parse error: %v", err))
}
return uint(y)
}
func TestPadRadix4(t *testing.T) {
for n := 1; n <= 1025; n++ {
x := make([]complex128, n)
y := PadRadix4(x)
if bits.OnesCount(uint(len(y))) != 1 || bits.Len(uint(len(y)))&0x1 == 0 {
t.Errorf("unexpected length of padded slice: not a power of 4: %d", len(y))
}
if len(x) == len(y) && &y[0] != &x[0] {
t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n)
}
if len(y) < len(x) {
t.Errorf("unexpected short result: len(y)=%d < len(x)=%d", len(y), len(x))
}
}
}
func TestTrimRadix4(t *testing.T) {
for n := 1; n <= 1025; n++ {
x := make([]complex128, n)
y, r := TrimRadix4(x)
if bits.OnesCount(uint(len(y))) != 1 || bits.Len(uint(len(y)))&0x1 == 0 {
t.Errorf("unexpected length of trimmed slice: not a power of 4: %d", len(y))
}
if len(y)+len(r) != len(x) {
t.Errorf("unexpected total result: len(y)=%d + len(r)%d != len(x)=%d", len(y), len(r), len(x))
}
if len(x) == len(y) && &y[0] != &x[0] {
t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n)
}
if len(y) > len(x) {
t.Errorf("unexpected long result: len(y)=%d > len(x)=%d", len(y), len(x))
}
}
}
func BenchmarkCoefficients(b *testing.B) {
for n := 16; n < 1<<24; n <<= 3 {
d := randComplexes(n, rand.NewSource(1))
b.Run(fmt.Sprintf("Radix2/%d", n), func(b *testing.B) {
for i := 0; i < b.N; i++ {
CoefficientsRadix2(d)
}
})
if bits.Len(uint(n))&0x1 == 0 {
continue
}
b.Run(fmt.Sprintf("Radix4/%d", n), func(b *testing.B) {
for i := 0; i < b.N; i++ {
CoefficientsRadix4(d)
}
})
}
}
func BenchmarkSequence(b *testing.B) {
for n := 16; n < 1<<24; n <<= 3 {
d := randComplexes(n, rand.NewSource(1))
b.Run(fmt.Sprintf("Radix2/%d", n), func(b *testing.B) {
for i := 0; i < b.N; i++ {
SequenceRadix2(d)
}
})
if bits.Len(uint(n))&0x1 == 0 {
continue
}
b.Run(fmt.Sprintf("Radix4/%d", n), func(b *testing.B) {
for i := 0; i < b.N; i++ {
SequenceRadix4(d)
}
})
}
}