1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""Prototype decorator for defining legacy-graph-mode functions.""" 17 18import weakref 19 20from tensorflow.core.protobuf import meta_graph_pb2 21from tensorflow.core.protobuf import struct_pb2 22from tensorflow.python.eager import context 23from tensorflow.python.eager import function 24from tensorflow.python.eager import lift_to_graph 25from tensorflow.python.framework import composite_tensor 26from tensorflow.python.framework import func_graph 27from tensorflow.python.framework import importer 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import sparse_tensor 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_spec 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.ops import resource_variable_ops 34from tensorflow.python.ops import variable_scope 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.saved_model import nested_structure_coder 37from tensorflow.python.trackable import data_structures 38from tensorflow.python.util import nest 39from tensorflow.python.util.tf_export import tf_export 40 41 42class VariableHolder(object): 43 """Holds variables for a python function.""" 44 45 def __init__(self, fn=None, share_variables=False): 46 self._fn = fn 47 48 self._share_variables = share_variables 49 self._variables_by_name = data_structures.Mapping() 50 51 @property 52 def variables(self): 53 return self._variables_by_name 54 55 def variable_creator_scope(self, next_creator, **kwargs): 56 """Creates variables & adds them to collections to match legacy code.""" 57 collections = kwargs.pop("collections", None) 58 v = None 59 60 # Get expected variable name. 61 with ops.name_scope( 62 kwargs.get("name", None), "Variable", skip_on_eager=False) as name: 63 variable_name = ops.name_from_scope_name(name) 64 kwargs["name"] = name 65 66 if self._share_variables: 67 v = self._variables_by_name.get(variable_name, None) 68 69 if v is None: 70 v = next_creator(**kwargs) 71 self._variables_by_name[variable_name] = v 72 73 if collections is None: 74 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 75 if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 76 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 77 78 ops.add_to_collections(collections, v) 79 80 return v 81 82 def __call__(self, *args, **kwargs): 83 return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs) 84 85 def call_with_variable_creator_scope(self, fn): 86 87 def wrapped(*args, **kwargs): 88 with variable_scope.variable_creator_scope(self.variable_creator_scope): 89 return fn(*args, **kwargs) 90 91 return wrapped 92 93 94def _get_element_from_tensor_info(tensor_info, graph): 95 """Simplified copy of the deprecated `get_tensor_from_tensor_info`.""" 96 encoding = tensor_info.WhichOneof("encoding") 97 if encoding == "name": 98 # We may get operations here in some cases. TensorInfo is a bit of a 99 # misnomer if so. 100 return graph.as_graph_element(tensor_info.name) 101 elif encoding == "coo_sparse": 102 return sparse_tensor.SparseTensor( 103 graph.get_tensor_by_name(tensor_info.coo_sparse.indices_tensor_name), 104 graph.get_tensor_by_name(tensor_info.coo_sparse.values_tensor_name), 105 graph.get_tensor_by_name( 106 tensor_info.coo_sparse.dense_shape_tensor_name)) 107 elif encoding == "composite_tensor": 108 spec_proto = struct_pb2.StructuredValue( 109 type_spec_value=tensor_info.composite_tensor.type_spec) 110 spec = nested_structure_coder.decode_proto(spec_proto) 111 components = [graph.get_tensor_by_name(component.name) for component in 112 tensor_info.composite_tensor.components] 113 return spec._from_components(components) # pylint: disable=protected-access 114 else: 115 raise ValueError(f"Invalid TensorInfo.encoding: {encoding}. Valid " 116 "encodings are 'name', 'coo_sparse', and " 117 "'composite_tensor'.") 118 119 120def _lift_single_variable(old_variable, graph, variable_holder): 121 """Lifts `old_variable` out of the `FuncGraph` `graph`.""" 122 new_variable = resource_variable_ops.UninitializedVariable( 123 shape=old_variable.shape, 124 dtype=old_variable.dtype, 125 name=old_variable.op.name, 126 trainable=old_variable.trainable, 127 extra_handle_data=old_variable.handle) 128 new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access 129 graph.add_capture(new_variable.handle, old_variable.handle) 130 # Now that we've added the new variable to graph.captures, 131 # graph.capture will use that cached value and do some post-processing 132 # on the capture like recording it on the tape. 133 graph.capture(new_variable.handle) 134 # pylint: disable=protected-access 135 variable_name = new_variable.name.split(":")[0] 136 variable_holder._variables_by_name[variable_name] = new_variable 137 graph._weak_variables.append(weakref.ref(new_variable)) 138 # pylint: enable=protected-access 139 graph.watch_variable(new_variable) 140 return new_variable 141 142 143def _lift_unlifted_variables(graph, variable_holder): 144 """Finds resource variables and lifts them into the outer context. 145 146 When we import a GraphDef inside a wrap_function, no Python graph building 147 code runs. This means we get VarHandleOps which create variable resources, 148 but no corresponding Python objects. Leaving them like this works but gives 149 the user no way to interact with or modify the variables outside the graph. 150 151 This method searches for variables and lifts them out as regular variable 152 objects when possible, indicating to the FuncGraph that they are captures. 153 154 Args: 155 graph: The FuncGraph to lift variables from. 156 variable_holder: A VariableHolder to record the lifted variables in. 157 """ 158 with graph.as_default(): 159 global_collection_variables = ops.get_collection( 160 ops.GraphKeys.GLOBAL_VARIABLES) 161 local_collection_variables = ops.get_collection( 162 ops.GraphKeys.LOCAL_VARIABLES) 163 existing_captures = {id(c) for c in graph.internal_captures} 164 lifted_variables = {} 165 166 def _should_lift_variable(v): 167 return ((v._in_graph_mode # pylint: disable=protected-access 168 and v.graph.building_function) 169 and isinstance(v, resource_variable_ops.BaseResourceVariable) 170 and id(v.handle) not in existing_captures) 171 172 for old_variable in global_collection_variables: 173 if _should_lift_variable(old_variable): 174 new_variable = _lift_single_variable( 175 old_variable, graph, variable_holder) 176 lifted_variables[id(old_variable)] = new_variable 177 existing_captures.add(id(old_variable.handle)) 178 179 for old_variable in local_collection_variables: 180 if _should_lift_variable(old_variable): 181 new_variable = _lift_single_variable( 182 old_variable, graph, variable_holder) 183 lifted_variables[id(old_variable)] = new_variable 184 existing_captures.add(id(old_variable.handle)) 185 if new_variable._in_graph_mode: # pylint: disable=protected-access 186 outer_graph = new_variable.graph 187 # Variables are added to the global collection by default. In this 188 # case we only want the variable in the local collection, so we'll pop 189 # it out. 190 global_collection = outer_graph.get_collection_ref( 191 ops.GraphKeys.GLOBAL_VARIABLES) 192 global_collection.remove(new_variable) 193 outer_graph.add_to_collection( 194 ops.GraphKeys.LOCAL_VARIABLES, new_variable) 195 196 # Update the FuncGraph's collections, partly for the user and partly so this 197 # function is idempotent when it runs again in prune() calls. 198 for collection_name in [ 199 ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES 200 ]: 201 mutable_collection = ops.get_collection_ref(collection_name) 202 for index, current in enumerate(mutable_collection): 203 mutable_collection[index] = lifted_variables.get(id(current), current) 204 if not resource_variable_ops.is_resource_variable( 205 mutable_collection[index]): 206 logging.log_first_n( 207 logging.WARN, 208 "Unable to create a python object for variable {} because it is " 209 "a reference variable. It may not be visible to training APIs. " 210 "If this is a problem, consider rebuilding the SavedModel after " 211 "running tf.compat.v1.enable_resource_variables().".format( 212 mutable_collection[index]), 213 5) 214 215 216# TODO(allenl): make this trackable 217class WrappedFunction(function.ConcreteFunction): 218 """Wraps a tf V1 piece of code in a function.""" 219 220 def __init__(self, fn_graph, variable_holder, attrs=None, signature=None): 221 self._variable_holder = variable_holder 222 _lift_unlifted_variables(fn_graph, variable_holder) 223 # We call __init__ after lifting variables so that the function's signature 224 # properly reflects the new captured inputs. 225 for f in fn_graph.as_graph_def().library.function: 226 context.context().add_function_def(f) 227 self._signature = signature 228 super(WrappedFunction, self).__init__(fn_graph, attrs=attrs) 229 230 def _call_impl(self, args, kwargs, cancellation_manager=None): 231 if self._arg_keywords is None: 232 if kwargs: 233 raise NotImplementedError( 234 "Keyword arguments are not supported when calling a " 235 f"wrap_function-decorated function. Got {kwargs}.") 236 if self._signature is not None: 237 args = list(args) 238 for i, arg in enumerate(args): 239 if isinstance(self._signature[i], tensor_spec.DenseSpec): 240 args[i] = ops.convert_to_tensor(arg, self._signature[i].dtype) 241 return self._call_flat(args, self.captured_inputs) 242 else: 243 return super(WrappedFunction, self)._call_impl( 244 args, kwargs, cancellation_manager) 245 246 def prune(self, feeds, fetches, name=None, input_signature=None): 247 """Extract a subgraph of this function's underlying graph. 248 249 Wraps the subgraph in a new `WrappedFunction` object. 250 251 Args: 252 feeds: Input tensors to the subgraph to extract, as `Tensor` objects. 253 fetches: Possibly-nested Python data structure containing information 254 about outputs of the target subgraph. Each entry can either be a 255 `Tensor` object (for data outputs), an `Operation` object (for control 256 outputs), or a `TensorInfo` proto. Any additional shape/dtype 257 information provided in a `TensorInfo` and not present in the original 258 graph will be added to the returned subgraph. 259 name: (optional) Name to give to the underlying `FuncGraph` of the 260 returned object. If no name is provided, the graph's name will be 261 `"pruned"`. 262 input_signature: (optional) possibly-nested Python data structure 263 containing `TensorSpec` objects, with which to populate the returned 264 functions's `FuncGraph`'s `structured_input_signature` field. 265 266 Returns: 267 A new `WrappedFunction` object containing a copy of the portion of this 268 object's graph that goes from `feeds` to `fetches`. 269 """ 270 # TODO(b/129646028): Add support for CompositeTensors. 271 name = name or "pruned" 272 flat_feeds = nest.flatten(feeds, expand_composites=True) 273 flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds] 274 for f in flat_feeds: 275 if not isinstance(f, ops.Tensor): 276 raise ValueError("All memebers of argument `feeds` must be tensors. " 277 f"Got {f} with type {type(f)}.") 278 279 # Ignoring all feeds that are captures allows prune to be called 280 # using wrapped_func.inputs even when it uses variables 281 internal_captures = {id(c) for c in self.graph.internal_captures} 282 flat_feeds = [f for f in flat_feeds if id(f) not in internal_captures] 283 284 operation_fetches = [] 285 tensor_fetches = [] 286 tensor_infos = [] 287 288 def _fetch_preprocessing_callback(fetch): 289 """Extract out lists of ops, tensors, and tensor type info. 290 291 Turns TensorInfos into Tensors in the original `fetches` structure. 292 Also extracts ops from `fetches`. 293 294 Args: 295 fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or 296 string identifying a Tensor or Operation. 297 298 Returns: 299 `fetch` converted to a Tensor. 300 """ 301 if isinstance(fetch, ops.Operation): 302 operation_fetches.append(fetch) 303 return fetch 304 elif isinstance(fetch, meta_graph_pb2.TensorInfo): 305 tensor_infos.append(fetch) 306 decoded = _get_element_from_tensor_info(fetch, self._func_graph) 307 if (tensor_util.is_tf_type(decoded) or 308 isinstance(decoded, composite_tensor.CompositeTensor)): 309 tensor_fetches.append(decoded) 310 else: 311 operation_fetches.append(decoded) 312 return decoded 313 elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)): 314 tensor_fetches.append(fetch) 315 return fetch 316 else: 317 graph_element = self.graph.as_graph_element(fetch) 318 return _fetch_preprocessing_callback(graph_element) 319 320 fetches = nest.map_structure(_fetch_preprocessing_callback, fetches) 321 322 # Expand composite tensors into their component dense Tensors. 323 tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True) 324 325 for f in flat_feeds + tensor_fetches + operation_fetches: 326 if f.graph is not self._func_graph: 327 raise ValueError("Can only prune function whose feeds and fetches " 328 f"from graph {self._func_graph}. Input " 329 f"{f} is from a different graph {f.graph}.") 330 with self._func_graph.as_default(): 331 pruned_graph = func_graph.FuncGraph(name) 332 lift_map = lift_to_graph.lift_to_graph( 333 operation_fetches + tensor_fetches, 334 pruned_graph, 335 sources=flat_feeds + self.graph.internal_captures, 336 base_graph=self._func_graph) 337 338 # Note that we add the component tensors of any composite tensors to the 339 # returned function's outputs list; the list must contain these component 340 # tensors, or the function's sparse outputs won't work properly. 341 pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) 342 pruned_graph.control_outputs.extend( 343 [lift_map[operation] for operation in operation_fetches]) 344 pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds) 345 for external_capture, internal_capture in self.graph.captures: 346 pruned_graph.add_capture(external_capture, lift_map[internal_capture]) 347 for ti in tensor_infos: 348 if ti.WhichOneof("encoding") == "name": # Dense tensors only 349 t = pruned_graph.as_graph_element(ti.name) 350 if tensor_util.is_tf_type(t): 351 t.set_shape(tensor_shape.TensorShape(ti.tensor_shape)) 352 # pylint: disable=protected-access 353 for f in self.graph._functions.values(): 354 pruned_graph._add_function(f) 355 # pylint: enable=protected-access 356 357 pruned_graph.variables = self.graph.variables 358 359 def _structured_output_mapping(fetched): 360 """callback for `nest.map_structure()`""" 361 lifted = lift_map[fetched] 362 if isinstance(lifted, ops.Operation): 363 return None 364 return lifted 365 366 # expand_composites=True here causes composite tensors to be expanded 367 # into their component dense Tensors, mapped to the new graph, and then 368 # reconstituted into their original composite form. 369 pruned_graph.structured_outputs = nest.map_structure( 370 _structured_output_mapping, fetches, expand_composites=True) 371 pruned_graph.structured_input_signature = input_signature 372 pruned_fn = WrappedFunction( 373 pruned_graph, variable_holder=self._variable_holder) 374 pruned_fn._num_positional_args = len(flat_feeds) # pylint: disable=protected-access 375 # TODO(kathywu): Enable keyword arguments if an input signature is specified 376 pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds] # pylint: disable=protected-access 377 return pruned_fn 378 379 380def _filter_returned_ops(fn): 381 """Filtering out any ops returned by function. 382 383 Args: 384 fn: a function 385 386 Returns: 387 A tuple of ( 388 Wrapped function that returns `None` in place of any ops, 389 dict that maps the index in the flat output structure to the returned op 390 ) 391 """ 392 returned_ops = {} 393 394 def wrap_and_filter_returned_ops(*args, **kwargs): 395 outputs = fn(*args, **kwargs) 396 flat_outputs = nest.flatten(outputs) 397 for n in range(len(flat_outputs)): 398 output = flat_outputs[n] 399 if isinstance(output, ops.Operation): 400 returned_ops[n] = output 401 flat_outputs[n] = None 402 return nest.pack_sequence_as(outputs, flat_outputs) 403 404 return wrap_and_filter_returned_ops, returned_ops 405 406 407class WrappedGraph(object): 408 """Class for wrapping multiple TF 1.X functions in a single graph. 409 410 Maintains a dictionary mapping names to wrapped functions. See 411 `tf.compat.v1.wrap_function` to learn more about wrapping V1 functions. 412 413 Functions wrapped using this class have access to variables and collections 414 created in other wrapped functions, using the standard TF 1.X API ( 415 `tf.compat.v1.get_variable` or 416 `tf.compat.v1.get_default_graph().get_collection(...)`) 417 418 Outside a function, variables and collections may be accessed using the 419 `variables` and `graph` properties. 420 421 Example: 422 423 ``` 424 def add_v1(x): 425 with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE): 426 v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) 427 return v + x 428 429 def increment_var_v1(x): 430 with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE): 431 v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) 432 return v.assign_add(x) 433 434 g = WrappedGraph() 435 add = g.wrap_function(add_v1, [tf.TensorSpec([], tf.int32)]) 436 increment_var = g.wrap_function(increment_var_v1, 437 [tf.TensorSpec([], tf.int32)]) 438 439 assert len(g.variables) == 1 440 assert g.variables[0].numpy() == 0 441 increment_var(tf.constant(5)) 442 assert g.variables[0].numpy() == 5 443 444 ``` 445 """ 446 447 def __init__(self, variable_holder=None, **kwargs): 448 self._variable_holder = ( 449 variable_holder or VariableHolder(share_variables=True)) 450 451 name = kwargs.pop("name", "wrapped_function_graph") 452 # Always start with empty collections, unless otherwise specified. Setting 453 # `collections=None` will copy the collections from the outer graph. 454 collections = kwargs.pop("collections", {}) 455 self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs) 456 457 self._wrapped_function = WrappedFunction(self.graph, self._variable_holder) 458 self._functions = {} 459 460 @property 461 def functions(self): 462 return self._functions 463 464 @property 465 def variables(self): 466 return self._variable_holder.variables 467 468 def wrap_function(self, fn, signature, name=None): 469 """Wraps a TF 1.X function and returns an eager-compatible function. 470 471 All functions wrapped in the same `WrappedGraph` will have access to the 472 same graph (`tf.compat.v1.get_default_graph` to get the graph object 473 within a function, or `WrappedGraph.graph` to get the graph outside a 474 function). Variables created within the function will be added to the 475 `variables` list. 476 477 Function inputs: All inputs to the function must be tensors (nested ok), 478 with their shapes and dtypes defined in the `signature` argument. 479 480 Function outputs: 481 482 * The 1.X function may return tensors, variables, and ops. The wrapped 483 eager-compatible function will always return tensors in the same nested 484 structure. 485 * Variables are replaced with a tensor containing the latest read values. 486 * Returned ops are executed, and replaced with None. 487 * The order of op execution and variable reads in the return is 488 nondeterministic. For example: 489 490 ``` 491 def update_var(x): 492 v = tf.Variable(0) 493 op = tf.compat.v1.assign(v, x).op 494 return v, op 495 496 g = WrappedGraph() 497 fn = g.wrap_function(update_var) 498 read_value, _ = fn(tf.constant(3)) 499 print(read_value.numpy()) # could be 0 or 3 500 print(g.variables[0].numpy()) # always 3 501 ``` 502 503 To ensure that ops in the function are executed (e.g. ops added to the 504 `tf.GraphKeys.UPDATE_OPS` collection), include them in the function returns. 505 506 Args: 507 fn: a 1.X tensorflow function. 508 signature: a possibly nested sequence of `TensorSpecs` specifying the 509 shapes and dtypes of the arguments. 510 name: an optional string name for the function. The function will be saved 511 with key `name` in the `functions` dictionary. 512 513 Returns: 514 An eager-compatible function. 515 """ 516 return self._wrap_function(fn, signature=signature, name=name) 517 518 def _wrap_function(self, 519 fn, 520 args=None, 521 kwargs=None, 522 signature=None, 523 name=None): 524 """Internal wrap function method with extended func_graph arguments.""" 525 fn_with_filter_and_scope, returned_ops = _filter_returned_ops( 526 self._variable_holder.call_with_variable_creator_scope(fn)) 527 528 func_graph.func_graph_from_py_func( 529 None, # Name is unused. 530 fn_with_filter_and_scope, 531 args=args, 532 kwargs=kwargs, 533 signature=signature, 534 add_control_dependencies=False, 535 func_graph=self.graph) 536 537 # This code relies on questional behavior from `func_graph_from_py_func`. 538 # If an existing FuncGraph is passed into the `func_graph` arg, the inputs 539 # and structured outputs are overwritten. Pretty sure this is a bug, 540 # because structured outputs doesn't match up with the outputs... 541 fn_inputs = self.graph.inputs[:-len(self.graph.captures)] 542 543 # Return filtered ops to the flattened outputs. 544 flat_fn_outputs = nest.flatten(self.graph.structured_outputs) 545 for index, op in returned_ops.items(): 546 flat_fn_outputs[index] = op 547 fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs, 548 flat_fn_outputs) 549 550 name = name or fn.__name__ 551 wrapped_function = self._wrapped_function.prune( 552 fn_inputs, fn_outputs, name, self.graph.structured_input_signature) 553 self._functions[name] = wrapped_function 554 return wrapped_function 555 556 557@tf_export(v1=["wrap_function"]) 558def wrap_function(fn, signature, name=None): 559 """Wraps the TF 1.x function fn into a graph function. 560 561 The python function `fn` will be called once with symbolic arguments specified 562 in the `signature`, traced, and turned into a graph function. Any variables 563 created by `fn` will be owned by the object returned by `wrap_function`. The 564 resulting graph function can be called with tensors which match the 565 signature. 566 567 ```python 568 def f(x, do_add): 569 v = tf.Variable(5.0) 570 if do_add: 571 op = v.assign_add(x) 572 else: 573 op = v.assign_sub(x) 574 with tf.control_dependencies([op]): 575 return v.read_value() 576 577 f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True]) 578 579 assert float(f_add(1.0)) == 6.0 580 assert float(f_add(1.0)) == 7.0 581 582 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set 583 # of variables, and possibly different non-template arguments. 584 f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False]) 585 586 assert float(f_sub(1.0)) == 4.0 587 assert float(f_sub(1.0)) == 3.0 588 ``` 589 590 Both `tf.compat.v1.wrap_function` and `tf.function` create a callable 591 TensorFlow graph. But while `tf.function` runs all stateful operations 592 (e.g. `tf.print`) and sequences operations to provide the same semantics as 593 eager execution, `wrap_function` is closer to the behavior of `session.run` in 594 TensorFlow 1.x. It will not run any operations unless they are required to 595 compute the function's outputs, either through a data dependency or a control 596 dependency. Nor will it sequence operations. 597 598 Unlike `tf.function`, `wrap_function` will only trace the Python function 599 once. As with placeholders in TF 1.x, shapes and dtypes must be provided to 600 `wrap_function`'s `signature` argument. 601 602 Since it is only traced once, variables and state may be created inside the 603 function and owned by the function wrapper object. 604 605 Args: 606 fn: python function to be wrapped 607 signature: the placeholder and python arguments to be passed to the wrapped 608 function 609 name: Optional. The name of the function. 610 611 Returns: 612 the wrapped graph function. 613 """ 614 holder = VariableHolder(fn) 615 func_graph_name = "wrapped_function" 616 if name is not None: 617 func_graph_name = "wrapped_function_" + name 618 return WrappedFunction( 619 func_graph.func_graph_from_py_func( 620 func_graph_name, 621 holder, 622 args=None, 623 kwargs=None, 624 signature=signature, 625 add_control_dependencies=False, 626 collections={}), 627 variable_holder=holder, 628 signature=signature) 629 630 631def function_from_graph_def(graph_def, inputs, outputs, captures=None): 632 """Creates a ConcreteFunction from a GraphDef. 633 634 Args: 635 graph_def: A GraphDef to make a function out of. 636 inputs: A Tensor name or nested structure of names in `graph_def` which 637 should be inputs to the function. 638 outputs: A Tensor name or nested structure of names in `graph_def` which 639 should be outputs of the function. 640 captures: (Optional) A dictionary mapping node names in `graph_def` that 641 should be captured as inputs to tensors containing the value of the 642 captured inputs. 643 644 Returns: 645 A ConcreteFunction. 646 """ 647 648 def _imports_graph_def(): 649 importer.import_graph_def(graph_def, name="") 650 graph = ops.get_default_graph() 651 if captures is not None: 652 for c in captures: 653 graph.add_capture(captures[c], graph.get_tensor_by_name(str(c) + ":0")) 654 655 wrapped_import = wrap_function(_imports_graph_def, []) 656 import_graph = wrapped_import.graph 657 return wrapped_import.prune( 658 nest.map_structure(import_graph.as_graph_element, inputs), 659 nest.map_structure(import_graph.as_graph_element, outputs)) 660