blob: 1a1080b448fb84a03c84e25d4e6b05142c9bf477 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import weakref
from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.xla.service import hlo_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ContextTest(test.TestCase, parameterized.TestCase):
def testSetGlobalSeed(self):
c = context.Context()
c._set_global_seed(123)
for t in [np.int32, np.int64, np.uint32, np.uint64]:
c._set_global_seed(t(123))
c._set_global_seed(np.array(123, dtype=t))
c._set_global_seed(ops.convert_to_tensor(123, dtype=t))
def testContextIsDestroyedAfterTensors(self):
# Create a new context
new_context = context.Context()
weak_c = weakref.ref(new_context)
new_context.ensure_initialized()
# Create a tensor with the new context as default.
# Make sure to restore the original context.
original_context = context.context()
try:
context._set_context(new_context)
# Use a 2D tensor so that it is not cached.
tensor1 = constant_op.constant([[3.]])
# Produce a tensor as an operation output. This uses a different code path
# from tensors created from Python.
tensor2 = tensor1 * tensor1
context._set_context(original_context)
except:
context._set_context(original_context)
raise
# Deleting our context reference should not delete the underlying object.
del new_context
self.assertIsNot(weak_c(), None)
# Deleting the first tensor should not delete the context since there is
# another tensor.
del tensor1
self.assertIsNot(weak_c(), None)
# Deleting the last tensor should result in deleting its context.
del tensor2
self.assertIs(weak_c(), None)
def testSimpleGraphCollection(self):
@def_function.function
def f(x):
with ops.device('CPU:0'):
return x + constant_op.constant(1.)
with context.collect_graphs() as graphs:
with ops.device('CPU:0'):
x = constant_op.constant(1.)
f(x)
self.assertLen(graphs, 1)
graph, = graphs
self.assertIn('CPU:0', graph.node[1].device)
@test_util.disable_tfrt(
'b/171600738: tfrt does not support exporting post-optimization graph')
def testGraphCollectionAfterDevicePlacement(self):
@def_function.function
def f(x):
return x + constant_op.constant(1.)
with context.collect_graphs() as graphs:
with ops.device('CPU:0'):
f(constant_op.constant(1.))
self.assertLen(graphs, 1)
graph, = graphs
self.assertIn('CPU:0', graph.node[0].device)
def testGetFunctionDef(self):
@def_function.function
def f():
return constant_op.constant(1.)
concrete = f.get_concrete_function()
function_def = context.get_function_def(concrete.name)
self.assertIsNot(function_def, None)
found_const_node = False
for node_def in function_def.node_def:
if node_def.op == 'Const':
found_const_node = True
break
self.assertTrue(found_const_node)
with self.assertRaises(errors.NotFoundError):
_ = context.get_function_def('this_should_not_be_found')
@test_util.run_gpu_only
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryInfo(self):
array_ops.zeros([10]) # Allocate some memory on the GPU.
self.assertGreater(context.context().get_memory_info('GPU:0')['current'], 0)
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryInfoCPU(self):
if test_util.IsMklEnabled():
# TODO(gzmkl) work with Google team to address design issue in allocator.h
self.skipTest('MklCPUAllocator does not throw exception. So skip test.')
with self.assertRaisesRegex(ValueError, 'Allocator stats not available'):
context.context().get_memory_info('CPU:0')
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryInfoUnknownDevice(self):
with self.assertRaisesRegex(ValueError, 'No matching devices found'):
context.context().get_memory_info('unknown_device:0')
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryInfoUnparsableDevice(self):
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
context.context().get_memory_info('GPU')
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
context.context().get_memory_info('GPU:')
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
context.context().get_memory_info('GPU:CPU')
def testListFunctionNames(self):
@def_function.function
def f():
return constant_op.constant(1.)
concrete = f.get_concrete_function()
self.assertIn(concrete.name.decode(),
context.context().list_function_names())
def testSetLogicalDeviceAfterContextInitialization(self):
ctx = context.Context()
ctx.set_logical_cpu_devices(4)
self.assertIs(len(ctx.list_logical_devices('CPU')), 4)
# Cannot set logical device twice.
with self.assertRaisesRegex(RuntimeError, 'Virtual CPUs already set'):
ctx.set_logical_cpu_devices(8)
def testSetLogicalLocalCpuDevice(self):
ctx = context.Context()
# Manually add a remote CPU device into logical device list.
ctx._logical_devices = [] # pylint: disable=protected-access
dev = context.LogicalDevice(name='/job:worker/replica:0/task:1',
device_type='CPU')
ctx._logical_devices.append(dev) # pylint: disable=protected-access
self.assertIs(len(ctx.list_logical_devices('CPU')), 1)
# This would pass the check since the previously added device is not local.
ctx.set_logical_cpu_devices(4)
# Logical device list would be overwritten after initialization.
self.assertIs(len(ctx.list_logical_devices('CPU')), 4)
@parameterized.named_parameters([(f'_{stage}', stage) for stage in [
'hlo', 'hlo_serialized', 'optimized_hlo', 'optimized_hlo_serialized',
'optimized_hlo_proto_serialized', 'optimized_hlo_dot'
]])
def testGetCompilerIr(self, stage):
@def_function.function(jit_compile=True)
def test_func(x):
return 2 * x
a = array_ops.ones((1000, 1000)) # 4 * 1000 * 1000 in bytes
result = test_func.experimental_get_compiler_ir(a)(stage=stage)
self.assertNotEmpty(result)
if stage == 'optimized_hlo_proto_serialized':
hlo_proto = hlo_pb2.HloProto.FromString(result)
allocations = hlo_proto.buffer_assignment.buffer_allocations
buffer_size = sum(
getattr(allocation, 'size') for allocation in allocations)
# The sizes of input and output are both 4 * 1000 * 1000 in bytes.
self.assertGreaterEqual(buffer_size, 2 * 4 * 1000 * 1000)
self.assertLess(buffer_size, 4 * 4 * 1000 * 1000)
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()