xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/wrap_function_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import os
17
18
19from tensorflow.core.protobuf import meta_graph_pb2
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.eager import backprop
22from tensorflow.python.eager import def_function
23from tensorflow.python.eager import wrap_function
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import importer as graph_def_importer
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_spec
29from tensorflow.python.framework import test_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import init_ops
32from tensorflow.python.ops import state_ops
33from tensorflow.python.ops import variable_scope
34from tensorflow.python.ops import variables
35from tensorflow.python.ops.ragged import ragged_factory_ops
36from tensorflow.python.ops.ragged import ragged_tensor
37from tensorflow.python.platform import test
38from tensorflow.python.training import saver as saver_lib
39
40
41class WrapFunctionTest(test.TestCase):
42
43  def testDocString(self):
44
45    def f(x, do_add):
46      v = variables.Variable(5.0)
47      if do_add:
48        op = v.assign_add(x)
49      else:
50        op = v.assign_sub(x)
51      with ops.control_dependencies([op]):
52        return v.read_value()
53
54    f_add = wrap_function.wrap_function(
55        f, [tensor_spec.TensorSpec((), dtypes.float32), True])
56
57    self.assertAllEqual(f_add(1.0), 6.0)
58    self.assertAllEqual(f_add(1.0), 7.0)
59
60    # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
61    # of variables, and possibly different non-template arguments.
62    f_sub = wrap_function.wrap_function(
63        f, [tensor_spec.TensorSpec((), dtypes.float32), False])
64
65    self.assertAllEqual(f_sub(1.0), 4.0)
66    self.assertAllEqual(f_sub(1.0), 3.0)
67
68  def testPrune(self):
69
70    x_in = []
71    x_out = []
72
73    def f(x, y):
74      x_in.append(x)
75      xx = x * x
76      x_out.append(xx)
77      return xx, 2 * y*y
78
79    f_wrapped = wrap_function.wrap_function(
80        f, [tensor_spec.TensorSpec((), dtypes.float32)] * 2)
81
82    f_pruned = f_wrapped.prune(x_in[0], [x_out[0]])
83    self.assertAllEqual(f_pruned(ops.convert_to_tensor(2.0)), [4.0])
84
85  def testPruneRagged(self):
86
87    x_in = []
88    x_out = []
89
90    def f(x, y):
91      x_in.append(x)
92      xx = x * x
93      x_out.append(xx)
94      return xx, y * y
95
96    x_spec = ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32)
97    y_spec = tensor_spec.TensorSpec((), dtypes.float32)
98
99    f_wrapped = wrap_function.wrap_function(f, [x_spec, y_spec])
100
101    f_pruned = f_wrapped.prune(x_in[0], x_out[0])
102    rt = ragged_factory_ops.constant([[1.0, 2.0], [3.0]])
103    expected = ragged_factory_ops.constant_value([[1.0, 4.0], [9.0]])
104
105    # Note: when we call f_pruned, we must pass the RaggedTensor in using
106    # its components, since that's the current convention for how concrete
107    # functions handle structured inputs.
108    self.assertAllEqual(f_pruned(rt.values, rt.row_splits), expected)
109
110  def _assert_single_captured_variable_argument(self, graph_def):
111    # The single FunctionDef should have one argument, a captured variable
112    function_def, = graph_def.library.function
113    self.assertLen(function_def.signature.input_arg, 1)
114    function_arg, = function_def.signature.input_arg
115    self.assertEqual(dtypes.resource, dtypes.as_dtype(function_arg.type))
116
117  def testVariableLifting(self):
118    save_prefix = os.path.join(self.get_temp_dir(), 'meta_graph_test')
119
120    export_graph = ops.Graph()
121    with export_graph.as_default():
122      v = variables.Variable(1.)
123      array_ops.identity(v + 1., name='output')
124      saver = saver_lib.Saver([v])
125      with self.test_session() as session:
126        session.run(v.initializer)
127        saver.save(session, save_prefix)
128
129    def importer():
130      saver_lib.import_meta_graph(save_prefix + '.meta')
131      return ops.get_default_graph().as_graph_element('output:0')
132
133    wrapped = wrap_function.wrap_function(importer, [])
134    lifted_variables = list(wrapped.graph.variables)
135    self.assertLen(lifted_variables, 1)
136    initializer = wrapped.prune(
137        [], wrapped.graph.as_graph_element(v.initializer.name))
138    self.assertEqual(lifted_variables, list(initializer.graph.variables))
139    self.assertEqual(initializer.graph.external_captures,
140                     wrapped.graph.external_captures)
141
142    @def_function.function
143    def wraps_initializer():
144      initializer()
145
146    wraps_initializer()
147    self.assertEqual(1., lifted_variables[0].numpy())
148    wrapped_initializer_graphdef = (
149        wraps_initializer.get_concrete_function().graph.as_graph_def())
150    self._assert_single_captured_variable_argument(wrapped_initializer_graphdef)
151
152    @def_function.function
153    def wraps_wrapped():
154      return wrapped()
155
156    # Verify that the original graph also has the correct signature.
157    wrapped_wrapped_graphdef = (
158        wraps_wrapped.get_concrete_function().graph.as_graph_def())
159    self._assert_single_captured_variable_argument(wrapped_wrapped_graphdef)
160    # Now check that the graph runs wrapped, from eager, and when pruned.
161    self.assertAllEqual(wraps_wrapped().numpy(),
162                        lifted_variables[0].numpy() + 1.)
163    self.assertAllEqual(wrapped().numpy(), lifted_variables[0].numpy() + 1.)
164    pruned = wrapped.prune([], wrapped.graph.as_graph_element('output:0'))
165    self.assertAllEqual(wrapped().numpy(), pruned().numpy())
166
167  def testNoArguments(self):
168
169    def f():
170      return constant_op.constant(1.)
171
172    f_wrapped = wrap_function.wrap_function(f, [])
173    self.assertAllEqual(1.0, f_wrapped())
174
175  def testPruneCaptures(self):
176
177    v1 = variables.Variable(2.)
178
179    def f():
180      v2 = variables.Variable(3.)
181      return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch')
182
183    f_wrapped = wrap_function.wrap_function(f, [])
184    self.assertAllEqual(6.0, f_wrapped())
185
186    # Test pruning directly on the inputs
187    pruned = f_wrapped.prune(
188        feeds=f_wrapped.inputs,
189        fetches=f_wrapped.graph.get_tensor_by_name('fetch:0'))
190    self.assertAllEqual(6.0, pruned())
191
192    # Test pruning with no inputs
193    pruned = f_wrapped.prune(
194        feeds=(),
195        fetches=f_wrapped.graph.get_tensor_by_name('fetch:0'))
196    self.assertAllEqual(6.0, pruned())
197
198  def testCollectionsIsolation(self):
199
200    v1 = variables.Variable(2.)
201    v2_holder = []
202    def f():
203      v2 = variables.Variable(3.)
204      v2_holder.append(v2)
205      ops.add_to_collection(ops.GraphKeys.LOSSES, v2 * constant_op.constant(3.))
206      return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch')
207
208    f_wrapped = wrap_function.wrap_function(f, [])
209    self.assertAllEqual(6.0, f_wrapped())
210    self.assertEqual(
211        len(f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1)
212    f_var_collection = f_wrapped.graph.get_collection(
213        ops.GraphKeys.TRAINABLE_VARIABLES)
214    self.assertEqual(len(f_var_collection), 1)
215    self.assertIs(f_var_collection[0], v2_holder[0])
216
217    v3_holder = []
218    def g():
219      v3 = variables.Variable(4.)
220      v3_holder.append(v3)
221      ops.add_to_collection(ops.GraphKeys.LOSSES, v3 * constant_op.constant(3.))
222      return array_ops.identity(v1 * v3 * constant_op.constant(1.), 'fetch')
223
224    g_wrapped = wrap_function.wrap_function(g, [])
225    self.assertAllEqual(8.0, g_wrapped())
226    self.assertEqual(
227        len(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1)
228    g_var_collection = g_wrapped.graph.get_collection(
229        ops.GraphKeys.TRAINABLE_VARIABLES)
230    self.assertEqual(len(g_var_collection), 1)
231    self.assertIs(g_var_collection[0], v3_holder[0])
232
233    # Both have only one value, and their values aren't equal. So no sharing.
234    self.assertIsNot(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES[0]),
235                     f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)[0])
236
237  def testGradientsOfPrune(self):
238
239    v1 = variables.Variable(2.)
240    v2_holder = []
241
242    def f(z):
243      v2 = variables.Variable(3.)
244      v2_holder.append(v2)
245      return array_ops.identity(v1 * v2 * z, 'fetch')
246
247    f_wrapped = wrap_function.wrap_function(
248        f, [tensor_spec.TensorSpec((), dtype=dtypes.float32)])
249
250    x = constant_op.constant(1.)
251    with backprop.GradientTape() as tape:
252      tape.watch(x)
253      out = f_wrapped(x)
254    grads = tape.gradient(out, [x, v1, v2_holder[0]])
255
256    self.assertAllEqual(6.0, out)
257    self.assertAllEqual([6.0, 3.0, 2.0], grads)
258
259    pruned = f_wrapped.prune(
260        feeds=f_wrapped.inputs,
261        fetches=f_wrapped.graph.get_tensor_by_name('fetch:0'))
262
263    x = constant_op.constant(1.)
264    with backprop.GradientTape() as tape:
265      tape.watch(x)
266      out = pruned(x)
267    grads = tape.gradient(out, [x, v1, v2_holder[0]])
268
269    self.assertAllEqual(6.0, out)
270    self.assertAllEqual([6.0, 3.0, 2.0], grads)
271
272  def testPruneOperations(self):
273
274    v = variables.Variable(0)
275
276    def f():
277      v.assign_add(1, name='increment', read_value=False)
278
279    f_wrapped = wrap_function.wrap_function(f, [])
280    pruned = f_wrapped.prune(
281        feeds=(),
282        fetches=(f_wrapped.graph.get_operation_by_name('increment'),))
283    self.assertEqual((None,), pruned())
284    self.assertEqual(1, self.evaluate(v))
285
286    del f, f_wrapped
287
288    def f1():
289      v.assign_add(
290          array_ops.placeholder(shape=[], dtype=dtypes.int32, name='step'),
291          name='increment', read_value=False)
292      return constant_op.constant(1, name='other')
293
294    f_wrapped = wrap_function.wrap_function(f1, [])
295    increments = f_wrapped.prune(
296        feeds=(f_wrapped.graph.get_tensor_by_name('step:0')),
297        fetches=(f_wrapped.graph.get_operation_by_name('increment'),
298                 f_wrapped.graph.get_tensor_by_name('other:0')))
299    first_output, second_output = increments(constant_op.constant(2))
300    self.assertEqual(['step:0', 'increment/resource:0'],
301                     [t.name for t in increments.inputs])
302    self.assertIs(None, first_output)
303    self.assertEqual(1, second_output.numpy())
304    self.assertEqual(3, v.numpy())
305    does_not_increment = f_wrapped.prune(
306        feeds=(f_wrapped.graph.get_tensor_by_name('step:0')),
307        fetches=f_wrapped.graph.get_tensor_by_name('other:0'))
308    self.assertEqual(1, does_not_increment(constant_op.constant(3)).numpy())
309    self.assertEqual(3, v.numpy())
310
311  def testPruneStatefulOpsFromWrappedFunc(self):
312
313    v0 = variables.Variable(0)
314    v1 = variables.Variable(0)
315
316    # When we wrap a function, we expect it to be executed with 'tf.Graph`
317    # rules: it's allowed to prune all ops that are not in transitive fanin of
318    # the fetches.
319    def f(x):
320      v0.assign_add(1, name='increment_v0')
321      v1.assign_add(1, name='increment_v1')
322      return x
323
324    f_wrapped = wrap_function.wrap_function(f, [1])
325
326    self.assertEqual(1, f_wrapped().numpy())
327    self.assertEqual(0, v0.numpy())
328    self.assertEqual(0, v1.numpy())
329
330    f_wrapped_with_name = wrap_function.wrap_function(f, [2], name='func')
331
332    self.assertEqual(2, f_wrapped_with_name().numpy())
333    self.assertEqual(0, v0.numpy())
334    self.assertEqual(0, v1.numpy())
335
336  def test_operation_returned(self):
337
338    v = variables.Variable(0)
339
340    def f():
341      v.assign(1, read_value=False, name='assign_to_v')
342
343    f_wrapped = wrap_function.wrap_function(f, [])
344    operation_to_fetch = f_wrapped.graph.get_operation_by_name('assign_to_v')
345    f_pruned = f_wrapped.prune(
346        [], operation_to_fetch)
347    self.assertEqual(
348        ['assign_to_v'],
349        [operation.name for operation in f_pruned.graph.control_outputs])
350    self.assertEqual(0, v.numpy())
351    f_pruned()
352    self.assertEqual(1, v.numpy())
353    f_wrapped.prune([], 'assign_to_v')()
354    f_wrapped.prune([], meta_graph_pb2.TensorInfo(name='assign_to_v'))()
355
356  def test_function_from_graph_def(self):
357    @def_function.function
358    def make_graph_def(x):
359      return x + 1.
360
361    original_func_graph = make_graph_def.get_concrete_function(
362        tensor_spec.TensorSpec([None, 2], dtypes.float32)).graph
363    graph_def = original_func_graph.as_graph_def()
364    revived_function = wrap_function.function_from_graph_def(
365        graph_def, inputs=original_func_graph.inputs[0].name,
366        outputs=original_func_graph.outputs[0].name)
367    self.assertEqual(2., revived_function(constant_op.constant(1.)).numpy())
368
369  def test_create_variables_with_same_name(self):
370    def f():
371      v1 = variables.Variable(0, name='v')
372      v2 = variables.Variable(1, name='v')
373      return v1, v2
374
375    f_wrapped = wrap_function.wrap_function(f, [])
376    self.assertDictEqual(
377        {'v:0': 0, 'v_1:0': 1},  # assert that variable names are uniquified
378        {v.name: v.numpy()
379         for v in f_wrapped._variable_holder.variables.values()})
380
381    # Uniquification should reset in separate calls to wrap_function.
382    def f2():
383      v1 = variables.Variable(3, name='v')
384      v2 = variables.Variable(4, name='v')
385      return v1, v2
386
387    f_wrapped_2 = wrap_function.wrap_function(f2, [])
388    self.assertDictEqual(
389        {'v:0': 3, 'v_1:0': 4},
390        {v.name: v.numpy()
391         for v in f_wrapped_2._variable_holder.variables.values()})
392
393
394class WrappedGraphTest(test.TestCase):
395
396  def testAddFunction(self):
397
398    def fn(x):
399      v = variables.Variable(3, name='v')
400      v2 = variable_scope.get_variable(
401          'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32)
402      return v + v2 + x
403
404    with self.cached_session() as sess:
405      result = fn(constant_op.constant(5))
406      sess.run(variables.global_variables_initializer())
407      expected = sess.run(result)
408
409    g = wrap_function.WrappedGraph()
410    signature = [tensor_spec.TensorSpec([], dtypes.int32)]
411    wrapped_fn = g.wrap_function(fn, signature)
412    self.assertEqual(expected, wrapped_fn(constant_op.constant(5)).numpy())
413
414  def testCollections(self):
415
416    def fn(x):
417      v = variables.VariableV1(3, name='v', trainable=False, collections=['a'])
418      v2 = variable_scope.get_variable(
419          'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32,
420          collections=['a', 'b'])
421      return v + v2 + x
422
423    def assert_collections(graph):
424      self.assertLen(graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), 1)
425      self.assertLen(graph.get_collection('a'), 2)
426      self.assertLen(graph.get_collection('b'), 1)
427
428    g = wrap_function.WrappedGraph()
429    g.wrap_function(fn, [tensor_spec.TensorSpec([], dtypes.int32)])
430    assert_collections(g.graph)
431
432    def assert_fn():
433      assert_collections(ops.get_default_graph())
434      return 1  # Return is required
435
436    # Assert that collections are accessible within a wrapped function.
437    g.wrap_function(assert_fn, [])
438
439  def testShareVariablesSameGraph(self):
440
441    def add_v1(x):
442      with variable_scope.variable_scope(
443          'reuse', reuse=variable_scope.AUTO_REUSE):
444        v = variable_scope.get_variable(
445            'v', initializer=init_ops.Constant(3), shape=[], dtype=dtypes.int32)
446      return v + x
447
448    def subtract_v1(x):
449      with variable_scope.variable_scope(
450          'reuse', reuse=variable_scope.AUTO_REUSE):
451        v = variable_scope.get_variable(
452            'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32)
453      return v - x
454
455    def different_variable_fn_v1(x):
456      with variable_scope.variable_scope(
457          'no_reuse', reuse=variable_scope.AUTO_REUSE):
458        v = variable_scope.get_variable(
459            'v', initializer=init_ops.Constant(5), shape=[], dtype=dtypes.int32)
460      return v * x
461
462    def increment_variable_v1(x):
463      with variable_scope.variable_scope(
464          'reuse', reuse=variable_scope.AUTO_REUSE):
465        v = variable_scope.get_variable(
466            'v', initializer=init_ops.Constant(6), shape=[], dtype=dtypes.int32)
467      return v.assign_add(x)
468
469    g = wrap_function.WrappedGraph()
470    signature = [tensor_spec.TensorSpec([], dtypes.int32)]
471    add = g.wrap_function(add_v1, signature)
472    subtract = g.wrap_function(subtract_v1, signature)
473    different_variable_fn = g.wrap_function(different_variable_fn_v1, signature)
474    increment_variable = g.wrap_function(increment_variable_v1, signature)
475
476    self.assertEqual(10, add(constant_op.constant(7)).numpy())
477    self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
478
479    # The shared variable has a starting value of 3 because add_v1 was wrapped
480    # first.
481    self.assertEqual(-4, subtract(constant_op.constant(7)).numpy())
482    self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy())
483
484    # Check that variable updates
485    self.assertEqual(17, add(constant_op.constant(7)).numpy())
486    self.assertEqual(3, subtract(constant_op.constant(7)).numpy())
487
488    # Sanity check - result from this function shouldn't change.
489    self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
490
491    self.assertAllEqual({'reuse/v', 'no_reuse/v'}, set(g.variables.keys()))
492
493  def testShareVariablesDifferentGraphs(self):
494
495    def add_v1(x):
496      v = variables.Variable(3, name='v')
497      return v + x
498
499    def subtract_v1(x):
500      v = variables.Variable(4, name='v')
501      return v - x
502
503    def different_variable_fn_v1(x):
504      with ops.name_scope('different_scope'):
505        v = variables.Variable(5, name='v')
506      return v * x
507
508    def increment_variable_v1(x):
509      v = variables.Variable(6, name='v')
510      return v.assign_add(x)
511
512    signature = [tensor_spec.TensorSpec([], dtypes.int32)]
513    vh = wrap_function.VariableHolder(share_variables=True)
514    new_graph = lambda: wrap_function.WrappedGraph(variable_holder=vh)
515
516    add = new_graph().wrap_function(add_v1, signature)
517    subtract = new_graph().wrap_function(subtract_v1, signature)
518    different_variable_fn = new_graph().wrap_function(
519        different_variable_fn_v1, signature)
520    increment_variable = new_graph().wrap_function(
521        increment_variable_v1, signature)
522
523    self.assertEqual(10, add(constant_op.constant(7)).numpy())
524    self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
525
526    # Because the variable in add_v1 was created first, its starting value is 3
527    # instead of the values defined in subtract_v1 or increment_variable_v1.
528    self.assertEqual(-4, subtract(constant_op.constant(7)).numpy())
529    self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy())
530
531    # Check that variable updates
532    self.assertEqual(17, add(constant_op.constant(7)).numpy())
533    self.assertEqual(3, subtract(constant_op.constant(7)).numpy())
534
535    # Sanity check - result from this function shouldn't change.
536    self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
537
538    self.assertAllEqual({'v', 'different_scope/v'}, set(vh.variables.keys()))
539
540  @test_util.run_in_graph_and_eager_modes
541  def testImportedFunctionsRegistered(self):
542    if test_util.is_gpu_available():
543      self.skipTest('not a GPU test')
544    with ops.Graph().as_default() as graph:
545      x = array_ops.placeholder(dtypes.variant, shape=[], name='foo')
546      ds = dataset_ops.from_variant(x, structure=(
547          tensor_spec.TensorSpec([], dtypes.int32)))
548      y = ds.reduce(array_ops.zeros([], dtype=dtypes.int32), lambda p, q: p + q)
549
550    graph_def = graph.as_graph_def()
551
552    def fn_to_wrap(a):
553      returned_elements = graph_def_importer.import_graph_def(
554          graph_def, input_map={x.name: a}, return_elements=[y.name])
555      return returned_elements[0]
556
557    wrapped_fn = wrap_function.wrap_function(
558        fn_to_wrap, [tensor_spec.TensorSpec((), dtypes.variant)])
559    ds = dataset_ops.Dataset.from_tensor_slices([10, 20])
560    v = dataset_ops.to_variant(ds)
561    self.evaluate(wrapped_fn(v))
562
563  def testReturnOp(self):
564
565    def update_var_v1(x):
566      v = variables.Variable(3, name='v')
567      update_op = state_ops.assign(v, x).op
568      return update_op
569
570    g = wrap_function.WrappedGraph()
571    signature = [tensor_spec.TensorSpec([], dtypes.int32)]
572    update_var = g.wrap_function(update_var_v1, signature)
573
574    self.assertEqual(g.variables['v'].numpy(), 3)
575    update_var(constant_op.constant(12))
576    self.assertEqual(g.variables['v'].numpy(), 12)
577
578
579if __name__ == '__main__':
580  ops.enable_eager_execution()
581  test.main()
582