xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_verifier.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
17 
18 #include <memory>
19 #include <optional>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/string_view.h"
28 #include "tensorflow/compiler/xla/comparison_util.h"
29 #include "tensorflow/compiler/xla/permutation_util.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
33 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
34 #include "tensorflow/compiler/xla/service/hlo_computation.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/platform/errors.h"
44 
45 namespace xla {
46 
47 namespace {
48 
IsCallerInstruction(HloInstruction * hlo)49 bool IsCallerInstruction(HloInstruction* hlo) {
50   switch (hlo->opcode()) {
51     case HloOpcode::kAsyncStart:
52     case HloOpcode::kAsyncUpdate:
53     case HloOpcode::kAsyncDone:
54     case HloOpcode::kCall:
55     case HloOpcode::kConditional:
56     case HloOpcode::kWhile:
57     case HloOpcode::kAllReduce:
58     case HloOpcode::kReduceScatter:
59     case HloOpcode::kAllReduceStart:
60     case HloOpcode::kMap:
61     case HloOpcode::kReduce:
62     case HloOpcode::kReduceWindow:
63     case HloOpcode::kScatter:
64     case HloOpcode::kSelectAndScatter:
65     case HloOpcode::kSort:
66     case HloOpcode::kFusion:
67     case HloOpcode::kCustomCall:
68       return true;
69     default:
70       return false;
71   }
72 }
73 
CheckOperandCount(const HloInstruction * hlo,int expected)74 Status CheckOperandCount(const HloInstruction* hlo, int expected) {
75   if (hlo->operand_count() != expected) {
76     return InternalError("Expected %d operands for %s instruction: %s",
77                          expected, HloOpcodeString(hlo->opcode()),
78                          hlo->ToString());
79   }
80   return OkStatus();
81 }
82 
CheckParameterCount(const HloInstruction * calling_instruction,const HloComputation * computation,int expected)83 Status CheckParameterCount(const HloInstruction* calling_instruction,
84                            const HloComputation* computation, int expected) {
85   if (computation->num_parameters() != expected) {
86     return InternalError(
87         "Expected computation %s called from %s to have %d parameters, has %d",
88         computation->name(), calling_instruction->name(), expected,
89         computation->num_parameters());
90   }
91   return OkStatus();
92 }
93 
GetSubgroupSize(HloCollectiveInstruction * hlo,CollectiveOpGroupMode group_mode)94 int64_t GetSubgroupSize(HloCollectiveInstruction* hlo,
95                         CollectiveOpGroupMode group_mode) {
96   const HloModuleConfig& config = hlo->GetModule()->config();
97   // empty replica groups imply all replicas form a single group.
98   int64_t replica_subgroup_size =
99       hlo->replica_groups().empty()
100           ? 0
101           : hlo->replica_groups()[0].replica_ids_size();
102   switch (group_mode) {
103     case CollectiveOpGroupMode::kCrossReplica:
104     case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
105       int64_t replica_subgroup_size =
106           hlo->replica_groups().empty()
107               ? config.replica_count()
108               : hlo->replica_groups()[0].replica_ids_size();
109       if (group_mode == CollectiveOpGroupMode::kCrossReplicaAndPartition) {
110         // Replicas from all partitions participate.
111         replica_subgroup_size *= config.num_partitions();
112       }
113       return replica_subgroup_size;
114     }
115     case CollectiveOpGroupMode::kFlattenedID:
116       return replica_subgroup_size;
117     case CollectiveOpGroupMode::kCrossPartition:
118       return hlo->replica_groups().empty()
119                  ? config.num_partitions()
120                  : hlo->replica_groups()[0].replica_ids_size();
121   }
122 }
123 
CheckNestedComputationThreadNameEqual(const HloComputation * comp,bool skip_nested_async_op_check)124 Status CheckNestedComputationThreadNameEqual(const HloComputation* comp,
125                                              bool skip_nested_async_op_check) {
126   for (const HloInstruction* instr : comp->instructions()) {
127     if (skip_nested_async_op_check && instr->IsAsynchronous()) {
128       continue;
129     }
130     for (const HloComputation* called_cmp : instr->called_computations()) {
131       if (called_cmp->execution_thread() != comp->execution_thread()) {
132         return InternalError(
133             "Nested computations expects same computation's thread name (%s vs "
134             "%s).",
135             called_cmp->execution_thread(), comp->execution_thread());
136       }
137       TF_RETURN_IF_ERROR(CheckNestedComputationThreadNameEqual(
138           called_cmp, skip_nested_async_op_check));
139     }
140   }
141   return Status::OK();
142 }
143 }  // namespace
144 
Preprocess(HloInstruction * hlo)145 Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
146   if (!hlo->called_computations().empty() && !IsCallerInstruction(hlo)) {
147     return InternalError(
148         "Called computations specified for non-caller instruction  %s",
149         hlo->ToString());
150   }
151   std::optional<int> arity = HloOpcodeArity(hlo->opcode());
152   if (arity) {
153     TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity));
154   }
155   return OkStatus();
156 }
157 
HandleElementwiseUnary(HloInstruction * hlo)158 Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
159   return CheckUnaryShape(hlo);
160 }
161 
HandleElementwiseBinary(HloInstruction * hlo)162 Status ShapeVerifier::HandleElementwiseBinary(HloInstruction* hlo) {
163   return CheckBinaryShape(hlo);
164 }
165 
HandleClamp(HloInstruction * clamp)166 Status ShapeVerifier::HandleClamp(HloInstruction* clamp) {
167   return CheckTernaryShape(clamp);
168 }
169 
HandleSelect(HloInstruction * select)170 Status ShapeVerifier::HandleSelect(HloInstruction* select) {
171   return CheckTernaryShape(select);
172 }
173 
HandleConcatenate(HloInstruction * concatenate)174 Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) {
175   std::vector<const Shape*> operand_shapes;
176   for (const HloInstruction* operand : concatenate->operands()) {
177     operand_shapes.push_back(&operand->shape());
178   }
179   return CheckShape(concatenate,
180                     ShapeInference::InferConcatOpShape(
181                         operand_shapes, concatenate->concatenate_dimension()));
182 }
183 
HandleConvert(HloInstruction * convert)184 Status ShapeVerifier::HandleConvert(HloInstruction* convert) {
185   return CheckShape(convert, ShapeInference::InferConvertShape(
186                                  convert->operand(0)->shape(),
187                                  convert->shape().element_type()));
188 }
189 
HandleBitcastConvert(HloInstruction * convert)190 Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) {
191   return CheckShape(convert, ShapeInference::InferBitcastConvertShape(
192                                  convert->operand(0)->shape(),
193                                  convert->shape().element_type()));
194 }
195 
HandleCopy(HloInstruction * copy)196 Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
197   return CheckUnaryShape(copy);
198 }
199 
HandleDot(HloInstruction * dot)200 Status ShapeVerifier::HandleDot(HloInstruction* dot) {
201   TF_ASSIGN_OR_RETURN(
202       const Shape expected,
203       ShapeInference::InferDotOpShape(
204           dot->operand(0)->shape(), dot->operand(1)->shape(),
205           dot->dot_dimension_numbers(),
206           /*preferred_element_type=*/dot->shape().element_type()));
207   return CheckShape(dot, expected);
208 }
209 
HandleConvolution(HloInstruction * convolution)210 Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
211   TF_ASSIGN_OR_RETURN(
212       Shape expected,
213       ShapeInference::InferConvolveShape(
214           convolution->operand(0)->shape(), convolution->operand(1)->shape(),
215           convolution->feature_group_count(), convolution->batch_group_count(),
216           convolution->window(), convolution->convolution_dimension_numbers(),
217           /*preferred_element_type=*/convolution->shape().element_type()));
218   return CheckShape(convolution, expected);
219 }
220 
HandleFft(HloInstruction * fft)221 Status ShapeVerifier::HandleFft(HloInstruction* fft) {
222   TF_ASSIGN_OR_RETURN(
223       const Shape expected,
224       ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(),
225                                     fft->fft_length()));
226   return CheckShape(fft, expected);
227 }
228 
HandleTriangularSolve(HloInstruction * hlo)229 Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) {
230   TF_ASSIGN_OR_RETURN(const Shape expected,
231                       ShapeInference::InferTriangularSolveShape(
232                           hlo->operand(0)->shape(), hlo->operand(1)->shape(),
233                           hlo->triangular_solve_options()));
234   return CheckShape(hlo, expected);
235 }
236 
HandleCholesky(HloInstruction * hlo)237 Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
238   TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1));
239   TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferCholeskyShape(
240                                                 hlo->operand(0)->shape()));
241   return CheckShape(hlo, expected);
242 }
243 
HandleOptimizationBarrier(HloInstruction * hlo)244 Status ShapeVerifier::HandleOptimizationBarrier(HloInstruction* hlo) {
245   TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1));
246   return CheckShape(hlo, hlo->operand(0)->shape());
247 }
248 
249 // Checks that `hlo`'s set of ReplicaGroups:
250 //
251 //  - names each replica 0 through n-1 exactly once (where n is either number of
252 //    replicas, or number of partitions, or their product)
253 //  - does not contain any empty ReplicaGroups.
254 //
255 // Note that although none of the groups may be empty, `hlo` is allowed to have
256 // empty groups when group mode is not kFlattenedID. That just means it has one
257 // big group.
258 //
259 // In general, if replica groups is not empty, all replica groups should be of
260 // the same size. The exception is all-reduce, where non-uniform replica groups
261 // are allowed. This is controlled by `uniform_replica_group_size`.
CheckReplicaGroups(HloInstruction * hlo,CollectiveOpGroupMode group_mode,bool uniform_replica_group_size=true)262 static Status CheckReplicaGroups(HloInstruction* hlo,
263                                  CollectiveOpGroupMode group_mode,
264                                  bool uniform_replica_group_size = true) {
265   if (!hlo->replica_groups().empty()) {
266     absl::flat_hash_set<int64_t> replicas_seen;
267     for (const ReplicaGroup& g : hlo->replica_groups()) {
268       if (g.replica_ids().empty()) {
269         return InternalError(
270             "Instruction cannot have an empty replica group: %s",
271             hlo->ToString());
272       }
273       for (int64_t i : g.replica_ids()) {
274         if (!replicas_seen.insert(i).second) {
275           return InternalError(
276               "Replica %d is repeated in instruction's replica-groups: %s", i,
277               hlo->ToString());
278         }
279       }
280     }
281     size_t n = replicas_seen.size();
282     for (int64_t i = 0; i < n; ++i) {
283       if (!replicas_seen.count(i)) {
284         return InternalError(
285             "Replica %d is not named in instruction's replica-groups: %s", i,
286             hlo->ToString());
287       }
288     }
289 
290     // replica-groups have numbers [0, n). This n should be either replica or
291     // partition count, or their product. In some cases, replica and/or
292     // partition count is not set in the HloModule config and has a default
293     // value of 1. For those cases, skip this part of the verification.
294     int64_t replica_count = hlo->GetModule()->config().replica_count();
295     int64_t num_partitions = hlo->GetModule()->config().num_partitions();
296     switch (group_mode) {
297       case CollectiveOpGroupMode::kCrossReplica:
298       case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
299         TF_RET_CHECK(replica_count == 1 || n == replica_count)
300             << "In " << CollectiveOpGroupModeToString(group_mode)
301             << " mode, replica groups should contain " << replica_count
302             << " replicas, but found " << n << ": " << hlo->ToString();
303         break;
304       }
305       case CollectiveOpGroupMode::kCrossPartition: {
306         TF_RET_CHECK(num_partitions == 1 || n == num_partitions)
307             << "In " << CollectiveOpGroupModeToString(group_mode)
308             << " mode, replica groups should contain " << num_partitions
309             << " partitions, but found " << n << ": " << hlo->ToString();
310         break;
311       }
312       case CollectiveOpGroupMode::kFlattenedID: {
313         const int64_t num_flattened_ids = replica_count * num_partitions;
314         TF_RET_CHECK(num_flattened_ids == 1 || n == num_flattened_ids)
315             << "In " << CollectiveOpGroupModeToString(group_mode)
316             << " mode, replica groups should contain " << num_flattened_ids
317             << " flattened IDs, but found " << n << ": " << hlo->ToString();
318         break;
319       }
320     }
321 
322     if (uniform_replica_group_size) {
323       int64_t size = hlo->replica_groups()[0].replica_ids_size();
324       for (const ReplicaGroup& g : hlo->replica_groups()) {
325         TF_RET_CHECK(size == g.replica_ids_size())
326             << "Replica groups expected to be of uniform size";
327       }
328     }
329   } else {
330     TF_RET_CHECK(group_mode != CollectiveOpGroupMode::kFlattenedID)
331         << "Replica groups must be specified in flattened-id mode";
332   }
333 
334   return OkStatus();
335 }
336 
CheckCommonAllGatherInvariants(HloInstruction * hlo,int64_t * computed_shard_count)337 static Status CheckCommonAllGatherInvariants(HloInstruction* hlo,
338                                              int64_t* computed_shard_count) {
339   auto ag = Cast<HloAllGatherInstruction>(hlo);
340   CHECK_NE(computed_shard_count, nullptr) << "Expected a shard count as input";
341   TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
342                       GetCollectiveOpGroupMode(ag->channel_id().has_value(),
343                                                ag->use_global_device_ids()));
344   TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode));
345   TF_RET_CHECK(ag->all_gather_dimension() >= 0);
346 
347   int64_t shard_count;
348   for (int64_t i = 0; i < ag->operand_count(); ++i) {
349     TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(i)->shape().rank());
350 
351     Shape output_shape;
352     if (hlo->opcode() == HloOpcode::kAllGather) {
353       output_shape = (ag->operand_count() == 1) ? ag->shape()
354                                                 : ag->shape().tuple_shapes(i);
355     } else {
356       TF_RET_CHECK(hlo->opcode() == HloOpcode::kAllGatherStart);
357       output_shape = (ag->operand_count() == 1)
358                          ? ag->shape().tuple_shapes(1)
359                          : ag->shape().tuple_shapes(1).tuple_shapes(i);
360     }
361     TF_RET_CHECK(ag->all_gather_dimension() < output_shape.rank());
362     if (i == 0) {
363       shard_count = CeilOfRatio(
364           output_shape.dimensions(ag->all_gather_dimension()),
365           ag->operand(i)->shape().dimensions(ag->all_gather_dimension()));
366     }
367   }
368 
369   int64_t subgroup_size = GetSubgroupSize(ag, group_mode);
370   // If replica and partition count is not explicitly set, it will have a
371   // default value of 1, in which case the subgroup_size will be 1 as well. Skip
372   // these verification checks in that case.
373   TF_RET_CHECK(subgroup_size == 1 || shard_count == subgroup_size)
374       << "shard_count = " << shard_count
375       << ", subgroup_size = " << subgroup_size << ", " << hlo->ToString();
376   *computed_shard_count = shard_count;
377   return OkStatus();
378 }
379 
HandleAllGather(HloInstruction * hlo)380 Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
381   auto ag = Cast<HloAllGatherInstruction>(hlo);
382   int64_t shard_count;
383   TF_RETURN_IF_ERROR(CheckCommonAllGatherInvariants(hlo, &shard_count));
384   std::vector<const Shape*> operand_shapes;
385   for (const HloInstruction* operand : hlo->operands()) {
386     operand_shapes.push_back(&operand->shape());
387   }
388   return CheckShape(
389       ag, ShapeInference::InferAllGatherShape(
390               operand_shapes, ag->all_gather_dimension(), shard_count));
391 }
392 
HandleAllGatherStart(HloInstruction * hlo)393 Status ShapeVerifier::HandleAllGatherStart(HloInstruction* hlo) {
394   auto ag = Cast<HloAllGatherInstruction>(hlo);
395   int64_t shard_count;
396   TF_RETURN_IF_ERROR(CheckCommonAllGatherInvariants(hlo, &shard_count));
397   std::vector<const Shape*> operand_shapes;
398   for (const HloInstruction* operand : hlo->operands()) {
399     operand_shapes.push_back(&operand->shape());
400   }
401   return CheckShape(
402       ag, ShapeInference::InferAllGatherStartShape(
403               operand_shapes, ag->all_gather_dimension(), shard_count));
404 }
405 
HandleAllGatherDone(HloInstruction * hlo)406 Status ShapeVerifier::HandleAllGatherDone(HloInstruction* hlo) {
407   return CheckShape(
408       hlo, ShapeInference::InferAllGatherDoneShape(hlo->operand(0)->shape()));
409 }
410 
HandleAllReduce(HloInstruction * hlo)411 Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) {
412   auto ar = Cast<HloAllReduceInstruction>(hlo);
413   TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
414                       GetCollectiveOpGroupMode(ar->channel_id().has_value(),
415                                                ar->use_global_device_ids()));
416   TF_RETURN_IF_ERROR(
417       CheckReplicaGroups(ar, group_mode, /*uniform_replica_group_size=*/false));
418 
419   std::vector<const Shape*> operand_shapes;
420   for (const HloInstruction* operand : hlo->operands()) {
421     operand_shapes.push_back(&operand->shape());
422   }
423   return CheckShape(hlo, ShapeInference::InferAllReduceShape(operand_shapes));
424 }
425 
HandleReduceScatter(HloInstruction * hlo)426 Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) {
427   auto ars = Cast<HloReduceScatterInstruction>(hlo);
428   TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
429                       GetCollectiveOpGroupMode(ars->channel_id().has_value(),
430                                                ars->use_global_device_ids()));
431   TF_RETURN_IF_ERROR(CheckReplicaGroups(ars, group_mode));
432   TF_RET_CHECK(ars->scatter_dimension() >= 0);
433 
434   for (int64_t i = 0; i < ars->operand_count(); ++i) {
435     TF_RET_CHECK(ars->scatter_dimension() < ars->operand(i)->shape().rank());
436 
437     const Shape& output_shape = (ars->operand_count() == 1)
438                                     ? ars->shape()
439                                     : ars->shape().tuple_shapes(i);
440     TF_RET_CHECK(ars->scatter_dimension() < output_shape.rank());
441   }
442 
443   const Shape& output0_shape =
444       (ars->operand_count() == 1) ? ars->shape() : ars->shape().tuple_shapes(0);
445   int64_t shard_count =
446       CeilOfRatio(ars->operand(0)->shape().dimensions(ars->scatter_dimension()),
447                   output0_shape.dimensions(ars->scatter_dimension()));
448   int64_t subgroup_size = GetSubgroupSize(ars, group_mode);
449   // If replica and partition count is not explicitly set, it will have a
450   // default value of 1, in which case the subgroup_size will be 1 as well. Skip
451   // these verification checks in that case.
452   TF_RET_CHECK(subgroup_size == 1 || shard_count == subgroup_size)
453       << "shard_count = " << shard_count
454       << ", subgroup_size = " << subgroup_size << ", " << hlo->ToString();
455 
456   std::vector<const Shape*> operand_shapes;
457   for (const HloInstruction* operand : hlo->operands()) {
458     operand_shapes.push_back(&operand->shape());
459   }
460   return CheckShape(ars,
461                     ShapeInference::InferReduceScatterShape(
462                         operand_shapes, ars->scatter_dimension(), shard_count));
463 }
464 
HandleAllReduceStart(HloInstruction * hlo)465 Status ShapeVerifier::HandleAllReduceStart(HloInstruction* hlo) {
466   auto ar = Cast<HloAllReduceInstruction>(hlo);
467   TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
468                       GetCollectiveOpGroupMode(ar->channel_id().has_value(),
469                                                ar->use_global_device_ids()));
470   TF_RETURN_IF_ERROR(
471       CheckReplicaGroups(ar, group_mode, /*uniform_replica_group_size=*/false));
472 
473   std::vector<const Shape*> operand_shapes;
474   for (const HloInstruction* operand : hlo->operands()) {
475     operand_shapes.push_back(&operand->shape());
476   }
477   return CheckShape(hlo,
478                     ShapeInference::InferAllReduceStartShape(operand_shapes));
479 }
480 
HandleAllReduceDone(HloInstruction * hlo)481 Status ShapeVerifier::HandleAllReduceDone(HloInstruction* hlo) {
482   return CheckShape(
483       hlo, ShapeInference::InferAllReduceDoneShape(hlo->operand(0)->shape()));
484 }
485 
HandleAllToAll(HloInstruction * hlo)486 Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
487   auto* all_to_all = Cast<HloAllToAllInstruction>(hlo);
488   TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
489                       GetCollectiveOpGroupMode(
490                           all_to_all->channel_id().has_value(), std::nullopt));
491 
492   TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, group_mode));
493 
494   TF_RET_CHECK(all_to_all != nullptr);
495 
496   if (all_to_all->split_dimension()) {
497     int64_t split_count = GetSubgroupSize(all_to_all, group_mode);
498     TF_RET_CHECK(hlo->operand_count() == 1);
499     return CheckShape(
500         hlo, ShapeInference::InferAllToAllShape(
501                  hlo->operand(0)->shape(), *all_to_all->split_dimension(),
502                  *all_to_all->split_dimension(), split_count));
503   } else {
504     std::vector<const Shape*> operand_shapes;
505     for (const HloInstruction* operand : hlo->operands()) {
506       operand_shapes.push_back(&operand->shape());
507     }
508     return CheckShape(hlo,
509                       ShapeInference::InferAllToAllTupleShape(operand_shapes));
510   }
511 }
512 
HandlePartitionId(HloInstruction * hlo)513 Status ShapeVerifier::HandlePartitionId(HloInstruction* hlo) {
514   return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
515 }
516 
HandleReplicaId(HloInstruction * hlo)517 Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) {
518   return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
519 }
520 
521 namespace {
522 
CheckBufferOffset(const Shape & buffer_shape,const Shape & buffer_offset_shape)523 Status CheckBufferOffset(const Shape& buffer_shape,
524                          const Shape& buffer_offset_shape) {
525   if (!buffer_offset_shape.IsTuple()) {
526     return InternalError("Buffer offset is not tuple.");
527   }
528   bool all_is_array =
529       absl::c_all_of(buffer_offset_shape.tuple_shapes(),
530                      [](const Shape& shape) { return shape.IsArray(); });
531   bool all_is_tuple =
532       absl::c_all_of(buffer_offset_shape.tuple_shapes(),
533                      [](const Shape& shape) { return shape.IsTuple(); });
534   if (!all_is_array && !all_is_tuple) {
535     return InternalError(
536         "Buffer offset should either be a tuple of arrays or "
537         " a tuple of tuples.");
538   }
539 
540   if (all_is_tuple) {
541     if (absl::c_any_of(buffer_offset_shape.tuple_shapes(),
542                        [&buffer_shape](const Shape& shape) {
543                          return ShapeUtil::TupleElementCount(shape) !=
544                                 buffer_shape.rank();
545                        })) {
546       return InternalError(
547           "Buffer offset index should have the same number of "
548           "elements as the buffer's rank.");
549     }
550   } else {
551     if (buffer_offset_shape.tuple_shapes_size() != buffer_shape.rank()) {
552       return InternalError(
553           "Buffer offset index should have the same number of "
554           "elements as the buffer's rank.");
555     }
556   }
557   return OkStatus();
558 }
559 
CheckInplaceCollectivePermute(HloInstruction * collective_permute)560 Status CheckInplaceCollectivePermute(HloInstruction* collective_permute) {
561   if (collective_permute->operand_count() == 1) {
562     return OkStatus();
563   }
564   if (collective_permute->operand_count() != 4) {
565     return InternalError("Unexpected number of operands: %d.",
566                          collective_permute->operand_count());
567   }
568 
569   const Shape& input_buffer_shape = collective_permute->operand(0)->shape();
570   const Shape& output_buffer_shape = collective_permute->operand(1)->shape();
571   const Shape& input_offset_shape = collective_permute->operand(2)->shape();
572   const Shape& output_offset_shape = collective_permute->operand(3)->shape();
573 
574   if (input_buffer_shape.IsArray() && output_buffer_shape.IsArray()) {
575     Status check_input_buffer_offset =
576         CheckBufferOffset(input_buffer_shape, input_offset_shape);
577     if (!check_input_buffer_offset.ok()) {
578       return check_input_buffer_offset;
579     }
580     Status check_output_buffer_offset =
581         CheckBufferOffset(output_buffer_shape, output_offset_shape);
582     if (!check_output_buffer_offset.ok()) {
583       return check_output_buffer_offset;
584     }
585   } else if (input_buffer_shape.IsTuple() && output_buffer_shape.IsTuple()) {
586     if (ShapeUtil::TupleElementCount(input_buffer_shape) !=
587         ShapeUtil::TupleElementCount(output_buffer_shape)) {
588       return InternalError("Unmatching input buffers and output buffers.");
589     }
590     if (!input_offset_shape.IsTuple() ||
591         ShapeUtil::TupleElementCount(input_offset_shape) !=
592             ShapeUtil::TupleElementCount(input_buffer_shape)) {
593       return InternalError("Unmatching input buffers and input offset.");
594     }
595     for (int i = 0; i < input_buffer_shape.tuple_shapes_size(); ++i) {
596       Status check_input_buffer_offset =
597           CheckBufferOffset(input_buffer_shape.tuple_shapes(i),
598                             input_offset_shape.tuple_shapes(i));
599       if (!check_input_buffer_offset.ok()) {
600         return check_input_buffer_offset;
601       }
602     }
603     if (!output_offset_shape.IsTuple() ||
604         ShapeUtil::TupleElementCount(output_offset_shape) !=
605             ShapeUtil::TupleElementCount(output_buffer_shape)) {
606       return InternalError("Unmatching output buffers and output offset.");
607     }
608     for (int i = 0; i < output_buffer_shape.tuple_shapes_size(); ++i) {
609       Status check_output_buffer_offset =
610           CheckBufferOffset(output_buffer_shape.tuple_shapes(i),
611                             output_offset_shape.tuple_shapes(i));
612       if (!check_output_buffer_offset.ok()) {
613         return check_output_buffer_offset;
614       }
615     }
616   } else {
617     return InternalError("Unmatching input buffers and output buffers.");
618   }
619   return OkStatus();
620 }
621 
CheckDuplicatedSourceOrTarget(HloInstruction * hlo,CollectiveOpGroupMode group_mode)622 Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo,
623                                      CollectiveOpGroupMode group_mode) {
624   // A source or target cannot appear twice in the collective-permute's
625   // source-target pairs. Also, based on the group formation mode, check if the
626   // source and target IDs are within expected range.
627 
628   // Note: for collective-permute, only kCrossReplica and kCrossPartition modes
629   // are valid.
630   const HloModuleConfig& config = hlo->GetModule()->config();
631   const int64_t limit = group_mode == CollectiveOpGroupMode::kCrossReplica
632                             ? config.replica_count()
633                             : config.num_partitions();
634   absl::flat_hash_map<int64_t, std::vector<int64_t>> seen_source_to_targets;
635   absl::flat_hash_map<int64_t, std::vector<int64_t>> seen_target_to_sources;
636   int allowed_seen_count = 1;
637   if (hlo->operand_count() == 4) {
638     if (hlo->operand(0)->shape().IsArray()) {
639       allowed_seen_count = hlo->operand(2)->shape().tuple_shapes_size();
640     } else {
641       allowed_seen_count =
642           hlo->operand(2)->shape().tuple_shapes(0).tuple_shapes_size();
643     }
644   }
645 
646   for (const auto& p : hlo->source_target_pairs()) {
647     TF_RET_CHECK(p.first >= 0)
648         << "Source " << p.first
649         << " in the instruction's source-target pair must be >= 0 : "
650         << hlo->ToString();
651     TF_RET_CHECK(limit == 1 || p.first < limit)
652         << "Source " << p.first
653         << " in the instruction's source-target pair must be < " << limit
654         << " : " << hlo->ToString();
655     if (seen_source_to_targets.contains(p.first) &&
656         seen_source_to_targets[p.first].size() == allowed_seen_count) {
657       if (allowed_seen_count == 1) {
658         return InternalError(
659             "Source %d appears more than once in instruction's source-target "
660             "pairs: %s",
661             p.first, hlo->ToString());
662       } else {
663         return InternalError(
664             "Source %d appears more than %d times in instruction's "
665             "source-target "
666             "pairs: %s",
667             p.first, allowed_seen_count, hlo->ToString());
668       }
669     } else {
670       seen_source_to_targets[p.first].push_back(p.second);
671     }
672     TF_RET_CHECK(p.second >= 0)
673         << "Target " << p.second
674         << " in the instruction's source-target pair must be >= 0 : "
675         << hlo->ToString();
676     TF_RET_CHECK(limit == 1 || p.second < limit)
677         << "Target " << p.second
678         << " in the instruction's source-target pair must be < " << limit
679         << " : " << hlo->ToString();
680     if (seen_target_to_sources.contains(p.second) &&
681         seen_target_to_sources[p.second].size() == allowed_seen_count) {
682       if (allowed_seen_count == 1) {
683         return InternalError(
684             "Target %d appears more than once in instruction's source-target "
685             "pairs: %s",
686             p.second, hlo->ToString());
687       } else {
688         return InternalError(
689             "Target %d appears more than %d times in instruction's "
690             "source-target "
691             "pairs: %s",
692             p.second, allowed_seen_count, hlo->ToString());
693       }
694     } else {
695       seen_target_to_sources[p.second].push_back(p.first);
696     }
697   }
698   return OkStatus();
699 }
700 
701 }  // namespace
702 
HandleCollectivePermute(HloInstruction * hlo)703 Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
704   TF_ASSIGN_OR_RETURN(
705       CollectiveOpGroupMode group_mode,
706       GetCollectiveOpGroupMode(hlo->channel_id().has_value(),
707                                /*use_global_device_ids=*/std::nullopt));
708   TF_RETURN_IF_ERROR(CheckInplaceCollectivePermute(hlo));
709   TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo, group_mode));
710   std::vector<const Shape*> operand_shapes;
711   absl::c_transform(
712       hlo->operands(), std::back_inserter(operand_shapes),
713       [](const HloInstruction* operand) { return &(operand->shape()); });
714   return CheckShape(
715       hlo, ShapeInference::InferCollectivePermuteShape(operand_shapes));
716 }
717 
HandleCollectivePermuteStart(HloInstruction * hlo)718 Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) {
719   TF_ASSIGN_OR_RETURN(
720       CollectiveOpGroupMode group_mode,
721       GetCollectiveOpGroupMode(hlo->channel_id().has_value(),
722                                /*use_global_device_ids=*/std::nullopt));
723   TF_RETURN_IF_ERROR(CheckInplaceCollectivePermute(hlo));
724   TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo, group_mode));
725   std::vector<const Shape*> operand_shapes;
726   absl::c_transform(
727       hlo->operands(), std::back_inserter(operand_shapes),
728       [](const HloInstruction* operand) { return &(operand->shape()); });
729   return CheckShape(
730       hlo, ShapeInference::InferCollectivePermuteStartShape(operand_shapes));
731 }
732 
HandleCollectivePermuteDone(HloInstruction * hlo)733 Status ShapeVerifier::HandleCollectivePermuteDone(HloInstruction* hlo) {
734   return CheckShape(hlo, ShapeInference::InferCollectivePermuteDoneShape(
735                              hlo->operand(0)->shape()));
736 }
737 
HandleReducePrecision(HloInstruction * reduce_precision)738 Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
739   return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
740                                           reduce_precision->operand(0)->shape(),
741                                           reduce_precision->exponent_bits(),
742                                           reduce_precision->mantissa_bits()));
743 }
744 
CheckIsTokenOperand(const HloInstruction * instruction,int64_t operand_no)745 Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction,
746                                           int64_t operand_no) {
747   const HloInstruction* token = instruction->operand(operand_no);
748   if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
749     return InternalError(
750         "Expected operand %d to be token-shaped, actual shape is "
751         "%s:\n%s",
752         operand_no, StringifyShape(token->shape()), instruction->ToString());
753   }
754   return OkStatus();
755 }
756 
CheckOperandAndParameter(const HloInstruction * instruction,int64_t operand_number,const HloComputation * computation,int64_t parameter_number)757 Status ShapeVerifier::CheckOperandAndParameter(
758     const HloInstruction* instruction, int64_t operand_number,
759     const HloComputation* computation, int64_t parameter_number) {
760   const HloInstruction* operand = instruction->operand(operand_number);
761   const HloInstruction* parameter =
762       computation->parameter_instruction(parameter_number);
763   if (!ShapesSame(operand->shape(), parameter->shape())) {
764     return InternalError("Operand %s shape does not match parameter's %s in %s",
765                          operand->ToString(), parameter->ToString(),
766                          instruction->ToString());
767   }
768   return OkStatus();
769 }
770 
HandleInfeed(HloInstruction * instruction)771 Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
772   HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
773   TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
774 
775   // The output of infeed is a tuple containing the data value and a token.
776   return CheckShape(infeed,
777                     ShapeUtil::MakeTupleShape(
778                         {infeed->infeed_shape(), ShapeUtil::MakeTokenShape()}));
779 }
780 
HandleOutfeed(HloInstruction * instruction)781 Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
782   HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction);
783   TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1));
784 
785   // Outfeed has a separate shape field for the value which is outfed to the
786   // host. The shape of the instruction itself is always a token.
787   if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) {
788     return InternalError(
789         "Expected outfeed shape to be equal to operand's shape %s, "
790         "actual shape is %s:\n%s",
791         StringifyShape(outfeed->operand(0)->shape()),
792         StringifyShape(outfeed->outfeed_shape()), outfeed->ToString());
793   }
794   return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
795 }
796 
HasCompatibleElementTypes(const Shape & shape_0,const Shape & shape_1,const Shape & result_shape)797 bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
798                                               const Shape& shape_1,
799                                               const Shape& result_shape) {
800   return ShapeUtil::SameElementType(shape_0, shape_1) &&
801          (ShapeUtil::SameElementType(shape_0, result_shape) ||
802           (opts_.allow_mixed_precision &&
803            ShapeUtil::SameElementTypeIgnoringFpPrecision(shape_0,
804                                                          result_shape)));
805 }
806 
HandleRng(HloInstruction * instruction)807 Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
808   TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2));
809 
810   const Shape& shape_0 = instruction->operand(0)->shape();
811   const Shape& shape_1 = instruction->operand(1)->shape();
812   if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) {
813     return InternalError(
814         "Expected scalar types for the two operands of Rng instruction: %s",
815         instruction->ToString());
816   }
817 
818   if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) {
819     return InternalError(
820         "Expected compatible element types for the result and the two operands"
821         " of Rng instruction: %s",
822         instruction->ToString());
823   }
824 
825   PrimitiveType element_type = shape_0.element_type();
826   switch (instruction->random_distribution()) {
827     case RNG_UNIFORM:
828       if (!primitive_util::IsFloatingPointType(element_type) &&
829           !primitive_util::IsIntegralType(element_type) &&
830           element_type != PRED) {
831         return InternalError(
832             "Element type not supported."
833             " Expected element to be of floating point type, integral type or"
834             " predicate type for RngUniform: %s",
835             instruction->ToString());
836       }
837       break;
838 
839     case RNG_NORMAL:
840       if (!primitive_util::IsFloatingPointType(element_type)) {
841         return InternalError(
842             "Element type not supported."
843             " Expected element to be FloatingPointType for RngNormal: %s",
844             instruction->ToString());
845       }
846       break;
847     default:
848       return InternalError(
849           "Invalid Rng distribution %s",
850           RandomDistribution_Name(instruction->random_distribution()));
851   }
852 
853   return OkStatus();
854 }
855 
HandleRngBitGenerator(HloInstruction * hlo)856 Status ShapeVerifier::HandleRngBitGenerator(HloInstruction* hlo) {
857   if (!hlo->shape().IsTuple()) {
858     return OkStatus();
859   }
860   if (hlo->shape().IsTuple() && hlo->shape().tuple_shapes_size() != 2) {
861     return InternalError(
862         "Expected tuple shape with 2 elements for RngBitGenerator. Got: %s",
863         hlo->shape().ToString());
864   }
865   if (!ShapeUtil::Compatible(hlo->operand(0)->shape(),
866                              hlo->shape().tuple_shapes(0))) {
867     return InternalError(
868         "Expected state shape to match between input and output for "
869         "RngBitGenerator. Got %s vs. %s",
870         hlo->operand(0)->shape().ToString(),
871         hlo->shape().tuple_shapes(0).ToString());
872   }
873   return OkStatus();
874 }
875 
HandleRngGetAndUpdateState(HloInstruction * instruction)876 Status ShapeVerifier::HandleRngGetAndUpdateState(HloInstruction* instruction) {
877   TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0));
878   const Shape& result_shape = instruction->shape();
879   const Shape expected_shape = ShapeUtil::MakeShape(U64, {2});
880   if (!ShapeUtil::Compatible(result_shape, expected_shape)) {
881     return InternalError(
882         "Invalid RngGetAndUpdateState, expect result to have shape %s, got %s ",
883         StringifyShape(expected_shape), StringifyShape(result_shape));
884   }
885 
886   return OkStatus();
887 }
888 
HandleReverse(HloInstruction * reverse)889 Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
890   return CheckShape(
891       reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(),
892                                                  reverse->dimensions()));
893 }
894 
HandleSort(HloInstruction * hlo)895 Status ShapeVerifier::HandleSort(HloInstruction* hlo) {
896   HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
897   if (sort->operand_count() < 1) {
898     return InternalError("Expected at least 1 operand for %s instruction: %s",
899                          HloOpcodeString(sort->opcode()), sort->ToString());
900   }
901   HloComputation* compare = sort->to_apply();
902 
903   // Check that the 'compare' computation returns a PRED.
904   Shape compare_shape = compare->root_instruction()->shape();
905   if (!ShapeUtil::Compatible(compare_shape, ShapeUtil::MakeShape(PRED, {}))) {
906     return InternalError(
907         "The Sort compare computation shape does not lead to a scalar "
908         "predicate shape: %s",
909         StringifyShape(compare_shape));
910   }
911 
912   // Check that the number of parameters of the 'compare' computation is
913   // correct.
914   TF_RETURN_IF_ERROR(
915       CheckParameterCount(sort, compare, sort->operand_count() * 2));
916 
917   // Verify that the operands of the compare computation have the correct scalar
918   // shapes.
919   for (int64_t parameter_idx = 0; parameter_idx < compare->num_parameters();
920        ++parameter_idx) {
921     int64_t operand_idx = parameter_idx / 2;
922     Shape expected_scalar_shape = ShapeUtil::MakeShape(
923         sort->operand(operand_idx)->shape().element_type(), {});
924     Shape actual_parameter_shape =
925         compare->parameter_instruction(parameter_idx)->shape();
926     if (!ShapeUtil::CompatibleIgnoringFpPrecision(expected_scalar_shape,
927                                                   actual_parameter_shape)) {
928       return InternalError(
929           "Expected the %lld-th parameter of the compare computation of sort "
930           "to have shape %s, but got %s",
931           parameter_idx, StringifyShape(expected_scalar_shape),
932           StringifyShape(actual_parameter_shape));
933     }
934   }
935 
936   // Verify that all operand shapes have the same dimensions.
937   for (int64_t operand = 1; operand < sort->operand_count(); ++operand) {
938     if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(),
939                                    sort->operand(operand)->shape())) {
940       return InternalError(
941           "Expected sort to have to have the same dimensions for all operands. "
942           "First operand shape is: %s\n, shape (operand index %lld) is: %s",
943           StringifyShape(sort->operand(0)->shape()), operand,
944           StringifyShape(sort->operand(operand)->shape()));
945     }
946   }
947 
948   // Verify the sort_dimension.
949   if (sort->sort_dimension() >= sort->operand(0)->shape().rank()) {
950     return InternalError(
951         "Expected the sort_dimension %d of sort to be smaller than the rank %d "
952         "of the operand(s).",
953         sort->sort_dimension(), sort->shape().rank());
954   }
955 
956   return CheckVariadicShape(sort);
957 }
958 
HandleConstant(HloInstruction * constant)959 Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
960   if (!Cast<HloConstantInstruction>(constant)->HasLiteral()) {
961     return InternalError("Constant is required to have a valid literal: %s",
962                          constant->ToString());
963   }
964   return CheckShape(constant, constant->literal().shape(),
965                     /*only_compare_minor_to_major_in_layout=*/true);
966 }
967 
HandleIota(HloInstruction * hlo)968 Status ShapeVerifier::HandleIota(HloInstruction* hlo) {
969   auto* iota = Cast<HloIotaInstruction>(hlo);
970   if (!iota->shape().IsArray()) {
971     return InternalError("Iota does not support non-array result.");
972   }
973   const int64_t rank = iota->shape().rank();
974   if (rank == 0) {
975     return InternalError("Iota does not support scalars.");
976   }
977   int64_t iota_dimension = iota->iota_dimension();
978   if (iota_dimension >= rank || iota_dimension < 0) {
979     return InternalError(
980         "The iota dimension cannot go beyond the operation rank or be "
981         "negative.");
982   }
983 
984   PrimitiveType primitive_type = iota->shape().element_type();
985   if (!primitive_util::IsIntegralType(primitive_type) &&
986       !primitive_util::IsFloatingPointType(primitive_type) &&
987       !primitive_util::IsComplexType(primitive_type)) {
988     return InvalidArgument(
989         "Only support iota of integral, floating point or complex primitive "
990         "types, got %s",
991         PrimitiveType_Name(primitive_type));
992   }
993 
994   return OkStatus();
995 }
996 
HandleGetTupleElement(HloInstruction * get_tuple_element)997 Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
998   return CheckShape(get_tuple_element,
999                     ShapeInference::InferGetTupleElementShape(
1000                         get_tuple_element->operand(0)->shape(),
1001                         get_tuple_element->tuple_index()));
1002 }
1003 
1004 namespace {
SameElementTypesForOperandsAndToApplyParameters(const HloInstruction & instruction,int64_t num_operands_to_check)1005 Status SameElementTypesForOperandsAndToApplyParameters(
1006     const HloInstruction& instruction, int64_t num_operands_to_check) {
1007   const ProgramShape& to_apply = instruction.to_apply()->ComputeProgramShape();
1008   for (int i = 0; i < num_operands_to_check; ++i) {
1009     const Shape& parameter_shape = to_apply.parameters(i);
1010     const Shape& operand_shape = instruction.operands()[i]->shape();
1011     if (!ShapeUtil::SameElementType(parameter_shape, operand_shape)) {
1012       return InvalidArgument(
1013           "Shape mismatch between to_apply computation"
1014           " parameter and operand %d in %s.",
1015           i, instruction.ToString().c_str());
1016     }
1017   }
1018   return OkStatus();
1019 }
1020 }  // namespace
1021 
HandleReduce(HloInstruction * reduce)1022 Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
1023   if (reduce->operand_count() % 2 != 0) {
1024     return InternalError(
1025         "Expected an even number of operands for %s instruction: %s",
1026         HloOpcodeString(reduce->opcode()), reduce->ToString());
1027   }
1028 
1029   std::vector<const Shape*> operand_shapes;
1030   for (const HloInstruction* operand : reduce->operands()) {
1031     operand_shapes.push_back(&operand->shape());
1032   }
1033   TF_RETURN_IF_ERROR(
1034       CheckShape(reduce, ShapeInference::InferReduceShape(
1035                              operand_shapes, reduce->dimensions(),
1036                              reduce->to_apply()->ComputeProgramShape())));
1037 
1038   return opts_.allow_mixed_precision
1039              ? OkStatus()
1040              : SameElementTypesForOperandsAndToApplyParameters(
1041                    *reduce, reduce->operand_count());
1042 }
1043 
HandleBitcast(HloInstruction * bitcast)1044 Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
1045   const Shape& output_shape = bitcast->shape();
1046   const Shape& operand_shape = bitcast->operand(0)->shape();
1047   if (opts_.layout_sensitive &&
1048       opts_.shape_size(output_shape) != opts_.shape_size(operand_shape)) {
1049     // Allow bitcast that has the same data size but different trailing
1050     // paddings.
1051     if (!opts_.allow_bitcast_to_have_different_size ||
1052         !(output_shape.is_static() && operand_shape.is_static() &&
1053           (ShapeUtil::ArrayDataSize(output_shape) ==
1054            ShapeUtil::ArrayDataSize(operand_shape)))) {
1055       return InternalError(
1056           "Bitcast cannot have different shape sizes of output (%d) and "
1057           "operand "
1058           "(%d) (%s) (%s)",
1059           opts_.shape_size(output_shape), opts_.shape_size(operand_shape),
1060           output_shape.ToString(true), operand_shape.ToString(true));
1061     }
1062   }
1063   return OkStatus();
1064 }
1065 
HandleBroadcast(HloInstruction * broadcast)1066 Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
1067   // HLO broadcast has no exact analog at the client level so there is no
1068   // ShapeInference method. Check the output shape explicitly.
1069   const Shape& operand_shape = broadcast->operand(0)->shape();
1070   // Check for mixed precision.
1071   TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape));
1072   TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size());
1073   for (int64_t operand_dimension = 0; operand_dimension < operand_shape.rank();
1074        ++operand_dimension) {
1075     int64_t output_dimension = broadcast->dimensions()[operand_dimension];
1076     TF_RET_CHECK((output_dimension < broadcast->shape().rank()) &&
1077                  output_dimension >= 0 &&
1078                  (broadcast->shape().dimensions(output_dimension) ==
1079                   operand_shape.dimensions(operand_dimension)))
1080         << broadcast->ToString() << " operand shape " << operand_shape;
1081   }
1082   return OkStatus();
1083 }
1084 
HandleDynamicReshape(HloInstruction * dynamic_reshape)1085 Status ShapeVerifier::HandleDynamicReshape(HloInstruction* dynamic_reshape) {
1086   // Check for mixed precision.
1087   const Shape& operand_shape = dynamic_reshape->operand(0)->shape();
1088   TF_RET_CHECK(SameElementType(dynamic_reshape->shape(), operand_shape));
1089   TF_RET_CHECK(ShapeUtil::ElementsIn(dynamic_reshape->shape()) ==
1090                ShapeUtil::ElementsIn(operand_shape));
1091   TF_RET_CHECK(dynamic_reshape->shape().rank() + 1 ==
1092                dynamic_reshape->operand_count());
1093   for (int64_t i = 1; i < dynamic_reshape->operand_count(); ++i) {
1094     TF_RET_CHECK(dynamic_reshape->operand(i)->shape().element_type() == S32);
1095   }
1096   return OkStatus();
1097 }
1098 
HandleReshape(HloInstruction * reshape)1099 Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
1100   // Check for mixed precision.
1101   const Shape& operand_shape = reshape->operand(0)->shape();
1102   TF_RET_CHECK(SameElementType(reshape->shape(), operand_shape));
1103   TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
1104                ShapeUtil::ElementsIn(operand_shape));
1105   return OkStatus();
1106 }
1107 
HandleTranspose(HloInstruction * transpose)1108 Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
1109   return CheckShape(
1110       transpose, ShapeInference::InferTransposeShape(
1111                      transpose->operand(0)->shape(), transpose->dimensions()));
1112 }
1113 
HandleParameter(HloInstruction * hlo)1114 Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
1115   return OkStatus();
1116 }
1117 
HandleFusion(HloInstruction * fusion)1118 Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
1119   if (fusion->called_computations().size() != 1) {
1120     return InternalError(
1121         "Fusion has a non-unary number of called computations (%s)",
1122         fusion->ToString().c_str());
1123   }
1124   const Shape& root_computation_shape =
1125       fusion->called_computations()[0]->root_instruction()->shape();
1126   if (!ShapesSame(fusion->shape(), root_computation_shape)) {
1127     return InternalError(
1128         "Fused computation shape (%s) is not equal to the fusion shape (%s)",
1129         root_computation_shape.ToString(true), fusion->shape().ToString(true));
1130   }
1131 
1132   auto& fused_parameters = fusion->fused_parameters();
1133   if (fused_parameters.size() != fusion->operand_count()) {
1134     return InternalError(
1135         "Fused parameter count (%d) does not match the number of operands (%d)"
1136         " passed to the fusion instruction in: %s.",
1137         fused_parameters.size(), fusion->operand_count(),
1138         fusion->ToString().c_str());
1139   }
1140   for (HloInstruction* fused_param : fused_parameters) {
1141     int64_t param_no = fused_param->parameter_number();
1142     if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) {
1143       return InternalError(
1144           "Shape mismatch between parameter number %d and its operand in "
1145           "%s.",
1146           param_no, fusion->ToString().c_str());
1147     }
1148   }
1149   return OkStatus();
1150 }
1151 
HandleCall(HloInstruction * call)1152 Status ShapeVerifier::HandleCall(HloInstruction* call) {
1153   TF_RETURN_IF_ERROR(
1154       CheckParameterCount(call, call->to_apply(), call->operand_count()));
1155   for (int64_t i = 0; i < call->to_apply()->num_parameters(); ++i) {
1156     TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
1157   }
1158   // The shape of kCall should match the shape of the computation it calls.
1159   return CheckShape(call, call->to_apply()->root_instruction()->shape());
1160 }
1161 
HandleCustomCall(HloInstruction * instruction)1162 Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
1163   const HloCustomCallInstruction* custom_call =
1164       DynCast<const HloCustomCallInstruction>(instruction);
1165   TF_RET_CHECK(custom_call != nullptr);
1166   if (custom_call->layout_constrained()) {
1167     // If the layout is constrained, verify all the respective shapes have
1168     // layouts and that the constrained operand shapes match the shapes of the
1169     // operands.
1170     TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape()));
1171     TF_RET_CHECK(custom_call->operand_count() ==
1172                  custom_call->operand_shapes_with_layout().size());
1173     for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
1174       const Shape& operand_shape_with_layout =
1175           custom_call->operand_shapes_with_layout()[i];
1176       TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(),
1177                                          operand_shape_with_layout))
1178           << custom_call->operand(i)->shape().ToString() << " operand "
1179           << operand_shape_with_layout.ToString();
1180       TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
1181     }
1182   }
1183   for (const auto& pair : custom_call->output_to_operand_aliasing()) {
1184     TF_RET_CHECK(pair.second.first < custom_call->operand_count())
1185         << "Invalid aliasing operand index.";
1186     TF_RET_CHECK(ShapeUtil::IndexIsValid(
1187         custom_call->operand(pair.second.first)->shape(), pair.second.second))
1188         << "Invalid aliasing operand shape index.";
1189     TF_RET_CHECK(ShapeUtil::IndexIsValid(custom_call->shape(), pair.first))
1190         << "Invalid aliasing output shape index.";
1191     const Shape& output_subshape =
1192         ShapeUtil::GetSubshape(custom_call->shape(), pair.first);
1193     const Shape& operand_subshape = ShapeUtil::GetSubshape(
1194         custom_call->operand(pair.second.first)->shape(), pair.second.second);
1195     if (opts_.layout_sensitive) {
1196       TF_RET_CHECK(operand_subshape == output_subshape)
1197           << "Different aliasing shapes: " << operand_subshape.ToString()
1198           << " vs " << output_subshape.ToString();
1199     } else {
1200       TF_RET_CHECK(ShapeUtil::Compatible(output_subshape, operand_subshape))
1201           << "Different aliasing shapes: " << operand_subshape.ToString()
1202           << " vs " << output_subshape.ToString();
1203     }
1204   }
1205   return OkStatus();
1206 }
1207 
HandleSlice(HloInstruction * slice)1208 Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
1209   return CheckShape(slice,
1210                     ShapeInference::InferSliceShape(
1211                         slice->operand(0)->shape(), slice->slice_starts(),
1212                         slice->slice_limits(), slice->slice_strides()));
1213 }
1214 
HandleDynamicSlice(HloInstruction * dynamic_slice)1215 Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) {
1216   return CheckShape(
1217       dynamic_slice,
1218       ShapeInference::InferDynamicSliceShape(
1219           dynamic_slice->operand(0)->shape(),
1220           Cast<HloDynamicSliceInstruction>(dynamic_slice)->index_shapes(),
1221           dynamic_slice->dynamic_slice_sizes()));
1222 }
1223 
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)1224 Status ShapeVerifier::HandleDynamicUpdateSlice(
1225     HloInstruction* dynamic_update_slice) {
1226   return CheckShape(
1227       dynamic_update_slice,
1228       ShapeInference::InferDynamicUpdateSliceShape(
1229           dynamic_update_slice->operand(0)->shape(),
1230           dynamic_update_slice->operand(1)->shape(),
1231           Cast<HloDynamicUpdateSliceInstruction>(dynamic_update_slice)
1232               ->index_shapes()));
1233 }
1234 
HandleTuple(HloInstruction * tuple)1235 Status ShapeVerifier::HandleTuple(HloInstruction* tuple) {
1236   return CheckVariadicShape(tuple);
1237 }
1238 
HandleMap(HloInstruction * map)1239 Status ShapeVerifier::HandleMap(HloInstruction* map) {
1240   std::vector<const Shape*> operand_shapes;
1241   int64_t max_operand_rank = 0;
1242   for (const HloInstruction* operand : map->operands()) {
1243     operand_shapes.push_back(&operand->shape());
1244     max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
1245   }
1246   // TODO(b/65689298) Remove code below once Map is generalized to accept
1247   // arbitrary map dimensions.
1248   std::vector<int64_t> map_dims(max_operand_rank);
1249   std::iota(map_dims.begin(), map_dims.end(), 0);
1250 
1251   TF_RETURN_IF_ERROR(CheckShape(
1252       map,
1253       ShapeInference::InferMapShape(
1254           operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims)));
1255 
1256   return opts_.allow_mixed_precision
1257              ? OkStatus()
1258              : SameElementTypesForOperandsAndToApplyParameters(
1259                    *map, map->operand_count());
1260 }
1261 
HandleReduceWindow(HloInstruction * reduce_window)1262 Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
1263   VLOG(2) << "Verify reduce window:" << reduce_window->ToString() << "\n";
1264   auto reduce_window_instr = Cast<HloReduceWindowInstruction>(reduce_window);
1265   auto input_shapes = reduce_window_instr->input_shapes();
1266   VLOG(2) << "reduce window input shape count: " << input_shapes.size() << "\n";
1267   auto init_shapes = reduce_window_instr->init_value_shapes();
1268   VLOG(2) << "reduce instruction is :" << reduce_window->ToString() << "\n";
1269   TF_RETURN_IF_ERROR(CheckShape(
1270       reduce_window, ShapeInference::InferReduceWindowShape(
1271                          input_shapes, init_shapes, reduce_window->window(),
1272                          reduce_window->to_apply()->ComputeProgramShape())));
1273 
1274   return opts_.allow_mixed_precision
1275              ? OkStatus()
1276              : SameElementTypesForOperandsAndToApplyParameters(
1277                    *reduce_window, reduce_window->operand_count());
1278 }
1279 
HandleSelectAndScatter(HloInstruction * instruction)1280 Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
1281   return CheckShape(
1282       instruction,
1283       ShapeInference::InferSelectAndScatterShape(
1284           instruction->operand(0)->shape(),
1285           instruction->select()->ComputeProgramShape(), instruction->window(),
1286           instruction->operand(1)->shape(), instruction->operand(2)->shape(),
1287           instruction->scatter()->ComputeProgramShape()));
1288 }
1289 
HandleWhile(HloInstruction * xla_while)1290 Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
1291   TF_RETURN_IF_ERROR(
1292       CheckParameterCount(xla_while, xla_while->while_body(), 1));
1293   TF_RETURN_IF_ERROR(
1294       CheckParameterCount(xla_while, xla_while->while_condition(), 1));
1295   TF_RETURN_IF_ERROR(
1296       CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
1297   TF_RETURN_IF_ERROR(
1298       CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
1299   const Shape& conditional_shape =
1300       xla_while->while_condition()->root_instruction()->shape();
1301   if (!ShapeUtil::Compatible(conditional_shape,
1302                              ShapeUtil::MakeShape(PRED, {}))) {
1303     return InternalError(
1304         "Conditional computation shape does not lead to a scalar predicate "
1305         "shape: %s",
1306         StringifyShape(conditional_shape));
1307   }
1308   // The shape of kWhile should match the shape of the body computation it
1309   // calls.
1310   return CheckShape(xla_while,
1311                     xla_while->while_body()->root_instruction()->shape());
1312 }
1313 
HandleConditional(HloInstruction * conditional)1314 Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
1315   if (!ShapeUtil::IsScalar(conditional->operand(0)->shape())) {
1316     return InvalidArgument(
1317         "The first operand of conditional must be a scalar. Got %s",
1318         conditional->operand(0)->shape().DebugString());
1319   }
1320   const int num_branches = conditional->branch_count();
1321   PrimitiveType operand0_type = conditional->operand(0)->shape().element_type();
1322   if (operand0_type == PRED) {
1323     TF_RET_CHECK(num_branches == 2);
1324   } else {
1325     if (operand0_type != S32) {
1326       return InvalidArgument(
1327           "The first operand of indexed conditional must be a scalar of S32. "
1328           "Got"
1329           " type %s.",
1330           PrimitiveType_Name(operand0_type));
1331     }
1332     TF_RET_CHECK(num_branches >= 1);
1333   }
1334   TF_RETURN_IF_ERROR(CheckOperandCount(conditional, num_branches + 1));
1335   for (int j = 0; j < num_branches; ++j) {
1336     TF_RETURN_IF_ERROR(CheckParameterCount(
1337         conditional, conditional->branch_computation(j), 1));
1338     TF_RETURN_IF_ERROR(CheckOperandAndParameter(
1339         conditional, j + 1, conditional->branch_computation(j), 0));
1340     TF_RETURN_IF_ERROR(CheckShape(
1341         conditional,
1342         conditional->branch_computation(j)->root_instruction()->shape()));
1343   }
1344   return OkStatus();
1345 }
1346 
HandlePad(HloInstruction * pad)1347 Status ShapeVerifier::HandlePad(HloInstruction* pad) {
1348   return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(),
1349                                                        pad->operand(1)->shape(),
1350                                                        pad->padding_config()));
1351 }
1352 
1353 namespace {
CheckAsyncOpOperand(const HloInstruction * async_op)1354 Status CheckAsyncOpOperand(const HloInstruction* async_op) {
1355   const HloInstruction* operand = async_op->operand(0);
1356   if (operand->opcode() != HloOpcode::kAsyncStart &&
1357       operand->opcode() != HloOpcode::kAsyncUpdate) {
1358     return InternalError(
1359         "%s expects operand to be async-update or async-done, found "
1360         "%s.",
1361         HloOpcodeString(async_op->opcode()),
1362         HloOpcodeString(operand->opcode()));
1363   }
1364   if (*async_op->async_wrapped_computation() !=
1365       *operand->async_wrapped_computation()) {
1366     return InternalError(
1367         "The %s expects its wrapped async computation to be identical to its "
1368         "operand's wrapped async computation (%s vs %s), thread name (%s vs "
1369         "%s).",
1370         HloOpcodeString(async_op->opcode()),
1371         async_op->async_wrapped_instruction()->ToString(),
1372         operand->async_wrapped_instruction()->ToString(),
1373         async_op->async_wrapped_computation()->execution_thread(),
1374         operand->async_wrapped_computation()->execution_thread());
1375   }
1376   if (async_op->async_group_id() != operand->async_group_id()) {
1377     return InternalError(
1378         "%s expects its operand to have the same group id (%s vs %s).",
1379         HloOpcodeString(async_op->opcode()),
1380         async_op->async_group_id() ? absl::StrCat(*async_op->async_group_id())
1381                                    : "none",
1382         operand->async_group_id() ? absl::StrCat(*operand->async_group_id())
1383                                   : "none");
1384   }
1385   return OkStatus();
1386 }
1387 
CheckAsyncOpComputationShapes(const HloInstruction * async_op,const Shape & async_shape)1388 Status CheckAsyncOpComputationShapes(const HloInstruction* async_op,
1389                                      const Shape& async_shape) {
1390   if (!async_shape.IsTuple() || async_shape.tuple_shapes_size() < 2) {
1391     return InternalError(
1392         "The %s expects the async shape to be a tuple of at least two "
1393         "elements, found %s.",
1394         HloOpcodeString(async_op->opcode()), async_shape.ToString());
1395   }
1396   ProgramShape computation_shape =
1397       async_op->async_wrapped_computation()->ComputeProgramShape();
1398   Shape param_shape = ShapeUtil::MakeTupleShape(computation_shape.parameters());
1399   if (async_shape.tuple_shapes(0) != param_shape) {
1400     return InternalError(
1401         "The %s expects the async shape at index {0} to match async "
1402         "computation parameter shape (%s vs %s).",
1403         HloOpcodeString(async_op->opcode()),
1404         async_shape.tuple_shapes(0).ToString(/*print_layout=*/true),
1405         param_shape.ToString(/*print_layout=*/true));
1406   }
1407   if (async_shape.tuple_shapes(1) != computation_shape.result()) {
1408     return InternalError(
1409         "The %s expects the async shape at index {1} to match the async "
1410         "computation root shape (%s vs %s).",
1411         HloOpcodeString(async_op->opcode()),
1412         async_shape.tuple_shapes(1).ToString(/*print_layout=*/true),
1413         computation_shape.result().ToString(/*print_layout=*/true));
1414   }
1415   return Status::OK();
1416 }
1417 
CheckAsyncOpComputationThreadName(const HloInstruction * async_op)1418 Status CheckAsyncOpComputationThreadName(const HloInstruction* async_op) {
1419   std::optional<absl::string_view> async_execution_thread =
1420       async_op->async_execution_thread();
1421   if (async_execution_thread !=
1422       async_op->async_wrapped_computation()->execution_thread()) {
1423     return InternalError(
1424         "async-start expects same async thread name as wrapped computation's "
1425         "thread name (%s vs %s).",
1426         async_execution_thread ? absl::StrCat(*async_execution_thread) : "none",
1427         async_op->async_wrapped_computation()->execution_thread());
1428   }
1429   return CheckNestedComputationThreadNameEqual(
1430       async_op->async_wrapped_computation(),
1431       /*skip_nested_async_op_check=*/false);
1432 }
1433 
CheckCallableInstructionThreadName(const HloInstruction * instruction,bool skip_nested_async_op_check)1434 Status CheckCallableInstructionThreadName(const HloInstruction* instruction,
1435                                           bool skip_nested_async_op_check) {
1436   for (const HloComputation* computation : instruction->called_computations()) {
1437     if (instruction->parent() != nullptr) {
1438       if (instruction->parent()->execution_thread() !=
1439           computation->execution_thread()) {
1440         return InternalError(
1441             "callable instruction %s expects parent computation thread name "
1442             "same as called computation's thread name (%s vs %s).",
1443             instruction->ToString(), instruction->parent()->execution_thread(),
1444             computation->execution_thread());
1445       }
1446     }
1447     TF_RETURN_IF_ERROR(CheckNestedComputationThreadNameEqual(
1448         computation, skip_nested_async_op_check));
1449   }
1450   return Status::OK();
1451 }
1452 }  // namespace
1453 
HandleAsyncStart(HloInstruction * async_start)1454 Status ShapeVerifier::HandleAsyncStart(HloInstruction* async_start) {
1455   TF_RETURN_IF_ERROR(
1456       CheckAsyncOpComputationShapes(async_start, async_start->shape()));
1457   TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_start));
1458   const Shape& param_shape = async_start->shape().tuple_shapes(0);
1459   for (int i = 0; i < async_start->operand_count(); ++i) {
1460     if (param_shape.tuple_shapes(i) != async_start->operand(i)->shape()) {
1461       return InternalError(
1462           "The %s expects the shape of operand %d to match the async shape at "
1463           "index {0} (%s vs %s).",
1464           HloOpcodeString(async_start->opcode()), i,
1465           async_start->operand(i)->shape().ToString(/*print_layout=*/true),
1466           param_shape.tuple_shapes(i).ToString(/*print_layout=*/true));
1467     }
1468   }
1469   return Status::OK();
1470 }
1471 
HandleAsyncUpdate(HloInstruction * async_update)1472 Status ShapeVerifier::HandleAsyncUpdate(HloInstruction* async_update) {
1473   TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_update));
1474   if (async_update->operand(0)->shape() != async_update->shape()) {
1475     return InternalError(
1476         "The %s expects the shape of operand and output to match (%s vs %s).",
1477         HloOpcodeString(async_update->opcode()),
1478         async_update->operand(0)->shape().ToString(),
1479         async_update->shape().ToString());
1480   }
1481   TF_RETURN_IF_ERROR(
1482       CheckAsyncOpComputationShapes(async_update, async_update->shape()));
1483   return CheckAsyncOpOperand(async_update);
1484 }
1485 
HandleAsyncDone(HloInstruction * async_done)1486 Status ShapeVerifier::HandleAsyncDone(HloInstruction* async_done) {
1487   TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_done));
1488   TF_RETURN_IF_ERROR(CheckAsyncOpComputationShapes(
1489       async_done, async_done->operand(0)->shape()));
1490   const Shape& root_shape = async_done->operand(0)->shape().tuple_shapes(1);
1491   if (root_shape != async_done->shape()) {
1492     return InternalError(
1493         "The %s expects the shape of output to match the async shape at index "
1494         "{1} (%s vs %s).",
1495         HloOpcodeString(async_done->opcode()), async_done->shape().ToString(),
1496         root_shape.ToString());
1497   }
1498   return CheckAsyncOpOperand(async_done);
1499 }
1500 
HandleCopyStart(HloInstruction * copy_start)1501 Status ShapeVerifier::HandleCopyStart(HloInstruction* copy_start) {
1502   return CheckShape(copy_start,
1503                     ShapeUtil::MakeTupleShape({copy_start->operand(0)->shape(),
1504                                                copy_start->operand(0)->shape(),
1505                                                ShapeUtil::MakeShape(U32, {})}),
1506                     /*only_compare_minor_to_major_in_layout=*/true);
1507 }
1508 
HandleCopyDone(HloInstruction * copy_done)1509 Status ShapeVerifier::HandleCopyDone(HloInstruction* copy_done) {
1510   const Shape& operand_shape = copy_done->operand(0)->shape();
1511   const Shape& dest_shape = ShapeUtil::GetTupleElementShape(operand_shape, 0);
1512   const Shape& src_shape = ShapeUtil::GetTupleElementShape(operand_shape, 1);
1513   if (!ShapesSame(dest_shape, src_shape,
1514                   /*minor_to_major_only=*/false,
1515                   /*ignore_memory_space=*/true)) {
1516     return InternalError(
1517         "Source and destination buffers in CopyDone arguments need to be the "
1518         "same shape found %s and %s\n%s",
1519         StringifyShape(dest_shape), StringifyShape(src_shape),
1520         copy_done->ToString());
1521   }
1522   return CheckShape(copy_done, ShapeUtil::GetTupleElementShape(
1523                                    copy_done->operand(0)->shape(), 0));
1524 }
1525 
HandleSend(HloInstruction * send)1526 Status ShapeVerifier::HandleSend(HloInstruction* send) {
1527   return CheckShape(send,
1528                     ShapeUtil::MakeTupleShape({send->operand(0)->shape(),
1529                                                ShapeUtil::MakeShape(U32, {}),
1530                                                ShapeUtil::MakeTokenShape()}),
1531                     /*only_compare_minor_to_major_in_layout=*/true);
1532 }
1533 
HandleSendDone(HloInstruction * send_done)1534 Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
1535   return CheckShape(send_done, ShapeUtil::MakeTokenShape());
1536 }
1537 
HandleRecv(HloInstruction * recv)1538 Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
1539   return CheckShape(
1540       recv,
1541       ShapeUtil::MakeTupleShape(
1542           {ShapeUtil::GetTupleElementShape(recv->shape(), 0),
1543            ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}),
1544       /*only_compare_minor_to_major_in_layout=*/true);
1545 }
1546 
HandleRecvDone(HloInstruction * recv_done)1547 Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
1548   return CheckShape(
1549       recv_done,
1550       ShapeUtil::MakeTupleShape(
1551           {ShapeUtil::GetTupleElementShape(recv_done->operand(0)->shape(), 0),
1552            ShapeUtil::MakeTokenShape()}));
1553 }
1554 
HandleBatchNormTraining(HloInstruction * batch_norm_training)1555 Status ShapeVerifier::HandleBatchNormTraining(
1556     HloInstruction* batch_norm_training) {
1557   return CheckShape(batch_norm_training,
1558                     ShapeInference::InferBatchNormTrainingShape(
1559                         batch_norm_training->operand(0)->shape(),
1560                         batch_norm_training->operand(1)->shape(),
1561                         batch_norm_training->operand(2)->shape(),
1562                         batch_norm_training->feature_index()));
1563 }
1564 
HandleBatchNormInference(HloInstruction * batch_norm_inference)1565 Status ShapeVerifier::HandleBatchNormInference(
1566     HloInstruction* batch_norm_inference) {
1567   return CheckShape(batch_norm_inference,
1568                     ShapeInference::InferBatchNormInferenceShape(
1569                         batch_norm_inference->operand(0)->shape(),
1570                         batch_norm_inference->operand(1)->shape(),
1571                         batch_norm_inference->operand(2)->shape(),
1572                         batch_norm_inference->operand(3)->shape(),
1573                         batch_norm_inference->operand(4)->shape(),
1574                         batch_norm_inference->feature_index()));
1575 }
1576 
HandleBatchNormGrad(HloInstruction * batch_norm_grad)1577 Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
1578   return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape(
1579                                          batch_norm_grad->operand(0)->shape(),
1580                                          batch_norm_grad->operand(1)->shape(),
1581                                          batch_norm_grad->operand(2)->shape(),
1582                                          batch_norm_grad->operand(3)->shape(),
1583                                          batch_norm_grad->operand(4)->shape(),
1584                                          batch_norm_grad->feature_index()));
1585 }
1586 
1587 namespace {
1588 
1589 // Checks that the instruction does not have mixed precision floating point
1590 // inputs.
CheckMixedPrecisionOperands(const HloInstruction * instruction)1591 Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
1592   switch (instruction->opcode()) {
1593     // Allow-list the following opcodes for mixed-precision check, because
1594     // they involve data pass through or grouping via tuples, where the
1595     // precisions of buffers can be different.
1596     case HloOpcode::kCall:
1597     case HloOpcode::kConditional:
1598     case HloOpcode::kConstant:
1599     case HloOpcode::kConvolution:
1600     case HloOpcode::kDot:
1601     case HloOpcode::kAllReduce:
1602     case HloOpcode::kAllReduceStart:
1603     case HloOpcode::kAllReduceDone:
1604     case HloOpcode::kAsyncDone:
1605     case HloOpcode::kAsyncUpdate:
1606     case HloOpcode::kAsyncStart:
1607     case HloOpcode::kCopyDone:
1608     case HloOpcode::kCopyStart:
1609     case HloOpcode::kCustomCall:
1610     case HloOpcode::kDomain:
1611     case HloOpcode::kFusion:
1612     case HloOpcode::kGetTupleElement:
1613     case HloOpcode::kOptimizationBarrier:
1614     case HloOpcode::kInfeed:
1615     case HloOpcode::kOutfeed:
1616     case HloOpcode::kParameter:
1617     case HloOpcode::kRecv:
1618     case HloOpcode::kRecvDone:
1619     case HloOpcode::kReducePrecision:
1620     case HloOpcode::kReduceWindow:
1621     case HloOpcode::kSend:
1622     case HloOpcode::kSendDone:
1623     case HloOpcode::kSort:
1624     case HloOpcode::kTuple:
1625     case HloOpcode::kWhile:
1626       break;
1627     default: {
1628       PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID;
1629       for (auto operand : instruction->operands()) {
1630         TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1631             operand->shape(),
1632             [&](const Shape& subshape, const ShapeIndex& index) {
1633               if (!ShapeUtil::ElementIsFloating(subshape)) {
1634                 return OkStatus();
1635               }
1636               if (fp_type == PRIMITIVE_TYPE_INVALID) {
1637                 fp_type = subshape.element_type();
1638               } else if (fp_type != subshape.element_type()) {
1639                 return InternalError(
1640                     "Seen floating point types of different precisions in "
1641                     "%s, but mixed precision is disallowed.",
1642                     instruction->ToString());
1643               }
1644               return OkStatus();
1645             }));
1646       }
1647     }
1648   }
1649   return OkStatus();
1650 }
1651 
1652 }  // namespace
1653 
HandleGather(HloInstruction * gather)1654 Status ShapeVerifier::HandleGather(HloInstruction* gather) {
1655   return CheckShape(
1656       gather,
1657       ShapeInference::InferGatherShape(
1658           gather->operand(0)->shape(), gather->operand(1)->shape(),
1659           gather->gather_dimension_numbers(), gather->gather_slice_sizes()));
1660 }
1661 
HandleScatter(HloInstruction * scatter)1662 Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
1663   absl::InlinedVector<const Shape*, 3> arg_shapes;
1664   arg_shapes.reserve(scatter->operand_count());
1665   for (const HloInstruction* operand : scatter->operands()) {
1666     arg_shapes.push_back(&operand->shape());
1667   }
1668   return CheckShape(scatter,
1669                     ShapeInference::InferScatterShape(
1670                         arg_shapes, scatter->to_apply()->ComputeProgramShape(),
1671                         scatter->scatter_dimension_numbers()));
1672 }
1673 
HandleAfterAll(HloInstruction * token)1674 Status ShapeVerifier::HandleAfterAll(HloInstruction* token) {
1675   std::vector<const Shape*> operand_shapes;
1676   for (const HloInstruction* operand : token->operands()) {
1677     operand_shapes.push_back(&operand->shape());
1678   }
1679   return CheckShape(token, ShapeUtil::MakeTokenShape());
1680 }
1681 
HandleAddDependency(HloInstruction * add_dependency)1682 Status ShapeVerifier::HandleAddDependency(HloInstruction* add_dependency) {
1683   TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1));
1684   return CheckShape(add_dependency, add_dependency->operand(0)->shape());
1685 }
1686 
HandleGetDimensionSize(HloInstruction * get_size)1687 Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
1688   return CheckShape(get_size,
1689                     ShapeInference::InferGetDimensionSizeShape(
1690                         get_size->operand(0)->shape(), get_size->dimension()));
1691 }
1692 
HandleSetDimensionSize(HloInstruction * set_size)1693 Status ShapeVerifier::HandleSetDimensionSize(HloInstruction* set_size) {
1694   return CheckShape(set_size,
1695                     ShapeInference::InferSetDimensionSizeShape(
1696                         set_size->operand(0)->shape(),
1697                         set_size->operand(1)->shape(), set_size->dimension()));
1698 }
1699 
CheckShape(const HloInstruction * instruction,const Shape & inferred_shape,bool only_compare_minor_to_major_in_layout)1700 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
1701                                  const Shape& inferred_shape,
1702                                  bool only_compare_minor_to_major_in_layout) {
1703   // If allow_mixed_precision_ is false, check if there are operands with
1704   // different precisions. We need this check because ShapeInference allows
1705   // mixed precision inputs.
1706   if (!opts_.allow_mixed_precision) {
1707     TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction));
1708   }
1709 
1710   // Check if the output shape matches the expected shape.
1711   //
1712   // We treat BF16 and F32 as compatible types if mixed precision is allowed,
1713   // but only when the instruction defines the BF16/F32 buffer.
1714   bool equal = [&] {
1715     switch (instruction->opcode()) {
1716       // The opcodes below can't have implicit layout conversions, nor can they
1717       // implicitly transform f32 -> bf16.  Fundamentally these are either
1718       // reinterpreting existing data (e.g. kBitcast) or shuffling data around
1719       // without modifying it (e.g. kGetTupleElement).
1720       case HloOpcode::kBitcast:
1721       case HloOpcode::kCall:
1722       case HloOpcode::kConditional:
1723       case HloOpcode::kConstant:
1724       case HloOpcode::kCopyDone:
1725       case HloOpcode::kCopyStart:
1726       case HloOpcode::kCustomCall:
1727       case HloOpcode::kGetTupleElement:
1728       case HloOpcode::kInfeed:
1729       case HloOpcode::kOutfeed:
1730       case HloOpcode::kOptimizationBarrier:
1731       case HloOpcode::kParameter:
1732       case HloOpcode::kRecv:
1733       case HloOpcode::kRecvDone:
1734       case HloOpcode::kSend:
1735       case HloOpcode::kSendDone:
1736       case HloOpcode::kTuple:
1737       case HloOpcode::kWhile:
1738         return ShapesSame(instruction->shape(), inferred_shape,
1739                           only_compare_minor_to_major_in_layout);
1740       case HloOpcode::kDynamicUpdateSlice:
1741         // For DynamicUpdateSlice it has an "in-place" update semantics, but
1742         // inside of fusions memory space propagation doesn't propagate the
1743         // memory spaces all the way, causing possible mismatches. Relax the
1744         // constraint in that condition.
1745         return ShapesSame(instruction->shape(), inferred_shape,
1746                           only_compare_minor_to_major_in_layout,
1747                           /*ignore_memory_space=*/
1748                           instruction->parent()->IsFusionComputation());
1749 
1750       // We allow arbitrary layout and f32->bf16 transformations on all other
1751       // instructions, although this may be made more strict pending discussion
1752       // in b/112709536.
1753       default:
1754         if (opts_.allow_mixed_precision) {
1755           return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(),
1756                                                           inferred_shape);
1757         } else {
1758           return ShapeUtil::Compatible(instruction->shape(), inferred_shape);
1759         }
1760     }
1761   }();
1762   if (!equal) {
1763     return InternalError(
1764         "Expected instruction to have shape equal to %s, actual "
1765         "shape is %s:\n%s",
1766         StringifyShape(inferred_shape), StringifyShape(instruction->shape()),
1767         instruction->ToString());
1768   }
1769   return OkStatus();
1770 }
1771 
CheckShape(const HloInstruction * instruction,const StatusOr<Shape> & inferred_shape_status)1772 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
1773                                  const StatusOr<Shape>& inferred_shape_status) {
1774   if (!inferred_shape_status.ok()) {
1775     Status s = inferred_shape_status.status();
1776     tensorflow::errors::AppendToMessage(&s, ", for instruction ",
1777                                         instruction->ToString());
1778     return s;
1779   }
1780   return CheckShape(instruction, inferred_shape_status.ValueOrDie());
1781 }
1782 
CheckUnaryShape(const HloInstruction * instruction)1783 Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {
1784   return CheckShape(instruction,
1785                     ShapeInference::InferUnaryOpShape(instruction->opcode(),
1786                                                       instruction->operand(0)));
1787 }
1788 
CheckBinaryShape(const HloInstruction * instruction)1789 Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) {
1790   return CheckShape(
1791       instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(),
1792                                                       instruction->operand(0),
1793                                                       instruction->operand(1)));
1794 }
1795 
CheckTernaryShape(const HloInstruction * instruction)1796 Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) {
1797   return CheckShape(instruction,
1798                     ShapeInference::InferTernaryOpShape(
1799                         instruction->opcode(), instruction->operand(0),
1800                         instruction->operand(1), instruction->operand(2)));
1801 }
1802 
CheckVariadicShape(const HloInstruction * instruction)1803 Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
1804   return CheckShape(instruction,
1805                     ShapeInference::InferVariadicOpShape(
1806                         instruction->opcode(), instruction->operands()));
1807 }
1808 
VerifyEntryComputationLayout(const HloModule & module)1809 Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) {
1810   const HloComputation* computation = module.entry_computation();
1811   const auto& layout = module.entry_computation_layout();
1812   const ShapeLayout& result_layout = layout.result_layout();
1813 
1814   TF_RETURN_IF_ERROR(
1815       ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape()));
1816 
1817   if (!ShapeUtil::Compatible(computation->root_instruction()->shape(),
1818                              result_layout.shape())) {
1819     return InternalError(
1820         "Shape of the root instruction of entry computation (%s) should be "
1821         "compatible to one specified in module's entry computation layout (%s)",
1822         ShapeUtil::HumanString(computation->root_instruction()->shape()),
1823         ShapeUtil::HumanString(result_layout.shape()));
1824   }
1825 
1826   if (computation->num_parameters() != layout.parameter_count()) {
1827     return InternalError(
1828         "Number of parameters in entry computation layout (%d) must be same "
1829         "as number of parameters of entry computation (%d)",
1830         layout.parameter_count(), computation->num_parameters());
1831   }
1832 
1833   for (int i = 0; i < computation->num_parameters(); ++i) {
1834     const HloInstruction* parameter = computation->parameter_instruction(i);
1835     TF_RETURN_IF_ERROR(
1836         ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i)));
1837     if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) {
1838       return InternalError(
1839           "Shape of the entry computation parameter %d is %s should be "
1840           "compatible to the one specified in module's entry computation "
1841           "layout %s",
1842           i, ShapeUtil::HumanString(parameter->shape()),
1843           ShapeUtil::HumanString(layout.parameter_shape(i)));
1844     }
1845   }
1846 
1847   return OkStatus();
1848 }
1849 
ComputationsToString(absl::Span<HloComputation * const> computations)1850 std::string ComputationsToString(
1851     absl::Span<HloComputation* const> computations) {
1852   return absl::StrJoin(computations, ",",
1853                        [](std::string* s, const HloComputation* computation) {
1854                          s->append(computation->name());
1855                        });
1856 }
1857 
1858 // Verifies various invariants about the structure of the HLO:
1859 //
1860 // (1) each instruction has a non-null parent() set to the HloComputation
1861 // which
1862 //     contains it.
1863 //
1864 // (2) each computation has a non-null parent() set to the HloModule which
1865 //     contains it.
1866 //
1867 // (3) the operands of each instruction are in the same computation as the
1868 //     instruction.
VerifyHloStructure(HloModule * module)1869 Status VerifyHloStructure(HloModule* module) {
1870   for (const HloComputation* computation : module->computations()) {
1871     if (computation->parent() == nullptr) {
1872       return InternalError("Computation %s has a null parent pointer",
1873                            computation->name());
1874     }
1875     if (computation->parent() != module) {
1876       return InternalError(
1877           "Computation %s parent() does not point to parent module",
1878           computation->name());
1879     }
1880 
1881     for (const HloInstruction* instruction : computation->instructions()) {
1882       if (instruction->parent() == nullptr) {
1883         return InternalError("Instruction %s has a null parent pointer",
1884                              instruction->name());
1885       }
1886       if (instruction->parent() != computation) {
1887         return InternalError(
1888             "Instruction %s parent() does not point to parent computation",
1889             instruction->name());
1890       }
1891     }
1892   }
1893 
1894   // Check that operands are in the same computation separately from verifying
1895   // parent() correctness so conditions like a null HloInstruction::parent()
1896   // are identified and reported explicitly above rather than reporting a
1897   // mismatched operand.
1898   for (const HloComputation* computation : module->computations()) {
1899     for (const HloInstruction* instruction : computation->instructions()) {
1900       for (int i = 0; i < instruction->operand_count(); ++i) {
1901         const HloInstruction* operand = instruction->operand(i);
1902         if (operand->parent() != instruction->parent()) {
1903           return InternalError(
1904               "Operand %d (%s) of instruction %s is in a different "
1905               "computation: %s vs %s",
1906               i, operand->name(), instruction->name(),
1907               operand->parent() ? operand->parent()->name() : "(null)",
1908               instruction->parent()->name());
1909         }
1910       }
1911     }
1912   }
1913   return OkStatus();
1914 }
1915 
1916 namespace {
1917 
1918 // Returns true if the given Shape has a TOKEN shape as any subshape.
ShapeContainsToken(const Shape & shape)1919 bool ShapeContainsToken(const Shape& shape) {
1920   bool contains_token = false;
1921   ShapeUtil::ForEachSubshape(
1922       shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
1923         if (subshape.IsToken()) {
1924           contains_token = true;
1925         }
1926       });
1927   return contains_token;
1928 }
1929 
1930 // Verifies that all types entering and exiting the entry computation are
1931 // legal.
VerifyEntryAndExitShapes(const HloModule & module)1932 Status VerifyEntryAndExitShapes(const HloModule& module) {
1933   // Tokens cannot be passed as entry parameters.
1934   // TODO(b/80000000): Remove this constraint.
1935   for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
1936     HloInstruction* param =
1937         module.entry_computation()->parameter_instruction(i);
1938     if (ShapeContainsToken(param->shape())) {
1939       return InternalError(
1940           "Entry parameter %d is or contains a token shape: %s", i,
1941           ShapeUtil::HumanString(param->shape()));
1942     }
1943   }
1944   return OkStatus();
1945 }
1946 
1947 // Checks if the given two instructions share the same channel id.
CheckSameChannel(const HloInstruction * instr1,const HloInstruction * instr2)1948 Status CheckSameChannel(const HloInstruction* instr1,
1949                         const HloInstruction* instr2) {
1950   if (instr1->channel_id() != instr2->channel_id()) {
1951     return InternalError(
1952         "Expected to have the same channel id, actual channel ids are: %s "
1953         "(%d), %s (%d)",
1954         instr1->ToString(), *instr1->channel_id(), instr2->ToString(),
1955         *instr2->channel_id());
1956   }
1957   return OkStatus();
1958 }
1959 
1960 // Checks if the given two instructions have the same is_host_transfer
1961 // attribute value. Intsructions must be send/recv instructions or their
1962 // 'done' variant.
CheckSameIsHostTransfer(const HloInstruction * instr1,const HloInstruction * instr2)1963 Status CheckSameIsHostTransfer(const HloInstruction* instr1,
1964                                const HloInstruction* instr2) {
1965   const HloSendRecvInstruction* send_recv1 =
1966       DynCast<const HloSendRecvInstruction>(instr1);
1967   const HloSendRecvInstruction* send_recv2 =
1968       DynCast<const HloSendRecvInstruction>(instr2);
1969   TF_RET_CHECK(send_recv1 != nullptr);
1970   TF_RET_CHECK(send_recv2 != nullptr);
1971   if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
1972     return InternalError(
1973         "Expected instructions to have the same is-host-transfer property: "
1974         "%s, "
1975         "%s ",
1976         instr1->ToString(), instr2->ToString());
1977   }
1978   return OkStatus();
1979 }
1980 
VerifySingleUser(const HloInstruction * instruction,const absl::flat_hash_set<HloOpcode> & expected_users)1981 Status VerifySingleUser(const HloInstruction* instruction,
1982                         const absl::flat_hash_set<HloOpcode>& expected_users) {
1983   TF_RET_CHECK(instruction->users().size() == 1)
1984       << "The " << HloOpcodeString(instruction->opcode())
1985       << " instruction requires one consumer, found "
1986       << instruction->users().size();
1987 
1988   const HloInstruction* user = instruction->users().front();
1989   TF_RET_CHECK(expected_users.contains(user->opcode()))
1990       << "The consumer of a " << HloOpcodeString(instruction->opcode())
1991       << " instruction needs to be one of ("
1992       << absl::StrJoin(expected_users, ", ",
1993                        [](std::string* out, HloOpcode opcode) {
1994                          out->append(HloOpcodeString(opcode));
1995                        })
1996       << "), found " << HloOpcodeString(user->opcode());
1997   return OkStatus();
1998 }
1999 
VerifySingleOperand(const HloInstruction * instruction,const std::vector<HloOpcode> & expected_operands)2000 Status VerifySingleOperand(const HloInstruction* instruction,
2001                            const std::vector<HloOpcode>& expected_operands) {
2002   TF_RET_CHECK(instruction->operands().size() == 1)
2003       << "The " << HloOpcodeString(instruction->opcode())
2004       << " instruction requires one consumer, found "
2005       << instruction->users().size();
2006 
2007   const HloInstruction* operand = instruction->operand(0);
2008   TF_RET_CHECK(absl::c_find(expected_operands, operand->opcode()) !=
2009                expected_operands.end())
2010       << "The operand of a " << HloOpcodeString(instruction->opcode())
2011       << " instruction needs to be "
2012       << absl::StrJoin(expected_operands, " or ",
2013                        [](std::string* out, HloOpcode opcode) {
2014                          out->append(HloOpcodeString(opcode));
2015                        })
2016       << ", found " << HloOpcodeString(operand->opcode());
2017   return OkStatus();
2018 }
2019 
2020 // Checks asynchronous instruction pairs.
VerifyAsynchronousInstructionPairs(const HloModule & module)2021 Status VerifyAsynchronousInstructionPairs(const HloModule& module) {
2022   // CopyStart must have a single CopyDone user.
2023   for (const HloComputation* computation : module.computations()) {
2024     for (const HloInstruction* instruction : computation->instructions()) {
2025       switch (instruction->opcode()) {
2026         case HloOpcode::kAsyncStart: {
2027           TF_RETURN_IF_ERROR(VerifySingleUser(
2028               instruction, {HloOpcode::kAsyncUpdate, HloOpcode::kAsyncDone}));
2029           break;
2030         }
2031         case HloOpcode::kAsyncUpdate: {
2032           TF_RETURN_IF_ERROR(VerifySingleOperand(
2033               instruction, {HloOpcode::kAsyncStart, HloOpcode::kAsyncUpdate}));
2034           TF_RETURN_IF_ERROR(VerifySingleUser(
2035               instruction, {HloOpcode::kAsyncUpdate, HloOpcode::kAsyncDone}));
2036           break;
2037         }
2038         case HloOpcode::kAsyncDone: {
2039           TF_RETURN_IF_ERROR(VerifySingleOperand(
2040               instruction, {HloOpcode::kAsyncStart, HloOpcode::kAsyncUpdate}));
2041           break;
2042         }
2043         case HloOpcode::kAllReduceStart: {
2044           TF_RETURN_IF_ERROR(
2045               VerifySingleUser(instruction, {HloOpcode::kAllReduceDone}));
2046           break;
2047         }
2048         case HloOpcode::kAllReduceDone: {
2049           TF_RETURN_IF_ERROR(
2050               VerifySingleOperand(instruction, {HloOpcode::kAllReduceStart}));
2051           break;
2052         }
2053         case HloOpcode::kCopyStart: {
2054           TF_RETURN_IF_ERROR(
2055               VerifySingleUser(instruction, {HloOpcode::kCopyDone}));
2056           break;
2057         }
2058         case HloOpcode::kCopyDone: {
2059           TF_RETURN_IF_ERROR(
2060               VerifySingleOperand(instruction, {HloOpcode::kCopyStart}));
2061           break;
2062         }
2063         case HloOpcode::kCollectivePermuteStart: {
2064           TF_RETURN_IF_ERROR(VerifySingleUser(
2065               instruction, {HloOpcode::kCollectivePermuteDone}));
2066           break;
2067         }
2068         case HloOpcode::kCollectivePermuteDone: {
2069           TF_RETURN_IF_ERROR(VerifySingleOperand(
2070               instruction, {HloOpcode::kCollectivePermuteStart}));
2071           break;
2072         }
2073         default:
2074           break;
2075       }
2076     }
2077   }
2078   return OkStatus();
2079 }
2080 
2081 // Checks that AllReduce instructions in the module are either all layout
2082 // constrained or all unconstrained.
VerifyLayoutConstrainedAllReduce(const HloModule & module)2083 Status VerifyLayoutConstrainedAllReduce(const HloModule& module) {
2084   const HloAllReduceInstruction* reference = nullptr;
2085   for (const HloComputation* computation : module.computations()) {
2086     for (const HloInstruction* instruction : computation->instructions()) {
2087       if ((instruction->opcode() != HloOpcode::kAllReduce) &&
2088           (instruction->opcode() != HloOpcode::kAllReduceStart)) {
2089         continue;
2090       }
2091       auto all_reduce = DynCast<HloAllReduceInstruction>(instruction);
2092       if (!reference) {
2093         reference = all_reduce;
2094       }
2095       if (reference->constrain_layout() != all_reduce->constrain_layout()) {
2096         return FailedPrecondition(
2097             "HloModule has a mix of layout constrained and unconstrained "
2098             "AllReduce instructions.");
2099       }
2100     }
2101   }
2102   return OkStatus();
2103 }
2104 
2105 // Checks various invariants of channel instructions (send/recv and
2106 // collectives).
VerifyChannels(const HloModule & module)2107 Status VerifyChannels(const HloModule& module) {
2108   absl::flat_hash_map<int64_t, std::vector<const HloInstruction*>>
2109       channel_instructions;
2110 
2111   // Send/Recv instruction must have a single user: the corresponding
2112   // SendDone/RecvDone. with matching channel.
2113   for (const HloComputation* computation : module.computations()) {
2114     for (const HloInstruction* instruction : computation->instructions()) {
2115       auto channel_instr = DynCast<HloChannelInstruction>(instruction);
2116       if (!channel_instr || !channel_instr->channel_id()) {
2117         continue;
2118       }
2119       channel_instructions[*channel_instr->channel_id()].push_back(instruction);
2120 
2121       switch (instruction->opcode()) {
2122         case HloOpcode::kSend: {
2123           TF_RET_CHECK(instruction->users().size() == 1);
2124           const HloInstruction* send_done = instruction->users().front();
2125           TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
2126           TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done));
2127           TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done));
2128           break;
2129         }
2130         case HloOpcode::kRecv: {
2131           TF_RET_CHECK(instruction->users().size() == 1);
2132           const HloInstruction* recv_done = instruction->users().front();
2133           TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
2134           TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done));
2135           TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done));
2136           break;
2137         }
2138         case HloOpcode::kSendDone:
2139           TF_RET_CHECK(instruction->operands().size() == 1);
2140           TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend);
2141           break;
2142         case HloOpcode::kRecvDone:
2143           TF_RET_CHECK(instruction->operands().size() == 1);
2144           TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv);
2145           break;
2146         default:
2147           break;
2148       }
2149     }
2150   }
2151 
2152   // Iterate over each channel to check invariants.
2153   for (auto& pair : channel_instructions) {
2154     auto& instructions = pair.second;
2155     const HloInstruction* first = instructions[0];
2156     auto sendrecv = DynCast<HloSendRecvInstruction>(first);
2157     if (sendrecv) {
2158       absl::flat_hash_set<HloOpcode> opcodes;
2159       for (const HloInstruction* instr : instructions) {
2160         opcodes.insert(instr->opcode());
2161         auto cast = DynCast<HloSendRecvInstruction>(instr);
2162         TF_RET_CHECK(cast != nullptr)
2163             << "channel " << pair.first
2164             << " is used for different types of channel instructions";
2165       }
2166       if (sendrecv->is_host_transfer()) {
2167         TF_RET_CHECK(instructions.size() == 2)
2168             << "channel " << pair.first
2169             << " is used for multiple host send/recv instructions";
2170       } else {
2171         TF_RET_CHECK(instructions.size() == opcodes.size())
2172             << "channel " << pair.first
2173             << " is used for multiple send/recv instructions";
2174       }
2175     } else {
2176       for (const HloInstruction* instr : instructions) {
2177         TF_RET_CHECK(first->opcode() == instr->opcode())
2178             << "channel " << pair.first
2179             << " is used for different types of channel instructions";
2180       }
2181     }
2182   }
2183 
2184   return OkStatus();
2185 }
2186 
2187 // CHECKs various invariants of a fusion instruction.
CheckFusionInstruction(HloInstruction * fusion)2188 Status CheckFusionInstruction(HloInstruction* fusion) {
2189   // The parent fusion instruction of the fusion computation must be 'fusion'.
2190   HloComputation* fused_computation = fusion->fused_instructions_computation();
2191   if (fusion != fused_computation->FusionInstruction()) {
2192     return InternalError(
2193         "Instruction of fused computation does not match expected "
2194         "instruction "
2195         "%s.",
2196         fusion->ToString());
2197   }
2198 
2199   // Fused root instruction and fused parameters must all be owned by the
2200   // fusion computation.
2201   bool root_owned = false;
2202   const std::vector<HloInstruction*>& fused_parameters =
2203       fusion->fused_parameters();
2204   const HloInstruction* fused_root = fusion->fused_expression_root();
2205   std::vector<bool> parameter_owned(fused_parameters.size(), false);
2206   for (auto* instruction : fused_computation->instructions()) {
2207     if (fused_root == instruction) {
2208       if (root_owned) {
2209         return InternalError("Root appears more than once in %s.",
2210                              fusion->ToString());
2211       }
2212       root_owned = true;
2213     }
2214     for (int i = 0; i < fused_parameters.size(); ++i) {
2215       if (fused_parameters[i] == instruction) {
2216         if (parameter_owned[i]) {
2217           return InternalError("Parameter appears more than once in %s.",
2218                                fusion->ToString());
2219         }
2220         parameter_owned[i] = true;
2221       }
2222     }
2223   }
2224   if (!root_owned) {
2225     return InternalError("Root not found in computation of %s.",
2226                          fusion->ToString());
2227   }
2228   // Make sure all the parameter_owned entries are set
2229   for (int i = 0; i < parameter_owned.size(); i++) {
2230     if (!parameter_owned[i]) {
2231       return InternalError("Parameter %d not found in computation of %s.", i,
2232                            fusion->ToString());
2233     }
2234   }
2235 
2236   // Fused root must have no users.
2237   if (fused_root->user_count() != 0) {
2238     return InternalError("Root of %s may not have users.", fusion->ToString());
2239   }
2240 
2241   // All uses of fused instructions must be in the fusion computation, and
2242   // every non-root instruction must have at least one use.
2243   for (auto* instruction :
2244        fusion->fused_instructions_computation()->instructions()) {
2245     if (instruction != fused_root) {
2246       if (instruction->user_count() == 0) {
2247         return InternalError("Non-root instruction %s in %s must have users.",
2248                              instruction->ToString(), fusion->ToString());
2249       }
2250       for (auto& user : instruction->users()) {
2251         if (fused_computation != user->parent()) {
2252           return InternalError(
2253               "Non-root instruction %s in %s may not have external users.",
2254               instruction->ToString(), fusion->ToString());
2255         }
2256       }
2257     }
2258   }
2259 
2260   // Fused parameter instructions must be numbered contiguously and match up
2261   // (shapes equal) with their respective operand.
2262   CHECK_EQ(fusion->operands().size(), fused_parameters.size());
2263   std::vector<bool> parameter_numbers(fused_parameters.size(), false);
2264   for (auto fused_param : fused_parameters) {
2265     int64_t param_no = fused_param->parameter_number();
2266     if (param_no < 0) {
2267       return InternalError("Unexpected negative parameter number %d in %s.",
2268                            param_no, fusion->ToString());
2269     }
2270     if (param_no >= fused_parameters.size()) {
2271       return InternalError(
2272           "Unexpected parameter number %d in %s: higher then number of "
2273           "parameters %lu.",
2274           param_no, fusion->ToString(), fused_parameters.size());
2275     }
2276     if (parameter_numbers[param_no]) {
2277       return InternalError(
2278           "Did not expect parameter number %d more than once in %s.", param_no,
2279           fusion->ToString());
2280     }
2281     parameter_numbers[param_no] = true;
2282   }
2283   // Make sure all the parameter_numbers entries were seen.
2284   for (int i = 0; i < parameter_numbers.size(); i++) {
2285     if (!parameter_numbers[i]) {
2286       return InternalError("Did not see parameter number %d in %s.", i,
2287                            fusion->ToString());
2288     }
2289   }
2290 
2291   TF_RET_CHECK(fusion->called_computations() ==
2292                absl::Span<HloComputation* const>(
2293                    {fusion->fused_instructions_computation()}))
2294       << "Fusion HLO calls computations other than the "
2295          "fused_instructions_computation: "
2296       << fusion->ToString() << " fusion->fused_instructions_computation(): "
2297       << fusion->fused_instructions_computation()->ToString()
2298       << " fusion->called_computations(): "
2299       << ComputationsToString(fusion->called_computations());
2300 
2301   for (const auto& fused : fusion->fused_instructions()) {
2302     TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation())
2303         << "Fused HLO was missing a parent: " << fused->ToString()
2304         << " parent: " << fused->parent()
2305         << " computation: " << fusion->parent();
2306   }
2307 
2308   // TODO(b/65423525): We'd like to check that all operands are distinct.
2309   // This is currently disabled due to the invariant being violated by
2310   // multi-output fusion.
2311   return OkStatus();
2312 }
2313 
2314 // Checks that the operand shapes are compatible to the output shape, i.e.,
2315 // that there are no implicit broadcasts.
CheckElementwiseInstruction(HloInstruction * instruction)2316 Status CheckElementwiseInstruction(HloInstruction* instruction) {
2317   const Shape& out_shape = instruction->shape();
2318   for (HloInstruction* operand : instruction->operands()) {
2319     const Shape& operand_shape = operand->shape();
2320     if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
2321       return FailedPrecondition(
2322           "Implicit broadcast is not allowed in HLO."
2323           "Found different shapes for instruction %s.\n"
2324           "output: %s\noperand: %s\n",
2325           HloOpcodeString(instruction->opcode()),
2326           ShapeUtil::HumanString(out_shape),
2327           ShapeUtil::HumanString(operand_shape));
2328     }
2329   }
2330   if (auto* comparison = DynCast<HloCompareInstruction>(instruction)) {
2331     const Shape& operand_shape = comparison->operand(1)->shape();
2332     PrimitiveType operand_element_type = operand_shape.element_type();
2333     Comparison::Type default_comparison_type =
2334         Comparison::DefaultComparisonType(operand_element_type);
2335     if (primitive_util::IsFloatingPointType(operand_element_type)) {
2336       if (comparison->type() != Comparison::Type::kFloat &&
2337           comparison->type() != Comparison::Type::kFloatTotalOrder) {
2338         return FailedPrecondition(
2339             "Expected comparison type %s or %s.\n"
2340             "actual: %s\noperand: %s\n",
2341             ComparisonTypeToString(Comparison::Type::kFloat),
2342             ComparisonTypeToString(Comparison::Type::kFloatTotalOrder),
2343             ComparisonTypeToString(comparison->type()),
2344             ShapeUtil::HumanString(operand_shape));
2345       }
2346     } else if (comparison->type() != default_comparison_type) {
2347       return FailedPrecondition(
2348           "Expected comparison type %s.\n"
2349           "actual: %s\noperand: %s\n",
2350           ComparisonTypeToString(default_comparison_type),
2351           ComparisonTypeToString(comparison->type()),
2352           ShapeUtil::HumanString(operand_shape));
2353     }
2354   }
2355   return OkStatus();
2356 }
2357 
2358 // Visitor which verifies various fields on the HLO instruction. This class does
2359 // not check result shape as that is checked in the ShapeVerifier.
2360 class InstructionVerifier : public DfsHloVisitorWithDefault {
2361  public:
InstructionVerifier(const HloVerifierOpts & opts)2362   explicit InstructionVerifier(const HloVerifierOpts& opts) : opts_(opts) {}
2363 
DefaultAction(HloInstruction *)2364   Status DefaultAction(HloInstruction*) override { return OkStatus(); }
2365 
HandleFusion(HloInstruction * fusion)2366   Status HandleFusion(HloInstruction* fusion) override {
2367     TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(
2368         fusion, /*skip_nested_async_op_check*/ false));
2369     return CheckFusionInstruction(fusion);
2370   }
2371 
HandleBroadcast(HloInstruction * broadcast)2372   Status HandleBroadcast(HloInstruction* broadcast) override {
2373     // If you see this failure then someone has confused the difference
2374     // between the HLO broadcast op, and the UserComputation broadcast
2375     // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
2376     // or ComputationLowerer::Visit()
2377     TF_RET_CHECK(broadcast->dimensions().size() ==
2378                  broadcast->operand(0)->shape().rank())
2379         << "Broadcast HLO (" << broadcast->ToShortString()
2380         << ") has invalid number of dimensions: "
2381         << broadcast->dimensions().size()
2382         << " != " << broadcast->operand(0)->shape().rank();
2383     if (opts_.verify_broadcast_dimensions_order) {
2384       TF_RET_CHECK(absl::c_is_sorted(broadcast->dimensions()))
2385           << "Broadcast dimensions should be ordered, got: "
2386           << broadcast->ToString();
2387     }
2388     return OkStatus();
2389   }
2390 
HandleBitcastConvert(HloInstruction * c)2391   Status HandleBitcastConvert(HloInstruction* c) override {
2392     // Shape verifier will check all we need.
2393     return OkStatus();
2394   }
2395 
HandleWhile(HloInstruction * xla_while)2396   Status HandleWhile(HloInstruction* xla_while) override {
2397     auto* while_cond = xla_while->while_condition();
2398     auto* while_body = xla_while->while_body();
2399     if (while_cond->num_parameters() != 1) {
2400       return FailedPrecondition(
2401           "While condition must have exactly 1 parameter; had %d : %s",
2402           while_cond->num_parameters(), while_cond->ToString());
2403     }
2404     if (while_body->num_parameters() != 1) {
2405       return FailedPrecondition(
2406           "While body must have exactly 1 parameter; had %d : %s",
2407           while_body->num_parameters(), while_body->ToString());
2408     }
2409     if (xla_while->operand_count() != 1) {
2410       return FailedPrecondition(
2411           "While loop must have exactly one operand; had %d : %s",
2412           xla_while->operand_count(), xla_while->ToString());
2413     }
2414     // Allow kWhile to contain computations on separate thread.
2415     TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(
2416         xla_while, /*skip_nested_async_op_check=*/true));
2417     return OkStatus();
2418   }
2419 
HandleCall(HloInstruction * call)2420   Status HandleCall(HloInstruction* call) override {
2421     // Allow kCall to contain computations on separate thread.
2422     return CheckCallableInstructionThreadName(
2423         call, /*skip_nested_async_op_check=*/true);
2424   }
2425 
HandleConditional(HloInstruction * conditional)2426   Status HandleConditional(HloInstruction* conditional) override {
2427     for (int b = 0; b < conditional->branch_count(); ++b) {
2428       if (conditional->branch_computation(b)->num_parameters() != 1) {
2429         return FailedPrecondition(
2430             "Branch computation %s of %s must have 1 parameter instead of %d",
2431             conditional->branch_computation(b)->name(), conditional->ToString(),
2432             conditional->branch_computation(b)->num_parameters());
2433       }
2434     }
2435     // Allow kConditional to contain computations on separate thread.
2436     TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(
2437         conditional, /*skip_nested_async_op_check=*/true));
2438     return OkStatus();
2439   }
2440 
HandleElementwiseUnary(HloInstruction * instruction)2441   Status HandleElementwiseUnary(HloInstruction* instruction) override {
2442     return CheckElementwiseInstruction(instruction);
2443   }
2444 
HandleElementwiseBinary(HloInstruction * instruction)2445   Status HandleElementwiseBinary(HloInstruction* instruction) override {
2446     return CheckElementwiseInstruction(instruction);
2447   }
2448 
HandleGetTupleElement(HloInstruction * gte)2449   Status HandleGetTupleElement(HloInstruction* gte) override {
2450     TF_RET_CHECK(gte->operand(0)->shape().IsTuple());
2451     return OkStatus();
2452   }
2453 
HandleTranspose(HloInstruction * transpose)2454   Status HandleTranspose(HloInstruction* transpose) override {
2455     const Shape& shape = transpose->shape();
2456     const HloInstruction* operand = transpose->operand(0);
2457     TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size());
2458     TF_RET_CHECK(shape.dimensions().size() ==
2459                  transpose->operand(0)->shape().dimensions().size());
2460     TF_RET_CHECK(std::equal(
2461         shape.dimensions().begin(), shape.dimensions().end(),
2462         Permute(operand->shape().dimensions(), transpose->dimensions())
2463             .begin()))
2464         << "shape: " << shape << ", operand->shape(): " << shape
2465         << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ")
2466         << "}";
2467     return OkStatus();
2468   }
2469 
HandleAllReduce(HloInstruction * crs)2470   Status HandleAllReduce(HloInstruction* crs) override {
2471     if (crs->channel_id().has_value()) {
2472       TF_RET_CHECK(crs->channel_id().value() > 0)
2473           << "All reduce channel id must be greater than 0 for "
2474           << crs->ToShortString();
2475     }
2476     return OkStatus();
2477   }
2478 
HandleReshape(HloInstruction * hlo)2479   Status HandleReshape(HloInstruction* hlo) override {
2480     if (opts_.verify_reshape_is_bitcast && !hlo->IsFused()) {
2481       TF_RET_CHECK(
2482           ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()))
2483           << "Reshape should be a physical bitcast, got: " << hlo->ToString();
2484     }
2485     return OkStatus();
2486   }
2487 
HandleCustomCall(HloInstruction * hlo)2488   Status HandleCustomCall(HloInstruction* hlo) override {
2489     if (opts_.verify_custom_call_nested_computation_thread_name) {
2490       // Allow kCustomCall to contain computations on separate thread.
2491       return CheckCallableInstructionThreadName(
2492           hlo, /*skip_nested_async_op_check=*/true);
2493     }
2494     return OkStatus();
2495   }
2496 
Preprocess(HloInstruction * instruction)2497   Status Preprocess(HloInstruction* instruction) override {
2498     auto previous = instructions_by_name_.find(instruction->name());
2499     TF_RET_CHECK(previous == instructions_by_name_.end())
2500         << "HLO has name that is not unique within module:\n"
2501         << instruction->ToString()
2502         << " in computation: " << instruction->parent()->name()
2503         << "\nPrevious HLO with same name:\n"
2504         << previous->second->ToString()
2505         << " in computation: " << previous->second->parent()->name();
2506     instructions_by_name_[instruction->name()] = instruction;
2507     return OkStatus();
2508   }
2509 
Postprocess(HloInstruction * instruction)2510   Status Postprocess(HloInstruction* instruction) override {
2511     if (!opts_.InstructionCanChangeLayout(instruction) &&
2512         LayoutUtil::IsDenseArray(instruction->shape()) &&
2513         instruction->shape().has_layout()) {
2514       const Shape& result_shape = instruction->shape();
2515       const Layout& result_layout = result_shape.layout();
2516       for (HloInstruction* operand : instruction->operands()) {
2517         const Shape& operand_shape = operand->shape();
2518         if (LayoutUtil::IsDenseArray(operand_shape) &&
2519             operand_shape.rank() == result_shape.rank() &&
2520             operand_shape.has_layout()) {
2521           const Layout& operand_layout = operand_shape.layout();
2522           TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
2523               << "Instruction shouldn't change layouts "
2524               << instruction->ToString() << " From " << result_shape << " To "
2525               << operand_shape;
2526         }
2527       }
2528     }
2529 
2530     return OkStatus();
2531   }
2532 
2533  private:
2534   absl::flat_hash_map<std::string, const HloInstruction*> instructions_by_name_;
2535   const HloVerifierOpts& opts_;
2536 };
2537 
2538 }  // namespace
2539 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2540 StatusOr<bool> HloVerifier::Run(
2541     HloModule* module,
2542     const absl::flat_hash_set<absl::string_view>& execution_threads) {
2543   auto disabled = module->config().debug_options().xla_disable_hlo_passes();
2544   if (std::find(disabled.begin(), disabled.end(), name()) != disabled.end()) {
2545     return false;
2546   }
2547   auto status_or_changed = [&]() -> StatusOr<bool> {
2548     TF_RET_CHECK(!module->name().empty());
2549 
2550     if (module->entry_computation()->IsFusionComputation()) {
2551       return InvalidArgument(
2552           "Module entry computation cannot be a fusion computation");
2553     }
2554 
2555     TF_RETURN_IF_ERROR(VerifyHloStructure(module));
2556     TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module));
2557     TF_RETURN_IF_ERROR(VerifyChannels(*module));
2558 
2559     std::unique_ptr<ShapeVerifier> shape_verifier =
2560         target_metadata_->GetVerifier();
2561     InstructionVerifier instruction_verifier(
2562         target_metadata_->GetVerifierOpts());
2563     for (auto* computation : module->computations(execution_threads)) {
2564       TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
2565       TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
2566     }
2567 
2568     TF_RETURN_IF_ERROR(shape_verifier->VerifyEntryComputationLayout(*module));
2569     TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
2570 
2571     // If the module has a schedule, it must be valid.
2572     if (module->has_schedule()) {
2573       TF_RETURN_IF_ERROR(module->schedule().Verify());
2574     }
2575 
2576     TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(
2577         *module, [this](const Shape& shape) -> int64_t {
2578           if (target_metadata_->GetVerifierOpts().IsLayoutSensitive()) {
2579             return target_metadata_->GetVerifierOpts().ShapeSize(shape);
2580           } else {
2581             return 0;
2582           }
2583         }));
2584 
2585     TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module));
2586     TF_RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module));
2587     return false;
2588   }();
2589   if (status_or_changed.ok()) {
2590     return status_or_changed.ValueOrDie();
2591   }
2592   return Status(status_or_changed.status().code(),
2593                 absl::StrCat("during context [", context_, "]: ",
2594                              status_or_changed.status().error_message()));
2595 }
2596 
2597 }  // namespace xla
2598