xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/variable_helpers_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for variable_helpers.py."""
15
16from absl.testing import absltest
17
18import tensorflow as tf
19import tensorflow_federated as tff
20
21from fcp.artifact_building import artifact_constants
22from fcp.artifact_building import variable_helpers
23
24
25@tff.federated_computation(
26    tff.type_at_server(tf.int32), tff.type_at_clients(tf.float32)
27)
28def sample_comp(x, y):
29  a = tff.federated_broadcast(x)
30  output1 = tff.federated_secure_sum_bitwidth(a, 5)
31  output2 = tff.federated_mean([y, y], y)
32  return output1, output2
33
34
35class VariableHelpersTest(absltest.TestCase):
36
37  def test_create_vars_for_tff_type(self):
38    with tf.Graph().as_default():
39      vl = variable_helpers.create_vars_for_tff_type(
40          tff.to_type(
41              [('a', tf.int32), ('b', [('c', tf.bool), ('d', tf.float32)])]
42          ),
43          'x',
44      )
45    self.assertLen(vl, 3)
46    for v in vl:
47      self.assertTrue(type(v).__name__.endswith('Variable'))
48      self.assertEqual(v.shape.ndims, 0)
49    self.assertEqual([v.dtype for v in vl], [tf.int32, tf.bool, tf.float32])
50    self.assertEqual([v.name for v in vl], ['x/a:0', 'x/b/c:0', 'x/b/d:0'])
51
52  def test_create_vars_for_tff_type_with_none_and_zero_shape(self):
53    with tf.Graph().as_default():
54      vl = variable_helpers.create_vars_for_tff_type(
55          tff.TensorType(dtype=tf.int32, shape=[5, None, 0])
56      )
57      self.assertLen(vl, 1)
58      test_variable = vl[0]
59      self.assertEqual(test_variable.initial_value.shape.as_list(), [5, 0, 0])
60      self.assertEqual(test_variable.shape.as_list(), [5, None, None])
61
62  def test_create_vars_for_tff_federated_type(self):
63    tff_type = tff.FederatedType(tff.TensorType(tf.int32), tff.SERVER)
64    with tf.Graph().as_default():
65      vl = variable_helpers.create_vars_for_tff_type(tff_type)
66
67    self.assertLen(vl, 1)
68    v = vl[0]
69    self.assertTrue(type(v).__name__.endswith('Variable'))
70    self.assertEqual(v.shape.ndims, 0)
71    self.assertEqual(v.dtype, tf.int32)
72    self.assertEqual(v.name, 'v:0')
73
74  def test_create_vars_for_struct_of_tff_federated_types(self):
75    tff_type = tff.StructType([
76        (
77            'num_examples_secagg',
78            tff.FederatedType(tff.TensorType(tf.int32), tff.SERVER),
79        ),
80        (
81            'num_examples_simpleagg',
82            tff.FederatedType(tff.TensorType(tf.int32), tff.SERVER),
83        ),
84    ])
85    with tf.Graph().as_default():
86      vl = variable_helpers.create_vars_for_tff_type(tff_type)
87
88    self.assertLen(vl, 2)
89    for v in vl:
90      self.assertTrue(type(v).__name__.endswith('Variable'))
91      self.assertEqual(v.shape.ndims, 0)
92      self.assertEqual([v.dtype for v in vl], [tf.int32, tf.int32])
93      self.assertEqual(
94          [v.name for v in vl],
95          ['v/num_examples_secagg:0', 'v/num_examples_simpleagg:0'],
96      )
97
98  def test_create_vars_fails_for_client_placed_type(self):
99    tff_type = tff.FederatedType(tff.TensorType(tf.int32), tff.CLIENTS)
100    with self.assertRaisesRegex(TypeError, 'Can only create vars'):
101      with tf.Graph().as_default():
102        _ = variable_helpers.create_vars_for_tff_type(tff_type)
103
104  def test_create_vars_fails_for_struct_with_client_placed_type(self):
105    tff_type = tff.StructType([
106        (
107            'num_examples_secagg',
108            tff.FederatedType(tff.TensorType(tf.int32), tff.SERVER),
109        ),
110        (
111            'num_examples_simpleagg',
112            tff.FederatedType(tff.TensorType(tf.int32), tff.CLIENTS),
113        ),
114    ])
115    with self.assertRaisesRegex(TypeError, 'Can only create vars'):
116      with tf.Graph().as_default():
117        _ = variable_helpers.create_vars_for_tff_type(tff_type)
118
119  def test_variable_names_from_type_with_tensor_type_and_no_name(self):
120    names = variable_helpers.variable_names_from_type(
121        tff.TensorType(dtype=tf.int32)
122    )
123    self.assertEqual(names, ['v'])
124
125  def test_variable_names_from_type_with_tensor_type(self):
126    names = variable_helpers.variable_names_from_type(
127        tff.TensorType(dtype=tf.int32), 'test_name'
128    )
129    self.assertEqual(names, ['test_name'])
130
131  def test_variable_names_from_type_with_federated_type(self):
132    names = variable_helpers.variable_names_from_type(
133        tff.FederatedType(tff.TensorType(dtype=tf.int32), tff.SERVER),
134        'test_name',
135    )
136    self.assertEqual(names, ['test_name'])
137
138  def test_variable_names_from_type_with_named_tuple_type_and_no_name(self):
139    names = variable_helpers.variable_names_from_type(
140        tff.to_type(
141            [('a', tf.int32), ('b', [('c', tf.bool), ('d', tf.float32)])]
142        )
143    )
144    self.assertEqual(names, ['v/a', 'v/b/c', 'v/b/d'])
145
146  def test_variable_names_from_type_with_named_tuple_type(self):
147    names = variable_helpers.variable_names_from_type(
148        tff.to_type(
149            [('a', tf.int32), ('b', [('c', tf.bool), ('d', tf.float32)])]
150        ),
151        'test_name',
152    )
153    self.assertEqual(names, ['test_name/a', 'test_name/b/c', 'test_name/b/d'])
154
155  def test_variable_names_from_type_with_named_tuple_type_no_name_field(self):
156    names = variable_helpers.variable_names_from_type(
157        tff.to_type([(tf.int32), ('b', [(tf.bool), ('d', tf.float32)])]),
158        'test_name',
159    )
160    self.assertEqual(names, ['test_name/0', 'test_name/b/0', 'test_name/b/d'])
161
162  def test_get_flattened_tensor_specs_with_tensor_type(self):
163    specs = variable_helpers.get_flattened_tensor_specs(
164        tff.TensorType(dtype=tf.int32, shape=tf.TensorShape([3, 5])),
165        'test_name',
166    )
167    self.assertEqual(
168        specs,
169        [
170            tf.TensorSpec(
171                name='test_name',
172                shape=tf.TensorShape([3, 5]),
173                dtype=tf.int32,
174            )
175        ],
176    )
177
178  def test_get_flattened_tensor_specs_with_federated_type(self):
179    specs = variable_helpers.get_flattened_tensor_specs(
180        tff.FederatedType(
181            tff.TensorType(dtype=tf.int32, shape=tf.TensorShape([3, 5])),
182            tff.SERVER,
183        ),
184        'test_name',
185    )
186    self.assertEqual(
187        specs,
188        [
189            tf.TensorSpec(
190                name='test_name',
191                shape=tf.TensorShape([3, 5]),
192                dtype=tf.int32,
193            )
194        ],
195    )
196
197  def test_get_flattened_tensor_specs_with_tuple_type(self):
198    specs = variable_helpers.get_flattened_tensor_specs(
199        tff.StructType([
200            (
201                'a',
202                tff.TensorType(dtype=tf.int32, shape=tf.TensorShape([3, 5])),
203            ),
204            (
205                'b',
206                tff.StructType([
207                    (tff.TensorType(dtype=tf.bool, shape=tf.TensorShape([4]))),
208                    (
209                        'd',
210                        tff.TensorType(
211                            dtype=tf.float32,
212                            shape=tf.TensorShape([1, 3, 5]),
213                        ),
214                    ),
215                ]),
216            ),
217        ]),
218        'test_name',
219    )
220    self.assertEqual(
221        specs,
222        [
223            tf.TensorSpec(
224                name='test_name/a',
225                shape=tf.TensorShape([3, 5]),
226                dtype=tf.int32,
227            ),
228            tf.TensorSpec(
229                name='test_name/b/0',
230                shape=tf.TensorShape([4]),
231                dtype=tf.bool,
232            ),
233            tf.TensorSpec(
234                name='test_name/b/d',
235                shape=tf.TensorShape([1, 3, 5]),
236                dtype=tf.float32,
237            ),
238        ],
239    )
240
241  def test_get_grouped_input_tensor_specs_for_aggregations(self):
242    daf = tff.backends.mapreduce.get_distribute_aggregate_form_for_computation(
243        sample_comp
244    )
245    grouped_input_tensor_specs = variable_helpers.get_grouped_input_tensor_specs_for_aggregations(
246        daf.client_to_server_aggregation.to_building_block(),
247        artifact_constants.AGGREGATION_INTRINSIC_ARG_SELECTION_INDEX_TO_NAME_DICT,
248    )
249    self.assertEqual(
250        grouped_input_tensor_specs,
251        [
252            [  # federated_weighted_mean intrinsic args
253                [  # federated_weighted_mean value arg
254                    tf.TensorSpec(
255                        name='update/0/0',
256                        shape=tf.TensorShape([]),
257                        dtype=tf.float32,
258                    ),
259                    tf.TensorSpec(
260                        name='update/0/1',
261                        shape=tf.TensorShape([]),
262                        dtype=tf.float32,
263                    ),
264                ],
265                [  # federated_weighted_mean weight arg
266                    tf.TensorSpec(
267                        name='update/1',
268                        shape=tf.TensorShape([]),
269                        dtype=tf.float32,
270                    )
271                ],
272            ],
273            [  # federated_secure_sum_bitwidth intrinsic args
274                [  # federated_secure_sum_bitwidth value arg
275                    tf.TensorSpec(
276                        name='update/2',
277                        shape=tf.TensorShape([]),
278                        dtype=tf.int32,
279                    )
280                ],
281                [  # federated_secure_sum_bitwidth bitwidth arg
282                    tf.TensorSpec(
283                        name='intermediate_state/0',
284                        shape=tf.TensorShape([]),
285                        dtype=tf.int32,
286                    )
287                ],
288            ],
289        ],
290    )
291
292  def test_get_grouped_output_tensor_specs_for_aggregations(self):
293    daf = tff.backends.mapreduce.get_distribute_aggregate_form_for_computation(
294        sample_comp
295    )
296    grouped_output_tensor_specs = (
297        variable_helpers.get_grouped_output_tensor_specs_for_aggregations(
298            daf.client_to_server_aggregation.to_building_block()
299        )
300    )
301    self.assertEqual(
302        grouped_output_tensor_specs,
303        [
304            [  # federated_weighted_mean intrinsic output
305                tf.TensorSpec(
306                    name='intermediate_update/0/0/0',
307                    shape=tf.TensorShape([]),
308                    dtype=tf.float32,
309                ),
310                tf.TensorSpec(
311                    name='intermediate_update/0/0/1',
312                    shape=tf.TensorShape([]),
313                    dtype=tf.float32,
314                ),
315            ],
316            [  # federated_secure_sum_bitwidth intrinsic output
317                tf.TensorSpec(
318                    name='intermediate_update/0/1',
319                    shape=tf.TensorShape([]),
320                    dtype=tf.int32,
321                )
322            ],
323        ],
324    )
325
326
327if __name__ == '__main__':
328  absltest.main()
329