1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A utility function for importing TensorFlow graphs.""" 16import contextlib 17 18from tensorflow.core.framework import graph_pb2 19from tensorflow.python import tf2 20from tensorflow.python.client import pywrap_tf_session as c_api 21from tensorflow.python.framework import c_api_util 22from tensorflow.python.framework import device as pydev 23from tensorflow.python.framework import errors 24from tensorflow.python.framework import function 25from tensorflow.python.framework import op_def_registry 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import control_flow_util 28from tensorflow.python.util import compat 29from tensorflow.python.util.deprecation import deprecated_args 30from tensorflow.python.util.tf_export import tf_export 31 32 33def _IsControlInput(input_name): 34 # Expected format: '^operation_name' (control input). 35 return input_name.startswith('^') 36 37 38def _ParseTensorName(tensor_name): 39 """Parses a tensor name into an operation name and output index. 40 41 This function will canonicalize tensor names as follows: 42 43 * "foo:0" -> ("foo", 0) 44 * "foo:7" -> ("foo", 7) 45 * "foo" -> ("foo", 0) 46 * "foo:bar:baz" -> ValueError 47 48 Args: 49 tensor_name: The name of a tensor. 50 51 Returns: 52 A tuple containing the operation name, and the output index. 53 54 Raises: 55 ValueError: If `tensor_name' cannot be interpreted as the name of a tensor. 56 """ 57 components = tensor_name.split(':') 58 if len(components) == 2: 59 # Expected format: 'operation_name:output_index'. 60 try: 61 output_index = int(components[1]) 62 except ValueError: 63 raise ValueError(f'Cannot convert {tensor_name!r} to a tensor name. ' 64 'Second component of the name following the `:` should ' 65 f'be an int. Got {components[1]}.') 66 return components[0], output_index 67 elif len(components) == 1: 68 # Expected format: 'operation_name' (implicit 0th output). 69 return components[0], 0 70 else: 71 raise ValueError(f"Cannot convert '{tensor_name}' to a tensor name. Tensor " 72 'names should not contain more than 1 `:`. Obtained ' 73 f'{len(components) - 1}') 74 75 76@contextlib.contextmanager 77def _MaybeDevice(device): 78 """Applies the given device only if device is not None or empty.""" 79 if device: 80 with ops.device(device): 81 yield 82 else: 83 yield 84 85 86def _ProcessGraphDefParam(graph_def): 87 """Type-checks and possibly canonicalizes `graph_def`.""" 88 if not isinstance(graph_def, graph_pb2.GraphDef): 89 # `graph_def` could be a dynamically-created message, so try a duck-typed 90 # approach 91 try: 92 old_graph_def = graph_def 93 graph_def = graph_pb2.GraphDef() 94 graph_def.MergeFrom(old_graph_def) 95 except TypeError: 96 raise TypeError('Argument `graph_def` must be a GraphDef proto.') 97 else: 98 # If we're using the graph_def provided by the caller, modify graph_def 99 # in-place to add attr defaults to the NodeDefs (this is visible to the 100 # caller). 101 # NOTE(skyewm): this is undocumented behavior that at least meta_graph.py 102 # depends on. It might make sense to move this to meta_graph.py and have 103 # import_graph_def not modify the graph_def argument (we'd have to make sure 104 # this doesn't break anything else.) 105 for node in graph_def.node: 106 op_def = op_def_registry.get(node.op) 107 if op_def is None: 108 # Assume unrecognized ops are functions for now. TF_ImportGraphDef will 109 # report an error if the op is actually missing. 110 continue 111 _SetDefaultAttrValues(node, op_def) 112 113 return graph_def 114 115 116def _ProcessInputMapParam(input_map): 117 """Type-checks and possibly canonicalizes `input_map`.""" 118 if input_map is None: 119 input_map = {} 120 else: 121 if not isinstance(input_map, dict): 122 raise TypeError('Argument `input_map` must be a dictionary. Obtained ' 123 f'{type(input_map).__name__}') 124 if not all( 125 isinstance(k, compat.bytes_or_text_types) for k in input_map.keys()): 126 raise TypeError('All keys for argument `input_map` must be strings. ' 127 f'Obtained keys: {list(input_map.keys())}') 128 return input_map 129 130 131def _ProcessReturnElementsParam(return_elements): 132 """Type-checks and possibly canonicalizes `return_elements`.""" 133 if return_elements is None: 134 return None 135 if not all( 136 isinstance(x, compat.bytes_or_text_types) for x in return_elements): 137 raise TypeError('Argument `return_elements` must be a list of strings. ' 138 f'Obtained {return_elements}.') 139 return tuple(compat.as_str(x) for x in return_elements) 140 141 142def _FindAttrInOpDef(attr_name, op_def): 143 for attr_def in op_def.attr: 144 if attr_name == attr_def.name: 145 return attr_def 146 return None 147 148 149def _RemoveDefaultAttrs(producer_op_list, graph_def): 150 """Removes unknown default attrs according to `producer_op_list`. 151 152 Removes any unknown attrs in `graph_def` (i.e. attrs that do not appear in 153 registered OpDefs) that have a default value in `producer_op_list`. 154 155 Args: 156 producer_op_list: OpList proto. 157 graph_def: GraphDef proto 158 """ 159 producer_op_dict = {op.name: op for op in producer_op_list.op} 160 for node in graph_def.node: 161 # Remove any default attr values that aren't in op_def. 162 if node.op in producer_op_dict: 163 op_def = op_def_registry.get(node.op) 164 if op_def is None: 165 # Some custom op registrations won't show up here. That's OK, attribute 166 # stripping just won't be available. 167 continue 168 producer_op_def = producer_op_dict[node.op] 169 # We make a copy of node.attr to iterate through since we may modify 170 # node.attr inside the loop. 171 for key in list(node.attr): 172 if _FindAttrInOpDef(key, op_def) is None: 173 # No attr_def in consumer, look in producer. 174 attr_def = _FindAttrInOpDef(key, producer_op_def) 175 if (attr_def and attr_def.HasField('default_value') and 176 node.attr[key] == attr_def.default_value): 177 # Unknown attr had default value in producer, delete it so it can be 178 # understood by consumer. 179 del node.attr[key] 180 181 182def _ConvertInputMapValues(name, input_map): 183 """Ensures all input map values are tensors. 184 185 This should be called from inside the import name scope. 186 187 Args: 188 name: the `name` argument passed to import_graph_def 189 input_map: the `input_map` argument passed to import_graph_def. 190 191 Returns: 192 An possibly-updated version of `input_map`. 193 194 Raises: 195 ValueError: if input map values cannot be converted due to empty name scope. 196 """ 197 if not all(isinstance(v, ops.Tensor) for v in input_map.values()): 198 if name == '': # pylint: disable=g-explicit-bool-comparison 199 raise ValueError( 200 'tf.import_graph_def() requires a non-empty `name` if `input_map` ' 201 'contains non-Tensor values. Try calling tf.convert_to_tensor() on ' 202 '`input_map` values before calling tf.import_graph_def().') 203 with ops.name_scope('_inputs'): 204 input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()} 205 return input_map 206 207 208def _PopulateTFImportGraphDefOptions(options, prefix, input_map, 209 return_elements, 210 validate_colocation_constraints): 211 """Populates the TF_ImportGraphDefOptions `options`.""" 212 c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix) 213 c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True) 214 215 for input_src, input_dst in input_map.items(): 216 input_src = compat.as_str(input_src) 217 if input_src.startswith('^'): 218 src_name = compat.as_str(input_src[1:]) 219 dst_op = input_dst._as_tf_output().oper # pylint: disable=protected-access 220 c_api.TF_ImportGraphDefOptionsRemapControlDependency( 221 options, src_name, dst_op) 222 else: 223 src_name, src_idx = _ParseTensorName(input_src) 224 src_name = compat.as_str(src_name) 225 dst_output = input_dst._as_tf_output() # pylint: disable=protected-access 226 c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_idx, 227 dst_output) 228 for name in return_elements or []: 229 if ':' in name: 230 op_name, index = _ParseTensorName(name) 231 op_name = compat.as_str(op_name) 232 c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index) 233 else: 234 c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, 235 compat.as_str(name)) 236 237 c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints( 238 options, validate_colocation_constraints) 239 240 241def _ProcessNewOps(graph): 242 """Processes the newly-added TF_Operations in `graph`.""" 243 # Maps from a node to the names of the ops it's colocated with, if colocation 244 # is specified in the attributes. 245 colocation_pairs = {} 246 247 for new_op in graph._add_new_tf_operations(compute_devices=False): # pylint: disable=protected-access 248 original_device = new_op.device 249 new_op._set_device('') # pylint: disable=protected-access 250 colocation_names = _GetColocationNames(new_op) 251 if colocation_names: 252 colocation_pairs[new_op] = colocation_names 253 # Don't set a device for this op, since colocation constraints override 254 # device functions and the original device. Note that this op's device may 255 # still be set by the loop below. 256 # TODO(skyewm): why does it override the original device? 257 else: 258 with _MaybeDevice(original_device): 259 graph._apply_device_functions(new_op) # pylint: disable=protected-access 260 261 # The following loop populates the device field of ops that are colocated 262 # with another op. This is implied by the colocation attribute, but we 263 # propagate the device field for completeness. 264 for op, coloc_op_list in colocation_pairs.items(): 265 coloc_device = None 266 # Find any device in the list of colocated ops that have a device, if it 267 # exists. We assume that if multiple ops have devices, they refer to the 268 # same device. Otherwise, a runtime error will occur since the colocation 269 # property cannot be guaranteed. Note in TF2 colocations have been removed 270 # from the public API and will be considered a hint, so there is no runtime 271 # error. 272 # 273 # One possible improvement is to try to check for compatibility of all 274 # devices in this list at import time here, which would require 275 # implementing a compatibility function for device specs in python. 276 for coloc_op_name in coloc_op_list: 277 try: 278 coloc_op = graph._get_operation_by_name_unsafe(coloc_op_name) # pylint: disable=protected-access 279 except KeyError: 280 # Do not error in TF2 if the colocation cannot be guaranteed 281 if tf2.enabled() or control_flow_util.EnableControlFlowV2(graph): 282 continue 283 284 raise ValueError(f'Specified colocation to an op: {coloc_op_name} that ' 285 f'does not exist during import for op: {op.name}') 286 if coloc_op.device: 287 coloc_device = pydev.DeviceSpec.from_string(coloc_op.device) 288 break 289 if coloc_device: 290 op._set_device(coloc_device) # pylint: disable=protected-access 291 292 293def _GetColocationNames(op): 294 """Returns names of the ops that `op` should be colocated with.""" 295 colocation_names = [] 296 try: 297 class_values = op.get_attr('_class') 298 except ValueError: 299 # No _class attr 300 return 301 for val in class_values: 302 val = compat.as_str(val) 303 if val.startswith('loc:@'): 304 colocation_node_name = val[len('loc:@'):] 305 if colocation_node_name != op.name: 306 colocation_names.append(colocation_node_name) 307 return colocation_names 308 309 310def _GatherReturnElements(requested_return_elements, graph, results): 311 """Returns the requested return elements from results. 312 313 Args: 314 requested_return_elements: list of strings of operation and tensor names 315 graph: Graph 316 results: wrapped TF_ImportGraphDefResults 317 318 Returns: 319 list of `Operation` and/or `Tensor` objects 320 """ 321 return_outputs = c_api.TF_ImportGraphDefResultsReturnOutputs(results) 322 return_opers = c_api.TF_ImportGraphDefResultsReturnOperations(results) 323 324 combined_return_elements = [] 325 outputs_idx = 0 326 opers_idx = 0 327 for name in requested_return_elements: 328 if ':' in name: 329 combined_return_elements.append( 330 graph._get_tensor_by_tf_output(return_outputs[outputs_idx])) # pylint: disable=protected-access 331 outputs_idx += 1 332 else: 333 combined_return_elements.append( 334 graph._get_operation_by_tf_operation(return_opers[opers_idx])) # pylint: disable=protected-access 335 opers_idx += 1 336 return combined_return_elements 337 338 339def _SetDefaultAttrValues(node_def, op_def): 340 """Set any default attr values in `node_def` that aren't present.""" 341 assert node_def.op == op_def.name 342 for attr_def in op_def.attr: 343 key = attr_def.name 344 if attr_def.HasField('default_value'): 345 value = node_def.attr[key] 346 if value is None or value.WhichOneof('value') is None: 347 node_def.attr[key].CopyFrom(attr_def.default_value) 348 349 350@tf_export('graph_util.import_graph_def', 'import_graph_def') 351@deprecated_args(None, 'Please file an issue at ' 352 'https://github.com/tensorflow/tensorflow/issues if you depend' 353 ' on this feature.', 'op_dict') 354def import_graph_def(graph_def, 355 input_map=None, 356 return_elements=None, 357 name=None, 358 op_dict=None, 359 producer_op_list=None): 360 """Imports the graph from `graph_def` into the current default `Graph`. 361 362 This function provides a way to import a serialized TensorFlow 363 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 364 protocol buffer, and extract individual objects in the `GraphDef` as 365 `tf.Tensor` and `tf.Operation` objects. Once extracted, 366 these objects are placed into the current default `Graph`. See 367 `tf.Graph.as_graph_def` for a way to create a `GraphDef` 368 proto. 369 370 Args: 371 graph_def: A `GraphDef` proto containing operations to be imported into 372 the default graph. 373 input_map: A dictionary mapping input names (as strings) in `graph_def` 374 to `Tensor` objects. The values of the named input tensors in the 375 imported graph will be re-mapped to the respective `Tensor` values. 376 return_elements: A list of strings containing operation names in 377 `graph_def` that will be returned as `Operation` objects; and/or 378 tensor names in `graph_def` that will be returned as `Tensor` objects. 379 name: (Optional.) A prefix that will be prepended to the names in 380 `graph_def`. Note that this does not apply to imported function names. 381 Defaults to `"import"`. 382 op_dict: (Optional.) Deprecated, do not use. 383 producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) 384 list of `OpDef`s used by the producer of the graph. If provided, 385 unrecognized attrs for ops in `graph_def` that have their default value 386 according to `producer_op_list` will be removed. This will allow some more 387 `GraphDef`s produced by later binaries to be accepted by earlier binaries. 388 389 Returns: 390 A list of `Operation` and/or `Tensor` objects from the imported graph, 391 corresponding to the names in `return_elements`, 392 and None if `returns_elements` is None. 393 394 Raises: 395 TypeError: If `graph_def` is not a `GraphDef` proto, 396 `input_map` is not a dictionary mapping strings to `Tensor` objects, 397 or `return_elements` is not a list of strings. 398 ValueError: If `input_map`, or `return_elements` contains names that 399 do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. 400 it refers to an unknown tensor). 401 """ 402 del op_dict 403 return _import_graph_def_internal( 404 graph_def, 405 input_map=input_map, 406 return_elements=return_elements, 407 name=name, 408 producer_op_list=producer_op_list) 409 410 411def import_graph_def_for_function( # pylint: disable=invalid-name 412 graph_def, name=None): 413 """Like import_graph_def but does not validate colocation constraints.""" 414 return _import_graph_def_internal( 415 graph_def, validate_colocation_constraints=False, name=name) 416 417 418def _import_graph_def_internal( # pylint: disable=invalid-name 419 graph_def, 420 input_map=None, 421 return_elements=None, 422 validate_colocation_constraints=True, 423 name=None, 424 producer_op_list=None): 425 """Imports the graph from `graph_def` into the current default `Graph`. 426 427 This function provides a way to import a serialized TensorFlow 428 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 429 protocol buffer, and extract individual objects in the `GraphDef` as 430 `tf.Tensor` and `tf.Operation` objects. Once extracted, 431 these objects are placed into the current default `Graph`. See 432 `tf.Graph.as_graph_def` for a way to create a `GraphDef` 433 proto. 434 435 Args: 436 graph_def: A `GraphDef` proto containing operations to be imported into the 437 default graph. 438 input_map: A dictionary mapping input names (as strings) in `graph_def` to 439 `Tensor` objects. The values of the named input tensors in the imported 440 graph will be re-mapped to the respective `Tensor` values. 441 return_elements: A list of strings containing operation names in `graph_def` 442 that will be returned as `Operation` objects; and/or tensor names in 443 `graph_def` that will be returned as `Tensor` objects. 444 validate_colocation_constraints: Whether to validate colocation constraints. 445 name: (Optional.) A prefix that will be prepended to the names in 446 `graph_def`. Note that this does not apply to imported function names. 447 Defaults to `"import"`. 448 producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) 449 list of `OpDef`s used by the producer of the graph. If provided, 450 unrecognized attrs for ops in `graph_def` that have their default value 451 according to `producer_op_list` will be removed. This will allow some more 452 `GraphDef`s produced by later binaries to be accepted by earlier binaries. 453 454 Returns: 455 A list of `Operation` and/or `Tensor` objects from the imported graph, 456 corresponding to the names in `return_elements`, 457 and None if `returns_elements` is None. 458 459 Raises: 460 TypeError: If `graph_def` is not a `GraphDef` proto, 461 `input_map` is not a dictionary mapping strings to `Tensor` objects, 462 or `return_elements` is not a list of strings. 463 ValueError: If `input_map`, or `return_elements` contains names that 464 do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. 465 it refers to an unknown tensor). 466 """ 467 graph_def = _ProcessGraphDefParam(graph_def) 468 input_map = _ProcessInputMapParam(input_map) 469 return_elements = _ProcessReturnElementsParam(return_elements) 470 471 if producer_op_list is not None: 472 # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? 473 _RemoveDefaultAttrs(producer_op_list, graph_def) 474 475 graph = ops.get_default_graph() 476 with ops.name_scope(name, 'import', input_map.values()) as scope: 477 # Save unique prefix generated by name_scope 478 if scope: 479 assert scope.endswith('/') 480 prefix = scope[:-1] 481 else: 482 prefix = '' 483 484 # Generate any input map tensors inside name scope 485 input_map = _ConvertInputMapValues(name, input_map) 486 487 scoped_options = c_api_util.ScopedTFImportGraphDefOptions() 488 options = scoped_options.options 489 _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements, 490 validate_colocation_constraints) 491 492 # _ProcessNewOps mutates the new operations. _mutation_lock ensures a 493 # Session.run call cannot occur between creating the TF_Operations in the 494 # TF_GraphImportGraphDefWithResults call and mutating the them in 495 # _ProcessNewOps. 496 with graph._mutation_lock(): # pylint: disable=protected-access 497 with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: 498 try: 499 with graph._c_graph.get() as c_graph: # pylint: disable=protected-access 500 results = c_api.TF_GraphImportGraphDefWithResults( 501 c_graph, serialized, options) 502 results = c_api_util.ScopedTFImportGraphDefResults(results) 503 except errors.InvalidArgumentError as e: 504 # Convert to ValueError for backwards compatibility. 505 raise ValueError(str(e)) 506 507 # Create _DefinedFunctions for any imported functions. 508 # 509 # We do this by creating _DefinedFunctions directly from `graph_def`, and 510 # adding them to `graph`. Adding an existing function to a TF_Graph is a 511 # no-op, so this only has the effect of updating the Python state (usually 512 # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). 513 # 514 # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph 515 # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph 516 517 _ProcessNewOps(graph) 518 519 if graph_def.library and graph_def.library.function: 520 functions = function.from_library(graph_def.library) 521 for f in functions: 522 f.add_to_graph(graph) 523 524 # Treat input mappings that don't appear in the graph as an error, because 525 # they are likely to be due to a typo. 526 missing_unused_input_keys = ( 527 c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( 528 results.results)) 529 if missing_unused_input_keys: 530 missing_unused_input_keys = [ 531 compat.as_str(s) for s in missing_unused_input_keys 532 ] 533 missing_keys = ', '.join(missing_unused_input_keys) 534 raise ValueError( 535 'Attempted to map inputs that were not found in graph_def: ' 536 f'[{missing_keys}]') 537 538 if return_elements is None: 539 return None 540 else: 541 return _GatherReturnElements(return_elements, graph, results.results) 542