xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/make_serve_slices_test_graph.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Writes a GraphDef to a file for testing `ServeSlices`."""
15
16from absl import app
17from absl import flags
18import tensorflow as tf
19
20from fcp.tensorflow import serve_slices
21
22CALLBACK_TOKEN_PLACEHOLDER_TENSOR = 'callback_token'
23SERVED_AT_TENSOR = 'served_at_id'
24SERVER_VAL = (1, 2.0, 'foo')
25MAX_KEY = 44
26SELECT_FN_INITIALIZE_OP = 'init_the_things'
27SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES = ['a', 'b', 'c']
28SELECT_FN_KEY_INPUT_TENSOR_NAME = 'bar'
29SELECT_FN_FILENAME_TENSOR_NAME = 'goofy'
30SELECT_FN_TARGET_TENSOR_NAME = 'goobler'
31
32flags.DEFINE_string('output', None, 'The path to the output file.')
33FLAGS = flags.FLAGS
34
35
36def make_graph():
37  """Builds and returns a `tf.Graph` which calls `ServeSlices`."""
38  graph = tf.Graph()
39  with graph.as_default():
40    # Create a placeholder with a fixed name to allow the code running the graph
41    # to provide input.
42    callback_token = tf.compat.v1.placeholder(
43        name=CALLBACK_TOKEN_PLACEHOLDER_TENSOR, dtype=tf.string)
44    served_at_id = serve_slices.serve_slices(
45        callback_token=callback_token,
46        server_val=SERVER_VAL,
47        max_key=MAX_KEY,
48        select_fn_initialize_op=SELECT_FN_INITIALIZE_OP,
49        select_fn_server_val_input_tensor_names=SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES,
50        select_fn_key_input_tensor_name=SELECT_FN_KEY_INPUT_TENSOR_NAME,
51        select_fn_filename_input_tensor_name=SELECT_FN_FILENAME_TENSOR_NAME,
52        select_fn_target_tensor_name=SELECT_FN_TARGET_TENSOR_NAME)
53    # Create a tensor with a fixed name to allow the code running the graph to
54    # receive output.
55    tf.identity(served_at_id, name=SERVED_AT_TENSOR)
56  return graph
57
58
59def main(argv):
60  del argv
61  graph_def_str = str(make_graph().as_graph_def())
62  with open(FLAGS.output, 'w') as output_file:
63    output_file.write(graph_def_str)
64
65
66if __name__ == '__main__':
67  flags.mark_flag_as_required('output')
68  app.run(main)
69