1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
18
19 #include <iosfwd>
20 #include <optional>
21 #include <string>
22
23 #include "tensorflow/compiler/xla/comparison_util.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27
28 namespace xla {
29
30 // High-level optimizer instruction opcodes -- these are linear-algebra level
31 // opcodes. They are a flattened form of the UnaryOp, BinaryOp, ... opcodes
32 // present in the XLA service protobuf.
33 //
34 // See the XLA documentation for the semantics of each opcode.
35 //
36 // Each entry has the format:
37 // (enum_name, opcode_name, arity)
38 //
39 // Note: Do not use ':' in opcode names. It is used as a special character
40 // in these places:
41 // - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to
42 // separate the opcode from the fusion kind
43 // - In fully qualified names (HloInstruction::FullyQualifiedName()), to
44 // separate the qualifiers (name of the computation and potentially the
45 // fusion instruction) from the name
46 //
47 // If you change one of these opcodes, please make the corresponding change to
48 // the MHLO opset to keep both opsets synchronized.
49 // LINT.IfChange
50 #define HLO_OPCODE_LIST(V) \
51 V(kAbs, "abs", 1) \
52 V(kAdd, "add", 2) \
53 V(kAddDependency, "add-dependency", 2) \
54 V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \
55 V(kAllGather, "all-gather", kHloOpcodeIsVariadic) \
56 V(kAllGatherStart, "all-gather-start", kHloOpcodeIsVariadic) \
57 V(kAllGatherDone, "all-gather-done", 1) \
58 V(kAllReduce, "all-reduce", kHloOpcodeIsVariadic) \
59 V(kAllReduceStart, "all-reduce-start", kHloOpcodeIsVariadic) \
60 V(kAllReduceDone, "all-reduce-done", 1) \
61 V(kAllToAll, "all-to-all", kHloOpcodeIsVariadic) \
62 V(kAsyncStart, "async-start", kHloOpcodeIsVariadic) \
63 V(kAsyncUpdate, "async-update", 1) \
64 V(kAsyncDone, "async-done", 1) \
65 V(kAtan2, "atan2", 2) \
66 V(kBatchNormGrad, "batch-norm-grad", 5) \
67 V(kBatchNormInference, "batch-norm-inference", 5) \
68 V(kBatchNormTraining, "batch-norm-training", 3) \
69 V(kBitcast, "bitcast", 1) \
70 V(kBitcastConvert, "bitcast-convert", 1) \
71 V(kBroadcast, "broadcast", 1) \
72 V(kCall, "call", kHloOpcodeIsVariadic) \
73 V(kCeil, "ceil", 1) \
74 V(kCholesky, "cholesky", 1) \
75 V(kClamp, "clamp", 3) \
76 V(kCollectivePermute, "collective-permute", kHloOpcodeIsVariadic) \
77 V(kCollectivePermuteStart, "collective-permute-start", kHloOpcodeIsVariadic) \
78 V(kCollectivePermuteDone, "collective-permute-done", 1) \
79 V(kClz, "count-leading-zeros", 1) \
80 V(kCompare, "compare", 2) \
81 V(kComplex, "complex", 2) \
82 V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
83 V(kConditional, "conditional", kHloOpcodeIsVariadic) \
84 V(kConstant, "constant", 0) \
85 V(kConvert, "convert", 1) \
86 V(kConvolution, "convolution", 2) \
87 V(kCopy, "copy", 1) \
88 V(kCopyDone, "copy-done", 1) \
89 V(kCopyStart, "copy-start", 1) \
90 V(kCos, "cosine", 1) \
91 V(kCustomCall, "custom-call", kHloOpcodeIsVariadic) \
92 V(kDivide, "divide", 2) \
93 V(kDomain, "domain", 1) \
94 V(kDot, "dot", 2) \
95 V(kDynamicSlice, "dynamic-slice", kHloOpcodeIsVariadic) \
96 V(kDynamicUpdateSlice, "dynamic-update-slice", kHloOpcodeIsVariadic) \
97 V(kExp, "exponential", 1) \
98 V(kExpm1, "exponential-minus-one", 1) \
99 V(kFft, "fft", 1) \
100 V(kFloor, "floor", 1) \
101 V(kFusion, "fusion", kHloOpcodeIsVariadic) \
102 V(kGather, "gather", 2) \
103 V(kGetDimensionSize, "get-dimension-size", 1) \
104 V(kSetDimensionSize, "set-dimension-size", 2) \
105 V(kGetTupleElement, "get-tuple-element", 1) \
106 V(kImag, "imag", 1) \
107 V(kInfeed, "infeed", 1) \
108 V(kIota, "iota", 0) \
109 V(kIsFinite, "is-finite", 1) \
110 V(kLog, "log", 1) \
111 V(kLog1p, "log-plus-one", 1) \
112 V(kLogistic, "logistic", 1) \
113 V(kAnd, "and", 2) \
114 V(kNot, "not", 1) \
115 V(kOptimizationBarrier, "opt-barrier", 1) \
116 V(kOr, "or", 2) \
117 V(kXor, "xor", 2) \
118 V(kMap, "map", kHloOpcodeIsVariadic) \
119 V(kMaximum, "maximum", 2) \
120 V(kMinimum, "minimum", 2) \
121 V(kMultiply, "multiply", 2) \
122 V(kNegate, "negate", 1) \
123 V(kOutfeed, "outfeed", 2) \
124 V(kPad, "pad", 2) \
125 V(kParameter, "parameter", 0) \
126 V(kPartitionId, "partition-id", 0) \
127 V(kPopulationCount, "popcnt", 1) \
128 V(kPower, "power", 2) \
129 V(kReal, "real", 1) \
130 V(kRecv, "recv", 1) \
131 V(kRecvDone, "recv-done", 1) \
132 V(kReduce, "reduce", kHloOpcodeIsVariadic) \
133 V(kReducePrecision, "reduce-precision", 1) \
134 V(kReduceScatter, "reduce-scatter", kHloOpcodeIsVariadic) \
135 V(kReduceWindow, "reduce-window", kHloOpcodeIsVariadic) \
136 V(kRemainder, "remainder", 2) \
137 V(kReplicaId, "replica-id", 0) \
138 V(kReshape, "reshape", 1) \
139 V(kDynamicReshape, "dynamic-reshape", kHloOpcodeIsVariadic) \
140 V(kReverse, "reverse", 1) \
141 V(kRng, "rng", kHloOpcodeIsVariadic) \
142 V(kRngGetAndUpdateState, "rng-get-and-update-state", 0) \
143 V(kRngBitGenerator, "rng-bit-generator", 1) \
144 V(kRoundNearestAfz, "round-nearest-afz", 1) \
145 V(kRoundNearestEven, "round-nearest-even", 1) \
146 V(kRsqrt, "rsqrt", 1) \
147 V(kScatter, "scatter", kHloOpcodeIsVariadic) \
148 V(kSelect, "select", 3) \
149 V(kSelectAndScatter, "select-and-scatter", 3) \
150 V(kSend, "send", 2) \
151 V(kSendDone, "send-done", 1) \
152 V(kShiftLeft, "shift-left", 2) \
153 V(kShiftRightArithmetic, "shift-right-arithmetic", 2) \
154 V(kShiftRightLogical, "shift-right-logical", 2) \
155 V(kSign, "sign", 1) \
156 V(kSin, "sine", 1) \
157 V(kSlice, "slice", 1) \
158 V(kSort, "sort", kHloOpcodeIsVariadic) \
159 V(kSqrt, "sqrt", 1) \
160 V(kCbrt, "cbrt", 1) \
161 V(kSubtract, "subtract", 2) \
162 V(kTanh, "tanh", 1) \
163 V(kTranspose, "transpose", 1) \
164 V(kTriangularSolve, "triangular-solve", 2) \
165 V(kTuple, "tuple", kHloOpcodeIsVariadic) \
166 V(kWhile, "while", 1)
167 // LINT.ThenChange(../../mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td)
168
169 enum class HloOpcode {
170 #define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name,
171 HLO_OPCODE_LIST(DECLARE_ENUM)
172 #undef DECLARE_ENUM
173 };
174
175 // Arity value that denotes that an operator is variadic.
176 enum {
177 kHloOpcodeIsVariadic = -1,
178 };
179
180 // Returns a string representation of the opcode.
181 std::string HloOpcodeString(HloOpcode opcode);
182
183 // Retrieves the opcode enum by name if the opcode exists.
184 StatusOr<HloOpcode> StringToHloOpcode(const std::string& opcode_name);
185
186 inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) {
187 return os << HloOpcodeString(opcode);
188 }
189
190 // Returns true iff the given opcode is a comparison operation.
191 bool HloOpcodeIsComparison(HloOpcode opcode);
192
193 // Returns true iff the given opcode has variadic operands.
194 bool HloOpcodeIsVariadic(HloOpcode opcode);
195
196 // Returns the arity of opcode. If the opcode is variadic,
197 // returns nullopt.
198 std::optional<int> HloOpcodeArity(HloOpcode opcode);
199
200 // Returns true if the given opcode is one of kAsyncStart, kAsyncUpdate, or
201 // kAsyncDone.
202 bool HloOpcodeIsAsync(HloOpcode opcode);
203
204 // True if the op takes two arguments and order doesn't matter.
HloOpcodeIsBinaryCommutative(HloOpcode opcode)205 inline bool HloOpcodeIsBinaryCommutative(HloOpcode opcode) {
206 switch (opcode) {
207 case HloOpcode::kAdd:
208 case HloOpcode::kMultiply:
209 case HloOpcode::kMaximum:
210 case HloOpcode::kMinimum:
211 case HloOpcode::kAnd:
212 case HloOpcode::kOr:
213 case HloOpcode::kXor:
214 return true;
215 default:
216 return false;
217 }
218 }
219
220 // Returns the number of HloOpcode values.
HloOpcodeCount()221 inline const uint32_t HloOpcodeCount() {
222 #define HLO_COUNT_ONE(...) +1
223 #define HLO_XLIST_LENGTH(list) list(HLO_COUNT_ONE)
224 return HLO_XLIST_LENGTH(HLO_OPCODE_LIST);
225 }
226
227 } // namespace xla
228
229 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
230