xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Serves as a common "main" function for all the SavedModel tests.
16
17There is a fair amount of setup needed to initialize tensorflow and get it
18into a proper TF2 execution mode. This hides that boilerplate.
19"""
20
21import tempfile
22from absl import app
23from absl import flags
24from absl import logging
25import tensorflow.compat.v1 as tf
26
27from tensorflow.python import pywrap_mlir  # pylint: disable=g-direct-tensorflow-import
28
29# Use /tmp to make debugging the tests easier (see README.md)
30flags.DEFINE_string('save_model_path', '', 'Path to save the model to.')
31FLAGS = flags.FLAGS
32
33
34def set_tf_options():
35  # Default TF1.x uses reference variables that are not supported by SavedModel
36  # v1 Importer. To use SavedModel V1 Importer, resource variables should be
37  # enabled.
38  tf.enable_resource_variables()
39  tf.compat.v1.disable_eager_execution()
40
41
42# This function needs to take a "create_module_fn", as opposed to just the
43# module itself, because the creation of the module has to be delayed until
44# after absl and tensorflow have run various initialization steps.
45def do_test(create_signature,
46            canonicalize=False,
47            show_debug_info=False,
48            use_lite=False,
49            lift_variables=True):
50  """Runs test.
51
52  1. Performs absl and tf "main"-like initialization that must run before almost
53     anything else.
54  2. Converts signature_def_map to SavedModel V1
55  3. Converts SavedModel V1 to MLIR
56  4. Prints the textual MLIR to stdout (it is expected that the caller will have
57     FileCheck checks in its file to check this output).
58
59  This is only for use by the MLIR SavedModel importer tests.
60
61  Args:
62    create_signature: A functor that return signature_def_map, init_op and
63      assets_collection. signature_def_map is a map from string key to
64      signature_def. The key will be used as function name in the resulting
65      MLIR.
66    canonicalize: If true, canonicalizer will be run on the resulting MLIR.
67    show_debug_info: If true, shows debug locations in the resulting MLIR.
68    use_lite: If true, importer will not do any graph transformation such as
69      lift variables.
70    lift_variables: If false, no variable lifting will be done on the graph.
71  """
72
73  # Make LOG(ERROR) in C++ code show up on the console.
74  # All `Status` passed around in the C++ API seem to eventually go into
75  # `LOG(ERROR)`, so this makes them print out by default.
76  logging.set_stderrthreshold('error')
77
78  def app_main(argv):
79    """Function passed to absl.app.run."""
80    if len(argv) > 1:
81      raise app.UsageError('Too many command-line arguments.')
82    if FLAGS.save_model_path:
83      save_model_path = FLAGS.save_model_path
84    else:
85      save_model_path = tempfile.mktemp(suffix='.saved_model')
86
87    signature_def_map, init_op, assets_collection = create_signature()
88
89    sess = tf.Session()
90    sess.run(tf.initializers.global_variables())
91    builder = tf.saved_model.builder.SavedModelBuilder(save_model_path)
92    builder.add_meta_graph_and_variables(
93        sess, [tf.saved_model.tag_constants.SERVING],
94        signature_def_map,
95        main_op=init_op,
96        assets_collection=assets_collection,
97        strip_default_attrs=True)
98    builder.save()
99
100    logging.info('Saved model to: %s', save_model_path)
101    exported_names = ''
102    upgrade_legacy = True
103    if use_lite:
104      mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite(
105          save_model_path, exported_names,
106          ','.join([tf.saved_model.tag_constants.SERVING]),
107          upgrade_legacy, show_debug_info)
108      # We don't strictly need this, but it serves as a handy sanity check
109      # for that API, which is otherwise a bit annoying to test.
110      # The canonicalization shouldn't affect these tests in any way.
111      mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir,
112                                                        'tf-standard-pipeline',
113                                                        show_debug_info)
114    else:
115      mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
116          save_model_path, exported_names,
117          ','.join([tf.saved_model.tag_constants.SERVING]),
118          lift_variables, upgrade_legacy, show_debug_info)
119
120    if canonicalize:
121      mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize',
122                                                        show_debug_info)
123    print(mlir)
124
125  app.run(app_main)
126