1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Serves as a common "main" function for all the SavedModel tests. 16 17There is a fair amount of setup needed to initialize tensorflow and get it 18into a proper TF2 execution mode. This hides that boilerplate. 19""" 20 21import tempfile 22from absl import app 23from absl import flags 24from absl import logging 25import tensorflow.compat.v1 as tf 26 27from tensorflow.python import pywrap_mlir # pylint: disable=g-direct-tensorflow-import 28 29# Use /tmp to make debugging the tests easier (see README.md) 30flags.DEFINE_string('save_model_path', '', 'Path to save the model to.') 31FLAGS = flags.FLAGS 32 33 34def set_tf_options(): 35 # Default TF1.x uses reference variables that are not supported by SavedModel 36 # v1 Importer. To use SavedModel V1 Importer, resource variables should be 37 # enabled. 38 tf.enable_resource_variables() 39 tf.compat.v1.disable_eager_execution() 40 41 42# This function needs to take a "create_module_fn", as opposed to just the 43# module itself, because the creation of the module has to be delayed until 44# after absl and tensorflow have run various initialization steps. 45def do_test(create_signature, 46 canonicalize=False, 47 show_debug_info=False, 48 use_lite=False, 49 lift_variables=True): 50 """Runs test. 51 52 1. Performs absl and tf "main"-like initialization that must run before almost 53 anything else. 54 2. Converts signature_def_map to SavedModel V1 55 3. Converts SavedModel V1 to MLIR 56 4. Prints the textual MLIR to stdout (it is expected that the caller will have 57 FileCheck checks in its file to check this output). 58 59 This is only for use by the MLIR SavedModel importer tests. 60 61 Args: 62 create_signature: A functor that return signature_def_map, init_op and 63 assets_collection. signature_def_map is a map from string key to 64 signature_def. The key will be used as function name in the resulting 65 MLIR. 66 canonicalize: If true, canonicalizer will be run on the resulting MLIR. 67 show_debug_info: If true, shows debug locations in the resulting MLIR. 68 use_lite: If true, importer will not do any graph transformation such as 69 lift variables. 70 lift_variables: If false, no variable lifting will be done on the graph. 71 """ 72 73 # Make LOG(ERROR) in C++ code show up on the console. 74 # All `Status` passed around in the C++ API seem to eventually go into 75 # `LOG(ERROR)`, so this makes them print out by default. 76 logging.set_stderrthreshold('error') 77 78 def app_main(argv): 79 """Function passed to absl.app.run.""" 80 if len(argv) > 1: 81 raise app.UsageError('Too many command-line arguments.') 82 if FLAGS.save_model_path: 83 save_model_path = FLAGS.save_model_path 84 else: 85 save_model_path = tempfile.mktemp(suffix='.saved_model') 86 87 signature_def_map, init_op, assets_collection = create_signature() 88 89 sess = tf.Session() 90 sess.run(tf.initializers.global_variables()) 91 builder = tf.saved_model.builder.SavedModelBuilder(save_model_path) 92 builder.add_meta_graph_and_variables( 93 sess, [tf.saved_model.tag_constants.SERVING], 94 signature_def_map, 95 main_op=init_op, 96 assets_collection=assets_collection, 97 strip_default_attrs=True) 98 builder.save() 99 100 logging.info('Saved model to: %s', save_model_path) 101 exported_names = '' 102 upgrade_legacy = True 103 if use_lite: 104 mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite( 105 save_model_path, exported_names, 106 ','.join([tf.saved_model.tag_constants.SERVING]), 107 upgrade_legacy, show_debug_info) 108 # We don't strictly need this, but it serves as a handy sanity check 109 # for that API, which is otherwise a bit annoying to test. 110 # The canonicalization shouldn't affect these tests in any way. 111 mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 112 'tf-standard-pipeline', 113 show_debug_info) 114 else: 115 mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir( 116 save_model_path, exported_names, 117 ','.join([tf.saved_model.tag_constants.SERVING]), 118 lift_variables, upgrade_legacy, show_debug_info) 119 120 if canonicalize: 121 mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize', 122 show_debug_info) 123 print(mlir) 124 125 app.run(app_main) 126