1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <string> 22 #include <utility> 23 24 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 25 #include "tensorflow/compiler/xla/service/shape_inference.h" 26 27 namespace xla { 28 29 // Callback to return shape size, in bytes. 30 using ShapeSizeFn = std::function<int64_t(const Shape&)>; 31 32 struct HloVerifierOpts { MakeLayoutSensitiveHloVerifierOpts33 HloVerifierOpts&& MakeLayoutSensitive() { 34 layout_sensitive = true; 35 return std::move(*this); 36 } 37 WithLayoutSensitiveHloVerifierOpts38 HloVerifierOpts&& WithLayoutSensitive(bool layout_sensitive_p) { 39 layout_sensitive = layout_sensitive_p; 40 return std::move(*this); 41 } 42 WithAllowMixedPrecisionHloVerifierOpts43 HloVerifierOpts&& WithAllowMixedPrecision(bool allow_mixed_precision_p) { 44 allow_mixed_precision = allow_mixed_precision_p; 45 return std::move(*this); 46 } 47 AllowMixedPrecisionHloVerifierOpts48 HloVerifierOpts&& AllowMixedPrecision() { 49 allow_mixed_precision = true; 50 return std::move(*this); 51 } 52 VerifyBroadcastDimensionsOrderHloVerifierOpts53 HloVerifierOpts&& VerifyBroadcastDimensionsOrder() { 54 verify_broadcast_dimensions_order = true; 55 return std::move(*this); 56 } 57 VerifyReshapeIsBitcastHloVerifierOpts58 HloVerifierOpts&& VerifyReshapeIsBitcast() { 59 verify_reshape_is_bitcast = true; 60 return std::move(*this); 61 } 62 VerifyCustomCallNestedComputationThreadNameHloVerifierOpts63 HloVerifierOpts&& VerifyCustomCallNestedComputationThreadName() { 64 verify_custom_call_nested_computation_thread_name = true; 65 return std::move(*this); 66 } 67 WithAllowBitcastToHaveDifferentSizeHloVerifierOpts68 HloVerifierOpts&& WithAllowBitcastToHaveDifferentSize(bool allow) { 69 allow_bitcast_to_have_different_size = allow; 70 return std::move(*this); 71 } 72 WithInstructionCanChangeLayoutHloVerifierOpts73 HloVerifierOpts&& WithInstructionCanChangeLayout( 74 const HloPredicate& instruction_can_change_layout_p) { 75 instruction_can_change_layout = instruction_can_change_layout_p; 76 return std::move(*this); 77 } 78 WithCustomShapeSizeHloVerifierOpts79 HloVerifierOpts&& WithCustomShapeSize(const ShapeSizeFn& shape_size_p) { 80 shape_size = shape_size_p; 81 return std::move(*this); 82 } 83 IsLayoutSensitiveHloVerifierOpts84 bool IsLayoutSensitive() const { return layout_sensitive; } 85 AllowMixedPrecisionHloVerifierOpts86 bool AllowMixedPrecision() const { return allow_mixed_precision; } 87 InstructionCanChangeLayoutHloVerifierOpts88 const HloPredicate& InstructionCanChangeLayout() const { 89 return instruction_can_change_layout; 90 } 91 InstructionCanChangeLayoutHloVerifierOpts92 bool InstructionCanChangeLayout(const HloInstruction* instruction) const { 93 return !instruction_can_change_layout || 94 instruction_can_change_layout(instruction); 95 } 96 ShapeSizeHloVerifierOpts97 int64_t ShapeSize(const Shape& shape) const { return shape_size(shape); } 98 99 // If the verifier is layout-sensitive, shapes must be equal to what's 100 // expected. Otherwise, the shapes must simply be compatible. 101 bool layout_sensitive = false; 102 103 // Whether the inputs and output of an instruction can contain both F32s and 104 // BF16s. Tuples that include both F32s and BF16s are allowed regardless of 105 // this flag. 106 bool allow_mixed_precision = false; 107 108 // Check that `dimensions` attribute of broadcast is sorted. 109 bool verify_broadcast_dimensions_order = false; 110 111 // Check that reshape is a physical bitcast. 112 bool verify_reshape_is_bitcast = false; 113 114 // Check that custom call's called computations have same thread name as 115 // parent computation. 116 bool verify_custom_call_nested_computation_thread_name = true; 117 118 // Whether bitcast should have the same size, including all paddings. 119 bool allow_bitcast_to_have_different_size = false; 120 121 HloPredicate instruction_can_change_layout; 122 123 // Returns a target-specific shape size. 124 ShapeSizeFn shape_size = [](const Shape& shape) { 125 return ShapeUtil::ByteSizeOf(shape); 126 }; 127 }; 128 129 // Visitor which verifies that the output shape is correctly set. Verifies 130 // against the inferred shape for the instruction. 131 class ShapeVerifier : public DfsHloVisitor { 132 public: ShapeVerifier(const HloVerifierOpts & opts)133 explicit ShapeVerifier(const HloVerifierOpts& opts) : opts_(opts) {} 134 135 // Verifies that entry computation layout matches parameters and root shape of 136 // the module's entry computation. 137 virtual Status VerifyEntryComputationLayout(const HloModule& module); 138 139 Status Preprocess(HloInstruction* hlo) override; 140 141 Status HandleElementwiseUnary(HloInstruction* hlo) override; 142 Status HandleElementwiseBinary(HloInstruction* hlo) override; 143 Status HandleClamp(HloInstruction* clamp) override; 144 Status HandleSelect(HloInstruction* select) override; 145 Status HandleConcatenate(HloInstruction* concatenate) override; 146 Status HandleIota(HloInstruction* hlo) override; 147 Status HandleConvert(HloInstruction* convert) override; 148 Status HandleBitcastConvert(HloInstruction* convert) override; 149 Status HandleCopy(HloInstruction* copy) override; 150 Status HandleDot(HloInstruction* dot) override; 151 Status HandleConvolution(HloInstruction* convolution) override; 152 Status HandleFft(HloInstruction* fft) override; 153 Status HandleCholesky(HloInstruction* hlo) override; 154 Status HandleTriangularSolve(HloInstruction* hlo) override; 155 Status HandleAllGather(HloInstruction* hlo) override; 156 Status HandleAllGatherStart(HloInstruction* hlo) override; 157 Status HandleAllGatherDone(HloInstruction* hlo) override; 158 Status HandleAllReduce(HloInstruction* hlo) override; 159 Status HandleAllReduceStart(HloInstruction* hlo) override; 160 Status HandleAllReduceDone(HloInstruction* hlo) override; 161 Status HandleAllToAll(HloInstruction* hlo) override; 162 Status HandleCollectivePermute(HloInstruction* hlo) override; 163 Status HandleCollectivePermuteStart(HloInstruction* hlo) override; 164 Status HandleCollectivePermuteDone(HloInstruction* hlo) override; 165 Status HandlePartitionId(HloInstruction* hlo) override; 166 Status HandleReplicaId(HloInstruction* hlo) override; 167 Status HandleReducePrecision(HloInstruction* reduce_precision) override; 168 Status HandleInfeed(HloInstruction*) override; 169 Status HandleOptimizationBarrier(HloInstruction* hlo) override; 170 Status HandleOutfeed(HloInstruction*) override; 171 Status HandleRng(HloInstruction*) override; 172 Status HandleRngBitGenerator(HloInstruction*) override; 173 Status HandleRngGetAndUpdateState(HloInstruction*) override; 174 Status HandleReverse(HloInstruction* reverse) override; 175 Status HandleSort(HloInstruction* sort) override; 176 Status HandleConstant(HloInstruction* constant) override; 177 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 178 Status HandleReduce(HloInstruction* reduce) override; 179 Status HandleBitcast(HloInstruction* bitcast) override; 180 Status HandleBroadcast(HloInstruction* broadcast) override; 181 Status HandleReshape(HloInstruction* reshape) override; 182 Status HandleDynamicReshape(HloInstruction* dynamic_reshape) override; 183 Status HandleTranspose(HloInstruction* transpose) override; 184 Status HandleParameter(HloInstruction*) override; 185 Status HandleFusion(HloInstruction*) override; 186 Status HandleCall(HloInstruction* call) override; 187 Status HandleCustomCall(HloInstruction*) override; 188 Status HandleSlice(HloInstruction* slice) override; 189 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; 190 Status HandleDynamicUpdateSlice( 191 HloInstruction* dynamic_update_slice) override; 192 Status HandleTuple(HloInstruction* tuple) override; 193 Status HandleMap(HloInstruction* map) override; 194 Status HandleReduceScatter(HloInstruction* hlo) override; 195 Status HandleReduceWindow(HloInstruction* reduce_window) override; 196 Status HandleSelectAndScatter(HloInstruction* instruction) override; 197 Status HandleWhile(HloInstruction* xla_while) override; 198 Status HandleConditional(HloInstruction* conditional) override; 199 Status HandlePad(HloInstruction* pad) override; 200 Status HandleAsyncStart(HloInstruction* async_start) override; 201 Status HandleAsyncUpdate(HloInstruction* async_update) override; 202 Status HandleAsyncDone(HloInstruction* async_done) override; 203 Status HandleCopyStart(HloInstruction* copy_start) override; 204 Status HandleCopyDone(HloInstruction* copy_done) override; 205 Status HandleSend(HloInstruction* send) override; 206 Status HandleSendDone(HloInstruction* send_done) override; 207 Status HandleRecv(HloInstruction* recv) override; 208 Status HandleRecvDone(HloInstruction* recv_done) override; 209 Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; 210 Status HandleBatchNormInference( 211 HloInstruction* batch_norm_inference) override; 212 Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; 213 Status HandleGather(HloInstruction* gather) override; 214 Status HandleScatter(HloInstruction* scatter) override; 215 Status HandleAfterAll(HloInstruction* token) override; 216 Status HandleGetDimensionSize(HloInstruction* get_size) override; 217 Status HandleSetDimensionSize(HloInstruction* set_size) override; 218 Status HandleAddDependency(HloInstruction* add_dependency) override; 219 FinishVisit(HloInstruction *)220 Status FinishVisit(HloInstruction*) override { return OkStatus(); } 221 222 protected: 223 // Check the instruction's shape against the shape given by ShapeInference 224 // and return an appropriate error if there is a mismatch. 225 Status CheckShape(const HloInstruction* instruction, 226 const Shape& inferred_shape, 227 bool only_compare_minor_to_major_in_layout = false); 228 229 // Overload which takes a StatusOr to reduce boilerplate in the caller. 230 Status CheckShape(const HloInstruction* instruction, 231 const StatusOr<Shape>& inferred_shape_status); 232 233 // Check a unary (binary, etc) instruction's shape against the inferred shape. 234 Status CheckUnaryShape(const HloInstruction* instruction); 235 Status CheckBinaryShape(const HloInstruction* instruction); 236 Status CheckTernaryShape(const HloInstruction* instruction); 237 Status CheckVariadicShape(const HloInstruction* instruction); 238 239 private: 240 // Helpers that switch on layout_sensitive_. 241 bool ShapesSame(const Shape& a, const Shape& b, 242 bool minor_to_major_only = false, 243 bool ignore_memory_space = false) { 244 if (!opts_.layout_sensitive) { 245 return ShapeUtil::Compatible(a, b); 246 } 247 Shape::Equal equal; 248 if (ignore_memory_space) { 249 equal.IgnoreMemorySpaceInLayout(); 250 } 251 if (minor_to_major_only) { 252 equal.MinorToMajorOnlyInLayout(); 253 } 254 return equal(a, b); 255 } 256 257 bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b, 258 bool minor_to_major_only = false) { 259 if (!opts_.layout_sensitive) { 260 return ShapeUtil::CompatibleIgnoringFpPrecision(a, b); 261 } 262 Shape::Equal equal; 263 if (minor_to_major_only) { 264 equal.MinorToMajorOnlyInLayout(); 265 } 266 equal.IgnoreFpPrecision(); 267 return equal(a, b); 268 } 269 StringifyShape(const Shape & s)270 std::string StringifyShape(const Shape& s) { 271 return opts_.layout_sensitive ? ShapeUtil::HumanStringWithLayout(s) 272 : ShapeUtil::HumanString(s); 273 } 274 275 // Helpers that switch on allow_mixed_precision_. SameElementType(const Shape & a,const Shape & b)276 bool SameElementType(const Shape& a, const Shape& b) { 277 return opts_.allow_mixed_precision 278 ? ShapeUtil::SameElementTypeIgnoringFpPrecision(a, b) 279 : ShapeUtil::SameElementType(a, b); 280 } 281 282 // Checks that the given operand of the given instruction is of type TOKEN. 283 Status CheckIsTokenOperand(const HloInstruction* instruction, 284 int64_t operand_no); 285 286 // Checks that the shape of the given operand of the given instruction matches 287 // the given parameter of the given computation. 288 Status CheckOperandAndParameter(const HloInstruction* instruction, 289 int64_t operand_number, 290 const HloComputation* computation, 291 int64_t parameter_number); 292 293 // Returns true if the shapes of the two operands have the same element type, 294 // and the result shape either has the same element type as the operand shapes 295 // or mixed precision is allowed and the result shape and the operand shapes 296 // have floating point element types. 297 bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, 298 const Shape& result_shape); 299 300 const HloVerifierOpts& opts_; 301 }; 302 303 // An interface used to encapsulate target-specific verification quirks. 304 class TargetVerifierMetadata { 305 public: TargetVerifierMetadata(HloVerifierOpts && opts)306 explicit TargetVerifierMetadata(HloVerifierOpts&& opts) : opts_(opts) { 307 CHECK(opts.instruction_can_change_layout == nullptr || 308 opts.layout_sensitive); 309 } 310 311 virtual std::unique_ptr<ShapeVerifier> GetVerifier() const = 0; 312 TargetVerifierMetadata()313 TargetVerifierMetadata() {} ~TargetVerifierMetadata()314 virtual ~TargetVerifierMetadata() {} 315 316 TargetVerifierMetadata(const TargetVerifierMetadata&) = delete; 317 TargetVerifierMetadata& operator=(const TargetVerifierMetadata&) = delete; 318 GetVerifierOpts()319 const HloVerifierOpts& GetVerifierOpts() const { return opts_; } 320 321 private: 322 HloVerifierOpts opts_; 323 }; 324 325 // The default implementation of TargetVerifierMetadata, used unless the target 326 // needs to override it. 327 class DefaultVerifierMetadata : public TargetVerifierMetadata { 328 public: DefaultVerifierMetadata(HloVerifierOpts && opts)329 explicit DefaultVerifierMetadata(HloVerifierOpts&& opts) 330 : TargetVerifierMetadata(std::move(opts)) {} 331 332 // Creates a ShapeVerifier that checks that shapes match inferred 333 // expectations. This creates a new verifier every time because ShapeVerifier, 334 // being a DfsHloVisitor, is stateful. We want a clean object for each run of 335 // the verifier. GetVerifier()336 std::unique_ptr<ShapeVerifier> GetVerifier() const override { 337 return std::make_unique<ShapeVerifier>(GetVerifierOpts()); 338 } 339 }; 340 341 // HLO pass that verifies invariants of HLO instructions for each computation in 342 // the module. 343 class HloVerifier : public HloModulePass { 344 public: 345 HloVerifier( 346 bool layout_sensitive, bool allow_mixed_precision, 347 HloPredicate instruction_can_change_layout_func = {}, 348 std::function<int64_t(const Shape&)> shape_size_func = 349 [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); }) 350 : HloVerifier(HloVerifierOpts{} 351 .WithLayoutSensitive(layout_sensitive) 352 .WithAllowMixedPrecision(allow_mixed_precision) 353 .WithInstructionCanChangeLayout( 354 instruction_can_change_layout_func) 355 .WithCustomShapeSize(shape_size_func)) {} 356 HloVerifier(HloVerifierOpts && opts)357 explicit HloVerifier(HloVerifierOpts&& opts) 358 : target_metadata_( 359 std::make_unique<DefaultVerifierMetadata>(std::move(opts))), 360 context_("Unknown") {} 361 362 // Uses custom target metadata 363 explicit HloVerifier(std::unique_ptr<TargetVerifierMetadata> target_metadata, 364 absl::string_view context = "Unknown") target_metadata_(std::move (target_metadata))365 : target_metadata_(std::move(target_metadata)), context_(context) {} 366 367 ~HloVerifier() override = default; name()368 absl::string_view name() const override { return "hlo-verifier"; } 369 370 // Never returns true; no instructions are ever modified by this pass. 371 using HloPassInterface::Run; 372 using HloPassInterface::RunOnModuleGroup; 373 StatusOr<bool> Run( 374 HloModule* module, 375 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 376 377 private: 378 // Owns verifier config. 379 std::unique_ptr<TargetVerifierMetadata> target_metadata_; 380 381 // The hlo pass when the verifier is invoked. 382 std::string context_; 383 }; 384 385 } // namespace xla 386 387 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 388