1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for serve_slices_registry.""" 15 16from unittest import mock 17 18from absl.testing import absltest 19import numpy as np 20import tensorflow as tf 21 22from fcp.tensorflow import serve_slices 23from fcp.tensorflow import serve_slices as serve_slices_registry 24 25SERVER_VAL = (1, 2.0, b'foo') 26SERVER_VAL_NP_DTYPE = (np.int32, np.float32, object) 27MAX_KEY = 44 28SELECT_FN_INITIALIZE_OP = 'init_the_things' 29SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES = ['a', 'b', 'c'] 30SELECT_FN_KEY_INPUT_TENSOR_NAME = 'bar' 31SELECT_FN_FILENAME_TENSOR_NAME = 'goofy' 32SELECT_FN_TARGET_TENSOR_NAME = 'goobler' 33 34 35class ServeSlicesRegistryTest(absltest.TestCase): 36 37 def test_register_serve_slices_callback(self): 38 with tf.Graph().as_default() as graph: 39 # Create a placeholder with a fixed name to allow the code running the 40 # graph to provide input. 41 callback_token = tf.compat.v1.placeholder(dtype=tf.string) 42 served_at_id = serve_slices.serve_slices( 43 callback_token=callback_token, 44 server_val=SERVER_VAL, 45 max_key=MAX_KEY, 46 select_fn_initialize_op=SELECT_FN_INITIALIZE_OP, 47 select_fn_server_val_input_tensor_names=SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES, 48 select_fn_key_input_tensor_name=SELECT_FN_KEY_INPUT_TENSOR_NAME, 49 select_fn_filename_input_tensor_name=SELECT_FN_FILENAME_TENSOR_NAME, 50 select_fn_target_tensor_name=SELECT_FN_TARGET_TENSOR_NAME) 51 52 served_at_value = 'address.at.which.data.is.served' 53 mock_callback = mock.Mock(return_value=served_at_value) 54 with serve_slices_registry.register_serve_slices_callback( 55 mock_callback) as token: 56 with tf.compat.v1.Session(graph=graph) as session: 57 served_at_out = session.run( 58 served_at_id, feed_dict={callback_token: token}) 59 self.assertEqual(served_at_out, served_at_value.encode()) 60 mock_callback.assert_called_once_with( 61 token, 62 [ 63 np.array(v, dtype=dtype) 64 for v, dtype in zip(SERVER_VAL, SERVER_VAL_NP_DTYPE) 65 ], 66 MAX_KEY, 67 SELECT_FN_INITIALIZE_OP, 68 SELECT_FN_SERVER_VAL_INPUT_TENSOR_NAMES, 69 SELECT_FN_KEY_INPUT_TENSOR_NAME, 70 SELECT_FN_FILENAME_TENSOR_NAME, 71 SELECT_FN_TARGET_TENSOR_NAME, 72 ) 73 74 75if __name__ == '__main__': 76 absltest.main() 77