xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/python/serve_slices_registry_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for serve_slices_registry."""
15
16from unittest import mock
17
18from absl.testing import absltest
19import numpy as np
20import tensorflow as tf
21
22from fcp.tensorflow import serve_slices
23from fcp.tensorflow import serve_slices as serve_slices_registry
24
25SERVER_VAL = (1, 2.0, b'foo')
26SERVER_VAL_NP_DTYPE = (np.int32, np.float32, object)
27MAX_KEY = 44
28SELECT_FN_INITIALIZE_OP = 'init_the_things'
29SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES = ['a', 'b', 'c']
30SELECT_FN_KEY_INPUT_TENSOR_NAME = 'bar'
31SELECT_FN_FILENAME_TENSOR_NAME = 'goofy'
32SELECT_FN_TARGET_TENSOR_NAME = 'goobler'
33
34
35class ServeSlicesRegistryTest(absltest.TestCase):
36
37  def test_register_serve_slices_callback(self):
38    with tf.Graph().as_default() as graph:
39      # Create a placeholder with a fixed name to allow the code running the
40      # graph to provide input.
41      callback_token = tf.compat.v1.placeholder(dtype=tf.string)
42      served_at_id = serve_slices.serve_slices(
43          callback_token=callback_token,
44          server_val=SERVER_VAL,
45          max_key=MAX_KEY,
46          select_fn_initialize_op=SELECT_FN_INITIALIZE_OP,
47          select_fn_server_val_input_tensor_names=SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES,
48          select_fn_key_input_tensor_name=SELECT_FN_KEY_INPUT_TENSOR_NAME,
49          select_fn_filename_input_tensor_name=SELECT_FN_FILENAME_TENSOR_NAME,
50          select_fn_target_tensor_name=SELECT_FN_TARGET_TENSOR_NAME)
51
52    served_at_value = 'address.at.which.data.is.served'
53    mock_callback = mock.Mock(return_value=served_at_value)
54    with serve_slices_registry.register_serve_slices_callback(
55        mock_callback) as token:
56      with tf.compat.v1.Session(graph=graph) as session:
57        served_at_out = session.run(
58            served_at_id, feed_dict={callback_token: token})
59    self.assertEqual(served_at_out, served_at_value.encode())
60    mock_callback.assert_called_once_with(
61        token,
62        [
63            np.array(v, dtype=dtype)
64            for v, dtype in zip(SERVER_VAL, SERVER_VAL_NP_DTYPE)
65        ],
66        MAX_KEY,
67        SELECT_FN_INITIALIZE_OP,
68        SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES,
69        SELECT_FN_KEY_INPUT_TENSOR_NAME,
70        SELECT_FN_FILENAME_TENSOR_NAME,
71        SELECT_FN_TARGET_TENSOR_NAME,
72    )
73
74
75if __name__ == '__main__':
76  absltest.main()
77