| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| # pylint: disable=unidiomatic-typecheck |
| """Prototype decorator for defining legacy-graph-mode functions.""" |
| |
| import weakref |
| |
| from tensorflow.core.function.polymorphism import function_type as function_type_lib |
| from tensorflow.core.protobuf import meta_graph_pb2 |
| from tensorflow.core.protobuf import struct_pb2 |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import function |
| from tensorflow.python.eager import lift_to_graph |
| from tensorflow.python.framework import composite_tensor |
| from tensorflow.python.framework import func_graph |
| from tensorflow.python.framework import importer |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import sparse_tensor |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_spec |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.saved_model import nested_structure_coder |
| from tensorflow.python.trackable import data_structures |
| from tensorflow.python.util import nest |
| from tensorflow.python.util.tf_export import tf_export |
| |
| |
| class VariableHolder(object): |
| """Holds variables for a python function.""" |
| |
| def __init__(self, fn=None, share_variables=False): |
| self._fn = fn |
| |
| self._share_variables = share_variables |
| self._variables_by_name = data_structures.Mapping() |
| |
| @property |
| def variables(self): |
| return self._variables_by_name |
| |
| def variable_creator_scope(self, next_creator, **kwargs): |
| """Creates variables & adds them to collections to match legacy code.""" |
| collections = kwargs.pop("collections", None) |
| v = None |
| |
| # Get expected variable name. |
| with ops.name_scope( |
| kwargs.get("name", None), "Variable", skip_on_eager=False) as name: |
| variable_name = ops.name_from_scope_name(name) |
| kwargs["name"] = name |
| |
| if self._share_variables: |
| v = self._variables_by_name.get(variable_name, None) |
| |
| if v is None: |
| v = next_creator(**kwargs) |
| self._variables_by_name[variable_name] = v |
| |
| if collections is None: |
| collections = [ops.GraphKeys.GLOBAL_VARIABLES] |
| if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: |
| collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] |
| |
| ops.add_to_collections(collections, v) |
| |
| return v |
| |
| def __call__(self, *args, **kwargs): |
| return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs) |
| |
| def call_with_variable_creator_scope(self, fn): |
| |
| def wrapped(*args, **kwargs): |
| with variable_scope.variable_creator_scope(self.variable_creator_scope): |
| return fn(*args, **kwargs) |
| |
| return wrapped |
| |
| |
| def _get_element_from_tensor_info(tensor_info, graph): |
| """Simplified copy of the deprecated `get_tensor_from_tensor_info`.""" |
| encoding = tensor_info.WhichOneof("encoding") |
| if encoding == "name": |
| # We may get operations here in some cases. TensorInfo is a bit of a |
| # misnomer if so. |
| return graph.as_graph_element(tensor_info.name) |
| elif encoding == "coo_sparse": |
| return sparse_tensor.SparseTensor( |
| graph.get_tensor_by_name(tensor_info.coo_sparse.indices_tensor_name), |
| graph.get_tensor_by_name(tensor_info.coo_sparse.values_tensor_name), |
| graph.get_tensor_by_name( |
| tensor_info.coo_sparse.dense_shape_tensor_name)) |
| elif encoding == "composite_tensor": |
| spec_proto = struct_pb2.StructuredValue( |
| type_spec_value=tensor_info.composite_tensor.type_spec) |
| spec = nested_structure_coder.decode_proto(spec_proto) |
| components = [graph.get_tensor_by_name(component.name) for component in |
| tensor_info.composite_tensor.components] |
| return spec._from_components(components) # pylint: disable=protected-access |
| else: |
| raise ValueError(f"Invalid TensorInfo.encoding: {encoding}. Valid " |
| "encodings are 'name', 'coo_sparse', and " |
| "'composite_tensor'.") |
| |
| |
| def _lift_single_variable(old_variable, graph, variable_holder): |
| """Lifts `old_variable` out of the `FuncGraph` `graph`.""" |
| new_variable = resource_variable_ops.UninitializedVariable( |
| shape=old_variable.shape, |
| dtype=old_variable.dtype, |
| name=old_variable.op.name, |
| trainable=old_variable.trainable, |
| extra_handle_data=old_variable.handle) |
| new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access |
| graph.add_capture(new_variable.handle, old_variable.handle) |
| # Now that we've added the new variable to graph.captures, |
| # graph.capture will use that cached value and do some post-processing |
| # on the capture like recording it on the tape. |
| graph.capture(new_variable.handle) |
| # pylint: disable=protected-access |
| variable_name = new_variable.name.split(":")[0] |
| variable_holder._variables_by_name[variable_name] = new_variable |
| graph._weak_variables.append(weakref.ref(new_variable)) |
| # pylint: enable=protected-access |
| graph.watch_variable(new_variable) |
| return new_variable |
| |
| |
| def _lift_unlifted_variables(graph, variable_holder): |
| """Finds resource variables and lifts them into the outer context. |
| |
| When we import a GraphDef inside a wrap_function, no Python graph building |
| code runs. This means we get VarHandleOps which create variable resources, |
| but no corresponding Python objects. Leaving them like this works but gives |
| the user no way to interact with or modify the variables outside the graph. |
| |
| This method searches for variables and lifts them out as regular variable |
| objects when possible, indicating to the FuncGraph that they are captures. |
| |
| Args: |
| graph: The FuncGraph to lift variables from. |
| variable_holder: A VariableHolder to record the lifted variables in. |
| """ |
| with graph.as_default(): |
| global_collection_variables = ops.get_collection( |
| ops.GraphKeys.GLOBAL_VARIABLES) |
| local_collection_variables = ops.get_collection( |
| ops.GraphKeys.LOCAL_VARIABLES) |
| existing_captures = {id(c) for c in graph.internal_captures} |
| lifted_variables = {} |
| |
| def _should_lift_variable(v): |
| return ((v._in_graph_mode # pylint: disable=protected-access |
| and v.graph.building_function) |
| and isinstance(v, resource_variable_ops.BaseResourceVariable) |
| and id(v.handle) not in existing_captures) |
| |
| for old_variable in global_collection_variables: |
| if _should_lift_variable(old_variable): |
| new_variable = _lift_single_variable( |
| old_variable, graph, variable_holder) |
| lifted_variables[id(old_variable)] = new_variable |
| existing_captures.add(id(old_variable.handle)) |
| |
| for old_variable in local_collection_variables: |
| if _should_lift_variable(old_variable): |
| new_variable = _lift_single_variable( |
| old_variable, graph, variable_holder) |
| lifted_variables[id(old_variable)] = new_variable |
| existing_captures.add(id(old_variable.handle)) |
| if new_variable._in_graph_mode: # pylint: disable=protected-access |
| outer_graph = new_variable.graph |
| # Variables are added to the global collection by default. In this |
| # case we only want the variable in the local collection, so we'll pop |
| # it out. |
| global_collection = outer_graph.get_collection_ref( |
| ops.GraphKeys.GLOBAL_VARIABLES) |
| global_collection.remove(new_variable) |
| outer_graph.add_to_collection( |
| ops.GraphKeys.LOCAL_VARIABLES, new_variable) |
| |
| # Update the FuncGraph's collections, partly for the user and partly so this |
| # function is idempotent when it runs again in prune() calls. |
| for collection_name in [ |
| ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES |
| ]: |
| mutable_collection = ops.get_collection_ref(collection_name) |
| for index, current in enumerate(mutable_collection): |
| mutable_collection[index] = lifted_variables.get(id(current), current) |
| if not resource_variable_ops.is_resource_variable( |
| mutable_collection[index]): |
| logging.log_first_n( |
| logging.WARN, |
| "Unable to create a python object for variable {} because it is " |
| "a reference variable. It may not be visible to training APIs. " |
| "If this is a problem, consider rebuilding the SavedModel after " |
| "running tf.compat.v1.enable_resource_variables().".format( |
| mutable_collection[index]), |
| 5) |
| |
| |
| # TODO(allenl): make this trackable |
| class WrappedFunction(function.ConcreteFunction): |
| """Wraps a tf V1 piece of code in a function.""" |
| |
| def __init__(self, fn_graph, variable_holder, attrs=None, signature=None): |
| self._variable_holder = variable_holder |
| _lift_unlifted_variables(fn_graph, variable_holder) |
| # We call __init__ after lifting variables so that the function's signature |
| # properly reflects the new captured inputs. |
| for f in fn_graph.as_graph_def().library.function: |
| context.context().add_function_def(f) |
| self._signature = signature |
| function_type = function_type_lib.from_structured_signature( |
| fn_graph.structured_input_signature, |
| fn_graph.structured_outputs, |
| fn_graph.function_captures.capture_types, |
| ) |
| super().__init__( |
| fn_graph, attrs=attrs, function_type=function_type |
| ) |
| |
| def _call_impl(self, args, kwargs): |
| if self._arg_keywords is None: |
| if kwargs: |
| raise NotImplementedError( |
| "Keyword arguments are not supported when calling a " |
| f"wrap_function-decorated function. Got {kwargs}.") |
| if self._signature is not None: |
| args = list(args) |
| for i, arg in enumerate(args): |
| if isinstance(self._signature[i], tensor_spec.DenseSpec): |
| args[i] = ops.convert_to_tensor(arg, self._signature[i].dtype) |
| return self._call_flat(args, self.captured_inputs) |
| else: |
| return super()._call_impl(args, kwargs) |
| |
| def prune(self, feeds, fetches, name=None, input_signature=None): |
| """Extract a subgraph of this function's underlying graph. |
| |
| Wraps the subgraph in a new `WrappedFunction` object. |
| |
| Args: |
| feeds: Input tensors to the subgraph to extract, as `Tensor` objects. |
| fetches: Possibly-nested Python data structure containing information |
| about outputs of the target subgraph. Each entry can either be a |
| `Tensor` object (for data outputs), an `Operation` object (for control |
| outputs), or a `TensorInfo` proto. Any additional shape/dtype |
| information provided in a `TensorInfo` and not present in the original |
| graph will be added to the returned subgraph. |
| name: (optional) Name to give to the underlying `FuncGraph` of the |
| returned object. If no name is provided, the graph's name will be |
| `"pruned"`. |
| input_signature: (optional) possibly-nested Python data structure |
| containing `TensorSpec` objects, with which to populate the returned |
| functions's `FuncGraph`'s `structured_input_signature` field. |
| |
| Returns: |
| A new `WrappedFunction` object containing a copy of the portion of this |
| object's graph that goes from `feeds` to `fetches`. |
| """ |
| # TODO(b/129646028): Add support for CompositeTensors. |
| name = name or "pruned" |
| flat_feeds = nest.flatten(feeds, expand_composites=True) |
| flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds] |
| for f in flat_feeds: |
| if not isinstance(f, ops.Tensor): |
| raise ValueError("All memebers of argument `feeds` must be tensors. " |
| f"Got {f} with type {type(f)}.") |
| |
| # Ignoring all feeds that are captures allows prune to be called |
| # using wrapped_func.inputs even when it uses variables |
| internal_captures = {id(c) for c in self.graph.internal_captures} |
| flat_feeds = [f for f in flat_feeds if id(f) not in internal_captures] |
| |
| operation_fetches = [] |
| tensor_fetches = [] |
| tensor_infos = [] |
| |
| def _fetch_preprocessing_callback(fetch): |
| """Extract out lists of ops, tensors, and tensor type info. |
| |
| Turns TensorInfos into Tensors in the original `fetches` structure. |
| Also extracts ops from `fetches`. |
| |
| Args: |
| fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or |
| string identifying a Tensor or Operation. |
| |
| Returns: |
| `fetch` converted to a Tensor. |
| """ |
| if isinstance(fetch, ops.Operation): |
| operation_fetches.append(fetch) |
| return fetch |
| elif isinstance(fetch, meta_graph_pb2.TensorInfo): |
| tensor_infos.append(fetch) |
| decoded = _get_element_from_tensor_info(fetch, self._func_graph) |
| if (tensor_util.is_tf_type(decoded) or |
| isinstance(decoded, composite_tensor.CompositeTensor)): |
| tensor_fetches.append(decoded) |
| else: |
| operation_fetches.append(decoded) |
| return decoded |
| elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)): |
| tensor_fetches.append(fetch) |
| return fetch |
| else: |
| graph_element = self.graph.as_graph_element(fetch) |
| return _fetch_preprocessing_callback(graph_element) |
| |
| fetches = nest.map_structure(_fetch_preprocessing_callback, fetches) |
| |
| # Expand composite tensors into their component dense Tensors. |
| tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True) |
| |
| for f in flat_feeds + tensor_fetches + operation_fetches: |
| if f.graph is not self._func_graph: |
| raise ValueError("Can only prune function whose feeds and fetches " |
| f"from graph {self._func_graph}. Input " |
| f"{f} is from a different graph {f.graph}.") |
| with self._func_graph.as_default(): |
| pruned_graph = func_graph.FuncGraph(name) |
| lift_map = lift_to_graph.lift_to_graph( |
| operation_fetches + tensor_fetches, |
| pruned_graph, |
| sources=flat_feeds + self.graph.internal_captures, |
| base_graph=self._func_graph) |
| |
| # Note that we add the component tensors of any composite tensors to the |
| # returned function's outputs list; the list must contain these component |
| # tensors, or the function's sparse outputs won't work properly. |
| pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) |
| pruned_graph.control_outputs.extend( |
| [lift_map[operation] for operation in operation_fetches]) |
| pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds) |
| for external_capture, internal_capture in self.graph.captures: |
| pruned_graph.add_capture(external_capture, lift_map[internal_capture]) |
| for ti in tensor_infos: |
| if ti.WhichOneof("encoding") == "name": # Dense tensors only |
| t = pruned_graph.as_graph_element(ti.name) |
| if tensor_util.is_tf_type(t): |
| t.set_shape(tensor_shape.TensorShape(ti.tensor_shape)) |
| # pylint: disable=protected-access |
| for f in self.graph._functions.values(): |
| pruned_graph._add_function(f) |
| # pylint: enable=protected-access |
| |
| pruned_graph.variables = self.graph.variables |
| |
| def _structured_output_mapping(fetched): |
| """callback for `nest.map_structure()`""" |
| lifted = lift_map[fetched] |
| if isinstance(lifted, ops.Operation): |
| return None |
| return lifted |
| |
| # expand_composites=True here causes composite tensors to be expanded |
| # into their component dense Tensors, mapped to the new graph, and then |
| # reconstituted into their original composite form. |
| pruned_graph.structured_outputs = nest.map_structure( |
| _structured_output_mapping, fetches, expand_composites=True) |
| |
| if input_signature: |
| # canonicalize the signature before setting |
| args, kwargs = input_signature |
| args = () if args is None else args |
| input_signature = (args, kwargs) |
| |
| pruned_graph.structured_input_signature = input_signature |
| pruned_fn = WrappedFunction( |
| pruned_graph, variable_holder=self._variable_holder) |
| pruned_fn._num_positional_args = len(flat_feeds) # pylint: disable=protected-access |
| # TODO(kathywu): Enable keyword arguments if an input signature is specified |
| pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds] # pylint: disable=protected-access |
| return pruned_fn |
| |
| |
| def _filter_returned_ops(fn): |
| """Filtering out any ops returned by function. |
| |
| Args: |
| fn: a function |
| |
| Returns: |
| A tuple of ( |
| Wrapped function that returns `None` in place of any ops, |
| dict that maps the index in the flat output structure to the returned op |
| ) |
| """ |
| returned_ops = {} |
| |
| def wrap_and_filter_returned_ops(*args, **kwargs): |
| outputs = fn(*args, **kwargs) |
| flat_outputs = nest.flatten(outputs) |
| for n in range(len(flat_outputs)): |
| output = flat_outputs[n] |
| if isinstance(output, ops.Operation): |
| returned_ops[n] = output |
| flat_outputs[n] = None |
| return nest.pack_sequence_as(outputs, flat_outputs) |
| |
| return wrap_and_filter_returned_ops, returned_ops |
| |
| |
| class WrappedGraph(object): |
| """Class for wrapping multiple TF 1.X functions in a single graph. |
| |
| Maintains a dictionary mapping names to wrapped functions. See |
| `tf.compat.v1.wrap_function` to learn more about wrapping V1 functions. |
| |
| Functions wrapped using this class have access to variables and collections |
| created in other wrapped functions, using the standard TF 1.X API ( |
| `tf.compat.v1.get_variable` or |
| `tf.compat.v1.get_default_graph().get_collection(...)`) |
| |
| Outside a function, variables and collections may be accessed using the |
| `variables` and `graph` properties. |
| |
| Example: |
| |
| ``` |
| def add_v1(x): |
| with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE): |
| v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) |
| return v + x |
| |
| def increment_var_v1(x): |
| with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE): |
| v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) |
| return v.assign_add(x) |
| |
| g = WrappedGraph() |
| add = g.wrap_function(add_v1, [tf.TensorSpec([], tf.int32)]) |
| increment_var = g.wrap_function(increment_var_v1, |
| [tf.TensorSpec([], tf.int32)]) |
| |
| assert len(g.variables) == 1 |
| assert g.variables[0].numpy() == 0 |
| increment_var(tf.constant(5)) |
| assert g.variables[0].numpy() == 5 |
| |
| ``` |
| """ |
| |
| def __init__(self, variable_holder=None, **kwargs): |
| self._variable_holder = ( |
| variable_holder or VariableHolder(share_variables=True)) |
| |
| name = kwargs.pop("name", "wrapped_function_graph") |
| # Always start with empty collections, unless otherwise specified. Setting |
| # `collections=None` will copy the collections from the outer graph. |
| collections = kwargs.pop("collections", {}) |
| self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs) |
| |
| self._wrapped_function = WrappedFunction(self.graph, self._variable_holder) |
| self._functions = {} |
| |
| @property |
| def functions(self): |
| return self._functions |
| |
| @property |
| def variables(self): |
| return self._variable_holder.variables |
| |
| def wrap_function(self, fn, signature, name=None): |
| """Wraps a TF 1.X function and returns an eager-compatible function. |
| |
| All functions wrapped in the same `WrappedGraph` will have access to the |
| same graph (`tf.compat.v1.get_default_graph` to get the graph object |
| within a function, or `WrappedGraph.graph` to get the graph outside a |
| function). Variables created within the function will be added to the |
| `variables` list. |
| |
| Function inputs: All inputs to the function must be tensors (nested ok), |
| with their shapes and dtypes defined in the `signature` argument. |
| |
| Function outputs: |
| |
| * The 1.X function may return tensors, variables, and ops. The wrapped |
| eager-compatible function will always return tensors in the same nested |
| structure. |
| * Variables are replaced with a tensor containing the latest read values. |
| * Returned ops are executed, and replaced with None. |
| * The order of op execution and variable reads in the return is |
| nondeterministic. For example: |
| |
| ``` |
| def update_var(x): |
| v = tf.Variable(0) |
| op = tf.compat.v1.assign(v, x).op |
| return v, op |
| |
| g = WrappedGraph() |
| fn = g.wrap_function(update_var) |
| read_value, _ = fn(tf.constant(3)) |
| print(read_value.numpy()) # could be 0 or 3 |
| print(g.variables[0].numpy()) # always 3 |
| ``` |
| |
| To ensure that ops in the function are executed (e.g. ops added to the |
| `tf.GraphKeys.UPDATE_OPS` collection), include them in the function returns. |
| |
| Args: |
| fn: a 1.X tensorflow function. |
| signature: a possibly nested sequence of `TensorSpecs` specifying the |
| shapes and dtypes of the arguments. |
| name: an optional string name for the function. The function will be saved |
| with key `name` in the `functions` dictionary. |
| |
| Returns: |
| An eager-compatible function. |
| """ |
| return self._wrap_function(fn, signature=signature, name=name) |
| |
| def _wrap_function(self, |
| fn, |
| args=None, |
| kwargs=None, |
| signature=None, |
| name=None): |
| """Internal wrap function method with extended func_graph arguments.""" |
| fn_with_filter_and_scope, returned_ops = _filter_returned_ops( |
| self._variable_holder.call_with_variable_creator_scope(fn)) |
| |
| func_graph.func_graph_from_py_func( |
| None, # Name is unused. |
| fn_with_filter_and_scope, |
| args=args, |
| kwargs=kwargs, |
| signature=signature, |
| add_control_dependencies=False, |
| func_graph=self.graph) |
| |
| # This code relies on questional behavior from `func_graph_from_py_func`. |
| # If an existing FuncGraph is passed into the `func_graph` arg, the inputs |
| # and structured outputs are overwritten. Pretty sure this is a bug, |
| # because structured outputs doesn't match up with the outputs... |
| fn_inputs = self.graph.inputs[:-len(self.graph.captures)] |
| |
| # Return filtered ops to the flattened outputs. |
| flat_fn_outputs = nest.flatten(self.graph.structured_outputs) |
| for index, op in returned_ops.items(): |
| flat_fn_outputs[index] = op |
| fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs, |
| flat_fn_outputs) |
| |
| name = name or fn.__name__ |
| wrapped_function = self._wrapped_function.prune( |
| fn_inputs, fn_outputs, name, self.graph.structured_input_signature) |
| self._functions[name] = wrapped_function |
| return wrapped_function |
| |
| |
| @tf_export(v1=["wrap_function"]) |
| def wrap_function(fn, signature, name=None): |
| """Wraps the TF 1.x function fn into a graph function. |
| |
| The python function `fn` will be called once with symbolic arguments specified |
| in the `signature`, traced, and turned into a graph function. Any variables |
| created by `fn` will be owned by the object returned by `wrap_function`. The |
| resulting graph function can be called with tensors which match the |
| signature. |
| |
| ```python |
| def f(x, do_add): |
| v = tf.Variable(5.0) |
| if do_add: |
| op = v.assign_add(x) |
| else: |
| op = v.assign_sub(x) |
| with tf.control_dependencies([op]): |
| return v.read_value() |
| |
| f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True]) |
| |
| assert float(f_add(1.0)) == 6.0 |
| assert float(f_add(1.0)) == 7.0 |
| |
| # Can call tf.compat.v1.wrap_function again to get a new trace, a new set |
| # of variables, and possibly different non-template arguments. |
| f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False]) |
| |
| assert float(f_sub(1.0)) == 4.0 |
| assert float(f_sub(1.0)) == 3.0 |
| ``` |
| |
| Both `tf.compat.v1.wrap_function` and `tf.function` create a callable |
| TensorFlow graph. But while `tf.function` runs all stateful operations |
| (e.g. `tf.print`) and sequences operations to provide the same semantics as |
| eager execution, `wrap_function` is closer to the behavior of `session.run` in |
| TensorFlow 1.x. It will not run any operations unless they are required to |
| compute the function's outputs, either through a data dependency or a control |
| dependency. Nor will it sequence operations. |
| |
| Unlike `tf.function`, `wrap_function` will only trace the Python function |
| once. As with placeholders in TF 1.x, shapes and dtypes must be provided to |
| `wrap_function`'s `signature` argument. |
| |
| Since it is only traced once, variables and state may be created inside the |
| function and owned by the function wrapper object. |
| |
| Args: |
| fn: python function to be wrapped |
| signature: the placeholder and python arguments to be passed to the wrapped |
| function |
| name: Optional. The name of the function. |
| |
| Returns: |
| the wrapped graph function. |
| """ |
| holder = VariableHolder(fn) |
| func_graph_name = "wrapped_function" |
| if name is not None: |
| func_graph_name = "wrapped_function_" + name |
| return WrappedFunction( |
| func_graph.func_graph_from_py_func( |
| func_graph_name, |
| holder, |
| args=None, |
| kwargs=None, |
| signature=signature, |
| add_control_dependencies=False, |
| collections={}), |
| variable_holder=holder, |
| signature=signature) |
| |
| |
| def function_from_graph_def(graph_def, inputs, outputs, captures=None): |
| """Creates a ConcreteFunction from a GraphDef. |
| |
| Args: |
| graph_def: A GraphDef to make a function out of. |
| inputs: A Tensor name or nested structure of names in `graph_def` which |
| should be inputs to the function. |
| outputs: A Tensor name or nested structure of names in `graph_def` which |
| should be outputs of the function. |
| captures: (Optional) A dictionary mapping node names in `graph_def` that |
| should be captured as inputs to tensors containing the value of the |
| captured inputs. |
| |
| Returns: |
| A ConcreteFunction. |
| """ |
| |
| def _imports_graph_def(): |
| importer.import_graph_def(graph_def, name="") |
| graph = ops.get_default_graph() |
| if captures is not None: |
| for c in captures: |
| graph.add_capture(captures[c], graph.get_tensor_by_name(str(c) + ":0")) |
| |
| wrapped_import = wrap_function(_imports_graph_def, []) |
| import_graph = wrapped_import.graph |
| return wrapped_import.prune( |
| nest.map_structure(import_graph.as_graph_element, inputs), |
| nest.map_structure(import_graph.as_graph_element, outputs)) |