blob: 88b24f4db7666bd8c55b6bd12b1191bc7ac85c39 [file] [log] [blame]
# python/ops/memory_tests package
load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test")
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_strict_test")
load("//tensorflow/core/platform:build_config_root.bzl", "tf_additional_xla_deps_py")
package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
cuda_py_strict_test(
name = "custom_gradient_memory_test",
size = "medium",
srcs = ["custom_gradient_memory_test.py"],
xla_enable_strict_auto_jit = False, # XLA are enabled explicitly in XLA memory tests.
deps = [
"//tensorflow/compiler/xla/service:hlo_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/framework:config",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:test_lib",
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/platform:test",
"@absl_py//absl/testing:parameterized",
] + tf_additional_xla_deps_py(),
)
tpu_py_strict_test(
name = "custom_gradient_memory_test_tpu",
size = "medium",
srcs = ["custom_gradient_memory_test.py"],
# TODO(b/238830423): This test uses experimental_get_compiler_ir, which is
# not supported with TFRT (Failed getting HLO text: 'GetCompilerIr is not
# supported on this device.').
disable_tfrt = True,
main = "custom_gradient_memory_test.py",
deps = [
"//tensorflow/compiler/xla/service:hlo_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/framework:config",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:test_lib",
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/platform:test",
"@absl_py//absl/testing:parameterized",
] + tf_additional_xla_deps_py(),
)