| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| # pylint: disable=unidiomatic-typecheck |
| """Utility to lift subgraphs.""" |
| |
| import collections |
| |
| from tensorflow.python.framework import func_graph |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import op_selector |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.util import compat |
| from tensorflow.python.util import object_identity |
| from tensorflow.python.util.tf_export import tf_export |
| |
| |
| UnliftableError = op_selector.UnliftableError |
| |
| |
| def _as_operation(op_or_tensor): |
| if isinstance(op_or_tensor, ops.Tensor): |
| return op_or_tensor.op |
| return op_or_tensor |
| |
| |
| def _constant_inputs(op_or_tensor): |
| return all(_as_operation(i).type == u"Const" |
| and not _as_operation(i).control_inputs |
| for i in op_selector.graph_inputs(_as_operation(op_or_tensor))) |
| |
| |
| # Represents an input to `copied_op` which must be updated once |
| # `old_graph_tensor` has been copied. |
| _InputMutation = collections.namedtuple( |
| "_InputMutation", |
| ["copied_op", "input_index", "old_graph_tensor"]) |
| |
| |
| # Represents a control input to `copied_op` which must be added once |
| # `old_graph_op` has been copied. |
| _ControlMutation = collections.namedtuple( |
| "_ControlMutation", |
| ["copied_op", "old_graph_op"]) |
| |
| |
| def _copy_non_source(op, graph, op_map, base_graph): |
| """Copy an op directly to a given graph. |
| |
| Generally `op`'s inputs should already have been copied. If this is not the |
| case, for example with v1 while_loops, then `_copy_non_source` inserts |
| placeholders for the unavailable Tensors and returns a list of required |
| mutations. |
| |
| Args: |
| op: The op to be copied. |
| graph: The destination graph. |
| op_map: A dict mapping ops and tensors in the old graph to the new one. |
| base_graph: The graph we're copying from, for any necessary functions. |
| Returns: |
| A tuple of (required_inputs, required_control_inputs): |
| required_inputs: |
| A list of `_InputMutation` tuples containing inputs to `copied_op` which |
| must be updated once `old_graph_tensor` has been copied. |
| required_control_inputs: |
| A list of `_ControlMutation` tuples containing control inputs to |
| `copied_op` which must be added once `old_graph_op` has been copied. |
| """ |
| input_mutations = [] |
| control_mutations = [] |
| copied_inputs = [] |
| for input_index, original_input in enumerate(op.inputs): |
| copied_input = op_map.get(original_input, None) |
| if copied_input is None: |
| # An input for this op is missing due to a loop in the graph. We'll insert |
| # a placeholder for now and return information about the required post-hoc |
| # mutation. |
| copied_input = array_ops.placeholder( |
| name="unused_control_flow_input", |
| shape=original_input.shape, |
| dtype=original_input.dtype) |
| input_mutations.append( |
| # `copied_op` is filled in below, after we've created it. |
| _InputMutation(copied_op=None, |
| input_index=input_index, |
| old_graph_tensor=original_input)) |
| copied_inputs.append(copied_input) |
| |
| copied_control_inputs = [] |
| for original_control_input in op.control_inputs: |
| copied_control_input = op_map.get(original_control_input, None) |
| if copied_control_input is None: |
| control_mutations.append( |
| _ControlMutation(copied_op=None, |
| old_graph_op=original_control_input)) |
| else: |
| copied_control_inputs.append(copied_control_input) |
| |
| # Don't copy over nodes with _tpu_replicate attribute. This attributed is used |
| # to signal that the op was built inside a tpu_replicate context; if we're |
| # lifting it to another graph we're similarly lifting it into another context. |
| with ops.control_dependencies(copied_control_inputs), ops.device(op.device): |
| # pylint: disable=protected-access |
| f = base_graph._functions.get(op.type, None) |
| if f is not None and compat.as_str(f.name) not in graph._functions: |
| f.add_to_graph(graph) |
| # pylint: enable=protected-access |
| |
| # Create a new op in the destination graph if it doesn't exist before. |
| copied_op = graph.create_op( |
| op_type=op.type, |
| inputs=copied_inputs, |
| dtypes=[x.dtype for x in op.outputs], |
| attrs={ |
| key: value for key, value in op.node_def.attr.items() |
| if not key.startswith("_class") and |
| not key.startswith("_tpu_replicate") |
| }, # b/128981532. |
| name=op.name) |
| op_map[op] = copied_op |
| for i, o in enumerate(op.outputs): |
| op_map[o] = copied_op.outputs[i] |
| |
| return ([mutation._replace(copied_op=copied_op) |
| for mutation in input_mutations], |
| [mutation._replace(copied_op=copied_op) |
| for mutation in control_mutations]) |
| |
| |
| def _copy_source(s, graph, op_map, handle_captures, inverse_captures, |
| base_graph): |
| """Create a source in a graph based on a Tensor from a different graph. |
| |
| This function creates a placeholder analog of `s` in a graph with the |
| following behavior: |
| |
| 1) If s is a captured Tensor or Variable and handle_captures is set to True, |
| simply capture it in the new graph as well. |
| |
| 2) If s is a PlaceholderWithDefault whose default is a constant, preserve |
| said default in the new graph. |
| |
| 3) When applicable, copy resource variable metadata from `s` to the newly |
| created placeholder. |
| |
| Args: |
| s: The source of interest. |
| graph: The destination graph. |
| op_map: A dict mapping ops and tensors in the old graph to the new one. |
| handle_captures: A boolean indicating whether to re-capture s in the new |
| graph or simply create a vanilla placeholder. |
| inverse_captures: A dict mapping s back to the Tensor or Variable that it |
| captures. |
| base_graph: The graph being copied from. |
| """ |
| if handle_captures and s in inverse_captures: |
| copied_placeholder = graph.capture(inverse_captures[s], name=s.op.name) |
| elif s.op.type == "PlaceholderWithDefault" and _constant_inputs(s): |
| # Copy the default value to the graph. |
| default_value = s.op.inputs[0] |
| unavailable_inputs, unavailable_control_inputs = _copy_non_source( |
| op=default_value.op, graph=graph, op_map=op_map, |
| base_graph=base_graph) |
| if unavailable_inputs or unavailable_control_inputs: |
| raise AssertionError( |
| "Could not copy source node {} because it has inputs." |
| .format(default_value)) |
| |
| with ops.device(s.op.device): |
| copied_placeholder = array_ops.placeholder_with_default( |
| input=op_map[default_value], shape=s.shape, name=s.op.name) |
| else: |
| with ops.device(s.op.device): |
| copied_placeholder = array_ops.placeholder( |
| dtype=s.dtype, shape=s.shape, name=s.op.name) |
| |
| base_handle = resource_variable_ops.get_resource_handle_data(s) |
| if base_handle.shape_and_type: |
| resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access |
| copied_placeholder, |
| base_handle, |
| graph_mode=True) |
| |
| op_map[s] = copied_placeholder |
| # Add an entry for the op of the source tensor so that if there are any nodes |
| # depending on that op via control dependencies it can work correctly. |
| op_map[s.op] = copied_placeholder.op |
| |
| |
| @tf_export("__internal__.lift_to_graph", v1=[]) |
| def lift_to_graph(tensors, |
| graph, |
| sources=None, |
| disallowed_placeholders=None, |
| add_sources=False, |
| handle_captures=False, |
| base_graph=None, |
| op_map=None): |
| """Copies the tensor and all its inputs recursively to the outer graph. |
| |
| Args: |
| tensors: The Tensors to lift. |
| graph: The graph to lift to. |
| sources: Optional sequence of nodes to start from. If omitted the whole |
| subgraph which feeds into `init_tensor` is lifted. |
| disallowed_placeholders: An optional set of ops which may not appear in the |
| lifted graph. Defaults to all placeholders. |
| add_sources: A boolean indicating whether placeholders which are not in |
| sources should be allowed. |
| handle_captures: A boolean indicating whether to re-capture s in the new |
| graph or simply create a vanilla placeholder. |
| base_graph: The graph from which to lift ops. This will be inferred if not |
| specified. |
| op_map: A map contains all the existing nodes that have been lifted to the |
| destination graph, so they won't be lifted and copied again. |
| |
| Returns: |
| A mapping from ops in the current default graph to ops in `graph`. |
| |
| Raises: |
| UnliftableError: If a placeholder blocks lifting. |
| """ |
| variable_init_tensors = [] |
| init_tensors = [] |
| for tensor in tensors: |
| if isinstance(tensor, resource_variable_ops.ResourceVariable): |
| variable_init_tensors.append(tensor) |
| else: |
| init_tensors.append(tensor) |
| base_graph = base_graph or init_tensors[0].graph |
| op_map = op_map or object_identity.ObjectIdentityDictionary() |
| |
| # Check that the initializer does not depend on any placeholders. |
| sources = object_identity.ObjectIdentitySet(sources or []) |
| visited_ops = set(x.op for x in sources) |
| op_outputs = collections.defaultdict(set) |
| |
| # First we extract the subgraph between init_tensors and sources. |
| for init_tensor in init_tensors: |
| sources.update(op_selector.map_subgraph( |
| init_tensor=init_tensor, |
| sources=sources, |
| disallowed_placeholders=disallowed_placeholders, |
| visited_ops=visited_ops, |
| op_outputs=op_outputs, |
| add_sources=add_sources)) |
| |
| # Try to topologically sort the nodes we've extracted. Now we know how many of |
| # their outputs are part of this subgraph. |
| ops_to_copy = [] |
| marked_ops = set([]) |
| ops_to_visit = [_as_operation(t) for t in init_tensors |
| if not op_outputs[_as_operation(t)]] |
| unvisited_ops = set(ops_to_visit) |
| while unvisited_ops: |
| while ops_to_visit: |
| op = ops_to_visit.pop() |
| if op in marked_ops: |
| continue |
| marked_ops.add(op) |
| ops_to_copy.append(op) |
| for inp in op_selector.graph_inputs(op): |
| # Don't lift the TPUReplicateMetadata nodes out of the function, because |
| # it has no registered kernels. |
| if inp.type == "TPUReplicateMetadata": |
| continue |
| unvisited_ops.add(inp) |
| if (all(x in marked_ops for x in op_outputs[inp]) and |
| inp not in sources): |
| ops_to_visit.append(inp) |
| unvisited_ops.difference_update(marked_ops) |
| if unvisited_ops: |
| # `unvisited_ops` should only have elements if the graph has a loop. In |
| # this case we want to keep copying and there's no topological ordering; |
| # we'll do ugly post-hoc mutations instead. |
| ops_to_visit.append(next(iter(unvisited_ops))) |
| |
| # When the topological sort fails due to loops, it can result in exceptions |
| # later when copying a node which inputs haven't been copied yet. We can |
| # improve that pseudo-topological order slightly by putting the ops without |
| # inputs, such as constants, at the start of the topological order (i.e at |
| # the end of ops_to_copy). |
| ops_to_copy.sort(key=(lambda op: len(op_selector.graph_inputs(op)) == 0)) |
| |
| # When lifting from one FuncGraph to another, we will need to capture the |
| # relevant tensors as well. |
| captures = [] |
| inverse_captures = object_identity.ObjectIdentityDictionary() |
| internal_captures = [] |
| if (isinstance(base_graph, func_graph.FuncGraph) and |
| isinstance(graph, func_graph.FuncGraph)): |
| captures = base_graph.captures |
| for external_capture, internal_capture in captures: |
| inverse_captures[internal_capture] = external_capture |
| internal_captures = base_graph.internal_captures |
| |
| # ops_to_copy now holds a reverse topologically sorted list of ops which |
| # ends in the initializer. We copy those to the outermost graph and |
| # build the initialization op there. |
| with graph.as_default(): |
| for i in variable_init_tensors: |
| op_map[i] = i |
| source_ops = set() |
| # Add the sources in the same order as the original graph. |
| for s in internal_captures: |
| if s in sources: |
| sources.remove(s) |
| source_ops.add(s.op) |
| _copy_source( |
| s=s, |
| graph=graph, |
| op_map=op_map, |
| handle_captures=handle_captures, |
| inverse_captures=inverse_captures, |
| base_graph=base_graph) |
| for s in sources: |
| source_ops.add(s.op) |
| _copy_source( |
| s=s, |
| graph=graph, |
| op_map=op_map, |
| handle_captures=handle_captures, |
| inverse_captures=inverse_captures, |
| base_graph=base_graph) |
| |
| input_mutations = [] |
| control_mutations = [] |
| for op in reversed(ops_to_copy): |
| if op in source_ops or op in op_map: |
| continue |
| new_input_mutations, new_control_mutations = _copy_non_source( |
| op=op, graph=graph, op_map=op_map, base_graph=base_graph) |
| input_mutations.extend(new_input_mutations) |
| control_mutations.extend(new_control_mutations) |
| |
| # Mutate the new graph to insert any loops which existed in the source |
| # graph due to v1 while_loops. |
| # |
| # pylint: disable=protected-access |
| with graph._mutation_lock(): |
| for mutation in input_mutations: |
| mutation.copied_op._update_input( |
| mutation.input_index, op_map[mutation.old_graph_tensor]) |
| for mutation in control_mutations: |
| # Don't lift the TPUReplicateMetadata nodes out of the function, because |
| # it has no registered kernels. |
| if mutation.old_graph_op.type == "TPUReplicateMetadata": |
| continue |
| mutation.copied_op._add_control_input(op_map[mutation.old_graph_op]) |
| # pylint: enable=protected-access |
| |
| return op_map |