1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
17
18 #include <memory>
19 #include <optional>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/string_view.h"
28 #include "tensorflow/compiler/xla/comparison_util.h"
29 #include "tensorflow/compiler/xla/permutation_util.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
33 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
34 #include "tensorflow/compiler/xla/service/hlo_computation.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/platform/errors.h"
44
45 namespace xla {
46
47 namespace {
48
IsCallerInstruction(HloInstruction * hlo)49 bool IsCallerInstruction(HloInstruction* hlo) {
50 switch (hlo->opcode()) {
51 case HloOpcode::kAsyncStart:
52 case HloOpcode::kAsyncUpdate:
53 case HloOpcode::kAsyncDone:
54 case HloOpcode::kCall:
55 case HloOpcode::kConditional:
56 case HloOpcode::kWhile:
57 case HloOpcode::kAllReduce:
58 case HloOpcode::kReduceScatter:
59 case HloOpcode::kAllReduceStart:
60 case HloOpcode::kMap:
61 case HloOpcode::kReduce:
62 case HloOpcode::kReduceWindow:
63 case HloOpcode::kScatter:
64 case HloOpcode::kSelectAndScatter:
65 case HloOpcode::kSort:
66 case HloOpcode::kFusion:
67 case HloOpcode::kCustomCall:
68 return true;
69 default:
70 return false;
71 }
72 }
73
CheckOperandCount(const HloInstruction * hlo,int expected)74 Status CheckOperandCount(const HloInstruction* hlo, int expected) {
75 if (hlo->operand_count() != expected) {
76 return InternalError("Expected %d operands for %s instruction: %s",
77 expected, HloOpcodeString(hlo->opcode()),
78 hlo->ToString());
79 }
80 return OkStatus();
81 }
82
CheckParameterCount(const HloInstruction * calling_instruction,const HloComputation * computation,int expected)83 Status CheckParameterCount(const HloInstruction* calling_instruction,
84 const HloComputation* computation, int expected) {
85 if (computation->num_parameters() != expected) {
86 return InternalError(
87 "Expected computation %s called from %s to have %d parameters, has %d",
88 computation->name(), calling_instruction->name(), expected,
89 computation->num_parameters());
90 }
91 return OkStatus();
92 }
93
GetSubgroupSize(HloCollectiveInstruction * hlo,CollectiveOpGroupMode group_mode)94 int64_t GetSubgroupSize(HloCollectiveInstruction* hlo,
95 CollectiveOpGroupMode group_mode) {
96 const HloModuleConfig& config = hlo->GetModule()->config();
97 // empty replica groups imply all replicas form a single group.
98 int64_t replica_subgroup_size =
99 hlo->replica_groups().empty()
100 ? 0
101 : hlo->replica_groups()[0].replica_ids_size();
102 switch (group_mode) {
103 case CollectiveOpGroupMode::kCrossReplica:
104 case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
105 int64_t replica_subgroup_size =
106 hlo->replica_groups().empty()
107 ? config.replica_count()
108 : hlo->replica_groups()[0].replica_ids_size();
109 if (group_mode == CollectiveOpGroupMode::kCrossReplicaAndPartition) {
110 // Replicas from all partitions participate.
111 replica_subgroup_size *= config.num_partitions();
112 }
113 return replica_subgroup_size;
114 }
115 case CollectiveOpGroupMode::kFlattenedID:
116 return replica_subgroup_size;
117 case CollectiveOpGroupMode::kCrossPartition:
118 return hlo->replica_groups().empty()
119 ? config.num_partitions()
120 : hlo->replica_groups()[0].replica_ids_size();
121 }
122 }
123
CheckNestedComputationThreadNameEqual(const HloComputation * comp,bool skip_nested_async_op_check)124 Status CheckNestedComputationThreadNameEqual(const HloComputation* comp,
125 bool skip_nested_async_op_check) {
126 for (const HloInstruction* instr : comp->instructions()) {
127 if (skip_nested_async_op_check && instr->IsAsynchronous()) {
128 continue;
129 }
130 for (const HloComputation* called_cmp : instr->called_computations()) {
131 if (called_cmp->execution_thread() != comp->execution_thread()) {
132 return InternalError(
133 "Nested computations expects same computation's thread name (%s vs "
134 "%s).",
135 called_cmp->execution_thread(), comp->execution_thread());
136 }
137 TF_RETURN_IF_ERROR(CheckNestedComputationThreadNameEqual(
138 called_cmp, skip_nested_async_op_check));
139 }
140 }
141 return Status::OK();
142 }
143 } // namespace
144
Preprocess(HloInstruction * hlo)145 Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
146 if (!hlo->called_computations().empty() && !IsCallerInstruction(hlo)) {
147 return InternalError(
148 "Called computations specified for non-caller instruction %s",
149 hlo->ToString());
150 }
151 std::optional<int> arity = HloOpcodeArity(hlo->opcode());
152 if (arity) {
153 TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity));
154 }
155 return OkStatus();
156 }
157
HandleElementwiseUnary(HloInstruction * hlo)158 Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
159 return CheckUnaryShape(hlo);
160 }
161
HandleElementwiseBinary(HloInstruction * hlo)162 Status ShapeVerifier::HandleElementwiseBinary(HloInstruction* hlo) {
163 return CheckBinaryShape(hlo);
164 }
165
HandleClamp(HloInstruction * clamp)166 Status ShapeVerifier::HandleClamp(HloInstruction* clamp) {
167 return CheckTernaryShape(clamp);
168 }
169
HandleSelect(HloInstruction * select)170 Status ShapeVerifier::HandleSelect(HloInstruction* select) {
171 return CheckTernaryShape(select);
172 }
173
HandleConcatenate(HloInstruction * concatenate)174 Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) {
175 std::vector<const Shape*> operand_shapes;
176 for (const HloInstruction* operand : concatenate->operands()) {
177 operand_shapes.push_back(&operand->shape());
178 }
179 return CheckShape(concatenate,
180 ShapeInference::InferConcatOpShape(
181 operand_shapes, concatenate->concatenate_dimension()));
182 }
183
HandleConvert(HloInstruction * convert)184 Status ShapeVerifier::HandleConvert(HloInstruction* convert) {
185 return CheckShape(convert, ShapeInference::InferConvertShape(
186 convert->operand(0)->shape(),
187 convert->shape().element_type()));
188 }
189
HandleBitcastConvert(HloInstruction * convert)190 Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) {
191 return CheckShape(convert, ShapeInference::InferBitcastConvertShape(
192 convert->operand(0)->shape(),
193 convert->shape().element_type()));
194 }
195
HandleCopy(HloInstruction * copy)196 Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
197 return CheckUnaryShape(copy);
198 }
199
HandleDot(HloInstruction * dot)200 Status ShapeVerifier::HandleDot(HloInstruction* dot) {
201 TF_ASSIGN_OR_RETURN(
202 const Shape expected,
203 ShapeInference::InferDotOpShape(
204 dot->operand(0)->shape(), dot->operand(1)->shape(),
205 dot->dot_dimension_numbers(),
206 /*preferred_element_type=*/dot->shape().element_type()));
207 return CheckShape(dot, expected);
208 }
209
HandleConvolution(HloInstruction * convolution)210 Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
211 TF_ASSIGN_OR_RETURN(
212 Shape expected,
213 ShapeInference::InferConvolveShape(
214 convolution->operand(0)->shape(), convolution->operand(1)->shape(),
215 convolution->feature_group_count(), convolution->batch_group_count(),
216 convolution->window(), convolution->convolution_dimension_numbers(),
217 /*preferred_element_type=*/convolution->shape().element_type()));
218 return CheckShape(convolution, expected);
219 }
220
HandleFft(HloInstruction * fft)221 Status ShapeVerifier::HandleFft(HloInstruction* fft) {
222 TF_ASSIGN_OR_RETURN(
223 const Shape expected,
224 ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(),
225 fft->fft_length()));
226 return CheckShape(fft, expected);
227 }
228
HandleTriangularSolve(HloInstruction * hlo)229 Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) {
230 TF_ASSIGN_OR_RETURN(const Shape expected,
231 ShapeInference::InferTriangularSolveShape(
232 hlo->operand(0)->shape(), hlo->operand(1)->shape(),
233 hlo->triangular_solve_options()));
234 return CheckShape(hlo, expected);
235 }
236
HandleCholesky(HloInstruction * hlo)237 Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
238 TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1));
239 TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferCholeskyShape(
240 hlo->operand(0)->shape()));
241 return CheckShape(hlo, expected);
242 }
243
HandleOptimizationBarrier(HloInstruction * hlo)244 Status ShapeVerifier::HandleOptimizationBarrier(HloInstruction* hlo) {
245 TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1));
246 return CheckShape(hlo, hlo->operand(0)->shape());
247 }
248
249 // Checks that `hlo`'s set of ReplicaGroups:
250 //
251 // - names each replica 0 through n-1 exactly once (where n is either number of
252 // replicas, or number of partitions, or their product)
253 // - does not contain any empty ReplicaGroups.
254 //
255 // Note that although none of the groups may be empty, `hlo` is allowed to have
256 // empty groups when group mode is not kFlattenedID. That just means it has one
257 // big group.
258 //
259 // In general, if replica groups is not empty, all replica groups should be of
260 // the same size. The exception is all-reduce, where non-uniform replica groups
261 // are allowed. This is controlled by `uniform_replica_group_size`.
CheckReplicaGroups(HloInstruction * hlo,CollectiveOpGroupMode group_mode,bool uniform_replica_group_size=true)262 static Status CheckReplicaGroups(HloInstruction* hlo,
263 CollectiveOpGroupMode group_mode,
264 bool uniform_replica_group_size = true) {
265 if (!hlo->replica_groups().empty()) {
266 absl::flat_hash_set<int64_t> replicas_seen;
267 for (const ReplicaGroup& g : hlo->replica_groups()) {
268 if (g.replica_ids().empty()) {
269 return InternalError(
270 "Instruction cannot have an empty replica group: %s",
271 hlo->ToString());
272 }
273 for (int64_t i : g.replica_ids()) {
274 if (!replicas_seen.insert(i).second) {
275 return InternalError(
276 "Replica %d is repeated in instruction's replica-groups: %s", i,
277 hlo->ToString());
278 }
279 }
280 }
281 size_t n = replicas_seen.size();
282 for (int64_t i = 0; i < n; ++i) {
283 if (!replicas_seen.count(i)) {
284 return InternalError(
285 "Replica %d is not named in instruction's replica-groups: %s", i,
286 hlo->ToString());
287 }
288 }
289
290 // replica-groups have numbers [0, n). This n should be either replica or
291 // partition count, or their product. In some cases, replica and/or
292 // partition count is not set in the HloModule config and has a default
293 // value of 1. For those cases, skip this part of the verification.
294 int64_t replica_count = hlo->GetModule()->config().replica_count();
295 int64_t num_partitions = hlo->GetModule()->config().num_partitions();
296 switch (group_mode) {
297 case CollectiveOpGroupMode::kCrossReplica:
298 case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
299 TF_RET_CHECK(replica_count == 1 || n == replica_count)
300 << "In " << CollectiveOpGroupModeToString(group_mode)
301 << " mode, replica groups should contain " << replica_count
302 << " replicas, but found " << n << ": " << hlo->ToString();
303 break;
304 }
305 case CollectiveOpGroupMode::kCrossPartition: {
306 TF_RET_CHECK(num_partitions == 1 || n == num_partitions)
307 << "In " << CollectiveOpGroupModeToString(group_mode)
308 << " mode, replica groups should contain " << num_partitions
309 << " partitions, but found " << n << ": " << hlo->ToString();
310 break;
311 }
312 case CollectiveOpGroupMode::kFlattenedID: {
313 const int64_t num_flattened_ids = replica_count * num_partitions;
314 TF_RET_CHECK(num_flattened_ids == 1 || n == num_flattened_ids)
315 << "In " << CollectiveOpGroupModeToString(group_mode)
316 << " mode, replica groups should contain " << num_flattened_ids
317 << " flattened IDs, but found " << n << ": " << hlo->ToString();
318 break;
319 }
320 }
321
322 if (uniform_replica_group_size) {
323 int64_t size = hlo->replica_groups()[0].replica_ids_size();
324 for (const ReplicaGroup& g : hlo->replica_groups()) {
325 TF_RET_CHECK(size == g.replica_ids_size())
326 << "Replica groups expected to be of uniform size";
327 }
328 }
329 } else {
330 TF_RET_CHECK(group_mode != CollectiveOpGroupMode::kFlattenedID)
331 << "Replica groups must be specified in flattened-id mode";
332 }
333
334 return OkStatus();
335 }
336
CheckCommonAllGatherInvariants(HloInstruction * hlo,int64_t * computed_shard_count)337 static Status CheckCommonAllGatherInvariants(HloInstruction* hlo,
338 int64_t* computed_shard_count) {
339 auto ag = Cast<HloAllGatherInstruction>(hlo);
340 CHECK_NE(computed_shard_count, nullptr) << "Expected a shard count as input";
341 TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
342 GetCollectiveOpGroupMode(ag->channel_id().has_value(),
343 ag->use_global_device_ids()));
344 TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode));
345 TF_RET_CHECK(ag->all_gather_dimension() >= 0);
346
347 int64_t shard_count;
348 for (int64_t i = 0; i < ag->operand_count(); ++i) {
349 TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(i)->shape().rank());
350
351 Shape output_shape;
352 if (hlo->opcode() == HloOpcode::kAllGather) {
353 output_shape = (ag->operand_count() == 1) ? ag->shape()
354 : ag->shape().tuple_shapes(i);
355 } else {
356 TF_RET_CHECK(hlo->opcode() == HloOpcode::kAllGatherStart);
357 output_shape = (ag->operand_count() == 1)
358 ? ag->shape().tuple_shapes(1)
359 : ag->shape().tuple_shapes(1).tuple_shapes(i);
360 }
361 TF_RET_CHECK(ag->all_gather_dimension() < output_shape.rank());
362 if (i == 0) {
363 shard_count = CeilOfRatio(
364 output_shape.dimensions(ag->all_gather_dimension()),
365 ag->operand(i)->shape().dimensions(ag->all_gather_dimension()));
366 }
367 }
368
369 int64_t subgroup_size = GetSubgroupSize(ag, group_mode);
370 // If replica and partition count is not explicitly set, it will have a
371 // default value of 1, in which case the subgroup_size will be 1 as well. Skip
372 // these verification checks in that case.
373 TF_RET_CHECK(subgroup_size == 1 || shard_count == subgroup_size)
374 << "shard_count = " << shard_count
375 << ", subgroup_size = " << subgroup_size << ", " << hlo->ToString();
376 *computed_shard_count = shard_count;
377 return OkStatus();
378 }
379
HandleAllGather(HloInstruction * hlo)380 Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
381 auto ag = Cast<HloAllGatherInstruction>(hlo);
382 int64_t shard_count;
383 TF_RETURN_IF_ERROR(CheckCommonAllGatherInvariants(hlo, &shard_count));
384 std::vector<const Shape*> operand_shapes;
385 for (const HloInstruction* operand : hlo->operands()) {
386 operand_shapes.push_back(&operand->shape());
387 }
388 return CheckShape(
389 ag, ShapeInference::InferAllGatherShape(
390 operand_shapes, ag->all_gather_dimension(), shard_count));
391 }
392
HandleAllGatherStart(HloInstruction * hlo)393 Status ShapeVerifier::HandleAllGatherStart(HloInstruction* hlo) {
394 auto ag = Cast<HloAllGatherInstruction>(hlo);
395 int64_t shard_count;
396 TF_RETURN_IF_ERROR(CheckCommonAllGatherInvariants(hlo, &shard_count));
397 std::vector<const Shape*> operand_shapes;
398 for (const HloInstruction* operand : hlo->operands()) {
399 operand_shapes.push_back(&operand->shape());
400 }
401 return CheckShape(
402 ag, ShapeInference::InferAllGatherStartShape(
403 operand_shapes, ag->all_gather_dimension(), shard_count));
404 }
405
HandleAllGatherDone(HloInstruction * hlo)406 Status ShapeVerifier::HandleAllGatherDone(HloInstruction* hlo) {
407 return CheckShape(
408 hlo, ShapeInference::InferAllGatherDoneShape(hlo->operand(0)->shape()));
409 }
410
HandleAllReduce(HloInstruction * hlo)411 Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) {
412 auto ar = Cast<HloAllReduceInstruction>(hlo);
413 TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
414 GetCollectiveOpGroupMode(ar->channel_id().has_value(),
415 ar->use_global_device_ids()));
416 TF_RETURN_IF_ERROR(
417 CheckReplicaGroups(ar, group_mode, /*uniform_replica_group_size=*/false));
418
419 std::vector<const Shape*> operand_shapes;
420 for (const HloInstruction* operand : hlo->operands()) {
421 operand_shapes.push_back(&operand->shape());
422 }
423 return CheckShape(hlo, ShapeInference::InferAllReduceShape(operand_shapes));
424 }
425
HandleReduceScatter(HloInstruction * hlo)426 Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) {
427 auto ars = Cast<HloReduceScatterInstruction>(hlo);
428 TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
429 GetCollectiveOpGroupMode(ars->channel_id().has_value(),
430 ars->use_global_device_ids()));
431 TF_RETURN_IF_ERROR(CheckReplicaGroups(ars, group_mode));
432 TF_RET_CHECK(ars->scatter_dimension() >= 0);
433
434 for (int64_t i = 0; i < ars->operand_count(); ++i) {
435 TF_RET_CHECK(ars->scatter_dimension() < ars->operand(i)->shape().rank());
436
437 const Shape& output_shape = (ars->operand_count() == 1)
438 ? ars->shape()
439 : ars->shape().tuple_shapes(i);
440 TF_RET_CHECK(ars->scatter_dimension() < output_shape.rank());
441 }
442
443 const Shape& output0_shape =
444 (ars->operand_count() == 1) ? ars->shape() : ars->shape().tuple_shapes(0);
445 int64_t shard_count =
446 CeilOfRatio(ars->operand(0)->shape().dimensions(ars->scatter_dimension()),
447 output0_shape.dimensions(ars->scatter_dimension()));
448 int64_t subgroup_size = GetSubgroupSize(ars, group_mode);
449 // If replica and partition count is not explicitly set, it will have a
450 // default value of 1, in which case the subgroup_size will be 1 as well. Skip
451 // these verification checks in that case.
452 TF_RET_CHECK(subgroup_size == 1 || shard_count == subgroup_size)
453 << "shard_count = " << shard_count
454 << ", subgroup_size = " << subgroup_size << ", " << hlo->ToString();
455
456 std::vector<const Shape*> operand_shapes;
457 for (const HloInstruction* operand : hlo->operands()) {
458 operand_shapes.push_back(&operand->shape());
459 }
460 return CheckShape(ars,
461 ShapeInference::InferReduceScatterShape(
462 operand_shapes, ars->scatter_dimension(), shard_count));
463 }
464
HandleAllReduceStart(HloInstruction * hlo)465 Status ShapeVerifier::HandleAllReduceStart(HloInstruction* hlo) {
466 auto ar = Cast<HloAllReduceInstruction>(hlo);
467 TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
468 GetCollectiveOpGroupMode(ar->channel_id().has_value(),
469 ar->use_global_device_ids()));
470 TF_RETURN_IF_ERROR(
471 CheckReplicaGroups(ar, group_mode, /*uniform_replica_group_size=*/false));
472
473 std::vector<const Shape*> operand_shapes;
474 for (const HloInstruction* operand : hlo->operands()) {
475 operand_shapes.push_back(&operand->shape());
476 }
477 return CheckShape(hlo,
478 ShapeInference::InferAllReduceStartShape(operand_shapes));
479 }
480
HandleAllReduceDone(HloInstruction * hlo)481 Status ShapeVerifier::HandleAllReduceDone(HloInstruction* hlo) {
482 return CheckShape(
483 hlo, ShapeInference::InferAllReduceDoneShape(hlo->operand(0)->shape()));
484 }
485
HandleAllToAll(HloInstruction * hlo)486 Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
487 auto* all_to_all = Cast<HloAllToAllInstruction>(hlo);
488 TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
489 GetCollectiveOpGroupMode(
490 all_to_all->channel_id().has_value(), std::nullopt));
491
492 TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, group_mode));
493
494 TF_RET_CHECK(all_to_all != nullptr);
495
496 if (all_to_all->split_dimension()) {
497 int64_t split_count = GetSubgroupSize(all_to_all, group_mode);
498 TF_RET_CHECK(hlo->operand_count() == 1);
499 return CheckShape(
500 hlo, ShapeInference::InferAllToAllShape(
501 hlo->operand(0)->shape(), *all_to_all->split_dimension(),
502 *all_to_all->split_dimension(), split_count));
503 } else {
504 std::vector<const Shape*> operand_shapes;
505 for (const HloInstruction* operand : hlo->operands()) {
506 operand_shapes.push_back(&operand->shape());
507 }
508 return CheckShape(hlo,
509 ShapeInference::InferAllToAllTupleShape(operand_shapes));
510 }
511 }
512
HandlePartitionId(HloInstruction * hlo)513 Status ShapeVerifier::HandlePartitionId(HloInstruction* hlo) {
514 return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
515 }
516
HandleReplicaId(HloInstruction * hlo)517 Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) {
518 return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
519 }
520
521 namespace {
522
CheckBufferOffset(const Shape & buffer_shape,const Shape & buffer_offset_shape)523 Status CheckBufferOffset(const Shape& buffer_shape,
524 const Shape& buffer_offset_shape) {
525 if (!buffer_offset_shape.IsTuple()) {
526 return InternalError("Buffer offset is not tuple.");
527 }
528 bool all_is_array =
529 absl::c_all_of(buffer_offset_shape.tuple_shapes(),
530 [](const Shape& shape) { return shape.IsArray(); });
531 bool all_is_tuple =
532 absl::c_all_of(buffer_offset_shape.tuple_shapes(),
533 [](const Shape& shape) { return shape.IsTuple(); });
534 if (!all_is_array && !all_is_tuple) {
535 return InternalError(
536 "Buffer offset should either be a tuple of arrays or "
537 " a tuple of tuples.");
538 }
539
540 if (all_is_tuple) {
541 if (absl::c_any_of(buffer_offset_shape.tuple_shapes(),
542 [&buffer_shape](const Shape& shape) {
543 return ShapeUtil::TupleElementCount(shape) !=
544 buffer_shape.rank();
545 })) {
546 return InternalError(
547 "Buffer offset index should have the same number of "
548 "elements as the buffer's rank.");
549 }
550 } else {
551 if (buffer_offset_shape.tuple_shapes_size() != buffer_shape.rank()) {
552 return InternalError(
553 "Buffer offset index should have the same number of "
554 "elements as the buffer's rank.");
555 }
556 }
557 return OkStatus();
558 }
559
CheckInplaceCollectivePermute(HloInstruction * collective_permute)560 Status CheckInplaceCollectivePermute(HloInstruction* collective_permute) {
561 if (collective_permute->operand_count() == 1) {
562 return OkStatus();
563 }
564 if (collective_permute->operand_count() != 4) {
565 return InternalError("Unexpected number of operands: %d.",
566 collective_permute->operand_count());
567 }
568
569 const Shape& input_buffer_shape = collective_permute->operand(0)->shape();
570 const Shape& output_buffer_shape = collective_permute->operand(1)->shape();
571 const Shape& input_offset_shape = collective_permute->operand(2)->shape();
572 const Shape& output_offset_shape = collective_permute->operand(3)->shape();
573
574 if (input_buffer_shape.IsArray() && output_buffer_shape.IsArray()) {
575 Status check_input_buffer_offset =
576 CheckBufferOffset(input_buffer_shape, input_offset_shape);
577 if (!check_input_buffer_offset.ok()) {
578 return check_input_buffer_offset;
579 }
580 Status check_output_buffer_offset =
581 CheckBufferOffset(output_buffer_shape, output_offset_shape);
582 if (!check_output_buffer_offset.ok()) {
583 return check_output_buffer_offset;
584 }
585 } else if (input_buffer_shape.IsTuple() && output_buffer_shape.IsTuple()) {
586 if (ShapeUtil::TupleElementCount(input_buffer_shape) !=
587 ShapeUtil::TupleElementCount(output_buffer_shape)) {
588 return InternalError("Unmatching input buffers and output buffers.");
589 }
590 if (!input_offset_shape.IsTuple() ||
591 ShapeUtil::TupleElementCount(input_offset_shape) !=
592 ShapeUtil::TupleElementCount(input_buffer_shape)) {
593 return InternalError("Unmatching input buffers and input offset.");
594 }
595 for (int i = 0; i < input_buffer_shape.tuple_shapes_size(); ++i) {
596 Status check_input_buffer_offset =
597 CheckBufferOffset(input_buffer_shape.tuple_shapes(i),
598 input_offset_shape.tuple_shapes(i));
599 if (!check_input_buffer_offset.ok()) {
600 return check_input_buffer_offset;
601 }
602 }
603 if (!output_offset_shape.IsTuple() ||
604 ShapeUtil::TupleElementCount(output_offset_shape) !=
605 ShapeUtil::TupleElementCount(output_buffer_shape)) {
606 return InternalError("Unmatching output buffers and output offset.");
607 }
608 for (int i = 0; i < output_buffer_shape.tuple_shapes_size(); ++i) {
609 Status check_output_buffer_offset =
610 CheckBufferOffset(output_buffer_shape.tuple_shapes(i),
611 output_offset_shape.tuple_shapes(i));
612 if (!check_output_buffer_offset.ok()) {
613 return check_output_buffer_offset;
614 }
615 }
616 } else {
617 return InternalError("Unmatching input buffers and output buffers.");
618 }
619 return OkStatus();
620 }
621
CheckDuplicatedSourceOrTarget(HloInstruction * hlo,CollectiveOpGroupMode group_mode)622 Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo,
623 CollectiveOpGroupMode group_mode) {
624 // A source or target cannot appear twice in the collective-permute's
625 // source-target pairs. Also, based on the group formation mode, check if the
626 // source and target IDs are within expected range.
627
628 // Note: for collective-permute, only kCrossReplica and kCrossPartition modes
629 // are valid.
630 const HloModuleConfig& config = hlo->GetModule()->config();
631 const int64_t limit = group_mode == CollectiveOpGroupMode::kCrossReplica
632 ? config.replica_count()
633 : config.num_partitions();
634 absl::flat_hash_map<int64_t, std::vector<int64_t>> seen_source_to_targets;
635 absl::flat_hash_map<int64_t, std::vector<int64_t>> seen_target_to_sources;
636 int allowed_seen_count = 1;
637 if (hlo->operand_count() == 4) {
638 if (hlo->operand(0)->shape().IsArray()) {
639 allowed_seen_count = hlo->operand(2)->shape().tuple_shapes_size();
640 } else {
641 allowed_seen_count =
642 hlo->operand(2)->shape().tuple_shapes(0).tuple_shapes_size();
643 }
644 }
645
646 for (const auto& p : hlo->source_target_pairs()) {
647 TF_RET_CHECK(p.first >= 0)
648 << "Source " << p.first
649 << " in the instruction's source-target pair must be >= 0 : "
650 << hlo->ToString();
651 TF_RET_CHECK(limit == 1 || p.first < limit)
652 << "Source " << p.first
653 << " in the instruction's source-target pair must be < " << limit
654 << " : " << hlo->ToString();
655 if (seen_source_to_targets.contains(p.first) &&
656 seen_source_to_targets[p.first].size() == allowed_seen_count) {
657 if (allowed_seen_count == 1) {
658 return InternalError(
659 "Source %d appears more than once in instruction's source-target "
660 "pairs: %s",
661 p.first, hlo->ToString());
662 } else {
663 return InternalError(
664 "Source %d appears more than %d times in instruction's "
665 "source-target "
666 "pairs: %s",
667 p.first, allowed_seen_count, hlo->ToString());
668 }
669 } else {
670 seen_source_to_targets[p.first].push_back(p.second);
671 }
672 TF_RET_CHECK(p.second >= 0)
673 << "Target " << p.second
674 << " in the instruction's source-target pair must be >= 0 : "
675 << hlo->ToString();
676 TF_RET_CHECK(limit == 1 || p.second < limit)
677 << "Target " << p.second
678 << " in the instruction's source-target pair must be < " << limit
679 << " : " << hlo->ToString();
680 if (seen_target_to_sources.contains(p.second) &&
681 seen_target_to_sources[p.second].size() == allowed_seen_count) {
682 if (allowed_seen_count == 1) {
683 return InternalError(
684 "Target %d appears more than once in instruction's source-target "
685 "pairs: %s",
686 p.second, hlo->ToString());
687 } else {
688 return InternalError(
689 "Target %d appears more than %d times in instruction's "
690 "source-target "
691 "pairs: %s",
692 p.second, allowed_seen_count, hlo->ToString());
693 }
694 } else {
695 seen_target_to_sources[p.second].push_back(p.first);
696 }
697 }
698 return OkStatus();
699 }
700
701 } // namespace
702
HandleCollectivePermute(HloInstruction * hlo)703 Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
704 TF_ASSIGN_OR_RETURN(
705 CollectiveOpGroupMode group_mode,
706 GetCollectiveOpGroupMode(hlo->channel_id().has_value(),
707 /*use_global_device_ids=*/std::nullopt));
708 TF_RETURN_IF_ERROR(CheckInplaceCollectivePermute(hlo));
709 TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo, group_mode));
710 std::vector<const Shape*> operand_shapes;
711 absl::c_transform(
712 hlo->operands(), std::back_inserter(operand_shapes),
713 [](const HloInstruction* operand) { return &(operand->shape()); });
714 return CheckShape(
715 hlo, ShapeInference::InferCollectivePermuteShape(operand_shapes));
716 }
717
HandleCollectivePermuteStart(HloInstruction * hlo)718 Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) {
719 TF_ASSIGN_OR_RETURN(
720 CollectiveOpGroupMode group_mode,
721 GetCollectiveOpGroupMode(hlo->channel_id().has_value(),
722 /*use_global_device_ids=*/std::nullopt));
723 TF_RETURN_IF_ERROR(CheckInplaceCollectivePermute(hlo));
724 TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo, group_mode));
725 std::vector<const Shape*> operand_shapes;
726 absl::c_transform(
727 hlo->operands(), std::back_inserter(operand_shapes),
728 [](const HloInstruction* operand) { return &(operand->shape()); });
729 return CheckShape(
730 hlo, ShapeInference::InferCollectivePermuteStartShape(operand_shapes));
731 }
732
HandleCollectivePermuteDone(HloInstruction * hlo)733 Status ShapeVerifier::HandleCollectivePermuteDone(HloInstruction* hlo) {
734 return CheckShape(hlo, ShapeInference::InferCollectivePermuteDoneShape(
735 hlo->operand(0)->shape()));
736 }
737
HandleReducePrecision(HloInstruction * reduce_precision)738 Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
739 return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
740 reduce_precision->operand(0)->shape(),
741 reduce_precision->exponent_bits(),
742 reduce_precision->mantissa_bits()));
743 }
744
CheckIsTokenOperand(const HloInstruction * instruction,int64_t operand_no)745 Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction,
746 int64_t operand_no) {
747 const HloInstruction* token = instruction->operand(operand_no);
748 if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
749 return InternalError(
750 "Expected operand %d to be token-shaped, actual shape is "
751 "%s:\n%s",
752 operand_no, StringifyShape(token->shape()), instruction->ToString());
753 }
754 return OkStatus();
755 }
756
CheckOperandAndParameter(const HloInstruction * instruction,int64_t operand_number,const HloComputation * computation,int64_t parameter_number)757 Status ShapeVerifier::CheckOperandAndParameter(
758 const HloInstruction* instruction, int64_t operand_number,
759 const HloComputation* computation, int64_t parameter_number) {
760 const HloInstruction* operand = instruction->operand(operand_number);
761 const HloInstruction* parameter =
762 computation->parameter_instruction(parameter_number);
763 if (!ShapesSame(operand->shape(), parameter->shape())) {
764 return InternalError("Operand %s shape does not match parameter's %s in %s",
765 operand->ToString(), parameter->ToString(),
766 instruction->ToString());
767 }
768 return OkStatus();
769 }
770
HandleInfeed(HloInstruction * instruction)771 Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
772 HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
773 TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
774
775 // The output of infeed is a tuple containing the data value and a token.
776 return CheckShape(infeed,
777 ShapeUtil::MakeTupleShape(
778 {infeed->infeed_shape(), ShapeUtil::MakeTokenShape()}));
779 }
780
HandleOutfeed(HloInstruction * instruction)781 Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
782 HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction);
783 TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1));
784
785 // Outfeed has a separate shape field for the value which is outfed to the
786 // host. The shape of the instruction itself is always a token.
787 if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) {
788 return InternalError(
789 "Expected outfeed shape to be equal to operand's shape %s, "
790 "actual shape is %s:\n%s",
791 StringifyShape(outfeed->operand(0)->shape()),
792 StringifyShape(outfeed->outfeed_shape()), outfeed->ToString());
793 }
794 return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
795 }
796
HasCompatibleElementTypes(const Shape & shape_0,const Shape & shape_1,const Shape & result_shape)797 bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
798 const Shape& shape_1,
799 const Shape& result_shape) {
800 return ShapeUtil::SameElementType(shape_0, shape_1) &&
801 (ShapeUtil::SameElementType(shape_0, result_shape) ||
802 (opts_.allow_mixed_precision &&
803 ShapeUtil::SameElementTypeIgnoringFpPrecision(shape_0,
804 result_shape)));
805 }
806
HandleRng(HloInstruction * instruction)807 Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
808 TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2));
809
810 const Shape& shape_0 = instruction->operand(0)->shape();
811 const Shape& shape_1 = instruction->operand(1)->shape();
812 if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) {
813 return InternalError(
814 "Expected scalar types for the two operands of Rng instruction: %s",
815 instruction->ToString());
816 }
817
818 if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) {
819 return InternalError(
820 "Expected compatible element types for the result and the two operands"
821 " of Rng instruction: %s",
822 instruction->ToString());
823 }
824
825 PrimitiveType element_type = shape_0.element_type();
826 switch (instruction->random_distribution()) {
827 case RNG_UNIFORM:
828 if (!primitive_util::IsFloatingPointType(element_type) &&
829 !primitive_util::IsIntegralType(element_type) &&
830 element_type != PRED) {
831 return InternalError(
832 "Element type not supported."
833 " Expected element to be of floating point type, integral type or"
834 " predicate type for RngUniform: %s",
835 instruction->ToString());
836 }
837 break;
838
839 case RNG_NORMAL:
840 if (!primitive_util::IsFloatingPointType(element_type)) {
841 return InternalError(
842 "Element type not supported."
843 " Expected element to be FloatingPointType for RngNormal: %s",
844 instruction->ToString());
845 }
846 break;
847 default:
848 return InternalError(
849 "Invalid Rng distribution %s",
850 RandomDistribution_Name(instruction->random_distribution()));
851 }
852
853 return OkStatus();
854 }
855
HandleRngBitGenerator(HloInstruction * hlo)856 Status ShapeVerifier::HandleRngBitGenerator(HloInstruction* hlo) {
857 if (!hlo->shape().IsTuple()) {
858 return OkStatus();
859 }
860 if (hlo->shape().IsTuple() && hlo->shape().tuple_shapes_size() != 2) {
861 return InternalError(
862 "Expected tuple shape with 2 elements for RngBitGenerator. Got: %s",
863 hlo->shape().ToString());
864 }
865 if (!ShapeUtil::Compatible(hlo->operand(0)->shape(),
866 hlo->shape().tuple_shapes(0))) {
867 return InternalError(
868 "Expected state shape to match between input and output for "
869 "RngBitGenerator. Got %s vs. %s",
870 hlo->operand(0)->shape().ToString(),
871 hlo->shape().tuple_shapes(0).ToString());
872 }
873 return OkStatus();
874 }
875
HandleRngGetAndUpdateState(HloInstruction * instruction)876 Status ShapeVerifier::HandleRngGetAndUpdateState(HloInstruction* instruction) {
877 TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0));
878 const Shape& result_shape = instruction->shape();
879 const Shape expected_shape = ShapeUtil::MakeShape(U64, {2});
880 if (!ShapeUtil::Compatible(result_shape, expected_shape)) {
881 return InternalError(
882 "Invalid RngGetAndUpdateState, expect result to have shape %s, got %s ",
883 StringifyShape(expected_shape), StringifyShape(result_shape));
884 }
885
886 return OkStatus();
887 }
888
HandleReverse(HloInstruction * reverse)889 Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
890 return CheckShape(
891 reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(),
892 reverse->dimensions()));
893 }
894
HandleSort(HloInstruction * hlo)895 Status ShapeVerifier::HandleSort(HloInstruction* hlo) {
896 HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
897 if (sort->operand_count() < 1) {
898 return InternalError("Expected at least 1 operand for %s instruction: %s",
899 HloOpcodeString(sort->opcode()), sort->ToString());
900 }
901 HloComputation* compare = sort->to_apply();
902
903 // Check that the 'compare' computation returns a PRED.
904 Shape compare_shape = compare->root_instruction()->shape();
905 if (!ShapeUtil::Compatible(compare_shape, ShapeUtil::MakeShape(PRED, {}))) {
906 return InternalError(
907 "The Sort compare computation shape does not lead to a scalar "
908 "predicate shape: %s",
909 StringifyShape(compare_shape));
910 }
911
912 // Check that the number of parameters of the 'compare' computation is
913 // correct.
914 TF_RETURN_IF_ERROR(
915 CheckParameterCount(sort, compare, sort->operand_count() * 2));
916
917 // Verify that the operands of the compare computation have the correct scalar
918 // shapes.
919 for (int64_t parameter_idx = 0; parameter_idx < compare->num_parameters();
920 ++parameter_idx) {
921 int64_t operand_idx = parameter_idx / 2;
922 Shape expected_scalar_shape = ShapeUtil::MakeShape(
923 sort->operand(operand_idx)->shape().element_type(), {});
924 Shape actual_parameter_shape =
925 compare->parameter_instruction(parameter_idx)->shape();
926 if (!ShapeUtil::CompatibleIgnoringFpPrecision(expected_scalar_shape,
927 actual_parameter_shape)) {
928 return InternalError(
929 "Expected the %lld-th parameter of the compare computation of sort "
930 "to have shape %s, but got %s",
931 parameter_idx, StringifyShape(expected_scalar_shape),
932 StringifyShape(actual_parameter_shape));
933 }
934 }
935
936 // Verify that all operand shapes have the same dimensions.
937 for (int64_t operand = 1; operand < sort->operand_count(); ++operand) {
938 if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(),
939 sort->operand(operand)->shape())) {
940 return InternalError(
941 "Expected sort to have to have the same dimensions for all operands. "
942 "First operand shape is: %s\n, shape (operand index %lld) is: %s",
943 StringifyShape(sort->operand(0)->shape()), operand,
944 StringifyShape(sort->operand(operand)->shape()));
945 }
946 }
947
948 // Verify the sort_dimension.
949 if (sort->sort_dimension() >= sort->operand(0)->shape().rank()) {
950 return InternalError(
951 "Expected the sort_dimension %d of sort to be smaller than the rank %d "
952 "of the operand(s).",
953 sort->sort_dimension(), sort->shape().rank());
954 }
955
956 return CheckVariadicShape(sort);
957 }
958
HandleConstant(HloInstruction * constant)959 Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
960 if (!Cast<HloConstantInstruction>(constant)->HasLiteral()) {
961 return InternalError("Constant is required to have a valid literal: %s",
962 constant->ToString());
963 }
964 return CheckShape(constant, constant->literal().shape(),
965 /*only_compare_minor_to_major_in_layout=*/true);
966 }
967
HandleIota(HloInstruction * hlo)968 Status ShapeVerifier::HandleIota(HloInstruction* hlo) {
969 auto* iota = Cast<HloIotaInstruction>(hlo);
970 if (!iota->shape().IsArray()) {
971 return InternalError("Iota does not support non-array result.");
972 }
973 const int64_t rank = iota->shape().rank();
974 if (rank == 0) {
975 return InternalError("Iota does not support scalars.");
976 }
977 int64_t iota_dimension = iota->iota_dimension();
978 if (iota_dimension >= rank || iota_dimension < 0) {
979 return InternalError(
980 "The iota dimension cannot go beyond the operation rank or be "
981 "negative.");
982 }
983
984 PrimitiveType primitive_type = iota->shape().element_type();
985 if (!primitive_util::IsIntegralType(primitive_type) &&
986 !primitive_util::IsFloatingPointType(primitive_type) &&
987 !primitive_util::IsComplexType(primitive_type)) {
988 return InvalidArgument(
989 "Only support iota of integral, floating point or complex primitive "
990 "types, got %s",
991 PrimitiveType_Name(primitive_type));
992 }
993
994 return OkStatus();
995 }
996
HandleGetTupleElement(HloInstruction * get_tuple_element)997 Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
998 return CheckShape(get_tuple_element,
999 ShapeInference::InferGetTupleElementShape(
1000 get_tuple_element->operand(0)->shape(),
1001 get_tuple_element->tuple_index()));
1002 }
1003
1004 namespace {
SameElementTypesForOperandsAndToApplyParameters(const HloInstruction & instruction,int64_t num_operands_to_check)1005 Status SameElementTypesForOperandsAndToApplyParameters(
1006 const HloInstruction& instruction, int64_t num_operands_to_check) {
1007 const ProgramShape& to_apply = instruction.to_apply()->ComputeProgramShape();
1008 for (int i = 0; i < num_operands_to_check; ++i) {
1009 const Shape& parameter_shape = to_apply.parameters(i);
1010 const Shape& operand_shape = instruction.operands()[i]->shape();
1011 if (!ShapeUtil::SameElementType(parameter_shape, operand_shape)) {
1012 return InvalidArgument(
1013 "Shape mismatch between to_apply computation"
1014 " parameter and operand %d in %s.",
1015 i, instruction.ToString().c_str());
1016 }
1017 }
1018 return OkStatus();
1019 }
1020 } // namespace
1021
HandleReduce(HloInstruction * reduce)1022 Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
1023 if (reduce->operand_count() % 2 != 0) {
1024 return InternalError(
1025 "Expected an even number of operands for %s instruction: %s",
1026 HloOpcodeString(reduce->opcode()), reduce->ToString());
1027 }
1028
1029 std::vector<const Shape*> operand_shapes;
1030 for (const HloInstruction* operand : reduce->operands()) {
1031 operand_shapes.push_back(&operand->shape());
1032 }
1033 TF_RETURN_IF_ERROR(
1034 CheckShape(reduce, ShapeInference::InferReduceShape(
1035 operand_shapes, reduce->dimensions(),
1036 reduce->to_apply()->ComputeProgramShape())));
1037
1038 return opts_.allow_mixed_precision
1039 ? OkStatus()
1040 : SameElementTypesForOperandsAndToApplyParameters(
1041 *reduce, reduce->operand_count());
1042 }
1043
HandleBitcast(HloInstruction * bitcast)1044 Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
1045 const Shape& output_shape = bitcast->shape();
1046 const Shape& operand_shape = bitcast->operand(0)->shape();
1047 if (opts_.layout_sensitive &&
1048 opts_.shape_size(output_shape) != opts_.shape_size(operand_shape)) {
1049 // Allow bitcast that has the same data size but different trailing
1050 // paddings.
1051 if (!opts_.allow_bitcast_to_have_different_size ||
1052 !(output_shape.is_static() && operand_shape.is_static() &&
1053 (ShapeUtil::ArrayDataSize(output_shape) ==
1054 ShapeUtil::ArrayDataSize(operand_shape)))) {
1055 return InternalError(
1056 "Bitcast cannot have different shape sizes of output (%d) and "
1057 "operand "
1058 "(%d) (%s) (%s)",
1059 opts_.shape_size(output_shape), opts_.shape_size(operand_shape),
1060 output_shape.ToString(true), operand_shape.ToString(true));
1061 }
1062 }
1063 return OkStatus();
1064 }
1065
HandleBroadcast(HloInstruction * broadcast)1066 Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
1067 // HLO broadcast has no exact analog at the client level so there is no
1068 // ShapeInference method. Check the output shape explicitly.
1069 const Shape& operand_shape = broadcast->operand(0)->shape();
1070 // Check for mixed precision.
1071 TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape));
1072 TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size());
1073 for (int64_t operand_dimension = 0; operand_dimension < operand_shape.rank();
1074 ++operand_dimension) {
1075 int64_t output_dimension = broadcast->dimensions()[operand_dimension];
1076 TF_RET_CHECK((output_dimension < broadcast->shape().rank()) &&
1077 output_dimension >= 0 &&
1078 (broadcast->shape().dimensions(output_dimension) ==
1079 operand_shape.dimensions(operand_dimension)))
1080 << broadcast->ToString() << " operand shape " << operand_shape;
1081 }
1082 return OkStatus();
1083 }
1084
HandleDynamicReshape(HloInstruction * dynamic_reshape)1085 Status ShapeVerifier::HandleDynamicReshape(HloInstruction* dynamic_reshape) {
1086 // Check for mixed precision.
1087 const Shape& operand_shape = dynamic_reshape->operand(0)->shape();
1088 TF_RET_CHECK(SameElementType(dynamic_reshape->shape(), operand_shape));
1089 TF_RET_CHECK(ShapeUtil::ElementsIn(dynamic_reshape->shape()) ==
1090 ShapeUtil::ElementsIn(operand_shape));
1091 TF_RET_CHECK(dynamic_reshape->shape().rank() + 1 ==
1092 dynamic_reshape->operand_count());
1093 for (int64_t i = 1; i < dynamic_reshape->operand_count(); ++i) {
1094 TF_RET_CHECK(dynamic_reshape->operand(i)->shape().element_type() == S32);
1095 }
1096 return OkStatus();
1097 }
1098
HandleReshape(HloInstruction * reshape)1099 Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
1100 // Check for mixed precision.
1101 const Shape& operand_shape = reshape->operand(0)->shape();
1102 TF_RET_CHECK(SameElementType(reshape->shape(), operand_shape));
1103 TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
1104 ShapeUtil::ElementsIn(operand_shape));
1105 return OkStatus();
1106 }
1107
HandleTranspose(HloInstruction * transpose)1108 Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
1109 return CheckShape(
1110 transpose, ShapeInference::InferTransposeShape(
1111 transpose->operand(0)->shape(), transpose->dimensions()));
1112 }
1113
HandleParameter(HloInstruction * hlo)1114 Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
1115 return OkStatus();
1116 }
1117
HandleFusion(HloInstruction * fusion)1118 Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
1119 if (fusion->called_computations().size() != 1) {
1120 return InternalError(
1121 "Fusion has a non-unary number of called computations (%s)",
1122 fusion->ToString().c_str());
1123 }
1124 const Shape& root_computation_shape =
1125 fusion->called_computations()[0]->root_instruction()->shape();
1126 if (!ShapesSame(fusion->shape(), root_computation_shape)) {
1127 return InternalError(
1128 "Fused computation shape (%s) is not equal to the fusion shape (%s)",
1129 root_computation_shape.ToString(true), fusion->shape().ToString(true));
1130 }
1131
1132 auto& fused_parameters = fusion->fused_parameters();
1133 if (fused_parameters.size() != fusion->operand_count()) {
1134 return InternalError(
1135 "Fused parameter count (%d) does not match the number of operands (%d)"
1136 " passed to the fusion instruction in: %s.",
1137 fused_parameters.size(), fusion->operand_count(),
1138 fusion->ToString().c_str());
1139 }
1140 for (HloInstruction* fused_param : fused_parameters) {
1141 int64_t param_no = fused_param->parameter_number();
1142 if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) {
1143 return InternalError(
1144 "Shape mismatch between parameter number %d and its operand in "
1145 "%s.",
1146 param_no, fusion->ToString().c_str());
1147 }
1148 }
1149 return OkStatus();
1150 }
1151
HandleCall(HloInstruction * call)1152 Status ShapeVerifier::HandleCall(HloInstruction* call) {
1153 TF_RETURN_IF_ERROR(
1154 CheckParameterCount(call, call->to_apply(), call->operand_count()));
1155 for (int64_t i = 0; i < call->to_apply()->num_parameters(); ++i) {
1156 TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
1157 }
1158 // The shape of kCall should match the shape of the computation it calls.
1159 return CheckShape(call, call->to_apply()->root_instruction()->shape());
1160 }
1161
HandleCustomCall(HloInstruction * instruction)1162 Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
1163 const HloCustomCallInstruction* custom_call =
1164 DynCast<const HloCustomCallInstruction>(instruction);
1165 TF_RET_CHECK(custom_call != nullptr);
1166 if (custom_call->layout_constrained()) {
1167 // If the layout is constrained, verify all the respective shapes have
1168 // layouts and that the constrained operand shapes match the shapes of the
1169 // operands.
1170 TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape()));
1171 TF_RET_CHECK(custom_call->operand_count() ==
1172 custom_call->operand_shapes_with_layout().size());
1173 for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
1174 const Shape& operand_shape_with_layout =
1175 custom_call->operand_shapes_with_layout()[i];
1176 TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(),
1177 operand_shape_with_layout))
1178 << custom_call->operand(i)->shape().ToString() << " operand "
1179 << operand_shape_with_layout.ToString();
1180 TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
1181 }
1182 }
1183 for (const auto& pair : custom_call->output_to_operand_aliasing()) {
1184 TF_RET_CHECK(pair.second.first < custom_call->operand_count())
1185 << "Invalid aliasing operand index.";
1186 TF_RET_CHECK(ShapeUtil::IndexIsValid(
1187 custom_call->operand(pair.second.first)->shape(), pair.second.second))
1188 << "Invalid aliasing operand shape index.";
1189 TF_RET_CHECK(ShapeUtil::IndexIsValid(custom_call->shape(), pair.first))
1190 << "Invalid aliasing output shape index.";
1191 const Shape& output_subshape =
1192 ShapeUtil::GetSubshape(custom_call->shape(), pair.first);
1193 const Shape& operand_subshape = ShapeUtil::GetSubshape(
1194 custom_call->operand(pair.second.first)->shape(), pair.second.second);
1195 if (opts_.layout_sensitive) {
1196 TF_RET_CHECK(operand_subshape == output_subshape)
1197 << "Different aliasing shapes: " << operand_subshape.ToString()
1198 << " vs " << output_subshape.ToString();
1199 } else {
1200 TF_RET_CHECK(ShapeUtil::Compatible(output_subshape, operand_subshape))
1201 << "Different aliasing shapes: " << operand_subshape.ToString()
1202 << " vs " << output_subshape.ToString();
1203 }
1204 }
1205 return OkStatus();
1206 }
1207
HandleSlice(HloInstruction * slice)1208 Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
1209 return CheckShape(slice,
1210 ShapeInference::InferSliceShape(
1211 slice->operand(0)->shape(), slice->slice_starts(),
1212 slice->slice_limits(), slice->slice_strides()));
1213 }
1214
HandleDynamicSlice(HloInstruction * dynamic_slice)1215 Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) {
1216 return CheckShape(
1217 dynamic_slice,
1218 ShapeInference::InferDynamicSliceShape(
1219 dynamic_slice->operand(0)->shape(),
1220 Cast<HloDynamicSliceInstruction>(dynamic_slice)->index_shapes(),
1221 dynamic_slice->dynamic_slice_sizes()));
1222 }
1223
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)1224 Status ShapeVerifier::HandleDynamicUpdateSlice(
1225 HloInstruction* dynamic_update_slice) {
1226 return CheckShape(
1227 dynamic_update_slice,
1228 ShapeInference::InferDynamicUpdateSliceShape(
1229 dynamic_update_slice->operand(0)->shape(),
1230 dynamic_update_slice->operand(1)->shape(),
1231 Cast<HloDynamicUpdateSliceInstruction>(dynamic_update_slice)
1232 ->index_shapes()));
1233 }
1234
HandleTuple(HloInstruction * tuple)1235 Status ShapeVerifier::HandleTuple(HloInstruction* tuple) {
1236 return CheckVariadicShape(tuple);
1237 }
1238
HandleMap(HloInstruction * map)1239 Status ShapeVerifier::HandleMap(HloInstruction* map) {
1240 std::vector<const Shape*> operand_shapes;
1241 int64_t max_operand_rank = 0;
1242 for (const HloInstruction* operand : map->operands()) {
1243 operand_shapes.push_back(&operand->shape());
1244 max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
1245 }
1246 // TODO(b/65689298) Remove code below once Map is generalized to accept
1247 // arbitrary map dimensions.
1248 std::vector<int64_t> map_dims(max_operand_rank);
1249 std::iota(map_dims.begin(), map_dims.end(), 0);
1250
1251 TF_RETURN_IF_ERROR(CheckShape(
1252 map,
1253 ShapeInference::InferMapShape(
1254 operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims)));
1255
1256 return opts_.allow_mixed_precision
1257 ? OkStatus()
1258 : SameElementTypesForOperandsAndToApplyParameters(
1259 *map, map->operand_count());
1260 }
1261
HandleReduceWindow(HloInstruction * reduce_window)1262 Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
1263 VLOG(2) << "Verify reduce window:" << reduce_window->ToString() << "\n";
1264 auto reduce_window_instr = Cast<HloReduceWindowInstruction>(reduce_window);
1265 auto input_shapes = reduce_window_instr->input_shapes();
1266 VLOG(2) << "reduce window input shape count: " << input_shapes.size() << "\n";
1267 auto init_shapes = reduce_window_instr->init_value_shapes();
1268 VLOG(2) << "reduce instruction is :" << reduce_window->ToString() << "\n";
1269 TF_RETURN_IF_ERROR(CheckShape(
1270 reduce_window, ShapeInference::InferReduceWindowShape(
1271 input_shapes, init_shapes, reduce_window->window(),
1272 reduce_window->to_apply()->ComputeProgramShape())));
1273
1274 return opts_.allow_mixed_precision
1275 ? OkStatus()
1276 : SameElementTypesForOperandsAndToApplyParameters(
1277 *reduce_window, reduce_window->operand_count());
1278 }
1279
HandleSelectAndScatter(HloInstruction * instruction)1280 Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
1281 return CheckShape(
1282 instruction,
1283 ShapeInference::InferSelectAndScatterShape(
1284 instruction->operand(0)->shape(),
1285 instruction->select()->ComputeProgramShape(), instruction->window(),
1286 instruction->operand(1)->shape(), instruction->operand(2)->shape(),
1287 instruction->scatter()->ComputeProgramShape()));
1288 }
1289
HandleWhile(HloInstruction * xla_while)1290 Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
1291 TF_RETURN_IF_ERROR(
1292 CheckParameterCount(xla_while, xla_while->while_body(), 1));
1293 TF_RETURN_IF_ERROR(
1294 CheckParameterCount(xla_while, xla_while->while_condition(), 1));
1295 TF_RETURN_IF_ERROR(
1296 CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
1297 TF_RETURN_IF_ERROR(
1298 CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
1299 const Shape& conditional_shape =
1300 xla_while->while_condition()->root_instruction()->shape();
1301 if (!ShapeUtil::Compatible(conditional_shape,
1302 ShapeUtil::MakeShape(PRED, {}))) {
1303 return InternalError(
1304 "Conditional computation shape does not lead to a scalar predicate "
1305 "shape: %s",
1306 StringifyShape(conditional_shape));
1307 }
1308 // The shape of kWhile should match the shape of the body computation it
1309 // calls.
1310 return CheckShape(xla_while,
1311 xla_while->while_body()->root_instruction()->shape());
1312 }
1313
HandleConditional(HloInstruction * conditional)1314 Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
1315 if (!ShapeUtil::IsScalar(conditional->operand(0)->shape())) {
1316 return InvalidArgument(
1317 "The first operand of conditional must be a scalar. Got %s",
1318 conditional->operand(0)->shape().DebugString());
1319 }
1320 const int num_branches = conditional->branch_count();
1321 PrimitiveType operand0_type = conditional->operand(0)->shape().element_type();
1322 if (operand0_type == PRED) {
1323 TF_RET_CHECK(num_branches == 2);
1324 } else {
1325 if (operand0_type != S32) {
1326 return InvalidArgument(
1327 "The first operand of indexed conditional must be a scalar of S32. "
1328 "Got"
1329 " type %s.",
1330 PrimitiveType_Name(operand0_type));
1331 }
1332 TF_RET_CHECK(num_branches >= 1);
1333 }
1334 TF_RETURN_IF_ERROR(CheckOperandCount(conditional, num_branches + 1));
1335 for (int j = 0; j < num_branches; ++j) {
1336 TF_RETURN_IF_ERROR(CheckParameterCount(
1337 conditional, conditional->branch_computation(j), 1));
1338 TF_RETURN_IF_ERROR(CheckOperandAndParameter(
1339 conditional, j + 1, conditional->branch_computation(j), 0));
1340 TF_RETURN_IF_ERROR(CheckShape(
1341 conditional,
1342 conditional->branch_computation(j)->root_instruction()->shape()));
1343 }
1344 return OkStatus();
1345 }
1346
HandlePad(HloInstruction * pad)1347 Status ShapeVerifier::HandlePad(HloInstruction* pad) {
1348 return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(),
1349 pad->operand(1)->shape(),
1350 pad->padding_config()));
1351 }
1352
1353 namespace {
CheckAsyncOpOperand(const HloInstruction * async_op)1354 Status CheckAsyncOpOperand(const HloInstruction* async_op) {
1355 const HloInstruction* operand = async_op->operand(0);
1356 if (operand->opcode() != HloOpcode::kAsyncStart &&
1357 operand->opcode() != HloOpcode::kAsyncUpdate) {
1358 return InternalError(
1359 "%s expects operand to be async-update or async-done, found "
1360 "%s.",
1361 HloOpcodeString(async_op->opcode()),
1362 HloOpcodeString(operand->opcode()));
1363 }
1364 if (*async_op->async_wrapped_computation() !=
1365 *operand->async_wrapped_computation()) {
1366 return InternalError(
1367 "The %s expects its wrapped async computation to be identical to its "
1368 "operand's wrapped async computation (%s vs %s), thread name (%s vs "
1369 "%s).",
1370 HloOpcodeString(async_op->opcode()),
1371 async_op->async_wrapped_instruction()->ToString(),
1372 operand->async_wrapped_instruction()->ToString(),
1373 async_op->async_wrapped_computation()->execution_thread(),
1374 operand->async_wrapped_computation()->execution_thread());
1375 }
1376 if (async_op->async_group_id() != operand->async_group_id()) {
1377 return InternalError(
1378 "%s expects its operand to have the same group id (%s vs %s).",
1379 HloOpcodeString(async_op->opcode()),
1380 async_op->async_group_id() ? absl::StrCat(*async_op->async_group_id())
1381 : "none",
1382 operand->async_group_id() ? absl::StrCat(*operand->async_group_id())
1383 : "none");
1384 }
1385 return OkStatus();
1386 }
1387
CheckAsyncOpComputationShapes(const HloInstruction * async_op,const Shape & async_shape)1388 Status CheckAsyncOpComputationShapes(const HloInstruction* async_op,
1389 const Shape& async_shape) {
1390 if (!async_shape.IsTuple() || async_shape.tuple_shapes_size() < 2) {
1391 return InternalError(
1392 "The %s expects the async shape to be a tuple of at least two "
1393 "elements, found %s.",
1394 HloOpcodeString(async_op->opcode()), async_shape.ToString());
1395 }
1396 ProgramShape computation_shape =
1397 async_op->async_wrapped_computation()->ComputeProgramShape();
1398 Shape param_shape = ShapeUtil::MakeTupleShape(computation_shape.parameters());
1399 if (async_shape.tuple_shapes(0) != param_shape) {
1400 return InternalError(
1401 "The %s expects the async shape at index {0} to match async "
1402 "computation parameter shape (%s vs %s).",
1403 HloOpcodeString(async_op->opcode()),
1404 async_shape.tuple_shapes(0).ToString(/*print_layout=*/true),
1405 param_shape.ToString(/*print_layout=*/true));
1406 }
1407 if (async_shape.tuple_shapes(1) != computation_shape.result()) {
1408 return InternalError(
1409 "The %s expects the async shape at index {1} to match the async "
1410 "computation root shape (%s vs %s).",
1411 HloOpcodeString(async_op->opcode()),
1412 async_shape.tuple_shapes(1).ToString(/*print_layout=*/true),
1413 computation_shape.result().ToString(/*print_layout=*/true));
1414 }
1415 return Status::OK();
1416 }
1417
CheckAsyncOpComputationThreadName(const HloInstruction * async_op)1418 Status CheckAsyncOpComputationThreadName(const HloInstruction* async_op) {
1419 std::optional<absl::string_view> async_execution_thread =
1420 async_op->async_execution_thread();
1421 if (async_execution_thread !=
1422 async_op->async_wrapped_computation()->execution_thread()) {
1423 return InternalError(
1424 "async-start expects same async thread name as wrapped computation's "
1425 "thread name (%s vs %s).",
1426 async_execution_thread ? absl::StrCat(*async_execution_thread) : "none",
1427 async_op->async_wrapped_computation()->execution_thread());
1428 }
1429 return CheckNestedComputationThreadNameEqual(
1430 async_op->async_wrapped_computation(),
1431 /*skip_nested_async_op_check=*/false);
1432 }
1433
CheckCallableInstructionThreadName(const HloInstruction * instruction,bool skip_nested_async_op_check)1434 Status CheckCallableInstructionThreadName(const HloInstruction* instruction,
1435 bool skip_nested_async_op_check) {
1436 for (const HloComputation* computation : instruction->called_computations()) {
1437 if (instruction->parent() != nullptr) {
1438 if (instruction->parent()->execution_thread() !=
1439 computation->execution_thread()) {
1440 return InternalError(
1441 "callable instruction %s expects parent computation thread name "
1442 "same as called computation's thread name (%s vs %s).",
1443 instruction->ToString(), instruction->parent()->execution_thread(),
1444 computation->execution_thread());
1445 }
1446 }
1447 TF_RETURN_IF_ERROR(CheckNestedComputationThreadNameEqual(
1448 computation, skip_nested_async_op_check));
1449 }
1450 return Status::OK();
1451 }
1452 } // namespace
1453
HandleAsyncStart(HloInstruction * async_start)1454 Status ShapeVerifier::HandleAsyncStart(HloInstruction* async_start) {
1455 TF_RETURN_IF_ERROR(
1456 CheckAsyncOpComputationShapes(async_start, async_start->shape()));
1457 TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_start));
1458 const Shape& param_shape = async_start->shape().tuple_shapes(0);
1459 for (int i = 0; i < async_start->operand_count(); ++i) {
1460 if (param_shape.tuple_shapes(i) != async_start->operand(i)->shape()) {
1461 return InternalError(
1462 "The %s expects the shape of operand %d to match the async shape at "
1463 "index {0} (%s vs %s).",
1464 HloOpcodeString(async_start->opcode()), i,
1465 async_start->operand(i)->shape().ToString(/*print_layout=*/true),
1466 param_shape.tuple_shapes(i).ToString(/*print_layout=*/true));
1467 }
1468 }
1469 return Status::OK();
1470 }
1471
HandleAsyncUpdate(HloInstruction * async_update)1472 Status ShapeVerifier::HandleAsyncUpdate(HloInstruction* async_update) {
1473 TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_update));
1474 if (async_update->operand(0)->shape() != async_update->shape()) {
1475 return InternalError(
1476 "The %s expects the shape of operand and output to match (%s vs %s).",
1477 HloOpcodeString(async_update->opcode()),
1478 async_update->operand(0)->shape().ToString(),
1479 async_update->shape().ToString());
1480 }
1481 TF_RETURN_IF_ERROR(
1482 CheckAsyncOpComputationShapes(async_update, async_update->shape()));
1483 return CheckAsyncOpOperand(async_update);
1484 }
1485
HandleAsyncDone(HloInstruction * async_done)1486 Status ShapeVerifier::HandleAsyncDone(HloInstruction* async_done) {
1487 TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_done));
1488 TF_RETURN_IF_ERROR(CheckAsyncOpComputationShapes(
1489 async_done, async_done->operand(0)->shape()));
1490 const Shape& root_shape = async_done->operand(0)->shape().tuple_shapes(1);
1491 if (root_shape != async_done->shape()) {
1492 return InternalError(
1493 "The %s expects the shape of output to match the async shape at index "
1494 "{1} (%s vs %s).",
1495 HloOpcodeString(async_done->opcode()), async_done->shape().ToString(),
1496 root_shape.ToString());
1497 }
1498 return CheckAsyncOpOperand(async_done);
1499 }
1500
HandleCopyStart(HloInstruction * copy_start)1501 Status ShapeVerifier::HandleCopyStart(HloInstruction* copy_start) {
1502 return CheckShape(copy_start,
1503 ShapeUtil::MakeTupleShape({copy_start->operand(0)->shape(),
1504 copy_start->operand(0)->shape(),
1505 ShapeUtil::MakeShape(U32, {})}),
1506 /*only_compare_minor_to_major_in_layout=*/true);
1507 }
1508
HandleCopyDone(HloInstruction * copy_done)1509 Status ShapeVerifier::HandleCopyDone(HloInstruction* copy_done) {
1510 const Shape& operand_shape = copy_done->operand(0)->shape();
1511 const Shape& dest_shape = ShapeUtil::GetTupleElementShape(operand_shape, 0);
1512 const Shape& src_shape = ShapeUtil::GetTupleElementShape(operand_shape, 1);
1513 if (!ShapesSame(dest_shape, src_shape,
1514 /*minor_to_major_only=*/false,
1515 /*ignore_memory_space=*/true)) {
1516 return InternalError(
1517 "Source and destination buffers in CopyDone arguments need to be the "
1518 "same shape found %s and %s\n%s",
1519 StringifyShape(dest_shape), StringifyShape(src_shape),
1520 copy_done->ToString());
1521 }
1522 return CheckShape(copy_done, ShapeUtil::GetTupleElementShape(
1523 copy_done->operand(0)->shape(), 0));
1524 }
1525
HandleSend(HloInstruction * send)1526 Status ShapeVerifier::HandleSend(HloInstruction* send) {
1527 return CheckShape(send,
1528 ShapeUtil::MakeTupleShape({send->operand(0)->shape(),
1529 ShapeUtil::MakeShape(U32, {}),
1530 ShapeUtil::MakeTokenShape()}),
1531 /*only_compare_minor_to_major_in_layout=*/true);
1532 }
1533
HandleSendDone(HloInstruction * send_done)1534 Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
1535 return CheckShape(send_done, ShapeUtil::MakeTokenShape());
1536 }
1537
HandleRecv(HloInstruction * recv)1538 Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
1539 return CheckShape(
1540 recv,
1541 ShapeUtil::MakeTupleShape(
1542 {ShapeUtil::GetTupleElementShape(recv->shape(), 0),
1543 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}),
1544 /*only_compare_minor_to_major_in_layout=*/true);
1545 }
1546
HandleRecvDone(HloInstruction * recv_done)1547 Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
1548 return CheckShape(
1549 recv_done,
1550 ShapeUtil::MakeTupleShape(
1551 {ShapeUtil::GetTupleElementShape(recv_done->operand(0)->shape(), 0),
1552 ShapeUtil::MakeTokenShape()}));
1553 }
1554
HandleBatchNormTraining(HloInstruction * batch_norm_training)1555 Status ShapeVerifier::HandleBatchNormTraining(
1556 HloInstruction* batch_norm_training) {
1557 return CheckShape(batch_norm_training,
1558 ShapeInference::InferBatchNormTrainingShape(
1559 batch_norm_training->operand(0)->shape(),
1560 batch_norm_training->operand(1)->shape(),
1561 batch_norm_training->operand(2)->shape(),
1562 batch_norm_training->feature_index()));
1563 }
1564
HandleBatchNormInference(HloInstruction * batch_norm_inference)1565 Status ShapeVerifier::HandleBatchNormInference(
1566 HloInstruction* batch_norm_inference) {
1567 return CheckShape(batch_norm_inference,
1568 ShapeInference::InferBatchNormInferenceShape(
1569 batch_norm_inference->operand(0)->shape(),
1570 batch_norm_inference->operand(1)->shape(),
1571 batch_norm_inference->operand(2)->shape(),
1572 batch_norm_inference->operand(3)->shape(),
1573 batch_norm_inference->operand(4)->shape(),
1574 batch_norm_inference->feature_index()));
1575 }
1576
HandleBatchNormGrad(HloInstruction * batch_norm_grad)1577 Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
1578 return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape(
1579 batch_norm_grad->operand(0)->shape(),
1580 batch_norm_grad->operand(1)->shape(),
1581 batch_norm_grad->operand(2)->shape(),
1582 batch_norm_grad->operand(3)->shape(),
1583 batch_norm_grad->operand(4)->shape(),
1584 batch_norm_grad->feature_index()));
1585 }
1586
1587 namespace {
1588
1589 // Checks that the instruction does not have mixed precision floating point
1590 // inputs.
CheckMixedPrecisionOperands(const HloInstruction * instruction)1591 Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
1592 switch (instruction->opcode()) {
1593 // Allow-list the following opcodes for mixed-precision check, because
1594 // they involve data pass through or grouping via tuples, where the
1595 // precisions of buffers can be different.
1596 case HloOpcode::kCall:
1597 case HloOpcode::kConditional:
1598 case HloOpcode::kConstant:
1599 case HloOpcode::kConvolution:
1600 case HloOpcode::kDot:
1601 case HloOpcode::kAllReduce:
1602 case HloOpcode::kAllReduceStart:
1603 case HloOpcode::kAllReduceDone:
1604 case HloOpcode::kAsyncDone:
1605 case HloOpcode::kAsyncUpdate:
1606 case HloOpcode::kAsyncStart:
1607 case HloOpcode::kCopyDone:
1608 case HloOpcode::kCopyStart:
1609 case HloOpcode::kCustomCall:
1610 case HloOpcode::kDomain:
1611 case HloOpcode::kFusion:
1612 case HloOpcode::kGetTupleElement:
1613 case HloOpcode::kOptimizationBarrier:
1614 case HloOpcode::kInfeed:
1615 case HloOpcode::kOutfeed:
1616 case HloOpcode::kParameter:
1617 case HloOpcode::kRecv:
1618 case HloOpcode::kRecvDone:
1619 case HloOpcode::kReducePrecision:
1620 case HloOpcode::kReduceWindow:
1621 case HloOpcode::kSend:
1622 case HloOpcode::kSendDone:
1623 case HloOpcode::kSort:
1624 case HloOpcode::kTuple:
1625 case HloOpcode::kWhile:
1626 break;
1627 default: {
1628 PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID;
1629 for (auto operand : instruction->operands()) {
1630 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1631 operand->shape(),
1632 [&](const Shape& subshape, const ShapeIndex& index) {
1633 if (!ShapeUtil::ElementIsFloating(subshape)) {
1634 return OkStatus();
1635 }
1636 if (fp_type == PRIMITIVE_TYPE_INVALID) {
1637 fp_type = subshape.element_type();
1638 } else if (fp_type != subshape.element_type()) {
1639 return InternalError(
1640 "Seen floating point types of different precisions in "
1641 "%s, but mixed precision is disallowed.",
1642 instruction->ToString());
1643 }
1644 return OkStatus();
1645 }));
1646 }
1647 }
1648 }
1649 return OkStatus();
1650 }
1651
1652 } // namespace
1653
HandleGather(HloInstruction * gather)1654 Status ShapeVerifier::HandleGather(HloInstruction* gather) {
1655 return CheckShape(
1656 gather,
1657 ShapeInference::InferGatherShape(
1658 gather->operand(0)->shape(), gather->operand(1)->shape(),
1659 gather->gather_dimension_numbers(), gather->gather_slice_sizes()));
1660 }
1661
HandleScatter(HloInstruction * scatter)1662 Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
1663 absl::InlinedVector<const Shape*, 3> arg_shapes;
1664 arg_shapes.reserve(scatter->operand_count());
1665 for (const HloInstruction* operand : scatter->operands()) {
1666 arg_shapes.push_back(&operand->shape());
1667 }
1668 return CheckShape(scatter,
1669 ShapeInference::InferScatterShape(
1670 arg_shapes, scatter->to_apply()->ComputeProgramShape(),
1671 scatter->scatter_dimension_numbers()));
1672 }
1673
HandleAfterAll(HloInstruction * token)1674 Status ShapeVerifier::HandleAfterAll(HloInstruction* token) {
1675 std::vector<const Shape*> operand_shapes;
1676 for (const HloInstruction* operand : token->operands()) {
1677 operand_shapes.push_back(&operand->shape());
1678 }
1679 return CheckShape(token, ShapeUtil::MakeTokenShape());
1680 }
1681
HandleAddDependency(HloInstruction * add_dependency)1682 Status ShapeVerifier::HandleAddDependency(HloInstruction* add_dependency) {
1683 TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1));
1684 return CheckShape(add_dependency, add_dependency->operand(0)->shape());
1685 }
1686
HandleGetDimensionSize(HloInstruction * get_size)1687 Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
1688 return CheckShape(get_size,
1689 ShapeInference::InferGetDimensionSizeShape(
1690 get_size->operand(0)->shape(), get_size->dimension()));
1691 }
1692
HandleSetDimensionSize(HloInstruction * set_size)1693 Status ShapeVerifier::HandleSetDimensionSize(HloInstruction* set_size) {
1694 return CheckShape(set_size,
1695 ShapeInference::InferSetDimensionSizeShape(
1696 set_size->operand(0)->shape(),
1697 set_size->operand(1)->shape(), set_size->dimension()));
1698 }
1699
CheckShape(const HloInstruction * instruction,const Shape & inferred_shape,bool only_compare_minor_to_major_in_layout)1700 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
1701 const Shape& inferred_shape,
1702 bool only_compare_minor_to_major_in_layout) {
1703 // If allow_mixed_precision_ is false, check if there are operands with
1704 // different precisions. We need this check because ShapeInference allows
1705 // mixed precision inputs.
1706 if (!opts_.allow_mixed_precision) {
1707 TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction));
1708 }
1709
1710 // Check if the output shape matches the expected shape.
1711 //
1712 // We treat BF16 and F32 as compatible types if mixed precision is allowed,
1713 // but only when the instruction defines the BF16/F32 buffer.
1714 bool equal = [&] {
1715 switch (instruction->opcode()) {
1716 // The opcodes below can't have implicit layout conversions, nor can they
1717 // implicitly transform f32 -> bf16. Fundamentally these are either
1718 // reinterpreting existing data (e.g. kBitcast) or shuffling data around
1719 // without modifying it (e.g. kGetTupleElement).
1720 case HloOpcode::kBitcast:
1721 case HloOpcode::kCall:
1722 case HloOpcode::kConditional:
1723 case HloOpcode::kConstant:
1724 case HloOpcode::kCopyDone:
1725 case HloOpcode::kCopyStart:
1726 case HloOpcode::kCustomCall:
1727 case HloOpcode::kGetTupleElement:
1728 case HloOpcode::kInfeed:
1729 case HloOpcode::kOutfeed:
1730 case HloOpcode::kOptimizationBarrier:
1731 case HloOpcode::kParameter:
1732 case HloOpcode::kRecv:
1733 case HloOpcode::kRecvDone:
1734 case HloOpcode::kSend:
1735 case HloOpcode::kSendDone:
1736 case HloOpcode::kTuple:
1737 case HloOpcode::kWhile:
1738 return ShapesSame(instruction->shape(), inferred_shape,
1739 only_compare_minor_to_major_in_layout);
1740 case HloOpcode::kDynamicUpdateSlice:
1741 // For DynamicUpdateSlice it has an "in-place" update semantics, but
1742 // inside of fusions memory space propagation doesn't propagate the
1743 // memory spaces all the way, causing possible mismatches. Relax the
1744 // constraint in that condition.
1745 return ShapesSame(instruction->shape(), inferred_shape,
1746 only_compare_minor_to_major_in_layout,
1747 /*ignore_memory_space=*/
1748 instruction->parent()->IsFusionComputation());
1749
1750 // We allow arbitrary layout and f32->bf16 transformations on all other
1751 // instructions, although this may be made more strict pending discussion
1752 // in b/112709536.
1753 default:
1754 if (opts_.allow_mixed_precision) {
1755 return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(),
1756 inferred_shape);
1757 } else {
1758 return ShapeUtil::Compatible(instruction->shape(), inferred_shape);
1759 }
1760 }
1761 }();
1762 if (!equal) {
1763 return InternalError(
1764 "Expected instruction to have shape equal to %s, actual "
1765 "shape is %s:\n%s",
1766 StringifyShape(inferred_shape), StringifyShape(instruction->shape()),
1767 instruction->ToString());
1768 }
1769 return OkStatus();
1770 }
1771
CheckShape(const HloInstruction * instruction,const StatusOr<Shape> & inferred_shape_status)1772 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
1773 const StatusOr<Shape>& inferred_shape_status) {
1774 if (!inferred_shape_status.ok()) {
1775 Status s = inferred_shape_status.status();
1776 tensorflow::errors::AppendToMessage(&s, ", for instruction ",
1777 instruction->ToString());
1778 return s;
1779 }
1780 return CheckShape(instruction, inferred_shape_status.ValueOrDie());
1781 }
1782
CheckUnaryShape(const HloInstruction * instruction)1783 Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {
1784 return CheckShape(instruction,
1785 ShapeInference::InferUnaryOpShape(instruction->opcode(),
1786 instruction->operand(0)));
1787 }
1788
CheckBinaryShape(const HloInstruction * instruction)1789 Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) {
1790 return CheckShape(
1791 instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(),
1792 instruction->operand(0),
1793 instruction->operand(1)));
1794 }
1795
CheckTernaryShape(const HloInstruction * instruction)1796 Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) {
1797 return CheckShape(instruction,
1798 ShapeInference::InferTernaryOpShape(
1799 instruction->opcode(), instruction->operand(0),
1800 instruction->operand(1), instruction->operand(2)));
1801 }
1802
CheckVariadicShape(const HloInstruction * instruction)1803 Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
1804 return CheckShape(instruction,
1805 ShapeInference::InferVariadicOpShape(
1806 instruction->opcode(), instruction->operands()));
1807 }
1808
VerifyEntryComputationLayout(const HloModule & module)1809 Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) {
1810 const HloComputation* computation = module.entry_computation();
1811 const auto& layout = module.entry_computation_layout();
1812 const ShapeLayout& result_layout = layout.result_layout();
1813
1814 TF_RETURN_IF_ERROR(
1815 ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape()));
1816
1817 if (!ShapeUtil::Compatible(computation->root_instruction()->shape(),
1818 result_layout.shape())) {
1819 return InternalError(
1820 "Shape of the root instruction of entry computation (%s) should be "
1821 "compatible to one specified in module's entry computation layout (%s)",
1822 ShapeUtil::HumanString(computation->root_instruction()->shape()),
1823 ShapeUtil::HumanString(result_layout.shape()));
1824 }
1825
1826 if (computation->num_parameters() != layout.parameter_count()) {
1827 return InternalError(
1828 "Number of parameters in entry computation layout (%d) must be same "
1829 "as number of parameters of entry computation (%d)",
1830 layout.parameter_count(), computation->num_parameters());
1831 }
1832
1833 for (int i = 0; i < computation->num_parameters(); ++i) {
1834 const HloInstruction* parameter = computation->parameter_instruction(i);
1835 TF_RETURN_IF_ERROR(
1836 ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i)));
1837 if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) {
1838 return InternalError(
1839 "Shape of the entry computation parameter %d is %s should be "
1840 "compatible to the one specified in module's entry computation "
1841 "layout %s",
1842 i, ShapeUtil::HumanString(parameter->shape()),
1843 ShapeUtil::HumanString(layout.parameter_shape(i)));
1844 }
1845 }
1846
1847 return OkStatus();
1848 }
1849
ComputationsToString(absl::Span<HloComputation * const> computations)1850 std::string ComputationsToString(
1851 absl::Span<HloComputation* const> computations) {
1852 return absl::StrJoin(computations, ",",
1853 [](std::string* s, const HloComputation* computation) {
1854 s->append(computation->name());
1855 });
1856 }
1857
1858 // Verifies various invariants about the structure of the HLO:
1859 //
1860 // (1) each instruction has a non-null parent() set to the HloComputation
1861 // which
1862 // contains it.
1863 //
1864 // (2) each computation has a non-null parent() set to the HloModule which
1865 // contains it.
1866 //
1867 // (3) the operands of each instruction are in the same computation as the
1868 // instruction.
VerifyHloStructure(HloModule * module)1869 Status VerifyHloStructure(HloModule* module) {
1870 for (const HloComputation* computation : module->computations()) {
1871 if (computation->parent() == nullptr) {
1872 return InternalError("Computation %s has a null parent pointer",
1873 computation->name());
1874 }
1875 if (computation->parent() != module) {
1876 return InternalError(
1877 "Computation %s parent() does not point to parent module",
1878 computation->name());
1879 }
1880
1881 for (const HloInstruction* instruction : computation->instructions()) {
1882 if (instruction->parent() == nullptr) {
1883 return InternalError("Instruction %s has a null parent pointer",
1884 instruction->name());
1885 }
1886 if (instruction->parent() != computation) {
1887 return InternalError(
1888 "Instruction %s parent() does not point to parent computation",
1889 instruction->name());
1890 }
1891 }
1892 }
1893
1894 // Check that operands are in the same computation separately from verifying
1895 // parent() correctness so conditions like a null HloInstruction::parent()
1896 // are identified and reported explicitly above rather than reporting a
1897 // mismatched operand.
1898 for (const HloComputation* computation : module->computations()) {
1899 for (const HloInstruction* instruction : computation->instructions()) {
1900 for (int i = 0; i < instruction->operand_count(); ++i) {
1901 const HloInstruction* operand = instruction->operand(i);
1902 if (operand->parent() != instruction->parent()) {
1903 return InternalError(
1904 "Operand %d (%s) of instruction %s is in a different "
1905 "computation: %s vs %s",
1906 i, operand->name(), instruction->name(),
1907 operand->parent() ? operand->parent()->name() : "(null)",
1908 instruction->parent()->name());
1909 }
1910 }
1911 }
1912 }
1913 return OkStatus();
1914 }
1915
1916 namespace {
1917
1918 // Returns true if the given Shape has a TOKEN shape as any subshape.
ShapeContainsToken(const Shape & shape)1919 bool ShapeContainsToken(const Shape& shape) {
1920 bool contains_token = false;
1921 ShapeUtil::ForEachSubshape(
1922 shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
1923 if (subshape.IsToken()) {
1924 contains_token = true;
1925 }
1926 });
1927 return contains_token;
1928 }
1929
1930 // Verifies that all types entering and exiting the entry computation are
1931 // legal.
VerifyEntryAndExitShapes(const HloModule & module)1932 Status VerifyEntryAndExitShapes(const HloModule& module) {
1933 // Tokens cannot be passed as entry parameters.
1934 // TODO(b/80000000): Remove this constraint.
1935 for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
1936 HloInstruction* param =
1937 module.entry_computation()->parameter_instruction(i);
1938 if (ShapeContainsToken(param->shape())) {
1939 return InternalError(
1940 "Entry parameter %d is or contains a token shape: %s", i,
1941 ShapeUtil::HumanString(param->shape()));
1942 }
1943 }
1944 return OkStatus();
1945 }
1946
1947 // Checks if the given two instructions share the same channel id.
CheckSameChannel(const HloInstruction * instr1,const HloInstruction * instr2)1948 Status CheckSameChannel(const HloInstruction* instr1,
1949 const HloInstruction* instr2) {
1950 if (instr1->channel_id() != instr2->channel_id()) {
1951 return InternalError(
1952 "Expected to have the same channel id, actual channel ids are: %s "
1953 "(%d), %s (%d)",
1954 instr1->ToString(), *instr1->channel_id(), instr2->ToString(),
1955 *instr2->channel_id());
1956 }
1957 return OkStatus();
1958 }
1959
1960 // Checks if the given two instructions have the same is_host_transfer
1961 // attribute value. Intsructions must be send/recv instructions or their
1962 // 'done' variant.
CheckSameIsHostTransfer(const HloInstruction * instr1,const HloInstruction * instr2)1963 Status CheckSameIsHostTransfer(const HloInstruction* instr1,
1964 const HloInstruction* instr2) {
1965 const HloSendRecvInstruction* send_recv1 =
1966 DynCast<const HloSendRecvInstruction>(instr1);
1967 const HloSendRecvInstruction* send_recv2 =
1968 DynCast<const HloSendRecvInstruction>(instr2);
1969 TF_RET_CHECK(send_recv1 != nullptr);
1970 TF_RET_CHECK(send_recv2 != nullptr);
1971 if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
1972 return InternalError(
1973 "Expected instructions to have the same is-host-transfer property: "
1974 "%s, "
1975 "%s ",
1976 instr1->ToString(), instr2->ToString());
1977 }
1978 return OkStatus();
1979 }
1980
VerifySingleUser(const HloInstruction * instruction,const absl::flat_hash_set<HloOpcode> & expected_users)1981 Status VerifySingleUser(const HloInstruction* instruction,
1982 const absl::flat_hash_set<HloOpcode>& expected_users) {
1983 TF_RET_CHECK(instruction->users().size() == 1)
1984 << "The " << HloOpcodeString(instruction->opcode())
1985 << " instruction requires one consumer, found "
1986 << instruction->users().size();
1987
1988 const HloInstruction* user = instruction->users().front();
1989 TF_RET_CHECK(expected_users.contains(user->opcode()))
1990 << "The consumer of a " << HloOpcodeString(instruction->opcode())
1991 << " instruction needs to be one of ("
1992 << absl::StrJoin(expected_users, ", ",
1993 [](std::string* out, HloOpcode opcode) {
1994 out->append(HloOpcodeString(opcode));
1995 })
1996 << "), found " << HloOpcodeString(user->opcode());
1997 return OkStatus();
1998 }
1999
VerifySingleOperand(const HloInstruction * instruction,const std::vector<HloOpcode> & expected_operands)2000 Status VerifySingleOperand(const HloInstruction* instruction,
2001 const std::vector<HloOpcode>& expected_operands) {
2002 TF_RET_CHECK(instruction->operands().size() == 1)
2003 << "The " << HloOpcodeString(instruction->opcode())
2004 << " instruction requires one consumer, found "
2005 << instruction->users().size();
2006
2007 const HloInstruction* operand = instruction->operand(0);
2008 TF_RET_CHECK(absl::c_find(expected_operands, operand->opcode()) !=
2009 expected_operands.end())
2010 << "The operand of a " << HloOpcodeString(instruction->opcode())
2011 << " instruction needs to be "
2012 << absl::StrJoin(expected_operands, " or ",
2013 [](std::string* out, HloOpcode opcode) {
2014 out->append(HloOpcodeString(opcode));
2015 })
2016 << ", found " << HloOpcodeString(operand->opcode());
2017 return OkStatus();
2018 }
2019
2020 // Checks asynchronous instruction pairs.
VerifyAsynchronousInstructionPairs(const HloModule & module)2021 Status VerifyAsynchronousInstructionPairs(const HloModule& module) {
2022 // CopyStart must have a single CopyDone user.
2023 for (const HloComputation* computation : module.computations()) {
2024 for (const HloInstruction* instruction : computation->instructions()) {
2025 switch (instruction->opcode()) {
2026 case HloOpcode::kAsyncStart: {
2027 TF_RETURN_IF_ERROR(VerifySingleUser(
2028 instruction, {HloOpcode::kAsyncUpdate, HloOpcode::kAsyncDone}));
2029 break;
2030 }
2031 case HloOpcode::kAsyncUpdate: {
2032 TF_RETURN_IF_ERROR(VerifySingleOperand(
2033 instruction, {HloOpcode::kAsyncStart, HloOpcode::kAsyncUpdate}));
2034 TF_RETURN_IF_ERROR(VerifySingleUser(
2035 instruction, {HloOpcode::kAsyncUpdate, HloOpcode::kAsyncDone}));
2036 break;
2037 }
2038 case HloOpcode::kAsyncDone: {
2039 TF_RETURN_IF_ERROR(VerifySingleOperand(
2040 instruction, {HloOpcode::kAsyncStart, HloOpcode::kAsyncUpdate}));
2041 break;
2042 }
2043 case HloOpcode::kAllReduceStart: {
2044 TF_RETURN_IF_ERROR(
2045 VerifySingleUser(instruction, {HloOpcode::kAllReduceDone}));
2046 break;
2047 }
2048 case HloOpcode::kAllReduceDone: {
2049 TF_RETURN_IF_ERROR(
2050 VerifySingleOperand(instruction, {HloOpcode::kAllReduceStart}));
2051 break;
2052 }
2053 case HloOpcode::kCopyStart: {
2054 TF_RETURN_IF_ERROR(
2055 VerifySingleUser(instruction, {HloOpcode::kCopyDone}));
2056 break;
2057 }
2058 case HloOpcode::kCopyDone: {
2059 TF_RETURN_IF_ERROR(
2060 VerifySingleOperand(instruction, {HloOpcode::kCopyStart}));
2061 break;
2062 }
2063 case HloOpcode::kCollectivePermuteStart: {
2064 TF_RETURN_IF_ERROR(VerifySingleUser(
2065 instruction, {HloOpcode::kCollectivePermuteDone}));
2066 break;
2067 }
2068 case HloOpcode::kCollectivePermuteDone: {
2069 TF_RETURN_IF_ERROR(VerifySingleOperand(
2070 instruction, {HloOpcode::kCollectivePermuteStart}));
2071 break;
2072 }
2073 default:
2074 break;
2075 }
2076 }
2077 }
2078 return OkStatus();
2079 }
2080
2081 // Checks that AllReduce instructions in the module are either all layout
2082 // constrained or all unconstrained.
VerifyLayoutConstrainedAllReduce(const HloModule & module)2083 Status VerifyLayoutConstrainedAllReduce(const HloModule& module) {
2084 const HloAllReduceInstruction* reference = nullptr;
2085 for (const HloComputation* computation : module.computations()) {
2086 for (const HloInstruction* instruction : computation->instructions()) {
2087 if ((instruction->opcode() != HloOpcode::kAllReduce) &&
2088 (instruction->opcode() != HloOpcode::kAllReduceStart)) {
2089 continue;
2090 }
2091 auto all_reduce = DynCast<HloAllReduceInstruction>(instruction);
2092 if (!reference) {
2093 reference = all_reduce;
2094 }
2095 if (reference->constrain_layout() != all_reduce->constrain_layout()) {
2096 return FailedPrecondition(
2097 "HloModule has a mix of layout constrained and unconstrained "
2098 "AllReduce instructions.");
2099 }
2100 }
2101 }
2102 return OkStatus();
2103 }
2104
2105 // Checks various invariants of channel instructions (send/recv and
2106 // collectives).
VerifyChannels(const HloModule & module)2107 Status VerifyChannels(const HloModule& module) {
2108 absl::flat_hash_map<int64_t, std::vector<const HloInstruction*>>
2109 channel_instructions;
2110
2111 // Send/Recv instruction must have a single user: the corresponding
2112 // SendDone/RecvDone. with matching channel.
2113 for (const HloComputation* computation : module.computations()) {
2114 for (const HloInstruction* instruction : computation->instructions()) {
2115 auto channel_instr = DynCast<HloChannelInstruction>(instruction);
2116 if (!channel_instr || !channel_instr->channel_id()) {
2117 continue;
2118 }
2119 channel_instructions[*channel_instr->channel_id()].push_back(instruction);
2120
2121 switch (instruction->opcode()) {
2122 case HloOpcode::kSend: {
2123 TF_RET_CHECK(instruction->users().size() == 1);
2124 const HloInstruction* send_done = instruction->users().front();
2125 TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
2126 TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done));
2127 TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done));
2128 break;
2129 }
2130 case HloOpcode::kRecv: {
2131 TF_RET_CHECK(instruction->users().size() == 1);
2132 const HloInstruction* recv_done = instruction->users().front();
2133 TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
2134 TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done));
2135 TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done));
2136 break;
2137 }
2138 case HloOpcode::kSendDone:
2139 TF_RET_CHECK(instruction->operands().size() == 1);
2140 TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend);
2141 break;
2142 case HloOpcode::kRecvDone:
2143 TF_RET_CHECK(instruction->operands().size() == 1);
2144 TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv);
2145 break;
2146 default:
2147 break;
2148 }
2149 }
2150 }
2151
2152 // Iterate over each channel to check invariants.
2153 for (auto& pair : channel_instructions) {
2154 auto& instructions = pair.second;
2155 const HloInstruction* first = instructions[0];
2156 auto sendrecv = DynCast<HloSendRecvInstruction>(first);
2157 if (sendrecv) {
2158 absl::flat_hash_set<HloOpcode> opcodes;
2159 for (const HloInstruction* instr : instructions) {
2160 opcodes.insert(instr->opcode());
2161 auto cast = DynCast<HloSendRecvInstruction>(instr);
2162 TF_RET_CHECK(cast != nullptr)
2163 << "channel " << pair.first
2164 << " is used for different types of channel instructions";
2165 }
2166 if (sendrecv->is_host_transfer()) {
2167 TF_RET_CHECK(instructions.size() == 2)
2168 << "channel " << pair.first
2169 << " is used for multiple host send/recv instructions";
2170 } else {
2171 TF_RET_CHECK(instructions.size() == opcodes.size())
2172 << "channel " << pair.first
2173 << " is used for multiple send/recv instructions";
2174 }
2175 } else {
2176 for (const HloInstruction* instr : instructions) {
2177 TF_RET_CHECK(first->opcode() == instr->opcode())
2178 << "channel " << pair.first
2179 << " is used for different types of channel instructions";
2180 }
2181 }
2182 }
2183
2184 return OkStatus();
2185 }
2186
2187 // CHECKs various invariants of a fusion instruction.
CheckFusionInstruction(HloInstruction * fusion)2188 Status CheckFusionInstruction(HloInstruction* fusion) {
2189 // The parent fusion instruction of the fusion computation must be 'fusion'.
2190 HloComputation* fused_computation = fusion->fused_instructions_computation();
2191 if (fusion != fused_computation->FusionInstruction()) {
2192 return InternalError(
2193 "Instruction of fused computation does not match expected "
2194 "instruction "
2195 "%s.",
2196 fusion->ToString());
2197 }
2198
2199 // Fused root instruction and fused parameters must all be owned by the
2200 // fusion computation.
2201 bool root_owned = false;
2202 const std::vector<HloInstruction*>& fused_parameters =
2203 fusion->fused_parameters();
2204 const HloInstruction* fused_root = fusion->fused_expression_root();
2205 std::vector<bool> parameter_owned(fused_parameters.size(), false);
2206 for (auto* instruction : fused_computation->instructions()) {
2207 if (fused_root == instruction) {
2208 if (root_owned) {
2209 return InternalError("Root appears more than once in %s.",
2210 fusion->ToString());
2211 }
2212 root_owned = true;
2213 }
2214 for (int i = 0; i < fused_parameters.size(); ++i) {
2215 if (fused_parameters[i] == instruction) {
2216 if (parameter_owned[i]) {
2217 return InternalError("Parameter appears more than once in %s.",
2218 fusion->ToString());
2219 }
2220 parameter_owned[i] = true;
2221 }
2222 }
2223 }
2224 if (!root_owned) {
2225 return InternalError("Root not found in computation of %s.",
2226 fusion->ToString());
2227 }
2228 // Make sure all the parameter_owned entries are set
2229 for (int i = 0; i < parameter_owned.size(); i++) {
2230 if (!parameter_owned[i]) {
2231 return InternalError("Parameter %d not found in computation of %s.", i,
2232 fusion->ToString());
2233 }
2234 }
2235
2236 // Fused root must have no users.
2237 if (fused_root->user_count() != 0) {
2238 return InternalError("Root of %s may not have users.", fusion->ToString());
2239 }
2240
2241 // All uses of fused instructions must be in the fusion computation, and
2242 // every non-root instruction must have at least one use.
2243 for (auto* instruction :
2244 fusion->fused_instructions_computation()->instructions()) {
2245 if (instruction != fused_root) {
2246 if (instruction->user_count() == 0) {
2247 return InternalError("Non-root instruction %s in %s must have users.",
2248 instruction->ToString(), fusion->ToString());
2249 }
2250 for (auto& user : instruction->users()) {
2251 if (fused_computation != user->parent()) {
2252 return InternalError(
2253 "Non-root instruction %s in %s may not have external users.",
2254 instruction->ToString(), fusion->ToString());
2255 }
2256 }
2257 }
2258 }
2259
2260 // Fused parameter instructions must be numbered contiguously and match up
2261 // (shapes equal) with their respective operand.
2262 CHECK_EQ(fusion->operands().size(), fused_parameters.size());
2263 std::vector<bool> parameter_numbers(fused_parameters.size(), false);
2264 for (auto fused_param : fused_parameters) {
2265 int64_t param_no = fused_param->parameter_number();
2266 if (param_no < 0) {
2267 return InternalError("Unexpected negative parameter number %d in %s.",
2268 param_no, fusion->ToString());
2269 }
2270 if (param_no >= fused_parameters.size()) {
2271 return InternalError(
2272 "Unexpected parameter number %d in %s: higher then number of "
2273 "parameters %lu.",
2274 param_no, fusion->ToString(), fused_parameters.size());
2275 }
2276 if (parameter_numbers[param_no]) {
2277 return InternalError(
2278 "Did not expect parameter number %d more than once in %s.", param_no,
2279 fusion->ToString());
2280 }
2281 parameter_numbers[param_no] = true;
2282 }
2283 // Make sure all the parameter_numbers entries were seen.
2284 for (int i = 0; i < parameter_numbers.size(); i++) {
2285 if (!parameter_numbers[i]) {
2286 return InternalError("Did not see parameter number %d in %s.", i,
2287 fusion->ToString());
2288 }
2289 }
2290
2291 TF_RET_CHECK(fusion->called_computations() ==
2292 absl::Span<HloComputation* const>(
2293 {fusion->fused_instructions_computation()}))
2294 << "Fusion HLO calls computations other than the "
2295 "fused_instructions_computation: "
2296 << fusion->ToString() << " fusion->fused_instructions_computation(): "
2297 << fusion->fused_instructions_computation()->ToString()
2298 << " fusion->called_computations(): "
2299 << ComputationsToString(fusion->called_computations());
2300
2301 for (const auto& fused : fusion->fused_instructions()) {
2302 TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation())
2303 << "Fused HLO was missing a parent: " << fused->ToString()
2304 << " parent: " << fused->parent()
2305 << " computation: " << fusion->parent();
2306 }
2307
2308 // TODO(b/65423525): We'd like to check that all operands are distinct.
2309 // This is currently disabled due to the invariant being violated by
2310 // multi-output fusion.
2311 return OkStatus();
2312 }
2313
2314 // Checks that the operand shapes are compatible to the output shape, i.e.,
2315 // that there are no implicit broadcasts.
CheckElementwiseInstruction(HloInstruction * instruction)2316 Status CheckElementwiseInstruction(HloInstruction* instruction) {
2317 const Shape& out_shape = instruction->shape();
2318 for (HloInstruction* operand : instruction->operands()) {
2319 const Shape& operand_shape = operand->shape();
2320 if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
2321 return FailedPrecondition(
2322 "Implicit broadcast is not allowed in HLO."
2323 "Found different shapes for instruction %s.\n"
2324 "output: %s\noperand: %s\n",
2325 HloOpcodeString(instruction->opcode()),
2326 ShapeUtil::HumanString(out_shape),
2327 ShapeUtil::HumanString(operand_shape));
2328 }
2329 }
2330 if (auto* comparison = DynCast<HloCompareInstruction>(instruction)) {
2331 const Shape& operand_shape = comparison->operand(1)->shape();
2332 PrimitiveType operand_element_type = operand_shape.element_type();
2333 Comparison::Type default_comparison_type =
2334 Comparison::DefaultComparisonType(operand_element_type);
2335 if (primitive_util::IsFloatingPointType(operand_element_type)) {
2336 if (comparison->type() != Comparison::Type::kFloat &&
2337 comparison->type() != Comparison::Type::kFloatTotalOrder) {
2338 return FailedPrecondition(
2339 "Expected comparison type %s or %s.\n"
2340 "actual: %s\noperand: %s\n",
2341 ComparisonTypeToString(Comparison::Type::kFloat),
2342 ComparisonTypeToString(Comparison::Type::kFloatTotalOrder),
2343 ComparisonTypeToString(comparison->type()),
2344 ShapeUtil::HumanString(operand_shape));
2345 }
2346 } else if (comparison->type() != default_comparison_type) {
2347 return FailedPrecondition(
2348 "Expected comparison type %s.\n"
2349 "actual: %s\noperand: %s\n",
2350 ComparisonTypeToString(default_comparison_type),
2351 ComparisonTypeToString(comparison->type()),
2352 ShapeUtil::HumanString(operand_shape));
2353 }
2354 }
2355 return OkStatus();
2356 }
2357
2358 // Visitor which verifies various fields on the HLO instruction. This class does
2359 // not check result shape as that is checked in the ShapeVerifier.
2360 class InstructionVerifier : public DfsHloVisitorWithDefault {
2361 public:
InstructionVerifier(const HloVerifierOpts & opts)2362 explicit InstructionVerifier(const HloVerifierOpts& opts) : opts_(opts) {}
2363
DefaultAction(HloInstruction *)2364 Status DefaultAction(HloInstruction*) override { return OkStatus(); }
2365
HandleFusion(HloInstruction * fusion)2366 Status HandleFusion(HloInstruction* fusion) override {
2367 TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(
2368 fusion, /*skip_nested_async_op_check*/ false));
2369 return CheckFusionInstruction(fusion);
2370 }
2371
HandleBroadcast(HloInstruction * broadcast)2372 Status HandleBroadcast(HloInstruction* broadcast) override {
2373 // If you see this failure then someone has confused the difference
2374 // between the HLO broadcast op, and the UserComputation broadcast
2375 // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
2376 // or ComputationLowerer::Visit()
2377 TF_RET_CHECK(broadcast->dimensions().size() ==
2378 broadcast->operand(0)->shape().rank())
2379 << "Broadcast HLO (" << broadcast->ToShortString()
2380 << ") has invalid number of dimensions: "
2381 << broadcast->dimensions().size()
2382 << " != " << broadcast->operand(0)->shape().rank();
2383 if (opts_.verify_broadcast_dimensions_order) {
2384 TF_RET_CHECK(absl::c_is_sorted(broadcast->dimensions()))
2385 << "Broadcast dimensions should be ordered, got: "
2386 << broadcast->ToString();
2387 }
2388 return OkStatus();
2389 }
2390
HandleBitcastConvert(HloInstruction * c)2391 Status HandleBitcastConvert(HloInstruction* c) override {
2392 // Shape verifier will check all we need.
2393 return OkStatus();
2394 }
2395
HandleWhile(HloInstruction * xla_while)2396 Status HandleWhile(HloInstruction* xla_while) override {
2397 auto* while_cond = xla_while->while_condition();
2398 auto* while_body = xla_while->while_body();
2399 if (while_cond->num_parameters() != 1) {
2400 return FailedPrecondition(
2401 "While condition must have exactly 1 parameter; had %d : %s",
2402 while_cond->num_parameters(), while_cond->ToString());
2403 }
2404 if (while_body->num_parameters() != 1) {
2405 return FailedPrecondition(
2406 "While body must have exactly 1 parameter; had %d : %s",
2407 while_body->num_parameters(), while_body->ToString());
2408 }
2409 if (xla_while->operand_count() != 1) {
2410 return FailedPrecondition(
2411 "While loop must have exactly one operand; had %d : %s",
2412 xla_while->operand_count(), xla_while->ToString());
2413 }
2414 // Allow kWhile to contain computations on separate thread.
2415 TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(
2416 xla_while, /*skip_nested_async_op_check=*/true));
2417 return OkStatus();
2418 }
2419
HandleCall(HloInstruction * call)2420 Status HandleCall(HloInstruction* call) override {
2421 // Allow kCall to contain computations on separate thread.
2422 return CheckCallableInstructionThreadName(
2423 call, /*skip_nested_async_op_check=*/true);
2424 }
2425
HandleConditional(HloInstruction * conditional)2426 Status HandleConditional(HloInstruction* conditional) override {
2427 for (int b = 0; b < conditional->branch_count(); ++b) {
2428 if (conditional->branch_computation(b)->num_parameters() != 1) {
2429 return FailedPrecondition(
2430 "Branch computation %s of %s must have 1 parameter instead of %d",
2431 conditional->branch_computation(b)->name(), conditional->ToString(),
2432 conditional->branch_computation(b)->num_parameters());
2433 }
2434 }
2435 // Allow kConditional to contain computations on separate thread.
2436 TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(
2437 conditional, /*skip_nested_async_op_check=*/true));
2438 return OkStatus();
2439 }
2440
HandleElementwiseUnary(HloInstruction * instruction)2441 Status HandleElementwiseUnary(HloInstruction* instruction) override {
2442 return CheckElementwiseInstruction(instruction);
2443 }
2444
HandleElementwiseBinary(HloInstruction * instruction)2445 Status HandleElementwiseBinary(HloInstruction* instruction) override {
2446 return CheckElementwiseInstruction(instruction);
2447 }
2448
HandleGetTupleElement(HloInstruction * gte)2449 Status HandleGetTupleElement(HloInstruction* gte) override {
2450 TF_RET_CHECK(gte->operand(0)->shape().IsTuple());
2451 return OkStatus();
2452 }
2453
HandleTranspose(HloInstruction * transpose)2454 Status HandleTranspose(HloInstruction* transpose) override {
2455 const Shape& shape = transpose->shape();
2456 const HloInstruction* operand = transpose->operand(0);
2457 TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size());
2458 TF_RET_CHECK(shape.dimensions().size() ==
2459 transpose->operand(0)->shape().dimensions().size());
2460 TF_RET_CHECK(std::equal(
2461 shape.dimensions().begin(), shape.dimensions().end(),
2462 Permute(operand->shape().dimensions(), transpose->dimensions())
2463 .begin()))
2464 << "shape: " << shape << ", operand->shape(): " << shape
2465 << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ")
2466 << "}";
2467 return OkStatus();
2468 }
2469
HandleAllReduce(HloInstruction * crs)2470 Status HandleAllReduce(HloInstruction* crs) override {
2471 if (crs->channel_id().has_value()) {
2472 TF_RET_CHECK(crs->channel_id().value() > 0)
2473 << "All reduce channel id must be greater than 0 for "
2474 << crs->ToShortString();
2475 }
2476 return OkStatus();
2477 }
2478
HandleReshape(HloInstruction * hlo)2479 Status HandleReshape(HloInstruction* hlo) override {
2480 if (opts_.verify_reshape_is_bitcast && !hlo->IsFused()) {
2481 TF_RET_CHECK(
2482 ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()))
2483 << "Reshape should be a physical bitcast, got: " << hlo->ToString();
2484 }
2485 return OkStatus();
2486 }
2487
HandleCustomCall(HloInstruction * hlo)2488 Status HandleCustomCall(HloInstruction* hlo) override {
2489 if (opts_.verify_custom_call_nested_computation_thread_name) {
2490 // Allow kCustomCall to contain computations on separate thread.
2491 return CheckCallableInstructionThreadName(
2492 hlo, /*skip_nested_async_op_check=*/true);
2493 }
2494 return OkStatus();
2495 }
2496
Preprocess(HloInstruction * instruction)2497 Status Preprocess(HloInstruction* instruction) override {
2498 auto previous = instructions_by_name_.find(instruction->name());
2499 TF_RET_CHECK(previous == instructions_by_name_.end())
2500 << "HLO has name that is not unique within module:\n"
2501 << instruction->ToString()
2502 << " in computation: " << instruction->parent()->name()
2503 << "\nPrevious HLO with same name:\n"
2504 << previous->second->ToString()
2505 << " in computation: " << previous->second->parent()->name();
2506 instructions_by_name_[instruction->name()] = instruction;
2507 return OkStatus();
2508 }
2509
Postprocess(HloInstruction * instruction)2510 Status Postprocess(HloInstruction* instruction) override {
2511 if (!opts_.InstructionCanChangeLayout(instruction) &&
2512 LayoutUtil::IsDenseArray(instruction->shape()) &&
2513 instruction->shape().has_layout()) {
2514 const Shape& result_shape = instruction->shape();
2515 const Layout& result_layout = result_shape.layout();
2516 for (HloInstruction* operand : instruction->operands()) {
2517 const Shape& operand_shape = operand->shape();
2518 if (LayoutUtil::IsDenseArray(operand_shape) &&
2519 operand_shape.rank() == result_shape.rank() &&
2520 operand_shape.has_layout()) {
2521 const Layout& operand_layout = operand_shape.layout();
2522 TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
2523 << "Instruction shouldn't change layouts "
2524 << instruction->ToString() << " From " << result_shape << " To "
2525 << operand_shape;
2526 }
2527 }
2528 }
2529
2530 return OkStatus();
2531 }
2532
2533 private:
2534 absl::flat_hash_map<std::string, const HloInstruction*> instructions_by_name_;
2535 const HloVerifierOpts& opts_;
2536 };
2537
2538 } // namespace
2539
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2540 StatusOr<bool> HloVerifier::Run(
2541 HloModule* module,
2542 const absl::flat_hash_set<absl::string_view>& execution_threads) {
2543 auto disabled = module->config().debug_options().xla_disable_hlo_passes();
2544 if (std::find(disabled.begin(), disabled.end(), name()) != disabled.end()) {
2545 return false;
2546 }
2547 auto status_or_changed = [&]() -> StatusOr<bool> {
2548 TF_RET_CHECK(!module->name().empty());
2549
2550 if (module->entry_computation()->IsFusionComputation()) {
2551 return InvalidArgument(
2552 "Module entry computation cannot be a fusion computation");
2553 }
2554
2555 TF_RETURN_IF_ERROR(VerifyHloStructure(module));
2556 TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module));
2557 TF_RETURN_IF_ERROR(VerifyChannels(*module));
2558
2559 std::unique_ptr<ShapeVerifier> shape_verifier =
2560 target_metadata_->GetVerifier();
2561 InstructionVerifier instruction_verifier(
2562 target_metadata_->GetVerifierOpts());
2563 for (auto* computation : module->computations(execution_threads)) {
2564 TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
2565 TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
2566 }
2567
2568 TF_RETURN_IF_ERROR(shape_verifier->VerifyEntryComputationLayout(*module));
2569 TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
2570
2571 // If the module has a schedule, it must be valid.
2572 if (module->has_schedule()) {
2573 TF_RETURN_IF_ERROR(module->schedule().Verify());
2574 }
2575
2576 TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(
2577 *module, [this](const Shape& shape) -> int64_t {
2578 if (target_metadata_->GetVerifierOpts().IsLayoutSensitive()) {
2579 return target_metadata_->GetVerifierOpts().ShapeSize(shape);
2580 } else {
2581 return 0;
2582 }
2583 }));
2584
2585 TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module));
2586 TF_RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module));
2587 return false;
2588 }();
2589 if (status_or_changed.ok()) {
2590 return status_or_changed.ValueOrDie();
2591 }
2592 return Status(status_or_changed.status().code(),
2593 absl::StrCat("during context [", context_, "]: ",
2594 status_or_changed.status().error_message()));
2595 }
2596
2597 } // namespace xla
2598