# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Memory tests for tensorflow.ops.custom_gradient."""
import functools
from absl.testing import parameterized
from tensorflow.compiler.xla.service import hlo_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_strategy_util
class RecomputeGradMemoryTest(test.TestCase, parameterized.TestCase):
def _get_device_type(self):
for accelerator in ["GPU", "TPU"]:
if config.list_physical_devices(accelerator):
return accelerator
return "CPU"
def _grad(self, test_func, argnums=0):
def _f(*params):
with backprop.GradientTape() as tape:
output = test_func(*params)
return tape.gradient(output, params[argnums])
return _f
(f"_{mode}", mode) for mode in ["eager", "graph"])
def testRecomputeGradNonXla(self, mode):
device_type = self._get_device_type()
device_name = f"{device_type}:0"
if device_type == "TPU":
self.skipTest("XLA is required for TPU.")
if device_type == "CPU":
self.skipTest("b/185371422: get_memory_info does't support CPU yet.")
base_memory = config.get_memory_info(device_name)["current"]
n = 500
with ops.device(device_name):
a = array_ops.ones((n, n), dtype=dtypes.float16)
def f(x):
for _ in range(5):
x = math_ops.matmul(x, x)
return x
def g(f, x):
for _ in range(5):
x = f(x)
return x[0][0]
def run(test_func):
with ops.device(device_name):
if mode == "eager":
return self._grad(test_func)(a)
return def_function.function(self._grad(test_func))(a)
f_no_recompute = functools.partial(g, f)
f_recompute = functools.partial(g, custom_gradient.recompute_grad(f))
# The result is not saved so the base memory will stay the same.
peak_memory_no_recompute = (
config.get_memory_info(device_name)["peak"] - base_memory)
peak_memory_recompute = (
config.get_memory_info(device_name)["peak"] - base_memory)
# 2 * n * n (size of `a`) * 5 (loop of f) * 5 (loop of g)
self.assertGreaterEqual(peak_memory_no_recompute, 2 * n * n * 5 * 5)
# 2 * n * n (size of `a`) * (5 (loop of g) + 5 (recompute in f))
self.assertGreaterEqual(peak_memory_recompute, 2 * n * n * 5 * 2)
# peak_memory_recompute should be less than peak_memory_no_recompute.
self.assertLess(peak_memory_recompute, 2 * n * n * 5 * 3)
res_no_recompute = run(f_no_recompute)
res_recompute = run(f_recompute)
self.assertAllClose(res_no_recompute, res_recompute)
def testRecomputeGradXla(self):
device_type = self._get_device_type()
device_name = f"{device_type}:0"
# Necessary for TFRT tests.
if device_type == "TPU":
n = 500
with ops.device(device_name):
# XLA:TPU converts f32 matmuls to bf16, and XLA:CPU converts bf16/f16
# matmuls to f32 after cl/461262189. Use a type that doesn't get
# converted.
if device_type == "TPU":
dtype = dtypes.bfloat16
elem_size = 2
dtype = dtypes.float32
elem_size = 4
a = array_ops.zeros((n, n), dtype=dtype) # elem_size * n * n bytes
def f(x):
for _ in range(5):
# matmul can not be fused by XLA.
x = math_ops.matmul(x, x)
return x
def g(f, x):
for _ in range(5):
x = f(x)
return x[0][0]
def get_peak_memory(test_func):
test_func = def_function.function(self._grad(test_func), jit_compile=True)
# The hlo_proto contains statically allocated memory info of HLO values.
hlo_proto_serialized = test_func.experimental_get_compiler_ir(a)(
stage="optimized_hlo_proto_serialized", device_name=device_name)
hlo_proto = hlo_pb2.HloProto.FromString(hlo_proto_serialized)
allocations = hlo_proto.buffer_assignment.buffer_allocations
return sum(getattr(allocation, "size") for allocation in allocations)
f_no_recompute = functools.partial(g, f)
f_recompute = functools.partial(g, custom_gradient.recompute_grad(f))
peak_memory_no_recompute = get_peak_memory(f_no_recompute)
peak_memory_recompute = get_peak_memory(f_recompute)
# elem_size * n * n (size of `a`) * 5 (loop of g) * 5 (loop of f)
self.assertGreaterEqual(peak_memory_no_recompute, elem_size * n * n * 5 * 5)
# elem_size * n * n (size of `a`) * (5 (loop of g) + 5 (recompute in f))
self.assertGreaterEqual(peak_memory_recompute, elem_size * n * n * 5 * 2)
# peak_memory_recompute should be less than peak_memory_no_recompute.
self.assertLess(peak_memory_recompute, elem_size * n * n * 5 * 3)
with ops.device(device_name):
res_recompute = f_recompute(a)
res_no_recompute = f_no_recompute(a)
self.assertAllClose(res_recompute, res_no_recompute)
if __name__ == "__main__":