1# /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ==============================================================================*/ 15r"""Generate models in testdata for use in tests. 16 17If this script is being run via `<build-cmd> run`, pass an absolute path. 18Otherwise, this script will attempt to write to a non-writable directory. 19 20Example: 21<build-cmd> run //third_party/tensorflow/cc/experimental/libtf:generate_testdata 22 -- \ 23 --path`pwd`/third_party/tensorflow/cc/experimental/libtf/tests/testdata/ \ 24 --model_name=simple-model 25""" 26import os 27 28from absl import app 29from absl import flags 30 31from tensorflow.python.compat import v2_compat 32from tensorflow.python.eager import def_function 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import tensor_spec 36from tensorflow.python.module import module 37from tensorflow.python.ops import variables 38from tensorflow.python.saved_model import saved_model 39 40TESTDATA_PATH = flags.DEFINE_string( 41 "path", None, help="Path to testdata directory.") 42 43MODEL_NAME = flags.DEFINE_string( 44 "model_name", None, help="Name of model to generate.") 45 46 47class DataStructureModel(module.Module): 48 """Model used for testing data structures in the C++ API.""" 49 50 def __init__(self): 51 self.arr1 = [1.] 52 self.const_arr = [constant_op.constant(1.)] 53 self.var_arr = [variables.Variable(1.), variables.Variable(2.)] 54 self.dict1 = {"a": 1.} 55 self.var_dict = {"a": variables.Variable(1.), "b": variables.Variable(2.)} 56 57 58class SimpleModel(module.Module): 59 """A simple model used for exercising the C++ API.""" 60 61 @def_function.function(input_signature=[ 62 tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), 63 ]) 64 def test_float(self, x): 65 return constant_op.constant(3.0) * x 66 67 @def_function.function(input_signature=[ 68 tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32), 69 ]) 70 def test_int(self, x): 71 return constant_op.constant(3) * x 72 73 @def_function.function(input_signature=[ 74 tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), 75 tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), 76 ]) 77 def test_add(self, x, y): 78 # Test a function with multiple arguments. 79 return x + y 80 81 82TEST_MODELS = { 83 "simple-model": SimpleModel, 84 "data-structure-model": DataStructureModel 85} 86 87 88def get_model(name): 89 if name not in TEST_MODELS: 90 raise ValueError("Model name '{}' not in TEST_MODELS") 91 return TEST_MODELS[name]() 92 93 94def main(unused_argv): 95 96 model = get_model(MODEL_NAME.value) 97 path = os.path.join(TESTDATA_PATH.value, MODEL_NAME.value) 98 saved_model.save(model, path) 99 100 101if __name__ == "__main__": 102 v2_compat.enable_v2_behavior() 103 flags.mark_flag_as_required("path") 104 flags.mark_flag_as_required("model_name") 105 app.run(main) 106