1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import collections 17import copy 18import functools 19import itertools 20import multiprocessing.pool 21import os 22import re 23import sys 24import time 25import weakref 26 27from absl.testing import parameterized 28import numpy 29 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.core.protobuf import rewriter_config_pb2 32from tensorflow.python.autograph.core import ag_ctx 33from tensorflow.python.autograph.lang import directives 34from tensorflow.python.data.ops import dataset_ops 35from tensorflow.python.data.ops import iterator_ops 36from tensorflow.python.eager import backprop 37from tensorflow.python.eager import cancellation 38from tensorflow.python.eager import context 39from tensorflow.python.eager import def_function 40from tensorflow.python.eager import function 41from tensorflow.python.framework import composite_tensor 42from tensorflow.python.framework import config 43from tensorflow.python.framework import constant_op 44from tensorflow.python.framework import dtypes 45from tensorflow.python.framework import errors 46from tensorflow.python.framework import func_graph 47from tensorflow.python.framework import function as tf_function 48from tensorflow.python.framework import indexed_slices 49from tensorflow.python.framework import ops 50from tensorflow.python.framework import random_seed 51from tensorflow.python.framework import sparse_tensor 52from tensorflow.python.framework import tensor_shape 53from tensorflow.python.framework import tensor_spec 54from tensorflow.python.framework import test_ops 55from tensorflow.python.framework import test_util 56from tensorflow.python.framework import type_spec 57from tensorflow.python.layers import convolutional 58from tensorflow.python.module import module 59from tensorflow.python.ops import array_ops 60from tensorflow.python.ops import check_ops 61from tensorflow.python.ops import clip_ops 62from tensorflow.python.ops import control_flow_ops 63from tensorflow.python.ops import data_flow_ops 64from tensorflow.python.ops import functional_ops 65from tensorflow.python.ops import gen_functional_ops 66from tensorflow.python.ops import gen_random_ops 67from tensorflow.python.ops import gen_resource_variable_ops 68from tensorflow.python.ops import gen_sendrecv_ops 69from tensorflow.python.ops import gradients_impl 70from tensorflow.python.ops import init_ops 71from tensorflow.python.ops import list_ops 72from tensorflow.python.ops import logging_ops 73from tensorflow.python.ops import math_ops 74from tensorflow.python.ops import random_ops 75from tensorflow.python.ops import resource_variable_ops 76from tensorflow.python.ops import script_ops 77from tensorflow.python.ops import string_ops 78from tensorflow.python.ops import variable_scope 79from tensorflow.python.ops import variables 80from tensorflow.python.ops.ragged import ragged_factory_ops 81from tensorflow.python.ops.ragged import ragged_tensor 82from tensorflow.python.ops.structured import structured_tensor 83from tensorflow.python.platform import test 84from tensorflow.python.saved_model.load import load 85from tensorflow.python.saved_model.save import save 86from tensorflow.python.training import training_ops 87from tensorflow.python.util import compat 88from tensorflow.python.util import nest 89from tensorflow.python.util import tf_decorator 90from tensorflow.python.util import tf_inspect 91 92try: 93 import attr # pylint:disable=g-import-not-at-top 94except ImportError: 95 attr = None 96 97 98def total_function_cache(defined): 99 return defined._list_all_concrete_functions() # pylint: disable=protected-access 100 101 102def _example_indexed_slices_with_dense_shape(): 103 return indexed_slices.IndexedSlices( 104 constant_op.constant([1, 2]), constant_op.constant([0, 1]), 105 constant_op.constant([2])) 106 107 108def _example_indexed_slices_without_dense_shape(): 109 return indexed_slices.IndexedSlices( 110 constant_op.constant([1, 2]), constant_op.constant([0, 1])) 111 112 113def _spec_for_value(value): 114 """Returns the (nested) TypeSpec for a value.""" 115 if nest.is_nested(value): 116 return nest.map_structure(_spec_for_value, value) 117 elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)): 118 return type_spec.type_spec_from_value(value) 119 else: 120 return value 121 122 123# This dummy decorator imitates ordinary decorators utilizing tf_decorator. 124def dummy_tf_decorator(method): 125 126 def wrapper(*args, **kwargs): 127 return method(*args, **kwargs) 128 129 return tf_decorator.make_decorator(method, wrapper) 130 131 132# TODO(mdan): Organize these tests. 133class FunctionTest(test.TestCase, parameterized.TestCase): 134 135 def setUp(self): 136 super().setUp() 137 cpus = config.list_physical_devices('CPU') 138 # Set 4 virtual CPUs 139 config.set_logical_device_configuration(cpus[0], [ 140 context.LogicalDeviceConfiguration(), 141 context.LogicalDeviceConfiguration(), 142 context.LogicalDeviceConfiguration(), 143 context.LogicalDeviceConfiguration() 144 ]) 145 146 def testBasic(self): 147 matmul = def_function.function(math_ops.matmul) 148 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 149 sq = matmul(t, t, transpose_a=True) 150 sq2 = matmul(sq, t, transpose_a=True) 151 self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) 152 self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108]) 153 154 def testPythonFunctionNotCallable(self): 155 with self.assertRaisesRegex(TypeError, 'is not a callable object'): 156 def_function.function(1) 157 158 def testOnExitCallback(self): 159 values = [] 160 def append_1(): 161 values.append(1) 162 163 def append_2(): 164 values.append(2) 165 166 def g(x): 167 old_values = list(values) 168 ops.add_exit_callback_to_default_func_graph(append_1) 169 self.assertEqual(old_values, values) 170 return x + 1 171 172 tf_g = def_function.function(g) 173 174 def f(x): 175 old_values = list(values) 176 ops.add_exit_callback_to_default_func_graph(append_2) 177 self.assertEqual(old_values, values) 178 return tf_g(x) 179 180 tf_f = def_function.function(f) 181 self.assertEmpty(values) 182 tf_f(constant_op.constant(1.0)) 183 self.assertEqual(values, [1, 2]) # Once for g, once for f. 184 tf_f(constant_op.constant([1.0])) # force a retrace 185 self.assertEqual(values, [1, 2, 1, 2]) # And again. 186 187 def testCannotAddExitCallbackWhenNotInFunctionScope(self): 188 with self.assertRaisesRegex(RuntimeError, 'when not building a function.'): 189 ops.add_exit_callback_to_default_func_graph(lambda: None) 190 191 def testVariable(self): 192 v1 = variables.Variable(1.0) 193 add = def_function.function(lambda x, v: x + v1 + v) 194 v2 = variables.Variable(1.0) 195 x = constant_op.constant(1.0) 196 r = add(x, v2) 197 self.assertEqual(3.0, self.evaluate(r)) 198 199 def testVariableOnly(self): 200 v = variables.Variable(1.0) 201 add = def_function.function(lambda x: x.assign_add(1.0)) 202 r1 = add(v) 203 self.assertEqual(2.0, self.evaluate(r1)) 204 c = constant_op.constant(1.0) 205 with self.assertRaisesRegex(AttributeError, 'no attribute'): 206 add(c) 207 208 def testVariableMultiFunction(self): 209 @def_function.function 210 def second(dup_var, dup_var_2, some_const): 211 return dup_var + dup_var_2 + some_const 212 213 @def_function.function 214 def first(dup_var, some_const): 215 return second(dup_var, dup_var, some_const) 216 217 my_const = constant_op.constant(1) 218 my_var = variables.Variable(2, dtype=dtypes.int32) 219 self.assertEqual(second(my_var, my_var, my_const).numpy(), 5) 220 self.assertEqual(first(my_var, my_const).numpy(), 5) 221 222 @test_util.disable_tfrt('Packed tensor is not supported in tfrt yet.') 223 def testPackedVariable(self): 224 with ops.device('/cpu:0'): 225 v0_0 = resource_variable_ops.ResourceVariable(1.0) 226 with ops.device('/cpu:1'): 227 v0_1 = resource_variable_ops.ResourceVariable(2.0) 228 v1_0 = resource_variable_ops.ResourceVariable(3.0) 229 with ops.device('/cpu:2'): 230 v1_1 = resource_variable_ops.ResourceVariable(4.0) 231 232 packed_var_0 = ops.pack_eager_tensors([v0_0.handle, v0_1.handle]) 233 packed_var_1 = ops.pack_eager_tensors([v1_0.handle, v1_1.handle]) 234 235 # TODO(b/145922293): use ResourceVariable.assign_add and 236 # ResourceVariable.read_value directly once we support packing multiple 237 # ResourceVariable into one ResourceVariable. 238 @def_function.function 239 def read_var(): 240 resource_variable_ops.assign_add_variable_op( 241 packed_var_0, constant_op.constant(5.0)) 242 resource_variable_ops.assign_add_variable_op( 243 packed_var_1, constant_op.constant(6.0)) 244 with ops.device('/cpu:0'): 245 read0 = resource_variable_ops.read_variable_op( 246 packed_var_0, dtype=dtypes.float32) 247 with ops.device('/cpu:1'): 248 read1 = resource_variable_ops.read_variable_op( 249 packed_var_0, dtype=dtypes.float32) 250 read2 = resource_variable_ops.read_variable_op( 251 packed_var_1, dtype=dtypes.float32) 252 with ops.device('/cpu:2'): 253 read3 = resource_variable_ops.read_variable_op( 254 packed_var_1, dtype=dtypes.float32) 255 256 return read0, read1, read2, read3 257 258 arg_attrs = read_var.get_concrete_function().function_def.arg_attr 259 self.assertLen(arg_attrs, 2) 260 self.assertEqual(arg_attrs[0].attr['_composite_device'].s, 261 compat.as_bytes(packed_var_0.device)) 262 self.assertEqual(arg_attrs[1].attr['_composite_device'].s, 263 compat.as_bytes(packed_var_1.device)) 264 265 self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6)) 266 267 def testImplementsAttributeBasic(self): 268 v = def_function.function( 269 experimental_implements='func')(lambda x, y: x + y) 270 with context.graph_mode(), self.cached_session(): 271 a = array_ops.placeholder(dtypes.float32, ()) 272 b = array_ops.placeholder(dtypes.float32, ()) 273 v(a, b) 274 gradients_impl.gradients(v(a, b), [a, b]) 275 fdefs = ops.get_default_graph().as_graph_def().library.function 276 self.assertLen(fdefs, 3) 277 not_present = 0 278 present = 0 279 for f in fdefs: 280 name = f.signature.name 281 if 'forward' in name or 'backward' in name: 282 not_present += 1 283 self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f) 284 else: 285 present += 1 286 self.assertEqual(f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME].s, 287 'func'.encode('ascii'), f) 288 self.assertEqual(not_present, 2, fdefs) 289 self.assertEqual(present, 1, fdefs) 290 291 def testImplementsAttributeAssertsOnSideInput(self): 292 with context.graph_mode(), self.cached_session(): 293 z = array_ops.zeros(0) 294 v = def_function.function( 295 experimental_implements='func')(lambda x, y: x + y + z) 296 a = array_ops.ones((1,)) 297 b = array_ops.ones((1,)) 298 with self.assertRaisesRegex(AssertionError, 299 'variables are always captured'): 300 v(a, b) 301 functions = ops.get_default_graph().as_graph_def().library.function 302 self.assertEmpty(functions) 303 304 def testImplementsAttributeWorksWithGradientTape(self): 305 add = lambda x, y: x + y ** 2 306 add = def_function.function(experimental_implements='MyFunc')(add) 307 x = variables.Variable(3.0) 308 y = variables.Variable(2.0) 309 310 with backprop.GradientTape() as tape: 311 g = add(x, y) 312 313 dg_dy, dg_dx = tape.gradient(g, [y, x]) 314 self.assertEqual(dg_dy.numpy(), 4.0) 315 self.assertEqual(dg_dx.numpy(), 1.0) 316 317 def testImplementsAttributeWorksOnVariables(self): 318 with context.graph_mode(), self.cached_session(): 319 v = def_function.function( 320 experimental_implements='func')(lambda x, y: x + y) 321 a = variables.Variable((1.0,)) 322 b = variables.Variable((1.0,)) 323 r1 = v(a, b) 324 _ = v(a, a) 325 functions = ops.get_default_graph().as_graph_def().library.function 326 # Verify that we created only one function 327 self.assertLen(functions, 1) 328 # Verify that eval() reads the current values. 329 a.initializer.run() 330 b.initializer.run() 331 self.assertEqual(r1.eval(), 2) 332 333 a.assign_add([1]).eval() 334 self.assertEqual(r1.eval(), 3) 335 336 def testImplementsAttributeWorksOnConstants(self): 337 with context.graph_mode(), self.cached_session(): 338 v = def_function.function( 339 experimental_implements='func')(lambda x, y: x + y) 340 a = variables.Variable(1.0) 341 r1 = v(a, 2.) 342 r2 = v(2., a) 343 functions = ops.get_default_graph().as_graph_def().library.function 344 self.assertLen(functions, 1) 345 self.assertLen(functions[0].signature.input_arg, 2) 346 # Verify that eval() reads the current values. 347 a.initializer.run() 348 self.assertEqual(r1.eval(), 3) 349 self.assertEqual(r2.eval(), 3) 350 351 def testImplementsAttributeSpecializes(self): 352 with context.graph_mode(), self.cached_session(): 353 v = def_function.function( 354 experimental_implements='func')(lambda x, y: x + y) 355 a = variables.Variable(1.0) 356 r1 = v(a, [2.]) 357 r2 = v([2., 2], a) 358 functions = ops.get_default_graph().as_graph_def().library.function 359 self.assertLen(functions, 2) 360 # Ensure that all parameters are still there and haven't been inlined! 361 362 self.assertLen(functions[0].signature.input_arg, 2) 363 self.assertLen(functions[1].signature.input_arg, 2) 364 # Verify that eval() reads the current values. 365 a.initializer.run() 366 numpy.testing.assert_equal(r1.eval(), [3.]) 367 numpy.testing.assert_equal(r2.eval(), [3., 3.]) 368 369 def testImplementsWorksWithTensorSpec(self): 370 v = def_function.function( 371 experimental_implements='func')(lambda x, y: x + y) 372 v = v.get_concrete_function( 373 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), 374 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)) 375 x = v(1., 2.) 376 self.assertEqual(x.numpy(), 3.) 377 378 def testImplementsAttributeAsNameAttrList(self): 379 implements_attr = ( 380 'name: "embedding_matmul" attr { key: "key1" value { i: 2 } ' 381 '} attr { key: "key2" value { b: false } }') 382 v = def_function.function( 383 experimental_implements=implements_attr)(lambda x, y: x + y) 384 with context.graph_mode(), self.cached_session(): 385 a = array_ops.placeholder(dtypes.float32, ()) 386 b = array_ops.placeholder(dtypes.float32, ()) 387 v(a, b) 388 gradients_impl.gradients(v(a, b), [a, b]) 389 fdefs = ops.get_default_graph().as_graph_def().library.function 390 self.assertLen(fdefs, 3) 391 not_present = 0 392 present = 0 393 for f in fdefs: 394 name = f.signature.name 395 if 'forward' in name or 'backward' in name: 396 not_present += 1 397 self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f) 398 else: 399 present += 1 400 attr_value = f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME] 401 self.assertIsNotNone(attr_value.func, f) 402 self.assertEqual(attr_value.func.name, 'embedding_matmul') 403 name_attrs = attr_value.func.attr 404 self.assertLen(name_attrs, 2) 405 self.assertEqual(not_present, 2, fdefs) 406 self.assertEqual(present, 1, fdefs) 407 408 def testExternalControlDependency(self): 409 with ops.Graph().as_default(), self.test_session(): 410 v = variables.Variable(1.0) 411 v.initializer.run() 412 413 op = v.assign_add(1.0) 414 415 @function.defun 416 def f(): 417 with ops.control_dependencies([op]): 418 return 1.0 419 420 self.evaluate(f()) 421 self.assertAllEqual(self.evaluate(v), 2.0) 422 423 def testInputShapeFunctionRelaxation(self): 424 unknown_dim = [False] 425 426 @function.defun(reduce_retracing=True) 427 def func(a): 428 if a._shape_tuple()[0] is None: 429 unknown_dim[0] = True 430 return a + 1 431 432 func(constant_op.constant([])) 433 self.assertFalse(unknown_dim[0]) 434 self.assertLen(total_function_cache(func), 1) 435 436 func(constant_op.constant([1.0])) 437 self.assertTrue(unknown_dim[0]) 438 self.assertLen(total_function_cache(func), 2) 439 440 func(constant_op.constant([1.0, 2.0])) 441 self.assertTrue(unknown_dim[0]) 442 self.assertLen(total_function_cache(func), 2) 443 444 def testInputShapeRelaxationOnInstanceMethod(self): 445 # Test that reduce_retracing is passed during 446 # instance method bounding. 447 unknown_dim = [False] 448 449 class Foo: 450 451 @def_function.function(reduce_retracing=True) 452 def func(self, a): 453 if a._shape_tuple()[0] is None: 454 unknown_dim[0] = True 455 return a + 1 456 457 foo = Foo() 458 foo.func(constant_op.constant([])) 459 self.assertFalse(unknown_dim[0]) 460 461 foo.func(constant_op.constant([1.0])) 462 self.assertTrue(unknown_dim[0]) 463 464 foo.func(constant_op.constant([1.0, 2.0])) 465 self.assertTrue(unknown_dim[0]) 466 467 def testInputShapeFunctionRelaxationWithRaggedTensors(self): 468 traced_type_spec = [None] 469 470 @def_function.function(reduce_retracing=True) 471 def func(x): 472 traced_type_spec[0] = x._type_spec 473 return x 474 475 def check_trace(x, expected_trace): 476 traced_type_spec[0] = None 477 func(x) 478 self.assertEqual(traced_type_spec[0], expected_trace) 479 480 check_trace( # Initial call gets traced. 481 ragged_factory_ops.constant([[1], [2, 3, 4]]), 482 ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32)) 483 check_trace( # Input TypeSpec is the same -> no retrace. 484 ragged_factory_ops.constant([[1, 2], [3, 4]]), None) 485 check_trace( # Even if component tensor shapes change -> no retrace. 486 ragged_factory_ops.constant([[1, 2], [3, 4, 5, 6]]), None) 487 check_trace( # Different TypeSpec shape (nrows): relax & retrace 488 ragged_factory_ops.constant([[1], [2], [3]]), 489 ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)) 490 check_trace( # Different nrows again: relax & retrace 491 ragged_factory_ops.constant([[1], [2], [3], [4]]), None) 492 check_trace( # Different nrows yet again: not retrace 493 ragged_factory_ops.constant([[1]]), None) 494 check_trace( # Different ragged_rank: retrace 495 ragged_factory_ops.constant([[[1]]]), 496 ragged_tensor.RaggedTensorSpec([1, None, None], dtypes.int32)) 497 check_trace( # Different ragged_rank again: retrace & relax 498 ragged_factory_ops.constant([[[1]], [[2]]]), 499 ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)) 500 501 def testInputShapeFunctionRelaxationWithStructuredTensors(self): 502 traced_type_spec = [None] 503 504 @def_function.function(reduce_retracing=True) 505 def func(x): 506 traced_type_spec[0] = x._type_spec 507 return x 508 509 def check_trace(x, expected_trace): 510 traced_type_spec[0] = None 511 func(x) 512 self.assertEqual(traced_type_spec[0], expected_trace) 513 514 # If we have TypeSpecs that differ in ways other than just their shape, 515 # then retrace each time. 516 check_trace( 517 structured_tensor.StructuredTensor.from_pyval({'a': [1]}), 518 structured_tensor.StructuredTensor.Spec._from_fields_and_rank( 519 fields={'a': tensor_spec.TensorSpec((1,), dtypes.int32)}, 520 rank=0)) 521 check_trace( 522 structured_tensor.StructuredTensor.from_pyval({'b': [1]}), 523 structured_tensor.StructuredTensor.Spec._from_fields_and_rank( 524 fields={'b': tensor_spec.TensorSpec((1,), dtypes.int32)}, 525 rank=0)) 526 check_trace( 527 structured_tensor.StructuredTensor.from_pyval({'c': [1]}), 528 structured_tensor.StructuredTensor.Spec._from_fields_and_rank( 529 fields={'c': tensor_spec.TensorSpec((1,), dtypes.int32)}, 530 rank=0)) 531 532 # But if we call again with only shape different, then do relax: 533 check_trace( # relax & retrace 534 structured_tensor.StructuredTensor.from_pyval({'a': [1, 2]}), 535 structured_tensor.StructuredTensor.Spec._from_fields_and_rank( 536 fields={'a': tensor_spec.TensorSpec((None,), dtypes.int32)}, 537 rank=0)) 538 check_trace( # use relaxed graph 539 structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3]}), 540 None) 541 check_trace( # use relaxed graph 542 structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3, 4]}), 543 None) 544 545 def testInputShapeFunctionRelaxationWithDatasetIterators(self): 546 # For dataset iterators, the TypeSpec includes type information that's 547 # not derivable from the component tensors. Make sure that the TypeSpec 548 # shapes get relaxed as appropriate. 549 550 traced_type_spec = [None] 551 552 @def_function.function(reduce_retracing=True) 553 def func(x): 554 traced_type_spec[0] = x._type_spec 555 return x 556 557 def check_trace(x, expected_trace): 558 traced_type_spec[0] = None 559 func(x) 560 self.assertEqual(traced_type_spec[0], expected_trace) 561 562 ds_1_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([1, 2])) 563 ds_2_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 2])) 564 ds_3_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([3, 2])) 565 ds_4_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([4, 2])) 566 ds_2_1 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 1])) 567 check_trace( # shape=[1, 2]: retrace 568 dataset_ops.make_one_shot_iterator(ds_1_2), 569 iterator_ops.IteratorSpec( 570 tensor_spec.TensorSpec([1, 2], dtypes.float32))) 571 check_trace( # shape=[1, 2]: no retrace (use the [1, 2] graph) 572 dataset_ops.make_one_shot_iterator(ds_1_2), None) 573 check_trace( # shape=[2, 2]: relax to [None, 2] and retrace 574 dataset_ops.make_one_shot_iterator(ds_2_2), 575 iterator_ops.IteratorSpec( 576 tensor_spec.TensorSpec([None, 2], dtypes.float32))) 577 check_trace( # shape=[3, 2]: no retrace (use the [None, 2] graph) 578 dataset_ops.make_one_shot_iterator(ds_3_2), None) 579 check_trace( # shape=[4, 2]: no retrace (use the [None, 2] graph) 580 dataset_ops.make_one_shot_iterator(ds_4_2), None) 581 check_trace( # shape=[2, 1]: relax to [None, None] and retrace 582 dataset_ops.make_one_shot_iterator(ds_2_1), 583 iterator_ops.IteratorSpec( 584 tensor_spec.TensorSpec([None, None], dtypes.float32))) 585 586 def testCapturesVariables(self): 587 a = variables.Variable(1.0, trainable=False) 588 b = variables.Variable(1.0) 589 cc = [None] 590 591 @def_function.function 592 def f(): 593 c = cc[0] 594 if c is None: 595 c = cc[0] = variables.Variable(1.) 596 return a + b + c + 1 597 598 cf = f.get_concrete_function() 599 c = cc[0] 600 601 captured_variables = {v.ref() for v in (a, b, c)} 602 trainable_variables = {v.ref() for v in (b, c)} 603 self.assertEqual({v.ref() for v in cf.variables}, captured_variables) 604 self.assertEqual({v.ref() for v in cf.trainable_variables}, 605 trainable_variables) 606 self.assertEqual(cf.variables, cf.graph.variables) 607 self.assertEqual(cf.trainable_variables, cf.graph.trainable_variables) 608 609 def testNestedInputShapeFunctionRelaxation(self): 610 unknown_dim = [False] 611 612 @function.defun(reduce_retracing=True) 613 def func(a_, b_=None): 614 del a_ # Only used to check which cache is used. 615 self.assertEqual(b_[0]._shape_tuple(), ()) 616 if b_[1]._shape_tuple()[0] is None: 617 unknown_dim[0] = True 618 return b_[0] + 1 619 620 a = 'hi' 621 b0 = constant_op.constant(1.0) 622 func(a, b_=[b0, constant_op.constant([])]) 623 self.assertFalse(unknown_dim[0]) 624 self.assertLen(total_function_cache(func), 1) 625 626 func(a, b_=[b0, constant_op.constant([1.0])]) 627 self.assertTrue(unknown_dim[0]) 628 self.assertLen(total_function_cache(func), 2) 629 630 func(a, b_=[b0, constant_op.constant([1.0, 1.0])]) 631 self.assertTrue(unknown_dim[0]) 632 self.assertLen(total_function_cache(func), 2) 633 634 unknown_dim[0] = False 635 636 # Now do the same except with a new a which is not a tensor; this should 637 # change the cache key. 638 a = 'bye' 639 func(a, b_=[b0, constant_op.constant([])]) 640 self.assertFalse(unknown_dim[0]) 641 self.assertLen(total_function_cache(func), 3) 642 643 # We relax the type traced previously. 644 func(a, b_=[b0, constant_op.constant([1.0])]) 645 self.assertTrue(unknown_dim[0]) 646 self.assertLen(total_function_cache(func), 4) 647 648 def testNestedShapeFunctionRelaxation(self): 649 traced_shape = None 650 # The inner function will go through shape relaxation because the shapes it 651 # receives will be [1], [2], [3], ... 652 @def_function.function(reduce_retracing=True) 653 def bar(x_shape): 654 nonlocal traced_shape 655 traced_shape = x_shape._shape_tuple() 656 return x_shape 657 658 # The outer function will not go through shape relaxation because the shapes 659 # it receives will be [1], [[1]], [[[1]]], ... 660 @def_function.function(reduce_retracing=True) 661 def foo(ones): 662 return bar(array_ops.shape(ones)) 663 664 self.assertAllEqual(self.evaluate(foo(array_ops.ones([1]))), [1]) 665 self.assertEqual(traced_shape, (1,)) 666 667 for rank in range(2, 6): 668 x_shape = self.evaluate(foo(array_ops.ones([1] * rank))) 669 self.assertAllEqual(x_shape, [1] * rank) 670 self.assertEqual(traced_shape, (None,)) 671 672 def testNoHash(self): 673 674 @def_function.function() 675 def f(_): 676 return 1.0 677 678 with self.assertRaisesRegex( 679 TypeError, 680 r'could not be represented through the generic tracing'): 681 f(set([])) 682 683 def testFuncName(self): 684 685 @function.defun_with_attributes(attributes={'func_name': 'multiply'}) 686 def add(x, y): 687 _ = x * y 688 return x + y 689 690 @function.defun 691 def add_2(x, y): 692 _ = x * y 693 return x + y 694 695 self.assertEqual(add._name, 'multiply') 696 self.assertEqual(add_2._name, 'add_2') 697 698 def testBasicGraphMode(self): 699 matmul = def_function.function(math_ops.matmul) 700 701 @def_function.function 702 def sq(a): 703 return matmul(a, a) 704 705 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 706 out = sq(t) 707 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 708 709 def testNestedInputsGraphMode(self): 710 matmul = def_function.function(math_ops.matmul) 711 712 pair = collections.namedtuple('pair', ['a', 'b']) 713 714 @def_function.function 715 def a_times_b(inputs): 716 return matmul(inputs.a['a'], inputs.b['b']) 717 718 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 719 720 out = a_times_b(pair({'a': t}, {'b': t})) 721 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 722 723 def testNestedOutputsGraphMode(self): 724 matmul = def_function.function(math_ops.matmul) 725 726 pair = collections.namedtuple('pair', ['a', 'b']) 727 728 @def_function.function() 729 def pairs_mul(pair_a, pair_b): 730 return pair(matmul(pair_a.a, pair_b.a), matmul(pair_a.b, pair_b.b)) 731 732 a = constant_op.constant([[1.0, 2.0], [1.0, 2.0]]) 733 b = constant_op.constant([[3.0, 4.0], [3.0, 4.0]]) 734 735 out = pairs_mul(pair(a, b), pair(b, a)) 736 expected = pair(math_ops.matmul(a, b).numpy(), 737 math_ops.matmul(b, a).numpy()) 738 self.assertAllClose(out, expected) 739 740 @parameterized.named_parameters( 741 dict(testcase_name='Defun', 742 function_decorator=function.defun), 743 dict(testcase_name='DefFunction', 744 function_decorator=def_function.function)) 745 def testNestedFunctionGraphNotOutOfDate(self, function_decorator): 746 @function_decorator 747 def f(): 748 return constant_op.constant(1.) 749 750 class _Model(object): 751 752 @function_decorator 753 def g(self): 754 self.f = f.get_concrete_function() 755 756 model = _Model() 757 model.g() 758 concrete = model.f 759 weak_g_graph = weakref.ref(model.g.get_concrete_function().graph) 760 self.assertIs(weak_g_graph(), concrete.graph.outer_graph) 761 weak_g = weakref.ref(model.g) 762 del model 763 self.assertIsNone(weak_g()) 764 self.assertIsNone(weak_g_graph()) 765 self.assertIsNotNone(concrete.graph.outer_graph) 766 self.assertIs(ops.get_default_graph(), concrete.graph.outer_graph) 767 768 def testGraphEagerIsolation(self): 769 770 @function.defun 771 def f(): 772 self.v = variables.Variable(1.0) 773 return self.v.read_value() 774 775 self.assertAllEqual(f(), 1.0) 776 777 with ops.Graph().as_default(): 778 self.assertEqual(f().shape, ()) 779 780 def testBasicGraphFunction(self): 781 matmul = def_function.function(math_ops.matmul) 782 783 @def_function.function 784 def sq(a): 785 return matmul(a, a) 786 787 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 788 789 sq_op = sq.get_concrete_function(t) 790 self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) 791 out = sq_op(t) 792 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 793 794 def testGetConcreteFunctionThreadSafety(self): 795 796 @def_function.function 797 def sq(): 798 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 799 return math_ops.matmul(t, t) 800 801 concrete_functions = [] 802 803 def thread_func(_): 804 cf = sq.get_concrete_function() 805 concrete_functions.append(cf) 806 807 num_threads = 100 808 pool = multiprocessing.pool.ThreadPool(num_threads) 809 _ = pool.map(thread_func, list(range(num_threads))) 810 811 self.assertLen(set(concrete_functions), 1) 812 813 def testGetConcreteFunctionThreadSafetyWithArgs(self): 814 @def_function.function 815 def add_100(*args): 816 return math_ops.add_n(args) 817 818 p = multiprocessing.pool.ThreadPool(2) 819 args = (constant_op.constant(1.),) * 100 820 f1, f2 = p.map(add_100.get_concrete_function, [args] * 2) 821 # I see about len(args) + max(0, len(args) - 3) arguments expected. 822 f1(*args) 823 del f2 824 825 def testInputSpecGraphFunction(self): 826 matmul = def_function.function(math_ops.matmul) 827 828 @def_function.function 829 def sq(a): 830 return matmul(a, a) 831 832 sq_op = sq.get_concrete_function( 833 tensor_spec.TensorSpec((None, None), dtypes.float32)) 834 self.assertEqual([None, None], sq_op.output_shapes.as_list()) 835 836 t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 837 out1 = sq_op(t1) 838 self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy()) 839 840 t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 841 out2 = sq_op(t2) 842 self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy()) 843 844 def testNestedInputSpecGraphFunction(self): 845 matmul = def_function.function(math_ops.matmul) 846 847 @def_function.function 848 def sq(mats): 849 ((a, b),) = mats 850 return matmul(a, b) 851 852 sq_op_autonamed = sq.get_concrete_function( 853 [(tensor_spec.TensorSpec((None, None), dtypes.float32), 854 tensor_spec.TensorSpec((None, None), dtypes.float32))]) 855 self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list()) 856 857 sq_op = sq.get_concrete_function( 858 [(tensor_spec.TensorSpec((None, None), dtypes.float32, 859 name='first_mat'), 860 tensor_spec.TensorSpec((None, None), dtypes.float32, 861 name='second_mat'))]) 862 self.assertEqual([None, None], sq_op.output_shapes.as_list()) 863 864 t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 865 t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]]) 866 out = sq_op(first_mat=t1, second_mat=t2) 867 self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy()) 868 self.assertAllEqual(sq_op_autonamed(t1, t2), 869 math_ops.matmul(t1, t2).numpy()) 870 871 def testExecutingStatelessDefunConcurrently(self): 872 873 @def_function.function 874 def stateless(x): 875 return math_ops.multiply(2.0, x) 876 877 pool = multiprocessing.pool.ThreadPool() 878 inputs = [constant_op.constant(1.0 * x) for x in range(100)] 879 outputs = [float(out) for out in pool.map(stateless, inputs)] 880 expected = [float(2.0 * x) for x in inputs] 881 self.assertSequenceEqual(outputs, expected) 882 883 def testExecutingManyStatelessDefunsConcurrently(self): 884 885 @def_function.function 886 def stateless(x): 887 del x 888 return math_ops.multiply(2.0, 2.0) 889 890 pool = multiprocessing.pool.ThreadPool() 891 # `pool.map` below instantiates 100 functions, one for each object. 892 objects = [object() for _ in range(100)] 893 outputs = [float(out) for out in pool.map(stateless, objects)] 894 expected = [4.0] * 100 895 self.assertSequenceEqual(outputs, expected) 896 897 @test_util.disable_tfrt('b/169431085: This test is flaky on tfrt') 898 def testExecutingStatefulDefunConcurrently(self): 899 900 v = resource_variable_ops.ResourceVariable(1.0) 901 902 @def_function.function 903 def stateful(x): 904 v.assign(x) 905 906 pool = multiprocessing.pool.ThreadPool() 907 inputs = [constant_op.constant(0.0)] * 100 908 pool.map(stateful, inputs) 909 self.assertEqual(float(v.read_value()), 0.0) 910 911 def testExecutingManyStatefulDefunsConcurrently(self): 912 913 v = resource_variable_ops.ResourceVariable(1.0) 914 915 @def_function.function 916 def stateful(x): 917 del x 918 return v.assign(0.0) 919 920 pool = multiprocessing.pool.ThreadPool() 921 # `pool.map` below instantiates 100 functions, one for each object. 922 pool.map(stateful, [object() for _ in range(100)]) 923 self.assertEqual(float(v.read_value()), 0.0) 924 925 def testShareRendezvous(self): 926 927 # Disable grappler from inlining the functions. Note we run the send & recv 928 # in graph mode since with eager mode the function should automatically be 929 # inlined. 930 context.context().set_optimizer_experimental_options( 931 {'disable_meta_optimizer': True}) 932 933 cpu = '/device:CPU:0' 934 935 signature = [tensor_spec.TensorSpec([], dtypes.int32)] 936 937 @def_function.function 938 def send(): 939 x = constant_op.constant(1) 940 gen_sendrecv_ops.send(x, 'x', cpu, 0, cpu) 941 return x 942 943 send._shared_rendezvous = True # pylint: disable=protected-access 944 945 @def_function.function(input_signature=signature) 946 def send_body(n): 947 send() 948 return n - 1 949 950 @def_function.function 951 def recv(): 952 return gen_sendrecv_ops.recv(dtypes.int32, 'x', cpu, 0, cpu) 953 954 recv._shared_rendezvous = True # pylint: disable=protected-access 955 956 @def_function.function(input_signature=signature) 957 def recv_body(n): 958 recv() 959 return n - 1 960 961 @def_function.function(input_signature=signature) 962 def cond(n): 963 return n > 0 964 965 # Instead of calling the send & recv functions directly we want to call them 966 # through a functional while to ensure the rendezvous is shared across the 967 # while boundary. 968 @def_function.function 969 def fn(n): 970 functional_ops.While([n], cond.get_concrete_function(), 971 send_body.get_concrete_function()) 972 return functional_ops.While([n], cond.get_concrete_function(), 973 recv_body.get_concrete_function()) 974 975 # Use a graph context since functions will not be automatically inlined 976 with context.graph_mode(), self.cached_session(): 977 self.evaluate(fn(2)) 978 979 def disabled_testRandomSeed(self): 980 981 @def_function.function 982 def f(): 983 return random_ops.random_normal(()) 984 985 random_seed.set_random_seed(1) 986 x = f() 987 self.assertNotEqual(x, f()) 988 random_seed.set_random_seed(1) 989 self.assertAllEqual(f(), x) 990 991 def testNestedInputsGraphFunction(self): 992 matmul = def_function.function(math_ops.matmul) 993 994 pair = collections.namedtuple('pair', ['a', 'b']) 995 996 @def_function.function 997 def a_times_b(inputs): 998 return matmul(inputs.a['a'], inputs.b['b']) 999 1000 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 1001 sq_op = a_times_b.get_concrete_function( 1002 pair(dict(a=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'a')), 1003 dict(b=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'b')))) 1004 self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) 1005 out = sq_op(a=t, b=t) 1006 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 1007 1008 def testNestedOutputGraphFunction(self): 1009 matmul = def_function.function(math_ops.matmul) 1010 1011 @def_function.function 1012 def sq(a): 1013 return (matmul(a, a), {'b': constant_op.constant(1.0)}) 1014 1015 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 1016 1017 sq_op = sq.get_concrete_function(t) 1018 self.assertEqual(sq_op.output_shapes, 1019 (tensor_shape.TensorShape([2, 2]), 1020 {'b': tensor_shape.TensorShape([])})) 1021 self.assertEqual(sq_op.output_dtypes, 1022 (dtypes.float32, {'b': dtypes.float32})) 1023 (a, b) = sq_op(t) 1024 self.assertAllEqual(a, math_ops.matmul(t, t).numpy()) 1025 self.assertAllEqual(b['b'].numpy(), 1.0) 1026 1027 def testGraphFunctionNoneOutput(self): 1028 @def_function.function 1029 def fn(unused_a, unused_b): 1030 return None 1031 1032 x = constant_op.constant(1) 1033 fn_op = fn.get_concrete_function(x, x) 1034 self.assertEqual(fn_op.output_dtypes, None) 1035 self.assertEqual(fn_op.output_shapes, None) 1036 self.assertAllEqual(fn_op(x, x), None) 1037 1038 def testDefunNumpyArraysConvertedToTensors(self): 1039 1040 def f(x): 1041 self.assertIsInstance(x, ops.Tensor) 1042 return x 1043 1044 x = random_ops.random_uniform([2, 2]).numpy() 1045 defined = function.defun(f) 1046 defined(x) 1047 self.assertLen(total_function_cache(defined), 1) 1048 1049 x = random_ops.random_uniform([2, 2]).numpy() 1050 defined(x) 1051 # A NumPy array with different values but the same shape and dtype 1052 # shouldn't trigger another function definition. 1053 self.assertLen(total_function_cache(defined), 1) 1054 1055 np_ones = numpy.ones([], numpy.float32) 1056 np_zeros = numpy.zeros([], numpy.float32) 1057 tf_ones = array_ops.ones([]) 1058 tf_zeros = array_ops.zeros([]) 1059 1060 # Test that the numpy array is properly an argument to the graph function. 1061 self.assertEqual(1., defined(np_ones).numpy()) 1062 self.assertLen(total_function_cache(defined), 2) 1063 self.assertEqual(0., defined(np_zeros).numpy()) 1064 self.assertEqual(1., defined(tf_ones).numpy()) 1065 self.assertEqual(0., defined(tf_zeros).numpy()) 1066 self.assertLen(total_function_cache(defined), 2) 1067 1068 # Test that mutable inputs are supported. 1069 mutable = numpy.ones([], numpy.float32) 1070 self.assertEqual(1., defined(mutable).numpy()) 1071 mutable.fill(0) 1072 self.assertEqual(0., defined(mutable).numpy()) 1073 1074 class MyNdarray(numpy.ndarray): 1075 pass 1076 1077 # Test that the subclasses of ndarray are converted too. 1078 self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy()) 1079 self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy()) 1080 1081 # We should not have triggered any re-tracing of the python function. 1082 self.assertLen(total_function_cache(defined), 2) 1083 1084 def testNumpyDtypeInputSupported(self): 1085 @function.defun 1086 def f(x, dtype): 1087 return constant_op.constant(dtype(x)) 1088 1089 self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1)) 1090 self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2)) 1091 self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1)) 1092 self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2)) 1093 1094 def testDefunNumpyArraysConvertedToTensorsInKwargs(self): 1095 1096 def f(**kwargs): 1097 x = kwargs.pop('x') 1098 self.assertIsInstance(x, ops.Tensor) 1099 return x 1100 1101 x = random_ops.random_uniform([2, 2]).numpy() 1102 defined = function.defun(f) 1103 defined(x=x) 1104 self.assertLen(total_function_cache(defined), 1) 1105 1106 x = random_ops.random_uniform([2, 2]).numpy() 1107 defined(x=x) 1108 # A NumPy array with different values but the same shape and dtype 1109 # shouldn't trigger another function definition. 1110 self.assertLen(total_function_cache(defined), 1) 1111 1112 # Test that the numpy array is properly an argument to the graph function. 1113 self.assertEqual(1., defined(x=numpy.ones([])).numpy()) 1114 self.assertEqual(0., defined(x=numpy.zeros([])).numpy()) 1115 self.assertEqual(1., defined(x=array_ops.ones([])).numpy()) 1116 self.assertEqual(0., defined(x=array_ops.zeros([])).numpy()) 1117 1118 def testDefunCapturedInt32(self): 1119 x = constant_op.constant(1, dtype=dtypes.int32) 1120 1121 @def_function.function 1122 def add_int32s(): 1123 return x + x 1124 1125 self.assertEqual(2, int(add_int32s())) 1126 1127 def testDefunReadVariable(self): 1128 v = resource_variable_ops.ResourceVariable(1.0) 1129 1130 @def_function.function 1131 def f(): 1132 return v.read_value() 1133 1134 self.assertEqual(1.0, float(f())) 1135 1136 def testDefunAssignAddVariable(self): 1137 v = resource_variable_ops.ResourceVariable(1.0) 1138 x = constant_op.constant(2.0) 1139 1140 @def_function.function 1141 def test_assign_add(): 1142 v.assign_add(x) 1143 return v.read_value() 1144 1145 self.assertEqual(3.0, float(test_assign_add())) 1146 1147 @test_util.run_in_graph_and_eager_modes 1148 def testTensorInitializationInFunctionRaisesError(self): 1149 1150 @def_function.function 1151 def tensor_init(): 1152 with self.assertRaisesRegex(ValueError, 'could not be lifted out'): 1153 resource_variable_ops.ResourceVariable(constant_op.constant(2.0)) 1154 1155 tensor_init() 1156 1157 @test_util.run_in_graph_and_eager_modes 1158 def testCallableTensorInitializationInFunction(self): 1159 1160 @def_function.function 1161 def tensor_init(): 1162 self.v = resource_variable_ops.ResourceVariable( 1163 lambda: constant_op.constant(2.0)) 1164 return self.v.read_value() 1165 1166 value = tensor_init() 1167 if not context.executing_eagerly(): 1168 self.evaluate(variables.global_variables_initializer()) 1169 self.assertEqual(self.evaluate(value), 2.0) 1170 1171 @test_util.also_run_as_tf_function 1172 def testInitScopeTensorInitializationInFunction(self): 1173 1174 @def_function.function 1175 def tensor_init(): 1176 with ops.init_scope(): 1177 const = constant_op.constant(2.0) 1178 # Note: this variable bypasses tf.function's variable creation 1179 # requirements by bypassing variable_creator_scope by using 1180 # ResourceVariable instead of Variable. 1181 self.v = resource_variable_ops.ResourceVariable(const) 1182 return self.v.read_value() 1183 1184 value = tensor_init() 1185 self.assertAllEqual(value, 2.0) 1186 1187 @test_util.run_in_graph_and_eager_modes 1188 def testGetConcreteFunctionCreatesVariables(self): 1189 1190 v_holder = [] 1191 1192 @def_function.function 1193 def tensor_init(): 1194 if not v_holder: 1195 v_holder.append(variables.Variable(5.)) 1196 return v_holder[0].read_value() 1197 1198 concrete = tensor_init.get_concrete_function() 1199 self.evaluate(variables.global_variables_initializer()) 1200 self.assertAllEqual(5., self.evaluate(concrete())) 1201 self.assertAllEqual(5., self.evaluate(tensor_init())) 1202 1203 def testFuncGraphCaptureByValue(self): 1204 v = variables.Variable(1.0) 1205 1206 def trivial_function(): 1207 return v.read_value() 1208 1209 graph_function = function.Function( 1210 trivial_function, 'test', capture_by_value=True) 1211 1212 self.assertAllEqual(graph_function(), 1.0) 1213 v.assign(2.0) 1214 self.assertAllEqual(graph_function(), 1.0) 1215 1216 def testFuncGraphCaptureByValueNested(self): 1217 v = variables.Variable(1.0) 1218 1219 def trivial_function(): 1220 return control_flow_ops.cond( 1221 array_ops.placeholder_with_default(True, ()), 1222 v.read_value, v.read_value) 1223 1224 graph_function = function.Function( 1225 trivial_function, 'test', capture_by_value=True) 1226 1227 self.assertAllEqual(graph_function(), 1.0) 1228 v.assign(2.0) 1229 self.assertAllEqual(graph_function(), 1.0) 1230 1231 def testDefunShapeInferenceWithCapturedResourceVariable(self): 1232 v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) 1233 1234 def f(): 1235 x = constant_op.constant([[1, 2], [3, 4]]) 1236 out = math_ops.matmul(v, x) 1237 self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) 1238 # We do not return v directly since the tensor conversion function of 1239 # ResourceVariable returns the read value and not the resource itself. 1240 return v._handle 1241 1242 compiled = def_function.function(f) 1243 var_handle = compiled() 1244 self.assertEqual(var_handle.dtype, dtypes.resource) 1245 self.assertEqual(var_handle.shape, tensor_shape.TensorShape([])) 1246 var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) 1247 self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) 1248 1249 def testShapeInferenceForMoreSpecificInput(self): 1250 1251 def f(a): 1252 return array_ops.reshape(a, [-1, 3]) 1253 1254 signature = [tensor_spec.TensorSpec(None, dtypes.float32)] 1255 compiled = def_function.function(f, input_signature=signature) 1256 1257 @def_function.function 1258 def use_f(): 1259 inputs = array_ops.zeros([10, 10, 3]) 1260 self.assertAllEqual(f(inputs).shape, compiled(inputs).shape) 1261 1262 use_f() 1263 1264 def testFuncListAttr(self): 1265 1266 @function.defun 1267 def test_function(val): 1268 1269 def fn1(): 1270 return array_ops.ones([10]) 1271 1272 fn2 = lambda: array_ops.ones([10]) * 2 1273 1274 def fn3(x=3): 1275 return array_ops.ones([10]) * x 1276 fn4 = functools.partial(fn3, x=4) 1277 fn5 = functools.partial(fn3, 5) 1278 1279 return gen_functional_ops.case(val, [], [dtypes.float32], 1280 [function.defun(f).get_concrete_function() 1281 for f in (fn1, fn2, fn3, fn4, fn5)]) 1282 1283 ones = array_ops.ones([10]) 1284 self.assertAllEqual([ones], test_function(0)) 1285 self.assertAllEqual([ones * 2], test_function(1)) 1286 self.assertAllEqual([ones * 3], test_function(2)) 1287 self.assertAllEqual([ones * 4], test_function(3)) 1288 self.assertAllEqual([ones * 5], test_function(4)) 1289 self.assertAllEqual([ones * 5], test_function(22)) # default branch 1290 1291 @test_util.enable_control_flow_v2 1292 def testVariableInLoopInFunction(self): 1293 1294 @function.defun 1295 def test_function(): 1296 1297 def loop_test(_): 1298 return False 1299 1300 def loop_body(_): 1301 return variable_scope.get_variable('a', shape=()) 1302 1303 return control_flow_ops.while_loop(loop_test, loop_body, [0.0]) 1304 1305 self.assertEqual(test_function().shape, []) 1306 1307 def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self): 1308 with context.graph_mode(): 1309 v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) 1310 1311 def f(): 1312 x = constant_op.constant([[1, 2], [3, 4]]) 1313 out = math_ops.matmul(v, x) 1314 self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) 1315 # We do not return v directly since the tensor conversion function of 1316 # ResourceVariable returns the read value and not the resource itself. 1317 return v._handle 1318 1319 compiled = def_function.function(f) 1320 var_handle = compiled() 1321 self.assertEqual(var_handle.dtype, dtypes.resource) 1322 self.assertEqual(var_handle.shape, tensor_shape.TensorShape([])) 1323 var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) 1324 self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) 1325 1326 def testDefunShapeInferenceWithCapturedVariableInGraphMode(self): 1327 with context.graph_mode(): 1328 v = variables.Variable([[1, 2], [3, 4]]) 1329 1330 def f(): 1331 x = constant_op.constant([[1, 2], [3, 4]]) 1332 out = math_ops.matmul(v, x) 1333 self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) 1334 1335 # Check that shape inference works while creating the defun 1336 compiled = def_function.function(f) 1337 compiled() 1338 1339 def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self): 1340 with context.graph_mode(): 1341 tensor_list = list_ops.empty_tensor_list( 1342 element_dtype=dtypes.float32, 1343 element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) 1344 tensor_list = list_ops.tensor_list_push_back(tensor_list, 1345 constant_op.constant(1.0)) 1346 tensor_list = list_ops.tensor_list_push_back(tensor_list, 1347 constant_op.constant(2.0)) 1348 1349 def f(): 1350 tl, value = list_ops.tensor_list_pop_back( 1351 tensor_list, element_dtype=dtypes.float32) 1352 self.assertEqual(value.shape, tensor_shape.TensorShape([])) 1353 return tl 1354 1355 compiled = def_function.function(f) 1356 output_tensor_list = compiled() 1357 _, value = list_ops.tensor_list_pop_back( 1358 output_tensor_list, element_dtype=dtypes.float32) 1359 self.assertEqual(value.shape, tensor_shape.TensorShape([])) 1360 1361 @test_util.run_in_graph_and_eager_modes 1362 def testDefunForcesResourceVariables(self): 1363 1364 def variable_creator(): 1365 self.v = variables.Variable(0.0) 1366 return self.v.read_value() 1367 1368 self.v = None 1369 defined = function.defun(variable_creator) 1370 defined() # Create the variable. 1371 self.assertIsInstance( 1372 self.v, resource_variable_ops.ResourceVariable) 1373 1374 def testRunMetadata(self): 1375 1376 @def_function.function 1377 def f(x): 1378 return x * x 1379 1380 with ops.device('cpu:0'): 1381 context.enable_run_metadata() 1382 f(constant_op.constant(1.0)) 1383 run_metadata = context.export_run_metadata() 1384 context.disable_run_metadata() 1385 self.assertLen(run_metadata.partition_graphs, 1) 1386 1387 def testGraphModeCaptureVariable(self): 1388 with context.graph_mode(), self.cached_session(): 1389 1390 class HasAVar: 1391 1392 def __init__(self): 1393 self.v = resource_variable_ops.ResourceVariable(1.0) 1394 1395 def call(self): 1396 return self.v * 2 1397 1398 o = HasAVar() 1399 self.evaluate(variables.global_variables_initializer()) 1400 call = def_function.function(o.call) 1401 op = call() 1402 self.assertAllEqual(self.evaluate(op), 2.0) 1403 1404 def testGraphModeManyFunctions(self): 1405 with ops.Graph().as_default(), self.cached_session(): 1406 1407 @def_function.function 1408 def f(x): 1409 return x * x 1410 1411 @def_function.function 1412 def g(x): 1413 return f(x) + 1 1414 1415 self.assertAllEqual(g(constant_op.constant(2.0)), 5.0) 1416 1417 def testDict(self): 1418 1419 @def_function.function 1420 def f(x): 1421 return {'name': x + 1} 1422 1423 self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0) 1424 1425 def testWeakrefInputsRejected(self): 1426 1427 @def_function.function 1428 def f(x): 1429 return x 1430 1431 class Dummy: 1432 pass 1433 o = Dummy() 1434 wr = weakref.ref(o) 1435 1436 with self.assertRaisesRegex(ValueError, 'weakref'): 1437 f(wr) 1438 1439 def testTensorConversionWithDefun(self): 1440 1441 @def_function.function 1442 def f(x): 1443 return math_ops.add(x, constant_op.constant(3)) 1444 1445 self.assertAllEqual(5, f(constant_op.constant(2))) 1446 1447 def testTensorConversionCall(self): 1448 1449 @def_function.function 1450 def f(x): 1451 return math_ops.add(x, constant_op.constant(3)) 1452 1453 @def_function.function 1454 def g(x): 1455 return f(f(x)) 1456 1457 self.assertAllEqual(8, g(constant_op.constant(2))) 1458 1459 def testCallShape(self): 1460 1461 @def_function.function 1462 def f(x): 1463 return x + 1 1464 1465 @def_function.function 1466 def g(x): 1467 x = f(x) 1468 self.assertEqual(x.shape.as_list(), []) 1469 return None 1470 1471 g(constant_op.constant(1.0)) 1472 1473 def testNestedDefunWithNoOutputAndTapedInput(self): 1474 three = resource_variable_ops.ResourceVariable(3.0, name='v') 1475 1476 @def_function.function 1477 def f(x): 1478 # This function intentionally takes a taped variable as input, 1479 # but does not return any values 1480 math_ops.add(x, three) 1481 1482 @def_function.function 1483 def g(x): 1484 y = math_ops.add(x, three) 1485 f(y) 1486 1487 g(three) 1488 1489 def testGatherResourceWithDefun(self): 1490 with ops.device('cpu:0'): 1491 v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) 1492 1493 def sum_gather(): 1494 return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) 1495 1496 defined = def_function.function(sum_gather) 1497 self.assertAllEqual(sum_gather(), defined()) 1498 1499 @parameterized.named_parameters([ 1500 ('IndexedSlicesWithDenseShape', 1501 _example_indexed_slices_with_dense_shape,), 1502 ('IndexedSlicesWithoutDenseShape', 1503 _example_indexed_slices_without_dense_shape,), 1504 ('RaggedTensorRaggedRank1', ragged_tensor.RaggedTensor.from_row_lengths, 1505 {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}), 1506 ('RaggedTensorRaggedRank2', 1507 ragged_tensor.RaggedTensor.from_nested_row_lengths, 1508 {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}), 1509 ('SparseTensor', sparse_tensor.SparseTensor, 1510 {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}), 1511 ]) # pyformat: disable 1512 def testReturnCompositeTensorWithDefun(self, 1513 factory_fn, 1514 factory_kwargs={}, 1515 input_signature=None): 1516 input_ct = factory_fn(**factory_kwargs) 1517 1518 @def_function.function(input_signature=input_signature) 1519 def f(): 1520 return input_ct 1521 1522 output_ct = f() 1523 self.assertIsInstance(output_ct, type(input_ct)) 1524 nest.assert_same_structure(input_ct, output_ct, expand_composites=True) 1525 1526 input_flat = nest.flatten(input_ct, expand_composites=True) 1527 output_flat = nest.flatten(output_ct, expand_composites=True) 1528 for (input_component, output_component) in zip(input_flat, output_flat): 1529 self.assertAllEqual(input_component, output_component) 1530 1531 @parameterized.named_parameters([ 1532 ('IndexedSlicesWithDenseShape', 1533 _example_indexed_slices_with_dense_shape,), 1534 ('IndexedSlicesWithoutDenseShape', 1535 _example_indexed_slices_without_dense_shape,), 1536 ('RaggedTensorRaggedRank1', 1537 ragged_tensor.RaggedTensor.from_row_lengths, 1538 {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}), 1539 ('RaggedTensorRaggedRank2', 1540 ragged_tensor.RaggedTensor.from_nested_row_lengths, 1541 {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}), 1542 ('SparseTensor', 1543 sparse_tensor.SparseTensor, 1544 {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}), 1545 ('RaggedTensorRaggedRank1WithSignature', 1546 ragged_tensor.RaggedTensor.from_row_lengths, 1547 {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}, 1548 [ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]), 1549 ('RaggedTensorRaggedRank2WithSignature', 1550 ragged_tensor.RaggedTensor.from_nested_row_lengths, 1551 {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}, 1552 [ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)]), 1553 ('SparseTensorWithSignature', 1554 sparse_tensor.SparseTensor, 1555 {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}, 1556 [sparse_tensor.SparseTensorSpec([None], dtypes.int32)]), 1557 ]) # pyformat: disable 1558 def testCompositeAsArgumentTensorWithDefun(self, 1559 factory_fn, 1560 factory_kwargs={}, 1561 input_signature=None): 1562 input_ct = factory_fn(**factory_kwargs) 1563 1564 @def_function.function(input_signature=input_signature) 1565 def f(x): 1566 return x 1567 1568 output_ct = f(input_ct) 1569 self.assertIsInstance(output_ct, type(input_ct)) 1570 nest.assert_same_structure(input_ct, output_ct, expand_composites=True) 1571 1572 input_flat = nest.flatten(input_ct, expand_composites=True) 1573 output_flat = nest.flatten(output_ct, expand_composites=True) 1574 for (input_component, output_component) in zip(input_flat, output_flat): 1575 self.assertAllEqual(input_component, output_component) 1576 1577 def testTracedCompositeDiscardsShapeInfo(self): 1578 # SparseTensorSpec intentionally excludes info about the number of elements 1579 # that are in a sparse tensor (which is recorded as st.indices.shape[0] and 1580 # st.values.shape[0]). Similarly, RaggedTensorSpec intentionally excludes 1581 # info about the total number of values in a RaggedTensor (stored as 1582 # rt.values.shape[0]). This test checks that the placeholders created by 1583 # tf.function() properly mask this shape info. 1584 @def_function.function 1585 def f(rt, st): 1586 self.assertEqual(st.indices.shape.as_list()[:1], [None]) 1587 self.assertEqual(st.values.shape.as_list(), [None]) 1588 return (rt, st) 1589 1590 rt = ragged_factory_ops.constant([[1, 2], [3]]) 1591 st = sparse_tensor.SparseTensor([[0]], [0], [10]) 1592 f(rt, st) 1593 1594 @test_util.run_gpu_only 1595 def testFunctionOnDevice(self): 1596 x = constant_op.constant([1.]).gpu() 1597 f = def_function.function(math_ops.add) 1598 y = f(x, x).cpu() 1599 self.assertAllEqual(y, [2.]) 1600 1601 @test_util.run_gpu_only 1602 @test_util.run_in_graph_and_eager_modes 1603 def testFunctionWithResourcesOnDifferentDevices(self): 1604 with ops.device('/cpu:0'): 1605 v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) 1606 1607 with ops.device('/gpu:0'): 1608 v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) 1609 1610 def sum_gather(): 1611 cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2])) 1612 gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) 1613 return cpu_result, gpu_result 1614 1615 defined = function.defun(sum_gather) 1616 if not context.executing_eagerly(): 1617 self.evaluate(variables.global_variables_initializer()) 1618 expected = self.evaluate(sum_gather()) 1619 self.assertAllEqual(expected, self.evaluate(defined())) 1620 1621 @test_util.run_gpu_only 1622 @test_util.run_in_graph_and_eager_modes 1623 def testOpInFunctionWithConflictingResourceInputs(self): 1624 with ops.device('/cpu:0'): 1625 v_cpu = resource_variable_ops.ResourceVariable( 1626 [0.0, 1.0, 2.0], name='cpu') 1627 v_also_cpu = resource_variable_ops.ResourceVariable( 1628 [0.0, 1.0, 2.0], name='also_cpu') 1629 1630 with ops.device('/gpu:0'): 1631 v_gpu = resource_variable_ops.ResourceVariable( 1632 [0.0, 1.0, 2.0], name='gpu') 1633 1634 @def_function.function 1635 def resource_apply_adam(): 1636 training_ops.resource_apply_adam( 1637 v_cpu.handle, 1638 v_gpu.handle, 1639 v_also_cpu.handle, 1640 1.0, # beta1_power 1641 1.0, # beta2_power 1642 1.0, # learning_rate 1643 1.0, # beta1 1644 1.0, # beta2 1645 1.0, # epsilon, 1646 [1.0, 1.0, 1.0], # grad 1647 False) # use_locking 1648 return None 1649 1650 with self.assertRaisesRegex( 1651 errors.InvalidArgumentError, 1652 'Cannot place the graph because a reference or resource edge connects ' 1653 'colocation groups with incompatible assigned devices'): 1654 if not context.executing_eagerly(): 1655 self.evaluate(variables.global_variables_initializer()) 1656 self.evaluate(resource_apply_adam()) 1657 1658 @test_util.run_gpu_only 1659 def testFunctionHandlesInputsOnDifferentDevices(self): 1660 # The Reshape op requires the shape tensor to be placed in host memory. 1661 reshape = def_function.function(array_ops.reshape) 1662 value = constant_op.constant([1., 2.]).gpu() 1663 shape = constant_op.constant([2, 1]) 1664 reshaped = reshape(value, shape).cpu() 1665 self.assertAllEqual(reshaped, [[1], [2]]) 1666 1667 @test_util.run_gpu_only 1668 def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self): 1669 # The Reshape op requires the shape tensor to be placed in host memory. 1670 reshape = def_function.function(array_ops.reshape) 1671 value = constant_op.constant([1., 2.]) 1672 shape = constant_op.constant([2, 1]).gpu() 1673 reshape(value, shape) # No error is raised 1674 1675 def testNoneOutput(self): 1676 1677 @def_function.function 1678 def my_function(_): 1679 return None 1680 1681 self.assertAllEqual(my_function(1), None) 1682 1683 def testNestedFunctions(self): 1684 # TensorFlow function (which is what would be used in TensorFlow graph 1685 # construction). 1686 @tf_function.Defun(dtypes.int32, dtypes.int32) 1687 def add(a, b): 1688 return math_ops.add(a, b) 1689 1690 @def_function.function 1691 def add_one(x): 1692 return add(x, 1) 1693 1694 self.assertAllEqual(3, add_one(constant_op.constant(2))) 1695 1696 def testVariableCaptureInNestedFunctions(self): 1697 v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32) 1698 1699 @def_function.function 1700 def inner_read(): 1701 return v.read_value() 1702 1703 @def_function.function 1704 def outer(): 1705 return inner_read() 1706 1707 self.assertEqual(1, int(outer())) 1708 1709 def testReturnCapturedEagerTensor(self): 1710 t = constant_op.constant(1) 1711 1712 @def_function.function 1713 def read(): 1714 return t 1715 1716 self.assertEqual(1, int(read())) 1717 1718 def testReturnCapturedGraphTensor(self): 1719 with context.graph_mode(), self.cached_session(): 1720 t = constant_op.constant(1) 1721 1722 @def_function.function 1723 def read(): 1724 return t 1725 1726 self.assertEqual(1, int(self.evaluate(read()))) 1727 1728 def testSequenceInputs(self): 1729 clip_by_global_norm = def_function.function(clip_ops.clip_by_global_norm) 1730 t_list = [constant_op.constant(1.0), constant_op.constant(2.0)] 1731 clipped_list, global_norm = clip_by_global_norm(t_list, 1732 constant_op.constant(.2)) 1733 for t in clipped_list: 1734 self.assertIsInstance(t, ops.Tensor) 1735 self.assertIsInstance(global_norm, ops.Tensor) 1736 1737 def testNestedSequenceInputs(self): 1738 1739 def my_op(inputs): 1740 a, b, c = inputs 1741 e, f = b 1742 g, h = e 1743 return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c 1744 1745 my_eager_op = def_function.function(my_op) 1746 ret = my_eager_op([ 1747 constant_op.constant(1), [(constant_op.constant(2), 1748 constant_op.constant(3)), 1749 constant_op.constant(4)], 1750 constant_op.constant(5) 1751 ]) 1752 self.assertLen(ret, 2) 1753 self.assertAllEqual(ret[0][0], 2) 1754 self.assertAllEqual(ret[0][1][0][0], 8) 1755 self.assertAllEqual(ret[0][1][0][1], 4) 1756 self.assertIsInstance(ret[0][1][0], tuple) 1757 self.assertAllEqual(ret[0][1][1], 6) 1758 self.assertAllEqual(ret[0][2], 10) 1759 self.assertAllEqual(ret[1], 15) 1760 1761 def testVariableNamesRespectNameScopesWithDefun(self): 1762 @def_function.function 1763 def create_variable(): 1764 with ops.name_scope('foo', skip_on_eager=False): 1765 v = resource_variable_ops.ResourceVariable(0.0, name='bar') 1766 self.assertEqual(v.name, 'foo/bar:0') 1767 1768 create_variable() 1769 1770 def testVariableNamesRespectNameScopesWithDefunInGraph(self): 1771 with context.graph_mode(): 1772 @def_function.function 1773 def create_variable(): 1774 with ops.name_scope('foo', skip_on_eager=False): 1775 v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar') 1776 self.assertEqual(v.name, 'foo/bar:0') 1777 1778 with ops.get_default_graph().as_default(): 1779 create_variable() 1780 1781 @test_util.assert_no_new_pyobjects_executing_eagerly 1782 def testCallOptionsMemory(self): 1783 1784 @function.defun 1785 def model(x): 1786 return x + constant_op.constant(1.) 1787 1788 # This happens with a lot of option toggles, e.g. soft device placement 1789 context.context().function_call_options = None 1790 model(constant_op.constant(2.)) 1791 1792 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) 1793 def testLayerInDefun(self): 1794 conv = convolutional.Conv2D( 1795 filters=1, 1796 kernel_size=2, 1797 kernel_initializer=init_ops.ones_initializer(), 1798 bias_initializer=init_ops.zeros_initializer()) 1799 1800 @function.defun 1801 def model(x): 1802 return conv(x) 1803 1804 x = array_ops.ones([1, 2, 2, 1]) 1805 y = model(x) 1806 1807 if not context.executing_eagerly(): 1808 self.evaluate(variables.global_variables_initializer()) 1809 1810 self.assertAllClose([[[[4.0]]]], self.evaluate(y)) 1811 1812 # Variable lifting is somewhat different between defun/tf.function, so testing 1813 # device placement on both makes sense. 1814 @parameterized.named_parameters( 1815 dict(testcase_name='Defun', 1816 function_decorator=function.defun), 1817 dict(testcase_name='DefFunction', 1818 function_decorator=def_function.function)) 1819 @test_util.run_in_graph_and_eager_modes 1820 def testVariablesPlacedOnOutsideDevice(self, function_decorator): 1821 1822 class _Obj(object): 1823 1824 def __init__(self): 1825 self.v = None 1826 1827 @function_decorator 1828 def f(self): 1829 if self.v is None: 1830 self.v = variables.Variable(1.) 1831 return self.v + 1. 1832 1833 has_device = _Obj() 1834 with ops.device('cpu:0'): 1835 has_device.f() 1836 self.assertIn('CPU', has_device.v.device) 1837 1838 @test_util.run_in_graph_and_eager_modes 1839 def testMultipleDeviceCheck(self): 1840 1841 def f(): 1842 with ops.device('cpu'): 1843 return test_ops.device_placement_op() 1844 1845 func = function.defun(f) 1846 with ops.device('cpu:0'): 1847 output = self.evaluate(func()) 1848 self.assertIn(compat.as_bytes('CPU:0'), output) 1849 1850 @test_util.run_in_graph_and_eager_modes 1851 def testDeviceAnnotationsRespected(self): 1852 1853 def multi_device_fn(): 1854 with ops.device('/cpu:0'): 1855 s0 = test_ops.device_placement_op() 1856 with ops.device('/cpu:1'): 1857 s1 = test_ops.device_placement_op() 1858 with ops.device('/cpu:2'): 1859 s2 = test_ops.device_placement_op() 1860 s3 = test_ops.device_placement_op() 1861 return s0, s1, s2, s3 1862 1863 defined = function.defun(multi_device_fn) 1864 outputs = self.evaluate(defined()) 1865 self.assertLen(total_function_cache(defined), 1) 1866 self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) 1867 self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) 1868 self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) 1869 1870 with ops.device('/cpu:3'): 1871 outputs = self.evaluate(defined()) 1872 # All function definitions are agnostic to call site devices. 1873 self.assertLen(total_function_cache(defined), 1) 1874 self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) 1875 self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) 1876 self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) 1877 self.assertIn(compat.as_bytes('CPU:3'), outputs[3]) 1878 1879 with ops.device('/cpu:0'): 1880 outputs = self.evaluate(defined()) 1881 self.assertLen(total_function_cache(defined), 1) 1882 self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) 1883 self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) 1884 self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) 1885 self.assertIn(compat.as_bytes('CPU:0'), outputs[3]) 1886 1887 @test_util.run_in_graph_and_eager_modes 1888 def testCallingGraphFunctionOnDifferentDevice(self): 1889 1890 def func(): 1891 return constant_op.constant(0) 1892 1893 defined = def_function.function(func) 1894 with ops.device('cpu:0'): 1895 cpu_graph_function = defined.get_concrete_function() 1896 1897 with ops.device('cpu:0'): 1898 self.assertEqual( 1899 self.evaluate(cpu_graph_function()), self.evaluate(func())) 1900 1901 with ops.device('cpu:1'): 1902 self.assertEqual(0., self.evaluate(cpu_graph_function())) 1903 1904 with ops.device(None): 1905 self.assertEqual(0., self.evaluate(cpu_graph_function())) 1906 1907 default_graph_function = defined.get_concrete_function() 1908 self.assertEqual( 1909 self.evaluate(default_graph_function()), self.evaluate(func())) 1910 1911 with ops.device('cpu:1'): 1912 self.assertEqual(0., self.evaluate(default_graph_function())) 1913 1914 @test_util.run_gpu_only 1915 @test_util.run_in_graph_and_eager_modes 1916 def testColocateWithRespected(self): 1917 # TODO(b/113291792): Use multiple CPUs instead of a GPU. 1918 with ops.device('cpu:0'): 1919 x = array_ops.identity(1.0) 1920 1921 with ops.device('gpu:0'): 1922 y = array_ops.identity(1.0) 1923 1924 @def_function.function 1925 def foo(): 1926 return test_ops.device_placement_op() 1927 1928 with ops.colocate_with(x): 1929 self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo())) 1930 1931 with ops.colocate_with(y): 1932 self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo())) 1933 1934 def testVariablesAreTracked(self): 1935 v = resource_variable_ops.ResourceVariable(1.0) 1936 1937 def foo(x): 1938 return v * x 1939 1940 defined = def_function.function(foo) 1941 1942 x = constant_op.constant([1.0]) 1943 self.assertEqual(1., self.evaluate(defined(x))) 1944 v.assign(2.) 1945 1946 x = constant_op.constant([1.0, 2.0]) 1947 self.assertAllEqual([2., 4.], self.evaluate(defined(x))) 1948 1949 def testCacheObjectHashCollisions(self): 1950 1951 class Foo: 1952 1953 def __hash__(self): 1954 return 42 1955 1956 def func(foo): 1957 return constant_op.constant([id(foo)]) 1958 1959 defined = function.defun(func) 1960 foo_1 = Foo() 1961 defined(foo_1) 1962 self.assertLen(total_function_cache(defined), 1) 1963 1964 foo_2 = Foo() 1965 defined(foo_2) 1966 self.assertLen(total_function_cache(defined), 2) 1967 1968 def testCacheTensorDtypeCollision(self): 1969 1970 def func(t): 1971 return t + t 1972 1973 defined = function.defun(func) 1974 t = constant_op.constant([[1.0]], dtype=dtypes.complex64) 1975 defined(t) 1976 self.assertLen(total_function_cache(defined), 1) 1977 1978 t = constant_op.constant([[1.0]], dtype=dtypes.complex128) 1979 defined(t) 1980 self.assertLen(total_function_cache(defined), 2) 1981 1982 def testCacheTensorShapeCollision(self): 1983 1984 def func(t): 1985 return t + t 1986 1987 defined = function.defun(func) 1988 t = constant_op.constant([[1.0]], dtype=dtypes.complex64) 1989 defined(t) 1990 self.assertLen(total_function_cache(defined), 1) 1991 1992 t = constant_op.constant([1.0], dtype=dtypes.complex64) 1993 defined(t) 1994 self.assertLen(total_function_cache(defined), 2) 1995 1996 def testCacheTensorShapeDtypeCollision(self): 1997 1998 def func(t): 1999 return t + t 2000 2001 defined = function.defun(func) 2002 t = constant_op.constant([[1.0]], dtype=dtypes.complex64) 2003 defined(t) 2004 self.assertLen(total_function_cache(defined), 1) 2005 2006 t = constant_op.constant([1.0], dtype=dtypes.complex128) 2007 defined(t) 2008 self.assertLen(total_function_cache(defined), 2) 2009 2010 def testCacheTensorUnknownShapesCollisionRelaxedShapes(self): 2011 2012 def func(t): 2013 return t + t 2014 2015 with context.graph_mode(), self.cached_session(): 2016 defined = function.defun(func, reduce_retracing=True) 2017 2018 p = array_ops.placeholder(dtype=dtypes.float32, shape=[]) 2019 defined(p) 2020 self.assertLen(total_function_cache(defined), 1) 2021 2022 p = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) 2023 defined(p) 2024 self.assertLen(total_function_cache(defined), 2) 2025 2026 p = array_ops.placeholder(dtype=dtypes.float32, shape=[2]) 2027 defined(p) 2028 # Gradual shape relaxation is performed; and the common shape between 2029 # [1] and [2] is one containing unknown dimensions. 2030 self.assertLen(total_function_cache(defined), 2) 2031 2032 t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32) 2033 defined(t) 2034 # Shape (3,) matches the relaxed shape TensorShape([None]) 2035 self.assertLen(total_function_cache(defined), 2) 2036 2037 def testPythonFunctionWithDefaultArgs(self): 2038 2039 def func(foo, bar=1, baz=2): 2040 del foo 2041 del bar 2042 del baz 2043 return 2044 2045 defined = function.defun(func) 2046 defined(0, baz=20) 2047 self.assertLen(total_function_cache(defined), 1) 2048 2049 defined(1) # bar=1, baz=2 2050 self.assertLen(total_function_cache(defined), 2) 2051 2052 # This matches the previous call. 2053 defined(foo=1) 2054 self.assertLen(total_function_cache(defined), 2) 2055 2056 defined(1, 2, 3) 2057 self.assertLen(total_function_cache(defined), 3) 2058 2059 # This matches the previous call. 2060 defined(1, bar=2, baz=3) 2061 self.assertLen(total_function_cache(defined), 3) 2062 2063 # This matches the previous call. 2064 defined(1, baz=3, bar=2) 2065 self.assertLen(total_function_cache(defined), 3) 2066 2067 def testDatasetIteratorCaching(self): 2068 def func(it1, it2): 2069 next(it1) 2070 next(it2) 2071 return 0 2072 2073 defined = function.defun(func) 2074 2075 d = dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]) 2076 it1 = iter(d) 2077 it2 = iter(d) 2078 _ = defined(it1, it2) # The two iterators are different 2079 self.assertLen(total_function_cache(defined), 1) 2080 2081 it3 = iter(d) 2082 it4 = iter(d) 2083 _ = defined(it3, it4) # The two iterators are different, should not retrace 2084 self.assertLen(total_function_cache(defined), 1) 2085 2086 it5 = iter(d) 2087 _ = defined(it5, it5) # The two iterators are the same, should retrace 2088 self.assertLen(total_function_cache(defined), 2) 2089 2090 it6 = iter(d) 2091 _ = defined(it6, it6) # The two iterators are the same, should not retrace 2092 self.assertLen(total_function_cache(defined), 2) 2093 2094 def testFunctoolsPartialUnwrappedCorrectly(self): 2095 2096 def full_function(a, b, c=3): 2097 return a, b, c 2098 2099 partial = functools.partial(full_function, 1, c=4) 2100 a, b, c = partial(2) 2101 2102 defined = function.defun(partial) 2103 func_a, func_b, func_c = defined(2) 2104 self.assertEqual(func_a.numpy(), a) 2105 self.assertEqual(func_b.numpy(), b) 2106 self.assertEqual(func_c.numpy(), c) 2107 2108 def testInputSignatureWithMatchingInputs(self): 2109 2110 def foo(a): 2111 self.assertEqual(a.shape, (2,)) 2112 return a 2113 2114 signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] 2115 defined = function.defun(foo, input_signature=signature) 2116 a = array_ops.ones([2]) 2117 self.assertAllEqual(a, defined(a)) 2118 self.assertLen(total_function_cache(defined), 1) 2119 self.assertAllEqual(a, defined.get_concrete_function()(a)) 2120 self.assertAllEqual(a, defined.get_concrete_function(a)(a)) 2121 self.assertAllEqual(a, defined.get_concrete_function( 2122 tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a)) 2123 self.assertLen(total_function_cache(defined), 1) 2124 2125 def bar(a): 2126 self.assertEqual(a._shape_tuple(), (2, None)) 2127 return a 2128 2129 signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)] 2130 defined = function.defun(bar, input_signature=signature) 2131 a = array_ops.ones([2, 1]) 2132 out = defined(a) 2133 self.assertLen(total_function_cache(defined), 1) 2134 self.assertAllEqual(out, a) 2135 2136 # Changing the second dimension shouldn't create a new function. 2137 b = array_ops.ones([2, 3]) 2138 out = defined(b) 2139 self.assertLen(total_function_cache(defined), 1) 2140 self.assertAllEqual(out, b) 2141 2142 def testInputSignatureWithDictInPositionalArgs(self): 2143 2144 @function.defun 2145 def f(*_args, **_kwargs): 2146 return None 2147 2148 f(1, x=2) 2149 self.assertLen(total_function_cache(f), 1) 2150 f(1, x=2) 2151 self.assertLen(total_function_cache(f), 1) 2152 f(1, {'x': 2}) 2153 self.assertLen(total_function_cache(f), 2) 2154 2155 def testInputSignatureWithCompatibleInputs(self): 2156 2157 rank2_spec = tensor_spec.TensorSpec(shape=(None, None), 2158 dtype=dtypes.float32) 2159 2160 @function.defun(input_signature=[rank2_spec]) 2161 def func(a): 2162 self.assertEqual([None, None], a.shape.as_list()) 2163 return array_ops.shape(a) 2164 2165 self.assertAllEqual([3, 1], func([[0], [1.0], [1]])) 2166 self.assertAllEqual([2, 2], func(numpy.array([[1, 1], [2, 2]]))) 2167 2168 with self.assertRaisesRegex(ValueError, 'incompatible'): 2169 func([0.0, 1.0, 2.0]) # Wrong shape. 2170 2171 with self.assertRaisesRegex(ValueError, 'incompatible'): 2172 func([['wrong dtype']]) 2173 2174 def testNestedInputSignatures(self): 2175 2176 def expected_foo(a, b): 2177 return [a, b] 2178 2179 @function.defun(input_signature=[ 2180 [tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2, 2181 tensor_spec.TensorSpec((1,), dtypes.float32), 2182 ]) 2183 def foo(a, b): 2184 self.assertEqual(a[0]._shape_tuple(), (2, None)) 2185 self.assertEqual(a[1]._shape_tuple(), (2, None)) 2186 self.assertEqual(b._shape_tuple(), (1,)) 2187 return [a, b] 2188 2189 a = array_ops.ones([2, 1]) 2190 b = array_ops.ones([1]) 2191 expected = expected_foo([a, a], b) 2192 out = foo([a, a], b) 2193 self.assertLen(total_function_cache(foo), 1) 2194 nest.assert_same_structure(out, expected) 2195 self.assertAllEqual(out[0][0], a) 2196 self.assertAllEqual(out[0][1], a) 2197 self.assertAllEqual(out[1], b) 2198 2199 # Changing the unspecified dimensions shouldn't create a new function. 2200 a = array_ops.ones([2, 3]) 2201 b = array_ops.ones([2, 5]) 2202 c = array_ops.ones([1]) 2203 expected = expected_foo([a, b], c) 2204 out = foo([a, b], c) 2205 self.assertLen(total_function_cache(foo), 1) 2206 nest.assert_same_structure(out, expected) 2207 self.assertAllEqual(out[0][0], a) 2208 self.assertAllEqual(out[0][1], b) 2209 self.assertAllEqual(out[1], c) 2210 2211 # Passing compatible inputs should work. 2212 a = a.numpy().tolist() 2213 b = b.numpy().tolist() 2214 c = c.numpy().tolist() 2215 out = foo([a, b], c) 2216 self.assertLen(total_function_cache(foo), 1) 2217 nest.assert_same_structure(out, expected) 2218 self.assertAllEqual(out[0][0], a) 2219 self.assertAllEqual(out[0][1], b) 2220 self.assertAllEqual(out[1], c) 2221 2222 def testNestedInputSignaturesWithDict(self): 2223 def expected_bar(a): 2224 return a 2225 2226 @function.defun(input_signature=[{ 2227 'a': tensor_spec.TensorSpec((2, None), dtypes.float32), 2228 'b': tensor_spec.TensorSpec((2, None), dtypes.float32), 2229 'c': tensor_spec.TensorSpec((1,), dtypes.float32)}]) 2230 def bar(a): 2231 self.assertEqual(a['a']._shape_tuple(), (2, None)) 2232 self.assertEqual(a['b']._shape_tuple(), (2, None)) 2233 self.assertEqual(a['c']._shape_tuple(), (1,)) 2234 return a 2235 2236 a = array_ops.ones([2, 3]) 2237 b = array_ops.ones([1]) 2238 inputs = {'a': a, 'b': a, 'c': b} 2239 expected = expected_bar(inputs) 2240 out = bar(inputs) 2241 nest.assert_same_structure(out, expected) 2242 self.assertAllEqual(out['a'], expected['a']) 2243 self.assertAllEqual(out['b'], expected['b']) 2244 self.assertAllEqual(out['c'], expected['c']) 2245 2246 # Passing compatible inputs should work. 2247 a = a.numpy().tolist() 2248 b = b.numpy().tolist() 2249 inputs = {'a': a, 'b': a, 'c': b} 2250 out = bar(inputs) 2251 nest.assert_same_structure(out, expected) 2252 self.assertAllEqual(out['a'], expected['a']) 2253 self.assertAllEqual(out['b'], expected['b']) 2254 self.assertAllEqual(out['c'], expected['c']) 2255 2256 def testInputSignatureMustBeSequenceOfTensorSpecs(self): 2257 2258 def foo(a, b): 2259 del a 2260 del b 2261 2262 # Signatures must consist exclusively of `TensorSpec` objects. 2263 signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)] 2264 with self.assertRaisesRegex(TypeError, 'input_signature.*nested sequence'): 2265 def_function.function(foo, input_signature=signature) 2266 2267 # Signatures must be either lists or tuples on their outermost levels. 2268 signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)} 2269 with self.assertRaisesRegex( 2270 TypeError, 'input_signature must be either a ' 2271 'tuple or a list.*'): 2272 function.defun(foo, input_signature=signature) 2273 2274 @test_util.run_in_graph_and_eager_modes 2275 def testInputsIncompatibleWithSignatureRaisesError(self): 2276 2277 def foo(a): 2278 return a 2279 2280 signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] 2281 defined = def_function.function(foo, input_signature=signature) 2282 2283 # Invalid shapes. 2284 with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'): 2285 defined(array_ops.ones([3])) 2286 2287 with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'): 2288 defined(array_ops.ones([2, 1])) 2289 2290 # Wrong number of arguments. 2291 with self.assertRaisesRegex(TypeError, 'specifies 1 .* got 2'): 2292 defined(array_ops.ones([2]), array_ops.ones([2])) 2293 with self.assertRaisesRegex(ValueError, 2294 'Structure of Python function inputs.*'): 2295 defined() 2296 2297 with self.assertRaisesRegex(ValueError, 2298 'inputs incompatible with input_signature'): 2299 defined.get_concrete_function( 2300 tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32)) 2301 2302 def testMismatchedConcreteSignatureRaisesError(self): 2303 2304 @def_function.function 2305 def run_test(): 2306 @def_function.function 2307 def f(x): 2308 return x 2309 2310 with self.assertRaisesRegex( 2311 TypeError, 'ConcreteFunction .* was constructed .* but was called'): 2312 f.get_concrete_function(1)(constant_op.constant(1)) 2313 2314 with self.assertRaisesRegex(TypeError, r'f\(x\) expected .* but got .*'): 2315 f.get_concrete_function(constant_op.constant(1))(1) 2316 2317 with self.assertRaisesRegex( 2318 TypeError, 'ConcreteFunction .* was constructed .* but was called'): 2319 f.get_concrete_function(1)(2) 2320 2321 run_test() 2322 2323 def testInputsIncompatibleWithNestedSignatureRaisesError(self): 2324 2325 def foo(a, b): 2326 return [a, b] 2327 2328 signature = [[tensor_spec.TensorSpec((1,), dtypes.float32)] * 2, 2329 [tensor_spec.TensorSpec((1,), dtypes.float32)] * 2] 2330 defined = function.defun(foo, input_signature=signature) 2331 a = array_ops.ones([1]) 2332 2333 with self.assertRaisesRegex(ValueError, 2334 'Structure of Python function inputs.*'): 2335 defined([a, a, a], [a]) 2336 2337 with self.assertRaisesRegex(ValueError, 2338 'Structure of Python function inputs.*'): 2339 defined([a], [a, a, a]) 2340 defined([a, a], [a, a]) 2341 2342 def testUnderspecifiedInputSignature(self): 2343 @function.defun(input_signature=[ 2344 tensor_spec.TensorSpec([], dtypes.float32), 2345 ]) 2346 def foo(a, training=True): 2347 if training: 2348 return a 2349 else: 2350 return -1.0 * a 2351 2352 x = constant_op.constant(1.0) 2353 with self.assertRaisesRegex( 2354 TypeError, 'got keyword argument `training` ' 2355 'that was not included in input_signature'): 2356 foo(x, training=True) 2357 2358 with self.assertRaisesRegex( 2359 TypeError, 'got keyword argument `training` ' 2360 'that was not included in input_signature'): 2361 foo(x, training=False) 2362 2363 self.assertAllEqual(x.numpy(), foo(x).numpy()) 2364 2365 def testInputSignatureWithPartialFunction(self): 2366 def full_function(a, b, c=3.0): 2367 return a, b, c 2368 2369 partial = functools.partial(full_function, 1, c=4) 2370 a, b, c = partial(2.0) 2371 signature = [tensor_spec.TensorSpec([], dtypes.float32)] 2372 defined = function.defun(partial, input_signature=signature) 2373 x = constant_op.constant(2.0) 2374 func_a, func_b, func_c = defined(x) 2375 self.assertEqual(func_a.numpy(), a) 2376 self.assertEqual(func_b.numpy(), b) 2377 self.assertEqual(func_c.numpy(), c) 2378 2379 def testInputSignatureConversionWithDefaultArg(self): 2380 2381 def foo(a, training=True): 2382 if training: 2383 return a 2384 else: 2385 return -1.0 * a 2386 2387 signature = [ 2388 tensor_spec.TensorSpec([], dtypes.float32), 2389 tensor_spec.TensorSpec([], dtypes.bool), 2390 ] 2391 defined = def_function.function(foo, input_signature=signature) 2392 a = constant_op.constant(1.0) 2393 self.assertAllEqual(a.numpy(), defined(a)) 2394 self.assertAllEqual(a.numpy(), defined(a, training=True)) 2395 self.assertAllEqual(-a.numpy(), defined(a, training=False)) 2396 2397 def testInputSignatureWithKeywordPositionalArgs(self): 2398 2399 @function.defun(input_signature=[ 2400 tensor_spec.TensorSpec([], dtypes.float32), 2401 tensor_spec.TensorSpec([], dtypes.int64) 2402 ]) 2403 def foo(flt, integer): 2404 return flt, integer 2405 2406 flt = constant_op.constant(1.0) 2407 integer = constant_op.constant(2, dtypes.int64) 2408 2409 out1, out2 = foo(flt, integer) 2410 self.assertLen(total_function_cache(foo), 1) 2411 self.assertEqual(out1.numpy(), 1.0) 2412 self.assertEqual(out2.numpy(), 2) 2413 2414 out1, out2 = foo(flt=flt, integer=integer) 2415 self.assertLen(total_function_cache(foo), 1) 2416 self.assertEqual(out1.numpy(), 1.0) 2417 self.assertEqual(out2.numpy(), 2) 2418 2419 out1, out2 = foo(integer=integer, flt=flt) 2420 self.assertLen(total_function_cache(foo), 1) 2421 self.assertEqual(out1.numpy(), 1.0) 2422 self.assertEqual(out2.numpy(), 2) 2423 2424 out1, out2 = foo(flt, integer=integer) 2425 self.assertLen(total_function_cache(foo), 1) 2426 self.assertEqual(out1.numpy(), 1.0) 2427 self.assertEqual(out2.numpy(), 2) 2428 2429 def testInputSignatureWithKeywordArgs(self): 2430 def foo(a, b, **kwargs): 2431 del kwargs 2432 return a, b 2433 2434 x = function.defun( 2435 foo, 2436 input_signature=[ 2437 tensor_spec.TensorSpec([], dtypes.float32), 2438 tensor_spec.TensorSpec([], dtypes.int32) 2439 ]).get_concrete_function() 2440 result = x(constant_op.constant(5.0), constant_op.constant(5)) 2441 self.assertAllEqual(result, [5.0, 5]) 2442 2443 def testInputSignatureWithCompositeTensors(self): 2444 def f(rt): 2445 self.assertEqual(rt.values.shape.as_list(), [None]) 2446 self.assertEqual(rt.row_splits.shape.as_list(), [4]) 2447 return rt 2448 2449 signature = [ragged_tensor.RaggedTensorSpec( 2450 shape=[3, None], dtype=dtypes.int32)] 2451 defined = function.defun(f, input_signature=signature) 2452 rt1 = ragged_factory_ops.constant([[1], [], [2, 3, 4]]) 2453 out1 = defined(rt1) 2454 self.assertLen(total_function_cache(defined), 1) 2455 self.assertAllEqual(out1.values, rt1.values) 2456 self.assertAllEqual(out1.row_splits, rt1.row_splits) 2457 2458 # Changing the row lengths shouldn't create a new function. 2459 rt2 = ragged_factory_ops.constant([[1, 2], [3, 4], [5]]) 2460 out2 = defined(rt2) 2461 self.assertLen(total_function_cache(defined), 1) 2462 self.assertAllEqual(out2.values, rt2.values) 2463 self.assertAllEqual(out2.row_splits, rt2.row_splits) 2464 2465 # Different number of rows 2466 rt3 = ragged_factory_ops.constant([[1, 2], [3, 4], [5], [6]]) 2467 with self.assertRaisesRegex(ValueError, 'incompatible'): 2468 defined(rt3) 2469 2470 # Different dtype 2471 rt4 = ragged_factory_ops.constant([[1.0, 2.0], [], [3.0]]) 2472 with self.assertRaisesRegex(ValueError, 'Structure .* does not match'): 2473 defined(rt4) 2474 2475 # Different rank 2476 rt5 = ragged_factory_ops.constant([[[1]], [[2]], [[3]]]) 2477 with self.assertRaisesRegex(ValueError, 'does not match'): 2478 defined(rt5) 2479 2480 def testInputSignatureWithVariableArgs(self): 2481 2482 def f(v): 2483 v.assign_add(1) 2484 2485 signature = [ 2486 resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32) 2487 ] 2488 defined = function.defun(f, input_signature=signature) 2489 2490 v1 = variables.Variable(0) 2491 v2 = variables.Variable(0) 2492 2493 defined(v1) 2494 self.assertEqual(v1.numpy(), 1) 2495 self.assertEqual(v2.numpy(), 0) 2496 2497 defined(v=v2) 2498 self.assertEqual(v1.numpy(), 1) 2499 self.assertEqual(v2.numpy(), 1) 2500 2501 def testInputSignatureWithKeywordOnlyArgs(self): 2502 2503 def f(a, b, c=3, *, d=4): 2504 self.assertIsInstance(a, ops.Tensor) 2505 self.assertIsInstance(b, ops.Tensor) 2506 self.assertIsInstance(c, int) 2507 self.assertIsInstance(d, (int, ops.Tensor)) 2508 return a + b + c + d 2509 2510 signature = [ 2511 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), 2512 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), 2513 ] 2514 defined = function.defun(f, input_signature=signature) 2515 self.assertEqual(defined(1, 2).numpy(), 10) 2516 2517 defined = function.defun( 2518 functools.partial(f, c=4), input_signature=signature) 2519 self.assertEqual(defined(1, 2).numpy(), 11) 2520 2521 defined = function.defun( 2522 functools.partial(f, d=5), input_signature=signature) 2523 self.assertEqual(defined(1, 2).numpy(), 11) 2524 2525 defined = function.defun( 2526 functools.partial(f, d=array_ops.constant(5)), 2527 input_signature=signature) 2528 self.assertEqual(defined(1, 2).numpy(), 11) 2529 2530 mod = module.Module() 2531 save(mod, '/tmp/kwonlyf', defined.get_concrete_function(*signature)) 2532 loaded = load('/tmp/kwonlyf') 2533 result = loaded.signatures['serving_default']( 2534 a=array_ops.constant(1), b=array_ops.constant(2)) 2535 self.assertEqual(result['output_0'].numpy(), 11) 2536 2537 def testInputSignatureWithKeywordOnlyArgsNoDefaults(self): 2538 signature = [ 2539 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), 2540 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), 2541 ] 2542 2543 def test_func(a, *, b): 2544 return a + b 2545 2546 with self.assertRaisesRegex( 2547 ValueError, "keyword-only arguments must have default values.*'b'"): 2548 function.defun(test_func, input_signature=signature) 2549 2550 test_func_lambda = lambda a, *, b: a + b 2551 with self.assertRaisesRegex( 2552 ValueError, "keyword-only arguments must have default values.*'b'"): 2553 function.defun(test_func_lambda, input_signature=signature) 2554 2555 def testTensorKeywordArguments(self): 2556 2557 def foo(a, b): 2558 del a 2559 return b 2560 2561 defined = function.defun(foo) 2562 a = constant_op.constant(2.0) 2563 b = constant_op.constant([1.0, 2.0]) 2564 one = defined(a, b) 2565 self.assertLen(total_function_cache(defined), 1) 2566 2567 two = defined(a=a, b=b) 2568 self.assertLen(total_function_cache(defined), 1) 2569 2570 three = defined(b=b, a=a) 2571 self.assertLen(total_function_cache(defined), 1) 2572 2573 four = defined(a, b=b) 2574 self.assertLen(total_function_cache(defined), 1) 2575 2576 # The next call corresponds to a new input signature, hence 2577 # we expect another function to be defined. 2578 five = defined(b, a) 2579 self.assertLen(total_function_cache(defined), 2) 2580 2581 six = defined(a=b, b=a) 2582 self.assertLen(total_function_cache(defined), 2) 2583 2584 seven = defined(b=a, a=b) 2585 self.assertLen(total_function_cache(defined), 2) 2586 2587 self.assertAllEqual(one, [1.0, 2.0]) 2588 self.assertAllEqual(two, [1.0, 2.0]) 2589 self.assertAllEqual(three, [1.0, 2.0]) 2590 self.assertAllEqual(four, [1.0, 2.0]) 2591 self.assertAllEqual(five, 2.0) 2592 self.assertAllEqual(six, 2.0) 2593 self.assertAllEqual(seven, 2.0) 2594 2595 def testDefuningInstanceMethod(self): 2596 2597 integer = constant_op.constant(2, dtypes.int64) 2598 2599 class Foo: 2600 2601 def one(self, tensor): 2602 return tensor 2603 2604 @def_function.function 2605 def two(self, tensor, other=integer): 2606 return self.one(tensor), other 2607 2608 foo = Foo() 2609 t = constant_op.constant(1.0) 2610 one, two = foo.two(t) 2611 self.assertEqual(one.numpy(), 1.0) 2612 self.assertEqual(two.numpy(), 2) 2613 2614 def testDefuningInstanceMethodWithDefaultArgument(self): 2615 2616 integer = constant_op.constant(2, dtypes.int64) 2617 2618 class Foo: 2619 2620 @def_function.function 2621 def func(self, other=integer): 2622 return other 2623 2624 foo = Foo() 2625 self.assertEqual(foo.func().numpy(), int(integer)) 2626 2627 def testPythonCallWithSideEffects(self): 2628 state = [] 2629 2630 @def_function.function 2631 def side_effecting_function(): 2632 state.append(0) 2633 2634 side_effecting_function() 2635 self.assertAllEqual(state, [0]) 2636 2637 # The second invocation should call the graph function, which shouldn't 2638 # trigger the list append. 2639 side_effecting_function() 2640 self.assertAllEqual(state, [0]) 2641 2642 # Whereas calling the python function directly should create a side-effect. 2643 side_effecting_function.python_function() 2644 self.assertAllEqual(state, [0, 0]) 2645 2646 def testFunctionWithNestedFunctionCallAndSideEffects(self): 2647 v1 = variables.Variable(1.0) 2648 v2 = variables.Variable(1.0) 2649 2650 @def_function.function 2651 def add_one(a): 2652 a.assign_add(1.0) 2653 2654 # Grappler will inline calls to `add_one` into the function body, we check 2655 # that all side-effects were executed. 2656 @def_function.function 2657 def side_effecting_function(a, b): 2658 add_one(a) 2659 add_one(b) 2660 return a + b 2661 2662 result = side_effecting_function(v1, v2) 2663 self.assertEqual(result.numpy(), 4.0) 2664 2665 def testFunctionWithExtraAttributes(self): 2666 @function.defun_with_attributes(attributes={'experimental_1': 'value1', 2667 'experimental_2': 2}) 2668 def matmul(x, y): 2669 return math_ops.matmul(x, y) 2670 2671 def add(x, y): 2672 return math_ops.add(x, y) 2673 defun_add = function.defun_with_attributes( 2674 add, attributes={'experimental_3': True, 'experimental_4': 1.0}) 2675 2676 with context.graph_mode(), self.cached_session(): 2677 with ops.get_default_graph().as_default(): 2678 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2679 sq = matmul(t, t) 2680 double = defun_add(t, t) 2681 self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) 2682 self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) 2683 2684 graph = ops.get_default_graph() 2685 # pylint: disable=protected-access 2686 self.assertLen(graph._functions, 2) 2687 functions = list(graph._functions.values()) 2688 self.assertRegex(functions[0].definition.signature.name, '.*matmul.*') 2689 attrs = functions[0].definition.attr 2690 self.assertLen(attrs, 2) 2691 self.assertEqual(attrs['experimental_1'].s, b'value1') 2692 self.assertEqual(attrs['experimental_2'].i, 2) 2693 2694 self.assertRegex(functions[1].definition.signature.name, '.*add.*') 2695 attrs = functions[1].definition.attr 2696 self.assertLen(attrs, 2) 2697 self.assertEqual(attrs['experimental_3'].b, True) 2698 self.assertEqual(attrs['experimental_4'].f, 1.0) 2699 # pylint: enable=protected-access 2700 2701 def testFunctionWithInvalidAttribute(self): 2702 @function.defun_with_attributes(attributes={'experimental_1': ['value1']}) 2703 def add(x, y): 2704 return math_ops.add(x, y) 2705 2706 with self.assertRaisesRegex(ValueError, 2707 'Attribute experimental_1 must be .* Got .*'): 2708 with context.graph_mode(), self.cached_session(): 2709 with ops.get_default_graph().as_default(): 2710 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2711 add(t, t) 2712 2713 def testRegisterFunction(self): 2714 2715 @function.defun 2716 def add(x, y): 2717 return math_ops.add(x, y) 2718 2719 def matmul(x, y): 2720 return math_ops.matmul(x, y) 2721 defun_matmul = function.defun(matmul) 2722 2723 with context.graph_mode(), self.cached_session(): 2724 with ops.get_default_graph().as_default(): 2725 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2726 function.register(defun_matmul, t, t) 2727 function.register(add, t, t) 2728 2729 graph = ops.get_default_graph() 2730 # pylint: disable=protected-access 2731 self.assertLen(graph._functions, 6) 2732 # two sets of functions, each of them are (inference, forward, backward) 2733 functions = list(graph._functions.values()) 2734 captured_function_names = [ 2735 f.definition.signature.name for f in functions 2736 ] 2737 expected_func_name_regex = [ 2738 '.*inference.*matmul.*', 2739 '.*forward.*matmul.*', 2740 '.*inference.*backward.*matmul.*', 2741 '.*inference.*add.*', 2742 '.*forward.*add.*', 2743 '.*inference.*backward.*add.*', 2744 ] 2745 for i in range(len(functions)): 2746 self.assertRegex(captured_function_names[i], 2747 expected_func_name_regex[i]) 2748 2749 # Check the forward and backward function has the correct attributes. 2750 self.assertEqual( 2751 functions[1].definition.attr['backward_function_name'].s, 2752 functions[2].name) 2753 self.assertEqual( 2754 functions[2].definition.attr['forward_function_name'].s, 2755 functions[1].name) 2756 2757 self.assertEqual( 2758 functions[4].definition.attr['backward_function_name'].s, 2759 functions[5].name) 2760 self.assertEqual( 2761 functions[5].definition.attr['forward_function_name'].s, 2762 functions[4].name) 2763 2764 sq = defun_matmul(t, t) 2765 double = add(t, t) 2766 self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) 2767 self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) 2768 # Make sure the pre registered function is used, and no other function 2769 # is added. 2770 self.assertLen(graph._functions, 6) 2771 functions = list(graph._functions.values()) 2772 for i in range(len(functions)): 2773 self.assertEqual(captured_function_names[i], 2774 functions[i].definition.signature.name) 2775 2776 @parameterized.named_parameters( 2777 dict(testcase_name='Defun', 2778 function_decorator=function.defun), 2779 dict(testcase_name='DefFunction', 2780 function_decorator=def_function.function)) 2781 def testRegisterConcreteFunction(self, function_decorator): 2782 @function_decorator 2783 def py_add(x, y): 2784 return math_ops.add(x, y) 2785 2786 py_add(array_ops.ones([]), array_ops.ones([])) 2787 add = py_add.get_concrete_function( 2788 tensor_spec.TensorSpec(None, dtypes.float32), 2789 tensor_spec.TensorSpec(None, dtypes.float32)) 2790 2791 @function_decorator 2792 def py_composite(x, y): 2793 return x, add(x, y) 2794 2795 py_composite(array_ops.ones([]), array_ops.ones([])) 2796 composite = py_composite.get_concrete_function( 2797 tensor_spec.TensorSpec(None, dtypes.float32), 2798 tensor_spec.TensorSpec(None, dtypes.float32)) 2799 2800 with context.graph_mode(), self.cached_session(): 2801 with ops.get_default_graph().as_default(): 2802 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2803 composite.add_to_graph() 2804 composite.add_gradient_functions_to_graph() 2805 2806 graph = ops.get_default_graph() 2807 # pylint: disable=protected-access 2808 self.assertLen(graph._functions, 6) 2809 # two sets of functions, each of them are (inference, forward, backward) 2810 functions = list(graph._functions.values()) 2811 captured_function_names = [ 2812 f.definition.signature.name for f in functions 2813 ] 2814 expected_func_name_regex = [ 2815 '.*inference.*py_composite.*', 2816 '.*inference.*py_add.*', 2817 '.*forward.*py_composite.*', 2818 '.*forward.*py_add.*', 2819 '.*inference.*backward.*py_composite.*', 2820 '.*inference.*backward.*py_add.*', 2821 ] 2822 for expected, found in zip( 2823 expected_func_name_regex, 2824 captured_function_names): 2825 self.assertRegex(found, expected) 2826 2827 composite_t, composite_double = composite(t, t) 2828 double = add(t, t) 2829 self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(double)) 2830 self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(composite_double)) 2831 self.assertAllEqual([[1, 2], [3, 4]], self.evaluate(composite_t)) 2832 # Make sure the pre registered function is used, and no other function 2833 # is added. 2834 self.assertLen(graph._functions, 6) 2835 2836 @parameterized.named_parameters( 2837 dict(testcase_name='Defun', 2838 function_decorator=function.defun), 2839 dict(testcase_name='DefFunction', 2840 function_decorator=def_function.function)) 2841 def testEagerCaptures(self, function_decorator): 2842 with context.eager_mode(): 2843 large_tensor = array_ops.ones(shape=(256,)) 2844 self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD) 2845 2846 small_tensor = array_ops.ones(shape=(4,)) 2847 self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD) 2848 2849 v = resource_variable_ops.ResourceVariable(0.0) 2850 2851 for captured, op_type in [(large_tensor, 'Placeholder'), 2852 (small_tensor, 'Const'), (v, 'Placeholder')]: 2853 @function_decorator 2854 def test_fn(): 2855 return captured + 1 # pylint: disable=cell-var-from-loop 2856 2857 g = test_fn.get_concrete_function().graph 2858 internal_captures = g.internal_captures 2859 self.assertLen(internal_captures, 1) 2860 self.assertEqual(internal_captures[0].op.type, op_type) 2861 2862 def testRegisterFunctionWithInputSignature(self): 2863 def matmul(x, y): 2864 return math_ops.matmul(x, y) 2865 defun_matmul = function.defun( 2866 matmul, 2867 input_signature=[ 2868 tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), 2869 tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32) 2870 ]) 2871 with context.graph_mode(), self.cached_session(): 2872 with ops.get_default_graph().as_default(): 2873 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2874 function.register(defun_matmul, t, t) 2875 2876 graph = ops.get_default_graph() 2877 # pylint: disable=protected-access 2878 self.assertLen(graph._functions, 3) 2879 2880 # Test register function with cache, note inputs are ignored. 2881 function.register(defun_matmul) 2882 graph = ops.get_default_graph() 2883 self.assertLen(graph._functions, 3) 2884 2885 def testRegisterFunctionWithCache(self): 2886 def matmul(x, y): 2887 return math_ops.matmul(x, y) 2888 defun_matmul = function.defun(matmul) 2889 2890 with context.graph_mode(), self.cached_session(): 2891 with ops.get_default_graph().as_default(): 2892 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2893 t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]]) 2894 function.register(defun_matmul, t, t) 2895 function.register(defun_matmul, t2, t2) 2896 2897 graph = ops.get_default_graph() 2898 # Only one function is registered since the input param are in same type 2899 # pylint: disable=protected-access 2900 self.assertLen(graph._functions, 3) 2901 2902 def testCallingFunctionWithDifferentVariables(self): 2903 2904 @function.defun 2905 def foo(v): 2906 v.assign_add(1.0) 2907 return v.read_value() 2908 2909 v = resource_variable_ops.ResourceVariable(0.0) 2910 graph_function = foo.get_concrete_function(v) 2911 self.assertLen(graph_function.inputs, 1) 2912 self.assertEmpty(graph_function.captured_inputs) 2913 2914 self.assertEqual(float(graph_function(v)), 1.0) 2915 self.assertEqual(float(graph_function(v)), 2.0) 2916 2917 w = resource_variable_ops.ResourceVariable(0.0) 2918 2919 @function.defun 2920 def bar(v): 2921 del v 2922 return constant_op.constant(1.0) 2923 2924 graph_function = bar.get_concrete_function(v) 2925 self.assertEqual(float(graph_function(v)), 1.0) 2926 self.assertEqual(float(graph_function(w)), 1.0) 2927 2928 def testCallingFunctionWithNonTensorsFails(self): 2929 2930 @function.defun 2931 def foo(x): 2932 return x 2933 2934 graph_function = foo.get_concrete_function(constant_op.constant(1.0)) 2935 with self.assertRaises((TypeError, ValueError)): 2936 graph_function('Not a Tensor.') 2937 2938 def testSwapImplementationWithGrapplerPlugin(self): 2939 # Set the min_graph_nodes to -1 since the graph in this test is too small, 2940 # and will be ignored by grappler if don't set this. 2941 rewrites = rewriter_config_pb2.RewriterConfig() 2942 rewrites.implementation_selector = rewriter_config_pb2.RewriterConfig.ON 2943 rewrites.min_graph_nodes = -1 2944 graph_options = config_pb2.GraphOptions( 2945 rewrite_options=rewrites, build_cost_model=1) 2946 config_proto = config_pb2.ConfigProto(graph_options=graph_options) 2947 2948 with context.graph_mode(), self.cached_session( 2949 config=config_proto, graph=ops.Graph(), use_gpu=True): 2950 2951 @function.defun_with_attributes( 2952 attributes={ 2953 'api_implements': 'random_boost', 2954 'api_preferred_device': 'CPU' 2955 }) 2956 def cpu_boost(x): 2957 return math_ops.add(x, 2.0) 2958 2959 @function.defun_with_attributes( 2960 attributes={ 2961 'api_implements': 'random_boost', 2962 'api_preferred_device': 'GPU' 2963 }) 2964 def gpu_boost(x): 2965 return math_ops.add(x, 4.0) 2966 2967 x = constant_op.constant(1.0) 2968 2969 function.register(cpu_boost, x) 2970 y = gpu_boost(x) 2971 y_value = self.evaluate(y) 2972 2973 if test.is_gpu_available(): 2974 self.assertEqual(y_value, 5.0) 2975 else: 2976 # Grappler fallback to use the CPU impl even called with GPU function. 2977 self.assertEqual(y_value, 3.0) 2978 2979 @test_util.disable_tfrt('b/174712583: TFRT doesn\'t support behavior ' 2980 'equivalent to implementation_selector for function') 2981 def testSwapImplementationInEager(self): 2982 if not context.executing_eagerly(): 2983 self.skipTest('eager only') 2984 2985 # testSharedRendezvous sets the disable_meta_optimizer flag to True 2986 # if that subtest runs before this one, then having that set to True 2987 # will cause this subtest to fail. To avoid that scenario, explicitly 2988 # set the disable_meta_optimizer flag to false here 2989 context.context().set_optimizer_experimental_options({ 2990 'min_graph_nodes': -1, 2991 'implementation_selector': True, 2992 'disable_meta_optimizer': False 2993 }) 2994 2995 @function.defun_with_attributes( 2996 attributes={'api_implements': 'foo', 2997 'api_preferred_device': 'CPU'}) 2998 def on_cpu(x): 2999 return x + 2 3000 3001 @function.defun_with_attributes( 3002 attributes={'api_implements': 'foo', 3003 'api_preferred_device': 'GPU'}) 3004 def on_gpu(x): 3005 return x + 4 3006 3007 @function.defun 3008 def run_on_cpu(t): 3009 function.register(on_cpu, t) 3010 with ops.device('CPU:0'): 3011 return on_gpu(t) 3012 3013 # Expect to run the on_cpu branch, regardless whether gpu is available. 3014 self.assertEqual(run_on_cpu(constant_op.constant(1)).numpy(), 3) 3015 3016 def testDefunFunctionSeparateGraphs(self): 3017 with context.graph_mode(): 3018 3019 @function.defun 3020 def add(x): 3021 return x + 5 3022 3023 @function.defun 3024 def maybe_add(x, should_add): 3025 if should_add: 3026 return add(x) 3027 else: 3028 return x 3029 3030 with ops.Graph().as_default(): 3031 x = constant_op.constant(11) 3032 maybe_add(x, True) 3033 self.assertLen(total_function_cache(maybe_add), 1) 3034 self.assertLen(total_function_cache(add), 1) 3035 3036 maybe_add(x, False) 3037 self.assertLen(total_function_cache(maybe_add), 2) 3038 self.assertLen(total_function_cache(add), 1) 3039 3040 with ops.Graph().as_default(): 3041 x = constant_op.constant(11) 3042 maybe_add(x, True) 3043 self.assertLen(total_function_cache(maybe_add), 3) 3044 self.assertLen(total_function_cache(add), 2) 3045 3046 def testCacheKeyOverlappingShapes(self): 3047 @function.defun 3048 def defined(t): 3049 return t 3050 3051 defined(array_ops.zeros([12, 1])) 3052 self.assertLen(total_function_cache(defined), 1) 3053 defined(array_ops.zeros([1, 21])) 3054 self.assertLen(total_function_cache(defined), 2) 3055 3056 @function.defun 3057 def defined_again(t): 3058 return defined(t) 3059 3060 defined_again.get_concrete_function(array_ops.zeros([12, 1])) 3061 self.assertLen(total_function_cache(defined_again), 1) 3062 defined_again.get_concrete_function(array_ops.zeros([1, 21])) 3063 self.assertLen(total_function_cache(defined_again), 2) 3064 3065 def testCacheTensorSpecIdenticalToTensor(self): 3066 @function.defun 3067 def defined(t): 3068 return t 3069 3070 z = array_ops.zeros([2, 2]) 3071 z_spec = tensor_spec.TensorSpec.from_tensor(z) 3072 self.assertIs( 3073 defined.get_concrete_function(z_spec), defined.get_concrete_function(z)) 3074 3075 def testCacheKeyNestedLists(self): 3076 @function.defun 3077 def defined(l): 3078 return l 3079 3080 a = constant_op.constant(1.) 3081 b = constant_op.constant(2.) 3082 c = constant_op.constant(3.) 3083 defined([[a], b, c]) 3084 self.assertLen(total_function_cache(defined), 1) 3085 3086 defined([[a, b], c]) 3087 self.assertLen(total_function_cache(defined), 2) 3088 3089 def testCacheKeyAttrsClass(self): 3090 if attr is None: 3091 self.skipTest('attr module is unavailable.') 3092 3093 @attr.s 3094 class TestClass: 3095 a = attr.ib() 3096 b = attr.ib() 3097 3098 @function.defun 3099 def defined(l): 3100 return l 3101 3102 defined( 3103 TestClass( 3104 constant_op.constant(1.), 3105 [constant_op.constant(2.), 3106 constant_op.constant(3.)])) 3107 self.assertLen(total_function_cache(defined), 1) 3108 defined( 3109 TestClass( 3110 constant_op.constant(1.), 3111 [constant_op.constant(2.), 3112 constant_op.constant(3.)])) 3113 self.assertLen(total_function_cache(defined), 1) 3114 3115 defined( 3116 TestClass([constant_op.constant(1.), 3117 constant_op.constant(2.)], constant_op.constant(3.))) 3118 self.assertLen(total_function_cache(defined), 2) 3119 3120 def testDistinctVariablesNoRetracing(self): 3121 @function.defun 3122 def defined(a, b, c): 3123 return a + b + c 3124 3125 x = resource_variable_ops.ResourceVariable(0.0) 3126 y = resource_variable_ops.ResourceVariable(0.0) 3127 z = resource_variable_ops.ResourceVariable(0.0) 3128 3129 # We generate cache keys based on unique combinations of resource ids. 3130 defined(x, y, z) 3131 self.assertLen(total_function_cache(defined), 1) 3132 3133 # Re-arranging arguments should not cause cache miss 3134 # because the three inputs are still distinct 3135 defined(z, y, x) 3136 self.assertLen(total_function_cache(defined), 1) 3137 3138 def testRetracingOnDifferentVaribleCombinationPatterns(self): 3139 @function.defun 3140 def defined(a, b, c): 3141 return a + b + c 3142 3143 x = resource_variable_ops.ResourceVariable(0.0) 3144 y = resource_variable_ops.ResourceVariable(0.0) 3145 z = resource_variable_ops.ResourceVariable(0.0) 3146 3147 defined(x, y, z) 3148 self.assertLen(total_function_cache(defined), 1) 3149 3150 # Retracing because the first two arguments are the same 3151 defined(x, x, z) 3152 self.assertLen(total_function_cache(defined), 2) 3153 3154 # Replacing x with y does not cause cache miss 3155 # because the combination stays the same as (x, x, z) 3156 defined(y, y, z) 3157 self.assertLen(total_function_cache(defined), 2) 3158 3159 # A different combination pattern causes cache miss 3160 defined(z, y, y) 3161 self.assertLen(total_function_cache(defined), 3) 3162 defined(z, y, y) 3163 self.assertLen(total_function_cache(defined), 3) 3164 3165 def testDeepcopyVariableNoRetracing(self): 3166 @function.defun 3167 def defined(a, b, c): 3168 return a + b + c 3169 3170 x = resource_variable_ops.ResourceVariable(0.0) 3171 y = resource_variable_ops.ResourceVariable(0.0) 3172 z = resource_variable_ops.ResourceVariable(0.0) 3173 defined(x, y, z) 3174 self.assertLen(total_function_cache(defined), 1) 3175 3176 x_copy = copy.deepcopy(x) 3177 defined(x_copy, y, z) 3178 self.assertLen(total_function_cache(defined), 1) 3179 3180 def _total_function_cache_def_func(self, defined): 3181 return defined._list_all_concrete_functions() # pylint: disable=protected-access 3182 3183 def testVariableRetracingOnDtypeChanges(self): 3184 3185 @def_function.function 3186 def defined(a, b): 3187 return a + b 3188 3189 x1 = resource_variable_ops.ResourceVariable(0.0) 3190 x2 = resource_variable_ops.ResourceVariable(0.0) 3191 3192 defined(x1, x2) 3193 self.assertLen(self._total_function_cache_def_func(defined), 1) 3194 3195 # Should expect retracing for new dtypes 3196 y1 = resource_variable_ops.ResourceVariable(0) 3197 y2 = resource_variable_ops.ResourceVariable(1) 3198 defined(y1, y2) 3199 self.assertLen(self._total_function_cache_def_func(defined), 2) 3200 3201 def testVariableRetracingDtypeShape(self): 3202 3203 @def_function.function 3204 def defined(a, b): 3205 return a + b 3206 3207 x1 = resource_variable_ops.ResourceVariable(0.0) 3208 x2 = resource_variable_ops.ResourceVariable(0.0) 3209 3210 defined(x1, x2) 3211 self.assertLen(self._total_function_cache_def_func(defined), 1) 3212 3213 y1 = resource_variable_ops.ResourceVariable([0.0, 1.0]) 3214 y2 = resource_variable_ops.ResourceVariable([0.0, 1.0]) 3215 3216 defined(y1, y2) 3217 self.assertLen(self._total_function_cache_def_func(defined), 2) 3218 3219 z1 = resource_variable_ops.ResourceVariable([[0.0, 1.0]]) 3220 z2 = resource_variable_ops.ResourceVariable([[0.0, 1.0]]) 3221 defined(z1, z2) 3222 self.assertLen(self._total_function_cache_def_func(defined), 3) 3223 3224 def testDecoratedMethodInspect(self): 3225 3226 class DefunnedMiniModel: 3227 3228 @function.defun 3229 def call(self, inputs, training=True): 3230 pass 3231 3232 m = DefunnedMiniModel() 3233 fullargspec = tf_inspect.getfullargspec(m.call) 3234 self.assertIn('training', fullargspec.args) 3235 3236 def testFunctionModifiesInputList(self): 3237 # Tests on `list` methods that do in place modification, except `list.sort` 3238 # since it cannot even be "defunned" in the first place 3239 3240 def get_list(): 3241 return [constant_op.constant(0.), constant_op.constant(1.)] 3242 3243 expected_msg = '.*() should not modify' 3244 3245 with self.assertRaisesRegex(ValueError, expected_msg): 3246 3247 @def_function.function 3248 def append(l): 3249 l.append(constant_op.constant(0.)) 3250 3251 append(get_list()) 3252 3253 with self.assertRaisesRegex(ValueError, expected_msg): 3254 3255 @def_function.function 3256 def extend(l): 3257 l.extend([constant_op.constant(0.)]) 3258 3259 extend(get_list()) 3260 3261 with self.assertRaisesRegex(ValueError, expected_msg): 3262 3263 @def_function.function 3264 def insert(l): 3265 l.insert(0, constant_op.constant(0.)) 3266 3267 insert(get_list()) 3268 3269 with self.assertRaisesRegex(ValueError, expected_msg): 3270 3271 @def_function.function 3272 def pop(l): 3273 l.pop() 3274 3275 pop(get_list()) 3276 3277 with self.assertRaisesRegex(ValueError, expected_msg): 3278 3279 @def_function.function 3280 def reverse(l): 3281 l.reverse() 3282 3283 reverse(get_list()) 3284 3285 with self.assertRaisesRegex(ValueError, expected_msg): 3286 3287 @def_function.function 3288 def remove(l): 3289 l.remove(l[0]) 3290 3291 remove(get_list()) 3292 3293 # `list.clear` is a method that is in Py3 but not Py2 3294 if sys.version.startswith('3'): 3295 3296 with self.assertRaisesRegex(ValueError, expected_msg): 3297 3298 @def_function.function 3299 def clear(l): 3300 l.clear() 3301 3302 clear(get_list()) 3303 3304 # One last test for keyword arguments 3305 with self.assertRaisesRegex(ValueError, expected_msg): 3306 3307 @def_function.function 3308 def kwdappend(**kwargs): 3309 l = kwargs['l'] 3310 l.append(constant_op.constant(0.)) 3311 3312 kwdappend(l=get_list()) 3313 3314 def testFunctionModifiesInputDict(self): 3315 3316 def get_dict(): 3317 return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)} 3318 3319 expected_msg = '.* should not modify' 3320 3321 with self.assertRaisesRegex(ValueError, expected_msg): 3322 3323 @def_function.function 3324 def clear(m): 3325 m.clear() 3326 3327 clear(get_dict()) 3328 3329 with self.assertRaisesRegex(ValueError, expected_msg): 3330 3331 @def_function.function 3332 def pop(m): 3333 m.pop('t1') 3334 3335 pop(get_dict()) 3336 3337 with self.assertRaisesRegex(ValueError, expected_msg): 3338 3339 @def_function.function 3340 def popitem(m): 3341 m.popitem() 3342 3343 popitem(get_dict()) 3344 3345 with self.assertRaisesRegex(ValueError, expected_msg): 3346 3347 @def_function.function 3348 def update(m): 3349 m.update({'t1': constant_op.constant(3.)}) 3350 3351 update(get_dict()) 3352 3353 with self.assertRaisesRegex(ValueError, expected_msg): 3354 3355 @def_function.function 3356 def setdefault(m): 3357 m.setdefault('t3', constant_op.constant(3.)) 3358 3359 setdefault(get_dict()) 3360 3361 def testFunctionModifiesInputNest(self): 3362 with self.assertRaisesRegex(ValueError, 'modify.* should not modify'): 3363 3364 @def_function.function 3365 def modify(n): 3366 n[0]['t1'].append(constant_op.constant(1.)) 3367 3368 nested_input = [{ 3369 't1': [constant_op.constant(0.), 3370 constant_op.constant(1.)], 3371 }, 3372 constant_op.constant(2.)] 3373 3374 modify(nested_input) 3375 3376 with self.assertRaisesRegex(ValueError, 3377 'modify_same_flat.* should not modify'): 3378 3379 # The flat list doesn't change whereas the true structure changes 3380 @def_function.function 3381 def modify_same_flat(n): 3382 n[0].append(n[1].pop(0)) 3383 3384 nested_input = [[constant_op.constant(0.)], 3385 [constant_op.constant(1.), 3386 constant_op.constant(2.)]] 3387 3388 modify_same_flat(nested_input) 3389 3390 @test_util.disable_tfrt('b/173429686') 3391 def testExecutorType(self): 3392 @function.defun 3393 def add_five(x): 3394 return x + 5 3395 3396 self.assertEqual( 3397 5, 3398 add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()) 3399 3400 with self.assertRaisesRegex(errors.NotFoundError, 'NON_EXISTENT_EXECUTOR'): 3401 with context.function_executor_type('NON_EXISTENT_EXECUTOR'): 3402 add_five(constant_op.constant(0, dtype=dtypes.int32)) 3403 3404 for executor_type in ('', 'DEFAULT', None): 3405 with context.function_executor_type(executor_type): 3406 self.assertAllEqual( 3407 5, 3408 add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()) 3409 3410 @test_util.assert_no_garbage_created 3411 def testReferenceCycles(self): 3412 3413 fn = function.defun(lambda x: 2. * x) 3414 3415 fn(constant_op.constant(4.0)) 3416 weak_fn = weakref.ref(fn) 3417 del fn 3418 # Tests that the weak reference we made to the function is now dead, which 3419 # means the object has been deleted. This should be true as long as the 3420 # function itself is not involved in a reference cycle. 3421 self.assertIs(None, weak_fn()) 3422 3423 def testFunctionStackInErrorMessage(self): 3424 if context.executing_eagerly(): 3425 # TODO(b/122736651): Remove this skipTest once fixed. 3426 self.skipTest('Error interpolation is not working when function is ' 3427 'invoked without PartitionedCallOp.') 3428 3429 @def_function.function() 3430 def fn3(x): 3431 return x + 2 3432 3433 @def_function.function() 3434 def fn2(x): 3435 check_ops.assert_equal(fn3(x), 3) 3436 return 2 3437 3438 @def_function.function() 3439 def fn(x): 3440 return fn2(x) 3441 3442 with self.assertRaises(errors.InvalidArgumentError) as cm: 3443 fn(2) 3444 e = cm.exception 3445 self.assertIn('fn -> fn2', e.message) 3446 self.assertIn('node assert_equal/Assert/Assert (defined at', e.message) 3447 self.assertNotIn('fn3', e.message) 3448 3449 @test_util.run_gpu_only 3450 def testFunctionIsNotPinned(self): 3451 """Tests that functions aren't pinned to the CPU by the eager runtime.""" 3452 seed1, seed2 = 79, 25 3453 shape = constant_op.constant([4, 7]) 3454 dtype = dtypes.float32 3455 3456 @def_function.function 3457 def func(): 3458 with ops.device('GPU:0'): 3459 return gen_random_ops.random_standard_normal( 3460 shape, dtype=dtype, seed=seed1, seed2=seed2) 3461 3462 with ops.device('GPU:0'): 3463 x = func() 3464 self.assertRegex(x.device, 'GPU') 3465 3466 @test_util.run_in_graph_and_eager_modes 3467 def testShapeCaching(self): 3468 3469 @function.defun 3470 def func(x): 3471 return array_ops.shape(x) 3472 3473 @function.defun( 3474 input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)]) 3475 def calls_func(x): 3476 return func(x) 3477 3478 self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1])))) 3479 self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2])))) 3480 self.assertAllEqual( 3481 [3, 3], 3482 self.evaluate(calls_func(array_ops.zeros([3, 3])))) 3483 3484 def testLimitedRetracing(self): 3485 trace_count = [0] 3486 @function.defun 3487 def func(x): 3488 trace_count[0] += 1 3489 return x 3490 3491 for _ in range(50): 3492 func(constant_op.constant(3.)) 3493 func(constant_op.constant(4.)) 3494 func(constant_op.constant([[1., 2.]])) 3495 func(constant_op.constant([[]])) 3496 func(constant_op.constant([[3., 4.], [5., 6.]])) 3497 func(constant_op.constant([[3., 4.], [5., 6.], [7., 8.]])) 3498 # Tracing more than twice per input doesn't make sense. 3499 self.assertLess(trace_count[0], 13) 3500 3501 def testLimitedRetracingWithCompositeTensors(self): 3502 trace_count = [0] 3503 3504 @def_function.function 3505 def f(x): 3506 trace_count[0] += 1 3507 return x 3508 3509 for i in range(10): 3510 f(ragged_factory_ops.constant([[1, 2], [i]])) 3511 f(ragged_factory_ops.constant([[1, 2], [], [3, 4, 5]])) 3512 f(ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]])) 3513 self.assertEqual(trace_count[0], 3) 3514 3515 def test_concrete_function_shape_mismatch(self): 3516 3517 @def_function.function 3518 def f(argument_name): 3519 return argument_name + 1. 3520 3521 f_concrete = f.get_concrete_function(constant_op.constant([1.])) 3522 3523 # Calling a function from eager doesn't do any shape checking above what 3524 # kernels do while executing. 3525 self.assertAllEqual( 3526 [2., 3.], 3527 f_concrete(constant_op.constant([1., 2.])).numpy()) 3528 3529 @def_function.function 3530 def g(): 3531 f_concrete(constant_op.constant([1., 2.])) 3532 3533 with self.assertRaisesRegex(ValueError, 'is not compatible with the shape'): 3534 g() 3535 3536 @test_util.run_in_graph_and_eager_modes 3537 def test_shape_inference_with_symbolic_shapes(self): 3538 3539 @def_function.function 3540 def _uses_symbolic_shapes(w, x, y): 3541 x = array_ops.identity(x, name='name_collision') 3542 x = array_ops.transpose(x, [1, 0, 2]) 3543 x_batch = array_ops.shape(x)[0] 3544 y_batch = array_ops.shape(y)[0] 3545 y *= w 3546 n = y_batch // x_batch 3547 return array_ops.reshape(y, [n, x_batch, -1]) 3548 3549 conc = _uses_symbolic_shapes.get_concrete_function( 3550 tensor_spec.TensorSpec(None, dtypes.float32), 3551 tensor_spec.TensorSpec(None, dtypes.float32), 3552 tensor_spec.TensorSpec(None, dtypes.float32)) 3553 3554 @def_function.function 3555 def _call_concrete(): 3556 c = constant_op.constant(1.) 3557 array_ops.identity(c, name='name_collision') 3558 output1 = conc(array_ops.ones([2]), 3559 array_ops.ones([5, 4, 2]), 3560 array_ops.ones([20, 2])) 3561 self.assertEqual([5, 4, 2], output1.shape) 3562 output2 = conc(array_ops.ones([3]), 3563 array_ops.ones([5, 4, 3]), 3564 array_ops.ones([40, 3])) 3565 self.assertEqual([10, 4, 3], output2.shape) 3566 return output1, output2 3567 3568 output1, output2 = _call_concrete() 3569 self.assertEqual((5, 4, 2), self.evaluate(output1).shape) 3570 self.assertEqual((10, 4, 3), self.evaluate(output2).shape) 3571 3572 def testAutoGraphContext(self): 3573 3574 @def_function.function 3575 def test_fn(): 3576 self.assertEqual( 3577 ag_ctx.control_status_ctx().status, ag_ctx.Status.ENABLED) 3578 3579 prev_status = ag_ctx.control_status_ctx().status 3580 test_fn() 3581 self.assertEqual(ag_ctx.control_status_ctx().status, prev_status) 3582 3583 @test_util.disable_tfrt('b/170435618') 3584 def testCancelBeforeFunctionExecution(self): 3585 if not context.executing_eagerly(): 3586 self.skipTest('eager only') 3587 3588 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 3589 3590 @def_function.function 3591 def f(): 3592 return q.dequeue() 3593 3594 c_mgr = cancellation.CancellationManager() 3595 cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) 3596 3597 c_mgr.start_cancel() 3598 with self.assertRaises(errors.CancelledError): 3599 cancelable_func() 3600 3601 @test_util.disable_tfrt('b/170435618') 3602 def testCancelBlockedFunctionExecution(self): 3603 if not context.executing_eagerly(): 3604 self.skipTest('eager only') 3605 3606 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 3607 3608 @def_function.function 3609 def f(): 3610 return q.dequeue() 3611 3612 c_mgr = cancellation.CancellationManager() 3613 cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) 3614 3615 def cancel_thread(): 3616 time.sleep(0.5) 3617 c_mgr.start_cancel() 3618 3619 t = self.checkedThread(cancel_thread) 3620 t.start() 3621 with self.assertRaises(errors.CancelledError): 3622 cancelable_func() 3623 t.join() 3624 3625 @test_util.disable_tfrt('b/170435618') 3626 def testCancelAfterFunctionExecution(self): 3627 if not context.executing_eagerly(): 3628 self.skipTest('eager only') 3629 3630 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 3631 q.enqueue(37) 3632 3633 @def_function.function 3634 def f(): 3635 return q.dequeue() 3636 3637 c_mgr = cancellation.CancellationManager() 3638 cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) 3639 3640 self.assertAllEqual(37, cancelable_func().numpy()) 3641 3642 # Cancellation after the function executes is a no-op. 3643 c_mgr.start_cancel() 3644 3645 def testAddFunctionCallback(self): 3646 functions = [] 3647 def function_callback(f, name, graph, inputs, outputs): 3648 del name, graph, inputs, outputs 3649 functions.append(f) 3650 3651 @def_function.function 3652 def plus_one(x): 3653 return x + 1 3654 3655 try: 3656 function.add_function_callback(function_callback) 3657 x_float32 = numpy.array(3.0, dtype=numpy.float32) 3658 self.assertAllClose(plus_one(x_float32), 4.0) 3659 self.assertLen(functions, 1) 3660 # Function is already created. Executing it again should not invoke the 3661 # function callback. 3662 self.assertAllClose(plus_one(x_float32), 4.0) 3663 self.assertLen(functions, 1) 3664 # Signature change leads to a new Function being built. 3665 x_float64 = numpy.array(3.0, dtype=numpy.float64) 3666 self.assertAllClose(plus_one(x_float64), 4.0) 3667 self.assertLen(functions, 2) 3668 finally: 3669 function.clear_function_callbacks() 3670 3671 def testFunctionCallbackAddOps(self): 3672 file_name = os.path.join(self.get_temp_dir(), 'test') 3673 3674 def function_callback(f, name, graph, inputs, outputs): 3675 del f, name, inputs 3676 3677 with graph.as_default(): 3678 printer = logging_ops.print_v2( 3679 'hello', 3680 output_stream='file://' + file_name 3681 ) 3682 outputs[0].op._add_control_input(printer) 3683 3684 @def_function.function 3685 def plus_one(x): 3686 return x + 1 3687 3688 self.addCleanup(function.clear_function_callbacks) 3689 function.add_function_callback(function_callback) 3690 x_float32 = numpy.array(3.0, dtype=numpy.float32) 3691 3692 self.assertAllClose(plus_one(x_float32), 4.0) 3693 3694 with open(file_name, 'r') as f: 3695 self.assertEqual(f.read().strip(), 'hello') 3696 3697 def testRemoveFunctionCallback(self): 3698 functions_1 = [] 3699 def function_callback_1(f, name, graph, inputs, outputs): 3700 del name, graph, inputs, outputs 3701 functions_1.append(f) 3702 3703 functions_2 = [] 3704 def function_callback_2(f, name, graph, inputs, outputs): 3705 del name, graph, inputs, outputs 3706 functions_2.append(f) 3707 3708 @def_function.function 3709 def plus_one(x): 3710 return x + 1 3711 3712 try: 3713 function.add_function_callback(function_callback_1) 3714 function.add_function_callback(function_callback_2) 3715 self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float32)), 4.0) 3716 self.assertLen(functions_1, 1) 3717 self.assertLen(functions_2, 1) 3718 function.remove_function_callback(function_callback_1) 3719 # The 1st callback should not be invokved after remove_function_callback() 3720 # is called. 3721 self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float64)), 4.0) 3722 self.assertLen(functions_1, 1) 3723 self.assertLen(functions_2, 2) 3724 finally: 3725 function.clear_function_callbacks() 3726 3727 def testClearFunctionCallbacks(self): 3728 function.add_function_callback(lambda f: None) 3729 function.add_function_callback(lambda f: None) 3730 self.assertLen(function._function_callbacks, 2) 3731 function.clear_function_callbacks() 3732 self.assertEmpty(function._function_callbacks) # pylint:disable=protected-access 3733 3734 @test_util.run_in_graph_and_eager_modes 3735 def testConcreteFunctionWithNestedTensorInputs(self): 3736 3737 @def_function.function 3738 def f(x, y): 3739 return (x['a'] + x['b'], y[0] + y[1]) 3740 3741 a = constant_op.constant(1000) 3742 b = constant_op.constant(200) 3743 c = constant_op.constant(30) 3744 d = {'a': a, 'b': b} 3745 e = (c, 4) 3746 3747 # Test different argument signatures when constructing the concrete func. 3748 for cf in [ 3749 f.get_concrete_function(d, e), 3750 f.get_concrete_function(d, y=e), 3751 f.get_concrete_function(y=e, x=d), 3752 f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)), 3753 f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)), 3754 f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d)) 3755 ]: 3756 # Test different calling conventions when calling the concrete func. 3757 for output in [ 3758 cf(d, e), # structured signature 3759 cf(d, y=e), # structured signature w/ kwarg 3760 cf(y=e, x=d), # structured signature w/ 2 kwargs 3761 cf(a, b, c), # flat signature 3762 ]: 3763 self.assertIsInstance(output, tuple) 3764 self.assertLen(output, 2) 3765 self.assertAllEqual(output[0], 1200) 3766 self.assertAllEqual(output[1], 34) 3767 3768 @test_util.run_in_graph_and_eager_modes 3769 def testConcreteFunctionWithNestedNonTensorInputs(self): 3770 3771 @def_function.function 3772 def f(x, y): 3773 return (x['a'] + x['b'], y[0] + y[1]) 3774 3775 a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)} 3776 b = (50, 3) 3777 3778 for cf in [ # argument y is bound to non-Tensor value (50, 3). 3779 f.get_concrete_function(a, b), 3780 f.get_concrete_function(a, y=b), 3781 f.get_concrete_function(x=a, y=b) 3782 ]: 3783 for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]: 3784 self.assertAllEqual(output[0] + output[1], 1253) 3785 3786 @test_util.run_in_graph_and_eager_modes 3787 def testConcreteFunctionWithNonTensorStringInputs(self): 3788 3789 @def_function.function 3790 def f(x, y): 3791 return string_ops.string_join([x, y]) 3792 3793 a = constant_op.constant('a') 3794 b = 'b' 3795 3796 cf = f.get_concrete_function(a, b) 3797 for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]: 3798 self.assertAllEqual(output, b'ab') 3799 3800 @test_util.run_in_graph_and_eager_modes 3801 def testConcreteFunctionWithBoundNestedNonTensorInputs(self): 3802 3803 @def_function.function 3804 def f(x, y): 3805 return (x['a'] + x['b'], y[0] + y[1]) 3806 3807 a = {'a': 3000, 'b': 200, 'c': 9000} 3808 b = (constant_op.constant(30), 4) 3809 3810 for cf in [ # argument x is bound to non-tensor value `a` 3811 f.get_concrete_function(a, b), 3812 f.get_concrete_function(a, y=b), 3813 f.get_concrete_function(x=a, y=b) 3814 ]: 3815 for output in [cf(a, b), cf(a, y=b), cf(y=b), cf(x=a, y=b)]: 3816 self.assertAllEqual(output[0] + output[1], 3234) 3817 3818 @test_util.run_in_graph_and_eager_modes 3819 def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self): 3820 3821 @def_function.function 3822 def f(x, y): 3823 return (x['a'] + x['b'], y[0] + y[1]) 3824 3825 a = {'a': 5000, 'b': 500} 3826 b = (50, 5) 3827 3828 cf = f.get_concrete_function(a, b) 3829 for output in [cf(), cf(a), cf(y=b)]: 3830 self.assertAllEqual(output[0] + output[1], 5555) 3831 3832 @test_util.run_in_graph_and_eager_modes 3833 def testConcreteFunctionMethodWithVarargs(self): 3834 float32_scalar = tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32) 3835 3836 class MyModel(module.Module): 3837 3838 @def_function.function(input_signature=[float32_scalar, float32_scalar]) 3839 def add(self, *arg): 3840 return math_ops.add(*arg) 3841 3842 m = MyModel() 3843 cf = m.add.get_concrete_function() 3844 cf(-12.0, 3.0) 3845 3846 @test_util.run_in_graph_and_eager_modes 3847 def testConcreteFunctionStructuredSignatureKeywordOrder(self): 3848 # Check that keyword-only arguments are sorted appropriately, so that they 3849 # feed the right tensor into each input. 3850 @def_function.function 3851 def g(**kwargs): 3852 return string_ops.reduce_join( 3853 string_ops.reduce_join( 3854 ops.convert_to_tensor(sorted(kwargs.items())), 3855 axis=1, 3856 separator='='), 3857 axis=0, 3858 separator=', ') 3859 3860 s = constant_op.constant('s') 3861 g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s) 3862 self.assertAllEqual( 3863 g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'), 3864 b'a=f, l=e, m=a, p=g, q=d, r=b, v=c') 3865 self.assertAllEqual( 3866 g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'), 3867 b'a=f, l=e, m=a, p=g, q=d, r=b, v=c') 3868 self.assertAllEqual( 3869 g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'), 3870 b'a=f, l=e, m=a, p=g, q=d, r=b, v=c') 3871 3872 # pylint: disable=g-long-lambda 3873 @parameterized.named_parameters([ 3874 dict( 3875 testcase_name='MissingArg', 3876 conc_args=lambda: (1, constant_op.constant(2)), 3877 call_args=lambda: (1,), 3878 error=r'func\(x, y\) missing required arguments: y'), 3879 dict( 3880 testcase_name='MissingVararg', 3881 conc_args=lambda: (1, 2, constant_op.constant(1.0)), 3882 call_args=lambda: (1, 2), 3883 error=r'func\(x, y, <arg3>\) missing required arguments: <arg3>'), 3884 dict( 3885 testcase_name='ExtraPositionalArg', 3886 conc_args=lambda: (1, 2), 3887 call_args=lambda: (1, 2, 3), 3888 error=r'func\(x, y\) takes 2 .* got 3'), 3889 dict( 3890 testcase_name='MissingKeywordOnlyArg', 3891 conc_args=lambda: (1, 2), 3892 conc_kwargs=lambda: {'c': constant_op.constant(1.0)}, 3893 call_args=lambda: (1, 2), 3894 error=r'func\(x, y, \*, c\) missing required arguments: c'), 3895 dict( 3896 testcase_name='ExtraKeywordArg', 3897 conc_args=lambda: (1, 2), 3898 call_args=lambda: (1, 2), 3899 call_kwargs=lambda: {'c': constant_op.constant(1.0)}, 3900 error=r'func\(x, y\) got unexpected keyword arguments: c'), 3901 dict( 3902 testcase_name='ExpectedRaggedGotNest', 3903 conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),), 3904 call_args=lambda: ({ 3905 'a': constant_op.constant([1, 2, 3]) 3906 },), 3907 error=r'func\(x, y\): argument x had incorrect type\n' 3908 r' expected: RaggedTensor\n' 3909 r" got: {'a': (Eager)?Tensor}"), 3910 dict( 3911 testcase_name='WrongRaggedRank', 3912 conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),), 3913 call_args=lambda: (ragged_factory_ops.constant([[[1]]]),), 3914 error=r'func\(x, y\): argument x had incorrect type\n'), 3915 dict( 3916 testcase_name='WrongRaggedDType', 3917 conc_args=lambda: (ragged_factory_ops.constant([[1]]),), 3918 call_args=lambda: (ragged_factory_ops.constant([[1.0]]),), 3919 error=r'func\(x, y\): argument x had incorrect type\n'), 3920 dict( 3921 testcase_name='ExpectedDictGotTensor', 3922 conc_args=lambda: ({ 3923 'a': constant_op.constant(1), 3924 'b': constant_op.constant(1) 3925 },), 3926 call_args=lambda: (constant_op.constant(1),), 3927 error=r'func\(x, y\): argument x had incorrect type\n'), 3928 dict( 3929 testcase_name='ExpectedTupleGotTensor', 3930 conc_args=lambda: 3931 ((constant_op.constant(1), constant_op.constant(2)),), 3932 call_args=lambda: (constant_op.constant(1),), 3933 error=r'func\(x, y\): argument x had incorrect type\n'), 3934 dict( 3935 testcase_name='WrongDType', 3936 conc_args=lambda: (constant_op.constant(1),), 3937 call_args=lambda: (constant_op.constant(1.0),), 3938 exception=(ValueError, errors.InvalidArgumentError, 3939 # on xla_gpu, we get InternalError instead. 3940 errors.InternalError)), 3941 dict( 3942 testcase_name='ExpectedTensorGotInt', 3943 conc_args=lambda: (constant_op.constant(1),), 3944 call_args=lambda: (5,), 3945 error=r'func\(x, y\) expected a Tensor in x, but got int value 5'), 3946 dict( 3947 testcase_name='ExpectedIntGotDifferentInt', 3948 conc_args=lambda: (5,), 3949 call_args=lambda: (8,), 3950 error=r'ConcreteFunction func\(x, y\) was constructed with int ' 3951 r'value 5 in x, but was called with int value 8'), 3952 dict( 3953 testcase_name='ExpectedIntGotTensor', 3954 conc_args=lambda: (5,), 3955 call_args=lambda: (constant_op.constant(6),), 3956 error=r'ConcreteFunction func\(x, y\) was constructed with int ' 3957 'value 5 in x, but was called with (Eager)?Tensor value .*'), 3958 dict( 3959 testcase_name='TwoValuesForArgument', 3960 conc_args=lambda: (1, 2), 3961 call_args=lambda: (1, 2), 3962 call_kwargs=lambda: {'x': 3}, 3963 error=r"func\(x, y\) got two values for 'x'"), 3964 ]) 3965 # pylint: enable=g-long-lambda 3966 @test_util.run_in_graph_and_eager_modes 3967 def testConcreteFunctionStructuredSignatureError(self, 3968 conc_args=(), 3969 conc_kwargs=None, 3970 call_args=(), 3971 call_kwargs=None, 3972 error='.*', 3973 exception=TypeError): 3974 """Tests for errors in the structrued signature. 3975 3976 Args: 3977 conc_args: Positional arguments used for get_concrete_function. 3978 conc_kwargs: Keyword arguments used for get_concrete_function. 3979 call_args: Positional arguments used to call the function. 3980 call_kwargs: Keyword arguments used to call the function. 3981 error: Expected exception message. 3982 exception: Expected exception type. 3983 """ 3984 conc_args = conc_args() if callable(conc_args) else conc_args 3985 conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {} 3986 call_args = call_args() if callable(call_args) else call_args 3987 call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {} 3988 self.assertIsInstance(conc_args, tuple) 3989 self.assertIsInstance(call_args, tuple) 3990 self.assertIsInstance(conc_kwargs, dict) 3991 self.assertIsInstance(call_kwargs, dict) 3992 3993 @def_function.function 3994 def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg 3995 del y, varargs, kwargs 3996 return x 3997 3998 conc = func.get_concrete_function(*conc_args, **conc_kwargs) 3999 with self.assertRaisesRegex(exception, error): 4000 self.evaluate(conc(*call_args, **call_kwargs)) 4001 4002 # pylint: disable=g-long-lambda 4003 @parameterized.named_parameters([ 4004 dict( 4005 testcase_name='MissingArg', 4006 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 4007 call_args=lambda: (constant_op.constant(1),), 4008 error=r'func\(x, y\) missing required arguments: y'), 4009 dict( 4010 testcase_name='TwoValuesForArg', 4011 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 4012 call_args=lambda: (constant_op.constant(1),), 4013 call_kwargs=lambda: { 4014 'x': constant_op.constant(1), 4015 'y': constant_op.constant(1) 4016 }, 4017 error=r"func\(x, y\) got two values for 'x'"), 4018 dict( 4019 testcase_name='ExtraPositionalArg', 4020 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 4021 call_args=lambda: (constant_op.constant(1), constant_op.constant(2), 4022 constant_op.constant(3)), 4023 error=r'func\(x, y\) takes 2 .* got 3'), 4024 dict( 4025 testcase_name='UnexpectedKeywordArg', 4026 conc_args=lambda: (constant_op.constant(1),), 4027 call_args=lambda: (constant_op.constant(1),), 4028 call_kwargs=lambda: {'c': constant_op.constant(1)}, 4029 error=r'func\(x\) got unexpected keyword arguments: c'), 4030 dict( 4031 testcase_name='MissingVararg', 4032 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2), 4033 constant_op.constant(3)), 4034 call_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 4035 error=r'func\(x, y, varargs_0\) missing required ' 4036 r'arguments: varargs_0'), 4037 dict( 4038 testcase_name='MissingKeywordArg', 4039 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 4040 conc_kwargs=lambda: {'c': constant_op.constant(1)}, 4041 call_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 4042 error=r'func\(x, y, c\) missing required arguments: c'), 4043 dict( 4044 testcase_name='ExpectedTensorGotInt', 4045 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 4046 call_args=lambda: (5, constant_op.constant(2)), 4047 error=r'func\(x, y\): expected argument #0\(zero-based\) to be ' 4048 r'a Tensor; got int \(5\)'), 4049 dict( 4050 testcase_name='WrongDType', 4051 conc_args=lambda: (constant_op.constant(1),), 4052 call_args=lambda: (constant_op.constant(1.0),), 4053 exception=(ValueError, errors.InvalidArgumentError, 4054 # on xla_gpu, we get InternalError instead. 4055 errors.InternalError)), 4056 dict( 4057 testcase_name='MissingKeywordArgNestPiece', 4058 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 4059 conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])}, 4060 call_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 4061 call_kwargs=lambda: {'c': constant_op.constant(1)}, 4062 error=r'func\(x, y, c, c_1\) missing required arguments: c_1'), 4063 ]) 4064 # pylint: enable=g-long-lambda 4065 @test_util.run_in_graph_and_eager_modes 4066 def testConcreteFunctionFlatSignatureError(self, 4067 conc_args=(), 4068 conc_kwargs=None, 4069 call_args=(), 4070 call_kwargs=None, 4071 error='.*', 4072 exception=TypeError): 4073 """Tests for errors in the flat signature. 4074 4075 Args: 4076 conc_args: Positional arguments used for get_concrete_function. 4077 conc_kwargs: Keyword arguments used for get_concrete_function. 4078 call_args: Positional arguments used to call the function. 4079 call_kwargs: Keyword arguments used to call the function. 4080 error: Expected exception message. 4081 exception: Expected exception type. 4082 """ 4083 conc_args = conc_args() if callable(conc_args) else conc_args 4084 conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {} 4085 call_args = call_args() if callable(call_args) else call_args 4086 call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {} 4087 self.assertIsInstance(conc_args, tuple) 4088 self.assertIsInstance(call_args, tuple) 4089 self.assertIsInstance(conc_kwargs, dict) 4090 self.assertIsInstance(call_kwargs, dict) 4091 4092 @def_function.function 4093 def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg 4094 del y, varargs, kwargs 4095 return x 4096 4097 conc = func.get_concrete_function(*conc_args, **conc_kwargs) 4098 4099 # Remove _function_spec, to disable the structured signature. 4100 conc._set_function_spec(None) # pylint: disable=protected-access 4101 4102 with self.assertRaisesRegex(exception, error): 4103 self.evaluate(conc(*call_args, **call_kwargs)) 4104 4105 @test_util.run_in_graph_and_eager_modes 4106 def testConcreteFunctionAmbiguousSignature(self): 4107 # When both the flat & structured signatures are applicable, but they 4108 # give different results, we use the structured signature. Note: we expect 4109 # this to be extremely rare. 4110 @def_function.function 4111 def f(x, y): 4112 return x * 10 + y 4113 4114 conc = f.get_concrete_function( 4115 x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'), 4116 y=tensor_spec.TensorSpec(None, dtypes.int32, name='x')) 4117 4118 result = conc(x=constant_op.constant(5), y=constant_op.constant(6)) 4119 self.assertAllEqual(result, 56) 4120 4121 def testPrettyPrintedSignature(self): 4122 4123 @def_function.function 4124 def func(x, kangaroo=None, octopus=7): 4125 del octopus, kangaroo 4126 return x 4127 4128 scalar = constant_op.constant(5) 4129 vector = constant_op.constant([10, 10, 20]) 4130 ragged = ragged_factory_ops.constant([[10, 20], [40]]) 4131 4132 c1 = func.get_concrete_function(scalar, vector) 4133 c1_summary = r'func\(x, kangaroo, octopus=7\)' 4134 c1_details = (r' Args:\n' 4135 r' x: int32 Tensor, shape=\(\)\n' 4136 r' kangaroo: int32 Tensor, shape=\(3,\)\n' 4137 r' Returns:\n' 4138 r' int32 Tensor, shape=\(\)') 4139 self.assertRegex(c1.pretty_printed_signature(verbose=False), c1_summary) 4140 self.assertRegex( 4141 c1.pretty_printed_signature(verbose=True), 4142 c1_summary + '\n' + c1_details) 4143 self.assertRegex( 4144 repr(c1), r'<ConcreteFunction func\(x, kangaroo, octopus=7\) at .*>') 4145 self.assertRegex( 4146 str(c1), 'ConcreteFunction {}\n{}'.format(c1_summary, c1_details)) 4147 4148 c2 = func.get_concrete_function(scalar, ragged, 3) 4149 c2_summary = r'func\(x, kangaroo, octopus=3\)' 4150 c2_details = (r' Args:\n' 4151 r' x: int32 Tensor, shape=\(\)\n' 4152 r' kangaroo: RaggedTensorSpec\(.*\)\n' 4153 r' Returns:\n' 4154 r' int32 Tensor, shape=\(\)') 4155 self.assertRegex(c2.pretty_printed_signature(), 4156 c2_summary + '\n' + c2_details) 4157 4158 c3 = func.get_concrete_function({'a': scalar, 'b': [ragged, ragged]}) 4159 c3_summary = r'func\(x, kangaroo=None, octopus=7\)' 4160 c3_details = (r' Args:\n' 4161 r" x: {'a': <1>, 'b': \[<2>, <3>\]}\n" 4162 r' <1>: int32 Tensor, shape=\(\)\n' 4163 r' <2>: RaggedTensorSpec\(.*\)\n' 4164 r' <3>: RaggedTensorSpec\(.*\)\n' 4165 r' Returns:\n' 4166 r" {'a': <1>, 'b': \[<2>, <3>\]}\n" 4167 r' <1>: int32 Tensor, shape=\(\)\n' 4168 r' <2>: RaggedTensorSpec\(.*\)\n' 4169 r' <3>: RaggedTensorSpec\(.*\)') 4170 4171 # python 3.5 does not gurantee deterministic iteration of dict contents 4172 # which can lead mismatch on pretty_printed_signature output for "Args" 4173 if sys.version_info >= (3, 6): 4174 self.assertRegex(c3.pretty_printed_signature(), 4175 c3_summary + '\n' + c3_details) 4176 4177 # pylint: disable=keyword-arg-before-vararg 4178 @def_function.function 4179 def func2(x, y=3, *args, **kwargs): 4180 return (x, y, args, kwargs) 4181 4182 c4 = func2.get_concrete_function(scalar, 4, 5, a=scalar) 4183 c4_summary = 'func2(x, y=4, <arg3>=5, *, a)' 4184 self.assertEqual(c4.pretty_printed_signature(verbose=False), c4_summary) 4185 4186 c5 = func2.get_concrete_function(8, vector) 4187 c5_summary = 'func2(x=8, y)' 4188 self.assertEqual(c5.pretty_printed_signature(verbose=False), c5_summary) 4189 4190 def testPrettyPrintedExplicitSignatureWithKeywordArg(self): # b/159639913 4191 4192 @def_function.function(input_signature=[tensor_spec.TensorSpec(None)]) 4193 def fn(a, b=1): 4194 return a + b 4195 4196 concrete_fn = fn.get_concrete_function() 4197 self.assertEqual(concrete_fn.pretty_printed_signature(False), 'fn(a)') 4198 self.assertEqual( 4199 concrete_fn.pretty_printed_signature(True), 'fn(a)\n' 4200 ' Args:\n' 4201 ' a: float32 Tensor, shape=<unknown>\n' 4202 ' Returns:\n' 4203 ' float32 Tensor, shape=<unknown>') 4204 4205 def testPrettyPrintedSignatureLoadedNamedTuple(self): 4206 Point = collections.namedtuple('Point', ['x', 'y']) 4207 4208 @def_function.function 4209 def fn(b, a): # pylint: disable=unused-argument 4210 return 1. 4211 4212 b = Point( 4213 x=constant_op.constant(1., dtype=dtypes.float32), 4214 y=constant_op.constant(1., dtype=dtypes.float32)) 4215 a = Point( 4216 x=constant_op.constant(1, dtype=dtypes.int32), 4217 y=constant_op.constant(1, dtype=dtypes.int32)) 4218 4219 mod = module.Module() 4220 f = fn.get_concrete_function(b, a) 4221 save(mod, '/tmp/f', signatures=f) 4222 loaded = load('/tmp/f') 4223 4224 printed = loaded.signatures['serving_default'].pretty_printed_signature() 4225 self.assertIn('a: int32 Tensor, shape=()', printed) 4226 self.assertIn('a_1: int32 Tensor, shape=()', printed) 4227 self.assertIn('b: float32 Tensor, shape=()', printed) 4228 self.assertIn('b_1: float32 Tensor, shape=()', printed) 4229 4230 @test_util.run_in_graph_and_eager_modes 4231 def testIndexedSlicesAsGradientsForConcreteFunctions(self): 4232 4233 @def_function.function 4234 def summing_rnn(inputs): 4235 return math_ops.reduce_sum(inputs, axis=1) 4236 4237 @def_function.function 4238 def gradients(inputs): 4239 with backprop.GradientTape() as tape: 4240 tape.watch(inputs) 4241 hidden = summing_rnn(inputs) 4242 hidden = array_ops.gather(hidden, constant_op.constant([0])) 4243 loss = math_ops.reduce_mean(hidden) 4244 return tape.gradient(loss, inputs) 4245 4246 gradients(constant_op.constant([[[1.0], [2.0]]])) # No error is raised 4247 4248 def testFollowTypeHintsTraceBasic(self): 4249 trace_count = [0] 4250 4251 def func(x: ops.Tensor): 4252 trace_count[0] += 1 4253 return x 4254 4255 enabled = def_function.function(func, experimental_follow_type_hints=True) 4256 disabled = def_function.function(func, experimental_follow_type_hints=False) 4257 4258 enabled(1) # Initial call gets traced 4259 enabled(2) 4260 enabled(3) 4261 self.assertEqual(trace_count[0], 1) 4262 4263 trace_count = [0] 4264 disabled(1) 4265 disabled(2) # Retrace 4266 disabled(3) # Retrace 4267 self.assertEqual(trace_count[0], 3) 4268 4269 def testFollowTypeHintsTraceWithArgs(self): 4270 trace_count = [0] 4271 4272 def func(*args: ops.Tensor): 4273 trace_count[0] += 1 4274 return args 4275 4276 enabled = def_function.function(func, experimental_follow_type_hints=True) 4277 disabled = def_function.function(func, experimental_follow_type_hints=False) 4278 4279 args = ( 4280 'abc', 4281 'def', 4282 ) * 20 4283 args2 = ( 4284 'def', 4285 'abc', 4286 ) * 20 4287 4288 enabled(args) 4289 enabled(args2) 4290 self.assertEqual(trace_count[0], 1) 4291 4292 trace_count = [0] 4293 disabled(args) 4294 disabled(args2) # Retrace 4295 self.assertEqual(trace_count[0], 2) 4296 4297 def testFollowTypeHintsTraceWithKwargs(self): 4298 trace_count = [0] 4299 4300 def func(t: ops.Tensor, **kwargs: ops.Tensor): 4301 del kwargs 4302 trace_count[0] += 1 4303 return t 4304 4305 enabled = def_function.function(func, experimental_follow_type_hints=True) 4306 disabled = def_function.function(func, experimental_follow_type_hints=False) 4307 4308 enabled(1, x=1, y=1.0, z='one') 4309 enabled(2, x=2, y=2.0, z='two') 4310 self.assertEqual(trace_count[0], 1) 4311 4312 trace_count = [0] 4313 disabled(1, x=1, y=1.0, z='one') 4314 disabled(2, x=2, y=2.0, z='two') # Retrace 4315 self.assertEqual(trace_count[0], 2) 4316 4317 def testFollowTypeHintsTraceWithMultipleInputTypes(self): 4318 trace_count = [0] 4319 4320 def func(t: ops.Tensor, *args: ops.Tensor, **kwargs: ops.Tensor): 4321 del args, kwargs 4322 trace_count[0] += 1 4323 return t 4324 4325 enabled = def_function.function(func, experimental_follow_type_hints=True) 4326 disabled = def_function.function(func, experimental_follow_type_hints=False) 4327 4328 enabled(1, constant_op.constant(1), 'str', x=4.0) 4329 enabled(2, constant_op.constant(2), 'str2', x=5.0) 4330 self.assertEqual(trace_count[0], 1) 4331 4332 trace_count = [0] 4333 disabled(1, constant_op.constant(1), 'str', x=4.0) 4334 disabled(2, constant_op.constant(2), 'str2', x=5.0) # Retrace 4335 self.assertEqual(trace_count[0], 2) 4336 4337 def testFollowTypeHintsTraceWithOnlyArgNamed(self): 4338 trace_count = [0] 4339 4340 def func(t: ops.Tensor, i: int = 1, **kwargs): # pylint: disable=bad-whitespace 4341 del i, kwargs 4342 trace_count[0] += 1 4343 return t 4344 4345 enabled = def_function.function(func, experimental_follow_type_hints=True) 4346 4347 enabled(1, 3, x=4.0, y='str') 4348 enabled(2, 4, x=4.0, y='str') # Retrace 4349 self.assertEqual(trace_count[0], 2) 4350 4351 def testFollowTypeHintsTraceWithNotAllNamed(self): 4352 trace_count = [0] 4353 4354 def func(x, y: ops.Tensor, z: int): 4355 del y, z 4356 trace_count[0] += 1 4357 return x 4358 4359 enabled = def_function.function(func, experimental_follow_type_hints=True) 4360 4361 enabled(1, 2, 3) 4362 enabled(1, 20, 3) # No retrace - change in ops.Tensor typed arg 4363 enabled(2, 2, 3) # Retrace - change in untyped arg 4364 enabled(2, 2, 4) # Retrace - change in typed arg 4365 self.assertEqual(trace_count[0], 3) 4366 4367 def testFollowTypeHintsTraceWithOnlyArgsNamed(self): 4368 trace_count = [0] 4369 4370 def func(x, y, *args: ops.Tensor): 4371 del y, args 4372 trace_count[0] += 1 4373 return x 4374 4375 enabled = def_function.function(func, experimental_follow_type_hints=True) 4376 4377 enabled(1, 20, 3, 4, 5, 6) 4378 enabled(1, 20, 3, 4, 5, 60) # No retrace - change in *args 4379 enabled(1, 30, 7, 8, 9, 10) # Retrace - change in args 4380 self.assertEqual(trace_count[0], 2) 4381 4382 def testFollowTypeHintsTraceWithOnlyKwargsNamed(self): 4383 trace_count = [0] 4384 4385 def func(x, y, *args, **kwargs: ops.Tensor): 4386 del y, args, kwargs 4387 trace_count[0] += 1 4388 return x 4389 4390 enabled = def_function.function(func, experimental_follow_type_hints=True) 4391 4392 enabled(1, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0) 4393 enabled( 4394 1, 2, 3, 4, 5, 6, a=1.5, b=2.5, 4395 c=3.5) # No retrace - change in **kwargs 4396 enabled(100, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0) # Retrace - change in args 4397 enabled( 4398 1, 2, 3, 4, 5, 100, a=1.0, b=2.0, c=3.0) # Retrace - change in *args 4399 self.assertEqual(trace_count[0], 3) 4400 4401 def testFollowTypeHintsTraceWithArgsEquals(self): 4402 trace_count = [0] 4403 4404 def func( 4405 x: ops.Tensor = 0, # pylint:disable=bad-whitespace 4406 y: int = 1, # pylint:disable=bad-whitespace 4407 **kwargs: ops.Tensor): 4408 del y, kwargs 4409 trace_count[0] += 1 4410 return x 4411 4412 enabled = def_function.function(func, experimental_follow_type_hints=True) 4413 4414 enabled(x=1, y=2, z=3) 4415 enabled(x=1, y=3, z=3) # Retrace - change in args 4416 enabled(x=2, y=2, z=4) # No retrace - change in args and **kwargs 4417 enabled(x=2, y=2, z=4, u=5) # Retrace - change in **kwargs 4418 self.assertEqual(trace_count[0], 3) 4419 4420 def testFollowTypeHintsWithTensorSpec(self): 4421 def func(x: ops.Tensor, y): 4422 return x + y 4423 v = def_function.function(experimental_follow_type_hints=True)(func) 4424 v = v.get_concrete_function( 4425 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), 3) 4426 x = v(constant_op.constant(1.), 3) 4427 self.assertEqual(x.numpy(), 4.) 4428 4429 def testFollowTypeHintsTraceWithKwArgsAndNoVarKws(self): 4430 trace_count = [0] 4431 4432 def func(a: int, b: ops.Tensor, 4433 x: ops.Tensor = 0, y: int = 1): 4434 del a, b, y 4435 trace_count[0] += 1 4436 return x 4437 4438 enabled = def_function.function(func, experimental_follow_type_hints=True) 4439 4440 enabled(0, 0, x=1, y=2) 4441 enabled(0, 0, x=2, y=2,) # No retrace, since only tensor changed 4442 self.assertEqual(trace_count[0], 1) 4443 4444 # Pass args as keyword args. 4445 enabled(a=0, b=0, x=2, y=2,) # No retrace, args are the same 4446 self.assertEqual(trace_count[0], 1) 4447 4448 enabled(a=1, b=0, x=2, y=2,) # Retrace, since non-tensor arg changed 4449 self.assertEqual(trace_count[0], 2) 4450 4451 enabled(a=1, b=2, x=2, y=2) # No retrace, since only tensor changed 4452 self.assertEqual(trace_count[0], 2) 4453 4454 trace_count[0] = 0 4455 disabled = def_function.function(func, experimental_follow_type_hints=False) 4456 disabled(0, 0, x=1, y=2) 4457 disabled(0, 0, x=2, y=2,) # Retrace 4458 self.assertEqual(trace_count[0], 2) 4459 4460 def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self): 4461 trace_count = [0] 4462 4463 def func(x, y, **kwargs: ops.Tensor): 4464 del y, kwargs 4465 trace_count[0] += 1 4466 return x 4467 4468 enabled = def_function.function(func, experimental_follow_type_hints=True) 4469 4470 enabled(x=1, y=2, z=3) 4471 enabled(x=1, y=3, z=3) # Retrace 4472 enabled(x=1, y=2, z=4) # No retrace 4473 enabled(x=2, y=2, z=4) # Retrace 4474 enabled(x=2, y=2, z=4, u=5) # Retrace 4475 self.assertEqual(trace_count[0], 4) 4476 4477 def testFollowTypeHintsTraceWithArgsEqualsTypedArgs(self): 4478 trace_count = [0] 4479 4480 def func(x: ops.Tensor, y: int, **kwargs): 4481 del y, kwargs 4482 trace_count[0] += 1 4483 return x 4484 4485 enabled = def_function.function(func, experimental_follow_type_hints=True) 4486 4487 enabled(x=1, y=2, z=3) 4488 enabled(x=1, y=3, z=3) # Retrace 4489 enabled(x=1, y=2, z=4) # Retrace 4490 enabled(x=2, y=2, z=3) # No retrace 4491 enabled(x=2, y=2, z=4, u=5) # Retrace 4492 self.assertEqual(trace_count[0], 4) 4493 4494 def testFollowTypeHintsTraceWithKwOnlyArgsBasic(self): 4495 trace_count = [0] 4496 4497 def func(*, a: ops.Tensor = None, b=1): # pylint: disable=bad-whitespace 4498 del b 4499 trace_count[0] += 1 4500 return a 4501 4502 enabled = def_function.function(func, experimental_follow_type_hints=True) 4503 4504 enabled(a=1, b=2) 4505 enabled(a=2, b=2) # No retrace 4506 enabled(a=1, b=1) # Retrace 4507 self.assertEqual(trace_count[0], 2) 4508 4509 def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArg(self): 4510 trace_count = [0] 4511 4512 def func(arg: ops.Tensor, *args, kwonly, **kwargs): 4513 del args, kwonly, kwargs 4514 trace_count[0] += 1 4515 return arg 4516 4517 enabled = def_function.function(func, experimental_follow_type_hints=True) 4518 4519 enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) 4520 enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace 4521 enabled(1000, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace 4522 enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4523 enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace 4524 enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace 4525 self.assertEqual(trace_count[0], 4) 4526 4527 def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArgs(self): 4528 trace_count = [0] 4529 4530 def func(arg, *args: ops.Tensor, kwonly, **kwargs): 4531 del args, kwonly, kwargs 4532 trace_count[0] += 1 4533 return arg 4534 4535 enabled = def_function.function(func, experimental_follow_type_hints=True) 4536 4537 enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) 4538 enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4539 enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # No retrace 4540 enabled(1, 200, 300, 400, kwonly=5, kwarg1=6, kwarg2=7) # No retrace 4541 enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace 4542 enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace 4543 self.assertEqual(trace_count[0], 4) 4544 4545 def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwOnlyArg(self): 4546 trace_count = [0] 4547 4548 def func(arg, *args, kwonly: ops.Tensor, **kwargs): 4549 del args, kwonly, kwargs 4550 trace_count[0] += 1 4551 return arg 4552 4553 enabled = def_function.function(func, experimental_follow_type_hints=True) 4554 4555 enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) 4556 enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4557 enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4558 enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # No retrace 4559 enabled(1, 2, 3, 4, kwonly=500, kwarg1=6, kwarg2=7) # No retrace 4560 enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace 4561 self.assertEqual(trace_count[0], 4) 4562 4563 def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwargs(self): 4564 trace_count = [0] 4565 4566 def func(arg, *args, kwonly, **kwargs: ops.Tensor): 4567 del args, kwonly, kwargs 4568 trace_count[0] += 1 4569 return arg 4570 4571 enabled = def_function.function(func, experimental_follow_type_hints=True) 4572 4573 enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) 4574 enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4575 enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4576 enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace 4577 enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # No retrace 4578 enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700) # No retrace 4579 self.assertEqual(trace_count[0], 4) 4580 4581 def testWithExtraWrapper(self): 4582 4583 class Foo(module.Module): 4584 4585 def __init__(self): 4586 super().__init__() 4587 self.var = None 4588 4589 @def_function.function 4590 @dummy_tf_decorator 4591 def add(self, x, y, z=1): 4592 if self.var is None: 4593 return x + y + z 4594 4595 foo = Foo() 4596 self.assertEqual(foo.add(2, 3).numpy(), 6) 4597 4598 @parameterized.parameters([(def_function.function, dummy_tf_decorator), 4599 (dummy_tf_decorator, def_function.function), 4600 (def_function.function, def_function.function)]) 4601 def testWithExtraWrapperRedundantArgs(self, decorator1, decorator2): 4602 4603 class Foo(module.Module): 4604 4605 def __init__(self): 4606 super().__init__() 4607 self.var = None 4608 4609 @decorator1 4610 @decorator2 4611 def add1(self, x, y): 4612 if self.var is None: 4613 return x + y 4614 4615 foo = Foo() 4616 with self.assertRaisesRegex(TypeError, 'got two values'): 4617 foo.add1(2, x=3) # pylint: disable=redundant-keyword-arg,no-value-for-parameter 4618 4619 def testWithExtraWrapperMissingArgs(self): 4620 4621 class Foo(module.Module): 4622 4623 def __init__(self): 4624 super().__init__() 4625 self.var = None 4626 4627 @def_function.function 4628 @dummy_tf_decorator 4629 def add1(self, x, y): 4630 if self.var is None: 4631 return x + y 4632 4633 @def_function.function 4634 @dummy_tf_decorator 4635 def add2(self, x, y): 4636 if self.var is None: 4637 return x + y 4638 4639 @def_function.function 4640 @def_function.function 4641 def add3(self, x, y): 4642 if self.var is None: 4643 return x + y 4644 4645 foo = Foo() 4646 with self.assertRaisesRegex( 4647 TypeError, 'missing 1 required positional argument: \'y\''): 4648 foo.add1(2) # pylint: disable=no-value-for-parameter 4649 4650 with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'): 4651 foo.add1(y=2) # pylint: disable=no-value-for-parameter 4652 4653 with self.assertRaisesRegex( 4654 TypeError, 'missing 1 required positional argument: \'y\''): 4655 foo.add2(2) # pylint: disable=no-value-for-parameter 4656 4657 with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'): 4658 foo.add2(y=2) # pylint: disable=no-value-for-parameter 4659 4660 with self.assertRaisesRegex( 4661 TypeError, 'missing 1 required positional argument: \'y\''): 4662 foo.add3(2) # pylint: disable=no-value-for-parameter 4663 4664 with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'): 4665 foo.add3(y=2) # pylint: disable=no-value-for-parameter 4666 4667 def testMissingArgsTfFunctionedMethod(self): 4668 4669 class A: 4670 4671 def func(self, position_arg1, position_arg2): 4672 return position_arg1, position_arg2 4673 4674 @def_function.function 4675 def decorated_method(self, position_arg1, position_arg2): 4676 return position_arg1, position_arg2 4677 4678 a_instance = A() 4679 tf_method_pos = def_function.function(a_instance.func) 4680 with self.assertRaisesRegex( 4681 TypeError, '.* missing 1 required argument: position_arg1'): 4682 tf_method_pos(position_arg2='foo') 4683 4684 # tf.function-decorated instance methods need to be tested because of 4685 # the __get__ method implementation. 4686 tf_func_decorated_method = def_function.function( 4687 a_instance.decorated_method) 4688 tf_func_decorated_method(position_arg1='foo', position_arg2='bar') 4689 with self.assertRaisesRegex( 4690 TypeError, '.* missing 1 required argument: position_arg1'): 4691 tf_func_decorated_method(position_arg2='bar') 4692 4693 def testMissingArgsTfFunctionedObject(self): 4694 4695 class A: 4696 4697 def __call__(self, position_arg1, position_arg2): 4698 return position_arg1, position_arg2 4699 4700 a_instance = A() 4701 4702 # A tf.function-decorated callable object needs to be tested because of 4703 # the special inspect results. 4704 tf_func_obj = def_function.function(a_instance) 4705 tf_func_obj(position_arg1=1, position_arg2=2) 4706 with self.assertRaisesRegex( 4707 TypeError, '.* missing 1 required argument: position_arg1'): 4708 tf_func_obj(position_arg2='bar') 4709 4710 def testMissingArgsTfFunctionedFunctions(self): 4711 4712 def func_pos(position_arg1, position_arg2): 4713 return position_arg1, position_arg2 4714 4715 def func_with_default(position_arg, named_arg=None): 4716 return position_arg, named_arg 4717 4718 def func_pos_3args(position_arg1, position_arg2, position_arg3): 4719 return position_arg1, position_arg2, position_arg3 4720 4721 tf_func_pos = def_function.function(func_pos) 4722 with self.assertRaisesRegex( 4723 TypeError, '.* missing 1 required argument: position_arg1'): 4724 tf_func_pos(position_arg2='foo') 4725 4726 tf_func_with_default = def_function.function(func_with_default) 4727 tf_func_with_default(position_arg='bar') 4728 with self.assertRaisesRegex(TypeError, 4729 '.* missing 1 required argument: position_arg'): 4730 tf_func_with_default(named_arg='foo') 4731 4732 tf_func_pos_3args = def_function.function(func_pos_3args) 4733 with self.assertRaisesRegex( 4734 TypeError, 4735 '.* missing required arguments: position_arg1, position_arg3'): 4736 tf_func_pos_3args(position_arg2='foo') 4737 4738 def testShapeInferencePropagateConstNestedStack(self): 4739 4740 @def_function.function(input_signature=[ 4741 tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), 4742 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4743 ]) 4744 def f(x, s): 4745 old_shape = array_ops.shape(x) 4746 new_shape = array_ops.stack([old_shape[0], s], axis=0) 4747 y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) 4748 return y 4749 4750 @def_function.function(input_signature=[ 4751 tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) 4752 ]) 4753 def g(x): 4754 y = f(x, s=5) 4755 assert y.shape.as_list() == [3, 5], y.shape.as_list() 4756 return y 4757 4758 self.assertAllEqual( 4759 g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5])) 4760 4761 def testShapeInferencePropagateConstNestedUnstackStack(self): 4762 4763 @def_function.function(input_signature=[ 4764 tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), 4765 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4766 ]) 4767 def f(x, s): 4768 s0, _ = array_ops.unstack(array_ops.shape(x), axis=0) 4769 new_shape = array_ops.stack([s0, s], axis=0) 4770 y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) 4771 return y 4772 4773 @def_function.function(input_signature=[ 4774 tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) 4775 ]) 4776 def g(x): 4777 y = f(x, s=5) 4778 assert y.shape.as_list() == [3, 5], y.shape.as_list() 4779 return y 4780 4781 self.assertAllEqual( 4782 g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5])) 4783 4784 def testShapeInferencePropagateConstNestedConcat(self): 4785 4786 @def_function.function(input_signature=[ 4787 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4788 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4789 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4790 ]) 4791 def f(d1, d2, d3): 4792 new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) 4793 y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) 4794 return y 4795 4796 @def_function.function() 4797 def g(): 4798 y = f(1, 2, 3) 4799 assert y.shape.as_list() == [1, 2, 3], y.shape.as_list() 4800 return y 4801 4802 self.assertAllEqual(g(), array_ops.ones([1, 2, 3])) 4803 4804 def testShapeInferencePropagateConstDoubleNested(self): 4805 4806 @def_function.function(input_signature=[ 4807 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4808 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4809 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4810 ]) 4811 def f(d1, d2, d3): 4812 new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) 4813 y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) 4814 return y 4815 4816 @def_function.function() 4817 def g(): 4818 y = def_function.function(f)(1, 2, 3) 4819 assert y.shape.as_list() == [1, 2, 3], y.shape.as_list() 4820 return y 4821 4822 self.assertAllEqual(g(), array_ops.ones([1, 2, 3])) 4823 4824 @test_util.run_v2_only 4825 def testControlDependencyAfterInline(self): 4826 v = variables.Variable(0.) 4827 4828 @def_function.function 4829 def assign(): 4830 return v.assign(1.) 4831 4832 @def_function.function 4833 def assign_add(): 4834 return v.assign_add(1.) 4835 4836 @def_function.function 4837 def f(): 4838 check_ops.assert_equal_v2(assign(), 1.) 4839 check_ops.assert_equal_v2(assign_add(), 2.) 4840 4841 # We don't have a way to inspect the inlined graph in Python, so we run it 4842 # multiple times to have more confidence the dependency is correct. 4843 for _ in range(30): 4844 f() 4845 4846 @test_util.run_v2_only 4847 def testReadInFuncWriteOutside(self): 4848 # Run many times since we are testing for a potential race condition. 4849 for _ in range(30): 4850 # pylint: disable=cell-var-from-loop 4851 v = variables.Variable(1.) 4852 4853 @def_function.function 4854 def add_one(): 4855 return v + 1. 4856 4857 @def_function.function 4858 def get_v_plus_one(): 4859 v_plus_one = add_one() 4860 v.assign_add(2.0) 4861 return v_plus_one 4862 4863 self.assertAllEqual(get_v_plus_one(), 2.0) 4864 4865 def testOpExpandErrorMessage(self): 4866 @def_function.function 4867 def test_fn(): 4868 if array_ops.constant(False): 4869 return array_ops.constant(1) 4870 else: 4871 return script_ops.eager_py_func( 4872 func=lambda: array_ops.constant([2.]), inp=(), Tout=dtypes.int32) 4873 4874 error_pattern = re.compile(r'Graph execution error.*func=lambda', re.DOTALL) 4875 with self.assertRaisesRegex(errors.InvalidArgumentError, error_pattern): 4876 test_fn() 4877 4878 4879class MultiDeviceTest(test.TestCase, parameterized.TestCase): 4880 4881 @test_util.run_gpu_only 4882 def testMultiDeviceOutput(self): 4883 """Tests that functions can produce outputs on multiple devices.""" 4884 @function.defun 4885 def func(a, b, transpose_a): 4886 with ops.device('/device:CPU:0'): 4887 m1 = math_ops.matmul(a, b, transpose_a=transpose_a) 4888 with ops.device('/device:GPU:0'): 4889 m2 = math_ops.matmul(a, b, transpose_a=transpose_a) 4890 return m1, m2 4891 4892 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 4893 m1, m2 = func(t, t, transpose_a=True) 4894 self.assertAllEqual(m1.numpy(), [[10, 14], [14, 20]]) 4895 self.assertRegex(m1.backing_device, 'CPU') 4896 self.assertAllEqual(m2.numpy(), [[10, 14], [14, 20]]) 4897 self.assertRegex(m2.backing_device, 'GPU') 4898 4899 @test_util.run_gpu_only 4900 def testEmptyBody(self): 4901 @function.defun 4902 def func(a, b): 4903 return b, a 4904 4905 with ops.device('/device:CPU:0'): 4906 a = array_ops.identity(3.0) 4907 with ops.device('/device:GPU:0'): 4908 b = array_ops.identity(5.0) 4909 4910 m1, m2 = func(a, b) 4911 self.assertAllEqual(m1.numpy(), 5.0) 4912 self.assertRegex(m1.backing_device, 'GPU') 4913 self.assertAllEqual(m2.numpy(), 3.0) 4914 self.assertRegex(m2.backing_device, 'CPU') 4915 4916 @test_util.run_gpu_only 4917 def testMultiDeviceInt32(self): 4918 """Tests that multi-device functions can take and output INT32s. 4919 4920 When an INT32 device tensor is fed into a function, it is copied to CPU 4921 by the eager runtime. The function sees all INT32 inputs on CPU. 4922 4923 We set allocator attribute 'on_host' for INT32 outputs. They can be 4924 partitioned into the GPU component function, but will be allocated on 4925 CPU nevertheless. 4926 4927 There is experimental support for `ints_on_device` in 4928 FunctionLibraryRuntime now. We can try that. 4929 4930 """ 4931 with ops.device('/device:CPU:0'): 4932 int_cpu = constant_op.constant(3, dtype=dtypes.int32) 4933 resource = resource_variable_ops.ResourceVariable(5, dtype=dtypes.int32) 4934 with ops.device('/device:GPU:0'): 4935 int_gpu = constant_op.constant(7, dtype=dtypes.int32) 4936 4937 @function.defun 4938 def func(int_cpu, resource, int_gpu): 4939 with ops.device('/device:CPU:0'): 4940 m1 = int_cpu * resource + int_gpu 4941 with ops.device('/device:GPU:0'): 4942 # This computation will happen on GPU but m2 will be copied to CPU. 4943 m2 = int_gpu * resource + int_cpu + 1 4944 return m1, m2 4945 4946 m1, m2 = func(int_cpu, resource, int_gpu) 4947 self.assertAllEqual(m1.numpy(), 22) 4948 self.assertRegex(m1.backing_device, 'CPU') 4949 self.assertAllEqual(m2.numpy(), 39) 4950 self.assertRegex(m2.backing_device, 'CPU') 4951 4952 # flip arguments 4953 m1, m2 = func(int_gpu, resource, int_cpu) 4954 self.assertAllEqual(m1.numpy(), 38) 4955 self.assertRegex(m1.backing_device, 'CPU') 4956 self.assertAllEqual(m2.numpy(), 23) 4957 self.assertRegex(m2.backing_device, 'CPU') 4958 4959 @test_util.run_gpu_only 4960 def testMultiDeviceColocateWith(self): 4961 """Tests that function's outputs respect colocation constraints.""" 4962 @function.defun 4963 def func(a, b): 4964 with ops.colocate_with(a): 4965 ra = 2 * a 4966 with ops.colocate_with(b): 4967 rb = 3 * b 4968 return ra, rb 4969 4970 devices = ['/device:CPU:0', '/device:GPU:0'] 4971 for dev1, dev2 in itertools.product(devices, devices): 4972 with ops.device(dev1): 4973 a = array_ops.identity(1.0) 4974 with ops.device(dev2): 4975 b = array_ops.identity(10.0) 4976 4977 ra, rb = func(a, b) 4978 self.assertEqual(ra.numpy(), 2.0) 4979 self.assertRegex(ra.backing_device, dev1) 4980 self.assertEqual(rb.numpy(), 30.0) 4981 self.assertRegex(rb.backing_device, dev2) 4982 4983 @test_util.run_gpu_only 4984 def testMultiDeviceResources(self): 4985 with ops.device('/device:CPU:0'): 4986 c1 = resource_variable_ops.ResourceVariable(2.0) 4987 c2 = resource_variable_ops.ResourceVariable(7.0) 4988 with ops.device('/device:GPU:0'): 4989 g1 = resource_variable_ops.ResourceVariable(3.0) 4990 g2 = resource_variable_ops.ResourceVariable(5.0) 4991 4992 @function.defun 4993 def func(resource1, resource2): 4994 with ops.device('/device:CPU:0'): 4995 result1 = resource1 * g2 4996 with ops.device('/device:GPU:0'): 4997 result2 = resource2 * c2 4998 return result1, result2 4999 5000 r1, r2 = func(c1, g1) 5001 self.assertEqual(r1.numpy(), 10.0) 5002 self.assertRegex(r1.backing_device, 'CPU') 5003 self.assertEqual(r2.numpy(), 21.0) 5004 self.assertRegex(r2.backing_device, 'GPU') 5005 5006 # Call with flipped inputs. Check that we look at resource's 5007 # device and reinstantiates the function when inputs' devices change. 5008 r1, r2 = func(g1, c1) 5009 self.assertEqual(r1.numpy(), 15.0) 5010 self.assertRegex(r1.backing_device, 'CPU') 5011 self.assertEqual(r2.numpy(), 14.0) 5012 self.assertRegex(r2.backing_device, 'GPU') 5013 5014 @test_util.run_gpu_only 5015 def testOutputResources(self): 5016 with ops.device('/device:CPU:0'): 5017 c1 = resource_variable_ops.ResourceVariable(2.0) 5018 with ops.device('/device:GPU:0'): 5019 g1 = resource_variable_ops.ResourceVariable(3.0) 5020 5021 @function.defun 5022 def func(resource1, resource2): 5023 with ops.device('/device:CPU:0'): 5024 result1 = resource1 * 5 5025 with ops.device('/device:GPU:0'): 5026 result2 = resource2 * 7 5027 return result1, resource1.handle, result2, resource2.handle 5028 5029 r1, res1, r2, res2 = func(c1, g1) 5030 self.assertEqual(r1.numpy(), 10.0) 5031 self.assertRegex(r1.backing_device, 'CPU') 5032 self.assertEqual(r2.numpy(), 21.0) 5033 self.assertRegex(r2.backing_device, 'GPU') 5034 5035 def check_handle(handle, expected_value): 5036 self.assertRegex(handle.backing_device, 'CPU') 5037 tensor = gen_resource_variable_ops.read_variable_op( 5038 handle, dtypes.float32) 5039 self.assertEqual(tensor.numpy(), expected_value) 5040 5041 # Check that handles returned from functions are on CPU and an op using 5042 # the resource handle is correctly placed on the device backing the 5043 # resource. 5044 check_handle(res1, 2.0) 5045 check_handle(res2, 3.0) 5046 5047 # Call with flipped inputs to make sure the same the function is 5048 # reinstantiated and eager runtime does not mess up the device assignment 5049 # for ops consuming handles returned from defuns. 5050 r1, res1, r2, res2 = func(g1, c1) 5051 self.assertEqual(r1.numpy(), 15.0) 5052 self.assertRegex(r1.backing_device, 'CPU') 5053 self.assertEqual(r2.numpy(), 14.0) 5054 self.assertRegex(r2.backing_device, 'GPU') 5055 check_handle(res1, 3.0) 5056 check_handle(res2, 2.0) 5057 5058 @test_util.run_gpu_only 5059 def testPassResourceThroughNestedFunctionCall(self): 5060 """Test passing GPU resource to noinline function call placed on CPU. 5061 5062 PartitionedCallOp must not enforce any particular device assignment for the 5063 resource output. Inner function marked as `_nospecialize`, so Grappler would 5064 not prune unused function output. 5065 """ 5066 5067 with ops.device('/device:GPU:0'): 5068 g1 = resource_variable_ops.ResourceVariable(3.0) 5069 5070 @function.defun_with_attributes(attributes={ 5071 '_noinline': True, 5072 '_nospecialize': True 5073 }) 5074 def inner(resource1): 5075 return resource1 * 2, resource1.handle 5076 5077 @function.defun 5078 def outer(resource1): 5079 with ops.device('/device:CPU:0'): 5080 r1, _ = inner(resource1) 5081 return r1 5082 5083 r1 = outer(g1) 5084 5085 self.assertEqual(r1.numpy(), 6.0) 5086 self.assertRegex(r1.backing_device, 'CPU') 5087 5088 @test_util.run_gpu_only 5089 def testReturnResourceFromNestedFunctionCall(self): 5090 """Test returning GPU resource from noinline function call placed on CPU. 5091 5092 When inferring output devices for the return value, do not set a device for 5093 returns of DT_RESOURCE data type based on the device assignment of the node 5094 that produced that resource. As an example function call placed on CPU can 5095 return resources on GPU. 5096 """ 5097 5098 with ops.device('/device:GPU:0'): 5099 g1 = resource_variable_ops.ResourceVariable(3.0) 5100 5101 @function.defun_with_attributes(attributes={ 5102 '_noinline': True 5103 }) 5104 def inner(resource1): 5105 resource1.assign_add(2.0) 5106 return resource1 * 2, resource1.handle 5107 5108 @function.defun 5109 def outer(resource1): 5110 with ops.device('/device:CPU:0'): 5111 r1, res1 = inner(resource1) 5112 return r1, res1 5113 5114 r1, res1 = outer(g1) 5115 5116 self.assertEqual(r1.numpy(), 10.0) 5117 self.assertRegex(r1.backing_device, 'CPU') 5118 5119 def check_handle(handle, expected_value): 5120 self.assertRegex(handle.backing_device, 'CPU') 5121 tensor = gen_resource_variable_ops.read_variable_op( 5122 handle, dtypes.float32) 5123 self.assertEqual(tensor.numpy(), expected_value) 5124 5125 # Check that handles returned from functions are on CPU and an op using 5126 # the resource handle is correctly placed on the device backing the 5127 # resource. 5128 check_handle(res1, 5.0) 5129 5130 @test_util.run_gpu_only 5131 def testComplexInputOutputDevicePattern(self): 5132 """Tests input/output mapping logic in partitioning.""" 5133 with ops.device('/device:CPU:0'): 5134 rc0 = resource_variable_ops.ResourceVariable(2.0) 5135 rc1 = resource_variable_ops.ResourceVariable(3.0) 5136 cc0 = array_ops.identity(5.0) 5137 cc1 = array_ops.identity(7.0) 5138 with ops.device('/device:GPU:0'): 5139 rg0 = resource_variable_ops.ResourceVariable(11.0) 5140 rg1 = resource_variable_ops.ResourceVariable(13.0) 5141 cg0 = array_ops.identity(17.0) 5142 cg1 = array_ops.identity(19.0) 5143 5144 # Make sure tensors are on expected devices. 5145 for tensor in [cc0, cc1]: 5146 self.assertRegex(tensor.backing_device, 'CPU:0') 5147 for tensor in [cg0, cg1]: 5148 self.assertRegex(tensor.backing_device, 'GPU:0') 5149 5150 @function.defun 5151 def func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1): 5152 with ops.device('/device:CPU:0'): 5153 m1 = rc0 * cg0 5154 with ops.device('/device:GPU:0'): 5155 m2 = rg0 * cc0 5156 5157 with ops.device('/device:CPU:0'): 5158 r1 = 1000.0 * m2 + rc1 * cg1 5159 with ops.device('/device:GPU:0'): 5160 r2 = 1000.0 * m1 + rg1 * cc1 5161 5162 return r1, r2, m2, m1 5163 5164 r1, r2, m2, m1 = func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1) 5165 self.assertRegex(m1.backing_device, 'CPU') 5166 self.assertRegex(r1.backing_device, 'CPU') 5167 self.assertRegex(m2.backing_device, 'GPU') 5168 self.assertRegex(r2.backing_device, 'GPU') 5169 self.assertEqual(m1.numpy(), 34.0) 5170 self.assertEqual(r1.numpy(), 55000.0 + 3.0 * 19.0) 5171 self.assertEqual(m2.numpy(), 55.0) 5172 self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0) 5173 5174 @test_util.run_gpu_only 5175 def testArgumentPruning(self): 5176 """Tests functions taking unnecessary arguments.""" 5177 with ops.device('/device:CPU:0'): 5178 c1 = constant_op.constant(5.0) 5179 c2 = constant_op.constant(7.0) 5180 5181 with ops.device('/device:GPU:0'): 5182 g1 = constant_op.constant(11.0) 5183 g2 = constant_op.constant(13.0) 5184 g3 = constant_op.constant(17.0) 5185 5186 @function.defun 5187 def func(g1, g2, c1, g3, c2): # pylint: disable=unused-argument 5188 # arguments g1 and g2 are unused and can be pruned by grappler. 5189 return c1 * g3 * c2 5190 5191 result = func(g1, g2, c1, g3, c2) 5192 self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0) 5193 5194 def testNestedCallWatchedVariables(self): 5195 5196 v = variables.Variable(4.) 5197 5198 @def_function.function 5199 def f(): 5200 return v ** 2. 5201 5202 with backprop.GradientTape() as tape: 5203 f() 5204 5205 self.assertEqual((v,), tape.watched_variables()) 5206 5207 @def_function.function 5208 def g(): 5209 return f() 5210 5211 with backprop.GradientTape() as tape: 5212 g() 5213 5214 self.assertEqual((v,), tape.watched_variables()) 5215 5216 # f() can rely on the variable being read during its trace. g() checks that 5217 # variables from a function which knows about them are recorded on the 5218 # tape. h() tests that functions forward knowledge of variables to callers. 5219 5220 @def_function.function 5221 def h(): 5222 return g() 5223 5224 with backprop.GradientTape() as tape: 5225 h() 5226 5227 self.assertEqual((v,), tape.watched_variables()) 5228 5229 def testReplaceCaptureWithDeferred(self): 5230 5231 x = constant_op.constant(1.0) 5232 y = constant_op.constant(2.0) 5233 z = constant_op.constant(3.0) 5234 5235 @def_function.function 5236 def fn(): 5237 a = x + y 5238 b = a + z 5239 return b 5240 5241 concrete_fn = fn.get_concrete_function() 5242 self.assertAllEqual(concrete_fn(), 6.0) 5243 5244 value = constant_op.constant(4.0) 5245 5246 def closure(): 5247 return value 5248 5249 concrete_fn.replace_capture_with_deferred_capture( 5250 concrete_fn.captured_inputs[1], 5251 closure, 5252 spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), 5253 placeholder=concrete_fn.inputs[1]) 5254 5255 self.assertAllEqual(concrete_fn(), 8.0) 5256 5257 value = constant_op.constant(5.0) 5258 self.assertAllEqual(concrete_fn(), 9.0) 5259 5260 def testRaiseReplaceCaptureWithDeferredTypeSpecMismatch(self): 5261 bool_captured_tensor = constant_op.constant(True) 5262 float_captured_tensor = constant_op.constant([3.], dtype=dtypes.float32) 5263 value = constant_op.constant([2.], dtype=dtypes.float32) 5264 5265 @def_function.function 5266 def fn(): 5267 deferred_tensor = ops.get_default_graph().capture_call_time_value( 5268 lambda: value, 5269 tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32)) 5270 if bool_captured_tensor: 5271 return deferred_tensor 5272 else: 5273 return deferred_tensor + float_captured_tensor 5274 5275 concrete_fn = fn.get_concrete_function() 5276 self.assertAllEqual(concrete_fn(), [2.]) 5277 5278 new_bool_captured_tensor = constant_op.constant(False) 5279 def bool_closure(): 5280 return new_bool_captured_tensor 5281 5282 # Test raise if replacing a bool capture with a closure of output type 5283 # float32 5284 new_float_captured_tensor = constant_op.constant([3.], dtype=dtypes.float32) 5285 def float_closure(): 5286 return new_float_captured_tensor 5287 5288 with self.assertRaisesRegex(ValueError, 5289 'Attempting to substitute closure with spec*'): 5290 concrete_fn.replace_capture_with_deferred_capture( 5291 bool_captured_tensor, 5292 float_closure, 5293 spec=tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32)) 5294 5295 # Test replace without a placeholder 5296 concrete_fn.replace_capture_with_deferred_capture( 5297 bool_captured_tensor, 5298 bool_closure, 5299 spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)) 5300 5301 self.assertAllEqual(concrete_fn(), [5.]) 5302 5303 def testConcreteFunctionSetExternalCapture(self): 5304 captured_tensor = constant_op.constant([1.]) 5305 value = constant_op.constant([2.]) 5306 5307 @def_function.function 5308 def fn(): 5309 deferred_tensor = ops.get_default_graph().capture_call_time_value( 5310 lambda: value, 5311 tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32)) 5312 return deferred_tensor + captured_tensor 5313 5314 cf = fn.get_concrete_function() 5315 self.assertLen(cf._captured_inputs, 2) 5316 self.assertEqual(list(map(callable, cf._captured_inputs)), [False, True]) 5317 self.assertAllEqual(cf(), [3.]) 5318 5319 # Reset capture to a deferred one, reset deferred capture to a capture. 5320 cf.set_external_captures([cf._captured_inputs[1], cf._captured_inputs[0]]) 5321 5322 value = constant_op.constant([3.]) 5323 self.assertAllEqual(cf(), [4.]) 5324 5325 def testGraphReplaceCaptureAndSetExternalCapture(self): 5326 bool_captured_tensor = constant_op.constant(True) 5327 float_captured_tensor = constant_op.constant([3.], dtype=dtypes.float32) 5328 value = constant_op.constant([2.], dtype=dtypes.float32) 5329 5330 @def_function.function 5331 def fn(): 5332 deferred_tensor = ops.get_default_graph().capture_call_time_value( 5333 lambda: value, 5334 tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32)) 5335 if bool_captured_tensor: 5336 return deferred_tensor 5337 else: 5338 return deferred_tensor + float_captured_tensor 5339 5340 concrete_fn = fn.get_concrete_function() 5341 self.assertAllEqual(concrete_fn(), [2.]) 5342 5343 new_bool_captured_tensor = constant_op.constant(False) 5344 5345 def closure(): 5346 return new_bool_captured_tensor 5347 5348 concrete_fn.graph.replace_capture_with_deferred_capture( 5349 concrete_fn.captured_inputs[0], 5350 closure, 5351 spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool), 5352 placeholder=concrete_fn.inputs[1]) 5353 5354 concrete_fn.set_external_captures([ 5355 closure, concrete_fn._captured_inputs[1], 5356 concrete_fn._captured_inputs[2] 5357 ]) 5358 self.assertAllEqual(concrete_fn(), [5.]) 5359 5360 def testDeferredCapture(self): 5361 value = 1.0 5362 5363 @def_function.function 5364 def lazy_capture(x): 5365 y = ops.get_default_graph().capture_call_time_value( 5366 lambda: value, tensor_spec.TensorSpec(None)) 5367 return x + y 5368 5369 self.assertAllEqual(lazy_capture(2.0), 3.0) 5370 # After changing the value of `value` the function call should return a 5371 # different result. 5372 value = 2.0 5373 self.assertAllEqual(lazy_capture(2.0), 4.0) 5374 5375 def testNestedDeferredCapture(self): 5376 value = 1.0 5377 5378 @def_function.function 5379 def inner(x): 5380 y = ops.get_default_graph().capture_call_time_value( 5381 lambda: value, tensor_spec.TensorSpec(None)) 5382 return x + y 5383 5384 @def_function.function 5385 def outer(x): 5386 return inner(x) 5387 5388 self.assertAllEqual(outer(2.0), 3.0) 5389 # After changing the value of `value` the function call should return a 5390 # different result. 5391 value = 2.0 5392 self.assertAllEqual(outer(2.0), 4.0) 5393 5394 def testNestedDeferredCaptureInTFWhileLoop(self): 5395 5396 value = 1. 5397 5398 @def_function.function 5399 def inner(x): 5400 y = ops.get_default_graph().capture_call_time_value( 5401 lambda: value, tensor_spec.TensorSpec(None)) 5402 return x + y 5403 5404 @def_function.function 5405 def outer(): 5406 dummy = constant_op.constant(True) 5407 sums = constant_op.constant(0.) 5408 while dummy: 5409 directives.set_loop_options( 5410 shape_invariants=[(sums, tensor_shape.TensorShape(None))]) 5411 sums += inner(2.) 5412 dummy = constant_op.constant(False) 5413 return sums 5414 5415 self.assertAllEqual(outer(), 3.) 5416 5417 value = constant_op.constant(2.) 5418 self.assertAllEqual(outer(), 4.) 5419 5420 value = constant_op.constant(3.) 5421 self.assertAllEqual(outer(), 5.) 5422 5423 def testDeferredCaptureWithKey(self): 5424 value0 = 1.0 5425 value1 = 2.0 5426 5427 @def_function.function 5428 def lazy_capture(x): 5429 w = ops.get_default_graph().capture_call_time_value( 5430 lambda: value0, tensor_spec.TensorSpec(None), key=0) 5431 y = ops.get_default_graph().capture_call_time_value( 5432 lambda: value1, tensor_spec.TensorSpec(None), key=1) 5433 def bad_closure(): 5434 raise ValueError('Should not run') 5435 z = ops.get_default_graph().capture_call_time_value( 5436 bad_closure, tensor_spec.TensorSpec(None), key=1) 5437 return x + y + w + z 5438 5439 self.assertAllEqual(lazy_capture(2.0), 7.0) 5440 value0 = 2.0 5441 value1 = 3.0 5442 self.assertAllEqual(lazy_capture(2.0), 10.0) 5443 5444 def testDeferredCaptureTypeError(self): 5445 value = constant_op.constant(1.0) 5446 5447 @def_function.function 5448 def lazy_capture(x): 5449 y = ops.get_default_graph().capture_call_time_value( 5450 lambda: value, tensor_spec.TensorSpec(())) 5451 return x + y 5452 5453 self.assertAllEqual(lazy_capture(2.0), 3.0) 5454 5455 # dtype mismatch 5456 value = constant_op.constant(1) 5457 with self.assertRaisesRegex(ValueError, 'Value .* to a tensor with dtype'): 5458 lazy_capture(2.0) 5459 5460 # shape mismatch 5461 value = constant_op.constant([1.0]) 5462 with self.assertRaisesRegex(ValueError, 'Value .* shape'): 5463 lazy_capture(2.0) 5464 5465 def testDeferredCaptureReturnNestWithCompositeTensor(self): 5466 i_s = indexed_slices.IndexedSlices( 5467 constant_op.constant([1, 2]), 5468 constant_op.constant([0, 1], dtype=dtypes.int64), 5469 constant_op.constant([2])) 5470 r_t = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]]) 5471 s_t = sparse_tensor.SparseTensor( 5472 values=[1, 2, 3], indices=[[0], [8], [10]], dense_shape=[20]) 5473 5474 @def_function.function 5475 def lazy_capture(): 5476 y = ops.get_default_graph().capture_call_time_value( 5477 lambda: {'i': i_s, 't': (r_t, s_t)}, 5478 {'i': indexed_slices.IndexedSlicesSpec( 5479 dtype=dtypes.int32, dense_shape_dtype=dtypes.int32), 5480 't': (ragged_tensor.RaggedTensorSpec([2, None, None], dtypes.int32), 5481 sparse_tensor.SparseTensorSpec([None], dtypes.int32))}) 5482 return y['i'], y['t'] 5483 5484 i, (r, s) = lazy_capture() 5485 self.assertAllEqual(i_s.values, i.values) 5486 self.assertAllEqual(i_s.indices, i.indices) 5487 self.assertAllEqual(i_s.dense_shape, i.dense_shape) 5488 self.assertAllEqual(r_t, r) 5489 self.assertAllEqual(s_t.indices, s.indices) 5490 self.assertAllEqual(s_t.values, s.values) 5491 self.assertAllEqual(s_t.dense_shape, s.dense_shape) 5492 5493 def testDeferredCaptureCompositeTensorSpecTypeMismatch(self): 5494 value = indexed_slices.IndexedSlices( 5495 constant_op.constant([1, 2]), 5496 constant_op.constant([0, 1], dtype=dtypes.int64)) 5497 5498 @def_function.function 5499 def lazy_capture(): 5500 return ops.get_default_graph().capture_call_time_value( 5501 lambda: value, 5502 indexed_slices.IndexedSlicesSpec(dtype=dtypes.int32)) 5503 5504 # Type matches spec. 5505 lazy_capture() 5506 5507 # Extra dense shape component. 5508 value = indexed_slices.IndexedSlices( 5509 constant_op.constant([1, 2]), 5510 constant_op.constant([0, 1], dtype=dtypes.int64), 5511 constant_op.constant([2])) 5512 with self.assertRaises(ValueError): 5513 lazy_capture() 5514 5515 # Index dtype mismatch int32 vs. int64. 5516 value = indexed_slices.IndexedSlices( 5517 constant_op.constant([1, 2]), 5518 constant_op.constant([0, 1])) 5519 with self.assertRaises(ValueError): 5520 lazy_capture() 5521 5522 def testMaybeCreateCapturePlaceholderWithValidCapture(self): 5523 @def_function.function 5524 def f(): 5525 func = lambda: x 5526 return ops.get_default_graph()._maybe_create_capture_placeholder(func) 5527 5528 x = { 5529 'tensor': constant_op.constant(0), 5530 'list': [constant_op.constant(1), 2], 5531 'dict': { 5532 'float': constant_op.constant(0.5) 5533 } 5534 } 5535 5536 out = f() 5537 # tf.function output should have same structure/values with the side input 5538 self.assertEqual(x['tensor'].numpy(), out['tensor'].numpy()) 5539 self.assertEqual(x['list'][0].numpy(), out['list'][0].numpy()) 5540 self.assertEqual(x['list'][1], out['list'][1].numpy()) 5541 self.assertEqual(x['dict']['float'].numpy(), out['dict']['float'].numpy()) 5542 5543 def testMaybeCreateCapturePlaceholderWithInvalidCapture(self): 5544 @def_function.function 5545 def f(): 5546 func = lambda: x 5547 return ops.get_default_graph()._maybe_create_capture_placeholder(func) 5548 5549 # Set is not supported 5550 x = set([1, 2]) 5551 with self.assertRaises(NotImplementedError): 5552 f() 5553 5554 # TODO(panzf): remove this test after exposing manual API, as the integration 5555 # testcase can be turned on at that time. 5556 def test_inner_nested_tf_function_raise_error(self): 5557 @def_function.function 5558 def tf_f(): 5559 5560 @def_function.function 5561 def tf_g(): 5562 cx = ops.get_default_graph()._experimental_capture_side_input_by_ref( # pylint: disable=protected-access 5563 'lambda: x', lambda: x) 5564 return cx 5565 5566 return tf_g() 5567 5568 x = constant_op.constant(0) # pylint: disable=unused-variable 5569 with self.assertRaisesRegex( 5570 NotImplementedError, 'Manual side input usage for inner nested'): 5571 tf_f() 5572 5573 @parameterized.parameters( 5574 (1, int, 2, int, 2), 5575 (1, constant_op.constant, 2, constant_op.constant, 1)) 5576 def testRetraceLogicWithSideInputs(self, val_before, type_before, val_after, 5577 type_after, expected_len): 5578 @def_function.function 5579 def f(): 5580 func = lambda: x 5581 return ops.get_default_graph()._experimental_capture_side_input_by_ref( # pylint: disable=protected-access 5582 'lambda: x', func) 5583 5584 x = type_before(val_before) 5585 _ = f() 5586 x = type_after(val_after) 5587 _ = f() 5588 self.assertLen(total_function_cache(f), expected_len) 5589 5590 def testFunctoolsLruCache(self): 5591 self.skipTest( 5592 "b/194845243: inspect.getfullargspec doesn't unwrap Python decorators.") 5593 5594 @def_function.function 5595 @functools.lru_cache(maxsize=2) 5596 def f(a): 5597 return 2 * a 5598 5599 self.assertAllEqual(f(1), array_ops.constant(2)) 5600 5601if __name__ == '__main__': 5602 ops.enable_eager_execution() 5603 test.main() 5604