xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_verifier.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
25 #include "tensorflow/compiler/xla/service/shape_inference.h"
26 
27 namespace xla {
28 
29 // Callback to return shape size, in bytes.
30 using ShapeSizeFn = std::function<int64_t(const Shape&)>;
31 
32 struct HloVerifierOpts {
MakeLayoutSensitiveHloVerifierOpts33   HloVerifierOpts&& MakeLayoutSensitive() {
34     layout_sensitive = true;
35     return std::move(*this);
36   }
37 
WithLayoutSensitiveHloVerifierOpts38   HloVerifierOpts&& WithLayoutSensitive(bool layout_sensitive_p) {
39     layout_sensitive = layout_sensitive_p;
40     return std::move(*this);
41   }
42 
WithAllowMixedPrecisionHloVerifierOpts43   HloVerifierOpts&& WithAllowMixedPrecision(bool allow_mixed_precision_p) {
44     allow_mixed_precision = allow_mixed_precision_p;
45     return std::move(*this);
46   }
47 
AllowMixedPrecisionHloVerifierOpts48   HloVerifierOpts&& AllowMixedPrecision() {
49     allow_mixed_precision = true;
50     return std::move(*this);
51   }
52 
VerifyBroadcastDimensionsOrderHloVerifierOpts53   HloVerifierOpts&& VerifyBroadcastDimensionsOrder() {
54     verify_broadcast_dimensions_order = true;
55     return std::move(*this);
56   }
57 
VerifyReshapeIsBitcastHloVerifierOpts58   HloVerifierOpts&& VerifyReshapeIsBitcast() {
59     verify_reshape_is_bitcast = true;
60     return std::move(*this);
61   }
62 
VerifyCustomCallNestedComputationThreadNameHloVerifierOpts63   HloVerifierOpts&& VerifyCustomCallNestedComputationThreadName() {
64     verify_custom_call_nested_computation_thread_name = true;
65     return std::move(*this);
66   }
67 
WithAllowBitcastToHaveDifferentSizeHloVerifierOpts68   HloVerifierOpts&& WithAllowBitcastToHaveDifferentSize(bool allow) {
69     allow_bitcast_to_have_different_size = allow;
70     return std::move(*this);
71   }
72 
WithInstructionCanChangeLayoutHloVerifierOpts73   HloVerifierOpts&& WithInstructionCanChangeLayout(
74       const HloPredicate& instruction_can_change_layout_p) {
75     instruction_can_change_layout = instruction_can_change_layout_p;
76     return std::move(*this);
77   }
78 
WithCustomShapeSizeHloVerifierOpts79   HloVerifierOpts&& WithCustomShapeSize(const ShapeSizeFn& shape_size_p) {
80     shape_size = shape_size_p;
81     return std::move(*this);
82   }
83 
IsLayoutSensitiveHloVerifierOpts84   bool IsLayoutSensitive() const { return layout_sensitive; }
85 
AllowMixedPrecisionHloVerifierOpts86   bool AllowMixedPrecision() const { return allow_mixed_precision; }
87 
InstructionCanChangeLayoutHloVerifierOpts88   const HloPredicate& InstructionCanChangeLayout() const {
89     return instruction_can_change_layout;
90   }
91 
InstructionCanChangeLayoutHloVerifierOpts92   bool InstructionCanChangeLayout(const HloInstruction* instruction) const {
93     return !instruction_can_change_layout ||
94            instruction_can_change_layout(instruction);
95   }
96 
ShapeSizeHloVerifierOpts97   int64_t ShapeSize(const Shape& shape) const { return shape_size(shape); }
98 
99   // If the verifier is layout-sensitive, shapes must be equal to what's
100   // expected.  Otherwise, the shapes must simply be compatible.
101   bool layout_sensitive = false;
102 
103   // Whether the inputs and output of an instruction can contain both F32s and
104   // BF16s. Tuples that include both F32s and BF16s are allowed regardless of
105   // this flag.
106   bool allow_mixed_precision = false;
107 
108   // Check that `dimensions` attribute of broadcast is sorted.
109   bool verify_broadcast_dimensions_order = false;
110 
111   // Check that reshape is a physical bitcast.
112   bool verify_reshape_is_bitcast = false;
113 
114   // Check that custom call's called computations have same thread name as
115   // parent computation.
116   bool verify_custom_call_nested_computation_thread_name = true;
117 
118   // Whether bitcast should have the same size, including all paddings.
119   bool allow_bitcast_to_have_different_size = false;
120 
121   HloPredicate instruction_can_change_layout;
122 
123   // Returns a target-specific shape size.
124   ShapeSizeFn shape_size = [](const Shape& shape) {
125     return ShapeUtil::ByteSizeOf(shape);
126   };
127 };
128 
129 // Visitor which verifies that the output shape is correctly set. Verifies
130 // against the inferred shape for the instruction.
131 class ShapeVerifier : public DfsHloVisitor {
132  public:
ShapeVerifier(const HloVerifierOpts & opts)133   explicit ShapeVerifier(const HloVerifierOpts& opts) : opts_(opts) {}
134 
135   // Verifies that entry computation layout matches parameters and root shape of
136   // the module's entry computation.
137   virtual Status VerifyEntryComputationLayout(const HloModule& module);
138 
139   Status Preprocess(HloInstruction* hlo) override;
140 
141   Status HandleElementwiseUnary(HloInstruction* hlo) override;
142   Status HandleElementwiseBinary(HloInstruction* hlo) override;
143   Status HandleClamp(HloInstruction* clamp) override;
144   Status HandleSelect(HloInstruction* select) override;
145   Status HandleConcatenate(HloInstruction* concatenate) override;
146   Status HandleIota(HloInstruction* hlo) override;
147   Status HandleConvert(HloInstruction* convert) override;
148   Status HandleBitcastConvert(HloInstruction* convert) override;
149   Status HandleCopy(HloInstruction* copy) override;
150   Status HandleDot(HloInstruction* dot) override;
151   Status HandleConvolution(HloInstruction* convolution) override;
152   Status HandleFft(HloInstruction* fft) override;
153   Status HandleCholesky(HloInstruction* hlo) override;
154   Status HandleTriangularSolve(HloInstruction* hlo) override;
155   Status HandleAllGather(HloInstruction* hlo) override;
156   Status HandleAllGatherStart(HloInstruction* hlo) override;
157   Status HandleAllGatherDone(HloInstruction* hlo) override;
158   Status HandleAllReduce(HloInstruction* hlo) override;
159   Status HandleAllReduceStart(HloInstruction* hlo) override;
160   Status HandleAllReduceDone(HloInstruction* hlo) override;
161   Status HandleAllToAll(HloInstruction* hlo) override;
162   Status HandleCollectivePermute(HloInstruction* hlo) override;
163   Status HandleCollectivePermuteStart(HloInstruction* hlo) override;
164   Status HandleCollectivePermuteDone(HloInstruction* hlo) override;
165   Status HandlePartitionId(HloInstruction* hlo) override;
166   Status HandleReplicaId(HloInstruction* hlo) override;
167   Status HandleReducePrecision(HloInstruction* reduce_precision) override;
168   Status HandleInfeed(HloInstruction*) override;
169   Status HandleOptimizationBarrier(HloInstruction* hlo) override;
170   Status HandleOutfeed(HloInstruction*) override;
171   Status HandleRng(HloInstruction*) override;
172   Status HandleRngBitGenerator(HloInstruction*) override;
173   Status HandleRngGetAndUpdateState(HloInstruction*) override;
174   Status HandleReverse(HloInstruction* reverse) override;
175   Status HandleSort(HloInstruction* sort) override;
176   Status HandleConstant(HloInstruction* constant) override;
177   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
178   Status HandleReduce(HloInstruction* reduce) override;
179   Status HandleBitcast(HloInstruction* bitcast) override;
180   Status HandleBroadcast(HloInstruction* broadcast) override;
181   Status HandleReshape(HloInstruction* reshape) override;
182   Status HandleDynamicReshape(HloInstruction* dynamic_reshape) override;
183   Status HandleTranspose(HloInstruction* transpose) override;
184   Status HandleParameter(HloInstruction*) override;
185   Status HandleFusion(HloInstruction*) override;
186   Status HandleCall(HloInstruction* call) override;
187   Status HandleCustomCall(HloInstruction*) override;
188   Status HandleSlice(HloInstruction* slice) override;
189   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
190   Status HandleDynamicUpdateSlice(
191       HloInstruction* dynamic_update_slice) override;
192   Status HandleTuple(HloInstruction* tuple) override;
193   Status HandleMap(HloInstruction* map) override;
194   Status HandleReduceScatter(HloInstruction* hlo) override;
195   Status HandleReduceWindow(HloInstruction* reduce_window) override;
196   Status HandleSelectAndScatter(HloInstruction* instruction) override;
197   Status HandleWhile(HloInstruction* xla_while) override;
198   Status HandleConditional(HloInstruction* conditional) override;
199   Status HandlePad(HloInstruction* pad) override;
200   Status HandleAsyncStart(HloInstruction* async_start) override;
201   Status HandleAsyncUpdate(HloInstruction* async_update) override;
202   Status HandleAsyncDone(HloInstruction* async_done) override;
203   Status HandleCopyStart(HloInstruction* copy_start) override;
204   Status HandleCopyDone(HloInstruction* copy_done) override;
205   Status HandleSend(HloInstruction* send) override;
206   Status HandleSendDone(HloInstruction* send_done) override;
207   Status HandleRecv(HloInstruction* recv) override;
208   Status HandleRecvDone(HloInstruction* recv_done) override;
209   Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override;
210   Status HandleBatchNormInference(
211       HloInstruction* batch_norm_inference) override;
212   Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
213   Status HandleGather(HloInstruction* gather) override;
214   Status HandleScatter(HloInstruction* scatter) override;
215   Status HandleAfterAll(HloInstruction* token) override;
216   Status HandleGetDimensionSize(HloInstruction* get_size) override;
217   Status HandleSetDimensionSize(HloInstruction* set_size) override;
218   Status HandleAddDependency(HloInstruction* add_dependency) override;
219 
FinishVisit(HloInstruction *)220   Status FinishVisit(HloInstruction*) override { return OkStatus(); }
221 
222  protected:
223   // Check the instruction's shape against the shape given by ShapeInference
224   // and return an appropriate error if there is a mismatch.
225   Status CheckShape(const HloInstruction* instruction,
226                     const Shape& inferred_shape,
227                     bool only_compare_minor_to_major_in_layout = false);
228 
229   // Overload which takes a StatusOr to reduce boilerplate in the caller.
230   Status CheckShape(const HloInstruction* instruction,
231                     const StatusOr<Shape>& inferred_shape_status);
232 
233   // Check a unary (binary, etc) instruction's shape against the inferred shape.
234   Status CheckUnaryShape(const HloInstruction* instruction);
235   Status CheckBinaryShape(const HloInstruction* instruction);
236   Status CheckTernaryShape(const HloInstruction* instruction);
237   Status CheckVariadicShape(const HloInstruction* instruction);
238 
239  private:
240   // Helpers that switch on layout_sensitive_.
241   bool ShapesSame(const Shape& a, const Shape& b,
242                   bool minor_to_major_only = false,
243                   bool ignore_memory_space = false) {
244     if (!opts_.layout_sensitive) {
245       return ShapeUtil::Compatible(a, b);
246     }
247     Shape::Equal equal;
248     if (ignore_memory_space) {
249       equal.IgnoreMemorySpaceInLayout();
250     }
251     if (minor_to_major_only) {
252       equal.MinorToMajorOnlyInLayout();
253     }
254     return equal(a, b);
255   }
256 
257   bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b,
258                                      bool minor_to_major_only = false) {
259     if (!opts_.layout_sensitive) {
260       return ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
261     }
262     Shape::Equal equal;
263     if (minor_to_major_only) {
264       equal.MinorToMajorOnlyInLayout();
265     }
266     equal.IgnoreFpPrecision();
267     return equal(a, b);
268   }
269 
StringifyShape(const Shape & s)270   std::string StringifyShape(const Shape& s) {
271     return opts_.layout_sensitive ? ShapeUtil::HumanStringWithLayout(s)
272                                   : ShapeUtil::HumanString(s);
273   }
274 
275   // Helpers that switch on allow_mixed_precision_.
SameElementType(const Shape & a,const Shape & b)276   bool SameElementType(const Shape& a, const Shape& b) {
277     return opts_.allow_mixed_precision
278                ? ShapeUtil::SameElementTypeIgnoringFpPrecision(a, b)
279                : ShapeUtil::SameElementType(a, b);
280   }
281 
282   // Checks that the given operand of the given instruction is of type TOKEN.
283   Status CheckIsTokenOperand(const HloInstruction* instruction,
284                              int64_t operand_no);
285 
286   // Checks that the shape of the given operand of the given instruction matches
287   // the given parameter of the given computation.
288   Status CheckOperandAndParameter(const HloInstruction* instruction,
289                                   int64_t operand_number,
290                                   const HloComputation* computation,
291                                   int64_t parameter_number);
292 
293   // Returns true if the shapes of the two operands have the same element type,
294   // and the result shape either has the same element type as the operand shapes
295   // or mixed precision is allowed and the result shape and the operand shapes
296   // have floating point element types.
297   bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1,
298                                  const Shape& result_shape);
299 
300   const HloVerifierOpts& opts_;
301 };
302 
303 // An interface used to encapsulate target-specific verification quirks.
304 class TargetVerifierMetadata {
305  public:
TargetVerifierMetadata(HloVerifierOpts && opts)306   explicit TargetVerifierMetadata(HloVerifierOpts&& opts) : opts_(opts) {
307     CHECK(opts.instruction_can_change_layout == nullptr ||
308           opts.layout_sensitive);
309   }
310 
311   virtual std::unique_ptr<ShapeVerifier> GetVerifier() const = 0;
312 
TargetVerifierMetadata()313   TargetVerifierMetadata() {}
~TargetVerifierMetadata()314   virtual ~TargetVerifierMetadata() {}
315 
316   TargetVerifierMetadata(const TargetVerifierMetadata&) = delete;
317   TargetVerifierMetadata& operator=(const TargetVerifierMetadata&) = delete;
318 
GetVerifierOpts()319   const HloVerifierOpts& GetVerifierOpts() const { return opts_; }
320 
321  private:
322   HloVerifierOpts opts_;
323 };
324 
325 // The default implementation of TargetVerifierMetadata, used unless the target
326 // needs to override it.
327 class DefaultVerifierMetadata : public TargetVerifierMetadata {
328  public:
DefaultVerifierMetadata(HloVerifierOpts && opts)329   explicit DefaultVerifierMetadata(HloVerifierOpts&& opts)
330       : TargetVerifierMetadata(std::move(opts)) {}
331 
332   // Creates a ShapeVerifier that checks that shapes match inferred
333   // expectations. This creates a new verifier every time because ShapeVerifier,
334   // being a DfsHloVisitor, is stateful. We want a clean object for each run of
335   // the verifier.
GetVerifier()336   std::unique_ptr<ShapeVerifier> GetVerifier() const override {
337     return std::make_unique<ShapeVerifier>(GetVerifierOpts());
338   }
339 };
340 
341 // HLO pass that verifies invariants of HLO instructions for each computation in
342 // the module.
343 class HloVerifier : public HloModulePass {
344  public:
345   HloVerifier(
346       bool layout_sensitive, bool allow_mixed_precision,
347       HloPredicate instruction_can_change_layout_func = {},
348       std::function<int64_t(const Shape&)> shape_size_func =
349           [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); })
350       : HloVerifier(HloVerifierOpts{}
351                         .WithLayoutSensitive(layout_sensitive)
352                         .WithAllowMixedPrecision(allow_mixed_precision)
353                         .WithInstructionCanChangeLayout(
354                             instruction_can_change_layout_func)
355                         .WithCustomShapeSize(shape_size_func)) {}
356 
HloVerifier(HloVerifierOpts && opts)357   explicit HloVerifier(HloVerifierOpts&& opts)
358       : target_metadata_(
359             std::make_unique<DefaultVerifierMetadata>(std::move(opts))),
360         context_("Unknown") {}
361 
362   // Uses custom target metadata
363   explicit HloVerifier(std::unique_ptr<TargetVerifierMetadata> target_metadata,
364                        absl::string_view context = "Unknown")
target_metadata_(std::move (target_metadata))365       : target_metadata_(std::move(target_metadata)), context_(context) {}
366 
367   ~HloVerifier() override = default;
name()368   absl::string_view name() const override { return "hlo-verifier"; }
369 
370   // Never returns true; no instructions are ever modified by this pass.
371   using HloPassInterface::Run;
372   using HloPassInterface::RunOnModuleGroup;
373   StatusOr<bool> Run(
374       HloModule* module,
375       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
376 
377  private:
378   // Owns verifier config.
379   std::unique_ptr<TargetVerifierMetadata> target_metadata_;
380 
381   // The hlo pass when the verifier is invoked.
382   std::string context_;
383 };
384 
385 }  // namespace xla
386 
387 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
388