| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| |
| """Tests for the currently experimental in-graph batch ops.""" |
| import threading |
| import time |
| import numpy as np |
| |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import function |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.framework.errors import InvalidArgumentError |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import batch_ops |
| from tensorflow.python.ops import gen_batch_ops |
| from tensorflow.python.ops import gen_functional_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import random_ops |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import script_ops |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import test |
| |
| |
| def delayed_plus1(x): |
| """Sleeps for 100ms then returns x+1.""" |
| time.sleep(0.1) |
| return x + 1 |
| |
| |
| @test_util.run_all_in_graph_and_eager_modes |
| class BatchOpsTest(test.TestCase): |
| """Tests for batch_ops.{un,}batch.""" |
| |
| # Test for only non eager mode as batching in eager context as a functionality |
| # is TBD. |
| def testBasicBatch(self): |
| """Tests that a single batched tensor executes together and only once.""" |
| if context.executing_eagerly(): |
| return |
| with self.cached_session() as sess: |
| inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) |
| batched, index, _ = batch_ops.batch( |
| [inp], num_batch_threads=1, max_batch_size=2, |
| batch_timeout_micros=36000000, grad_timeout_micros=0, |
| batching_queue="") |
| thread_results = [] |
| |
| def worker(): |
| thread_results.extend( |
| sess.run([batched, index], feed_dict={inp: [1]})) |
| |
| worker_thread = threading.Thread(target=worker) |
| worker_thread.start() |
| main_results = sess.run([batched, index], feed_dict={inp: [2]}) |
| worker_thread.join() |
| |
| # At this point either the thread or the main did the batch and the other |
| # should have empty results. |
| if list(thread_results[0][0]): |
| batch_t = thread_results[0][0] |
| index_t = thread_results[1] |
| empty_b = main_results[0][0] |
| empty_m = main_results[1] |
| else: |
| batch_t = main_results[0][0] |
| index_t = main_results[1] |
| empty_b = thread_results[0][0] |
| empty_m = thread_results[1] |
| |
| # Check that both the inputs made it out exactly once. |
| self.assertAllEqual(sorted(batch_t), (1, 2)) |
| # Check that we get 2 rows in the index tensor. |
| self.assertEqual(len(index_t), 2) |
| # Check that the other ones are empty. |
| self.assertEqual(len(empty_b), 0) |
| self.assertEqual(len(empty_m), 0) |
| |
| def testBatchWithPadding(self): |
| """Test that batching with padding up to an allowed batch size works.""" |
| if context.executing_eagerly(): |
| return |
| with self.cached_session() as sess: |
| inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) |
| batched, index, _ = batch_ops.batch( |
| [inp], num_batch_threads=1, max_batch_size=10, |
| batch_timeout_micros=100000, # 100ms |
| allowed_batch_sizes=[5, 10], |
| grad_timeout_micros=0, batching_queue="") |
| thread_results = [] |
| |
| def worker(): |
| thread_results.extend( |
| sess.run([batched, index], feed_dict={inp: [1, 3]})) |
| |
| worker_thread = threading.Thread(target=worker) |
| worker_thread.start() |
| main_results = sess.run([batched, index], feed_dict={inp: [2, 4]}) |
| worker_thread.join() |
| |
| # At this point either the thread or the main did the batch and the other |
| # should have empty results. |
| if list(thread_results[0][0]): |
| batch_t = thread_results[0][0] |
| else: |
| batch_t = main_results[0][0] |
| |
| # Check that the batch tensor incorporates the padding. |
| self.assertEqual(len(batch_t), 5) |
| |
| def testMultipleBatch(self): |
| """Tests that multiple batched tensors execute together.""" |
| if context.executing_eagerly(): |
| return |
| with self.cached_session() as sess: |
| inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) |
| inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) |
| batched, _, _ = batch_ops.batch( |
| [inp0, inp1], |
| num_batch_threads=1, |
| max_batch_size=2, |
| batch_timeout_micros=36000000, |
| grad_timeout_micros=0, |
| batching_queue="") |
| thread_results = [] |
| |
| def worker(): |
| thread_results.extend( |
| sess.run([batched], feed_dict={inp0: [1], |
| inp1: [2]})) |
| |
| worker_thread = threading.Thread(target=worker) |
| worker_thread.start() |
| main_results = sess.run([batched], feed_dict={inp0: [2], inp1: [3]}) |
| worker_thread.join() |
| |
| # At this point either the thread or the main did the batch and the other |
| # should have empty results. |
| if list(thread_results[0][0]): |
| batch_t = thread_results[0] |
| empty_t = main_results[0] |
| else: |
| batch_t = main_results[0] |
| empty_t = thread_results[0] |
| |
| # Assert that the tensors were batched together. |
| self.assertAllEqual(sorted(batch_t[0]), [1, 2]) |
| self.assertAllEqual(sorted(batch_t[1]), [2, 3]) |
| self.assertAllEqual(empty_t[0], []) |
| self.assertAllEqual(empty_t[1], []) |
| |
| def testIllegalBatchDifferentDim0Sizes(self): |
| """Tests illegally feeding tensors with different dim0 sizes.""" |
| if context.executing_eagerly(): |
| return |
| with self.cached_session() as sess: |
| inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) |
| inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) |
| batched, index, _ = batch_ops.batch( |
| [inp0, inp1], num_batch_threads=1, max_batch_size=2, |
| batch_timeout_micros=0, grad_timeout_micros=0, batching_queue="") |
| with self.assertRaises(Exception) as raised: |
| _ = sess.run([batched, index], feed_dict={inp0: [0], inp1: [1, 2]}) |
| self.assertGreater( |
| raised.exception.message.find("must have equal 0th-dimension size"), |
| 0) |
| |
| def testBasicUnbatch(self): |
| """Tests that batch and unbatch work together.""" |
| if context.executing_eagerly(): |
| return |
| with self.cached_session() as sess: |
| inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) |
| batched, index, id_t = batch_ops.batch( |
| [inp], num_batch_threads=1, max_batch_size=10, |
| batch_timeout_micros=100000, # 100ms |
| allowed_batch_sizes=[3, 10], |
| grad_timeout_micros=0, batching_queue="") |
| computation = batched[0] + 1 |
| result = batch_ops.unbatch(computation, index, id_t, |
| timeout_micros=1000000, shared_name="unbatch") |
| thread_results = [] |
| |
| def worker(): |
| thread_results.extend(sess.run([result], feed_dict={inp: [1]})) |
| |
| worker_thread = threading.Thread(target=worker) |
| worker_thread.start() |
| main_results = sess.run([result], feed_dict={inp: [2]}) |
| worker_thread.join() |
| self.assertEqual(thread_results[0], [2]) |
| self.assertEqual(main_results[0], [3]) |
| |
| def testBasicUnbatchDecorated(self): |
| """Tests that the batch_function decorator works.""" |
| if context.executing_eagerly(): |
| return |
| with self.cached_session() as sess: |
| # TODO(apassos): Removing this line causes test flakiness! Ideally should |
| # be investigated. |
| default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable |
| |
| @batch_ops.batch_function(1, 10, 100000) |
| def computation(in_t): |
| self.assertTrue(in_t.shape is not None) |
| return in_t + 1 |
| |
| inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) |
| result = computation(inp) |
| thread_results = [] |
| |
| def worker(): |
| thread_results.extend(sess.run([result], feed_dict={inp: [1]})) |
| |
| worker_thread = threading.Thread(target=worker) |
| worker_thread.start() |
| main_results = sess.run([result], feed_dict={inp: [2]}) |
| worker_thread.join() |
| self.assertEqual(thread_results[0], [2]) |
| self.assertEqual(main_results[0], [3]) |
| |
| def testUnbatchInvalidIdArg(self): |
| """Tests that unbatch work together.""" |
| if context.executing_eagerly(): |
| batched_tensor = constant_op.constant( |
| value=np.random.random(size=(3, 3, 1)), dtype=dtypes.float64) |
| batched_index = constant_op.constant( |
| value=np.random.randint(0, 100, size=(3, 3, 1)), dtype=dtypes.int64) |
| arg_id = constant_op.constant( |
| value=np.random.randint(0, 100, size=(3, 3, 1)), dtype=dtypes.int64) |
| |
| with self.assertRaisesRegex(errors.InvalidArgumentError, |
| "Input id should be scalar;"): |
| batch_ops.unbatch( |
| batched_tensor=batched_tensor, |
| batch_index=batched_index, |
| id=arg_id, |
| timeout_micros=50, |
| container="", |
| shared_name="") |
| |
| def testBatchDecoratedWithCapturedInput(self): |
| """Tests that the batch_function decorator works.""" |
| if context.executing_eagerly(): |
| return |
| with self.cached_session() as sess: |
| captured_inp0 = array_ops.placeholder_with_default(2., shape=[]) |
| captured_inp1 = resource_variable_ops.ResourceVariable(3.) |
| with ops.device("/cpu:0"): |
| captured_inp2 = resource_variable_ops.ResourceVariable(4.) |
| |
| @batch_ops.batch_function(1, 10, 100000) |
| def computation(in_t): |
| return in_t + captured_inp0 + captured_inp1 + captured_inp2 |
| |
| inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) |
| result = computation(inp) |
| thread_results = [] |
| |
| def worker(): |
| thread_results.extend(sess.run([result], feed_dict={inp: [1]})) |
| |
| sess.run(variables.global_variables_initializer()) |
| worker_thread = threading.Thread(target=worker) |
| worker_thread.start() |
| main_results = sess.run([result], feed_dict={inp: [2]}) |
| worker_thread.join() |
| self.assertEqual(thread_results[0], [10]) |
| self.assertEqual(main_results[0], [11]) |
| |
| @test_util.disable_xla("DeviceIndex returns sentinel value with XLA") |
| def testBatchDecoratedGpu(self): |
| if context.executing_eagerly(): |
| return |
| with self.cached_session() as sess: |
| |
| @batch_ops.batch_function(1, 10, 100000) |
| def computation(in_t): |
| # index is 0 on CPU and 1 on GPU |
| index = gen_functional_ops.DeviceIndex(device_names=["CPU", "GPU"]) |
| return in_t + math_ops.cast(index, dtypes.float32) |
| |
| inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) |
| result = computation(inp) |
| thread_results = [] |
| |
| def worker(): |
| thread_results.extend(sess.run([result], feed_dict={inp: [10.]})) |
| |
| worker_thread = threading.Thread(target=worker) |
| worker_thread.start() |
| main_results = sess.run([result], feed_dict={inp: [20.]}) |
| worker_thread.join() |
| self.assertEqual(thread_results[0], [10 + test_util.is_gpu_available()]) |
| self.assertEqual(main_results[0], [20 + test_util.is_gpu_available()]) |
| |
| def testParallelRunsWithCpuAndGpu(self): |
| # Run multiple instances of a batch function in parallel. This is a |
| # regression test: this used to fail because _Send nodes for one call would |
| # send the tensor to the _Recv node for a different call. |
| if context.executing_eagerly(): |
| return |
| @batch_ops.batch_function(1, 2, 1) |
| def f(x): |
| with ops.device("/GPU:0"): |
| x = x + 1. |
| with ops.device("/CPU:0"): |
| return x + 1 |
| num_calls = 10 |
| placeholders = [array_ops.placeholder(dtypes.float32, shape=(1,)) |
| for _ in range(num_calls)] |
| results = [] |
| for p in placeholders: |
| result = f(p) |
| results.append(result) |
| inputs = [[float(i)] for i in range(num_calls)] |
| expected = [[float(i + 2)] for i in range(num_calls)] |
| with self.session() as sess: |
| outputs = sess.run(results, feed_dict=dict(zip(placeholders, inputs))) |
| self.assertAllEqual(outputs, expected) |
| |
| def testSoftPlacement(self): |
| if context.executing_eagerly(): |
| return |
| |
| @batch_ops.batch_function(1, 10, 100000) |
| def computation(in_t): |
| with ops.device("/GPU:0"): |
| return in_t + 1. |
| |
| inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) |
| result = computation(inp) |
| |
| # With soft placement, the function will run even without a GPU |
| config = config_pb2.ConfigProto(allow_soft_placement=True) |
| with self.session(config=config) as sess: |
| sess.run([result], feed_dict={inp: [20.]}) |
| |
| # Without soft placement, the function fails without a GPU due to the |
| # addition explicitly being placed on the GPU |
| config.allow_soft_placement = False |
| with self.session(config=config) as sess: |
| if test_util.is_gpu_available(): |
| sess.run([result], feed_dict={inp: [20.]}) |
| else: |
| with self.assertRaisesRegex(InvalidArgumentError, |
| "Cannot assign a device for operation"): |
| sess.run([result], feed_dict={inp: [20.]}) |
| |
| def testBatchFunctionOp(self): |
| """Tests that the batch_function op works.""" |
| if context.executing_eagerly(): |
| return |
| with self.cached_session() as sess: |
| |
| @function.Defun(dtypes.int32) |
| def computation(in_t): |
| return in_t + 1 |
| |
| inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) |
| result = gen_batch_ops.batch_function( |
| [inp], |
| num_batch_threads=1, |
| max_batch_size=10, |
| batch_timeout_micros=100000, |
| Tout=[dtypes.int32], |
| f=computation, |
| captured_tensors=computation.captured_inputs) |
| thread_results = [] |
| |
| def worker(): |
| thread_results.extend(sess.run([result], feed_dict={inp: [1]})) |
| |
| worker_thread = threading.Thread(target=worker) |
| worker_thread.start() |
| main_results = sess.run([result], feed_dict={inp: [2]}) |
| worker_thread.join() |
| self.assertEqual(thread_results[0], [2]) |
| self.assertEqual(main_results[0], [3]) |
| |
| def testBatchFunctionOpWithCapturedInput(self): |
| """Tests that batch_function op works with captured input.""" |
| if context.executing_eagerly(): |
| return |
| with self.cached_session() as sess: |
| captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) |
| captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) |
| inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) |
| |
| @function.Defun(dtypes.int32) |
| def computation(inp): |
| return inp + captured_inp0 - captured_inp1 |
| |
| result = gen_batch_ops.batch_function( |
| num_batch_threads=1, |
| max_batch_size=10, |
| batch_timeout_micros=100000, # 100ms |
| allowed_batch_sizes=[3, 10], |
| batching_queue="", |
| f=computation, |
| in_tensors=[inp], |
| captured_tensors=computation.captured_inputs, |
| Tout=[o.type for o in computation.definition.signature.output_arg]) |
| |
| thread_results = [] |
| |
| def worker(): |
| thread_results.extend(sess.run([result], feed_dict={inp: [1]})) |
| |
| worker_thread = threading.Thread(target=worker) |
| worker_thread.start() |
| main_results = sess.run([result], feed_dict={inp: [2]}) |
| worker_thread.join() |
| self.assertEqual(thread_results[0], [2]) |
| self.assertEqual(main_results[0], [3]) |
| |
| def testBatchFunctionOpWithInputError(self): |
| """Tests that batch_function op works with error in the inputs.""" |
| if context.executing_eagerly(): |
| return |
| with self.cached_session() as sess: |
| inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) |
| |
| @function.Defun(dtypes.int32, dtypes.int32) |
| def computation(in0, in1): |
| return in0 + in1 |
| |
| result = gen_batch_ops.batch_function( |
| [inp], # computation actually expects 2 inputs. |
| num_batch_threads=1, |
| max_batch_size=10, |
| batch_timeout_micros=100000, # 100ms |
| batching_queue="", |
| f=computation, |
| captured_tensors=computation.captured_inputs, |
| Tout=[o.type for o in computation.definition.signature.output_arg]) |
| |
| with self.assertRaisesRegex( |
| InvalidArgumentError, |
| r"Function takes 2 argument\(s\) but 1 argument\(s\) were passed"): |
| sess.run([result], feed_dict={inp: [2]}) |
| |
| def testBatchFunctionOpWithLargeBatchSplitted(self): |
| """Tests that the batch_function op works with large batch splitted.""" |
| if context.executing_eagerly(): |
| return |
| |
| with self.cached_session() as sess: |
| |
| @function.Defun(dtypes.int32) |
| def computation(in_t): |
| return in_t + 3 |
| |
| inp = array_ops.placeholder(dtype=dtypes.int32) |
| result = gen_batch_ops.batch_function( |
| [inp], |
| num_batch_threads=2, |
| # enable_large_batch_splitting is True, so it's valid as long as |
| # max('allowed_batch_sizes') <= 'max_batch_size'. |
| allowed_batch_sizes=[1, 2], |
| max_batch_size=5, |
| batch_timeout_micros=100000, # 100ms |
| Tout=[dtypes.int32], |
| enable_large_batch_splitting=True, |
| f=computation, |
| captured_tensors=computation.captured_inputs) |
| thread1_results = [] |
| thread2_results = [] |
| |
| # Input sizes of worker1 and main thread are larger than |
| # max(allowed_batch_sizes), while input size of worker2 is smaller. |
| def worker1(): |
| thread1_results.extend( |
| sess.run([result], feed_dict={inp: [5, 6, 7, 8, 9]})) |
| |
| worker_thread1 = threading.Thread(target=worker1) |
| worker_thread1.start() |
| |
| def worker2(): |
| thread2_results.extend(sess.run([result], feed_dict={inp: [10]})) |
| |
| worker_thread2 = threading.Thread(target=worker2) |
| worker_thread2.start() |
| |
| main_results = sess.run([result], feed_dict={inp: [2, 3, 4]}) |
| worker_thread1.join() |
| worker_thread2.join() |
| self.assertTrue( |
| np.all(np.equal(thread2_results[0], np.array([13], dtype=np.int32)))) |
| self.assertTrue( |
| np.all( |
| np.equal(thread1_results[0], |
| np.array([8, 9, 10, 11, 12], dtype=np.int32)))) |
| self.assertTrue( |
| np.all( |
| np.equal(main_results[0], np.array([5, 6, 7], dtype=np.int32)))) |
| |
| def testBasicUnbatchDecoratedWithReshape(self): |
| """Tests that the batch_function decorator works.""" |
| if context.executing_eagerly(): |
| return |
| with self.cached_session() as sess: |
| |
| @batch_ops.batch_function(1, 10, 100000) |
| def computation(in_t): |
| return array_ops.reshape(in_t, [-1]) + 1 |
| |
| inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1]) |
| result = computation(inp) |
| thread_results = [] |
| |
| def worker(): |
| thread_results.extend(sess.run([result], feed_dict={inp: [[1]]})) |
| |
| worker_thread = threading.Thread(target=worker) |
| worker_thread.start() |
| main_results = sess.run([result], feed_dict={inp: [[2]]}) |
| worker_thread.join() |
| self.assertEqual(thread_results[0], [2]) |
| self.assertEqual(main_results[0], [3]) |
| |
| def testUnbatchTimeout(self): |
| """Tests that the unbatch timeout works.""" |
| if context.executing_eagerly(): |
| return |
| with self.cached_session() as sess: |
| inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) |
| batched, index, id_t = batch_ops.batch( |
| [inp], num_batch_threads=1, max_batch_size=2, |
| batch_timeout_micros=36000000, grad_timeout_micros=0, |
| batching_queue="") |
| computation = batched[0] + 1 |
| timeout_micros = 10 |
| result = batch_ops.unbatch(computation, index, id_t, timeout_micros, |
| shared_name="shared_unbatch") |
| # Set up a parallel pipeline that delays the computation, but uses the |
| # same unbatch resource object as the non-delayed pipeline. |
| computation_delayed = script_ops.py_func(delayed_plus1, |
| [batched[0]], |
| dtypes.int32) |
| result_delayed = batch_ops.unbatch(computation_delayed, |
| index, |
| id_t, |
| timeout_micros, |
| shared_name="shared_unbatch") |
| |
| thread_results = [] |
| def worker(): |
| # A first call using the non-delayed pipeline. The batcher will send an |
| # empty tensor along the non-delayed pipeline. |
| thread_results.extend(sess.run([result], feed_dict={inp: [1]})) |
| worker_thread = threading.Thread(target=worker) |
| worker_thread.start() |
| time.sleep(0.1) # Ensure the thread's call starts first. |
| # A second call using the delayed pipeline. The batcher will send the |
| # batched tensor along the delayed pipeline, thus delaying the arrival of |
| # the batched tensor at the unbatch op, relative to the empty tensor. |
| # |
| # TODO(olston, apassos): Avoid relying on the order in which the batch op |
| # emits the empty tensor versus the batched one. |
| _ = sess.run([result_delayed], feed_dict={inp: [2]}) |
| worker_thread.join() |
| # The thread's call should hit the timeout, and thus get 0 results. |
| self.assertEqual(len(thread_results), 0) |
| |
| def testUnbatchGradInvalidId(self): |
| with self.assertRaises(errors.InvalidArgumentError): |
| self.evaluate( |
| gen_batch_ops.unbatch_grad( |
| original_input=constant_op.constant([1]), |
| batch_index=constant_op.constant([ |
| [0, 0, 0], |
| ], dtype=dtypes.int64), |
| grad=constant_op.constant([ |
| 1, |
| ]), |
| id=constant_op.constant([ |
| 1, |
| 1, |
| ], dtype=dtypes.int64))) |
| |
| def testUnbatchGradInvalidBatchId(self): |
| with self.assertRaises(errors.InvalidArgumentError): |
| self.evaluate( |
| gen_batch_ops.unbatch_grad( |
| original_input=constant_op.constant([1]), |
| batch_index=constant_op.constant([ |
| [0, 0], |
| ], dtype=dtypes.int64), |
| grad=constant_op.constant([ |
| 1, |
| ]), |
| id=constant_op.constant([ |
| 1, |
| ], dtype=dtypes.int64))) |
| |
| def testUnbatchGradInvalidArgs(self): |
| original_input = random_ops.random_uniform( |
| shape=(3, 1), dtype=dtypes.float64, maxval=None) |
| batch_index = random_ops.random_uniform( |
| shape=(3, 1), dtype=dtypes.int64, maxval=65536) |
| grad = random_ops.random_uniform( |
| shape=(3, 1), dtype=dtypes.float64, maxval=None) |
| batch_id = random_ops.random_uniform( |
| shape=(3, 1), dtype=dtypes.int64, maxval=65536) |
| with self.assertRaises(errors.InvalidArgumentError): |
| self.evaluate( |
| gen_batch_ops.unbatch_grad( |
| original_input=original_input, |
| batch_index=batch_index, |
| grad=grad, |
| id=batch_id, |
| container="", |
| shared_name="", |
| name="")) |
| |
| if __name__ == "__main__": |
| test.main() |