blob: 095d7c04043bb3f1084016acc979f041221abe65 [file] [log] [blame]
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for compiler_opt.rl.feature_ops."""
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from compiler_opt.rl import constant
from compiler_opt.rl import feature_ops
def _get_sqrt_z_score_preprocessing_fn_cross_product():
testcases = []
for sqrt in [True, False]:
for z_score in [True, False]:
for preprocessing_fn in [None, lambda x: x * x]:
test_name = ('sqrt_%s_zscore_%s_preprocessfn_%s' %
(sqrt, z_score, preprocessing_fn))
testcases.append((test_name, sqrt, z_score, preprocessing_fn))
return testcases
class FeatureUtilsTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
self._quantile_file_dir = os.path.join(constant.BASE_DIR, 'testdata')
super(FeatureUtilsTest, self).setUp()
def test_build_quantile_map(self):
quantile_map = feature_ops.build_quantile_map(self._quantile_file_dir)
self.assertLen(quantile_map, 1)
self.assertIn('edge_count', quantile_map)
quantile = quantile_map['edge_count']
# quantile
self.assertLen(quantile, 9)
self.assertEqual(2, quantile[0])
self.assertEqual(8, quantile[6])
def test_discard_fn(self):
# obs in shape of [2, 1].
obs = tf.constant(value=[[2.0], [8.0]])
output = feature_ops.discard_fn(obs)
self.assertAllEqual([2, 1, 0], output.shape)
def test_identity_fn(self):
# obs in shape of [2, 1].
obs = tf.constant(value=[[2.0], [8.0]])
output = feature_ops.identity_fn(obs)
expected = np.array([[[2.0]], [[8.0]]])
self.assertAllEqual([2, 1, 1], output.shape)
self.assertAllClose(expected.tolist(), output)
@parameterized.named_parameters(
*_get_sqrt_z_score_preprocessing_fn_cross_product())
def test_normalize_fn_sqrt_z_normalization(self, with_sqrt, with_z_score,
preprocessing_fn):
quantile_map = feature_ops.build_quantile_map(self._quantile_file_dir)
quantile = quantile_map['edge_count']
normalize_fn = feature_ops.get_normalize_fn(
quantile, with_sqrt, with_z_score, preprocessing_fn=preprocessing_fn)
obs = tf.constant(value=[[2.0], [8.0]])
output = normalize_fn(obs)
expected_shape = [2, 1, 2]
expected = np.array([[[0.333333, 0.111111]], [[0.777778, 0.604938]]])
if with_sqrt:
expected_shape[2] += 1
expected = np.concatenate([expected, [[[0.57735]], [[0.881917]]]],
axis=-1)
if with_z_score:
expected_shape[2] += 1
if preprocessing_fn:
expected = np.concatenate([expected, [[[-0.406244]], [[-0.33180502]]]],
axis=-1)
else:
expected = np.concatenate([expected, [[[-0.555968]], [[-0.155671]]]],
axis=-1)
self.assertAllEqual(expected_shape, output.shape)
self.assertAllClose(expected, output)
if __name__ == '__main__':
tf.test.main()