blob: a3b66fcaafd9138d42a0939a159bfeb1be4cc623 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for op_selector.py."""
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import op_selector
from tensorflow.python.platform import test
class SelectTest(test.TestCase):
def setUp(self):
self.graph = ops_lib.Graph()
with self.graph.as_default():
self.a = constant_op.constant([1., 1.], shape=[2], name="a")
with ops_lib.name_scope("foo"):
self.b = constant_op.constant([2., 2.], shape=[2], name="b")
self.c = math_ops.add(self.a, self.b, name="c")
self.d = constant_op.constant([3., 3.], shape=[2], name="d")
with ops_lib.name_scope("bar"):
self.e = math_ops.add(self.c, self.d, name="e")
self.f = math_ops.add(self.c, self.d, name="f")
self.g = math_ops.add(self.c, self.a, name="g")
with ops_lib.control_dependencies([self.c.op]):
self.h = math_ops.add(self.f, self.g, name="h")
def test_is_iterable(self):
"""Test for is_iterable."""
self.assertTrue(op_selector.is_iterable([0, 1, 2]))
self.assertFalse(op_selector.is_iterable(3))
def test_unique_graph(self):
"""Test for check_graphs and get_unique_graph."""
g0 = ops_lib.Graph()
with g0.as_default():
a0 = constant_op.constant(1)
b0 = constant_op.constant(2)
g1 = ops_lib.Graph()
with g1.as_default():
a1 = constant_op.constant(1)
b1 = constant_op.constant(2)
# Same graph, should be fine.
self.assertIsNone(op_selector.check_graphs(a0, b0))
# Two different graphs, should assert.
with self.assertRaises(ValueError):
op_selector.check_graphs(a0, b0, a1, b1)
# a0 and b0 belongs to the same graph, should be fine.
self.assertEqual(op_selector.get_unique_graph([a0, b0]), g0)
# Different graph, should raise an error.
with self.assertRaises(ValueError):
op_selector.get_unique_graph([a0, b0, a1, b1])
def test_unique_graph_func_graph(self):
"""Test for get_unique_graph with FuncGraph."""
outer = ops_lib.Graph()
with outer.as_default():
k1 = constant_op.constant(1)
inner = func_graph.FuncGraph("inner")
inner._graph_key = outer._graph_key
with inner.as_default():
k2 = constant_op.constant(2)
unique_graph = op_selector.get_unique_graph([k1, k2])
self.assertEqual(unique_graph._graph_key, inner._graph_key)
def test_make_list_of_op(self):
"""Test for make_list_of_op."""
g0 = ops_lib.Graph()
with g0.as_default():
a0 = constant_op.constant(1)
b0 = constant_op.constant(2)
# Should extract the ops from the graph.
self.assertEqual(len(op_selector.make_list_of_op(g0)), 2)
# Should extract the ops from the tuple.
self.assertEqual(len(op_selector.make_list_of_op((a0.op, b0.op))), 2)
def test_make_list_of_t(self):
"""Test for make_list_of_t."""
g0 = ops_lib.Graph()
with g0.as_default():
a0 = constant_op.constant(1)
b0 = constant_op.constant(2)
c0 = math_ops.add(a0, b0) # pylint: disable=unused-variable
# Should extract the tensors from the graph.
self.assertEqual(len(op_selector.make_list_of_t(g0)), 3)
# Should extract the tensors from the tuple
self.assertEqual(len(op_selector.make_list_of_t((a0, b0))), 2)
# Should extract the tensors and ignore the ops.
self.assertEqual(
len(op_selector.make_list_of_t(
(a0, a0.op, b0), ignore_ops=True)), 2)
def test_get_generating_consuming(self):
"""Test for get_generating_ops and get_consuming_ops."""
g0 = ops_lib.Graph()
with g0.as_default():
a0 = constant_op.constant(1)
b0 = constant_op.constant(2)
c0 = math_ops.add(a0, b0)
self.assertEqual(len(op_selector.get_generating_ops([a0, b0])), 2)
self.assertEqual(len(op_selector.get_consuming_ops([a0, b0])), 1)
self.assertEqual(len(op_selector.get_generating_ops([c0])), 1)
self.assertEqual(op_selector.get_consuming_ops([c0]), [])
def test_backward_walk_ops(self):
seed_ops = [self.h.op]
# Include all ops except for self.g.op
within_ops = [
x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h]
]
# For the fn, exclude self.c.op.
within_ops_fn = lambda op: op not in (self.c.op,)
stop_at_ts = (self.f,)
with self.graph.as_default():
# Backward walk only includes h since we stop at f and g is not within.
ops = op_selector.get_backward_walk_ops(
seed_ops,
inclusive=True,
within_ops=within_ops,
within_ops_fn=within_ops_fn,
stop_at_ts=stop_at_ts)
self.assertEqual(set(ops), set([self.h.op]))
# If we do inclusive=False, the result is empty.
ops = op_selector.get_backward_walk_ops(
seed_ops,
inclusive=False,
within_ops=within_ops,
within_ops_fn=within_ops_fn,
stop_at_ts=stop_at_ts)
self.assertEqual(set(ops), set())
# Removing stop_at_fs adds f.op, d.op.
ops = op_selector.get_backward_walk_ops(
seed_ops,
inclusive=True,
within_ops=within_ops,
within_ops_fn=within_ops_fn)
self.assertEqual(set(ops), set([self.d.op, self.f.op, self.h.op]))
# Not using within_ops_fn adds back ops for a, b, c.
ops = op_selector.get_backward_walk_ops(
seed_ops, inclusive=True, within_ops=within_ops)
self.assertEqual(
set(ops),
set([
self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.h.op
]))
# Vanially backward search via self.h.op includes everything except e.op.
ops = op_selector.get_backward_walk_ops(seed_ops, inclusive=True)
self.assertEqual(
set(ops),
set([
self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.g.op,
self.h.op
]))
if __name__ == "__main__":
test.main()