| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for multinomial generation ops in the XLA JIT compiler.""" |
| |
| import collections |
| |
| import numpy as np |
| |
| from tensorflow.compiler.tests import xla_test |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import random_seed |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import random_ops |
| from tensorflow.python.ops import stateless_random_ops |
| from tensorflow.python.platform import googletest |
| from tensorflow.python.platform import test |
| |
| |
| # TODO(srvasude): Merge this with |
| # third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py. |
| class CategoricalTest(xla_test.XLATestCase): |
| """Test cases for random-number generating operators.""" |
| |
| def output_dtypes(self): |
| return set(self.int_types).intersection([np.int32, np.int64]) |
| |
| def _chi2(self, expected, actual): |
| """Returns Chi2 GOF statistic.""" |
| actual = np.asarray(actual) |
| expected = np.asarray(expected) |
| diff = actual - expected |
| chi2 = np.sum(diff * diff / expected) |
| return chi2 |
| |
| def _do_sampling(self, logits, num_samples): |
| """Categorical samples from given input. |
| |
| Args: |
| logits: Numpy ndarray of shape [batch_size, num_classes]. |
| num_samples: Int; number of samples to draw. |
| |
| Returns: |
| Frequencies from sampled classes; shape [batch_size, num_classes]. |
| """ |
| with self.session(), self.test_scope(): |
| random_seed.set_random_seed(1618) |
| op = random_ops.multinomial(logits, num_samples, |
| output_dtype=dtypes.int32) |
| d = self.evaluate(op) |
| |
| batch_size, num_classes = logits.shape |
| freqs_mat = [] |
| for i in range(batch_size): |
| cnts = dict(collections.Counter(d[i, :])) |
| |
| # Requires drawn class labels be in range. |
| self.assertLess(max(cnts.keys()), num_classes) |
| self.assertGreaterEqual(min(cnts.keys()), 0) |
| |
| freqs = [(cnts[k] * 1. / num_samples if k in cnts else 0) |
| for k in range(num_classes)] |
| freqs_mat.append(freqs) |
| |
| return freqs_mat |
| |
| def _testRngIsNotConstant(self, rng, dtype, output_dtype): |
| # Tests that 'rng' does not always return the same value. |
| with self.session(): |
| with self.test_scope(): |
| x = rng(dtype, output_dtype) |
| |
| # The random-number generator, if working correctly, should produce the |
| # same output multiple times with low probability. |
| y = self.evaluate(x) |
| z = self.evaluate(x) |
| w = self.evaluate(x) |
| |
| # We use exact equality here. If the random-number generator is producing |
| # deterministic output, all three outputs will be bitwise identical. |
| self.assertTrue((not np.array_equal(y, z)) or |
| (not np.array_equal(z, w)) or |
| (not np.array_equal(y, w))) |
| |
| def testCategoricalIsNotConstant(self): |
| def rng(dtype, output_dtype): |
| return random_ops.multinomial(np.array([[1., 1., 1.]], dtype=dtype), 10, |
| output_dtype=output_dtype) |
| |
| dtype = np.float32 |
| for output_dtype in self.output_dtypes(): |
| self._testRngIsNotConstant(rng, dtype, output_dtype) |
| |
| @test.disable_with_predicate( |
| pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") |
| def testCategoricalIsInRange(self): |
| for dtype in self.float_types: |
| for output_dtype in self.output_dtypes(): |
| with self.session(): |
| with self.test_scope(): |
| x = random_ops.multinomial( |
| array_ops.ones(shape=[1, 20], dtype=dtype), 1000, |
| output_dtype=output_dtype) |
| y = self.evaluate(x) |
| self.assertTrue((y >= 0).sum() == 1000) |
| self.assertTrue((y < 20).sum() == 1000) |
| |
| def testSamplingCorrectness(self): |
| np.random.seed(1618) # Make it reproducible. |
| num_samples = 40000 |
| |
| rand_probs = np.random.dirichlet([1., 1., 2., 3.]) |
| rand_probs2 = np.random.dirichlet([1., 4., 5.], size=3) # batched |
| for probs in [[.5, .5], [.85, .05, .1], rand_probs, rand_probs2]: |
| probs = np.asarray(probs) |
| if len(probs.shape) == 1: |
| probs = probs.reshape(1, probs.size) # singleton batch |
| |
| logits = np.log(probs).astype(np.float32) |
| freqs = self._do_sampling(logits, num_samples) |
| |
| # the test here is similar to |
| # python/kernel_tests/random/multinomial_op_test.py |
| # Note that df >= 1 in all these cases. Choosing a cutoff of 1e-3 |
| # corresponds to an alpha value of 2.5% for df = 1, and smaller for larger |
| # df. |
| chi2 = self._chi2(probs, freqs) |
| self.assertLess(chi2, 1e-3) |
| |
| def testStatelessMultinomialIsInRange(self): |
| for dtype in self.float_types.intersection( |
| [dtypes.float32, dtypes.bfloat16]): |
| for output_dtype in self.output_dtypes(): |
| with self.session() as sess: |
| with self.test_scope(): |
| seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) |
| x = stateless_random_ops.stateless_multinomial( |
| array_ops.ones(shape=[1, 20], dtype=dtype), |
| 1000, |
| seed_t, |
| output_dtype=output_dtype) |
| y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) |
| self.assertTrue((y >= 0).sum() == 1000) |
| self.assertTrue((y < 20).sum() == 1000) |
| |
| def testDeterminismMultinomial(self): |
| # Stateless values should be equal iff the seeds are equal (roughly) |
| num_samples = 10 |
| with self.session(), self.test_scope(): |
| seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) |
| seeds = [(x, y) for x in range(5) for y in range(5)] * 3 |
| for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], |
| [0.25, 0.75]]): |
| pure = stateless_random_ops.stateless_multinomial( |
| logits, num_samples, seed=seed_t) |
| values = [(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds] |
| for s0, v0 in values: |
| for s1, v1 in values: |
| self.assertEqual(s0 == s1, np.all(v0 == v1)) |
| |
| def testEmpty(self): |
| with self.session(): |
| with self.test_scope(): |
| x = random_ops.multinomial( |
| array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32) |
| y = self.evaluate(x) |
| self.assertEqual(y.shape, (42, 0)) |
| |
| def testEmptyStateless(self): |
| with self.session() as sess: |
| with self.test_scope(): |
| seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) |
| x = stateless_random_ops.stateless_multinomial( |
| array_ops.zeros([42, 40]), |
| 0, |
| seed=seed_t, |
| output_dtype=dtypes.int32) |
| y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) |
| self.assertEqual(y.shape, (42, 0)) |
| |
| |
| |
| if __name__ == '__main__': |
| googletest.main() |