xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/resource_operation_table.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 
21 namespace tensorflow {
XlaResourceOpKindToString(XlaResourceOpKind op_kind)22 /*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
23     XlaResourceOpKind op_kind) {
24   switch (op_kind) {
25     case XlaResourceOpKind::kRead:
26       return "Read";
27     case XlaResourceOpKind::kWrite:
28       return "Write";
29     case XlaResourceOpKind::kReadWrite:
30       return "Modify";
31   }
32 }
33 
34 static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
CreateResourceOpInfoMap()35 CreateResourceOpInfoMap() {
36   auto* result = new absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>;
37 
38   auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
39                  XlaResourceKind resource_kind) {
40     auto insert_result =
41         result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
42     CHECK(insert_result.second);
43   };
44 
45   auto kRead = XlaResourceOpKind::kRead;
46   auto kWrite = XlaResourceOpKind::kWrite;
47   auto kReadWrite = XlaResourceOpKind::kReadWrite;
48 
49   auto kVariable = XlaResourceKind::kVariable;
50   auto kStack = XlaResourceKind::kStack;
51   auto kTensorArray = XlaResourceKind::kTensorArray;
52 
53   // clang-format off
54   add("AssignAddVariableOp"                  , kReadWrite, kVariable);
55   add("AssignSubVariableOp"                  , kReadWrite, kVariable);
56   add("AssignVariableOp"                     , kWrite,     kVariable);
57   add("AssignVariableXlaConcatND"            , kWrite,     kVariable);
58   add("CollectiveReduceV2"                   , kRead,      kVariable);
59   add("ReadVariableOp"                       , kRead,      kVariable);
60   add("ReadVariableXlaSplitND"               , kRead,      kVariable);
61   add("ResourceApplyAdaMax"                  , kReadWrite, kVariable);
62   add("ResourceApplyAdadelta"                , kReadWrite, kVariable);
63   add("ResourceApplyAdagrad"                 , kReadWrite, kVariable);
64   add("ResourceApplyAdagradV2"               , kReadWrite, kVariable),
65   add("ResourceApplyAdagradDA"               , kReadWrite, kVariable);
66   add("ResourceApplyAdam"                    , kReadWrite, kVariable);
67   add("ResourceApplyAddSign"                 , kReadWrite, kVariable);
68   add("ResourceApplyCenteredRMSProp"         , kReadWrite, kVariable);
69   add("ResourceApplyFtrl"                    , kReadWrite, kVariable);
70   add("ResourceApplyFtrlV2"                  , kReadWrite, kVariable);
71   add("ResourceApplyGradientDescent"         , kReadWrite, kVariable);
72   add("ResourceApplyMomentum"                , kReadWrite, kVariable);
73   add("ResourceApplyKerasMomentum"           , kReadWrite, kVariable);
74   add("ResourceApplyPowerSign"               , kReadWrite, kVariable);
75   add("ResourceApplyProximalAdagrad"         , kReadWrite, kVariable);
76   add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable);
77   add("ResourceApplyRMSProp"                 , kReadWrite, kVariable);
78   add("ResourceGather"                       , kRead,      kVariable);
79   add("ResourceScatterAdd"                   , kReadWrite, kVariable);
80   add("ResourceScatterDiv"                   , kReadWrite, kVariable);
81   add("ResourceScatterMax"                   , kReadWrite, kVariable);
82   add("ResourceScatterMin"                   , kReadWrite, kVariable);
83   add("ResourceScatterMul"                   , kReadWrite, kVariable);
84   add("ResourceScatterNdAdd"                 , kReadWrite, kVariable);
85   add("ResourceScatterNdSub"                 , kReadWrite, kVariable);
86   add("ResourceScatterNdUpdate"              , kReadWrite, kVariable);
87   add("ResourceScatterSub"                   , kReadWrite, kVariable);
88   add("ResourceScatterUpdate"                , kReadWrite, kVariable);
89   add("ResourceStridedSliceAssign"           , kReadWrite, kVariable);
90   add("RngReadAndSkip"                       , kReadWrite, kVariable);
91   add("RngSkip"                              , kReadWrite, kVariable);
92   add("StatefulStandardNormalV2"             , kReadWrite, kVariable);
93   add("StatefulTruncatedNormal"              , kReadWrite, kVariable);
94   add("StatefulUniform"                      , kReadWrite, kVariable);
95   add("StatefulUniformFullInt"               , kReadWrite, kVariable);
96   add("StatefulUniformInt"                   , kReadWrite, kVariable);
97   add("VarIsInitializedOp"                   , kRead,      kVariable);
98   add("VariableShape"                        , kRead,      kVariable);
99 
100   add("StackV2"                              , kWrite,     kStack);
101   add("StackCloseV2"                         , kRead,      kStack);
102   add("StackPopV2"                           , kReadWrite, kStack);
103   add("StackPushV2"                          , kReadWrite, kStack);
104 
105   add("TensorArrayV3"                        , kWrite,     kTensorArray);
106   add("TensorArrayConcatV3"                  , kRead,      kTensorArray);
107   add("TensorArrayGatherV3"                  , kRead,      kTensorArray);
108   add("TensorArrayScatterV3"                 , kWrite,     kTensorArray);
109   add("TensorArrayGradV3"                    , kRead,      kTensorArray);
110   add("TensorArrayCloseV3"                   , kRead,      kTensorArray);
111   add("TensorArrayReadV3"                    , kRead,      kTensorArray);
112   add("TensorArraySizeV3"                    , kRead,      kTensorArray);
113   add("TensorArraySplitV3"                   , kWrite,     kTensorArray);
114   add("TensorArrayWriteV3"                   , kWrite,     kTensorArray);
115   // clang-format on
116 
117   return result;
118 }
119 
120 static const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>&
GetStaticResourceOpInfoMap()121 GetStaticResourceOpInfoMap() {
122   static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
123       op_info_map = CreateResourceOpInfoMap();
124   return *op_info_map;
125 }
126 
GetResourceOpInfoForOp(absl::string_view op)127 const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
128   const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>& op_infos =
129       GetStaticResourceOpInfoMap();
130   auto it = op_infos.find(op);
131   return it == op_infos.end() ? nullptr : &it->second;
132 }
133 
134 namespace resource_op_table_internal {
GetKnownResourceOps()135 std::vector<absl::string_view> GetKnownResourceOps() {
136   std::vector<absl::string_view> result;
137   for (const auto& p : GetStaticResourceOpInfoMap()) {
138     result.push_back(p.first);
139   }
140   absl::c_sort(result);
141   return result;
142 }
143 }  // namespace resource_op_table_internal
144 }  // namespace tensorflow
145