xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/save_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for trackable object SavedModel save."""
16
17import os
18
19from absl.testing import parameterized
20
21from google.protobuf import text_format
22
23from tensorflow.core.config import flags
24from tensorflow.core.framework import graph_pb2
25from tensorflow.core.protobuf import graph_debug_info_pb2
26from tensorflow.python.checkpoint import checkpoint
27from tensorflow.python.client import session as session_lib
28from tensorflow.python.data.ops import dataset_ops
29from tensorflow.python.distribute import mirrored_strategy
30from tensorflow.python.eager import backprop
31from tensorflow.python.eager import context
32from tensorflow.python.eager import def_function
33from tensorflow.python.eager import test
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import meta_graph
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import test_util
40from tensorflow.python.framework import versions
41from tensorflow.python.lib.io import file_io
42from tensorflow.python.module import module
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import lookup_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import resource_variable_ops
48from tensorflow.python.ops import variables
49from tensorflow.python.ops.ragged import ragged_factory_ops
50from tensorflow.python.ops.ragged import ragged_tensor
51from tensorflow.python.saved_model import load
52from tensorflow.python.saved_model import loader
53from tensorflow.python.saved_model import loader_impl
54from tensorflow.python.saved_model import save
55from tensorflow.python.saved_model import save_options
56from tensorflow.python.saved_model import signature_constants
57from tensorflow.python.saved_model import tag_constants
58from tensorflow.python.trackable import asset
59from tensorflow.python.trackable import autotrackable
60from tensorflow.python.training import saver
61from tensorflow.python.util import compat
62
63
64def _run_signature(session, meta_graph_def, inputs, signature_key):
65  signature = meta_graph_def.signature_def[signature_key]
66  assert set(inputs.keys()) == set(signature.inputs.keys())
67  feed_dict = {}
68  for arg_name in inputs.keys():
69    input_tensor = session.graph.get_tensor_by_name(
70        signature.inputs[arg_name].name)
71    feed_dict[input_tensor] = inputs[arg_name]
72  output_dict = {}
73  for output_name, output_tensor_info in signature.outputs.items():
74    output_dict[output_name] = session.graph.get_tensor_by_name(
75        output_tensor_info.name)
76  return session.run(output_dict, feed_dict=feed_dict)
77
78
79def _import_and_infer(
80    save_dir,
81    inputs,
82    signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY):
83  """Import a SavedModel into a TF 1.x-style graph and run `signature_key`."""
84  graph = ops.Graph()
85  with graph.as_default(), session_lib.Session() as session:
86    model = loader.load(session, [tag_constants.SERVING], save_dir)
87    return _run_signature(session, model, inputs, signature_key)
88
89
90class SaveTest(test.TestCase, parameterized.TestCase):
91
92  def test_method_save_signature(self):
93    root = autotrackable.AutoTrackable()
94    root.f = def_function.function(
95        lambda x: 2. * x,
96        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
97    root.f(constant_op.constant(1.))
98    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
99    save.save(root, save_dir, root.f)
100    self.assertEqual({"output_0": 2.}, _import_and_infer(save_dir, {"x": 1.}))
101
102  def test_method_save_list_func(self):
103    root = autotrackable.AutoTrackable()
104
105    @def_function.function
106    def case_fn(x):
107      branch_index = constant_op.constant(1)
108      branches = [lambda: x, lambda: x + 1]
109      case_out = control_flow_ops.switch_case(branch_index, branches)
110      return case_out
111
112    root.f = def_function.function(
113        lambda x: 2. * case_fn(x),
114        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
115    root.f(constant_op.constant(1.))
116    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
117    save.save(root, save_dir, root.f)
118    self.assertEqual({"output_0": 4.}, _import_and_infer(save_dir, {"x": 1.}))
119
120  def test_method_save_concrete(self):
121    root = autotrackable.AutoTrackable()
122    root.f = def_function.function(lambda z: {"out": 2. * z})
123    root.f(constant_op.constant(1.))
124    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
125    save.save(
126        root, save_dir, {
127            "non_default_key":
128                root.f.get_concrete_function(
129                    tensor_spec.TensorSpec(None, dtypes.float32))
130        })
131    self.assertEqual({"out": 2.},
132                     _import_and_infer(
133                         save_dir, {"z": 1.}, signature_key="non_default_key"))
134
135  def test_method_save_annotated_function(self):
136    # This test is only meaningful with Python 3 because Python 2's
137    # inspect.getargspec doesn't save annotations.
138
139    root = autotrackable.AutoTrackable()
140
141    class UnknownType(object):  # pylint: disable=unused-variable
142      pass
143
144    def annotated_function(z):
145      return {"out": 2. * z}
146
147    # Same effect as annotating function like the following.
148    # def annotated_function("z": UnknownType) -> UnknownType:
149    # This is a workaround since Python 2 does not support annotations and
150    # our presubmit linter catches it.
151    annotated_function.__annotations__ = {
152        "z": UnknownType,
153        "return": UnknownType
154    }
155
156    root.f = def_function.function(annotated_function)
157    root.f(constant_op.constant(1.))
158    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
159    save.save(
160        root, save_dir, {
161            "non_default_key":
162                root.f.get_concrete_function(
163                    tensor_spec.TensorSpec(None, dtypes.float32))
164        })
165    self.assertEqual({"out": 2.},
166                     _import_and_infer(
167                         save_dir, {"z": 1.}, signature_key="non_default_key"))
168
169  def test_unsaveable_func_graph(self):
170    root = module.Module()
171
172    @def_function.function(input_signature=[])
173    def nested_f():
174      ops.get_default_graph().mark_as_unsaveable("ERROR MSG")
175      return 1
176
177    @def_function.function(input_signature=[])
178    def f():
179      return nested_f()
180
181    root.f = f
182    with self.assertRaisesRegex(ValueError, "ERROR MSG"):
183      save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
184
185  def test_untracked_variable_useful_message(self):
186    root = module.Module()
187    v = variables.Variable(1., name="some_unique_name")
188
189    @def_function.function(input_signature=[])
190    def f():
191      return v.read_value()
192
193    root.f = f
194    with self.assertRaisesRegex(
195        AssertionError, "Trackable referencing this tensor.*some_unique_name"):
196      save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
197
198  def test_version_information_included(self):
199    root = autotrackable.AutoTrackable()
200    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
201    save.save(root, save_dir)
202    saved_model_proto = loader_impl.parse_saved_model(save_dir)
203    self.assertEqual(
204        versions.__version__,
205        saved_model_proto.meta_graphs[0].meta_info_def.tensorflow_version)
206    self.assertEqual(
207        versions.__git_version__,
208        saved_model_proto.meta_graphs[0].meta_info_def.tensorflow_git_version)
209
210  def test_non_concrete_error(self):
211    root = autotrackable.AutoTrackable()
212    root.f = def_function.function(lambda x: 2. * x)
213    root.f(constant_op.constant(1.))
214    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
215    with self.assertRaisesRegex(ValueError, "Expected a TensorFlow function"):
216      save.save(root, save_dir, root.f)
217
218  def test_captures_unreachable_variable(self):
219    root = autotrackable.AutoTrackable()
220    unreachable_variable = variables.Variable([5.0, 2.0])
221    root.reachable_variable = variables.Variable([1.0, 3.0])
222
223    @def_function.function
224    def increase_variable(x):
225      return 2 * unreachable_variable * x + root.reachable_variable
226
227    root.f = increase_variable
228
229    self.assertAllEqual([101.0, 83.0],
230                        root.f(constant_op.constant([10.0, 20.0])).numpy())
231
232    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
233
234    with self.assertRaisesRegex(KeyError, "not reachable from root"):
235      save.save(root, save_dir)
236
237  def test_nested_inputs(self):
238    root = autotrackable.AutoTrackable()
239    root.f = def_function.function(
240        lambda x: 2. * x[0],
241        input_signature=([
242            tensor_spec.TensorSpec(None, dtypes.float32),
243            tensor_spec.TensorSpec(None, dtypes.float32)
244        ],))
245    root.f([constant_op.constant(1.), constant_op.constant(1.)])
246
247  def test_nested_outputs(self):
248    root = autotrackable.AutoTrackable()
249    root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x)))
250    root.f(constant_op.constant(1.))
251    to_save = root.f.get_concrete_function(constant_op.constant(1.))
252    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
253    with self.assertRaisesRegex(ValueError, "non-Tensor value"):
254      save.save(root, save_dir, to_save)
255
256  def test_nested_dict_outputs(self):
257    root = checkpoint.Checkpoint(
258        f=def_function.function(lambda x: {  # pylint: disable=g-long-lambda
259            "a": 2. * x,
260            "b": (3. * x, 4. * x)
261        }))
262    root.f(constant_op.constant(1.))
263    to_save = root.f.get_concrete_function(constant_op.constant(1.))
264    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
265    with self.assertRaisesRegex(ValueError, "non-Tensor value"):
266      save.save(root, save_dir, to_save)
267
268  def test_variable(self):
269    root = autotrackable.AutoTrackable()
270    root.v1 = variables.Variable(3.)
271    root.v2 = variables.Variable(2.)
272    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
273    root.f(constant_op.constant(1.))
274    to_save = root.f.get_concrete_function(constant_op.constant(1.))
275    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
276    save.save(root, save_dir, to_save)
277    self.assertAllEqual({"output_0": 12.},
278                        _import_and_infer(save_dir, {"x": 2.}))
279
280  def test_single_function_default_signature(self):
281    model = autotrackable.AutoTrackable()
282    model.f = def_function.function(lambda: 3., input_signature=())
283    model.f()
284    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
285    save.save(model, save_dir)
286    self.assertAllClose({"output_0": 3.}, _import_and_infer(save_dir, {}))
287
288  def test_single_function_no_signature(self):
289    model = autotrackable.AutoTrackable()
290    model.f = def_function.function(lambda: 3.)
291    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
292    save.save(model, save_dir)
293
294  def test_save_function_no_trace(self):
295
296    class ObjWithFunction(module.Module):
297
298      @def_function.function
299      def foo(self, a):
300        return a
301
302      @def_function.function
303      def bar(self, a):
304        return a + 1
305
306    root = ObjWithFunction()
307    root.bar(1)
308    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
309    with self.assertLogs(level="WARNING") as logs:
310      save.save(root, save_dir)
311
312    expected_message = (
313        "WARNING:absl:Found untraced functions such as foo while saving "
314        "(showing 1 of 1). These functions will not be directly callable after "
315        "loading.")
316    self.assertIn(expected_message, logs.output)
317
318  def test_find_default_save_function(self):
319
320    class ObjWithDefaultSignature(checkpoint.Checkpoint):
321
322      @def_function.function(input_signature=[
323          tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)
324      ])
325      def _default_save_signature(self, x):
326        return x + x + 1
327
328    obj = ObjWithDefaultSignature()
329    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
330    save.save(obj, save_dir)
331    self.assertAllClose({"output_0": 7.},
332                        _import_and_infer(save_dir, {"x": 3.}))
333
334  def test_docstring(self):
335
336    class Adder(module.Module):
337
338      @def_function.function(input_signature=[
339          tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)
340      ])
341      def add(self, x):
342        return x + x + 1.
343
344    to_save = Adder()
345    to_save.add(constant_op.constant(1.))
346    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
347    save.save(to_save, save_dir)
348    self.assertAllClose({"output_0": 7.},
349                        _import_and_infer(save_dir, {"x": 3.}))
350
351  def test_datastructures(self):
352
353    class HasDatastructures(checkpoint.Checkpoint):
354
355      def __init__(self):
356        self.a = [1.]
357        self.a.append(variables.Variable(2.))
358        self.b = {"a": variables.Variable(3.)}
359
360      @def_function.function(input_signature=[
361          tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)
362      ])
363      def add(self, x):
364        return x + math_ops.add_n(self.a) + self.b["a"]
365
366    to_save = HasDatastructures()
367    to_save.add(constant_op.constant(1.))
368    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
369    save.save(to_save, save_dir)
370    self.assertAllClose({"output_0": 10.},
371                        _import_and_infer(save_dir, {"x": 4.}))
372
373  def test_default_attr_stripping(self):
374
375    class Complex(checkpoint.Checkpoint):
376
377      @def_function.function(input_signature=[])
378      def __call__(self):
379        return math_ops.complex(
380            constant_op.constant(1.), constant_op.constant(2.), name="complex")
381
382    to_save = Complex()
383    to_save()
384    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
385    save.save(to_save, save_dir)
386    graph = ops.Graph()
387    with graph.as_default(), self.session(graph) as session:
388      loader.load(session, [tag_constants.SERVING], save_dir)
389      func, = [f for name, f in graph._functions.items() if "call" in name]
390      complex_node, = [
391          node for node in func.definition.node_def if node.op == "Complex"
392      ]
393      self.assertNotIn("T", complex_node.attr)
394      self.assertNotIn("Tout", complex_node.attr)
395
396  def test_signature_attribute_reserved(self):
397    root = checkpoint.Checkpoint(signatures=variables.Variable(1.))
398    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
399    with self.assertRaisesRegex(ValueError, "del obj.signatures"):
400      save.save(root, save_dir)
401    del root.signatures
402    save.save(root, save_dir)
403
404  def test_function_with_captured_dataset(self):
405    if test_util.is_gpu_available():
406      self.skipTest("Currently broken when a GPU is available.")
407
408    class HasDataset(module.Module):
409
410      def __init__(self):
411        super(HasDataset, self).__init__()
412        self.dataset = (dataset_ops.Dataset.range(5).map(lambda x: x**2))
413
414      @def_function.function
415      def __call__(self, x):
416        current_sum = array_ops.zeros([], dtype=dtypes.int64)
417        for element in self.dataset:
418          current_sum += x * element
419        return current_sum
420
421    root = HasDataset()
422    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
423    save.save(
424        root,
425        save_dir,
426        signatures=root.__call__.get_concrete_function(
427            tensor_spec.TensorSpec(None, dtypes.int64)))
428    self.assertAllClose({"output_0": 3 * (1 + 4 + 9 + 16)},
429                        _import_and_infer(save_dir, {"x": 3}))
430
431  def test_variable_args_cannot_be_used_as_signature(self):
432
433    @def_function.function(input_signature=[
434        resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)
435    ])
436    def f(unused_v):
437      return 1
438
439    root = autotrackable.AutoTrackable()
440    root.f = f.get_concrete_function()
441    with self.assertRaisesRegex(ValueError,
442                                "tf.Variable inputs cannot be exported"):
443      save.save(
444          root,
445          os.path.join(self.get_temp_dir(), "saved_model"),
446          signatures=root.f)
447
448  def test_export_correct_output_shapes(self):
449    """Asserts that nodes are exported with the correct number of output shapes.
450
451    After backpropagation rewrite, functions are rewritten with additional
452    outputs. When exporting to SavedModel, the shapes of the additional outputs
453    were incorrectly added to the FunctionDef proto (b/133666530).
454    """
455    obj = autotrackable.AutoTrackable()
456    obj.v = variables.Variable(2.)
457
458    @def_function.function(
459        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
460    def f(x):
461      return (math_ops.multiply(obj.v, x), math_ops.multiply(obj.v,
462                                                             (x + 1)), None)
463
464    obj.f = f
465
466    @def_function.function(
467        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
468    def g(x):
469      return obj.f(x)[1]
470
471    obj.g = g
472
473    # After the following lines, the concrete functions of obj.g and obj.f are
474    # rewritten with many extra outputs.
475    with backprop.GradientTape():
476      obj.g(constant_op.constant(3.0))
477
478    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
479    save.save(obj, save_dir, signatures={"g": obj.g})
480    graph_def = loader_impl.parse_saved_model(save_dir).meta_graphs[0].graph_def
481
482    def assert_correct_number_of_output_shapes(node):
483      if node.op == "StatefulPartitionedCall":
484        fn_name = node.attr["f"].func.name
485        if fn_name.startswith("__inference_f"):
486          self.assertLen(node.attr["_output_shapes"].list.shape, 2)
487        if fn_name.startswith("__inference_g"):
488          self.assertLen(node.attr["_output_shapes"].list.shape, 1)
489
490    for f in graph_def.library.function:
491      if (f.signature.name.startswith("__inference_f") or
492          f.signature.name.startswith("__inference_g")):
493        for node in f.node_def:
494          assert_correct_number_of_output_shapes(node)
495
496  def test_save_cached_variable(self):
497    with ops.Graph().as_default(), session_lib.Session() as session:
498      obj = autotrackable.AutoTrackable()
499      obj.v = variables.Variable(2., caching_device=lambda op: op.device)
500      obj.w = variables.Variable(3.)
501      session.run([obj.v.initializer, obj.w.initializer])
502
503      @def_function.function(input_signature=[])
504      def f():
505        return obj.v + obj.w
506
507      obj.f = f
508      save_dir = os.path.join(self.get_temp_dir(), "saved_model")
509      save.save(obj, save_dir, signatures=obj.f)
510      self.assertAllClose({"output_0": 5}, _import_and_infer(save_dir, {}))
511
512  @parameterized.named_parameters(
513      ("_SaveDevices_ExportMetaGraph",
514       save_options.VariablePolicy.SAVE_VARIABLE_DEVICES, True),
515      ("_DiscardDevices_ExportMetaGraph", save_options.VariablePolicy.NONE,
516       True), ("_SaveDevices_Save",
517               save_options.VariablePolicy.SAVE_VARIABLE_DEVICES, False),
518      ("_DiscardDevices_Save", save_options.VariablePolicy.NONE, False))
519  def test_save_variable_devices(self, save_devices, meta_graph_only):
520    context._reset_context()
521    cpus = context.context().list_physical_devices("CPU")
522    if len(cpus) == 1:
523      context.context().set_logical_device_configuration(
524          cpus[0], [
525              context.LogicalDeviceConfiguration(),
526              context.LogicalDeviceConfiguration()
527          ])
528    context.ensure_initialized()
529
530    root = autotrackable.AutoTrackable()
531    with ops.device("CPU:0"):
532      root.v0 = variables.Variable(1., name="v0")
533    with ops.device("CPU:1"):
534      root.v1 = variables.Variable(1., name="v1")
535
536    options = save_options.SaveOptions(
537        experimental_variable_policy=save_devices)
538    file_name = os.path.join(self.get_temp_dir(), "saved_model")
539    if meta_graph_only:
540      save.export_meta_graph(obj=root, filename=file_name, options=options)
541    else:
542      save.save(obj=root, export_dir=file_name, options=options)
543
544    meta = None
545    if meta_graph_only:
546      meta = meta_graph.read_meta_graph_file(file_name)
547    else:
548      meta = loader_impl.parse_saved_model(file_name).meta_graphs[0]
549
550    # Check devices in meta graph nodes.
551    graph_def = meta.graph_def
552    v0 = next((n for n in graph_def.node if n.name == "v0"), None)
553    v1 = next((n for n in graph_def.node if n.name == "v1"), None)
554    self.assertIsNotNone(v0)
555    self.assertIsNotNone(v1)
556    if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES:
557      self.assertIn("CPU:0", v0.device)
558      self.assertIn("CPU:1", v1.device)
559    else:
560      self.assertEmpty(v0.device)
561      self.assertEmpty(v1.device)
562
563    # Check devices in object graph nodes.
564    object_graph_def = meta.object_graph_def
565    v0 = next((n.variable
566               for n in object_graph_def.nodes
567               if n.HasField("variable") and n.variable.name == "v0"), None)
568    v1 = next((n.variable
569               for n in object_graph_def.nodes
570               if n.HasField("variable") and n.variable.name == "v1"), None)
571    self.assertIsNotNone(v0)
572    self.assertIsNotNone(v1)
573    if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES:
574      self.assertIn("CPU:0", v0.device)
575      self.assertIn("CPU:1", v1.device)
576    else:
577      self.assertEmpty(v0.device)
578      self.assertEmpty(v1.device)
579
580  @parameterized.named_parameters(
581      ("_ExpandDistributedVariablesWithPolicy",
582       save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, True),
583      ("_ExpandDistributedVariablesWithoutPolicy",
584       save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, False),
585      ("_DiscardDistributedVariablesWithPolicy",
586       save_options.VariablePolicy.NONE, True),
587      ("_DiscardDistributedVariablesWithoutPolicy",
588       save_options.VariablePolicy.NONE, False))
589  def test_expand_distributed_variables(self, expand_strategy, policy):
590    # 1. Create a context with both CPU:0 and CPU:1.
591    context._reset_context()
592    cpus = context.context().list_physical_devices("CPU")
593    if len(cpus) == 1:
594      context.context().set_logical_device_configuration(
595          cpus[0], [
596              context.LogicalDeviceConfiguration(),
597              context.LogicalDeviceConfiguration()
598          ])
599    context.ensure_initialized()
600
601    # 2. Create and save a model under a mirrored strategy.
602    file_name = os.path.join(self.get_temp_dir(), "saved_model.pb")
603    strategy = mirrored_strategy.MirroredStrategy(["CPU:0", "CPU:1"])
604    strategy.extended._use_var_policy = policy
605    with strategy.scope():
606      root = autotrackable.AutoTrackable()
607      root.v = variables.Variable([1., 1.], name="v")
608
609      @def_function.function(input_signature=[])
610      def f():
611        root.v.assign([2., 2.])
612
613      root.f = f
614
615      save.export_meta_graph(
616          obj=root,
617          filename=file_name,
618          options=save_options.SaveOptions(
619              experimental_variable_policy=expand_strategy))
620
621    # 3. Read the output file and test behavior.
622    meta_graph_def = meta_graph.read_meta_graph_file(file_name)
623    object_graph = meta_graph_def.object_graph_def
624    graph_def = meta_graph_def.graph_def
625    v = next((n.variable
626              for n in object_graph.nodes
627              if n.HasField("variable") and n.variable.name == "v"), None)
628    saved_function = next((f for f in graph_def.library.function
629                           if "inference_f_" in f.signature.name), None)
630    self.assertIsNotNone(saved_function)
631    if (expand_strategy ==
632        save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES):
633      # experimental_save_variable_devices should have been automatically set.
634      self.assertIn("CPU:0", v.device)
635      components = v.experimental_distributed_variable_components
636      self.assertLen(components, 2)
637      v0 = next((x for x in components if x.name == "v"), None)
638      v1 = next((x for x in components if x.name == "v/replica_1"), None)
639      self.assertIsNotNone(v0)
640      self.assertIsNotNone(v1)
641      self.assertIn("CPU:0", v0.device)
642      self.assertIn("CPU:1", v1.device)
643      self.assertLen(saved_function.signature.input_arg, 2)
644    else:
645      self.assertEmpty(v.device)
646      self.assertEmpty(v.experimental_distributed_variable_components)
647      self.assertLen(saved_function.signature.input_arg, 1)
648
649  def test_save_uninitialized_variable(self):
650    root = autotrackable.AutoTrackable()
651    root.uninitialized_variable = resource_variable_ops.UninitializedVariable(
652        name="uninitialized_variable", dtype=dtypes.float32)
653    root.initialized_variable = variables.Variable(
654        1.0, name="initialized_variable")
655
656    # TODO(b/149594077): Python loading does not work now partly because it
657    # shouldn't, as the public API and semantics of uninitialized variables
658    # are not properly defined, and officially supporting loading would end up
659    # defining semantics "by usage." We should only allow loading once the API
660    # is made official.
661    export_dir = os.path.join(self.get_temp_dir(), "saved_model")
662    save.save(root, export_dir)
663    with self.assertRaisesRegex(FileNotFoundError,
664                                "Key uninitialized_variable"):
665      load.load(export_dir)
666    with ops.Graph().as_default(), session_lib.Session() as session:
667      # The final ValueError here (with "no variables to save") is confusing,
668      # but errors upstream give the user the correct information (a
669      # NotFoundError stating that the uninitalized_variable was not found in
670      # the checkpoint).
671      with self.assertRaises(ValueError):
672        loader.load(session, [tag_constants.SERVING], export_dir)
673
674  def test_concrete_function_with_set_shape(self,):
675    # Serialized concrete function should retain the shape from the TensorSpec,
676    # instead of using the shape of the inputs (which are changed by set_shape).
677    @def_function.function
678    def f(x):
679      x.set_shape((5, 1))
680      return x
681
682    root = autotrackable.AutoTrackable()
683    path = os.path.join(self.get_temp_dir(), "saved_model")
684    concrete = f.get_concrete_function(
685        tensor_spec.TensorSpec((None, 1), name="name"))
686    save.save(root, path, signatures={"key": concrete})
687    imported = load.load(path)
688    self.assertEqual(imported.signatures["key"].structured_input_signature[1],
689                     {"name": tensor_spec.TensorSpec((None, 1), name="name")})
690
691  def test_save_composite_tensor_signature(self):
692    @def_function.function(
693        input_signature=[ragged_tensor.RaggedTensorSpec(ragged_rank=2)])
694    def f(x):
695      return {"output_key": x}
696    root = autotrackable.AutoTrackable()
697    path = os.path.join(self.get_temp_dir(), "saved_model")
698    inp = ragged_factory_ops.constant([[[1.0, 2.0], [3.0]], [[5.]]])
699    flat_inp = {
700        "x": constant_op.constant([1., 2., 3., 5]),
701        "x_1": constant_op.constant([0, 2, 3], dtype=dtypes.int64),
702        "x_2": constant_op.constant([0, 2, 3, 4], dtype=dtypes.int64)
703    }
704    save.save(root, path, signatures={"key": f.get_concrete_function()})
705
706    # Test that the ragged signature can be loaded back into Python with V2 APIs
707    imported = load.load(path)
708    self.assertAllEqual(inp,
709                        imported.signatures["key"](**flat_inp)["output_key"])
710    graph = ops.Graph()
711
712    # Try running the signature with V1 APIs.
713    with graph.as_default(), session_lib.Session() as session:
714      meta_graph_def = loader.load(session, [tag_constants.SERVING], path)
715      signature = meta_graph_def.signature_def["key"]
716
717      feed_dict = {}
718      for arg_name in flat_inp:
719        input_tensor = session.graph.get_tensor_by_name(
720            signature.inputs[arg_name].name)
721        feed_dict[input_tensor] = flat_inp[arg_name].numpy()
722
723      # Get composite tensor components
724      output_components = (
725          signature.outputs["output_key"].composite_tensor.components)
726      fetches = {}
727      components_keys = ["x", "x_1", "x_2"]
728      for k, output_tensor_info in zip(components_keys, output_components):
729        fetches[k] = session.graph.get_tensor_by_name(output_tensor_info.name)
730
731      outputs = session.run(fetches, feed_dict)
732
733    self.assertAllClose(flat_inp, outputs)
734
735  def test_save_uses_sanitized_signature_name(self):
736
737    @def_function.function(
738        input_signature=[ragged_tensor.RaggedTensorSpec(ragged_rank=2)])
739    def f(x):
740      return {"output_key": x}
741
742    # Colons are not usable as name scopes.
743    unsanitized_name = "foo:bar"
744    root = autotrackable.AutoTrackable()
745    path = os.path.join(self.get_temp_dir(), "saved_model")
746    save.save(
747        root, path, signatures={unsanitized_name: f.get_concrete_function()})
748    graph = ops.Graph()
749    with graph.as_default(), session_lib.Session() as session:
750      meta_graph_def = loader.load(session, [tag_constants.SERVING], path)
751      signature = meta_graph_def.signature_def[unsanitized_name]
752      tensor_names = [
753          session.graph.get_tensor_by_name(signature.inputs[key].name).name
754          for key in signature.inputs
755      ]
756      # The placeholder names will have the sanitized version.
757      self.assertCountEqual(tensor_names,
758                            ["foo_bar_x:0", "foo_bar_x_1:0", "foo_bar_x_2:0"])
759
760  def test_save_returns_none(self):
761    # Test that `tf.saved_model.save` API returns None to user.
762    root = autotrackable.AutoTrackable()
763    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
764    result = save.save(root, save_dir)
765    self.assertIsNone(result)
766
767
768class DependencyTest(test.TestCase):
769  """Tests for deserialization dependencies (saving-related only)."""
770
771  def test_validate_dependencies(self):
772
773    class Valid(autotrackable.AutoTrackable):
774
775      def _deserialization_dependencies(self, children):
776        return children
777
778    root = Valid()
779    root.f = variables.Variable(1.0)
780    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
781    save.save(root, save_dir)
782
783  def test_validate_dependencies_error_untracked(self):
784    untracked = variables.Variable(1.0)
785
786    class Invalid(autotrackable.AutoTrackable):
787
788      def _deserialization_dependencies(self, children):
789        del children  # Unused.
790        return {"untracked": untracked}
791    invalid_deps = Invalid()
792    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
793    with self.assertRaisesRegex(ValueError, "Found an untracked dependency"):
794      save.save(invalid_deps, save_dir)
795
796  def test_validate_dependencies_error_cyclic(self):
797
798    class Invalid(autotrackable.AutoTrackable):
799
800      def __init__(self):
801        self.cycle_ref = None
802
803      def _deserialization_dependencies(self, children):
804        del children  # Unused.
805        return {"cycle_ref": self.cycle_ref}
806    cycle1 = Invalid()
807    cycle2 = Invalid()
808    cycle1.cycle_ref = cycle2
809    cycle2.cycle_ref = cycle1
810    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
811    with self.assertRaisesRegex(ValueError,
812                                "dependency cycle in the saved Trackable"):
813      save.save(cycle1, save_dir)
814
815
816class VariablePolicyEnumTest(test.TestCase):
817
818  def testFromObj(self):
819    self.assertEqual(save_options.VariablePolicy.NONE,
820                     save_options.VariablePolicy.from_obj(None))
821    self.assertEqual(
822        save_options.VariablePolicy.SAVE_VARIABLE_DEVICES,
823        save_options.VariablePolicy.from_obj(
824            save_options.VariablePolicy.SAVE_VARIABLE_DEVICES))
825    self.assertEqual(
826        save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES,
827        save_options.VariablePolicy.from_obj(
828            save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES))
829    self.assertEqual(
830        save_options.VariablePolicy.SAVE_VARIABLE_DEVICES,
831        save_options.VariablePolicy.from_obj("save_variable_devices"))
832    self.assertEqual(
833        save_options.VariablePolicy.SAVE_VARIABLE_DEVICES,
834        save_options.VariablePolicy.from_obj("SaVe_VaRiAbLe_DeViCeS"))
835    self.assertEqual(
836        save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES,
837        save_options.VariablePolicy.from_obj("expand_distributed_variables"))
838    self.assertEqual(
839        save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES,
840        save_options.VariablePolicy.from_obj("eXpAnD_dIsTrIbUtEd_VaRiAbLeS"))
841    for invalid in ["not_a_valid_value", 2.0, []]:
842      with self.assertRaisesRegex(ValueError, "invalid VariablePolicy value"):
843        save_options.VariablePolicy.from_obj(invalid)
844
845  def testNamingConvention(self):
846    """Enforces names are uppercase versions of values."""
847    for policy in save_options.VariablePolicy:
848      if policy == save_options.VariablePolicy.NONE:
849        self.assertIsNone(policy.value)
850      else:
851        self.assertEqual(policy.name, policy.name.upper())
852        self.assertEqual(policy.value, policy.value.lower())
853        self.assertEqual(policy.name, policy.value.upper())
854
855
856class SavingOptionsTest(test.TestCase):
857
858  def testOpNameSpace(self):
859    # TODO(kathywu): Add test that saves out SavedModel with a custom op when
860    # the ">" character is allowed in op names.
861    graph_def = graph_pb2.GraphDef()
862    text_format.Parse("node { name: 'A' op: 'Test>CustomOp' }", graph_def)
863    with self.assertRaisesRegex(
864        ValueError, "Attempted to save ops from non-whitelisted namespaces"):
865      save._verify_ops(graph_def, [])
866    save._verify_ops(graph_def, ["Test"])
867
868    # Test with multiple carrots in op name.
869    text_format.Parse("node { name: 'A' op: 'Test>>A>CustomOp' }", graph_def)
870    with self.assertRaisesRegex(
871        ValueError, "Attempted to save ops from non-whitelisted namespaces"):
872      save._verify_ops(graph_def, [])
873    save._verify_ops(graph_def, ["Test"])
874
875  def test_save_custom_op_with_no_whitelist_specified(self):
876    # Test that we are able to save a model that contains a custom op with a
877    # custom namespace when the user has not explicitly specified a namespace
878    # whitelist (i.e. that we default to allowing all custom ops when saving
879    # and no whitelist is specified, rather than throwing an exception).
880    graph_def = graph_pb2.GraphDef()
881    text_format.Parse("node { name: 'A' op: 'Test>CustomOp' }", graph_def)
882    save._verify_ops(graph_def, namespace_whitelist=None)
883
884    # If the user passes an empty list for the namespace whitelist rather than
885    # nothing, we should then throw an exception if a custom op is used.
886    with self.assertRaisesRegex(
887        ValueError, "Attempted to save ops from non-whitelisted namespaces"):
888      save._verify_ops(graph_def, [])
889
890  def test_save_debug_info_enabled(self):
891    root = autotrackable.AutoTrackable()
892    root.f = def_function.function(
893        lambda x: math_ops.mul(2., x, name="DEBUG_INFO_OP"),
894        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
895    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
896    save.save(
897        root,
898        save_dir,
899        root.f,
900        options=save_options.SaveOptions(save_debug_info=True))
901    debug_info_file_name = os.path.join(save_dir, "debug",
902                                        "saved_model_debug_info.pb")
903    self.assertTrue(os.path.exists(debug_info_file_name))
904    debug_info = graph_debug_info_pb2.GraphDebugInfo()
905    with open(debug_info_file_name, "rb") as f:
906      debug_info.ParseFromString(f.read())
907
908    # Verify that there is a trace for DEBUG_INFO_OP just to ensure that
909    # function debug info tracing is nominally functioning.
910    found_op = False
911    for key in debug_info.traces.keys():
912      if key.startswith("DEBUG_INFO_OP@"):
913        found_op = True
914        break
915    self.assertTrue(found_op, "Did not find DEBUG_INFO_OP in trace")
916
917  def test_save_debug_info_disabled(self):
918    root = autotrackable.AutoTrackable()
919    root.f = def_function.function(
920        lambda x: math_ops.mul(2., x, name="DEBUG_INFO_OP"),
921        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
922    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
923    save.save(
924        root,
925        save_dir,
926        root.f,
927        options=save_options.SaveOptions(save_debug_info=False))
928    debug_info_file_name = os.path.join(save_dir, "debug",
929                                        "saved_model_debug_info.pb")
930    self.assertFalse(os.path.exists(debug_info_file_name))
931
932  def test_function_aliases(self):
933    root = autotrackable.AutoTrackable()
934    root.f = def_function.function(
935        lambda x: 2. * x,
936        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
937    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
938    options = save_options.SaveOptions(function_aliases={
939        "my_func": root.f,
940    })
941    save.save(root, save_dir, root.f, options=options)
942    function_cache = root.f._stateful_fn._list_all_concrete_functions()
943    function_aliases = loader_impl.parse_saved_model(
944        save_dir).meta_graphs[0].meta_info_def.function_aliases
945    self.assertLen(function_cache, 1)
946    self.assertEqual(function_cache[0].name.decode("utf-8"),
947                     list(function_aliases.keys())[0])
948
949  def test_accepts_io_device(self):
950    options = save_options.SaveOptions()
951    self.assertIsNone(options.experimental_io_device)
952    options = save_options.SaveOptions(experimental_io_device="/job:localhost")
953    self.assertEqual("/job:localhost", options.experimental_io_device)
954
955  def test_accepts_variable_policy(self):
956    options = save_options.SaveOptions()
957    self.assertEqual(save_options.VariablePolicy.NONE,
958                     options.experimental_variable_policy)
959    # VariablePolicy instances.
960    options = save_options.SaveOptions(experimental_variable_policy=save_options
961                                       .VariablePolicy.SAVE_VARIABLE_DEVICES)
962    self.assertEqual(save_options.VariablePolicy.SAVE_VARIABLE_DEVICES,
963                     options.experimental_variable_policy)
964    options = save_options.SaveOptions(
965        experimental_variable_policy=save_options.VariablePolicy
966        .EXPAND_DISTRIBUTED_VARIABLES)
967    self.assertEqual(save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES,
968                     options.experimental_variable_policy)
969    # String conversions.
970    options = save_options.SaveOptions(
971        experimental_variable_policy="save_variable_devices")
972    self.assertEqual(save_options.VariablePolicy.SAVE_VARIABLE_DEVICES,
973                     options.experimental_variable_policy)
974    options = save_options.SaveOptions(
975        experimental_variable_policy="expand_distributed_variables")
976    self.assertEqual(save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES,
977                     options.experimental_variable_policy)
978    with self.assertRaisesRegex(ValueError, "invalid VariablePolicy value"):
979      options = save_options.SaveOptions(
980          experimental_variable_policy="not_a_valid_value")
981
982
983class AssetTests(test.TestCase):
984
985  def setUp(self):
986    super(AssetTests, self).setUp()
987    self._vocab_path = os.path.join(self.get_temp_dir(), "vocab.txt")
988    with open(self._vocab_path, "w") as f:
989      f.write("alpha\nbeta\ngamma\n")
990
991  def test_asset_path_returned(self):
992    root = autotrackable.AutoTrackable()
993    root.path = asset.Asset(self._vocab_path)
994    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
995    root.get_asset = def_function.function(lambda: root.path.asset_path)
996    save.save(root, save_dir, signatures=root.get_asset.get_concrete_function())
997    second_dir = os.path.join(self.get_temp_dir(), "second_dir")
998    file_io.rename(save_dir, second_dir)
999    imported_path = _import_and_infer(second_dir, {})["output_0"]
1000    self.assertIn(
1001        compat.as_str_any(second_dir), compat.as_str_any(imported_path))
1002
1003  def test_table(self):
1004    initializer = lookup_ops.TextFileInitializer(
1005        self._vocab_path,
1006        key_dtype=dtypes.string,
1007        key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
1008        value_dtype=dtypes.int64,
1009        value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
1010    root = checkpoint.Checkpoint(
1011        table=lookup_ops.HashTable(initializer, default_value=-1))
1012    root.table_user = def_function.function(
1013        root.table.lookup,
1014        input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
1015    self.assertEqual(
1016        2, self.evaluate(root.table_user(constant_op.constant("gamma"))))
1017    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
1018    save.save(root, save_dir)
1019    file_io.delete_file(self._vocab_path)
1020    self.assertAllClose({"output_0": [2, 0]},
1021                        _import_and_infer(save_dir,
1022                                          {"keys": ["gamma", "alpha"]}))
1023    second_dir = os.path.join(self.get_temp_dir(), "second_dir")
1024    # Asset paths should track the location the SavedModel is loaded from.
1025    file_io.rename(save_dir, second_dir)
1026    self.assertAllClose({"output_0": [2, 1]},
1027                        _import_and_infer(second_dir,
1028                                          {"keys": ["gamma", "beta"]}))
1029
1030  def test_untracked_table_useful_message(self):
1031    root = module.Module()
1032    initializer = lookup_ops.TextFileInitializer(
1033        self._vocab_path,
1034        key_dtype=dtypes.string,
1035        key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
1036        value_dtype=dtypes.int64,
1037        value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
1038    table = lookup_ops.HashTable(initializer, default_value=-1)
1039    root.table_user = def_function.function(
1040        table.lookup,
1041        input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
1042    root.table_user(constant_op.constant("gamma"))
1043    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
1044    with self.assertRaisesRegexp(AssertionError, "HashTable"):
1045      save.save(root, save_dir)
1046
1047  def test_unused_asset(self):
1048    root = autotrackable.AutoTrackable()
1049    root.f = def_function.function(
1050        lambda x: 2. * x,
1051        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
1052    root.asset = asset.Asset(self._vocab_path)
1053
1054    export_dir = os.path.join(self.get_temp_dir(), "save_dir")
1055    save.save(root, export_dir)
1056    self.assertAllClose({"output_0": [0.2]},
1057                        _import_and_infer(export_dir, {"x": [0.1]}))
1058
1059  def test_sensible_function_building_exception(self):
1060    root = checkpoint.Checkpoint(v=variables.Variable(2.))
1061    root.f = def_function.function(
1062        lambda x: 2. * root.v,
1063        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
1064    export_dir = os.path.join(self.get_temp_dir(), "save_dir")
1065
1066    @def_function.function
1067    def _calls_save():
1068      save.save(root, export_dir)
1069
1070    with self.assertRaisesRegex(AssertionError, "tf.function"):
1071      _calls_save()
1072
1073
1074class ExportMetaGraphTests(test.TestCase):
1075
1076  def test_export_meta_graph(self):
1077    root = autotrackable.AutoTrackable()
1078    root.variable = resource_variable_ops.UninitializedVariable(
1079        name="some_variable", dtype=dtypes.float32)
1080
1081    @def_function.function(input_signature=[tensor_spec.TensorSpec(None)])
1082    def multiply_var(x):
1083      return root.variable * x
1084
1085    @def_function.function(input_signature=[tensor_spec.TensorSpec([])])
1086    def update(y):
1087      root.variable.assign_add(y)
1088      # TODO(b/150393409): All functions exported as signatures must have at
1089      # least one output.
1090      return 0
1091
1092    @def_function.function(input_signature=[])
1093    def initialize():
1094      root.variable.assign(1.0)
1095      # TODO(b/150393409): All functions exported as signatures must have at
1096      # least one output.
1097      return 0
1098
1099    save_path = os.path.join(self.get_temp_dir(), "meta_graph.pb")
1100    save.export_meta_graph(
1101        root,
1102        save_path,
1103        signatures={
1104            "multiply_var": multiply_var,
1105            "initialize": initialize,
1106            "update": update
1107        })
1108
1109    with ops.Graph().as_default(), session_lib.Session() as session:
1110      saver.import_meta_graph(save_path)
1111      meta_graph_def = meta_graph.read_meta_graph_file(save_path)
1112
1113      # Initialize variable to 1
1114      _run_signature(session, meta_graph_def, {}, "initialize")
1115      out = _run_signature(session, meta_graph_def, {"x": 3}, "multiply_var")
1116      self.assertAllEqual(out, {"output_0": 3})
1117
1118      # Adds 2 to the variable. Variable is now 3
1119      _run_signature(session, meta_graph_def, {"y": 2}, "update")
1120      out = _run_signature(session, meta_graph_def, {"x": 4}, "multiply_var")
1121      self.assertAllEqual(out, {"output_0": 12})
1122
1123
1124class FingerprintingTests(test.TestCase):
1125
1126  def test_toggle_flag(self):
1127    self.assertFalse(flags.config().saved_model_fingerprinting.value())
1128    flags.config().saved_model_fingerprinting.reset(True)
1129    self.assertTrue(flags.config().saved_model_fingerprinting.value())
1130
1131
1132if __name__ == "__main__":
1133  test.main()
1134