xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/serve_slices.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2021 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Provides the `serve_slices` operation.
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard WorkerThis wraps the generated op and ensures that necessary shared libraries
17*14675a02SAndroid Build Coastguard Workerare loaded.
18*14675a02SAndroid Build Coastguard Worker"""
19*14675a02SAndroid Build Coastguard Worker
20*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
21*14675a02SAndroid Build Coastguard Worker
22*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow import _serve_slices_op
23*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow import gen_serve_slices_py
24*14675a02SAndroid Build Coastguard Worker
25*14675a02SAndroid Build Coastguard Worker_serve_slices_so = tf.load_op_library(
26*14675a02SAndroid Build Coastguard Worker    tf.compat.v1.resource_loader.get_path_to_datafile('./_serve_slices_op.so'))
27*14675a02SAndroid Build Coastguard Worker
28*14675a02SAndroid Build Coastguard Worker
29*14675a02SAndroid Build Coastguard Workerdef _to_tensor_list(list_of_python_values, dtype=None):
30*14675a02SAndroid Build Coastguard Worker  return [
31*14675a02SAndroid Build Coastguard Worker      tf.convert_to_tensor(subvalue, dtype=dtype)
32*14675a02SAndroid Build Coastguard Worker      for subvalue in list_of_python_values
33*14675a02SAndroid Build Coastguard Worker  ]
34*14675a02SAndroid Build Coastguard Worker
35*14675a02SAndroid Build Coastguard Worker
36*14675a02SAndroid Build Coastguard Workerdef serve_slices(callback_token, server_val, max_key, select_fn_initialize_op,
37*14675a02SAndroid Build Coastguard Worker                 select_fn_server_val_input_tensor_names,
38*14675a02SAndroid Build Coastguard Worker                 select_fn_key_input_tensor_name,
39*14675a02SAndroid Build Coastguard Worker                 select_fn_filename_input_tensor_name,
40*14675a02SAndroid Build Coastguard Worker                 select_fn_target_tensor_name):
41*14675a02SAndroid Build Coastguard Worker  """Calls into a preregistered `callback_token` to serve slices of a value.
42*14675a02SAndroid Build Coastguard Worker
43*14675a02SAndroid Build Coastguard Worker  In addition to the arguments to this function, `serve_slices` requires that
44*14675a02SAndroid Build Coastguard Worker  a TensorFlow graph containing a selection function (`select_fn`) be provided
45*14675a02SAndroid Build Coastguard Worker  to the server running `serve_slices`. `serve_slices` is responsible for
46*14675a02SAndroid Build Coastguard Worker  providing the server with the names of the placeholder tensor inputs to the
47*14675a02SAndroid Build Coastguard Worker  selection function (`select_fn_X_input_tensor_names`,
48*14675a02SAndroid Build Coastguard Worker  `select_fn_key_input_tensor_name`, and `select_fn_filename_input_tensor_name`)
49*14675a02SAndroid Build Coastguard Worker  and the target tensor to evalate to ensure that the slice is written to the
50*14675a02SAndroid Build Coastguard Worker  provided filename (`select_fn_target_tensor_name`).
51*14675a02SAndroid Build Coastguard Worker
52*14675a02SAndroid Build Coastguard Worker  Args:
53*14675a02SAndroid Build Coastguard Worker    callback_token: An string ID corresponding to a callback registered with the
54*14675a02SAndroid Build Coastguard Worker      `register_serve_slices_callback` function. This function will be invoked
55*14675a02SAndroid Build Coastguard Worker      when `serve_slices` is called.
56*14675a02SAndroid Build Coastguard Worker    server_val: A list of arbitrary-typed tensors from which slices may be
57*14675a02SAndroid Build Coastguard Worker      generated using `select_fn`. These tensors must be passed into the
58*14675a02SAndroid Build Coastguard Worker      `select_fn` by writing them to the placeholder tensors named by
59*14675a02SAndroid Build Coastguard Worker      `select_fn_server_val_input_names`, which must contain exactly one tensor
60*14675a02SAndroid Build Coastguard Worker      name for each tensor in `server_val`.
61*14675a02SAndroid Build Coastguard Worker    max_key: An integer indicating the maxiumum slice index which may be
62*14675a02SAndroid Build Coastguard Worker      requested. Slice indices start at zero and may go up to `max_key`
63*14675a02SAndroid Build Coastguard Worker      (inclusive).
64*14675a02SAndroid Build Coastguard Worker    select_fn_initialize_op: An op to run before each call to `select_fn` in
65*14675a02SAndroid Build Coastguard Worker      order to reinitialize any state `select_fn` may contain.
66*14675a02SAndroid Build Coastguard Worker    select_fn_server_val_input_tensor_names: A list of names of the tensors that
67*14675a02SAndroid Build Coastguard Worker      make up the `server_val` portion of the inputs to `select_fn`. Must be the
68*14675a02SAndroid Build Coastguard Worker      same length as the number of tensors in `server_val`.
69*14675a02SAndroid Build Coastguard Worker    select_fn_key_input_tensor_name: The name of the tensor that is the `key`
70*14675a02SAndroid Build Coastguard Worker      input to `select_fn`.
71*14675a02SAndroid Build Coastguard Worker    select_fn_filename_input_tensor_name: The name of the placeholder tensor
72*14675a02SAndroid Build Coastguard Worker      that is the `filename` input to `select_fn`. The `filename` is used to
73*14675a02SAndroid Build Coastguard Worker      specify where the resulting slice should be written.
74*14675a02SAndroid Build Coastguard Worker    select_fn_target_tensor_name: The name of the `target` tensor to run which
75*14675a02SAndroid Build Coastguard Worker      will result in `select_fn`'s output being written to `filename`.
76*14675a02SAndroid Build Coastguard Worker
77*14675a02SAndroid Build Coastguard Worker  Returns:
78*14675a02SAndroid Build Coastguard Worker    A string identifier given by the underlying callback which can be used by
79*14675a02SAndroid Build Coastguard Worker    clients to access the generated slices.
80*14675a02SAndroid Build Coastguard Worker  """
81*14675a02SAndroid Build Coastguard Worker  return gen_serve_slices_py.serve_slices(
82*14675a02SAndroid Build Coastguard Worker      callback_token=tf.convert_to_tensor(callback_token, dtype=tf.string),
83*14675a02SAndroid Build Coastguard Worker      server_val=_to_tensor_list(server_val),
84*14675a02SAndroid Build Coastguard Worker      max_key=tf.convert_to_tensor(max_key, dtype=tf.int32),
85*14675a02SAndroid Build Coastguard Worker      select_fn_initialize_op=tf.convert_to_tensor(
86*14675a02SAndroid Build Coastguard Worker          select_fn_initialize_op, dtype=tf.string),
87*14675a02SAndroid Build Coastguard Worker      select_fn_server_val_input_tensor_names=_to_tensor_list(
88*14675a02SAndroid Build Coastguard Worker          select_fn_server_val_input_tensor_names, dtype=tf.string),
89*14675a02SAndroid Build Coastguard Worker      select_fn_key_input_tensor_name=tf.convert_to_tensor(
90*14675a02SAndroid Build Coastguard Worker          select_fn_key_input_tensor_name, dtype=tf.string),
91*14675a02SAndroid Build Coastguard Worker      select_fn_filename_input_tensor_name=tf.convert_to_tensor(
92*14675a02SAndroid Build Coastguard Worker          select_fn_filename_input_tensor_name, dtype=tf.string),
93*14675a02SAndroid Build Coastguard Worker      select_fn_target_tensor_name=tf.convert_to_tensor(
94*14675a02SAndroid Build Coastguard Worker          select_fn_target_tensor_name, dtype=tf.string))
95*14675a02SAndroid Build Coastguard Worker
96*14675a02SAndroid Build Coastguard Worker
97*14675a02SAndroid Build Coastguard Workerdef register_serve_slices_callback(callback):
98*14675a02SAndroid Build Coastguard Worker  """Registers a callback to be invoked by the `ServeSlices` op."""
99*14675a02SAndroid Build Coastguard Worker  def callback_adapter(callback_token, server_val, *args):
100*14675a02SAndroid Build Coastguard Worker    # Convert the serialized TensorProtos to ndarrays.
101*14675a02SAndroid Build Coastguard Worker    tensor_proto = tf.make_tensor_proto(0)
102*14675a02SAndroid Build Coastguard Worker    converted_server_val = [
103*14675a02SAndroid Build Coastguard Worker        tf.make_ndarray(tensor_proto.FromString(val)) for val in server_val
104*14675a02SAndroid Build Coastguard Worker    ]
105*14675a02SAndroid Build Coastguard Worker    return callback(callback_token, converted_server_val, *args)
106*14675a02SAndroid Build Coastguard Worker
107*14675a02SAndroid Build Coastguard Worker  return _serve_slices_op.register_serve_slices_callback(callback_adapter)
108