blob: 9a73a95cb34c6a3d69e8a91d98dc38e12ed5d272 [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for giant const op compilation."""
import os
import numpy as np
from tensorflow.python.distribute import tpu_strategy as tpu_lib
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import flags
FLAGS = flags.FLAGS
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
def get_tpu_cluster_resolver():
resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu,
zone=FLAGS.zone,
project=FLAGS.project,
)
return resolver
def get_tpu_strategy():
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
tpu_cluster_resolver.initialize_tpu_system(resolver)
return tpu_lib.TPUStrategyV2(resolver)
# This test doesn't use XLATestCase like the other tests in this directory.
# The Const op xla op kernel is compilation only and therefore is not executed
# with XLA in the on demand compilation mode. Also, here we want to feed the
# full program to XLA to verify handling of programs with giant constant
# tensors.
class GiantConstOp(test.TestCase):
# Verifies that graphs containing giant const tensors that won't fit in memory
# are compiled correctly to HLO.
def testGiantConst(self):
strategy = get_tpu_strategy()
types = {
dtypes.bool,
dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64,
dtypes.float16, dtypes.bfloat16,
dtypes.float32, dtypes.float64,
}
for dtype in types:
values = [True if dtype is dtypes.bool else 1]
if dtype is dtypes.bool:
values.append(False)
elif dtype is not dtypes.float64:
# TPUs don't follow IEEE 754 float64 standard for 64 bit floating point
# numbers so it could return different output even with just data
# transformation ops without any arithmetic operations.
values.extend([dtype.min, dtype.max])
for value in values:
@def_function.function
def train_step():
# pylint: disable=cell-var-from-loop
def computation():
const = constant_op.constant(value, dtype=dtype, shape=[1024]*4)
return const[:1, :1, :1, :1]
return strategy.run(computation, args=())
output = strategy.experimental_local_results(train_step())[0]
expected = np.full((1, 1, 1, 1), value)
self.assertAllEqual(output, expected)
if __name__ == "__main__":
# Make sure TF_XLA_FLAGS is not already set to avoid dropping the existing
# value silently.
assert "TF_XLA_FLAGS" not in os.environ
# Disable tfxla constant folding that always creates full Tensors and will
# fail for giant tensors.
os.environ["TF_XLA_FLAGS"] = "--tf_xla_disable_constant_folding=true"
test.main()