1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""tff.Computation subclass for the demo Federated Computation platform.""" 15 16import functools 17import re 18 19import tensorflow_federated as tff 20 21COMPUTATION_NAME_REGEX = re.compile(r'\w+(/\w+)*') 22 23 24class FederatedComputation(tff.Computation): 25 """A tff.Computation that should be run in a tff.program.FederatedContext.""" 26 27 def __init__(self, comp: tff.Computation, *, name: str): 28 """Constructs a new FederatedComputation object. 29 30 Args: 31 comp: The MapReduceForm- and DistributeAggregateForm- compatible 32 computation that will be run. 33 name: A unique name for the computation. 34 """ 35 tff.backends.mapreduce.check_computation_compatible_with_map_reduce_form( 36 comp 37 ) # pytype: disable=wrong-arg-types 38 if not COMPUTATION_NAME_REGEX.fullmatch(name): 39 raise ValueError(f'name must match "{COMPUTATION_NAME_REGEX.pattern}".') 40 self._comp = comp 41 self._name = name 42 43 @functools.cached_property 44 def map_reduce_form(self) -> tff.backends.mapreduce.MapReduceForm: 45 """The underlying MapReduceForm representation.""" 46 return tff.backends.mapreduce.get_map_reduce_form_for_computation( # pytype: disable=wrong-arg-types 47 self._comp 48 ) 49 50 @functools.cached_property 51 def distribute_aggregate_form( 52 self, 53 ) -> tff.backends.mapreduce.DistributeAggregateForm: 54 """The underlying DistributeAggregateForm representation.""" 55 return tff.backends.mapreduce.get_distribute_aggregate_form_for_computation( # pytype: disable=wrong-arg-types 56 self._comp 57 ) 58 59 @property 60 def wrapped_computation(self) -> tff.Computation: 61 """The underlying tff.Computation.""" 62 return self._comp 63 64 @property 65 def name(self) -> str: 66 """The name of the computation.""" 67 return self._name 68 69 @property 70 def type_signature(self) -> tff.Type: 71 return self._comp.type_signature 72 73 def __call__(self, *args, **kwargs) ->...: 74 arg = tff.structure.Struct([(None, arg) for arg in args] + 75 list(kwargs.items())) 76 return tff.framework.get_context_stack().current.invoke(self, arg) 77 78 def __hash__(self) -> int: 79 return hash((self._comp, self._name)) 80