1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for variable_helpers.py.""" 15 16from absl.testing import absltest 17 18import tensorflow as tf 19import tensorflow_federated as tff 20 21from fcp.artifact_building import artifact_constants 22from fcp.artifact_building import variable_helpers 23 24 25@tff.federated_computation( 26 tff.type_at_server(tf.int32), tff.type_at_clients(tf.float32) 27) 28def sample_comp(x, y): 29 a = tff.federated_broadcast(x) 30 output1 = tff.federated_secure_sum_bitwidth(a, 5) 31 output2 = tff.federated_mean([y, y], y) 32 return output1, output2 33 34 35class VariableHelpersTest(absltest.TestCase): 36 37 def test_create_vars_for_tff_type(self): 38 with tf.Graph().as_default(): 39 vl = variable_helpers.create_vars_for_tff_type( 40 tff.to_type( 41 [('a', tf.int32), ('b', [('c', tf.bool), ('d', tf.float32)])] 42 ), 43 'x', 44 ) 45 self.assertLen(vl, 3) 46 for v in vl: 47 self.assertTrue(type(v).__name__.endswith('Variable')) 48 self.assertEqual(v.shape.ndims, 0) 49 self.assertEqual([v.dtype for v in vl], [tf.int32, tf.bool, tf.float32]) 50 self.assertEqual([v.name for v in vl], ['x/a:0', 'x/b/c:0', 'x/b/d:0']) 51 52 def test_create_vars_for_tff_type_with_none_and_zero_shape(self): 53 with tf.Graph().as_default(): 54 vl = variable_helpers.create_vars_for_tff_type( 55 tff.TensorType(dtype=tf.int32, shape=[5, None, 0]) 56 ) 57 self.assertLen(vl, 1) 58 test_variable = vl[0] 59 self.assertEqual(test_variable.initial_value.shape.as_list(), [5, 0, 0]) 60 self.assertEqual(test_variable.shape.as_list(), [5, None, None]) 61 62 def test_create_vars_for_tff_federated_type(self): 63 tff_type = tff.FederatedType(tff.TensorType(tf.int32), tff.SERVER) 64 with tf.Graph().as_default(): 65 vl = variable_helpers.create_vars_for_tff_type(tff_type) 66 67 self.assertLen(vl, 1) 68 v = vl[0] 69 self.assertTrue(type(v).__name__.endswith('Variable')) 70 self.assertEqual(v.shape.ndims, 0) 71 self.assertEqual(v.dtype, tf.int32) 72 self.assertEqual(v.name, 'v:0') 73 74 def test_create_vars_for_struct_of_tff_federated_types(self): 75 tff_type = tff.StructType([ 76 ( 77 'num_examples_secagg', 78 tff.FederatedType(tff.TensorType(tf.int32), tff.SERVER), 79 ), 80 ( 81 'num_examples_simpleagg', 82 tff.FederatedType(tff.TensorType(tf.int32), tff.SERVER), 83 ), 84 ]) 85 with tf.Graph().as_default(): 86 vl = variable_helpers.create_vars_for_tff_type(tff_type) 87 88 self.assertLen(vl, 2) 89 for v in vl: 90 self.assertTrue(type(v).__name__.endswith('Variable')) 91 self.assertEqual(v.shape.ndims, 0) 92 self.assertEqual([v.dtype for v in vl], [tf.int32, tf.int32]) 93 self.assertEqual( 94 [v.name for v in vl], 95 ['v/num_examples_secagg:0', 'v/num_examples_simpleagg:0'], 96 ) 97 98 def test_create_vars_fails_for_client_placed_type(self): 99 tff_type = tff.FederatedType(tff.TensorType(tf.int32), tff.CLIENTS) 100 with self.assertRaisesRegex(TypeError, 'Can only create vars'): 101 with tf.Graph().as_default(): 102 _ = variable_helpers.create_vars_for_tff_type(tff_type) 103 104 def test_create_vars_fails_for_struct_with_client_placed_type(self): 105 tff_type = tff.StructType([ 106 ( 107 'num_examples_secagg', 108 tff.FederatedType(tff.TensorType(tf.int32), tff.SERVER), 109 ), 110 ( 111 'num_examples_simpleagg', 112 tff.FederatedType(tff.TensorType(tf.int32), tff.CLIENTS), 113 ), 114 ]) 115 with self.assertRaisesRegex(TypeError, 'Can only create vars'): 116 with tf.Graph().as_default(): 117 _ = variable_helpers.create_vars_for_tff_type(tff_type) 118 119 def test_variable_names_from_type_with_tensor_type_and_no_name(self): 120 names = variable_helpers.variable_names_from_type( 121 tff.TensorType(dtype=tf.int32) 122 ) 123 self.assertEqual(names, ['v']) 124 125 def test_variable_names_from_type_with_tensor_type(self): 126 names = variable_helpers.variable_names_from_type( 127 tff.TensorType(dtype=tf.int32), 'test_name' 128 ) 129 self.assertEqual(names, ['test_name']) 130 131 def test_variable_names_from_type_with_federated_type(self): 132 names = variable_helpers.variable_names_from_type( 133 tff.FederatedType(tff.TensorType(dtype=tf.int32), tff.SERVER), 134 'test_name', 135 ) 136 self.assertEqual(names, ['test_name']) 137 138 def test_variable_names_from_type_with_named_tuple_type_and_no_name(self): 139 names = variable_helpers.variable_names_from_type( 140 tff.to_type( 141 [('a', tf.int32), ('b', [('c', tf.bool), ('d', tf.float32)])] 142 ) 143 ) 144 self.assertEqual(names, ['v/a', 'v/b/c', 'v/b/d']) 145 146 def test_variable_names_from_type_with_named_tuple_type(self): 147 names = variable_helpers.variable_names_from_type( 148 tff.to_type( 149 [('a', tf.int32), ('b', [('c', tf.bool), ('d', tf.float32)])] 150 ), 151 'test_name', 152 ) 153 self.assertEqual(names, ['test_name/a', 'test_name/b/c', 'test_name/b/d']) 154 155 def test_variable_names_from_type_with_named_tuple_type_no_name_field(self): 156 names = variable_helpers.variable_names_from_type( 157 tff.to_type([(tf.int32), ('b', [(tf.bool), ('d', tf.float32)])]), 158 'test_name', 159 ) 160 self.assertEqual(names, ['test_name/0', 'test_name/b/0', 'test_name/b/d']) 161 162 def test_get_flattened_tensor_specs_with_tensor_type(self): 163 specs = variable_helpers.get_flattened_tensor_specs( 164 tff.TensorType(dtype=tf.int32, shape=tf.TensorShape([3, 5])), 165 'test_name', 166 ) 167 self.assertEqual( 168 specs, 169 [ 170 tf.TensorSpec( 171 name='test_name', 172 shape=tf.TensorShape([3, 5]), 173 dtype=tf.int32, 174 ) 175 ], 176 ) 177 178 def test_get_flattened_tensor_specs_with_federated_type(self): 179 specs = variable_helpers.get_flattened_tensor_specs( 180 tff.FederatedType( 181 tff.TensorType(dtype=tf.int32, shape=tf.TensorShape([3, 5])), 182 tff.SERVER, 183 ), 184 'test_name', 185 ) 186 self.assertEqual( 187 specs, 188 [ 189 tf.TensorSpec( 190 name='test_name', 191 shape=tf.TensorShape([3, 5]), 192 dtype=tf.int32, 193 ) 194 ], 195 ) 196 197 def test_get_flattened_tensor_specs_with_tuple_type(self): 198 specs = variable_helpers.get_flattened_tensor_specs( 199 tff.StructType([ 200 ( 201 'a', 202 tff.TensorType(dtype=tf.int32, shape=tf.TensorShape([3, 5])), 203 ), 204 ( 205 'b', 206 tff.StructType([ 207 (tff.TensorType(dtype=tf.bool, shape=tf.TensorShape([4]))), 208 ( 209 'd', 210 tff.TensorType( 211 dtype=tf.float32, 212 shape=tf.TensorShape([1, 3, 5]), 213 ), 214 ), 215 ]), 216 ), 217 ]), 218 'test_name', 219 ) 220 self.assertEqual( 221 specs, 222 [ 223 tf.TensorSpec( 224 name='test_name/a', 225 shape=tf.TensorShape([3, 5]), 226 dtype=tf.int32, 227 ), 228 tf.TensorSpec( 229 name='test_name/b/0', 230 shape=tf.TensorShape([4]), 231 dtype=tf.bool, 232 ), 233 tf.TensorSpec( 234 name='test_name/b/d', 235 shape=tf.TensorShape([1, 3, 5]), 236 dtype=tf.float32, 237 ), 238 ], 239 ) 240 241 def test_get_grouped_input_tensor_specs_for_aggregations(self): 242 daf = tff.backends.mapreduce.get_distribute_aggregate_form_for_computation( 243 sample_comp 244 ) 245 grouped_input_tensor_specs = variable_helpers.get_grouped_input_tensor_specs_for_aggregations( 246 daf.client_to_server_aggregation.to_building_block(), 247 artifact_constants.AGGREGATION_INTRINSIC_ARG_SELECTION_INDEX_TO_NAME_DICT, 248 ) 249 self.assertEqual( 250 grouped_input_tensor_specs, 251 [ 252 [ # federated_weighted_mean intrinsic args 253 [ # federated_weighted_mean value arg 254 tf.TensorSpec( 255 name='update/0/0', 256 shape=tf.TensorShape([]), 257 dtype=tf.float32, 258 ), 259 tf.TensorSpec( 260 name='update/0/1', 261 shape=tf.TensorShape([]), 262 dtype=tf.float32, 263 ), 264 ], 265 [ # federated_weighted_mean weight arg 266 tf.TensorSpec( 267 name='update/1', 268 shape=tf.TensorShape([]), 269 dtype=tf.float32, 270 ) 271 ], 272 ], 273 [ # federated_secure_sum_bitwidth intrinsic args 274 [ # federated_secure_sum_bitwidth value arg 275 tf.TensorSpec( 276 name='update/2', 277 shape=tf.TensorShape([]), 278 dtype=tf.int32, 279 ) 280 ], 281 [ # federated_secure_sum_bitwidth bitwidth arg 282 tf.TensorSpec( 283 name='intermediate_state/0', 284 shape=tf.TensorShape([]), 285 dtype=tf.int32, 286 ) 287 ], 288 ], 289 ], 290 ) 291 292 def test_get_grouped_output_tensor_specs_for_aggregations(self): 293 daf = tff.backends.mapreduce.get_distribute_aggregate_form_for_computation( 294 sample_comp 295 ) 296 grouped_output_tensor_specs = ( 297 variable_helpers.get_grouped_output_tensor_specs_for_aggregations( 298 daf.client_to_server_aggregation.to_building_block() 299 ) 300 ) 301 self.assertEqual( 302 grouped_output_tensor_specs, 303 [ 304 [ # federated_weighted_mean intrinsic output 305 tf.TensorSpec( 306 name='intermediate_update/0/0/0', 307 shape=tf.TensorShape([]), 308 dtype=tf.float32, 309 ), 310 tf.TensorSpec( 311 name='intermediate_update/0/0/1', 312 shape=tf.TensorShape([]), 313 dtype=tf.float32, 314 ), 315 ], 316 [ # federated_secure_sum_bitwidth intrinsic output 317 tf.TensorSpec( 318 name='intermediate_update/0/1', 319 shape=tf.TensorShape([]), 320 dtype=tf.int32, 321 ) 322 ], 323 ], 324 ) 325 326 327if __name__ == '__main__': 328 absltest.main() 329