1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU hardware feature info.""" 16import enum 17from tensorflow.core.protobuf.tpu import topology_pb2 18from tensorflow.python.util.tf_export import tf_export 19 20 21@tf_export("tpu.experimental.HardwareFeature") 22class HardwareFeature(object): 23 """class holds all the feature info about the TPU.""" 24 25 def __init__(self, tpu_hardware_feature_proto): 26 """Store TPU hardware feature info. 27 28 Args: 29 tpu_hardware_feature_proto: protobuf which describe the tpu hardware 30 feature. 31 """ 32 self.tpu_hardware_feature_proto = tpu_hardware_feature_proto 33 34 class EmbeddingFeature(enum.Enum): 35 """Embedding feature flag strings. 36 37 UNSUPPORTED: No embedding lookup accelerator available on the tpu. 38 V1: Embedding lookup accelerator V1. The embedding lookup operation can only 39 be placed at the beginning of computation. Only one instance of 40 embedding 41 lookup layer is allowed. 42 V2: Embedding lookup accelerator V2. The embedding lookup operation can be 43 placed anywhere of the computation. Multiple instances of embedding 44 lookup layer is allowed. 45 """ 46 UNSUPPORTED = "UNSUPPORTED" 47 V1 = "V1" 48 V2 = "V2" 49 50 @classmethod 51 def _embedding_feature_proto_to_string(cls, embedding_feature_proto): 52 """Convert the embedding feature proto to enum string.""" 53 embedding_feature_proto_to_string_map = { 54 topology_pb2.TPUHardwareFeature.EmbeddingFeature.UNSUPPORTED: 55 HardwareFeature.EmbeddingFeature.UNSUPPORTED, 56 topology_pb2.TPUHardwareFeature.EmbeddingFeature.V1: 57 HardwareFeature.EmbeddingFeature.V1, 58 topology_pb2.TPUHardwareFeature.EmbeddingFeature.V2: 59 HardwareFeature.EmbeddingFeature.V2 60 } 61 return embedding_feature_proto_to_string_map.get( 62 embedding_feature_proto, HardwareFeature.EmbeddingFeature.UNSUPPORTED) 63 64 @property 65 def embedding_feature(self): 66 """TPU embedding feature. 67 68 Returns: 69 An EmbeddingFeature enum. 70 """ 71 return HardwareFeature._embedding_feature_proto_to_string( 72 self.tpu_hardware_feature_proto.embedding_feature) 73