1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import os 17 18 19from tensorflow.core.protobuf import meta_graph_pb2 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.eager import backprop 22from tensorflow.python.eager import def_function 23from tensorflow.python.eager import wrap_function 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import importer as graph_def_importer 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_spec 29from tensorflow.python.framework import test_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import init_ops 32from tensorflow.python.ops import state_ops 33from tensorflow.python.ops import variable_scope 34from tensorflow.python.ops import variables 35from tensorflow.python.ops.ragged import ragged_factory_ops 36from tensorflow.python.ops.ragged import ragged_tensor 37from tensorflow.python.platform import test 38from tensorflow.python.training import saver as saver_lib 39 40 41class WrapFunctionTest(test.TestCase): 42 43 def testDocString(self): 44 45 def f(x, do_add): 46 v = variables.Variable(5.0) 47 if do_add: 48 op = v.assign_add(x) 49 else: 50 op = v.assign_sub(x) 51 with ops.control_dependencies([op]): 52 return v.read_value() 53 54 f_add = wrap_function.wrap_function( 55 f, [tensor_spec.TensorSpec((), dtypes.float32), True]) 56 57 self.assertAllEqual(f_add(1.0), 6.0) 58 self.assertAllEqual(f_add(1.0), 7.0) 59 60 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set 61 # of variables, and possibly different non-template arguments. 62 f_sub = wrap_function.wrap_function( 63 f, [tensor_spec.TensorSpec((), dtypes.float32), False]) 64 65 self.assertAllEqual(f_sub(1.0), 4.0) 66 self.assertAllEqual(f_sub(1.0), 3.0) 67 68 def testPrune(self): 69 70 x_in = [] 71 x_out = [] 72 73 def f(x, y): 74 x_in.append(x) 75 xx = x * x 76 x_out.append(xx) 77 return xx, 2 * y*y 78 79 f_wrapped = wrap_function.wrap_function( 80 f, [tensor_spec.TensorSpec((), dtypes.float32)] * 2) 81 82 f_pruned = f_wrapped.prune(x_in[0], [x_out[0]]) 83 self.assertAllEqual(f_pruned(ops.convert_to_tensor(2.0)), [4.0]) 84 85 def testPruneRagged(self): 86 87 x_in = [] 88 x_out = [] 89 90 def f(x, y): 91 x_in.append(x) 92 xx = x * x 93 x_out.append(xx) 94 return xx, y * y 95 96 x_spec = ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32) 97 y_spec = tensor_spec.TensorSpec((), dtypes.float32) 98 99 f_wrapped = wrap_function.wrap_function(f, [x_spec, y_spec]) 100 101 f_pruned = f_wrapped.prune(x_in[0], x_out[0]) 102 rt = ragged_factory_ops.constant([[1.0, 2.0], [3.0]]) 103 expected = ragged_factory_ops.constant_value([[1.0, 4.0], [9.0]]) 104 105 # Note: when we call f_pruned, we must pass the RaggedTensor in using 106 # its components, since that's the current convention for how concrete 107 # functions handle structured inputs. 108 self.assertAllEqual(f_pruned(rt.values, rt.row_splits), expected) 109 110 def _assert_single_captured_variable_argument(self, graph_def): 111 # The single FunctionDef should have one argument, a captured variable 112 function_def, = graph_def.library.function 113 self.assertLen(function_def.signature.input_arg, 1) 114 function_arg, = function_def.signature.input_arg 115 self.assertEqual(dtypes.resource, dtypes.as_dtype(function_arg.type)) 116 117 def testVariableLifting(self): 118 save_prefix = os.path.join(self.get_temp_dir(), 'meta_graph_test') 119 120 export_graph = ops.Graph() 121 with export_graph.as_default(): 122 v = variables.Variable(1.) 123 array_ops.identity(v + 1., name='output') 124 saver = saver_lib.Saver([v]) 125 with self.test_session() as session: 126 session.run(v.initializer) 127 saver.save(session, save_prefix) 128 129 def importer(): 130 saver_lib.import_meta_graph(save_prefix + '.meta') 131 return ops.get_default_graph().as_graph_element('output:0') 132 133 wrapped = wrap_function.wrap_function(importer, []) 134 lifted_variables = list(wrapped.graph.variables) 135 self.assertLen(lifted_variables, 1) 136 initializer = wrapped.prune( 137 [], wrapped.graph.as_graph_element(v.initializer.name)) 138 self.assertEqual(lifted_variables, list(initializer.graph.variables)) 139 self.assertEqual(initializer.graph.external_captures, 140 wrapped.graph.external_captures) 141 142 @def_function.function 143 def wraps_initializer(): 144 initializer() 145 146 wraps_initializer() 147 self.assertEqual(1., lifted_variables[0].numpy()) 148 wrapped_initializer_graphdef = ( 149 wraps_initializer.get_concrete_function().graph.as_graph_def()) 150 self._assert_single_captured_variable_argument(wrapped_initializer_graphdef) 151 152 @def_function.function 153 def wraps_wrapped(): 154 return wrapped() 155 156 # Verify that the original graph also has the correct signature. 157 wrapped_wrapped_graphdef = ( 158 wraps_wrapped.get_concrete_function().graph.as_graph_def()) 159 self._assert_single_captured_variable_argument(wrapped_wrapped_graphdef) 160 # Now check that the graph runs wrapped, from eager, and when pruned. 161 self.assertAllEqual(wraps_wrapped().numpy(), 162 lifted_variables[0].numpy() + 1.) 163 self.assertAllEqual(wrapped().numpy(), lifted_variables[0].numpy() + 1.) 164 pruned = wrapped.prune([], wrapped.graph.as_graph_element('output:0')) 165 self.assertAllEqual(wrapped().numpy(), pruned().numpy()) 166 167 def testNoArguments(self): 168 169 def f(): 170 return constant_op.constant(1.) 171 172 f_wrapped = wrap_function.wrap_function(f, []) 173 self.assertAllEqual(1.0, f_wrapped()) 174 175 def testPruneCaptures(self): 176 177 v1 = variables.Variable(2.) 178 179 def f(): 180 v2 = variables.Variable(3.) 181 return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch') 182 183 f_wrapped = wrap_function.wrap_function(f, []) 184 self.assertAllEqual(6.0, f_wrapped()) 185 186 # Test pruning directly on the inputs 187 pruned = f_wrapped.prune( 188 feeds=f_wrapped.inputs, 189 fetches=f_wrapped.graph.get_tensor_by_name('fetch:0')) 190 self.assertAllEqual(6.0, pruned()) 191 192 # Test pruning with no inputs 193 pruned = f_wrapped.prune( 194 feeds=(), 195 fetches=f_wrapped.graph.get_tensor_by_name('fetch:0')) 196 self.assertAllEqual(6.0, pruned()) 197 198 def testCollectionsIsolation(self): 199 200 v1 = variables.Variable(2.) 201 v2_holder = [] 202 def f(): 203 v2 = variables.Variable(3.) 204 v2_holder.append(v2) 205 ops.add_to_collection(ops.GraphKeys.LOSSES, v2 * constant_op.constant(3.)) 206 return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch') 207 208 f_wrapped = wrap_function.wrap_function(f, []) 209 self.assertAllEqual(6.0, f_wrapped()) 210 self.assertEqual( 211 len(f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1) 212 f_var_collection = f_wrapped.graph.get_collection( 213 ops.GraphKeys.TRAINABLE_VARIABLES) 214 self.assertEqual(len(f_var_collection), 1) 215 self.assertIs(f_var_collection[0], v2_holder[0]) 216 217 v3_holder = [] 218 def g(): 219 v3 = variables.Variable(4.) 220 v3_holder.append(v3) 221 ops.add_to_collection(ops.GraphKeys.LOSSES, v3 * constant_op.constant(3.)) 222 return array_ops.identity(v1 * v3 * constant_op.constant(1.), 'fetch') 223 224 g_wrapped = wrap_function.wrap_function(g, []) 225 self.assertAllEqual(8.0, g_wrapped()) 226 self.assertEqual( 227 len(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1) 228 g_var_collection = g_wrapped.graph.get_collection( 229 ops.GraphKeys.TRAINABLE_VARIABLES) 230 self.assertEqual(len(g_var_collection), 1) 231 self.assertIs(g_var_collection[0], v3_holder[0]) 232 233 # Both have only one value, and their values aren't equal. So no sharing. 234 self.assertIsNot(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES[0]), 235 f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)[0]) 236 237 def testGradientsOfPrune(self): 238 239 v1 = variables.Variable(2.) 240 v2_holder = [] 241 242 def f(z): 243 v2 = variables.Variable(3.) 244 v2_holder.append(v2) 245 return array_ops.identity(v1 * v2 * z, 'fetch') 246 247 f_wrapped = wrap_function.wrap_function( 248 f, [tensor_spec.TensorSpec((), dtype=dtypes.float32)]) 249 250 x = constant_op.constant(1.) 251 with backprop.GradientTape() as tape: 252 tape.watch(x) 253 out = f_wrapped(x) 254 grads = tape.gradient(out, [x, v1, v2_holder[0]]) 255 256 self.assertAllEqual(6.0, out) 257 self.assertAllEqual([6.0, 3.0, 2.0], grads) 258 259 pruned = f_wrapped.prune( 260 feeds=f_wrapped.inputs, 261 fetches=f_wrapped.graph.get_tensor_by_name('fetch:0')) 262 263 x = constant_op.constant(1.) 264 with backprop.GradientTape() as tape: 265 tape.watch(x) 266 out = pruned(x) 267 grads = tape.gradient(out, [x, v1, v2_holder[0]]) 268 269 self.assertAllEqual(6.0, out) 270 self.assertAllEqual([6.0, 3.0, 2.0], grads) 271 272 def testPruneOperations(self): 273 274 v = variables.Variable(0) 275 276 def f(): 277 v.assign_add(1, name='increment', read_value=False) 278 279 f_wrapped = wrap_function.wrap_function(f, []) 280 pruned = f_wrapped.prune( 281 feeds=(), 282 fetches=(f_wrapped.graph.get_operation_by_name('increment'),)) 283 self.assertEqual((None,), pruned()) 284 self.assertEqual(1, self.evaluate(v)) 285 286 del f, f_wrapped 287 288 def f1(): 289 v.assign_add( 290 array_ops.placeholder(shape=[], dtype=dtypes.int32, name='step'), 291 name='increment', read_value=False) 292 return constant_op.constant(1, name='other') 293 294 f_wrapped = wrap_function.wrap_function(f1, []) 295 increments = f_wrapped.prune( 296 feeds=(f_wrapped.graph.get_tensor_by_name('step:0')), 297 fetches=(f_wrapped.graph.get_operation_by_name('increment'), 298 f_wrapped.graph.get_tensor_by_name('other:0'))) 299 first_output, second_output = increments(constant_op.constant(2)) 300 self.assertEqual(['step:0', 'increment/resource:0'], 301 [t.name for t in increments.inputs]) 302 self.assertIs(None, first_output) 303 self.assertEqual(1, second_output.numpy()) 304 self.assertEqual(3, v.numpy()) 305 does_not_increment = f_wrapped.prune( 306 feeds=(f_wrapped.graph.get_tensor_by_name('step:0')), 307 fetches=f_wrapped.graph.get_tensor_by_name('other:0')) 308 self.assertEqual(1, does_not_increment(constant_op.constant(3)).numpy()) 309 self.assertEqual(3, v.numpy()) 310 311 def testPruneStatefulOpsFromWrappedFunc(self): 312 313 v0 = variables.Variable(0) 314 v1 = variables.Variable(0) 315 316 # When we wrap a function, we expect it to be executed with 'tf.Graph` 317 # rules: it's allowed to prune all ops that are not in transitive fanin of 318 # the fetches. 319 def f(x): 320 v0.assign_add(1, name='increment_v0') 321 v1.assign_add(1, name='increment_v1') 322 return x 323 324 f_wrapped = wrap_function.wrap_function(f, [1]) 325 326 self.assertEqual(1, f_wrapped().numpy()) 327 self.assertEqual(0, v0.numpy()) 328 self.assertEqual(0, v1.numpy()) 329 330 f_wrapped_with_name = wrap_function.wrap_function(f, [2], name='func') 331 332 self.assertEqual(2, f_wrapped_with_name().numpy()) 333 self.assertEqual(0, v0.numpy()) 334 self.assertEqual(0, v1.numpy()) 335 336 def test_operation_returned(self): 337 338 v = variables.Variable(0) 339 340 def f(): 341 v.assign(1, read_value=False, name='assign_to_v') 342 343 f_wrapped = wrap_function.wrap_function(f, []) 344 operation_to_fetch = f_wrapped.graph.get_operation_by_name('assign_to_v') 345 f_pruned = f_wrapped.prune( 346 [], operation_to_fetch) 347 self.assertEqual( 348 ['assign_to_v'], 349 [operation.name for operation in f_pruned.graph.control_outputs]) 350 self.assertEqual(0, v.numpy()) 351 f_pruned() 352 self.assertEqual(1, v.numpy()) 353 f_wrapped.prune([], 'assign_to_v')() 354 f_wrapped.prune([], meta_graph_pb2.TensorInfo(name='assign_to_v'))() 355 356 def test_function_from_graph_def(self): 357 @def_function.function 358 def make_graph_def(x): 359 return x + 1. 360 361 original_func_graph = make_graph_def.get_concrete_function( 362 tensor_spec.TensorSpec([None, 2], dtypes.float32)).graph 363 graph_def = original_func_graph.as_graph_def() 364 revived_function = wrap_function.function_from_graph_def( 365 graph_def, inputs=original_func_graph.inputs[0].name, 366 outputs=original_func_graph.outputs[0].name) 367 self.assertEqual(2., revived_function(constant_op.constant(1.)).numpy()) 368 369 def test_create_variables_with_same_name(self): 370 def f(): 371 v1 = variables.Variable(0, name='v') 372 v2 = variables.Variable(1, name='v') 373 return v1, v2 374 375 f_wrapped = wrap_function.wrap_function(f, []) 376 self.assertDictEqual( 377 {'v:0': 0, 'v_1:0': 1}, # assert that variable names are uniquified 378 {v.name: v.numpy() 379 for v in f_wrapped._variable_holder.variables.values()}) 380 381 # Uniquification should reset in separate calls to wrap_function. 382 def f2(): 383 v1 = variables.Variable(3, name='v') 384 v2 = variables.Variable(4, name='v') 385 return v1, v2 386 387 f_wrapped_2 = wrap_function.wrap_function(f2, []) 388 self.assertDictEqual( 389 {'v:0': 3, 'v_1:0': 4}, 390 {v.name: v.numpy() 391 for v in f_wrapped_2._variable_holder.variables.values()}) 392 393 394class WrappedGraphTest(test.TestCase): 395 396 def testAddFunction(self): 397 398 def fn(x): 399 v = variables.Variable(3, name='v') 400 v2 = variable_scope.get_variable( 401 'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32) 402 return v + v2 + x 403 404 with self.cached_session() as sess: 405 result = fn(constant_op.constant(5)) 406 sess.run(variables.global_variables_initializer()) 407 expected = sess.run(result) 408 409 g = wrap_function.WrappedGraph() 410 signature = [tensor_spec.TensorSpec([], dtypes.int32)] 411 wrapped_fn = g.wrap_function(fn, signature) 412 self.assertEqual(expected, wrapped_fn(constant_op.constant(5)).numpy()) 413 414 def testCollections(self): 415 416 def fn(x): 417 v = variables.VariableV1(3, name='v', trainable=False, collections=['a']) 418 v2 = variable_scope.get_variable( 419 'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32, 420 collections=['a', 'b']) 421 return v + v2 + x 422 423 def assert_collections(graph): 424 self.assertLen(graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), 1) 425 self.assertLen(graph.get_collection('a'), 2) 426 self.assertLen(graph.get_collection('b'), 1) 427 428 g = wrap_function.WrappedGraph() 429 g.wrap_function(fn, [tensor_spec.TensorSpec([], dtypes.int32)]) 430 assert_collections(g.graph) 431 432 def assert_fn(): 433 assert_collections(ops.get_default_graph()) 434 return 1 # Return is required 435 436 # Assert that collections are accessible within a wrapped function. 437 g.wrap_function(assert_fn, []) 438 439 def testShareVariablesSameGraph(self): 440 441 def add_v1(x): 442 with variable_scope.variable_scope( 443 'reuse', reuse=variable_scope.AUTO_REUSE): 444 v = variable_scope.get_variable( 445 'v', initializer=init_ops.Constant(3), shape=[], dtype=dtypes.int32) 446 return v + x 447 448 def subtract_v1(x): 449 with variable_scope.variable_scope( 450 'reuse', reuse=variable_scope.AUTO_REUSE): 451 v = variable_scope.get_variable( 452 'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32) 453 return v - x 454 455 def different_variable_fn_v1(x): 456 with variable_scope.variable_scope( 457 'no_reuse', reuse=variable_scope.AUTO_REUSE): 458 v = variable_scope.get_variable( 459 'v', initializer=init_ops.Constant(5), shape=[], dtype=dtypes.int32) 460 return v * x 461 462 def increment_variable_v1(x): 463 with variable_scope.variable_scope( 464 'reuse', reuse=variable_scope.AUTO_REUSE): 465 v = variable_scope.get_variable( 466 'v', initializer=init_ops.Constant(6), shape=[], dtype=dtypes.int32) 467 return v.assign_add(x) 468 469 g = wrap_function.WrappedGraph() 470 signature = [tensor_spec.TensorSpec([], dtypes.int32)] 471 add = g.wrap_function(add_v1, signature) 472 subtract = g.wrap_function(subtract_v1, signature) 473 different_variable_fn = g.wrap_function(different_variable_fn_v1, signature) 474 increment_variable = g.wrap_function(increment_variable_v1, signature) 475 476 self.assertEqual(10, add(constant_op.constant(7)).numpy()) 477 self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) 478 479 # The shared variable has a starting value of 3 because add_v1 was wrapped 480 # first. 481 self.assertEqual(-4, subtract(constant_op.constant(7)).numpy()) 482 self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy()) 483 484 # Check that variable updates 485 self.assertEqual(17, add(constant_op.constant(7)).numpy()) 486 self.assertEqual(3, subtract(constant_op.constant(7)).numpy()) 487 488 # Sanity check - result from this function shouldn't change. 489 self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) 490 491 self.assertAllEqual({'reuse/v', 'no_reuse/v'}, set(g.variables.keys())) 492 493 def testShareVariablesDifferentGraphs(self): 494 495 def add_v1(x): 496 v = variables.Variable(3, name='v') 497 return v + x 498 499 def subtract_v1(x): 500 v = variables.Variable(4, name='v') 501 return v - x 502 503 def different_variable_fn_v1(x): 504 with ops.name_scope('different_scope'): 505 v = variables.Variable(5, name='v') 506 return v * x 507 508 def increment_variable_v1(x): 509 v = variables.Variable(6, name='v') 510 return v.assign_add(x) 511 512 signature = [tensor_spec.TensorSpec([], dtypes.int32)] 513 vh = wrap_function.VariableHolder(share_variables=True) 514 new_graph = lambda: wrap_function.WrappedGraph(variable_holder=vh) 515 516 add = new_graph().wrap_function(add_v1, signature) 517 subtract = new_graph().wrap_function(subtract_v1, signature) 518 different_variable_fn = new_graph().wrap_function( 519 different_variable_fn_v1, signature) 520 increment_variable = new_graph().wrap_function( 521 increment_variable_v1, signature) 522 523 self.assertEqual(10, add(constant_op.constant(7)).numpy()) 524 self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) 525 526 # Because the variable in add_v1 was created first, its starting value is 3 527 # instead of the values defined in subtract_v1 or increment_variable_v1. 528 self.assertEqual(-4, subtract(constant_op.constant(7)).numpy()) 529 self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy()) 530 531 # Check that variable updates 532 self.assertEqual(17, add(constant_op.constant(7)).numpy()) 533 self.assertEqual(3, subtract(constant_op.constant(7)).numpy()) 534 535 # Sanity check - result from this function shouldn't change. 536 self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) 537 538 self.assertAllEqual({'v', 'different_scope/v'}, set(vh.variables.keys())) 539 540 @test_util.run_in_graph_and_eager_modes 541 def testImportedFunctionsRegistered(self): 542 if test_util.is_gpu_available(): 543 self.skipTest('not a GPU test') 544 with ops.Graph().as_default() as graph: 545 x = array_ops.placeholder(dtypes.variant, shape=[], name='foo') 546 ds = dataset_ops.from_variant(x, structure=( 547 tensor_spec.TensorSpec([], dtypes.int32))) 548 y = ds.reduce(array_ops.zeros([], dtype=dtypes.int32), lambda p, q: p + q) 549 550 graph_def = graph.as_graph_def() 551 552 def fn_to_wrap(a): 553 returned_elements = graph_def_importer.import_graph_def( 554 graph_def, input_map={x.name: a}, return_elements=[y.name]) 555 return returned_elements[0] 556 557 wrapped_fn = wrap_function.wrap_function( 558 fn_to_wrap, [tensor_spec.TensorSpec((), dtypes.variant)]) 559 ds = dataset_ops.Dataset.from_tensor_slices([10, 20]) 560 v = dataset_ops.to_variant(ds) 561 self.evaluate(wrapped_fn(v)) 562 563 def testReturnOp(self): 564 565 def update_var_v1(x): 566 v = variables.Variable(3, name='v') 567 update_op = state_ops.assign(v, x).op 568 return update_op 569 570 g = wrap_function.WrappedGraph() 571 signature = [tensor_spec.TensorSpec([], dtypes.int32)] 572 update_var = g.wrap_function(update_var_v1, signature) 573 574 self.assertEqual(g.variables['v'].numpy(), 3) 575 update_var(constant_op.constant(12)) 576 self.assertEqual(g.variables['v'].numpy(), 12) 577 578 579if __name__ == '__main__': 580 ops.enable_eager_execution() 581 test.main() 582