#!/usr/bin/env python from os.path import join, dirname import numpy as np from numpy.testing import assert_array_almost_equal, assert_equal, TestCase from scipy.fftpack.realtransforms import dct, idct, dst, idst # Matlab reference data MDATA = np.load(join(dirname(__file__), 'test.npz')) X = [MDATA['x%d' % i] for i in range(8)] Y = [MDATA['y%d' % i] for i in range(8)] # FFTW reference data: the data are organized as follows: # * SIZES is an array containing all available sizes # * for every type (1, 2, 3, 4) and every size, the array dct_type_size # contains the output of the DCT applied to the input np.linspace(0, size-1, # size) FFTWDATA_DOUBLE = np.load(join(dirname(__file__), 'fftw_double_ref.npz')) FFTWDATA_SINGLE = np.load(join(dirname(__file__), 'fftw_single_ref.npz')) FFTWDATA_SIZES = FFTWDATA_DOUBLE['sizes'] def fftw_dct_ref(type, size, dt): x = np.linspace(0, size-1, size).astype(dt) if dt == np.double: data = FFTWDATA_DOUBLE elif dt == np.float32: data = FFTWDATA_SINGLE else: raise ValueError() y = (data['dct_%d_%d' % (type, size)]).astype(dt) return x, y def fftw_dst_ref(type, size, dt): x = np.linspace(0, size-1, size).astype(dt) if dt == np.double: data = FFTWDATA_DOUBLE elif dt == np.float32: data = FFTWDATA_SINGLE else: raise ValueError() y = (data['dst_%d_%d' % (type, size)]).astype(dt) return x, y class _TestDCTBase(TestCase): def setUp(self): self.rdt = None self.dec = 14 self.type = None def test_definition(self): for i in FFTWDATA_SIZES: x, yr = fftw_dct_ref(self.type, i, self.rdt) y = dct(x, type=self.type) self.assertTrue(y.dtype == self.rdt, "Output dtype is %s, expected %s" % (y.dtype, self.rdt)) # XXX: we divide by np.max(y) because the tests fail otherwise. We # should really use something like assert_array_approx_equal. The # difference is due to fftw using a better algorithm w.r.t error # propagation compared to the ones from fftpack. assert_array_almost_equal(y / np.max(y), yr / np.max(y), decimal=self.dec, err_msg="Size %d failed" % i) def test_axis(self): nt = 2 for i in [7, 8, 9, 16, 32, 64]: x = np.random.randn(nt, i) y = dct(x, type=self.type) for j in range(nt): assert_array_almost_equal(y[j], dct(x[j], type=self.type), decimal=self.dec) x = x.T y = dct(x, axis=0, type=self.type) for j in range(nt): assert_array_almost_equal(y[:,j], dct(x[:,j], type=self.type), decimal=self.dec) class _TestDCTIIBase(_TestDCTBase): def test_definition_matlab(self): """Test correspondance with matlab (orthornomal mode).""" for i in range(len(X)): x = np.array(X[i], dtype=self.rdt) yr = Y[i] y = dct(x, norm="ortho", type=2) self.assertTrue(y.dtype == self.rdt, "Output dtype is %s, expected %s" % (y.dtype, self.rdt)) assert_array_almost_equal(y, yr, decimal=self.dec) class _TestDCTIIIBase(_TestDCTBase): def test_definition_ortho(self): """Test orthornomal mode.""" for i in range(len(X)): x = np.array(X[i], dtype=self.rdt) y = dct(x, norm='ortho', type=2) xi = dct(y, norm="ortho", type=3) self.assertTrue(xi.dtype == self.rdt, "Output dtype is %s, expected %s" % (xi.dtype, self.rdt)) assert_array_almost_equal(xi, x, decimal=self.dec) class TestDCTIDouble(_TestDCTBase): def setUp(self): self.rdt = np.double self.dec = 10 self.type = 1 class TestDCTIFloat(_TestDCTBase): def setUp(self): self.rdt = np.float32 self.dec = 5 self.type = 1 class TestDCTIIDouble(_TestDCTIIBase): def setUp(self): self.rdt = np.double self.dec = 10 self.type = 2 class TestDCTIIFloat(_TestDCTIIBase): def setUp(self): self.rdt = np.float32 self.dec = 5 self.type = 2 class TestDCTIIIDouble(_TestDCTIIIBase): def setUp(self): self.rdt = np.double self.dec = 14 self.type = 3 class TestDCTIIIFloat(_TestDCTIIIBase): def setUp(self): self.rdt = np.float32 self.dec = 5 self.type = 3 class _TestIDCTBase(TestCase): def setUp(self): self.rdt = None self.dec = 14 self.type = None def test_definition(self): for i in FFTWDATA_SIZES: xr, yr = fftw_dct_ref(self.type, i, self.rdt) x = idct(yr, type=self.type) if self.type == 1: x /= 2 * (i-1) else: x /= 2 * i self.assertTrue(x.dtype == self.rdt, "Output dtype is %s, expected %s" % (x.dtype, self.rdt)) # XXX: we divide by np.max(y) because the tests fail otherwise. We # should really use something like assert_array_approx_equal. The # difference is due to fftw using a better algorithm w.r.t error # propagation compared to the ones from fftpack. assert_array_almost_equal(x / np.max(x), xr / np.max(x), decimal=self.dec, err_msg="Size %d failed" % i) class TestIDCTIDouble(_TestIDCTBase): def setUp(self): self.rdt = np.double self.dec = 10 self.type = 1 class TestIDCTIFloat(_TestIDCTBase): def setUp(self): self.rdt = np.float32 self.dec = 4 self.type = 1 class TestIDCTIIDouble(_TestIDCTBase): def setUp(self): self.rdt = np.double self.dec = 10 self.type = 2 class TestIDCTIIFloat(_TestIDCTBase): def setUp(self): self.rdt = np.float32 self.dec = 5 self.type = 2 class TestIDCTIIIDouble(_TestIDCTBase): def setUp(self): self.rdt = np.double self.dec = 14 self.type = 3 class TestIDCTIIIFloat(_TestIDCTBase): def setUp(self): self.rdt = np.float32 self.dec = 5 self.type = 3 class _TestDSTBase(TestCase): def setUp(self): self.rdt = None # dtype self.dec = None # number of decimals to match self.type = None # dst type def test_definition(self): for i in FFTWDATA_SIZES: xr, yr = fftw_dst_ref(self.type, i, self.rdt) y = dst(xr, type=self.type) self.assertTrue(y.dtype == self.rdt, "Output dtype is %s, expected %s" % (y.dtype, self.rdt)) # XXX: we divide by np.max(y) because the tests fail otherwise. We # should really use something like assert_array_approx_equal. The # difference is due to fftw using a better algorithm w.r.t error # propagation compared to the ones from fftpack. assert_array_almost_equal(y / np.max(y), yr / np.max(y), decimal=self.dec, err_msg="Size %d failed" % i) class TestDSTIDouble(_TestDSTBase): def setUp(self): self.rdt = np.double self.dec = 14 self.type = 1 class TestDSTIFloat(_TestDSTBase): def setUp(self): self.rdt = np.float32 self.dec = 5 self.type = 1 class TestDSTIIDouble(_TestDSTBase): def setUp(self): self.rdt = np.double self.dec = 14 self.type = 2 class TestDSTIIFloat(_TestDSTBase): def setUp(self): self.rdt = np.float32 self.dec = 6 self.type = 2 class TestDSTIIIDouble(_TestDSTBase): def setUp(self): self.rdt = np.double self.dec = 14 self.type = 3 class TestDSTIIIFloat(_TestDSTBase): def setUp(self): self.rdt = np.float32 self.dec = 7 self.type = 3 class _TestIDSTBase(TestCase): def setUp(self): self.rdt = None self.dec = None self.type = None def test_definition(self): for i in FFTWDATA_SIZES: xr, yr = fftw_dst_ref(self.type, i, self.rdt) x = idst(yr, type=self.type) if self.type == 1: x /= 2 * (i+1) else: x /= 2 * i self.assertTrue(x.dtype == self.rdt, "Output dtype is %s, expected %s" % (x.dtype, self.rdt)) # XXX: we divide by np.max(x) because the tests fail otherwise. We # should really use something like assert_array_approx_equal. The # difference is due to fftw using a better algorithm w.r.t error # propagation compared to the ones from fftpack. assert_array_almost_equal(x / np.max(x), xr / np.max(x), decimal=self.dec, err_msg="Size %d failed" % i) class TestIDSTIDouble(_TestIDSTBase): def setUp(self): self.rdt = np.double self.dec = 12 self.type = 1 class TestIDSTIFloat(_TestIDSTBase): def setUp(self): self.rdt = np.float32 self.dec = 4 self.type = 1 class TestIDSTIIDouble(_TestIDSTBase): def setUp(self): self.rdt = np.double self.dec = 14 self.type = 2 class TestIDSTIIFloat(_TestIDSTBase): def setUp(self): self.rdt = np.float32 self.dec = 6 self.type = 2 class TestIDSTIIIDouble(_TestIDSTBase): def setUp(self): self.rdt = np.double self.dec = 14 self.type = 3 class TestIDSTIIIFloat(_TestIDSTBase): def setUp(self): self.rdt = np.float32 self.dec = 6 self.type = 3 class TestOverwrite(object): """ Check input overwrite behavior """ real_dtypes = [np.float32, np.float64] def _check(self, x, routine, type, fftsize, axis, norm, overwrite_x, should_overwrite, **kw): x2 = x.copy() y = routine(x2, type, fftsize, axis, norm, overwrite_x=overwrite_x) sig = "%s(%s%r, %r, axis=%r, overwrite_x=%r)" % ( routine.__name__, x.dtype, x.shape, fftsize, axis, overwrite_x) if not should_overwrite: assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig) else: if (x2 == x).all(): raise AssertionError("no overwrite in %s" % sig) def _check_1d(self, routine, dtype, shape, axis, overwritable_dtypes): np.random.seed(1234) if np.issubdtype(dtype, np.complexfloating): data = np.random.randn(*shape) + 1j*np.random.randn(*shape) else: data = np.random.randn(*shape) data = data.astype(dtype) for type in [1, 2, 3]: for overwrite_x in [True, False]: for norm in [None, 'ortho']: if type == 1 and norm == 'ortho': continue should_overwrite = (overwrite_x and dtype in overwritable_dtypes and (len(shape) == 1 or (axis % len(shape) == len(shape)-1 ))) self._check(data, routine, type, None, axis, norm, overwrite_x, should_overwrite) def test_dct(self): overwritable = self.real_dtypes for dtype in self.real_dtypes: self._check_1d(dct, dtype, (16,), -1, overwritable) self._check_1d(dct, dtype, (16, 2), 0, overwritable) self._check_1d(dct, dtype, (2, 16), 1, overwritable) def test_idct(self): overwritable = self.real_dtypes for dtype in self.real_dtypes: self._check_1d(idct, dtype, (16,), -1, overwritable) self._check_1d(idct, dtype, (16, 2), 0, overwritable) self._check_1d(idct, dtype, (2, 16), 1, overwritable) def test_dst(self): overwritable = self.real_dtypes for dtype in self.real_dtypes: self._check_1d(dst, dtype, (16,), -1, overwritable) self._check_1d(dst, dtype, (16, 2), 0, overwritable) self._check_1d(dst, dtype, (2, 16), 1, overwritable) def test_idst(self): overwritable = self.real_dtypes for dtype in self.real_dtypes: self._check_1d(idst, dtype, (16,), -1, overwritable) self._check_1d(idst, dtype, (16, 2), 0, overwritable) self._check_1d(idst, dtype, (2, 16), 1, overwritable) if __name__ == "__main__": np.testing.run_module_suite()