xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/experimental/libtf/tests/generate_testdata.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#   http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================*/
15r"""Generate models in testdata for use in tests.
16
17If this script is being run via `<build-cmd> run`, pass an absolute path.
18Otherwise, this script will attempt to write to a non-writable directory.
19
20Example:
21<build-cmd> run //third_party/tensorflow/cc/experimental/libtf:generate_testdata
22 -- \
23 --path`pwd`/third_party/tensorflow/cc/experimental/libtf/tests/testdata/ \
24 --model_name=simple-model
25"""
26import os
27
28from absl import app
29from absl import flags
30
31from tensorflow.python.compat import v2_compat
32from tensorflow.python.eager import def_function
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import tensor_spec
36from tensorflow.python.module import module
37from tensorflow.python.ops import variables
38from tensorflow.python.saved_model import saved_model
39
40TESTDATA_PATH = flags.DEFINE_string(
41    "path", None, help="Path to testdata directory.")
42
43MODEL_NAME = flags.DEFINE_string(
44    "model_name", None, help="Name of model to generate.")
45
46
47class DataStructureModel(module.Module):
48  """Model used for testing data structures in the C++ API."""
49
50  def __init__(self):
51    self.arr1 = [1.]
52    self.const_arr = [constant_op.constant(1.)]
53    self.var_arr = [variables.Variable(1.), variables.Variable(2.)]
54    self.dict1 = {"a": 1.}
55    self.var_dict = {"a": variables.Variable(1.), "b": variables.Variable(2.)}
56
57
58class SimpleModel(module.Module):
59  """A simple model used for exercising the C++ API."""
60
61  @def_function.function(input_signature=[
62      tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
63  ])
64  def test_float(self, x):
65    return constant_op.constant(3.0) * x
66
67  @def_function.function(input_signature=[
68      tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32),
69  ])
70  def test_int(self, x):
71    return constant_op.constant(3) * x
72
73  @def_function.function(input_signature=[
74      tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
75      tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
76  ])
77  def test_add(self, x, y):
78    # Test a function with multiple arguments.
79    return x + y
80
81
82TEST_MODELS = {
83    "simple-model": SimpleModel,
84    "data-structure-model": DataStructureModel
85}
86
87
88def get_model(name):
89  if name not in TEST_MODELS:
90    raise ValueError("Model name '{}' not in TEST_MODELS")
91  return TEST_MODELS[name]()
92
93
94def main(unused_argv):
95
96  model = get_model(MODEL_NAME.value)
97  path = os.path.join(TESTDATA_PATH.value, MODEL_NAME.value)
98  saved_model.save(model, path)
99
100
101if __name__ == "__main__":
102  v2_compat.enable_v2_behavior()
103  flags.mark_flag_as_required("path")
104  flags.mark_flag_as_required("model_name")
105  app.run(main)
106