blob: 365ecea97e89c7829393802050cab87bb3f86cfb [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for multinomial generation ops in the XLA JIT compiler."""
import collections
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
# TODO(srvasude): Merge this with
# third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py.
class CategoricalTest(xla_test.XLATestCase):
"""Test cases for random-number generating operators."""
def output_dtypes(self):
return set(self.int_types).intersection([np.int32, np.int64])
def _chi2(self, expected, actual):
"""Returns Chi2 GOF statistic."""
actual = np.asarray(actual)
expected = np.asarray(expected)
diff = actual - expected
chi2 = np.sum(diff * diff / expected)
return chi2
def _do_sampling(self, logits, num_samples):
"""Categorical samples from given input.
Args:
logits: Numpy ndarray of shape [batch_size, num_classes].
num_samples: Int; number of samples to draw.
Returns:
Frequencies from sampled classes; shape [batch_size, num_classes].
"""
with self.session(), self.test_scope():
random_seed.set_random_seed(1618)
op = random_ops.multinomial(logits, num_samples,
output_dtype=dtypes.int32)
d = self.evaluate(op)
batch_size, num_classes = logits.shape
freqs_mat = []
for i in range(batch_size):
cnts = dict(collections.Counter(d[i, :]))
# Requires drawn class labels be in range.
self.assertLess(max(cnts.keys()), num_classes)
self.assertGreaterEqual(min(cnts.keys()), 0)
freqs = [(cnts[k] * 1. / num_samples if k in cnts else 0)
for k in range(num_classes)]
freqs_mat.append(freqs)
return freqs_mat
def _testRngIsNotConstant(self, rng, dtype, output_dtype):
# Tests that 'rng' does not always return the same value.
with self.session():
with self.test_scope():
x = rng(dtype, output_dtype)
# The random-number generator, if working correctly, should produce the
# same output multiple times with low probability.
y = self.evaluate(x)
z = self.evaluate(x)
w = self.evaluate(x)
# We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical.
self.assertTrue((not np.array_equal(y, z)) or
(not np.array_equal(z, w)) or
(not np.array_equal(y, w)))
def testCategoricalIsNotConstant(self):
def rng(dtype, output_dtype):
return random_ops.multinomial(np.array([[1., 1., 1.]], dtype=dtype), 10,
output_dtype=output_dtype)
dtype = np.float32
for output_dtype in self.output_dtypes():
self._testRngIsNotConstant(rng, dtype, output_dtype)
@test.disable_with_predicate(
pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.")
def testCategoricalIsInRange(self):
for dtype in self.float_types:
for output_dtype in self.output_dtypes():
with self.session():
with self.test_scope():
x = random_ops.multinomial(
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
output_dtype=output_dtype)
y = self.evaluate(x)
self.assertTrue((y >= 0).sum() == 1000)
self.assertTrue((y < 20).sum() == 1000)
def testSamplingCorrectness(self):
np.random.seed(1618) # Make it reproducible.
num_samples = 40000
rand_probs = np.random.dirichlet([1., 1., 2., 3.])
rand_probs2 = np.random.dirichlet([1., 4., 5.], size=3) # batched
for probs in [[.5, .5], [.85, .05, .1], rand_probs, rand_probs2]:
probs = np.asarray(probs)
if len(probs.shape) == 1:
probs = probs.reshape(1, probs.size) # singleton batch
logits = np.log(probs).astype(np.float32)
freqs = self._do_sampling(logits, num_samples)
# the test here is similar to
# python/kernel_tests/random/multinomial_op_test.py
# Note that df >= 1 in all these cases. Choosing a cutoff of 1e-3
# corresponds to an alpha value of 2.5% for df = 1, and smaller for larger
# df.
chi2 = self._chi2(probs, freqs)
self.assertLess(chi2, 1e-3)
def testStatelessMultinomialIsInRange(self):
for dtype in self.float_types.intersection(
[dtypes.float32, dtypes.bfloat16]):
for output_dtype in self.output_dtypes():
with self.session() as sess:
with self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
x = stateless_random_ops.stateless_multinomial(
array_ops.ones(shape=[1, 20], dtype=dtype),
1000,
seed_t,
output_dtype=output_dtype)
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
self.assertTrue((y >= 0).sum() == 1000)
self.assertTrue((y < 20).sum() == 1000)
def testDeterminismMultinomial(self):
# Stateless values should be equal iff the seeds are equal (roughly)
num_samples = 10
with self.session(), self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
seeds = [(x, y) for x in range(5) for y in range(5)] * 3
for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
[0.25, 0.75]]):
pure = stateless_random_ops.stateless_multinomial(
logits, num_samples, seed=seed_t)
values = [(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds]
for s0, v0 in values:
for s1, v1 in values:
self.assertEqual(s0 == s1, np.all(v0 == v1))
def testEmpty(self):
with self.session():
with self.test_scope():
x = random_ops.multinomial(
array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32)
y = self.evaluate(x)
self.assertEqual(y.shape, (42, 0))
def testEmptyStateless(self):
with self.session() as sess:
with self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
x = stateless_random_ops.stateless_multinomial(
array_ops.zeros([42, 40]),
0,
seed=seed_t,
output_dtype=dtypes.int32)
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
self.assertEqual(y.shape, (42, 0))
if __name__ == '__main__':
googletest.main()