xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/convert_to_constants_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for convert_to_constants.py."""
16
17import os
18import re
19
20import numpy as np
21
22from google.protobuf import text_format
23from tensorflow.core.framework import attr_value_pb2
24from tensorflow.core.framework import function_pb2
25from tensorflow.core.framework import graph_pb2
26from tensorflow.core.framework import node_def_pb2
27from tensorflow.core.framework import op_def_pb2
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.core.protobuf import meta_graph_pb2
30from tensorflow.core.protobuf import saved_model_pb2
31from tensorflow.python.client import session as session_lib
32from tensorflow.python.eager import def_function
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import convert_to_constants
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import function
37from tensorflow.python.framework import importer
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import tensor_spec
40from tensorflow.python.framework import test_util
41from tensorflow.python.grappler import tf_optimizer
42from tensorflow.python.lib.io import file_io
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import cond_v2
45from tensorflow.python.ops import control_flow_ops
46from tensorflow.python.ops import control_flow_v2_toggles
47from tensorflow.python.ops import gen_math_ops
48from tensorflow.python.ops import math_ops
49from tensorflow.python.ops import rnn
50from tensorflow.python.ops import rnn_cell_impl
51from tensorflow.python.ops import variable_scope
52from tensorflow.python.ops import variables
53from tensorflow.python.ops import while_v2
54from tensorflow.python.platform import test
55from tensorflow.python.saved_model import constants
56from tensorflow.python.saved_model import loader_impl
57from tensorflow.python.saved_model import simple_save
58from tensorflow.python.saved_model.load import load
59from tensorflow.python.saved_model.save import save
60from tensorflow.python.trackable import autotrackable
61from tensorflow.python.training.saver import export_meta_graph
62from tensorflow.python.util import compat
63from tensorflow.python.util import nest
64
65
66class _GraphMerger(object):
67  """GraphDef merging methods for testing purposes."""
68
69  @staticmethod
70  def merge_any(x1, x2, empty_fn):
71    """Merges two values using the message's CopyFrom/MergeFrom methods."""
72    merged = empty_fn()
73    merged.CopyFrom(x1)
74    merged.MergeFrom(x2)
75    return merged
76
77  @staticmethod
78  def merge_nodes(node1, node2):
79    """Merges two NodeDef messages."""
80    merged = _GraphMerger.merge_any(node1, node2, node_def_pb2.NodeDef)
81    merged_inputs = node1.input[:]
82    merged_inputs.extend([i for i in node2.input[:] if i not in merged_inputs])
83    merged.input[:] = merged_inputs
84    return merged
85
86  @staticmethod
87  def merge_lists(repeated1, repeated2, empty_fn, key_fn, merge_fn):
88    """Merges two lists representing maps."""
89    merged = {}
90    xs1 = {key_fn(x): x for x in repeated1}
91    xs2 = {key_fn(x): x for x in repeated2}
92    for name in set().union(xs1.keys(), xs2.keys()):
93      x1 = empty_fn() if name not in xs1 else xs1[name]
94      x2 = empty_fn() if name not in xs2 else xs2[name]
95      merged[name] = merge_fn(x1, x2)
96    return sorted(merged.values(), key=key_fn)
97
98  @staticmethod
99  def merge_node_lists(repeated_nodes1, repeated_nodes2):
100    """Merges two repeated node fields."""
101    return _GraphMerger.merge_lists(repeated_nodes1, repeated_nodes2,
102                                    node_def_pb2.NodeDef, lambda n: n.name,
103                                    _GraphMerger.merge_nodes)
104
105  @staticmethod
106  def merge_functions(fn1, fn2):
107    """Merges two FunctionDefs."""
108    merged = _GraphMerger.merge_any(fn1, fn2, function_pb2.FunctionDef)
109
110    del merged.signature.input_arg[:]
111    merged.signature.input_arg.extend(
112        _GraphMerger.merge_lists(
113            fn1.signature.input_arg[:], fn2.signature.input_arg[:],
114            op_def_pb2.OpDef.ArgDef, lambda a: a.name,
115            lambda x, y: _GraphMerger.merge_any(x, y, op_def_pb2.OpDef.ArgDef)))
116
117    del merged.signature.output_arg[:]
118    merged.signature.output_arg.extend(
119        _GraphMerger.merge_lists(
120            fn1.signature.output_arg[:], fn2.signature.output_arg[:],
121            op_def_pb2.OpDef.ArgDef, lambda a: a.name,
122            lambda x, y: _GraphMerger.merge_any(x, y, op_def_pb2.OpDef.ArgDef)))
123
124    del merged.node_def[:]
125    merged.node_def.extend(
126        _GraphMerger.merge_node_lists(fn1.node_def[:], fn2.node_def[:]))
127
128    return merged
129
130  @staticmethod
131  def merge_graphs(graph1, graph2):
132    """Merges two GraphDef messages."""
133    merged = graph_pb2.GraphDef()
134    merged.node.extend(
135        _GraphMerger.merge_node_lists(graph1.node[:], graph2.node[:]))
136
137    merged.library.function.extend(
138        _GraphMerger.merge_lists(graph1.library.function,
139                                 graph2.library.function,
140                                 function_pb2.FunctionDef,
141                                 lambda f: f.signature.name,
142                                 _GraphMerger.merge_functions))
143
144    return merged
145
146
147def has_stateful_partitioned_call_op(graph_def):
148  """Determines if a StatefulPartitionedCall op exists in the graph."""
149  for node in graph_def.node:
150    if node.op == "StatefulPartitionedCall":
151      return True
152  return False
153
154
155def get_num_variables(graph_def):
156  """Returns the number of ReadVariableOp in the graph."""
157  return sum(node.op == "ReadVariableOp" for node in graph_def.node)
158
159
160class VariablesToConstantsTest(test.TestCase):
161
162  def _freezeModel(self, func):
163    """Freezes the function.
164
165    Args:
166      func: Function.
167
168    Returns:
169      root: AutoTrackable object with original ConcreteFunction.
170      output_func: frozen ConcreteFunction.
171    """
172    root = autotrackable.AutoTrackable()
173    root.f = func
174    input_func = root.f.get_concrete_function()
175
176    output_func = convert_to_constants.convert_variables_to_constants_v2(
177        input_func, lower_control_flow=False)
178    return root, output_func
179
180  def _testConvertedFunction(self, obj, func, converted_concrete_func,
181                             input_data):
182    # Ensure the converted graph has no variables and no function calls.
183    constant_graph_def = converted_concrete_func.graph.as_graph_def()
184    self.assertEqual(0, get_num_variables(constant_graph_def))
185    self.assertFalse(has_stateful_partitioned_call_op(constant_graph_def))
186
187    # Check that the converted ConcreteFunction produces the same result as the
188    # original Function.
189    expected_value = nest.flatten(func(**input_data))
190    actual_value = nest.flatten(converted_concrete_func(**input_data))
191
192    for expected, actual in zip(expected_value, actual_value):
193      np.testing.assert_almost_equal(expected.numpy(), actual.numpy())
194
195    # Ensure the shape is retained.
196    for tensor in converted_concrete_func.inputs:
197      actual_shape = input_data[tensor.name.split(":")[0]].shape
198      self.assertEqual(tensor.shape, actual_shape)
199
200    # Save the converted ConcreteFunction as a signature.
201    save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model")
202    root = autotrackable.AutoTrackable()
203    root.f = converted_concrete_func
204    save(root, save_dir, {"mykey": converted_concrete_func})
205
206    # Load it back and make sure it works.
207    loaded_obj = load(save_dir)
208    actual_value = nest.flatten(loaded_obj.signatures["mykey"](**input_data))
209    for expected, actual in zip(expected_value, actual_value):
210      np.testing.assert_almost_equal(expected.numpy(), actual.numpy())
211
212  @test_util.run_v2_only
213  def testConstSavedModel(self):
214    """Test a basic model with constants while saving/loading the SavedModel."""
215    input_data = {"x": constant_op.constant(1., shape=[1])}
216    root = autotrackable.AutoTrackable()
217    root.f = def_function.function(lambda x: 2. * x)
218    to_save = root.f.get_concrete_function(input_data["x"])
219
220    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
221    save(root, save_dir, to_save)
222    saved_model = load(save_dir)
223    input_func = saved_model.signatures["serving_default"]
224
225    variable_graph_def = input_func.graph.as_graph_def()
226    self.assertEqual(0, get_num_variables(variable_graph_def))
227    self.assertTrue(variable_graph_def.library.function)
228
229    output_func = convert_to_constants.convert_variables_to_constants_v2(
230        input_func)
231    self._testConvertedFunction(root, root.f, output_func, input_data)
232
233  @test_util.run_v2_only
234  def testVariableModel(self):
235    """Test a basic model with Variables."""
236    input_data = {"x": constant_op.constant(1., shape=[1])}
237    root = autotrackable.AutoTrackable()
238    root.v1 = variables.Variable(3.)
239    root.v2 = variables.Variable(2.)
240    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
241    input_func = root.f.get_concrete_function(input_data["x"])
242
243    variable_graph_def = input_func.graph.as_graph_def()
244    self.assertEqual(2, get_num_variables(variable_graph_def))
245
246    output_func = convert_to_constants.convert_variables_to_constants_v2(
247        input_func)
248    self._testConvertedFunction(root, root.f, output_func, input_data)
249
250  @test_util.run_v2_only
251  def testScalarModel(self):
252    """Test a basic model with Variables."""
253    input_data = {"x": constant_op.constant(1., shape=[])}
254    root = autotrackable.AutoTrackable()
255    root.v1 = variables.Variable(3.)
256    root.v2 = variables.Variable(2.)
257    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
258    input_func = root.f.get_concrete_function(input_data["x"])
259
260    variable_graph_def = input_func.graph.as_graph_def()
261    self.assertEqual(2, get_num_variables(variable_graph_def))
262
263    output_func = convert_to_constants.convert_variables_to_constants_v2(
264        input_func)
265    self._testConvertedFunction(root, root.f, output_func, input_data)
266
267  @test_util.run_v2_only
268  def testVariableSavedModel(self):
269    """Test a basic model with Variables with saving/loading the SavedModel."""
270    input_data = {"x": constant_op.constant(1., shape=[1])}
271    root = autotrackable.AutoTrackable()
272    root.v1 = variables.Variable(3.)
273    root.v2 = variables.Variable(2.)
274    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
275    to_save = root.f.get_concrete_function(input_data["x"])
276
277    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
278    save(root, save_dir, to_save)
279    saved_model = load(save_dir)
280    input_func = saved_model.signatures["serving_default"]
281
282    variable_graph_def = input_func.graph.as_graph_def()
283    self.assertTrue(has_stateful_partitioned_call_op(variable_graph_def))
284
285    output_func = convert_to_constants.convert_variables_to_constants_v2(
286        input_func)
287    self._testConvertedFunction(root, root.f, output_func, input_data)
288
289  @test_util.run_v2_only
290  def testMultiFunctionModel(self):
291    """Test a basic model with multiple tf.functions."""
292
293    class BasicModel(autotrackable.AutoTrackable):
294
295      def __init__(self):
296        self.y = None
297        self.z = None
298
299      @def_function.function
300      def add(self, x):
301        if self.y is None:
302          self.y = variables.Variable(2.)
303        return x + self.y
304
305      @def_function.function
306      def sub(self, x):
307        if self.z is None:
308          self.z = variables.Variable(3.)
309        return x - self.z
310
311    input_data = {"x": constant_op.constant(1., shape=[1])}
312    root = BasicModel()
313    input_func = root.add.get_concrete_function(input_data["x"])
314
315    variable_graph_def = input_func.graph.as_graph_def()
316    self.assertEqual(1, get_num_variables(variable_graph_def))
317
318    output_func = convert_to_constants.convert_variables_to_constants_v2(
319        input_func)
320    self._testConvertedFunction(root, root.add, output_func, input_data)
321
322  def _singleMetaGraphSavedModel(self):
323    export_graph = ops.Graph()
324    with export_graph.as_default():
325      start = array_ops.placeholder(
326          shape=[1, 1], dtype=dtypes.float32, name="start")
327      distractor = variables.RefVariable(-1., name="distractor")
328      v = variables.RefVariable(3., name="v")
329      local_variable = variables.VariableV1(
330          1.,
331          collections=[ops.GraphKeys.LOCAL_VARIABLES],
332          trainable=False,
333          use_resource=True)
334      output = array_ops.identity(start * v * local_variable, name="output")
335      with session_lib.Session() as session:
336        session.run([v.initializer, distractor.initializer,
337                     local_variable.initializer])
338        path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
339        simple_save.simple_save(
340            session,
341            path,
342            inputs={"start": start},
343            outputs={"output": output},
344            legacy_init_op=local_variable.initializer)
345    return path
346
347  @test_util.run_v2_only
348  def testRefVariableImport(self):
349    """Test a model with 1.X ReferenceVariables."""
350    input_data = {"start": constant_op.constant(1., shape=[1, 1])}
351
352    saved = self._singleMetaGraphSavedModel()
353    imported = load(saved)
354    fn = imported.signatures["serving_default"]
355
356    output_func = convert_to_constants.convert_variables_to_constants_v2(fn)
357    root = autotrackable.AutoTrackable()
358    self._testConvertedFunction(root, fn, output_func, input_data)
359
360  @test_util.run_v2_only
361  def testIf(self):
362    """Test a model with the If op."""
363    input_data = {
364        "x": constant_op.constant([1., 2.], shape=[1, 2]),
365        "b": constant_op.constant(True)
366    }
367
368    weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=dtypes.float32)
369
370    def true_fn(x):
371      return math_ops.matmul(x, weights)
372
373    def false_fn(x):
374      return math_ops.add(x, weights)
375
376    @def_function.function(input_signature=[
377        tensor_spec.TensorSpec(shape=[1, 2], dtype=dtypes.float32),
378        tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)
379    ])
380    def model(x, b):
381      return control_flow_ops.cond(
382          b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x))
383
384    root, output_func = self._freezeModel(model)
385    self._testConvertedFunction(root, root.f, output_func, input_data)
386
387  @test_util.run_v2_only
388  def testStatelessIf(self):
389    """Test a model with the StatelessIf op."""
390    input_data = {"b": constant_op.constant(True)}
391
392    x = constant_op.constant([1., 2.], shape=[1, 2], name="x")
393
394    def true_fn():
395      return x
396
397    def false_fn():
398      return x + 2
399
400    @def_function.function(
401        input_signature=[tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)])
402    def model(b):
403      return cond_v2.cond_v2(b, true_fn, false_fn)
404
405    root, output_func = self._freezeModel(model)
406    self._testConvertedFunction(root, root.f, output_func, input_data)
407
408  @test_util.run_v2_only
409  def testStaticRnn(self):
410    """Test a StaticRnn containing If ops."""
411    input_data = {
412        "x":
413            constant_op.constant(
414                np.array(np.random.random_sample((3, 10)), dtype=np.float32))
415    }
416
417    cell = rnn_cell_impl.LSTMCell(10)
418
419    @def_function.function(input_signature=[
420        tensor_spec.TensorSpec(shape=[3, 10], dtype=dtypes.float32)
421    ])
422    def model(x):
423      seq = array_ops.split(x, 3, 0)
424      return rnn.static_rnn(
425          cell, seq, dtype=dtypes.float32, sequence_length=[1])
426
427    root, output_func = self._freezeModel(model)
428
429    self._testConvertedFunction(root, root.f, output_func, input_data)
430
431  @test_util.run_v2_only
432  def testWhile(self):
433    """Test a While loop."""
434    input_data = {"x": constant_op.constant([1., 2., 3., 4.], shape=[2, 2])}
435
436    weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=dtypes.float32)
437
438    def condition(x):
439      return math_ops.reduce_sum(x) < 100
440
441    def body(x):
442      return math_ops.add(x, weights)
443
444    @def_function.function(input_signature=[
445        tensor_spec.TensorSpec(shape=[2, 2], dtype=dtypes.float32)
446    ])
447    def model(x):
448      return control_flow_ops.while_loop(condition, body, [x])
449
450    root, output_func = self._freezeModel(model)
451
452    self._testConvertedFunction(root, root.f, output_func, input_data)
453
454  @test_util.run_v2_only
455  def testStatelessWhile(self):
456    """Test a StatelessWhile loop."""
457    input_data = {"x": constant_op.constant(2.)}
458
459    @def_function.function(input_signature=[
460        tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
461    ])
462    def model(x):
463      return while_v2.while_loop(
464          lambda v: v < 4.,
465          lambda v: v * v, [x],
466          return_same_structure=False,
467          name="while_1")  # x**2
468
469    root, output_func = self._freezeModel(model)
470    self._testConvertedFunction(root, root.f, output_func, input_data)
471
472  @test_util.run_v2_only
473  def testDynamicRnn(self):
474    """Test a DynamicRnn containing While loops."""
475    input_data = {
476        "x":
477            constant_op.constant(
478                np.array(
479                    np.random.random_sample((3, 10, 10)), dtype=np.float32))
480    }
481
482    cell = rnn_cell_impl.LSTMCell(10)
483
484    @def_function.function(input_signature=[
485        tensor_spec.TensorSpec(shape=[3, 10, 10], dtype=dtypes.float32)
486    ])
487    def model(x):
488      return rnn.dynamic_rnn(cell, x, dtype=dtypes.float32)
489
490    root, output_func = self._freezeModel(model)
491    self._testConvertedFunction(root, root.f, output_func, input_data)
492
493  @test_util.run_v2_only
494  @test_util.disable_tfrt("b/180451239")
495  def testSwitchCase(self):
496    """Test a switch_case statement."""
497    input_data = {
498        "i": constant_op.constant(np.random.randint(0, 3, dtype=np.int32)),
499        "x": constant_op.constant(
500            np.asarray(np.random.random_sample((10, 3)), dtype=np.float32)),
501    }
502
503    w0 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32)
504    w1 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32)
505    w2 = variables.Variable(np.random.random_sample((4,)), dtype=np.float32)
506
507    def branch0(x):
508      return math_ops.matmul(x, w0)
509
510    def branch1(x):
511      return math_ops.matmul(x, w1)
512
513    def branch2(x):
514      x = array_ops.pad(x, [[0, 0], [0, 1]])
515      return x + w2
516
517    @def_function.function(input_signature=[
518        tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32),
519        tensor_spec.TensorSpec(shape=[10, 3], dtype=dtypes.float32),
520    ])
521    def model(i, x):
522      return control_flow_ops.switch_case(i, [
523          lambda: branch0(x), lambda: branch1(x), lambda: branch2(x)])
524
525    root, output_func = self._freezeModel(model)
526    self._testConvertedFunction(root, root.f, output_func, input_data)
527
528
529class ConvertVariablesToConstantsV2SessionTest(test.TestCase):
530
531  def _freezeModel(self, func):
532    """Freezes the function.
533
534    Args:
535      func: Function.
536
537    Returns:
538      root: AutoTrackable object with original ConcreteFunction.
539      output_func: frozen ConcreteFunction.
540    """
541    root = autotrackable.AutoTrackable()
542    root.f = func
543    input_func = root.f.get_concrete_function()
544
545    output_func = convert_to_constants.convert_var_to_const_function_in_v1(
546        input_func, lower_control_flow=False)
547    return root, output_func
548
549  def _testConvertedFunction(self, sess, obj, func, converted_concrete_func,
550                             input_data):
551    # Ensure the converted graph has no variables and no function calls.
552    constant_graph_def = converted_concrete_func.graph.as_graph_def()
553    self.assertEqual(0, get_num_variables(constant_graph_def))
554    self.assertFalse(has_stateful_partitioned_call_op(constant_graph_def))
555
556    # Check that the converted ConcreteFunction produces the same result as the
557    # original Function.
558    expected_value = nest.flatten(func(**input_data))
559    actual_value = nest.flatten(converted_concrete_func(**input_data))
560
561    for expected, actual in zip(expected_value, actual_value):
562      np.testing.assert_almost_equal(sess.run(expected), sess.run(actual))
563
564    # Ensure the shape is retained.
565    for tensor in converted_concrete_func.inputs:
566      actual_shape = input_data[tensor.name.split(":")[0]].shape
567      self.assertEqual(tensor.shape, actual_shape)
568
569    # Save the converted ConcreteFunction as a signature.
570    save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model")
571    root = autotrackable.AutoTrackable()
572    root.f = converted_concrete_func
573    save(root, save_dir, {"mykey": converted_concrete_func})
574
575    # Load it back and make sure it works.
576    loaded_obj = load(save_dir)
577    actual_value = nest.flatten(loaded_obj.signatures["mykey"](**input_data))
578    for expected, actual in zip(expected_value, actual_value):
579      np.testing.assert_almost_equal(sess.run(expected), sess.run(actual))
580
581  def testRaiseErrorInEagerMode(self):
582    """Test the raised exception in Eager mode."""
583    input_data = {"x": constant_op.constant(1., shape=[1])}
584    root = autotrackable.AutoTrackable()
585    root.v1 = variables.Variable(3.)
586    root.v2 = variables.Variable(2.)
587    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
588    input_func = root.f.get_concrete_function(input_data["x"])
589
590    with self.assertRaisesRegex(RuntimeError,
591                                "must be carried out in a Session"):
592      convert_to_constants.convert_var_to_const_function_in_v1(
593          input_func)
594
595  def testConvertVariables(self):
596    """Test a basic model with Variables."""
597    with ops.Graph().as_default():
598      with session_lib.Session() as sess:
599        input_data = {"x": constant_op.constant(1., shape=[1])}
600        root = autotrackable.AutoTrackable()
601        root.v1 = variables.Variable(3.)
602        root.v2 = variables.Variable(2.)
603        root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
604        input_func = root.f.get_concrete_function(input_data["x"])
605
606        variable_graph_def = input_func.graph.as_graph_def()
607        self.assertEqual(2, get_num_variables(variable_graph_def))
608
609        output_func = convert_to_constants.convert_var_to_const_function_in_v1(
610            input_func)
611
612        self._testConvertedFunction(sess, root, root.f, output_func, input_data)
613
614  def testConvertVariablesWithAssignments(self):
615    """Test a basic model with Variables and assignment ops."""
616    with ops.Graph().as_default():
617      with session_lib.Session() as sess:
618        input_data = {"x": constant_op.constant(1., shape=[1])}
619        root = autotrackable.AutoTrackable()
620        root.v1 = variables.Variable(3.)
621        root.v2 = variables.Variable(2.)
622        root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
623        input_func = root.f.get_concrete_function(input_data["x"])
624
625        variable_graph_def = input_func.graph.as_graph_def()
626        self.assertEqual(2, get_num_variables(variable_graph_def))
627
628        assign_op_1 = root.v1.assign(1.5)
629        assign_op_2 = root.v2.assign(3.0)
630        assign_op_3 = root.v1.assign(4.0)
631        ops.get_default_graph().add_to_collection(
632            convert_to_constants.VAR_ASSIGN_COLLECTION, assign_op_1)
633        ops.get_default_graph().add_to_collection(
634            convert_to_constants.VAR_ASSIGN_COLLECTION, assign_op_2)
635        ops.get_default_graph().add_to_collection(
636            convert_to_constants.VAR_ASSIGN_COLLECTION, assign_op_3)
637
638        output_func = convert_to_constants.convert_var_to_const_function_in_v1(
639            input_func)
640        self._testConvertedFunction(sess, root, root.f, output_func, input_data)
641
642  def testConstSavedModel(self):
643    """Test a basic model with constants while saving/loading the SavedModel."""
644    with ops.Graph().as_default():
645      with session_lib.Session() as sess:
646        input_data = {"x": constant_op.constant(1., shape=[1])}
647        root = autotrackable.AutoTrackable()
648        root.f = def_function.function(lambda x: 2. * x)
649        to_save = root.f.get_concrete_function(input_data["x"])
650
651        save_dir = os.path.join(self.get_temp_dir(), "saved_model")
652        save(root, save_dir, to_save)
653        saved_model = load(save_dir)
654        input_func = saved_model.signatures["serving_default"]
655
656        variable_graph_def = input_func.graph.as_graph_def()
657        self.assertEqual(0, get_num_variables(variable_graph_def))
658        self.assertTrue(variable_graph_def.library.function)
659
660        output_func = convert_to_constants.convert_var_to_const_function_in_v1(
661            input_func)
662        self._testConvertedFunction(sess, root, root.f, output_func, input_data)
663
664  def testVariableSavedModel(self):
665    """Test a basic model with Variables with saving/loading the SavedModel."""
666    with ops.Graph().as_default():
667      with session_lib.Session() as sess:
668        input_data = {"x": constant_op.constant(1., shape=[1])}
669        root = autotrackable.AutoTrackable()
670        root.v1 = variables.Variable(3.)
671        root.v2 = variables.Variable(2.)
672        root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
673        to_save = root.f.get_concrete_function(input_data["x"])
674        sess.run(variables.global_variables_initializer())
675
676        save_dir = os.path.join(self.get_temp_dir(), "saved_model")
677        save(root, save_dir, to_save)
678        saved_model = load(save_dir)
679        input_func = saved_model.signatures["serving_default"]
680
681        variable_graph_def = input_func.graph.as_graph_def()
682        self.assertTrue(has_stateful_partitioned_call_op(variable_graph_def))
683
684        output_func = convert_to_constants.convert_var_to_const_function_in_v1(
685            input_func)
686        self._testConvertedFunction(sess, root, root.f, output_func, input_data)
687
688  def testMultiFunctionModel(self):
689    """Test a basic model with multiple tf.functions."""
690
691    class BasicModel(autotrackable.AutoTrackable):
692
693      def __init__(self):
694        self.y = None
695        self.z = None
696
697      @def_function.function
698      def add(self, x):
699        if self.y is None:
700          self.y = variables.Variable(2.)
701        return x + self.y
702
703      @def_function.function
704      def sub(self, x):
705        if self.z is None:
706          self.z = variables.Variable(3.)
707        return x - self.z
708
709    with ops.Graph().as_default():
710      with session_lib.Session() as sess:
711        input_data = {"x": constant_op.constant(1., shape=[1])}
712        root = BasicModel()
713        input_func = root.add.get_concrete_function(input_data["x"])
714
715        variable_graph_def = input_func.graph.as_graph_def()
716        self.assertEqual(1, get_num_variables(variable_graph_def))
717
718        output_func = convert_to_constants.convert_var_to_const_function_in_v1(
719            input_func)
720        self._testConvertedFunction(sess, root, root.add, output_func,
721                                    input_data)
722
723  def testIf(self):
724    """Test a model with the If op."""
725    with ops.Graph().as_default():
726      with session_lib.Session() as sess:
727        input_data = {
728            "x": constant_op.constant([1., 2.], shape=[1, 2]),
729            "b": constant_op.constant(True)
730        }
731
732        weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]],
733                                     dtype=dtypes.float32)
734
735        def true_fn(x):
736          return math_ops.matmul(x, weights)
737
738        def false_fn(x):
739          return math_ops.add(x, weights)
740
741        @def_function.function(input_signature=[
742            tensor_spec.TensorSpec(shape=[1, 2], dtype=dtypes.float32),
743            tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)
744        ])
745        def model(x, b):
746          return control_flow_ops.cond(
747              b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x))
748
749        root, output_func = self._freezeModel(model)
750        self._testConvertedFunction(sess, root, root.f, output_func, input_data)
751
752  def testStatelessIf(self):
753    """Test a model with the StatelessIf op."""
754    with ops.Graph().as_default():
755      with session_lib.Session() as sess:
756        input_data = {"b": constant_op.constant(True)}
757
758        x = constant_op.constant([1., 2.], shape=[1, 2], name="x")
759
760        def true_fn():
761          return x
762
763        def false_fn():
764          return x + 2
765
766        @def_function.function(input_signature=[
767            tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)
768        ])
769        def model(b):
770          return cond_v2.cond_v2(b, true_fn, false_fn)
771
772        root, output_func = self._freezeModel(model)
773        self._testConvertedFunction(sess, root, root.f, output_func, input_data)
774
775  def testStaticRnn(self):
776    """Test a StaticRnn containing If ops."""
777    with ops.Graph().as_default():
778      with session_lib.Session() as sess:
779        input_data = {
780            "x":
781                constant_op.constant(
782                    np.array(
783                        np.random.random_sample((3, 10)), dtype=np.float32))
784        }
785
786        cell = rnn_cell_impl.LSTMCell(10)
787
788        @def_function.function(input_signature=[
789            tensor_spec.TensorSpec(shape=[3, 10], dtype=dtypes.float32)
790        ])
791        def model(x):
792          seq = array_ops.split(x, 3, 0)
793          return rnn.static_rnn(
794              cell, seq, dtype=dtypes.float32, sequence_length=[1])
795
796        root, output_func = self._freezeModel(model)
797
798        self._testConvertedFunction(sess, root, root.f, output_func, input_data)
799
800  def testWhile(self):
801    """Test a While loop."""
802    with ops.Graph().as_default():
803      with session_lib.Session() as sess:
804        input_data = {"x": constant_op.constant([1., 2., 3., 4.], shape=[2, 2])}
805
806        weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]],
807                                     dtype=dtypes.float32)
808
809        def condition(x):
810          return math_ops.reduce_sum(x) < 100
811
812        def body(x):
813          return math_ops.add(x, weights)
814
815        @def_function.function(input_signature=[
816            tensor_spec.TensorSpec(shape=[2, 2], dtype=dtypes.float32)
817        ])
818        def model(x):
819          return control_flow_ops.while_loop(condition, body, [x])
820
821        root, output_func = self._freezeModel(model)
822
823        self._testConvertedFunction(sess, root, root.f, output_func, input_data)
824
825  def testStatelessWhile(self):
826    """Test a StatelessWhile loop."""
827    with ops.Graph().as_default():
828      with session_lib.Session() as sess:
829        input_data = {"x": constant_op.constant(2.)}
830
831        @def_function.function(input_signature=[
832            tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
833        ])
834        def model(x):
835          return while_v2.while_loop(
836              lambda v: v < 4.,
837              lambda v: v * v, [x],
838              return_same_structure=False,
839              name="while_1")  # x**2
840
841        root, output_func = self._freezeModel(model)
842        self._testConvertedFunction(sess, root, root.f, output_func, input_data)
843
844  def testDynamicRnn(self):
845    """Test a DynamicRnn containing While loops."""
846    with ops.Graph().as_default():
847      with session_lib.Session() as sess:
848        input_data = {
849            "x":
850                constant_op.constant(
851                    np.array(
852                        np.random.random_sample((3, 10, 10)), dtype=np.float32))
853        }
854
855        cell = rnn_cell_impl.LSTMCell(10)
856
857        @def_function.function(input_signature=[
858            tensor_spec.TensorSpec(shape=[3, 10, 10], dtype=dtypes.float32)
859        ])
860        def model(x):
861          return rnn.dynamic_rnn(cell, x, dtype=dtypes.float32)
862
863        root, output_func = self._freezeModel(model)
864        self._testConvertedFunction(sess, root, root.f, output_func, input_data)
865
866  @test_util.disable_tfrt("b/180451239")
867  def testSwitchCase(self):
868    """Test a switch_case statement."""
869    with ops.Graph().as_default():
870      with session_lib.Session() as sess:
871        input_data = {
872            "i":
873                constant_op.constant(np.random.randint(0, 3, dtype=np.int32)),
874            "x":
875                constant_op.constant(
876                    np.asarray(
877                        np.random.random_sample((10, 3)), dtype=np.float32)),
878        }
879
880        w0 = variables.Variable(
881            np.random.random_sample((3, 4)), dtype=np.float32)
882        w1 = variables.Variable(
883            np.random.random_sample((3, 4)), dtype=np.float32)
884        w2 = variables.Variable(np.random.random_sample((4,)), dtype=np.float32)
885
886        def branch0(x):
887          return math_ops.matmul(x, w0)
888
889        def branch1(x):
890          return math_ops.matmul(x, w1)
891
892        def branch2(x):
893          x = array_ops.pad(x, [[0, 0], [0, 1]])
894          return x + w2
895
896        @def_function.function(input_signature=[
897            tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32),
898            tensor_spec.TensorSpec(shape=[10, 3], dtype=dtypes.float32),
899        ])
900        def model(i, x):
901          return control_flow_ops.switch_case(
902              i, [lambda: branch0(x), lambda: branch1(x), lambda: branch2(x)])
903
904        root, output_func = self._freezeModel(model)
905        self._testConvertedFunction(sess, root, root.f, output_func, input_data)
906
907
908class ConvertVariablesToConstantsSessionTest(test.TestCase):
909
910  def _assertGraphContains(self, graph, subgraph):
911    """Asserts that the given subgraph is contained within the given graph."""
912
913    def normalize_uids(msg):
914      """Replace auto-id function names with something consistent."""
915      # These functions have non-deterministic names, the non-determinism coming
916      # from having an ops.uid() suffix in their names. We're replacing these
917      # with new sequential IDs starting from 0 for each prefix, which is
918      # is sufficient for tests.
919      if isinstance(msg, graph_pb2.GraphDef):
920        msg = text_format.MessageToString(msg)
921      name_prefixes = ["case_cond_true.*", "case_cond_false.*"]
922      name_regex = r"\b(" + "|".join(name_prefixes) + r")_([0-9]+)\b"
923      names = {}
924      for (name, index) in re.findall(name_regex, msg):
925        names.setdefault(name, set()).add(int(index))
926      for name, indices in names.items():
927        for new_index, old_index in enumerate(sorted(list(indices))):
928          msg = re.sub(r"\b" + name + "_" + str(old_index) + r"\b",
929                       name + "_" + str(new_index), msg)
930      return msg
931
932    norm_graph = text_format.Parse(normalize_uids(graph), graph_pb2.GraphDef())
933    norm_subgraph = text_format.Parse(
934        normalize_uids(subgraph), graph_pb2.GraphDef())
935
936    # Graph S is contained in C if and only if merge(C,S) == C.
937    # We merge the input graph with an empty graph to normalize repeated fields:
938    # assertProtoEquals is sensitive to ordering.
939    norm_graph = _GraphMerger.merge_graphs(norm_graph, graph_pb2.GraphDef())
940    merged_graph = _GraphMerger.merge_graphs(norm_graph, norm_subgraph)
941    self.assertProtoEquals(norm_graph, merged_graph)
942
943  def _ensure_no_variables_in_graph(self, graph_def):
944    """Ensures there are no variables in the graph."""
945    for node in graph_def.node:
946      self.assertNotIn(
947          node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
948
949  def _test_variable_to_const_conversion(self, use_resource):
950    with ops.Graph().as_default():
951      with variable_scope.variable_scope("", use_resource=use_resource):
952        variable_node = variable_scope.get_variable(
953            "variable_node", initializer=1.0)
954        variable_scope.get_variable("unused_variable_node", initializer=1.0)
955        output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
956        with session_lib.Session() as sess:
957          self.evaluate(variable_node.initializer)
958          output = self.evaluate(output_node)
959          self.assertNear(2.0, output, 0.00001)
960          variable_graph_def = sess.graph.as_graph_def()
961          constant_graph_def = (
962              convert_to_constants
963              .convert_variables_to_constants_from_session_graph(
964                  session=sess,
965                  graph_def=variable_graph_def,
966                  output_node_names=["output_node"]))
967
968          self._ensure_no_variables_in_graph(constant_graph_def)
969
970    # Now we make sure the variable is now a constant, and that the graph still
971    # produces the expected result.
972    with ops.Graph().as_default():
973      _ = importer.import_graph_def(constant_graph_def, name="")
974      self.assertEqual(4, len(constant_graph_def.node))
975      self._ensure_no_variables_in_graph(constant_graph_def)
976      with session_lib.Session() as sess:
977        output_node = sess.graph.get_tensor_by_name("output_node:0")
978        output = self.evaluate(output_node)
979        self.assertNear(2.0, output, 0.00001)
980
981  def test_resource_variable_can_be_written_after_denylisting(self):
982    with ops.Graph().as_default():
983      with variable_scope.variable_scope("", use_resource=True):
984        variable_node = variable_scope.get_variable(
985            "variable_node", initializer=1.0)
986        another_variable = variable_scope.get_variable(
987            "unused_variable_node", initializer=2.0)
988        with ops.control_dependencies(
989            [variable_node.assign(another_variable + variable_node)]):
990          output_node = array_ops.identity(variable_node, name="output_node")
991        initializer_name = variable_node.initializer.name
992        with session_lib.Session() as sess:
993          self.evaluate(variable_node.initializer)
994          self.evaluate(another_variable.initializer)
995          output = self.evaluate(output_node)
996          self.assertNear(3.0, output, 0.00001)
997          variable_graph_def = sess.graph.as_graph_def()
998
999          # Test variable name black list. This should result in the variable
1000          # not being a const.  Furthermore, the paths that read from and assign
1001          # to the denylisted variable should continue to be valid.
1002          constant_graph_def_with_denylist = (
1003              convert_to_constants
1004              .convert_variables_to_constants_from_session_graph(
1005                  session=sess,
1006                  graph_def=variable_graph_def,
1007                  output_node_names=["output_node", initializer_name],
1008                  variable_names_denylist=set(["variable_node"])))
1009
1010          variable_node = None
1011          for node in constant_graph_def_with_denylist.node:
1012            if node.name == "variable_node":
1013              variable_node = node
1014          self.assertIsNotNone(variable_node)
1015          self.assertEqual(variable_node.op, "VarHandleOp")
1016
1017    # Now we make sure another_variable is now a constant, but the original
1018    # variable is not, and that the graph can be executed and update the
1019    # variable can be updated with each execution.
1020    with ops.Graph().as_default():
1021      _ = importer.import_graph_def(constant_graph_def_with_denylist, name="")
1022      with session_lib.Session() as sess:
1023        output_node = sess.graph.get_tensor_by_name("output_node:0")
1024        self.evaluate(sess.graph.get_operation_by_name(initializer_name))
1025        output = self.evaluate(output_node)
1026        self.assertNear(3.0, output, 0.00001)
1027        output = self.evaluate(output_node)
1028        self.assertNear(5.0, output, 0.00001)
1029
1030  def _inline_functions(self, graph_def, arrays):
1031    meta_graph = export_meta_graph(graph_def=graph_def)
1032    fetch_collection = meta_graph_pb2.CollectionDef()
1033    for name in arrays:
1034      fetch_collection.node_list.value.append(name)
1035    meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
1036
1037    # Initialize RewriterConfig with everything disabled except function
1038    # inlining.
1039    config = config_pb2.ConfigProto()
1040    rewrite_options = config.graph_options.rewrite_options
1041    rewrite_options.optimizers.append("function")
1042    return tf_optimizer.OptimizeGraph(config, meta_graph)
1043
1044  def _test_convert_variables_with_functions(self, inline_functions):
1045    """Freezes a graph with functions."""
1046
1047    @function.Defun(dtypes.float32)
1048    def plus_one(x):
1049      return x + 1.0
1050
1051    with ops.Graph().as_default():
1052      variable_node = variables.Variable(1.0, name="variable_node")
1053      _ = variables.Variable(1.0, name="unused_variable_node")
1054      defun_node = plus_one(variable_node)
1055      _ = math_ops.multiply(defun_node, 2.0, name="output_node")
1056
1057      with session_lib.Session() as sess:
1058        self.evaluate(variables.variables_initializer([variable_node]))
1059        variable_graph_def = sess.graph.as_graph_def()
1060
1061        if inline_functions:
1062          # Run Grappler to create the VarOpHandle --> Placeholder -->
1063          # ResourceVariable pattern.
1064          variable_graph_def = self._inline_functions(
1065              variable_graph_def, ["variable_node", "output_node"])
1066
1067        constant_graph_def = (
1068            convert_to_constants
1069            .convert_variables_to_constants_from_session_graph(
1070                session=sess,
1071                graph_def=variable_graph_def,
1072                output_node_names=["output_node"]))
1073
1074    self._ensure_no_variables_in_graph(constant_graph_def)
1075
1076  def testReferenceVariables(self):
1077    """Freezes a graph with reference variables."""
1078    self._test_variable_to_const_conversion(use_resource=False)
1079
1080  def testResourceVariables(self):
1081    """Freezes a graph with resource variables."""
1082    self._test_variable_to_const_conversion(use_resource=True)
1083
1084  def testWithFunctions(self):
1085    """Freezes a graph with functions."""
1086    self._test_convert_variables_with_functions(inline_functions=False)
1087
1088  def testWithInlinedFunctions(self):
1089    """Freezes a graph with functions that have been inlined using Grappler."""
1090    self._test_convert_variables_with_functions(inline_functions=True)
1091
1092  def testGraphWithSwitch(self):
1093    """Freezes a graph which contains a Switch with type RESOURCE_DT."""
1094    with ops.Graph().as_default():
1095      with variable_scope.variable_scope("", use_resource=True):
1096        x = variable_scope.get_variable("var_x", initializer=1.0)
1097        y = variable_scope.get_variable("var_y", initializer=2.0)
1098        f1 = lambda: variable_scope.get_variable("var_f1", initializer=17.0)
1099        f2 = lambda: variable_scope.get_variable("var_f2", initializer=23.0)
1100        cond_node = control_flow_ops.case([(gen_math_ops.less(x, y), f1)],
1101                                          default=f2)
1102        _ = math_ops.multiply(cond_node, 2.0, name="output_node")
1103
1104        with session_lib.Session() as sess:
1105          sess.run(variables.global_variables_initializer())
1106          variable_graph_def = sess.graph.as_graph_def()
1107
1108          constant_graph_def = (
1109              convert_to_constants
1110              .convert_variables_to_constants_from_session_graph(
1111                  session=sess,
1112                  graph_def=variable_graph_def,
1113                  output_node_names=["output_node"]))
1114
1115    self._ensure_no_variables_in_graph(constant_graph_def)
1116
1117  def testConvertSingleVariable(self):
1118    """Tests that a single variable is properly converted to a constant."""
1119
1120    with ops.Graph().as_default():
1121      with variable_scope.variable_scope("", use_resource=False):
1122        _ = variable_scope.get_variable("x", initializer=1.0)
1123      with session_lib.Session() as sess:
1124        sess.run(variables.global_variables_initializer())
1125        variable_graph_def = sess.graph.as_graph_def()
1126        constant_graph_def = (
1127            convert_to_constants
1128            .convert_variables_to_constants_from_session_graph(
1129                sess, variable_graph_def, ["x/read"]))
1130        self._assertGraphContains(
1131            constant_graph_def, """
1132            node {
1133              name: "x" op: "Const"
1134              attr { key: "dtype" value { type: DT_FLOAT } }
1135              attr {
1136                key: "value"
1137                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
1138            }
1139            node {
1140              name: "x/read" op: "Identity" input: "x"
1141              attr { key: "T" value { type: DT_FLOAT } }
1142            }""")
1143
1144  def testConvertSingleResourceVariable(self):
1145    """Tests that a resource variable is properly converted to a constant."""
1146    with ops.Graph().as_default():
1147      with variable_scope.variable_scope("", use_resource=True):
1148        _ = variable_scope.get_variable("x", initializer=1.0)
1149      with session_lib.Session() as sess:
1150        sess.run(variables.global_variables_initializer())
1151        variable_graph_def = sess.graph.as_graph_def()
1152        constant_graph_def = (
1153            convert_to_constants
1154            .convert_variables_to_constants_from_session_graph(
1155                sess, variable_graph_def, ["x/Read/ReadVariableOp"]))
1156        self._assertGraphContains(
1157            constant_graph_def, """
1158            node {
1159              name: "x" op: "Const"
1160              attr { key: "dtype" value { type: DT_FLOAT } }
1161              attr {
1162                key: "value"
1163                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
1164            }
1165            node {
1166              name: "x/Read/ReadVariableOp" op: "Identity" input: "x"
1167              attr { key: "T" value { type: DT_FLOAT } }
1168            }""")
1169
1170  def testConvertOneVariableOfTwo(self):
1171    """Tests that one variable can be kept unconverted."""
1172    with ops.Graph().as_default():
1173      with variable_scope.variable_scope("", use_resource=False):
1174        x = variable_scope.get_variable("x", initializer=1.0)
1175        y = variable_scope.get_variable("y", initializer=1.0)
1176        _ = math_ops.multiply(x, y, name="out")
1177      with session_lib.Session() as sess:
1178        sess.run(variables.global_variables_initializer())
1179        variable_graph_def = sess.graph.as_graph_def()
1180        constant_graph_def = (
1181            convert_to_constants
1182            .convert_variables_to_constants_from_session_graph(
1183                sess,
1184                variable_graph_def, ["out"],
1185                variable_names_denylist=["y"]))
1186        self._assertGraphContains(
1187            constant_graph_def, """
1188            node {
1189              name: "x" op: "Const"
1190              attr { key: "dtype" value { type: DT_FLOAT } }
1191              attr {
1192                key: "value"
1193                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
1194            }
1195            node {
1196              name: "x/read" op: "Identity" input: "x"
1197              attr { key: "T" value { type: DT_FLOAT } }
1198            }
1199            node {
1200              name: "y" op: "VariableV2"
1201              attr { key: "dtype" value { type: DT_FLOAT } }
1202            }
1203            node {
1204              name: "y/read" op: "Identity" input: "y"
1205              attr { key: "T" value { type: DT_FLOAT } }
1206            }
1207            node {
1208              name: "out" op: "Mul" input: "x/read" input: "y/read"
1209              attr {key: "T" value {type: DT_FLOAT}}
1210            }""")
1211
1212  def testConvertOneResourceVariableOfTwo(self):
1213    """Tests that one variable can be kept unconverted."""
1214    with ops.Graph().as_default():
1215      with variable_scope.variable_scope("", use_resource=True):
1216        x = variable_scope.get_variable("x", initializer=1.0)
1217        y = variable_scope.get_variable("y", initializer=1.0)
1218        _ = math_ops.multiply(x, y, name="out")
1219      with session_lib.Session() as sess:
1220        sess.run(variables.global_variables_initializer())
1221        variable_graph_def = sess.graph.as_graph_def()
1222        constant_graph_def = (
1223            convert_to_constants
1224            .convert_variables_to_constants_from_session_graph(
1225                sess,
1226                variable_graph_def, ["out"],
1227                variable_names_denylist=["y"]))
1228        self._assertGraphContains(
1229            constant_graph_def, """
1230            node {
1231              name: "x" op: "Const"
1232              attr { key: "dtype" value { type: DT_FLOAT } }
1233              attr {
1234                key: "value"
1235                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
1236            }
1237            node {
1238              name: "y" op: "VarHandleOp"
1239              attr { key: "dtype" value { type: DT_FLOAT } }
1240            }
1241            node {
1242              name: "out/ReadVariableOp" op: "Identity" input: "x"
1243              attr { key: "T" value { type: DT_FLOAT } }
1244            }
1245            node {
1246              name: "out/ReadVariableOp_1" op: "ReadVariableOp" input: "y"
1247              attr { key: "dtype" value { type: DT_FLOAT } }
1248            }
1249            node {
1250              name: "out" op: "Mul"
1251              input: "out/ReadVariableOp" input: "out/ReadVariableOp_1"
1252              attr {key: "T" value {type: DT_FLOAT}}
1253            }""")
1254
1255  def testConvertIdentityChain(self):
1256    """Tests that a chain of Identity ops is converted properly."""
1257    with ops.Graph().as_default():
1258      with variable_scope.variable_scope("", use_resource=True):
1259        x = variable_scope.get_variable("x", initializer=1.0)
1260        y = array_ops.identity(x, name="y")
1261        _ = array_ops.identity(y, name="z")
1262      with session_lib.Session() as sess:
1263        sess.run(variables.global_variables_initializer())
1264        variable_graph_def = sess.graph.as_graph_def()
1265        constant_graph_def = (
1266            convert_to_constants
1267            .convert_variables_to_constants_from_session_graph(
1268                sess, variable_graph_def, ["z"]))
1269        self._assertGraphContains(
1270            constant_graph_def, """
1271            node {
1272              name: "x" op: "Const"
1273              attr { key: "dtype" value { type: DT_FLOAT } }
1274              attr {
1275                key: "value"
1276                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
1277            }
1278            node {
1279              name: "y/ReadVariableOp" op: "Identity" input: "x"
1280              attr { key: "T" value { type: DT_FLOAT } }
1281            }
1282            node {
1283              name: "y" op: "Identity" input: "y/ReadVariableOp"
1284              attr { key: "T" value { type: DT_FLOAT } }
1285            }
1286            node {
1287              name: "z" op: "Identity" input: "y"
1288              attr { key: "T" value { type: DT_FLOAT } }
1289            }""")
1290
1291  def testConvertCase(self):
1292    """Tests that a v1 case() construction converts properly."""
1293    with ops.Graph().as_default():
1294      with variable_scope.variable_scope("", use_resource=False):
1295        control_flow_v2_toggles.disable_control_flow_v2()
1296        x = variable_scope.get_variable("x", initializer=1.0)
1297        y = variable_scope.get_variable("y", initializer=2.0)
1298        _ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: x)],
1299                                  default=lambda: y)
1300      with session_lib.Session() as sess:
1301        sess.run(variables.global_variables_initializer())
1302        variable_graph_def = sess.graph.as_graph_def()
1303        constant_graph_def = (
1304            convert_to_constants
1305            .convert_variables_to_constants_from_session_graph(
1306                sess, variable_graph_def, ["case/cond/Merge"]))
1307        self._assertGraphContains(
1308            constant_graph_def, """
1309            node {
1310              name: "x" op: "Const"
1311              attr { key: "dtype" value { type: DT_FLOAT } }
1312              attr {
1313                key: "value"
1314                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
1315            }
1316            node {
1317              name: "y" op: "Const"
1318              attr { key: "dtype" value { type: DT_FLOAT } }
1319              attr {
1320                key: "value"
1321                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 2 }}}
1322            }
1323            node {name: "x/read" op: "Identity" input: "x"}
1324            node {name: "y/read" op: "Identity" input: "y"}
1325            node {name: "Less" op: "Less" input: "x/read" input: "y/read"}
1326            node {name: "case/cond/pred_id" op: "Identity" input: "Less"}
1327            node {
1328              name: "case/cond/Switch_1" op: "Switch"
1329              input: "case/cond/pred_id" input: "x/read"
1330            }
1331            node {
1332              name: "case/cond/Switch_2" op: "Switch"
1333              input: "case/cond/pred_id" input: "y/read"
1334            }
1335            node {
1336              name: "case/cond/Merge" op: "Merge"
1337              input: "case/cond/Switch_2" input: "case/cond/Switch_1:1"
1338              attr {key: "T" value {type: DT_FLOAT}}
1339            }""")
1340
1341  def testConvertV2Case(self):
1342    """Tests that a v2 case() converts properly."""
1343    with ops.Graph().as_default():
1344      with variable_scope.variable_scope("", use_resource=False):
1345        control_flow_v2_toggles.enable_control_flow_v2()
1346        a = variable_scope.get_variable("a", initializer=2.0)
1347        x = variable_scope.get_variable("x", initializer=1.0)
1348        y = variable_scope.get_variable("y", initializer=2.0)
1349        _ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: a)],
1350                                  default=lambda: y)
1351        control_flow_v2_toggles.disable_control_flow_v2()
1352      with session_lib.Session() as sess:
1353        sess.run(variables.global_variables_initializer())
1354        variable_graph_def = sess.graph.as_graph_def()
1355        constant_graph_def = (
1356            convert_to_constants
1357            .convert_variables_to_constants_from_session_graph(
1358                sess, variable_graph_def, ["case/cond"]))
1359        self._assertGraphContains(
1360            constant_graph_def, """
1361            node {
1362              name: "x" op: "Const"
1363              attr { key: "dtype" value { type: DT_FLOAT } }
1364              attr {
1365                key: "value"
1366                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
1367            }
1368            node {
1369              name: "y" op: "Const"
1370              attr { key: "dtype" value { type: DT_FLOAT } }
1371              attr {
1372                key: "value"
1373                value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 2 }}}
1374            }
1375            node {name: "x/read" op: "Identity" input: "x"}
1376            node {name: "y/read" op: "Identity" input: "y"}
1377            node {name: "Less" op: "Less" input: "x/read" input: "y/read"}
1378            node {
1379              name: "case/cond" op: "StatelessIf"
1380              input: "Less" input: "a/read" input: "y/read"
1381              attr {key: "Tcond" value {type: DT_BOOL}}
1382              attr {key: "Tin" value {list {type: DT_FLOAT type: DT_FLOAT}}}
1383              attr {key: "Tout" value {list {type: DT_FLOAT}}}
1384            }
1385            library {
1386              function {
1387                signature {
1388                  name: "case_cond_false_frozen_0"
1389                  input_arg {name: "placeholder" type: DT_FLOAT}
1390                  input_arg {name: "y_read_0" type: DT_FLOAT}
1391                  output_arg {name: "y_read" type: DT_FLOAT}
1392                }
1393              }
1394              function {
1395                signature {
1396                  name: "case_cond_true_frozen_0"
1397                  input_arg {name: "a_read_0" type: DT_FLOAT}
1398                  input_arg {name: "placeholder" type: DT_FLOAT}
1399                  output_arg {name: "a_read" type: DT_FLOAT}
1400                }
1401              }
1402            }""")
1403
1404  def testConvertV2ResourceCase(self):
1405    """Tests that a v2 case() with resource variables converts properly."""
1406    with ops.Graph().as_default():
1407      with variable_scope.variable_scope("", use_resource=True):
1408        control_flow_v2_toggles.enable_control_flow_v2()
1409        x = variable_scope.get_variable("x", initializer=1.0)
1410        y = variable_scope.get_variable("y", initializer=2.0)
1411        _ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: x)],
1412                                  default=lambda: y)
1413        control_flow_v2_toggles.disable_control_flow_v2()
1414      with session_lib.Session() as sess:
1415        sess.run(variables.global_variables_initializer())
1416        variable_graph_def = sess.graph.as_graph_def()
1417        constant_graph_def = (
1418            convert_to_constants
1419            .convert_variables_to_constants_from_session_graph(
1420                sess, variable_graph_def, ["case/cond"]))
1421        self._assertGraphContains(
1422            constant_graph_def, """
1423            node {name: "x" op: "Const"}
1424            node {name: "y" op: "Const"}
1425            node {
1426              name: "case/cond" op: "If" input: "Less" input: "x" input: "y"
1427              attr {key: "Tcond" value {type: DT_BOOL}}
1428              attr {key: "Tin" value {list {type: DT_FLOAT type: DT_FLOAT}}}
1429              attr {key: "Tout" value {list {type: DT_FLOAT}}}
1430            }
1431            library {
1432              function {
1433                signature {
1434                  name: "case_cond_false_frozen_0"
1435                  input_arg {name: "placeholder" type: DT_FLOAT}
1436                  input_arg {name: "readvariableop_y" type: DT_FLOAT}
1437                  output_arg {name: "readvariableop" type: DT_FLOAT}
1438                }
1439              }
1440              function {
1441                signature {
1442                  name: "case_cond_true_frozen_0"
1443                  input_arg {name: "placeholder" type: DT_FLOAT}
1444                  input_arg {name: "readvariableop_x" type: DT_FLOAT}
1445                  output_arg {name: "readvariableop" type: DT_FLOAT}
1446                }
1447              }
1448            }""")
1449
1450  def testConvertV2UnconvertedResourceNestedCase(self):
1451    """Tests unconverted variable propagation through nested functions."""
1452    with ops.Graph().as_default():
1453      with variable_scope.variable_scope("", use_resource=True):
1454        control_flow_v2_toggles.enable_control_flow_v2()
1455        x = variable_scope.get_variable("x", initializer=1.0)
1456        y = variable_scope.get_variable("y", initializer=2.0)
1457        z = variable_scope.get_variable("z", initializer=3.0)
1458        # pylint: disable=g-long-lambda
1459        _ = control_flow_ops.case(
1460            [(gen_math_ops.less(x, y), lambda: x)],
1461            default=lambda: control_flow_ops.case(
1462                [(gen_math_ops.less(z, y), lambda: z)], default=lambda: y))
1463        # pylint: enable=g-long-lambda
1464        control_flow_v2_toggles.disable_control_flow_v2()
1465      with session_lib.Session() as sess:
1466        sess.run(variables.global_variables_initializer())
1467        variable_graph_def = sess.graph.as_graph_def()
1468        constant_graph_def = (
1469            convert_to_constants
1470            .convert_variables_to_constants_from_session_graph(
1471                sess,
1472                variable_graph_def, ["case/cond"],
1473                variable_names_denylist=["y"]))
1474        self._assertGraphContains(
1475            constant_graph_def, """
1476            node {name: "x" op: "Const"}
1477            node {name: "y" op: "VarHandleOp"}
1478            node {name: "z" op: "Const"}
1479
1480            node {name: "Less/ReadVariableOp" op: "Identity" input: "x"}
1481            node {name: "Less/ReadVariableOp_1" op: "ReadVariableOp" input: "y"}
1482
1483            node {
1484              name: "case/cond" op: "If"
1485              input: "x" input: "z" input: "y"
1486              attr {
1487                key: "Tin"
1488                value {list
1489                  {type: DT_FLOAT type: DT_FLOAT type: DT_RESOURCE}}}
1490              attr {
1491                key: "_read_only_resource_inputs"
1492                value {list {i: 1 i: 2 i: 3}}}
1493              attr {key: "then_branch"
1494                    value {func {name: "case_cond_true_frozen_0"}}}
1495              attr {key: "else_branch"
1496                    value {func {name: "case_cond_false_frozen_0"}}}
1497              attr {key: "output_shapes" value {list {shape {}}}}
1498            }
1499            library {
1500              function {
1501                signature {
1502                  name: "case_cond_true_frozen_0"
1503                  input_arg {name: "placeholder" type: DT_FLOAT}
1504                  input_arg {name: "placeholder_1" type: DT_RESOURCE}
1505                  input_arg {name: "readvariableop_x" type: DT_FLOAT}
1506                  output_arg {name: "readvariableop" type: DT_FLOAT}
1507                  is_stateful: true
1508                }
1509
1510                node_def {name: "ReadVariableOp" op: "Identity"
1511                  input: "readvariableop_x"}}
1512
1513              function {
1514                signature {
1515                  name: "case_cond_false_frozen_0"
1516                  input_arg {name: "placeholder" type: DT_FLOAT}
1517                  input_arg {name: "less_readvariableop_1_y" type: DT_RESOURCE}
1518                  input_arg {name: "less_readvariableop_z" type: DT_FLOAT}
1519                  output_arg {name: "case_cond_identity" type: DT_FLOAT}
1520                  is_stateful: true
1521                }
1522
1523                node_def {name: "Less/ReadVariableOp_1" op: "ReadVariableOp"
1524                  input: "less_readvariableop_1_y"}
1525
1526                node_def {name: "Less/ReadVariableOp" op: "Identity"
1527                  input: "less_readvariableop_z"}
1528
1529                node_def {name: "case/cond" op: "If"
1530                  input: "less_readvariableop_z"
1531                  input: "less_readvariableop_1_y"
1532                  attr {
1533                    key: "Tin"
1534                    value {list {type: DT_FLOAT type: DT_RESOURCE}}}
1535                  attr {key: "then_branch"
1536                        value {func {name: "case_cond_true_frozen_1"}}}
1537                  attr {key: "else_branch"
1538                        value {func {name: "case_cond_false_frozen_1"}}}
1539                  attr {
1540                    key: "_read_only_resource_inputs"
1541                    value {list {i: 1 i: 2}}}}}
1542
1543              function {
1544                signature {
1545                  name: "case_cond_false_frozen_1"
1546                  input_arg {name: "placeholder" type: DT_FLOAT}
1547                  input_arg {name: "readvariableop_y" type: DT_RESOURCE}
1548                  output_arg {name: "readvariableop" type: DT_FLOAT}
1549                  is_stateful: true
1550                }
1551
1552                node_def {name: "ReadVariableOp" op: "ReadVariableOp"
1553                  input: "readvariableop_y"}}
1554
1555              function {
1556                signature {
1557                  name: "case_cond_true_frozen_1"
1558                  input_arg {name: "placeholder" type: DT_RESOURCE}
1559                  input_arg {name: "readvariableop_z" type: DT_FLOAT}
1560                  output_arg {name: "readvariableop" type: DT_FLOAT}
1561                  is_stateful: true
1562                }
1563
1564                node_def {name: "ReadVariableOp" op: "Identity"
1565                  input: "readvariableop_z"}}}""")
1566
1567  def _addNoinlineAttributeToFunction(self, saved_model_dir, func_name):
1568    saved_model_proto = loader_impl.parse_saved_model(saved_model_dir)
1569    new_saved_model = saved_model_pb2.SavedModel()
1570    new_saved_model.CopyFrom(saved_model_proto)
1571    new_meta_graph_def = new_saved_model.meta_graphs[0]
1572    prefix_len = len("__inference_")
1573    for func_def in new_meta_graph_def.graph_def.library.function:
1574      func_name_without_prefix = func_def.signature.name[prefix_len:]
1575      if func_name_without_prefix.startswith(func_name):
1576        func_def.attr["_noinline"].CopyFrom(attr_value_pb2.AttrValue(b=True))
1577    old_saved_model_file = os.path.join(saved_model_dir,
1578                                        constants.SAVED_MODEL_FILENAME_PB)
1579    if os.path.exists(old_saved_model_file):
1580      os.remove(old_saved_model_file)
1581    path = os.path.join(
1582        compat.as_bytes(saved_model_dir),
1583        compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
1584    file_io.write_string_to_file(
1585        path, new_saved_model.SerializeToString(deterministic=True))
1586
1587  @test_util.run_v2_only
1588  def testVariableModelWithFunctionAndFunctionInliningDisabled(self):
1589    """Test a model with Variables and disable function inlining."""
1590
1591    class BasicModel:
1592
1593      def __init__(self):
1594        self.v1 = None
1595        self.v2 = variables.Variable(2.)
1596
1597      @def_function.function(input_signature=[
1598          tensor_spec.TensorSpec(shape=[1], dtype=dtypes.float32)
1599      ])
1600      def add_all(self, x):
1601        if self.v1 is None:
1602          self.v1 = variables.Variable(3.)
1603        return x + self.v1 + self.v2
1604
1605      def run(self, x):
1606        y = self.add_all(x)
1607        return y
1608
1609    save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model")
1610    with ops.Graph().as_default():
1611      model = BasicModel()
1612      a = array_ops.placeholder(dtypes.float32, shape=[1])
1613      b = model.run(a)
1614      with session_lib.Session() as sess:
1615        sess.run(variables.global_variables_initializer())
1616        simple_save.simple_save(sess, save_dir, {"myinput": a}, {"myoutput": b})
1617
1618    # Add _noinline to the SavedModel.
1619    self._addNoinlineAttributeToFunction(
1620        saved_model_dir=save_dir, func_name="add_all")
1621
1622    saved_model = load(save_dir)
1623    func = saved_model.signatures["serving_default"]
1624    frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
1625    constant_graph_def = frozen_func.graph.as_graph_def()
1626    self._ensure_no_variables_in_graph(constant_graph_def)
1627
1628
1629if __name__ == "__main__":
1630  test.main()
1631