1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Writes a GraphDef to a file for testing `ServeSlices`.""" 15 16from absl import app 17from absl import flags 18import tensorflow as tf 19 20from fcp.tensorflow import serve_slices 21 22CALLBACK_TOKEN_PLACEHOLDER_TENSOR = 'callback_token' 23SERVED_AT_TENSOR = 'served_at_id' 24SERVER_VAL = (1, 2.0, 'foo') 25MAX_KEY = 44 26SELECT_FN_INITIALIZE_OP = 'init_the_things' 27SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES = ['a', 'b', 'c'] 28SELECT_FN_KEY_INPUT_TENSOR_NAME = 'bar' 29SELECT_FN_FILENAME_TENSOR_NAME = 'goofy' 30SELECT_FN_TARGET_TENSOR_NAME = 'goobler' 31 32flags.DEFINE_string('output', None, 'The path to the output file.') 33FLAGS = flags.FLAGS 34 35 36def make_graph(): 37 """Builds and returns a `tf.Graph` which calls `ServeSlices`.""" 38 graph = tf.Graph() 39 with graph.as_default(): 40 # Create a placeholder with a fixed name to allow the code running the graph 41 # to provide input. 42 callback_token = tf.compat.v1.placeholder( 43 name=CALLBACK_TOKEN_PLACEHOLDER_TENSOR, dtype=tf.string) 44 served_at_id = serve_slices.serve_slices( 45 callback_token=callback_token, 46 server_val=SERVER_VAL, 47 max_key=MAX_KEY, 48 select_fn_initialize_op=SELECT_FN_INITIALIZE_OP, 49 select_fn_server_val_input_tensor_names=SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES, 50 select_fn_key_input_tensor_name=SELECT_FN_KEY_INPUT_TENSOR_NAME, 51 select_fn_filename_input_tensor_name=SELECT_FN_FILENAME_TENSOR_NAME, 52 select_fn_target_tensor_name=SELECT_FN_TARGET_TENSOR_NAME) 53 # Create a tensor with a fixed name to allow the code running the graph to 54 # receive output. 55 tf.identity(served_at_id, name=SERVED_AT_TENSOR) 56 return graph 57 58 59def main(argv): 60 del argv 61 graph_def_str = str(make_graph().as_graph_def()) 62 with open(FLAGS.output, 'w') as output_file: 63 output_file.write(graph_def_str) 64 65 66if __name__ == '__main__': 67 flags.mark_flag_as_required('output') 68 app.run(main) 69