xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/wrap_function.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Prototype decorator for defining legacy-graph-mode functions."""
17
18import weakref
19
20from tensorflow.core.protobuf import meta_graph_pb2
21from tensorflow.core.protobuf import struct_pb2
22from tensorflow.python.eager import context
23from tensorflow.python.eager import function
24from tensorflow.python.eager import lift_to_graph
25from tensorflow.python.framework import composite_tensor
26from tensorflow.python.framework import func_graph
27from tensorflow.python.framework import importer
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import sparse_tensor
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.ops import resource_variable_ops
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.saved_model import nested_structure_coder
37from tensorflow.python.trackable import data_structures
38from tensorflow.python.util import nest
39from tensorflow.python.util.tf_export import tf_export
40
41
42class VariableHolder(object):
43  """Holds variables for a python function."""
44
45  def __init__(self, fn=None, share_variables=False):
46    self._fn = fn
47
48    self._share_variables = share_variables
49    self._variables_by_name = data_structures.Mapping()
50
51  @property
52  def variables(self):
53    return self._variables_by_name
54
55  def variable_creator_scope(self, next_creator, **kwargs):
56    """Creates variables & adds them to collections to match legacy code."""
57    collections = kwargs.pop("collections", None)
58    v = None
59
60    # Get expected variable name.
61    with ops.name_scope(
62        kwargs.get("name", None), "Variable", skip_on_eager=False) as name:
63      variable_name = ops.name_from_scope_name(name)
64      kwargs["name"] = name
65
66    if self._share_variables:
67      v = self._variables_by_name.get(variable_name, None)
68
69    if v is None:
70      v = next_creator(**kwargs)
71      self._variables_by_name[variable_name] = v
72
73    if collections is None:
74      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
75    if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
76      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
77
78    ops.add_to_collections(collections, v)
79
80    return v
81
82  def __call__(self, *args, **kwargs):
83    return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs)
84
85  def call_with_variable_creator_scope(self, fn):
86
87    def wrapped(*args, **kwargs):
88      with variable_scope.variable_creator_scope(self.variable_creator_scope):
89        return fn(*args, **kwargs)
90
91    return wrapped
92
93
94def _get_element_from_tensor_info(tensor_info, graph):
95  """Simplified copy of the deprecated `get_tensor_from_tensor_info`."""
96  encoding = tensor_info.WhichOneof("encoding")
97  if encoding == "name":
98    # We may get operations here in some cases. TensorInfo is a bit of a
99    # misnomer if so.
100    return graph.as_graph_element(tensor_info.name)
101  elif encoding == "coo_sparse":
102    return sparse_tensor.SparseTensor(
103        graph.get_tensor_by_name(tensor_info.coo_sparse.indices_tensor_name),
104        graph.get_tensor_by_name(tensor_info.coo_sparse.values_tensor_name),
105        graph.get_tensor_by_name(
106            tensor_info.coo_sparse.dense_shape_tensor_name))
107  elif encoding == "composite_tensor":
108    spec_proto = struct_pb2.StructuredValue(
109        type_spec_value=tensor_info.composite_tensor.type_spec)
110    spec = nested_structure_coder.decode_proto(spec_proto)
111    components = [graph.get_tensor_by_name(component.name) for component in
112                  tensor_info.composite_tensor.components]
113    return spec._from_components(components)  # pylint: disable=protected-access
114  else:
115    raise ValueError(f"Invalid TensorInfo.encoding: {encoding}. Valid "
116                     "encodings are 'name', 'coo_sparse', and "
117                     "'composite_tensor'.")
118
119
120def _lift_single_variable(old_variable, graph, variable_holder):
121  """Lifts `old_variable` out of the `FuncGraph` `graph`."""
122  new_variable = resource_variable_ops.UninitializedVariable(
123      shape=old_variable.shape,
124      dtype=old_variable.dtype,
125      name=old_variable.op.name,
126      trainable=old_variable.trainable,
127      extra_handle_data=old_variable.handle)
128  new_variable._initializer_op = old_variable._initializer_op  # pylint: disable=protected-access
129  graph.add_capture(new_variable.handle, old_variable.handle)
130  # Now that we've added the new variable to graph.captures,
131  # graph.capture will use that cached value and do some post-processing
132  # on the capture like recording it on the tape.
133  graph.capture(new_variable.handle)
134  # pylint: disable=protected-access
135  variable_name = new_variable.name.split(":")[0]
136  variable_holder._variables_by_name[variable_name] = new_variable
137  graph._weak_variables.append(weakref.ref(new_variable))
138  # pylint: enable=protected-access
139  graph.watch_variable(new_variable)
140  return new_variable
141
142
143def _lift_unlifted_variables(graph, variable_holder):
144  """Finds resource variables and lifts them into the outer context.
145
146  When we import a GraphDef inside a wrap_function, no Python graph building
147  code runs. This means we get VarHandleOps which create variable resources,
148  but no corresponding Python objects. Leaving them like this works but gives
149  the user no way to interact with or modify the variables outside the graph.
150
151  This method searches for variables and lifts them out as regular variable
152  objects when possible, indicating to the FuncGraph that they are captures.
153
154  Args:
155    graph: The FuncGraph to lift variables from.
156    variable_holder: A VariableHolder to record the lifted variables in.
157  """
158  with graph.as_default():
159    global_collection_variables = ops.get_collection(
160        ops.GraphKeys.GLOBAL_VARIABLES)
161    local_collection_variables = ops.get_collection(
162        ops.GraphKeys.LOCAL_VARIABLES)
163    existing_captures = {id(c) for c in graph.internal_captures}
164    lifted_variables = {}
165
166    def _should_lift_variable(v):
167      return ((v._in_graph_mode  # pylint: disable=protected-access
168               and v.graph.building_function)
169              and isinstance(v, resource_variable_ops.BaseResourceVariable)
170              and id(v.handle) not in existing_captures)
171
172    for old_variable in global_collection_variables:
173      if _should_lift_variable(old_variable):
174        new_variable = _lift_single_variable(
175            old_variable, graph, variable_holder)
176        lifted_variables[id(old_variable)] = new_variable
177        existing_captures.add(id(old_variable.handle))
178
179    for old_variable in local_collection_variables:
180      if _should_lift_variable(old_variable):
181        new_variable = _lift_single_variable(
182            old_variable, graph, variable_holder)
183        lifted_variables[id(old_variable)] = new_variable
184        existing_captures.add(id(old_variable.handle))
185        if new_variable._in_graph_mode:  # pylint: disable=protected-access
186          outer_graph = new_variable.graph
187          # Variables are added to the global collection by default. In this
188          # case we only want the variable in the local collection, so we'll pop
189          # it out.
190          global_collection = outer_graph.get_collection_ref(
191              ops.GraphKeys.GLOBAL_VARIABLES)
192          global_collection.remove(new_variable)
193          outer_graph.add_to_collection(
194              ops.GraphKeys.LOCAL_VARIABLES, new_variable)
195
196    # Update the FuncGraph's collections, partly for the user and partly so this
197    # function is idempotent when it runs again in prune() calls.
198    for collection_name in [
199        ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES
200    ]:
201      mutable_collection = ops.get_collection_ref(collection_name)
202      for index, current in enumerate(mutable_collection):
203        mutable_collection[index] = lifted_variables.get(id(current), current)
204        if not resource_variable_ops.is_resource_variable(
205            mutable_collection[index]):
206          logging.log_first_n(
207              logging.WARN,
208              "Unable to create a python object for variable {} because it is "
209              "a reference variable. It may not be visible to training APIs. "
210              "If this is a problem, consider rebuilding the SavedModel after "
211              "running tf.compat.v1.enable_resource_variables().".format(
212                  mutable_collection[index]),
213              5)
214
215
216# TODO(allenl): make this trackable
217class WrappedFunction(function.ConcreteFunction):
218  """Wraps a tf V1 piece of code in a function."""
219
220  def __init__(self, fn_graph, variable_holder, attrs=None, signature=None):
221    self._variable_holder = variable_holder
222    _lift_unlifted_variables(fn_graph, variable_holder)
223    # We call __init__ after lifting variables so that the function's signature
224    # properly reflects the new captured inputs.
225    for f in fn_graph.as_graph_def().library.function:
226      context.context().add_function_def(f)
227    self._signature = signature
228    super(WrappedFunction, self).__init__(fn_graph, attrs=attrs)
229
230  def _call_impl(self, args, kwargs, cancellation_manager=None):
231    if self._arg_keywords is None:
232      if kwargs:
233        raise NotImplementedError(
234            "Keyword arguments are not supported when calling a "
235            f"wrap_function-decorated function. Got {kwargs}.")
236      if self._signature is not None:
237        args = list(args)
238        for i, arg in enumerate(args):
239          if isinstance(self._signature[i], tensor_spec.DenseSpec):
240            args[i] = ops.convert_to_tensor(arg, self._signature[i].dtype)
241      return self._call_flat(args, self.captured_inputs)
242    else:
243      return super(WrappedFunction, self)._call_impl(
244          args, kwargs, cancellation_manager)
245
246  def prune(self, feeds, fetches, name=None, input_signature=None):
247    """Extract a subgraph of this function's underlying graph.
248
249    Wraps the subgraph in a new `WrappedFunction` object.
250
251    Args:
252      feeds: Input tensors to the subgraph to extract, as `Tensor` objects.
253      fetches: Possibly-nested Python data structure containing information
254        about outputs of the target subgraph. Each entry can either be a
255        `Tensor` object (for data outputs), an `Operation` object (for control
256        outputs), or a `TensorInfo` proto. Any additional shape/dtype
257        information provided in a `TensorInfo` and not present in the original
258        graph will be added to the returned subgraph.
259      name: (optional) Name to give to the underlying `FuncGraph` of the
260        returned object. If no name is provided, the graph's name will be
261        `"pruned"`.
262      input_signature: (optional) possibly-nested Python data structure
263        containing `TensorSpec` objects, with which to populate the returned
264        functions's `FuncGraph`'s `structured_input_signature` field.
265
266    Returns:
267      A new `WrappedFunction` object containing a copy of the portion of this
268        object's graph that goes from `feeds` to `fetches`.
269    """
270    # TODO(b/129646028): Add support for CompositeTensors.
271    name = name or "pruned"
272    flat_feeds = nest.flatten(feeds, expand_composites=True)
273    flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds]
274    for f in flat_feeds:
275      if not isinstance(f, ops.Tensor):
276        raise ValueError("All memebers of argument `feeds` must be tensors. "
277                         f"Got {f} with type {type(f)}.")
278
279    # Ignoring all feeds that are captures allows prune to be called
280    # using wrapped_func.inputs even when it uses variables
281    internal_captures = {id(c) for c in self.graph.internal_captures}
282    flat_feeds = [f for f in flat_feeds if id(f) not in internal_captures]
283
284    operation_fetches = []
285    tensor_fetches = []
286    tensor_infos = []
287
288    def _fetch_preprocessing_callback(fetch):
289      """Extract out lists of ops, tensors, and tensor type info.
290
291      Turns TensorInfos into Tensors in the original `fetches` structure.
292      Also extracts ops from `fetches`.
293
294      Args:
295        fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or
296          string identifying a Tensor or Operation.
297
298      Returns:
299        `fetch` converted to a Tensor.
300      """
301      if isinstance(fetch, ops.Operation):
302        operation_fetches.append(fetch)
303        return fetch
304      elif isinstance(fetch, meta_graph_pb2.TensorInfo):
305        tensor_infos.append(fetch)
306        decoded = _get_element_from_tensor_info(fetch, self._func_graph)
307        if (tensor_util.is_tf_type(decoded) or
308            isinstance(decoded, composite_tensor.CompositeTensor)):
309          tensor_fetches.append(decoded)
310        else:
311          operation_fetches.append(decoded)
312        return decoded
313      elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)):
314        tensor_fetches.append(fetch)
315        return fetch
316      else:
317        graph_element = self.graph.as_graph_element(fetch)
318        return _fetch_preprocessing_callback(graph_element)
319
320    fetches = nest.map_structure(_fetch_preprocessing_callback, fetches)
321
322    # Expand composite tensors into their component dense Tensors.
323    tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)
324
325    for f in flat_feeds + tensor_fetches + operation_fetches:
326      if f.graph is not self._func_graph:
327        raise ValueError("Can only prune function whose feeds and fetches "
328                         f"from graph {self._func_graph}. Input "
329                         f"{f} is from a different graph {f.graph}.")
330    with self._func_graph.as_default():
331      pruned_graph = func_graph.FuncGraph(name)
332    lift_map = lift_to_graph.lift_to_graph(
333        operation_fetches + tensor_fetches,
334        pruned_graph,
335        sources=flat_feeds + self.graph.internal_captures,
336        base_graph=self._func_graph)
337
338    # Note that we add the component tensors of any composite tensors to the
339    # returned function's outputs list; the list must contain these component
340    # tensors, or the function's sparse outputs won't work properly.
341    pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
342    pruned_graph.control_outputs.extend(
343        [lift_map[operation] for operation in operation_fetches])
344    pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
345    for external_capture, internal_capture in self.graph.captures:
346      pruned_graph.add_capture(external_capture, lift_map[internal_capture])
347    for ti in tensor_infos:
348      if ti.WhichOneof("encoding") == "name":  # Dense tensors only
349        t = pruned_graph.as_graph_element(ti.name)
350        if tensor_util.is_tf_type(t):
351          t.set_shape(tensor_shape.TensorShape(ti.tensor_shape))
352    # pylint: disable=protected-access
353    for f in self.graph._functions.values():
354      pruned_graph._add_function(f)
355    # pylint: enable=protected-access
356
357    pruned_graph.variables = self.graph.variables
358
359    def _structured_output_mapping(fetched):
360      """callback for `nest.map_structure()`"""
361      lifted = lift_map[fetched]
362      if isinstance(lifted, ops.Operation):
363        return None
364      return lifted
365
366    # expand_composites=True here causes composite tensors to be expanded
367    # into their component dense Tensors, mapped to the new graph, and then
368    # reconstituted into their original composite form.
369    pruned_graph.structured_outputs = nest.map_structure(
370        _structured_output_mapping, fetches, expand_composites=True)
371    pruned_graph.structured_input_signature = input_signature
372    pruned_fn = WrappedFunction(
373        pruned_graph, variable_holder=self._variable_holder)
374    pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
375    # TODO(kathywu): Enable keyword arguments if an input signature is specified
376    pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds]  # pylint: disable=protected-access
377    return pruned_fn
378
379
380def _filter_returned_ops(fn):
381  """Filtering out any ops returned by function.
382
383  Args:
384    fn: a function
385
386  Returns:
387    A tuple of (
388      Wrapped function that returns `None` in place of any ops,
389      dict that maps the index in the flat output structure to the returned op
390    )
391  """
392  returned_ops = {}
393
394  def wrap_and_filter_returned_ops(*args, **kwargs):
395    outputs = fn(*args, **kwargs)
396    flat_outputs = nest.flatten(outputs)
397    for n in range(len(flat_outputs)):
398      output = flat_outputs[n]
399      if isinstance(output, ops.Operation):
400        returned_ops[n] = output
401        flat_outputs[n] = None
402    return nest.pack_sequence_as(outputs, flat_outputs)
403
404  return wrap_and_filter_returned_ops, returned_ops
405
406
407class WrappedGraph(object):
408  """Class for wrapping multiple TF 1.X functions in a single graph.
409
410  Maintains a dictionary mapping names to wrapped functions. See
411  `tf.compat.v1.wrap_function` to learn more about wrapping V1 functions.
412
413  Functions wrapped using this class have access to variables and collections
414  created in other wrapped functions, using the standard TF 1.X API (
415  `tf.compat.v1.get_variable` or
416  `tf.compat.v1.get_default_graph().get_collection(...)`)
417
418  Outside a function, variables and collections may be accessed using the
419  `variables` and `graph` properties.
420
421  Example:
422
423  ```
424  def add_v1(x):
425    with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE):
426      v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32)
427    return v + x
428
429  def increment_var_v1(x):
430    with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE):
431      v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32)
432    return v.assign_add(x)
433
434  g = WrappedGraph()
435  add = g.wrap_function(add_v1, [tf.TensorSpec([], tf.int32)])
436  increment_var = g.wrap_function(increment_var_v1,
437                                  [tf.TensorSpec([], tf.int32)])
438
439  assert len(g.variables) == 1
440  assert g.variables[0].numpy() == 0
441  increment_var(tf.constant(5))
442  assert g.variables[0].numpy() == 5
443
444  ```
445  """
446
447  def __init__(self, variable_holder=None, **kwargs):
448    self._variable_holder = (
449        variable_holder or VariableHolder(share_variables=True))
450
451    name = kwargs.pop("name", "wrapped_function_graph")
452    # Always start with empty collections, unless otherwise specified. Setting
453    # `collections=None` will copy the collections from the outer graph.
454    collections = kwargs.pop("collections", {})
455    self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs)
456
457    self._wrapped_function = WrappedFunction(self.graph, self._variable_holder)
458    self._functions = {}
459
460  @property
461  def functions(self):
462    return self._functions
463
464  @property
465  def variables(self):
466    return self._variable_holder.variables
467
468  def wrap_function(self, fn, signature, name=None):
469    """Wraps a TF 1.X function and returns an eager-compatible function.
470
471    All functions wrapped in the same `WrappedGraph` will have access to the
472    same graph (`tf.compat.v1.get_default_graph` to get the graph object
473    within a function, or `WrappedGraph.graph` to get the graph outside a
474    function). Variables created within the function will be added to the
475    `variables` list.
476
477    Function inputs: All inputs to the function must be tensors (nested ok),
478    with their shapes and dtypes defined in the `signature` argument.
479
480    Function outputs:
481
482      * The 1.X function may return tensors, variables, and ops. The wrapped
483        eager-compatible function will always return tensors in the same nested
484        structure.
485      * Variables are replaced with a tensor containing the latest read values.
486      * Returned ops are executed, and replaced with None.
487      * The order of op execution and variable reads in the return is
488        nondeterministic. For example:
489
490        ```
491        def update_var(x):
492          v = tf.Variable(0)
493          op = tf.compat.v1.assign(v, x).op
494          return v, op
495
496        g = WrappedGraph()
497        fn = g.wrap_function(update_var)
498        read_value, _ = fn(tf.constant(3))
499        print(read_value.numpy())  # could be 0 or 3
500        print(g.variables[0].numpy()) # always 3
501        ```
502
503    To ensure that ops in the function are executed (e.g. ops added to the
504    `tf.GraphKeys.UPDATE_OPS` collection), include them in the function returns.
505
506    Args:
507      fn: a 1.X tensorflow function.
508      signature: a possibly nested sequence of `TensorSpecs` specifying the
509        shapes and dtypes of the arguments.
510      name: an optional string name for the function. The function will be saved
511        with key `name` in the `functions` dictionary.
512
513    Returns:
514      An eager-compatible function.
515    """
516    return self._wrap_function(fn, signature=signature, name=name)
517
518  def _wrap_function(self,
519                     fn,
520                     args=None,
521                     kwargs=None,
522                     signature=None,
523                     name=None):
524    """Internal wrap function method with extended func_graph arguments."""
525    fn_with_filter_and_scope, returned_ops = _filter_returned_ops(
526        self._variable_holder.call_with_variable_creator_scope(fn))
527
528    func_graph.func_graph_from_py_func(
529        None,  # Name is unused.
530        fn_with_filter_and_scope,
531        args=args,
532        kwargs=kwargs,
533        signature=signature,
534        add_control_dependencies=False,
535        func_graph=self.graph)
536
537    # This code relies on questional behavior from `func_graph_from_py_func`.
538    # If an existing FuncGraph is passed into the `func_graph` arg, the inputs
539    # and structured outputs are overwritten. Pretty sure this is a bug,
540    # because structured outputs doesn't match up with the outputs...
541    fn_inputs = self.graph.inputs[:-len(self.graph.captures)]
542
543    # Return filtered ops to the flattened outputs.
544    flat_fn_outputs = nest.flatten(self.graph.structured_outputs)
545    for index, op in returned_ops.items():
546      flat_fn_outputs[index] = op
547    fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs,
548                                       flat_fn_outputs)
549
550    name = name or fn.__name__
551    wrapped_function = self._wrapped_function.prune(
552        fn_inputs, fn_outputs, name, self.graph.structured_input_signature)
553    self._functions[name] = wrapped_function
554    return wrapped_function
555
556
557@tf_export(v1=["wrap_function"])
558def wrap_function(fn, signature, name=None):
559  """Wraps the TF 1.x function fn into a graph function.
560
561  The python function `fn` will be called once with symbolic arguments specified
562  in the `signature`, traced, and turned into a graph function. Any variables
563  created by `fn` will be owned by the object returned by `wrap_function`. The
564  resulting graph function can be called with tensors which match the
565  signature.
566
567  ```python
568  def f(x, do_add):
569    v = tf.Variable(5.0)
570    if do_add:
571      op = v.assign_add(x)
572    else:
573      op = v.assign_sub(x)
574    with tf.control_dependencies([op]):
575      return v.read_value()
576
577  f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True])
578
579  assert float(f_add(1.0)) == 6.0
580  assert float(f_add(1.0)) == 7.0
581
582  # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
583  # of variables, and possibly different non-template arguments.
584  f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False])
585
586  assert float(f_sub(1.0)) == 4.0
587  assert float(f_sub(1.0)) == 3.0
588  ```
589
590  Both `tf.compat.v1.wrap_function` and `tf.function` create a callable
591  TensorFlow graph. But while `tf.function` runs all stateful operations
592  (e.g. `tf.print`) and sequences operations to provide the same semantics as
593  eager execution, `wrap_function` is closer to the behavior of `session.run` in
594  TensorFlow 1.x. It will not run any operations unless they are required to
595  compute the function's outputs, either through a data dependency or a control
596  dependency. Nor will it sequence operations.
597
598  Unlike `tf.function`, `wrap_function` will only trace the Python function
599  once. As with placeholders in TF 1.x, shapes and dtypes must be provided to
600  `wrap_function`'s `signature` argument.
601
602  Since it is only traced once, variables and state may be created inside the
603  function and owned by the function wrapper object.
604
605  Args:
606    fn: python function to be wrapped
607    signature: the placeholder and python arguments to be passed to the wrapped
608      function
609    name: Optional. The name of the function.
610
611  Returns:
612    the wrapped graph function.
613  """
614  holder = VariableHolder(fn)
615  func_graph_name = "wrapped_function"
616  if name is not None:
617    func_graph_name = "wrapped_function_" + name
618  return WrappedFunction(
619      func_graph.func_graph_from_py_func(
620          func_graph_name,
621          holder,
622          args=None,
623          kwargs=None,
624          signature=signature,
625          add_control_dependencies=False,
626          collections={}),
627      variable_holder=holder,
628      signature=signature)
629
630
631def function_from_graph_def(graph_def, inputs, outputs, captures=None):
632  """Creates a ConcreteFunction from a GraphDef.
633
634  Args:
635    graph_def: A GraphDef to make a function out of.
636    inputs: A Tensor name or nested structure of names in `graph_def` which
637      should be inputs to the function.
638    outputs: A Tensor name or nested structure of names in `graph_def` which
639      should be outputs of the function.
640    captures: (Optional) A dictionary mapping node names in `graph_def` that
641      should be captured as inputs to tensors containing the value of the
642      captured inputs.
643
644  Returns:
645    A ConcreteFunction.
646  """
647
648  def _imports_graph_def():
649    importer.import_graph_def(graph_def, name="")
650    graph = ops.get_default_graph()
651    if captures is not None:
652      for c in captures:
653        graph.add_capture(captures[c], graph.get_tensor_by_name(str(c) + ":0"))
654
655  wrapped_import = wrap_function(_imports_graph_def, [])
656  import_graph = wrapped_import.graph
657  return wrapped_import.prune(
658      nest.map_structure(import_graph.as_graph_element, inputs),
659      nest.map_structure(import_graph.as_graph_element, outputs))
660