1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Python front-end supports for functions. 16 17NOTE: At this time, functions are experimental and subject to change!. Proceed 18with caution. 19""" 20 21import collections 22import hashlib 23 24from tensorflow.core.framework import attr_value_pb2 25from tensorflow.core.framework import function_pb2 26from tensorflow.python.client import pywrap_tf_session as c_api 27from tensorflow.python.eager import context 28from tensorflow.python.framework import c_api_util 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import graph_to_function_def 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import resource_variable_ops 34from tensorflow.python.ops import variable_scope as vs 35from tensorflow.python.util import compat 36from tensorflow.python.util import function_utils 37from tensorflow.python.util import tf_contextlib 38from tensorflow.python.util import tf_inspect 39 40 41# TODO(b/136040013): Drop support for Defun. 42class Defun(object): 43 """Obsolete. Slated for deletion. Please use tf.function instead. 44 45 Known feature gaps while migrating to tf.function (could be outdated): 46 - tf.function doesn’t support Send/Recv capability since it doesn’t share 47 rendezvous with the main graph but always creates a new one. 48 - tf.function doesn’t support custom gradient function directly, instead you 49 need to define the function inside a tf.custom_gradient wrapper together 50 with the gradient function. 51 - Unlike Defun, Keras layers used inside a tf.function need to be created only 52 once to avoid variable recreation. 53 - Defun respects the device assignments and applies them to the function body 54 but tf.function needs it to be done manually. 55 - Defun might prune out unused ops automatically but tf.function doesn't. 56 57 Limitations of Defun: 58 - Original source locations are not preserved so errors do not include 59 full/valid stack traces. 60 - Only supports linear sequence of arguments and return values, putting the 61 burden on the caller to pack/unpack everything across a Defun boundary into 62 tuples (as opposed to passing list and dict-like structures directly). 63 - Does not support overloading or late-bound specializations. 64 - Has its own way for defining gradient overrides which does not follow 65 current conventions. 66 - Cannot support imperative control flow or automatic control dependencies. 67 - Does not reflect statefulness in the graph and has a calling convention that 68 differs from how more modern tools interact. 69 - Is only compatible with graph building mode. 70 71 Decorator used to define TensorFlow functions. 72 73 Use this decorator to make a Python function usable directly as a TensorFlow 74 function. 75 76 The decorated function must add ops to the default graph and return zero or 77 more `Tensor` objects. Call the decorator with named arguments, one for each 78 argument of the function to decorate, with the expected type of the argument 79 as value. 80 81 For example if the function to decorate accepts two `tf.float32` arguments 82 named `x` and `y`, call the decorator with: 83 84 @Defun(tf.float32, tf.float32) 85 def foo(x, y): 86 ... 87 88 When you call the decorated function, it adds the `call` ops to the 89 default graph. In addition, it adds the definition of the function into the 90 default graph. Because the addition of the function into the graph 91 is deferred, the decorator can be used anywhere in the program. 92 93 Any variables created inside of the function are hoisted into the outer graph. 94 Note that the variables are created in the variable scope that was active 95 during the first call to the function. Subsequent function calls will refer to 96 the same set of variables. 97 98 Definitions of functions in a graph are frozen as soon as the graph is used to 99 create a session. However, new functions and new calls to existing functions 100 may be added to the graph, with the new functions themselves becoming 101 immediately frozen. 102 103 Example, but also see the [How To on functions](link_needed). 104 105 ```python 106 # Defining the function. 107 @tf.Defun(tf.float32, tf.float32) 108 def MyFunc(x, y): 109 return x + y, x - y 110 111 # Building the graph. 112 a = tf.constant([1.0]) 113 b = tf.constant([2.0]) 114 c, d = MyFunc(a, b, name='mycall') 115 ``` 116 """ 117 118 def __init__(self, *input_types, **kwargs): 119 """Create a `Defun` decorator. 120 121 Args: 122 *input_types: A list of `tf.DType` 123 **kwargs: Optional keyword arguments, including 124 func_name - (optional). A python string, the name to use to 125 declare this `Function` in the graph. 126 127 grad_func - (optional). A function implementing the gradient 128 of the function-to-register. This is must be a 129 `_DefinedFunction` object. The gradient 130 function must satisfy the criterion defined in 131 function.proto:GradientDef. 132 133 python_grad_func - (optional). A function implementing the 134 gradient of the function python-side. This function must 135 take the current op and the gradients w.r.t. its outputs, 136 and return the gradients w.r.t. the inputs. That is it must 137 implement the interface expected by `tf.RegisterGradient`). 138 This will be called by tf.gradients to add the gradient ops 139 to the graph. At most one of grad_func and python_grad_func 140 can be specified. 141 142 out_names = (optional). A list of strings, one per output 143 tensor. 144 145 shape_func - (optional). A function taking the op and returning a list 146 of static shapes to set for the function's outputs. 147 """ 148 self._input_types = input_types 149 self._func_name = kwargs.pop("func_name", None) 150 self._grad_func = kwargs.pop("grad_func", None) 151 self._python_grad_func = kwargs.pop("python_grad_func", None) 152 self._out_names = kwargs.pop("out_names", None) 153 self._extra_kwargs = kwargs 154 155 def __call__(self, func): 156 # Various sanity checks on the callable func. 157 if not callable(func): 158 raise ValueError(f"Function {func} must be a callable.") 159 160 # Func should not use kwargs and defaults. 161 argspec = tf_inspect.getargspec(func) 162 if argspec.keywords or argspec.defaults: 163 raise ValueError( 164 "Functions with argument defaults or keywords arguments are not " 165 f"supported. {func} has defaults {argspec.defaults} and keywords " 166 f"{argspec.keywords}.") 167 168 # Computes how many arguments 'func' has. 169 min_args = len(argspec.args) 170 max_args = min_args 171 if argspec.varargs: 172 max_args = 1000000 173 argnames = argspec.args 174 if tf_inspect.ismethod(func): 175 # 1st argument is the "class" type. 176 min_args -= 1 177 argnames = argnames[1:] 178 179 if self._input_types: 180 # If Defun is given a list of types for the inputs, the number 181 # of input types should be compatible with 'func'. 182 num = len(self._input_types) 183 if num < min_args or num > max_args: 184 raise ValueError( 185 "The number of tf.function input types is not compatible with the " 186 f"allowed arguments of {func}. The tf.function have {num} input " 187 f"types, while the python function allows minimum {min_args} and " 188 f"maximum {max_args} arguments.") 189 return _DefinedFunction( 190 func, 191 argnames, 192 self._input_types, 193 self._func_name, 194 self._grad_func, 195 self._python_grad_func, 196 out_names=self._out_names, 197 **self._extra_kwargs) 198 199 # 'func' expects no arguments and input types is an empty list. 200 if min_args == 0 and max_args == 0: 201 return _DefinedFunction( 202 func, [], [], 203 self._func_name, 204 self._grad_func, 205 self._python_grad_func, 206 out_names=self._out_names, 207 **self._extra_kwargs) 208 209 # Input types are unknown. It's an overloaded function and hence 210 # its definition needs to be deferred until it's called. 211 return _OverloadedFunction( 212 func, 213 argnames, 214 self._func_name, 215 self._grad_func, 216 self._python_grad_func, 217 out_names=self._out_names, 218 **self._extra_kwargs) 219 220 221class _DefinedFunctionDeleter(object): 222 """Unregister function from eager context.""" 223 224 __slots__ = ["name"] 225 226 def __init__(self, name): 227 self.name = name 228 229 def __del__(self): 230 try: 231 context.remove_function(self.name) 232 except TypeError: 233 # Suppress some exceptions, mainly for the case when we're running on 234 # module deletion. Things that can go wrong include the context module 235 # already being unloaded, self._handle._handle_data no longer being 236 # valid, and so on. Printing warnings in these cases is silly 237 # (exceptions raised from __del__ are printed as warnings to stderr). 238 pass # 'NoneType' object is not callable when the handle has been 239 # partially unloaded. 240 except AttributeError: 241 pass # 'NoneType' object has no attribute 'eager_mode' when context has 242 # been unloaded. Will catch other module unloads as well. 243 244 245class _DefinedFunction(object): 246 """_DefinedFunction encapsulates a function definition and its properties. 247 248 Attributes: 249 name: The function name. 250 definition: The definition of this function. A FunctionDef proto. 251 grad_func_name: If not None, the name of this function's gradient function. 252 python_grad_func: A python callable implementing the gradient of 253 the function python-side. 254 """ 255 256 def __init__(self, 257 func, 258 argnames, 259 input_types, 260 func_name=None, 261 grad_func=None, 262 python_grad_func=None, 263 out_names=None, 264 shape_func=None, 265 capture_by_value=False, 266 allowlisted_stateful_ops=None, 267 capture_resource_var_by_value=True, 268 **kwargs): 269 """Creates _DefinedFunction. 270 271 Args: 272 func: A python callable which constructs a tf function body. 273 argnames: A list of strings for function argument names. 274 input_types: The function's argument types. Can be a tuple, list of 275 tf data types. 276 func_name: The function name. Defaults to None, in which derives from 277 'func'. 278 grad_func: This function's gradient function, if not None. Defaults 279 to None. 280 python_grad_func: A python callable implementing the gradient of 281 the function python-side. 282 out_names: An optional list of strings for the function return value 283 names. 284 shape_func: An optional function mapping an op to a list of static 285 output shapes. 286 capture_by_value: Boolean (defaults to False). If True, captured values 287 will be copied into the function body. 288 allowlisted_stateful_ops: A set of ops that if stateful we ignore and 289 copy into the function body, when `capture_by_value` is True. 290 capture_resource_var_by_value: Boolean (defaults to True). If False, 291 captured resource variable returns the handle instead of value. 292 **kwargs: The keyword arguments. **kwargs is passed to every call 293 site of this function. 294 295 Raises: 296 ValueError: The function definition is invalid. 297 298 """ 299 self._func = func 300 self._input_types = input_types 301 self._func_name = func_name 302 self._grad_func = grad_func 303 self._python_grad_func = python_grad_func 304 self._out_names = out_names 305 self._shape_func = shape_func 306 self._capture_by_value = capture_by_value 307 self._allowlisted_stateful_ops = allowlisted_stateful_ops 308 if self._allowlisted_stateful_ops is None: 309 self._allowlisted_stateful_ops = set() 310 self._capture_resource_var_by_value = capture_resource_var_by_value 311 self._extra_kwargs = kwargs 312 # Constructed only when C API is disabled, lazily 313 self._definition = None 314 # Constructed only when C API is enabled, lazily 315 self._c_func = None 316 self._function_deleter = None 317 self._sub_functions = {} # Constructed with _definition or _c_func 318 # pylint: disable=protected-access 319 device_funcs = ops.get_default_graph()._device_functions_outer_to_inner 320 # pylint: enable=protected-access 321 322 # Get the innermost device if possible. 323 self._caller_device = device_funcs[-1] if device_funcs else None 324 325 # Cached OpDef for this function. When C API is enabled, this is 326 # the only part of FunctionDef that we cache in Python. When C API 327 # is disabled the whole _definition is available and this is simply 328 # another reference to _definition.signature 329 self._op_def = None 330 331 assert isinstance(input_types, (list, tuple)) 332 self._arg_types = input_types 333 self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i) 334 for i in range(len(input_types))] 335 336 @property 337 def name(self): 338 """Function name.""" 339 self._create_definition_if_needed() 340 return self._func_name 341 342 @property 343 def definition(self): 344 """Function definition proto.""" 345 self._create_definition_if_needed() 346 if self._c_func: 347 with c_api_util.tf_buffer() as buf: 348 with self._c_func.get() as func: 349 c_api.TF_FunctionToFunctionDef(func, buf) 350 fdef = function_pb2.FunctionDef() 351 proto_data = c_api.TF_GetBuffer(buf) 352 fdef.ParseFromString(compat.as_bytes(proto_data)) 353 with ops.init_scope(): 354 if context.executing_eagerly(): 355 context.add_function(func) 356 self._function_deleter = _DefinedFunctionDeleter( 357 fdef.signature.name) 358 return fdef 359 return self._definition 360 361 @property 362 def _signature(self): 363 self._create_definition_if_needed() 364 return self._op_def 365 366 def set_grad_func(self, grad_func): 367 """Specifies the gradient function of this function.""" 368 assert not self._grad_func 369 assert isinstance(grad_func, _DefinedFunction) 370 self._grad_func = grad_func 371 372 @property 373 def grad_func_name(self): 374 """Returns the name of the gradient function.""" 375 return self._grad_func.name if self._grad_func else None 376 377 @property 378 def python_grad_func(self): 379 """Python gradient function callable.""" 380 return self._python_grad_func 381 382 @property 383 def declared_input_types(self): 384 """Returns the list of data types of explicit declared inputs.""" 385 return self._input_types 386 387 @property 388 def captured_inputs(self): 389 """Returns the list of implicitly captured inputs.""" 390 self._create_definition_if_needed() 391 return self._extra_inputs 392 393 @property 394 def stateful_ops(self): 395 """Returns the list of stateful ops in function definition. 396 397 Returns: 398 A list of (op.name, op.type) pairs. 399 """ 400 self._create_definition_if_needed() 401 return self._stateful_ops 402 403 def _create_definition_if_needed(self): 404 """Creates the function definition if it's not created yet.""" 405 with context.graph_mode(): 406 self._create_definition_if_needed_impl() 407 408 def _create_definition_if_needed_impl(self): 409 """This is not what you want, see _create_definition_if_needed.""" 410 if self._definition is not None or self._c_func is not None: 411 return 412 413 # Copy variable collections (by reference) from the parent graph such that 414 # name based variable sharing (e.g. via tf.make_template) works between the 415 # func graph and parent graph. 416 variable_keys = [] 417 variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS) # pylint: disable=protected-access 418 variable_keys.append(vs._VARSTORE_KEY) # pylint: disable=protected-access 419 420 parent_graph = ops.get_default_graph() 421 collections_ref = { 422 key: parent_graph.get_collection_ref(key) for key in variable_keys} 423 424 temp_graph = func_graph_from_py_func( 425 self._func, 426 self._arg_names, 427 self._arg_types, 428 self._func_name, 429 self._capture_by_value, 430 self._caller_device, 431 collections_ref=collections_ref, 432 allowlisted_stateful_ops=self._allowlisted_stateful_ops, 433 capture_resource_var_by_value=self._capture_resource_var_by_value) 434 435 self._extra_inputs = temp_graph.extra_inputs 436 # pylint: disable=protected-access 437 self._sub_functions = temp_graph._functions 438 # pylint: enable=protected-access 439 440 # Extra kwargs are treated as attrs on the function def. 441 if self._func_name: 442 base_func_name = self._func_name 443 else: 444 base_func_name = function_utils.get_func_name(self._func) 445 if self._grad_func: 446 base_func_name += ("_%s" % self._grad_func.name) 447 kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) 448 449 # FIXME(feyu): C API is always enabled now. The if-true branch never runs. 450 if not temp_graph._c_graph: # pylint: disable=protected-access 451 # Build the FunctionDef 452 self._definition = graph_to_function_def.graph_to_function_def( 453 temp_graph, 454 temp_graph.get_operations(), 455 temp_graph.inputs, 456 temp_graph.outputs, 457 out_names=self._out_names) 458 459 for k in kwargs_attr: 460 self._definition.attr[k].CopyFrom(kwargs_attr[k]) 461 462 # Hash the definition and its dependencies. 463 self._hash_str = self._create_hash_str( 464 self._definition.signature.input_arg, 465 self._definition.signature.output_arg, self._definition.node_def) 466 467 # Finally, we decide the function name to use. If not specified, 468 # make up something which is almost certainly unique (but deterministic). 469 if not self._func_name: 470 self._func_name = "_".join([base_func_name, self._hash_str]) 471 self._definition.signature.name = self._func_name 472 if self._func.__doc__: 473 self._definition.signature.description = self._func.__doc__ 474 475 self._op_def = self._definition.signature 476 else: # C API is enabled 477 output_names = ([compat.as_bytes(x) for x in self._out_names] 478 if self._out_names else []) 479 description = self._func.__doc__ or None 480 # pylint: disable=protected-access 481 with temp_graph._c_graph.get() as c_graph: 482 c_func = c_api.TF_GraphToFunction_wrapper( 483 c_graph, 484 base_func_name, 485 self._func_name is None, # append_hash_to_fn_name 486 None, # opers 487 [t._as_tf_output() for t in temp_graph.inputs], 488 [t._as_tf_output() for t in temp_graph.outputs], 489 output_names, 490 [], # control_outputs 491 [], # control_output_names 492 None, # opts 493 description) 494 self._c_func = c_api_util.ScopedTFFunction(c_func, base_func_name) 495 # pylint: enable=protected-access 496 self._set_c_attrs(kwargs_attr) 497 498 # Set cached fields: _op_def and _func_name (if not already set) 499 self._op_def = self.definition.signature 500 if self._func_name: 501 assert self._func_name == self._op_def.name 502 else: 503 self._func_name = compat.as_str(self._op_def.name) 504 505 self._stateful_ops = [(op.name, op.type) 506 for op in temp_graph.get_operations() 507 if op._is_stateful] # pylint: disable=protected-access 508 509 def _set_c_attrs(self, attrs): 510 """Sets `attrs` as attributes of self._c_func. 511 512 Requires that self._c_func is not None. 513 514 Args: 515 attrs: a dictionary from attribute name to attribute proto value 516 """ 517 for name, attr_value in attrs.items(): 518 serialized = attr_value.SerializeToString() 519 # TODO(skyewm): this creates and deletes a new TF_Status for every attr. 520 # It might be worth creating a convenient way to re-use the same status. 521 with self._c_func.get() as func: 522 c_api.TF_FunctionSetAttrValueProto(func, compat.as_str(name), 523 serialized) 524 525 def _create_hash_str(self, input_arg, output_arg, node_def): 526 """Creates an 8-character string unique to this input. 527 528 Args: 529 input_arg: the input_arg field of an OpDef 530 (e.g. self._definition.signature.input_arg) 531 output_arg: the output_arg field of an OpDef 532 (e.g. self._definition.signature.output_arg) 533 node_def: the node_def field of a FunctionDef 534 (e.g. self._definition.node_def) 535 536 Returns: 537 The unique string for this input 538 """ 539 hasher = hashlib.sha1() 540 541 def update_num(n): 542 hasher.update(compat.as_bytes("%x" % n)) 543 544 def update_str(s): 545 update_num(len(s)) 546 hasher.update(compat.as_bytes(s)) 547 548 def update_strs(slist): 549 update_num(len(slist)) 550 for s in slist: 551 update_str(s) 552 553 for adef in input_arg: 554 update_str(adef.SerializeToString()) 555 556 for adef in output_arg: 557 update_str(adef.SerializeToString()) 558 559 for n in sorted(node_def, key=lambda n: n.name): 560 update_str(n.name) 561 update_str(n.op) 562 update_strs(n.input) 563 update_num(len(n.attr)) 564 # NOTE: protobuf map serialization does not guarantee ordering. 565 for k in sorted(n.attr): 566 update_str(k) 567 update_str(n.attr[k].SerializeToString()) 568 569 return hasher.hexdigest()[:8] 570 571 def add_to_graph(self, g): 572 """Adds this function into the graph g.""" 573 self._create_definition_if_needed() 574 575 # Adds this function into 'g'. 576 # pylint: disable=protected-access 577 if context.executing_eagerly(): 578 context.context().add_function_def(self.definition) 579 else: 580 g._add_function(self) 581 # pylint: enable=protected-access 582 583 # Ensures related sub-routines are defined in 'g', too. 584 for f in self._sub_functions.values(): 585 f.add_to_graph(g) 586 587 # Adds its gradient function, too. 588 if self._grad_func: 589 self._grad_func.add_to_graph(g) 590 591 def __call__(self, *args, **kwargs): 592 self.add_to_graph(ops.get_default_graph()) 593 args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs 594 ret, op = _call(self._signature, *args, **kwargs) 595 596 # Set a hidden attr in 'op' so that gradients_impl can refer back 597 # to this _DefinedFunction instance to access python_grad_func. 598 assert isinstance(op, ops.Operation) 599 setattr(op, "__defun", self) 600 601 if self._shape_func is not None: 602 shapes = self._shape_func(op) 603 if len(shapes) != len(op.outputs): 604 raise ValueError(f"shape_func {self._shape_func} produced " 605 f"{len(shapes):d} shapes, which does not match " 606 f"{len(op.outputs)} outputs.") 607 for (t, shape) in zip(op.outputs, shapes): 608 t.set_shape(shape) 609 return ret 610 611 612class _OverloadedFunction(object): 613 """_OverloadedFunction encapsulates an overloaded function. 614 615 _OverloadedFunction maintains a mapping from input types to 616 instantiated _DefinedFunction in self._overload. 617 618 """ 619 620 def __init__(self, 621 func, 622 argnames, 623 func_name=None, 624 grad_func=None, 625 python_grad_func=None, 626 out_names=None, 627 **kwargs): 628 """Creates _DefinedFunction. 629 630 Args: 631 func: A python callable which constructs a tf function body. 632 argnames: A list of strings for function argument names. 633 func_name: The function name. Defaults to None, in which derives from 634 'func'. 635 grad_func: This function's gradient function, if not None. Defaults 636 to None. 637 python_grad_func: A python callable implementing the gradient of 638 the function python-side. 639 out_names: A list of strings for the function return value names. 640 **kwargs: The keyword arguments. **kwargs is passed to every call 641 site of this function. 642 643 Raises: 644 ValueError: The function definition is invalid. 645 646 """ 647 self._func = func 648 self._argnames = argnames 649 self._func_name = func_name 650 assert grad_func is None or isinstance(grad_func, _OverloadedFunction) 651 self._grad_func = grad_func 652 self._python_grad_func = python_grad_func 653 self._out_names = out_names 654 self._extra_kwargs = kwargs 655 self._overload = {} 656 657 def instantiate(self, input_types): 658 """Instantiate this function given input argument types. 659 660 Args: 661 input_types: A list of data types for the inputs. 662 663 Returns: 664 _DefinedFunction for the given input types. 665 666 """ 667 # Stringify the type list. 668 key = _type_list_to_str(input_types) 669 defined = self._overload.get(key) 670 if not defined: 671 # If not defined yet, define the function given the input types. 672 name = self._func_name 673 if name is not None: 674 name = "_".join([name, key]) 675 defined = _DefinedFunction( 676 self._func, 677 self._argnames, 678 input_types, 679 name, 680 None, 681 self._python_grad_func, 682 out_names=self._out_names, 683 **self._extra_kwargs) 684 _ = defined.name # Fully instantiate the function definition. 685 if self._grad_func: 686 # If _grad_func is given, it is another 687 # _OverloadedFunction. We need to instantiate it with the 688 # right input types. 689 output_types = [ 690 dtypes.DType(_.type) for _ in defined._signature.output_arg # pylint: disable=protected-access 691 ] 692 # pylint: disable=protected-access 693 defined._grad_func = self._grad_func.instantiate(input_types + 694 output_types) 695 # pylint: enable=protected-access 696 self._overload[key] = defined 697 return defined 698 699 def __call__(self, *args, **kwargs): 700 input_types = [] 701 args = list(args) 702 for (i, x) in enumerate(args): 703 x = ops.convert_to_tensor(x) 704 if not isinstance(x, ops.Tensor): 705 raise ValueError(f"Expected a Tensor but got {x} with type {type(x)}.") 706 input_types.append(x.dtype) 707 args[i] = x 708 return self.instantiate(input_types)(*args, **kwargs) 709 710 711class _FuncGraph(ops.Graph): 712 """A helper for constructing a function. 713 714 _FuncGraph overrides ops.Graph's create_op() so that we can keep 715 track of all inputs into every op created inside the function. If 716 any input is from other graphs, we keep track of it in self.capture 717 and substitute the input with a place holder. 718 719 Each captured input's corresponding place holder is converted into a 720 function argument and the caller passes in the captured tensor. 721 """ 722 723 def __init__(self, name, capture_by_value, allowlisted_stateful_ops, 724 capture_resource_var_by_value, *args, **kwargs): 725 super(_FuncGraph, self).__init__(*args, **kwargs) 726 self._capture_by_value = capture_by_value 727 self._allowlisted_stateful_ops = allowlisted_stateful_ops 728 self._capture_resource_var_by_value = capture_resource_var_by_value 729 self._building_function = True 730 self._outer_graph = ops.get_default_graph() 731 self._vscope = vs.get_variable_scope() 732 self._old_custom_getter = self._vscope.custom_getter 733 734 # The name of the function. 735 self.name = name 736 # Placeholder tensors representing the inputs to this function. The tensors 737 # are in this _FuncGraph. 738 self.inputs = [] 739 # Tensors that will be returned this function. The tensors are in this 740 # _FuncGraph. 741 self.outputs = [] 742 # Maps external tensor -> internal tensor (e.g. input placeholder). 743 self._captured = {} 744 # The external tensors that have been captured as inputs and must be passed 745 # to this function (empty if capturing by value, otherwise these are the 746 # keys of _captured). 747 self.extra_inputs = [] 748 # Input placeholders that been added for captured values (empty if capturing 749 # by value). 750 self.extra_args = [] 751 # Captured variables. 752 # TODO(skyewm): is this needed? 753 self.extra_vars = [] 754 755 # pylint: disable=g-doc-return-or-yield 756 757 @property 758 def outer_graph(self): 759 """The graph active when this _FuncGraph was created.""" 760 return self._outer_graph 761 762 @tf_contextlib.contextmanager 763 def container(self, container_name): 764 """Returns a context manager that specifies the resource container to use. 765 766 Overridden from `tf.Graph` to update both the init_scope container 767 and the present inner container. This is necessary to make sure setting 768 containers applies correctly both to created variables and to stateful 769 ops. 770 771 Args: 772 container_name: container name string. 773 774 Returns: 775 A context manager for defining resource containers for stateful ops, 776 yields the container name. 777 """ 778 original_container = self._container 779 # pylint: disable=protected-access 780 with ops.init_scope(): 781 original_init_container = ops.get_default_graph()._container 782 try: 783 self._container = container_name 784 with ops.init_scope(): 785 ops.get_default_graph()._container = container_name 786 yield self._container 787 finally: 788 self._container = original_container 789 with ops.init_scope(): 790 ops.get_default_graph()._container = original_init_container 791 # pylint: enable=protected-access 792 793 # pylint: enable=g-doc-return-or-yield 794 795 def getvar( 796 self, 797 getter, 798 name, 799 shape=None, 800 dtype=None, 801 initializer=None, 802 reuse=None, 803 trainable=True, 804 collections=None, # pylint: disable=redefined-outer-name 805 use_resource=None, 806 **kwargs): 807 """A custom variable getter.""" 808 # Here, we switch the default graph to the outer graph and ask the 809 # variable scope in which the function is defined to give us the 810 # variable. The variable is stashed in extra_vars and returned to 811 # the caller. 812 # 813 # We capture these variables so that the variable definition is 814 # hoisted upward to the outer most graph. 815 with self._outer_graph.as_default(): 816 # pylint: disable=protected-access 817 var = self._vscope.get_variable( 818 vs._get_default_variable_store(), 819 name, 820 shape=shape, 821 dtype=dtype, 822 initializer=initializer, 823 reuse=reuse, 824 trainable=trainable, 825 collections=collections, 826 use_resource=use_resource) 827 self.extra_vars.append(var) 828 if (isinstance(var, resource_variable_ops.BaseResourceVariable) and 829 self._capture_resource_var_by_value): 830 # For resource-based variables read the variable outside the function 831 # and pass in the value. This ensures that the function is pure and 832 # differentiable. TODO(apassos) this may have performance problems if 833 # the function will only do embedding lookups on the variable. 834 return var.value() 835 return var 836 837 def _create_op_internal( 838 self, 839 op_type, 840 inputs, 841 dtypes=None, # pylint: disable=redefined-outer-name 842 input_types=None, 843 name=None, 844 attrs=None, 845 op_def=None, 846 compute_device=True): 847 for i, x in enumerate(inputs): 848 if isinstance(x, ops.EagerTensor) or x.graph is not self: 849 inputs[i] = self.capture(x) 850 return super(_FuncGraph, self)._create_op_internal( 851 op_type, 852 inputs, 853 dtypes=dtypes, 854 input_types=input_types, 855 name=name, 856 attrs=attrs, 857 op_def=op_def, 858 compute_device=compute_device) 859 860 def capture(self, tensor, name=None): 861 """Adds the given tensor to this graph and returns the captured tensor.""" 862 if tensor.ref() in self._captured: 863 # Captured already. 864 return self._captured[tensor.ref()] 865 elif self._capture_by_value: 866 return self._add_tensor_and_parents(tensor) 867 else: 868 return self._capture_tensor_as_extra_input(tensor, name) 869 870 @property 871 def captures(self): 872 """Pairs of tensors and captured tensor.""" 873 return [(k.deref(), v) for k, v in self._captured.items()] 874 875 def _capture_tensor_as_extra_input(self, tensor, name=None): 876 # Substitute with a placeholder. 877 self.extra_inputs.append(tensor) 878 # Hoist the new input placeholder out of any control flow context 879 # we're currently in. 880 with ops.control_dependencies(None): 881 ph = array_ops.placeholder( 882 tensor.dtype, shape=tensor.get_shape(), name=name) 883 # pylint: disable=protected-access 884 if isinstance(tensor, ops.EagerTensor): 885 handle_data = tensor._handle_data 886 if handle_data: 887 handle_data = handle_data.SerializeToString() 888 else: 889 with tensor.graph._c_graph.get() as c_graph: 890 handle_data = c_api.GetHandleShapeAndType(c_graph, 891 tensor._as_tf_output()) 892 893 if handle_data: 894 with ph.graph._c_graph.get() as c_graph: 895 c_api.SetHandleShapeAndType(c_graph, ph._as_tf_output(), 896 compat.as_bytes(handle_data)) 897 # pylint: enable=protected-access 898 self.inputs.append(ph) 899 self._captured[tensor.ref()] = ph 900 self.extra_args.append(ph) 901 if _is_guaranteed_const(tensor): 902 with ops.control_dependencies(None): 903 return array_ops.guarantee_const(ph) 904 else: 905 return ph 906 907 def _add_tensor_and_parents(self, tensor): 908 op = self._add_op_and_parents(tensor.op) 909 return op.outputs[tensor.value_index] 910 911 def _add_op_and_parents(self, op): 912 # pylint: disable=protected-access 913 op_def = graph_to_function_def._get_op_def(op) 914 if op._is_stateful and op not in self._allowlisted_stateful_ops: 915 raise ValueError(f"Cannot capture a stateful node (name:{op.name}, " 916 f"type:{op.type}) by value.") 917 elif op.type in ("Placeholder", "PlaceholderV2"): 918 raise ValueError(f"Cannot capture a placeholder (name:{op.name}, " 919 f"type:{op.type}) by value.") 920 # pylint: enable=protected-access 921 922 captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs] 923 924 captured_op = self._create_op_internal( 925 op.type, 926 captured_inputs, [o.dtype for o in op.outputs], 927 name=op.name, 928 attrs=op.node_def.attr, 929 op_def=op_def) 930 931 for t, captured_t in zip(op.outputs, captured_op.outputs): 932 self._captured[t.ref()] = captured_t 933 934 return captured_op 935 936 937def func_graph_from_py_func(func, 938 arg_names, 939 arg_types, 940 name=None, 941 capture_by_value=False, 942 device=None, 943 colocation_stack=None, 944 container=None, 945 collections_ref=None, 946 arg_shapes=None, 947 allowlisted_stateful_ops=None, 948 capture_resource_var_by_value=True): 949 """Returns a _FuncGraph generated from `func`. 950 951 Args: 952 func: A Python callable which constructs a TF function body. The arguments 953 must correspond to `arg_types`. Returns a value or list/tuple of values. 954 No returned value can be None. 955 arg_names: A sequence of strings for the function argument names. 956 arg_types: A sequence of the function's argument types. 957 name: The function name. If None, the name is derived from `func`. 958 capture_by_value: boolean. If True, captured values will be copied into the 959 function body. 960 device: device name or function. 961 colocation_stack: A colocation stack (list) the _FuncGraph should use. 962 container: A container name the _FuncGraph should start with. 963 collections_ref: A reference to a collections dict the _FuncGraph should 964 use internally. 965 arg_shapes: A sequence of the function's argument shapes. 966 allowlisted_stateful_ops: A set of ops that if stateful we ignore and 967 re-create. 968 capture_resource_var_by_value: Boolean (defaults to True). If False, 969 captured resource variable returns the handle instead of value. 970 971 Returns: 972 A _FuncGraph. 973 974 Raises: 975 ValueError: if func returns None. 976 """ 977 if not name: 978 name = function_utils.get_func_name(func) 979 func_graph = _FuncGraph(name, capture_by_value, allowlisted_stateful_ops, 980 capture_resource_var_by_value) 981 982 with func_graph.as_default(), ops.device(device): 983 # pylint: disable=protected-access 984 if collections_ref is not None: 985 func_graph._collections = collections_ref 986 if container is not None: 987 func_graph._container = container 988 if colocation_stack is not None: 989 func_graph._colocation_stack = colocation_stack 990 # pylint: enable=protected-access 991 992 if arg_shapes is None: 993 arg_shapes = [None] * len(arg_types) 994 995 # Create placeholders for the function arguments. 996 for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes): 997 argholder = array_ops.placeholder(argtype, shape=argshape, name=argname) 998 func_graph.inputs.append(argholder) 999 # Call func and gather the output tensors. 1000 with vs.variable_scope("", custom_getter=func_graph.getvar): 1001 outputs = func(*func_graph.inputs) 1002 1003 # There is no way of distinguishing between a function not returning 1004 # anything and a function returning None in Python. 1005 # We need to allow the former and ideally want to forbid the latter as 1006 # it is most likely user error. 1007 # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to 1008 # allow users to explicitly mark the function as not returning anything. 1009 # For now, we allow a single None return and interpret it as a function 1010 # with no output. 1011 if outputs is None: 1012 outputs = [] 1013 else: 1014 # If func only returned one value, make it a tuple. 1015 if not isinstance(outputs, (list, tuple)): 1016 outputs = (outputs,) 1017 if any(_ is None for _ in outputs): 1018 raise ValueError(f"Function {name} can not return None.") 1019 # Ensures each output is a Tensor in the function graph. 1020 outputs = [ops.convert_to_tensor(t) for t in outputs] 1021 outputs = [func_graph.capture(t) if t.graph is not func_graph else t 1022 for t in outputs] 1023 func_graph.outputs = outputs 1024 return func_graph 1025 1026 1027def _is_guaranteed_const(tensor): 1028 """Determines whether `tensor` is guaranteed to be a constant. 1029 1030 A tensor is guaranteed to be a constant if either it was produced by 1031 a `GuaranteeConst` op or if all of its children are guaranteed to be 1032 constants. 1033 1034 Args: 1035 tensor: The tensor for which to determine const-ness. 1036 1037 Returns: 1038 True if `tensor` is guaranteed to be a constant, False otherwise. 1039 """ 1040 1041 if isinstance(tensor, ops.EagerTensor): 1042 return False 1043 1044 class Work(object): 1045 1046 def __init__(self, op, leaving): 1047 self.op = op 1048 self.leaving = leaving 1049 1050 is_guaranteed_const = lambda op: op.node_def.op == "GuaranteeConst" 1051 constants = set([]) 1052 def all_inputs_const(op): 1053 # If all inputs of an op are guaranteed constants, then we can infer that 1054 # the op produces a constant as well. 1055 return op.inputs and all(inp.op in constants for inp in op.inputs) 1056 1057 visited = set([]) 1058 stack = [Work(tensor.op, leaving=False)] 1059 while stack: 1060 work = stack.pop() 1061 if work.leaving: 1062 if all_inputs_const(work.op): 1063 constants.add(work.op) 1064 continue 1065 visited.add(work.op) 1066 if is_guaranteed_const(work.op): 1067 constants.add(work.op) 1068 continue 1069 1070 # This op will be revisited after all its inputs are checked for const-ness. 1071 stack.append(Work(work.op, leaving=True)) 1072 for inp in work.op.inputs: 1073 if inp.op not in visited: 1074 stack.append(Work(inp.op, leaving=False)) 1075 return tensor.op in constants 1076 1077 1078def _call(sig, *inputs, **kwargs): 1079 """Adds a node calling a function. 1080 1081 This adds a `call` op to the default graph that calls the function 1082 of signature `sig`, passing the tensors in `inputs` as arguments. 1083 It returns the outputs of the call, which are one or more tensors. 1084 1085 `sig` is OpDefArg.a `_DefinedFunction` object. 1086 1087 You can pass an optional keyword parameter `name=string` to name the 1088 added operation. 1089 1090 You can pass an optional keyword parameter `noinline=True|False` to 1091 instruct the runtime not to inline the function body into the call 1092 site. 1093 1094 Args: 1095 sig: OpDefArg. The signature of the function. 1096 *inputs: arguments to the function. 1097 **kwargs: Optional keyword arguments. Can only contain 'name' or 1098 'noinline'. 1099 1100 Returns: 1101 A 2-element tuple. First element: a Tensor if the function returns a single 1102 value; a list of Tensors if the function returns multiple value; the 1103 Operation if the function returns no values. Second element: the Operation. 1104 1105 Raises: 1106 ValueError: if the arguments are invalid. 1107 """ 1108 if len(inputs) != len(sig.input_arg): 1109 raise ValueError(f"Expected {len(sig.input_arg):d} arguments, got " 1110 f"{len(inputs):d}.") 1111 name = kwargs.pop("name", None) 1112 g = ops.get_default_graph() 1113 func_name = sig.name 1114 if name is None: 1115 name = func_name 1116 attrs = _parse_kwargs_as_attrs(func_name, **kwargs) 1117 output_types = [dtypes.DType(x.type) for x in sig.output_arg] 1118 op = g._create_op_internal( # pylint: disable=protected-access 1119 func_name, list(inputs), output_types, name=name, attrs=attrs, op_def=sig) 1120 if op.outputs: 1121 if len(op.outputs) == 1: 1122 ret = op.outputs[0] 1123 else: 1124 ret = tuple(op.outputs) 1125 else: 1126 ret = op 1127 return ret, op 1128 1129 1130def _from_definition(fdef, grad_func=None): 1131 """Creates a _DefinedFunction initialized from a FunctionDef proto. 1132 1133 Args: 1134 fdef: a FunctionDef 1135 grad_func: a _DefinedFunction or None 1136 1137 Returns: 1138 A _DefinedFunction representing fdef 1139 """ 1140 # TODO(iga): This method does major surgery on _DefinedFunction. 1141 # Make it a named constructor using @classmethod of _DefinedFunction. 1142 1143 # The Python callable is only needed to create a FunctionDef. Since we have 1144 # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we 1145 # have access to such a callable here). 1146 func = None 1147 argnames = [arg.name for arg in fdef.signature.input_arg] 1148 input_types = tuple( 1149 dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg) 1150 func_name = fdef.signature.name 1151 # Note: FunctionDefs do not include python gradient functions, so if the 1152 # original _DefinedFunction included one it will not be reflected here. 1153 python_grad_func = None 1154 out_names = [arg.name for arg in fdef.signature.output_arg] 1155 result = _DefinedFunction(func, argnames, input_types, func_name, grad_func, 1156 python_grad_func, out_names) 1157 # pylint: disable=protected-access 1158 serialized = fdef.SerializeToString() 1159 c_func = c_api.TF_FunctionImportFunctionDef(serialized) 1160 result._c_func = c_api_util.ScopedTFFunction(c_func, func_name) 1161 result._extra_inputs = [] 1162 result._op_def = fdef.signature 1163 # pylint: enable=protected-access 1164 1165 return result 1166 1167 1168def from_library(lib): 1169 """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto. 1170 1171 This method handles assigning the correct gradient functions to each 1172 function. 1173 1174 Args: 1175 lib: a FunctionDefLibrary 1176 1177 Returns: 1178 A list of _DefinedFunctions 1179 1180 Raises: 1181 ValueError: `lib` is invalid 1182 """ 1183 if not lib.function and not lib.gradient: 1184 return [] 1185 1186 # function name -> FunctionDef proto 1187 funcs = {fdef.signature.name: fdef for fdef in lib.function} 1188 1189 # Validate that all references function names have function defs 1190 for g in lib.gradient: 1191 if g.function_name not in funcs: 1192 raise ValueError(f"FunctionDefLibrary missing '{g.function_name}' " 1193 f"FunctionDef\n{lib}") 1194 if g.gradient_func not in funcs: 1195 raise ValueError(f"FunctionDefLibrary missing '{g.gradient_func}' " 1196 f"FunctionDef\n{lib}") 1197 1198 # function name -> gradient function name 1199 func_to_grad = collections.defaultdict(lambda: None) 1200 # gradient function name -> names of functions having that grad function 1201 grad_to_funcs = collections.defaultdict(list) 1202 1203 for gdef in lib.gradient: 1204 func_to_grad[gdef.function_name] = gdef.gradient_func 1205 grad_to_funcs[gdef.gradient_func].append(gdef.function_name) 1206 1207 # Start with functions without gradients 1208 ready = [ 1209 fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None 1210 ] 1211 if not ready: 1212 raise ValueError( 1213 f"FunctionDefLibrary contains cyclic gradient functions!\n{lib}") 1214 # function name -> _DefinedFunction 1215 initialized = {} 1216 1217 while ready: 1218 fdef = ready.pop() 1219 name = fdef.signature.name 1220 1221 grad = initialized.get(func_to_grad[name]) 1222 if func_to_grad[name]: 1223 assert grad 1224 defined_func = _from_definition(fdef, grad_func=grad) 1225 initialized[name] = defined_func 1226 1227 ready.extend(funcs[f] for f in grad_to_funcs[name]) 1228 1229 return initialized.values() 1230 1231 1232def _get_experimental_kwarg_as_attr(attr_name, value): 1233 """Creates an AttrValue for a python object.""" 1234 if isinstance(value, bool): 1235 return attr_value_pb2.AttrValue(b=value) 1236 elif isinstance(value, int): 1237 return attr_value_pb2.AttrValue(i=value) 1238 elif isinstance(value, float): 1239 return attr_value_pb2.AttrValue(f=value) 1240 elif isinstance(value, str): 1241 return attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 1242 else: 1243 raise ValueError(f"Attribute {attr_name} must be bool, int, float, or " 1244 f"str. Got {type(value)}.") 1245 1246 1247def _get_kwarg_as_str_attr(attr_name, value): 1248 """Creates an AttrValue for a python object.""" 1249 if isinstance(value, str): 1250 return attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 1251 else: 1252 raise ValueError(f"Attribute {attr_name} must be str. Got {type(value)}.") 1253 1254 1255def _parse_kwargs_as_attrs(func_name, **kwargs): 1256 """Parses **kwargs into a node's attributes.""" 1257 attrs = {} 1258 1259 noinline = kwargs.pop("noinline", None) 1260 if noinline is not None: 1261 attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline)) 1262 1263 # For compatibility with previous behavior, Defun does not perform shape 1264 # inference through its function call operations. 1265 attrs["_disable_call_shape_inference"] = attr_value_pb2.AttrValue(b=True) 1266 1267 compiled = kwargs.pop("compiled", None) 1268 separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None) 1269 if compiled is not None: 1270 attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled)) 1271 attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue( 1272 b=bool(separate_compiled_gradients)) 1273 # Forward _XlaScope from enclosing context (if set), otherwise create new. 1274 # pylint: disable=protected-access 1275 if "_XlaScope" in ops.get_default_graph()._attr_scope_map: 1276 attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"] 1277 else: 1278 attrs["_XlaScope"] = attr_value_pb2.AttrValue( 1279 s=("function_%s" % func_name).encode()) 1280 # pylint: enable=protected-access 1281 1282 kwargs_keys = list(kwargs.keys()) 1283 for key in kwargs_keys: 1284 if key.startswith("experimental_"): 1285 attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key]) 1286 del kwargs[key] 1287 # Support for https://github.com/tensorflow/community/pull/113/files. 1288 elif key == "_implements" or key == "_reference": 1289 attrs[key] = _get_kwarg_as_str_attr(key, kwargs[key]) 1290 del kwargs[key] 1291 if kwargs: 1292 raise ValueError(f"Unknown keyword arguments: {kwargs.keys()}.") 1293 return attrs 1294 1295 1296def get_extra_vars(): 1297 """Returns the captured variables by the function. 1298 1299 Returns: 1300 If the default graph is being used to define a function, the 1301 returned list of variables are those created inside the function 1302 body so far. Otherwise, returns an empty list. 1303 """ 1304 g = ops.get_default_graph() 1305 if isinstance(g, _FuncGraph): 1306 return g.extra_vars 1307 else: 1308 return [] 1309 1310 1311def get_extra_inputs(): 1312 """Returns the captured input tensors by the function. 1313 1314 Returns: 1315 If the default graph is being used to define a function, the 1316 returned list of tensors are those accessed inside the function body 1317 but defined outside the function body so far. Otherwise, returns an 1318 empty list. 1319 """ 1320 g = ops.get_default_graph() 1321 if isinstance(g, _FuncGraph): 1322 return g.extra_inputs 1323 else: 1324 return [] 1325 1326 1327def get_extra_args(): 1328 """Returns the corresponding function arguments for the captured inputs. 1329 1330 Returns: 1331 If the default graph is being used to define a function, the 1332 returned list of place holders are those used inside the function 1333 body corresponding those returned by get_extra_inputs(). Otherwise, 1334 returns an empty list. 1335 """ 1336 g = ops.get_default_graph() 1337 if isinstance(g, _FuncGraph): 1338 return g.extra_args 1339 else: 1340 return [] 1341 1342 1343def _type_list_to_str(types): 1344 if any(_ not in _DTYPE_TO_STR for _ in types): 1345 unsupported_types = [type_ for type_ in types if type_ not in _DTYPE_TO_STR] 1346 raise ValueError(f"Unsupported dtypes {unsupported_types} in " 1347 "`types`. Supported dtypes are " 1348 f"{_DTYPE_TO_STR.keys()}.") 1349 return "".join(_DTYPE_TO_STR[_] for _ in types) 1350 1351 1352# NOTE: The list needs to be extended when more data types are added. 1353_DTYPE_TO_STR = { 1354 dtypes.float16: "f16", 1355 dtypes.float32: "f32", 1356 dtypes.float64: "f64", 1357 dtypes.int32: "i32", 1358 dtypes.uint8: "i8", 1359 dtypes.uint16: "u16", 1360 dtypes.uint32: "u32", 1361 dtypes.uint64: "u64", 1362 dtypes.int16: "i16", 1363 dtypes.int8: "i8", 1364 dtypes.string: "s", 1365 dtypes.complex64: "c64", 1366 dtypes.complex128: "c128", 1367 dtypes.int64: "i64", 1368 dtypes.bool: "b", 1369 dtypes.qint8: "qi8", 1370 dtypes.quint8: "qu8", 1371 dtypes.qint16: "qi16", 1372 dtypes.quint16: "qu16", 1373 dtypes.qint32: "qi32", 1374 dtypes.bfloat16: "b16" 1375} 1376