1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20
21 namespace tensorflow {
XlaResourceOpKindToString(XlaResourceOpKind op_kind)22 /*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
23 XlaResourceOpKind op_kind) {
24 switch (op_kind) {
25 case XlaResourceOpKind::kRead:
26 return "Read";
27 case XlaResourceOpKind::kWrite:
28 return "Write";
29 case XlaResourceOpKind::kReadWrite:
30 return "Modify";
31 }
32 }
33
34 static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
CreateResourceOpInfoMap()35 CreateResourceOpInfoMap() {
36 auto* result = new absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>;
37
38 auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
39 XlaResourceKind resource_kind) {
40 auto insert_result =
41 result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
42 CHECK(insert_result.second);
43 };
44
45 auto kRead = XlaResourceOpKind::kRead;
46 auto kWrite = XlaResourceOpKind::kWrite;
47 auto kReadWrite = XlaResourceOpKind::kReadWrite;
48
49 auto kVariable = XlaResourceKind::kVariable;
50 auto kStack = XlaResourceKind::kStack;
51 auto kTensorArray = XlaResourceKind::kTensorArray;
52
53 // clang-format off
54 add("AssignAddVariableOp" , kReadWrite, kVariable);
55 add("AssignSubVariableOp" , kReadWrite, kVariable);
56 add("AssignVariableOp" , kWrite, kVariable);
57 add("AssignVariableXlaConcatND" , kWrite, kVariable);
58 add("CollectiveReduceV2" , kRead, kVariable);
59 add("ReadVariableOp" , kRead, kVariable);
60 add("ReadVariableXlaSplitND" , kRead, kVariable);
61 add("ResourceApplyAdaMax" , kReadWrite, kVariable);
62 add("ResourceApplyAdadelta" , kReadWrite, kVariable);
63 add("ResourceApplyAdagrad" , kReadWrite, kVariable);
64 add("ResourceApplyAdagradV2" , kReadWrite, kVariable),
65 add("ResourceApplyAdagradDA" , kReadWrite, kVariable);
66 add("ResourceApplyAdam" , kReadWrite, kVariable);
67 add("ResourceApplyAddSign" , kReadWrite, kVariable);
68 add("ResourceApplyCenteredRMSProp" , kReadWrite, kVariable);
69 add("ResourceApplyFtrl" , kReadWrite, kVariable);
70 add("ResourceApplyFtrlV2" , kReadWrite, kVariable);
71 add("ResourceApplyGradientDescent" , kReadWrite, kVariable);
72 add("ResourceApplyMomentum" , kReadWrite, kVariable);
73 add("ResourceApplyKerasMomentum" , kReadWrite, kVariable);
74 add("ResourceApplyPowerSign" , kReadWrite, kVariable);
75 add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable);
76 add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable);
77 add("ResourceApplyRMSProp" , kReadWrite, kVariable);
78 add("ResourceGather" , kRead, kVariable);
79 add("ResourceScatterAdd" , kReadWrite, kVariable);
80 add("ResourceScatterDiv" , kReadWrite, kVariable);
81 add("ResourceScatterMax" , kReadWrite, kVariable);
82 add("ResourceScatterMin" , kReadWrite, kVariable);
83 add("ResourceScatterMul" , kReadWrite, kVariable);
84 add("ResourceScatterNdAdd" , kReadWrite, kVariable);
85 add("ResourceScatterNdSub" , kReadWrite, kVariable);
86 add("ResourceScatterNdUpdate" , kReadWrite, kVariable);
87 add("ResourceScatterSub" , kReadWrite, kVariable);
88 add("ResourceScatterUpdate" , kReadWrite, kVariable);
89 add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
90 add("RngReadAndSkip" , kReadWrite, kVariable);
91 add("RngSkip" , kReadWrite, kVariable);
92 add("StatefulStandardNormalV2" , kReadWrite, kVariable);
93 add("StatefulTruncatedNormal" , kReadWrite, kVariable);
94 add("StatefulUniform" , kReadWrite, kVariable);
95 add("StatefulUniformFullInt" , kReadWrite, kVariable);
96 add("StatefulUniformInt" , kReadWrite, kVariable);
97 add("VarIsInitializedOp" , kRead, kVariable);
98 add("VariableShape" , kRead, kVariable);
99
100 add("StackV2" , kWrite, kStack);
101 add("StackCloseV2" , kRead, kStack);
102 add("StackPopV2" , kReadWrite, kStack);
103 add("StackPushV2" , kReadWrite, kStack);
104
105 add("TensorArrayV3" , kWrite, kTensorArray);
106 add("TensorArrayConcatV3" , kRead, kTensorArray);
107 add("TensorArrayGatherV3" , kRead, kTensorArray);
108 add("TensorArrayScatterV3" , kWrite, kTensorArray);
109 add("TensorArrayGradV3" , kRead, kTensorArray);
110 add("TensorArrayCloseV3" , kRead, kTensorArray);
111 add("TensorArrayReadV3" , kRead, kTensorArray);
112 add("TensorArraySizeV3" , kRead, kTensorArray);
113 add("TensorArraySplitV3" , kWrite, kTensorArray);
114 add("TensorArrayWriteV3" , kWrite, kTensorArray);
115 // clang-format on
116
117 return result;
118 }
119
120 static const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>&
GetStaticResourceOpInfoMap()121 GetStaticResourceOpInfoMap() {
122 static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
123 op_info_map = CreateResourceOpInfoMap();
124 return *op_info_map;
125 }
126
GetResourceOpInfoForOp(absl::string_view op)127 const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
128 const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>& op_infos =
129 GetStaticResourceOpInfoMap();
130 auto it = op_infos.find(op);
131 return it == op_infos.end() ? nullptr : &it->second;
132 }
133
134 namespace resource_op_table_internal {
GetKnownResourceOps()135 std::vector<absl::string_view> GetKnownResourceOps() {
136 std::vector<absl::string_view> result;
137 for (const auto& p : GetStaticResourceOpInfoMap()) {
138 result.push_back(p.first);
139 }
140 absl::c_sort(result);
141 return result;
142 }
143 } // namespace resource_op_table_internal
144 } // namespace tensorflow
145