xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_hardware_feature.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU hardware feature info."""
16import enum
17from tensorflow.core.protobuf.tpu import topology_pb2
18from tensorflow.python.util.tf_export import tf_export
19
20
21@tf_export("tpu.experimental.HardwareFeature")
22class HardwareFeature(object):
23  """class holds all the feature info about the TPU."""
24
25  def __init__(self, tpu_hardware_feature_proto):
26    """Store TPU hardware feature info.
27
28    Args:
29      tpu_hardware_feature_proto: protobuf which describe the tpu hardware
30        feature.
31    """
32    self.tpu_hardware_feature_proto = tpu_hardware_feature_proto
33
34  class EmbeddingFeature(enum.Enum):
35    """Embedding feature flag strings.
36
37    UNSUPPORTED: No embedding lookup accelerator available on the tpu.
38    V1: Embedding lookup accelerator V1. The embedding lookup operation can only
39        be placed at the beginning of computation. Only one instance of
40        embedding
41        lookup layer is allowed.
42    V2: Embedding lookup accelerator V2. The embedding lookup operation can be
43        placed anywhere of the computation. Multiple instances of embedding
44        lookup layer is allowed.
45    """
46    UNSUPPORTED = "UNSUPPORTED"
47    V1 = "V1"
48    V2 = "V2"
49
50  @classmethod
51  def _embedding_feature_proto_to_string(cls, embedding_feature_proto):
52    """Convert the embedding feature proto to enum string."""
53    embedding_feature_proto_to_string_map = {
54        topology_pb2.TPUHardwareFeature.EmbeddingFeature.UNSUPPORTED:
55            HardwareFeature.EmbeddingFeature.UNSUPPORTED,
56        topology_pb2.TPUHardwareFeature.EmbeddingFeature.V1:
57            HardwareFeature.EmbeddingFeature.V1,
58        topology_pb2.TPUHardwareFeature.EmbeddingFeature.V2:
59            HardwareFeature.EmbeddingFeature.V2
60    }
61    return embedding_feature_proto_to_string_map.get(
62        embedding_feature_proto, HardwareFeature.EmbeddingFeature.UNSUPPORTED)
63
64  @property
65  def embedding_feature(self):
66    """TPU embedding feature.
67
68    Returns:
69      An EmbeddingFeature enum.
70    """
71    return HardwareFeature._embedding_feature_proto_to_string(
72        self.tpu_hardware_feature_proto.embedding_feature)
73