1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Standalone utility to generate some test saved models.""" 16 17import os 18 19from absl import app 20 21from tensorflow.python.client import session as session_lib 22from tensorflow.python.compat import v2_compat 23from tensorflow.python.eager import def_function 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_spec 27from tensorflow.python.module import module 28from tensorflow.python.ops import io_ops 29from tensorflow.python.ops import lookup_ops 30from tensorflow.python.ops import variables 31from tensorflow.python.platform import test 32from tensorflow.python.saved_model import save_options 33from tensorflow.python.saved_model import saved_model 34from tensorflow.python.trackable import asset 35 36 37class VarsAndArithmeticObjectGraph(module.Module): 38 """Three vars (one in a sub-module) and compute method.""" 39 40 def __init__(self): 41 self.x = variables.Variable(1.0, name="variable_x") 42 self.y = variables.Variable(2.0, name="variable_y") 43 self.child = module.Module() 44 self.child.z = variables.Variable(3.0, name="child_variable") 45 self.child.c = ops.convert_to_tensor(5.0) 46 47 @def_function.function(input_signature=[ 48 tensor_spec.TensorSpec((), dtypes.float32), 49 tensor_spec.TensorSpec((), dtypes.float32) 50 ]) 51 def compute(self, a, b): 52 return (a + self.x) * (b + self.y) / (self.child.z) + self.child.c 53 54 55class ReferencesParent(module.Module): 56 57 def __init__(self, parent): 58 super(ReferencesParent, self).__init__() 59 self.parent = parent 60 self.my_variable = variables.Variable(3., name="MyVariable") 61 62 63# Creates a cyclic object graph. 64class CyclicModule(module.Module): 65 66 def __init__(self): 67 super(CyclicModule, self).__init__() 68 self.child = ReferencesParent(self) 69 70 71class AssetModule(module.Module): 72 73 def __init__(self): 74 self.asset = asset.Asset( 75 test.test_src_dir_path("cc/saved_model/testdata/test_asset.txt")) 76 77 @def_function.function(input_signature=[]) 78 def read_file(self): 79 return io_ops.read_file(self.asset) 80 81 82class StaticHashTableModule(module.Module): 83 """A module with an Asset, StaticHashTable, and a lookup function.""" 84 85 def __init__(self): 86 self.asset = asset.Asset( 87 test.test_src_dir_path( 88 "cc/saved_model/testdata/static_hashtable_asset.txt")) 89 self.table = lookup_ops.StaticHashTable( 90 lookup_ops.TextFileInitializer(self.asset, dtypes.string, 91 lookup_ops.TextFileIndex.WHOLE_LINE, 92 dtypes.int64, 93 lookup_ops.TextFileIndex.LINE_NUMBER), 94 -1) 95 96 @def_function.function( 97 input_signature=[tensor_spec.TensorSpec(shape=None, dtype=dtypes.string)]) 98 def lookup(self, word): 99 return self.table.lookup(word) 100 101 102def get_simple_session(): 103 ops.disable_eager_execution() 104 sess = session_lib.Session() 105 variables.Variable(1.) 106 sess.run(variables.global_variables_initializer()) 107 return sess 108 109 110MODULE_CTORS = { 111 "VarsAndArithmeticObjectGraph": (VarsAndArithmeticObjectGraph, 2), 112 "CyclicModule": (CyclicModule, 2), 113 "AssetModule": (AssetModule, 2), 114 "StaticHashTableModule": (StaticHashTableModule, 2), 115 "SimpleV1Model": (get_simple_session, 1) 116} 117 118 119def main(args): 120 if len(args) != 3: 121 print("Expected: {export_path} {ModuleName}") 122 print("Allowed ModuleNames:", MODULE_CTORS.keys()) 123 return 1 124 125 _, export_path, module_name = args 126 module_ctor, version = MODULE_CTORS.get(module_name) 127 if not module_ctor: 128 print("Expected ModuleName to be one of:", MODULE_CTORS.keys()) 129 return 2 130 os.makedirs(export_path) 131 132 tf_module = module_ctor() 133 if version == 2: 134 options = save_options.SaveOptions(save_debug_info=True) 135 saved_model.save(tf_module, export_path, options=options) 136 else: 137 builder = saved_model.builder.SavedModelBuilder(export_path) 138 builder.add_meta_graph_and_variables(tf_module, ["serve"]) 139 builder.save() 140 141 142if __name__ == "__main__": 143 v2_compat.enable_v2_behavior() 144 app.run(main) 145