# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for data_spec.py.""" import collections from absl.testing import absltest import tensorflow as tf import tensorflow_federated as tff from fcp.artifact_building import data_spec from fcp.protos import plan_pb2 _TEST_EXAMPLE_SELECTOR = plan_pb2.ExampleSelector( collection_uri='app://fake_uri' ) class DataSpecTest(absltest.TestCase): def test_construction_with_valid_arguments(self): preprocessing_fn = lambda ds: ds.batch(10) ds = data_spec.DataSpec(_TEST_EXAMPLE_SELECTOR, preprocessing_fn) self.assertIs(ds.example_selector_proto, _TEST_EXAMPLE_SELECTOR) self.assertIs(ds.preprocessing_fn, preprocessing_fn) def test_is_data_spec_or_structure(self): preprocessing_fn = lambda ds: ds.batch(10) ds = data_spec.DataSpec(_TEST_EXAMPLE_SELECTOR, preprocessing_fn) self.assertTrue(data_spec.is_data_spec_or_structure(ds)) self.assertTrue(data_spec.is_data_spec_or_structure([ds, ds])) self.assertTrue(data_spec.is_data_spec_or_structure({'a': ds})) self.assertFalse(data_spec.is_data_spec_or_structure(10)) self.assertFalse(data_spec.is_data_spec_or_structure({'a': 10})) def test_type_signature(self): def parsing_fn(serialized_example): parsing_dict = { 'key': tf.io.FixedLenFeature(shape=[1], dtype=tf.int64), } parsed_example = tf.io.parse_example(serialized_example, parsing_dict) return collections.OrderedDict([('key', parsed_example['key'])]) preprocessing_fn = lambda ds: ds.map(parsing_fn) ds = data_spec.DataSpec(_TEST_EXAMPLE_SELECTOR, preprocessing_fn) expected_type = tff.SequenceType( tff.types.to_type( collections.OrderedDict( [('key', tf.TensorSpec(shape=(1,), dtype=tf.int64))] ) ) ) self.assertEqual(ds.type_signature, expected_type) if __name__ == '__main__': tf.compat.v1.enable_v2_behavior() absltest.main()