1*14675a02SAndroid Build Coastguard Worker# Copyright 2021 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Provides the `serve_slices` operation. 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard WorkerThis wraps the generated op and ensures that necessary shared libraries 17*14675a02SAndroid Build Coastguard Workerare loaded. 18*14675a02SAndroid Build Coastguard Worker""" 19*14675a02SAndroid Build Coastguard Worker 20*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 21*14675a02SAndroid Build Coastguard Worker 22*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow import _serve_slices_op 23*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow import gen_serve_slices_py 24*14675a02SAndroid Build Coastguard Worker 25*14675a02SAndroid Build Coastguard Worker_serve_slices_so = tf.load_op_library( 26*14675a02SAndroid Build Coastguard Worker tf.compat.v1.resource_loader.get_path_to_datafile('./_serve_slices_op.so')) 27*14675a02SAndroid Build Coastguard Worker 28*14675a02SAndroid Build Coastguard Worker 29*14675a02SAndroid Build Coastguard Workerdef _to_tensor_list(list_of_python_values, dtype=None): 30*14675a02SAndroid Build Coastguard Worker return [ 31*14675a02SAndroid Build Coastguard Worker tf.convert_to_tensor(subvalue, dtype=dtype) 32*14675a02SAndroid Build Coastguard Worker for subvalue in list_of_python_values 33*14675a02SAndroid Build Coastguard Worker ] 34*14675a02SAndroid Build Coastguard Worker 35*14675a02SAndroid Build Coastguard Worker 36*14675a02SAndroid Build Coastguard Workerdef serve_slices(callback_token, server_val, max_key, select_fn_initialize_op, 37*14675a02SAndroid Build Coastguard Worker select_fn_server_val_input_tensor_names, 38*14675a02SAndroid Build Coastguard Worker select_fn_key_input_tensor_name, 39*14675a02SAndroid Build Coastguard Worker select_fn_filename_input_tensor_name, 40*14675a02SAndroid Build Coastguard Worker select_fn_target_tensor_name): 41*14675a02SAndroid Build Coastguard Worker """Calls into a preregistered `callback_token` to serve slices of a value. 42*14675a02SAndroid Build Coastguard Worker 43*14675a02SAndroid Build Coastguard Worker In addition to the arguments to this function, `serve_slices` requires that 44*14675a02SAndroid Build Coastguard Worker a TensorFlow graph containing a selection function (`select_fn`) be provided 45*14675a02SAndroid Build Coastguard Worker to the server running `serve_slices`. `serve_slices` is responsible for 46*14675a02SAndroid Build Coastguard Worker providing the server with the names of the placeholder tensor inputs to the 47*14675a02SAndroid Build Coastguard Worker selection function (`select_fn_X_input_tensor_names`, 48*14675a02SAndroid Build Coastguard Worker `select_fn_key_input_tensor_name`, and `select_fn_filename_input_tensor_name`) 49*14675a02SAndroid Build Coastguard Worker and the target tensor to evalate to ensure that the slice is written to the 50*14675a02SAndroid Build Coastguard Worker provided filename (`select_fn_target_tensor_name`). 51*14675a02SAndroid Build Coastguard Worker 52*14675a02SAndroid Build Coastguard Worker Args: 53*14675a02SAndroid Build Coastguard Worker callback_token: An string ID corresponding to a callback registered with the 54*14675a02SAndroid Build Coastguard Worker `register_serve_slices_callback` function. This function will be invoked 55*14675a02SAndroid Build Coastguard Worker when `serve_slices` is called. 56*14675a02SAndroid Build Coastguard Worker server_val: A list of arbitrary-typed tensors from which slices may be 57*14675a02SAndroid Build Coastguard Worker generated using `select_fn`. These tensors must be passed into the 58*14675a02SAndroid Build Coastguard Worker `select_fn` by writing them to the placeholder tensors named by 59*14675a02SAndroid Build Coastguard Worker `select_fn_server_val_input_names`, which must contain exactly one tensor 60*14675a02SAndroid Build Coastguard Worker name for each tensor in `server_val`. 61*14675a02SAndroid Build Coastguard Worker max_key: An integer indicating the maxiumum slice index which may be 62*14675a02SAndroid Build Coastguard Worker requested. Slice indices start at zero and may go up to `max_key` 63*14675a02SAndroid Build Coastguard Worker (inclusive). 64*14675a02SAndroid Build Coastguard Worker select_fn_initialize_op: An op to run before each call to `select_fn` in 65*14675a02SAndroid Build Coastguard Worker order to reinitialize any state `select_fn` may contain. 66*14675a02SAndroid Build Coastguard Worker select_fn_server_val_input_tensor_names: A list of names of the tensors that 67*14675a02SAndroid Build Coastguard Worker make up the `server_val` portion of the inputs to `select_fn`. Must be the 68*14675a02SAndroid Build Coastguard Worker same length as the number of tensors in `server_val`. 69*14675a02SAndroid Build Coastguard Worker select_fn_key_input_tensor_name: The name of the tensor that is the `key` 70*14675a02SAndroid Build Coastguard Worker input to `select_fn`. 71*14675a02SAndroid Build Coastguard Worker select_fn_filename_input_tensor_name: The name of the placeholder tensor 72*14675a02SAndroid Build Coastguard Worker that is the `filename` input to `select_fn`. The `filename` is used to 73*14675a02SAndroid Build Coastguard Worker specify where the resulting slice should be written. 74*14675a02SAndroid Build Coastguard Worker select_fn_target_tensor_name: The name of the `target` tensor to run which 75*14675a02SAndroid Build Coastguard Worker will result in `select_fn`'s output being written to `filename`. 76*14675a02SAndroid Build Coastguard Worker 77*14675a02SAndroid Build Coastguard Worker Returns: 78*14675a02SAndroid Build Coastguard Worker A string identifier given by the underlying callback which can be used by 79*14675a02SAndroid Build Coastguard Worker clients to access the generated slices. 80*14675a02SAndroid Build Coastguard Worker """ 81*14675a02SAndroid Build Coastguard Worker return gen_serve_slices_py.serve_slices( 82*14675a02SAndroid Build Coastguard Worker callback_token=tf.convert_to_tensor(callback_token, dtype=tf.string), 83*14675a02SAndroid Build Coastguard Worker server_val=_to_tensor_list(server_val), 84*14675a02SAndroid Build Coastguard Worker max_key=tf.convert_to_tensor(max_key, dtype=tf.int32), 85*14675a02SAndroid Build Coastguard Worker select_fn_initialize_op=tf.convert_to_tensor( 86*14675a02SAndroid Build Coastguard Worker select_fn_initialize_op, dtype=tf.string), 87*14675a02SAndroid Build Coastguard Worker select_fn_server_val_input_tensor_names=_to_tensor_list( 88*14675a02SAndroid Build Coastguard Worker select_fn_server_val_input_tensor_names, dtype=tf.string), 89*14675a02SAndroid Build Coastguard Worker select_fn_key_input_tensor_name=tf.convert_to_tensor( 90*14675a02SAndroid Build Coastguard Worker select_fn_key_input_tensor_name, dtype=tf.string), 91*14675a02SAndroid Build Coastguard Worker select_fn_filename_input_tensor_name=tf.convert_to_tensor( 92*14675a02SAndroid Build Coastguard Worker select_fn_filename_input_tensor_name, dtype=tf.string), 93*14675a02SAndroid Build Coastguard Worker select_fn_target_tensor_name=tf.convert_to_tensor( 94*14675a02SAndroid Build Coastguard Worker select_fn_target_tensor_name, dtype=tf.string)) 95*14675a02SAndroid Build Coastguard Worker 96*14675a02SAndroid Build Coastguard Worker 97*14675a02SAndroid Build Coastguard Workerdef register_serve_slices_callback(callback): 98*14675a02SAndroid Build Coastguard Worker """Registers a callback to be invoked by the `ServeSlices` op.""" 99*14675a02SAndroid Build Coastguard Worker def callback_adapter(callback_token, server_val, *args): 100*14675a02SAndroid Build Coastguard Worker # Convert the serialized TensorProtos to ndarrays. 101*14675a02SAndroid Build Coastguard Worker tensor_proto = tf.make_tensor_proto(0) 102*14675a02SAndroid Build Coastguard Worker converted_server_val = [ 103*14675a02SAndroid Build Coastguard Worker tf.make_ndarray(tensor_proto.FromString(val)) for val in server_val 104*14675a02SAndroid Build Coastguard Worker ] 105*14675a02SAndroid Build Coastguard Worker return callback(callback_token, converted_server_val, *args) 106*14675a02SAndroid Build Coastguard Worker 107*14675a02SAndroid Build Coastguard Worker return _serve_slices_op.register_serve_slices_callback(callback_adapter) 108