xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17 
18 #include <vector>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/match.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
24 #include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/tuple_util.h"
32 #include "tensorflow/compiler/xla/service/while_util.h"
33 #include "tensorflow/compiler/xla/shape_tree.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/window_util.h"
38 namespace xla {
39 
40 namespace {
41 // Replace `narrow_comp` with a new computation with `wide_shape` as input.
WidenComputation(HloComputation * narrow_comp,const Shape & wide_shape)42 StatusOr<HloComputation*> WidenComputation(HloComputation* narrow_comp,
43                                            const Shape& wide_shape) {
44   TF_RET_CHECK(wide_shape.IsTuple());
45   const Shape& narrow_shape = narrow_comp->parameter_instruction(0)->shape();
46   if (Shape::Equal()(wide_shape, narrow_shape)) {
47     // No need to widen the computation.
48     return narrow_comp;
49   }
50   HloComputation* wide_comp = [&]() {
51     HloComputation::Builder builder(absl::StrCat("wide.", narrow_comp->name()));
52     builder.AddInstruction(
53         HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
54     return narrow_comp->parent()->AddEmbeddedComputation(builder.Build());
55   }();
56 
57   HloInstruction* wide_parameter = wide_comp->parameter_instruction(0);
58   HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
59       wide_parameter, narrow_shape.tuple_shapes_size());
60   HloInstruction* call_narrow_comp = wide_comp->AddInstruction(
61       HloInstruction::CreateCall(narrow_comp->root_instruction()->shape(),
62                                  {truncated_parameter}, narrow_comp));
63   wide_comp->set_root_instruction(call_narrow_comp,
64                                   /*accept_different_shape=*/true);
65   TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_comp).status());
66   return wide_comp;
67 }
68 }  // namespace
69 
70 class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
71  public:
DynamicDimensionInferenceVisitor(const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent,DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler,DynamicDimensionInference::ShapeCheckMode shape_check_mode)72   explicit DynamicDimensionInferenceVisitor(
73       const DynamicParameterBinding& param_bindings,
74       DynamicDimensionInference* parent,
75       DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler,
76       DynamicDimensionInference::ShapeCheckMode shape_check_mode)
77       : param_bindings_(param_bindings),
78         parent_(parent),
79         custom_call_handler_(std::move(custom_call_handler)),
80         shape_check_mode_(shape_check_mode) {}
81 
82   Status DefaultAction(HloInstruction* hlo) override;
83 
Run(HloComputation * computation,const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent,DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler=nullptr,DynamicDimensionInference::ShapeCheckMode shape_check_mode=DynamicDimensionInference::ShapeCheckMode::kIgnore,const DynamicDimensionInference::AssertionGenerator & assertion_generator=nullptr)84   static Status Run(HloComputation* computation,
85                     const DynamicParameterBinding& param_bindings,
86                     DynamicDimensionInference* parent,
87                     DynamicDimensionInference::CustomCallInferenceHandler
88                         custom_call_handler = nullptr,
89                     DynamicDimensionInference::ShapeCheckMode shape_check_mode =
90                         DynamicDimensionInference::ShapeCheckMode::kIgnore,
91                     const DynamicDimensionInference::AssertionGenerator&
92                         assertion_generator = nullptr) {
93     DynamicDimensionInferenceVisitor visitor(param_bindings, parent,
94                                              std::move(custom_call_handler),
95                                              shape_check_mode);
96 
97     TF_RETURN_IF_ERROR(computation->Accept(&visitor));
98     if (visitor.shape_assertion_ != nullptr) {
99       CHECK(assertion_generator);
100       assertion_generator(visitor.shape_assertion_);
101     }
102     return Status::OK();
103   }
104 
105   Status HandleParameter(HloInstruction* hlo) override;
106 
107   Status HandleReduce(HloInstruction* hlo) override;
108 
109   Status HandleDot(HloInstruction* hlo) override;
110 
111   Status HandleTuple(HloInstruction* hlo) override;
112 
113   Status HandleTranspose(HloInstruction* hlo) override;
114 
115   Status HandleDynamicReshape(HloInstruction* hlo) override;
116 
117   Status HandleReshape(HloInstruction* hlo) override;
118 
119   Status HandleSort(HloInstruction* hlo) override;
120 
121   Status HandlePad(HloInstruction* hlo) override;
122 
123   Status HandleCustomCall(HloInstruction* hlo) override;
124 
125   Status HandleBroadcast(HloInstruction* hlo) override;
126 
127   Status HandleGetDimensionSize(HloInstruction* hlo) override;
128 
129   Status HandleSetDimensionSize(HloInstruction* hlo) override;
130 
131   Status HandleSelect(HloInstruction* hlo) override;
132 
133   Status HandleConvolution(HloInstruction* hlo) override;
134 
135   Status HandleConcatenate(HloInstruction* hlo) override;
136 
137   Status HandleReduceWindow(HloInstruction* hlo) override;
138 
139   Status HandleReverse(HloInstruction* hlo) override;
140 
141   Status HandleSelectAndScatter(HloInstruction* hlo) override;
142 
143   Status HandleGetTupleElement(HloInstruction* hlo) override;
144 
145   Status HandleElementwiseUnary(HloInstruction* hlo) override;
146 
147   Status HandleElementwiseNary(HloInstruction* hlo);
148 
149   Status HandleElementwiseBinary(HloInstruction* hlo) override;
150 
151   Status HandleClamp(HloInstruction* hlo) override;
152 
153   Status HandleConditional(HloInstruction* hlo) override;
154 
155   Status HandleWhile(HloInstruction* hlo) override;
156 
157   Status HandleSlice(HloInstruction* hlo) override;
158 
159   Status HandleDynamicSlice(HloInstruction* hlo) override;
160 
161   Status HandleDynamicUpdateSlice(HloInstruction* hlo) override;
162 
163   Status HandleGather(HloInstruction* hlo) override;
164 
165   Status HandleScatter(HloInstruction* hlo) override;
166 
167   Status HandleMap(HloInstruction* hlo) override;
168 
169   Status HandleDomain(HloInstruction* hlo) override;
170 
171  private:
172   using OperandDynamicDimensionFn = std::function<Status(
173       HloInstruction* operand, ShapeIndex index, int64_t dimension,
174       int64_t operand_index, HloInstruction* dynamic_size)>;
175 
176   using DynamicDimensionFn = std::function<Status(
177       ShapeIndex index, int64_t dimension, HloInstruction* dynamic_size)>;
178 
179   Status HandleDynamicConvolutionForward(HloInstruction* hlo,
180                                          int64_t operand_index,
181                                          int64_t dimension,
182                                          HloInstruction* dynamic_size);
183 
184   Status HandleDynamicConvolutionKernelGrad(HloInstruction* hlo,
185                                             int64_t operand_index,
186                                             int64_t dimension);
187 
188   Status HandleDynamicConvolutionInputGrad(HloInstruction* hlo,
189                                            int64_t operand_index,
190                                            int64_t dimension);
191 
192   Status HandleDynamicWindowSamePadding(HloInstruction* hlo,
193                                         HloInstruction* dynamic_size,
194                                         int64_t operand_index,
195                                         int64_t dimension);
196 
197   Status ForEachOperandDynamicDimension(HloInstruction* inst,
198                                         const OperandDynamicDimensionFn&);
199   Status ForEachDynamicDimensionInOperand(HloInstruction* inst,
200                                           int64_t operand_index,
201                                           const OperandDynamicDimensionFn&);
202   Status ForEachDynamicDimension(HloInstruction* inst,
203                                  const DynamicDimensionFn& fn);
204 
205   // Insert shape check to make sure `dim1` is equal to `dim2`. If
206   // support_implicit_broadcast is true, the check will pass if either of them
207   // is 1, even if they are different.
208   Status InsertShapeCheck(HloInstruction* dim1, HloInstruction* dim2,
209                           bool support_implicit_broadcast);
210 
211   // Pass through a dynamic dimension from the input to the output with the
212   // same value and index in the shape. This is a helper function to handle
213   // trivial instructions like elementwise operations.
214   Status PassThroughDynamicDimension(HloInstruction*);
215 
216   // The dynamic parameter bindings of this computation.
217   const DynamicParameterBinding& param_bindings_;
218 
219   // A pointer to DynamicDimensionInference, used to update the dynamic mapping.
220   DynamicDimensionInference* parent_;
221 
222   // A handler for custom calls.
223   DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler_;
224 
225   // Indicates what to do at places where shape check is needed.
226   DynamicDimensionInference::ShapeCheckMode shape_check_mode_;
227 
228   // Value which has to be `true` for the shapes to match.
229   HloInstruction* shape_assertion_ = nullptr;
230 };
231 
DefaultAction(HloInstruction * hlo)232 Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) {
233   return ForEachOperandDynamicDimension(
234       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
235                int64_t operand_index, HloInstruction* dynamic_size) {
236         return UnimplementedStrCat(
237             "Asked to propagate a dynamic dimension from hlo ", operand->name(),
238             "@", index.ToString(), "@", dimension, " to hlo ", hlo->ToString(),
239             ", which is not implemented.");
240       });
241 }
242 
HandleGetTupleElement(HloInstruction * hlo)243 Status DynamicDimensionInferenceVisitor::HandleGetTupleElement(
244     HloInstruction* hlo) {
245   return ForEachOperandDynamicDimension(
246       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
247                int64_t operand_index, HloInstruction* dynamic_size) {
248         if (hlo->tuple_index() == index[0]) {
249           ShapeIndex new_index(ShapeIndexView(index).subspan(1));
250           parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size);
251         }
252         return OkStatus();
253       });
254 }
255 
HandleTuple(HloInstruction * hlo)256 Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) {
257   return ForEachOperandDynamicDimension(
258       hlo, [&](HloInstruction*, ShapeIndex index, int64_t dimension,
259                int64_t operand_index, HloInstruction* dynamic_size) {
260         index.push_front(operand_index);
261         parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
262         return OkStatus();
263       });
264 }
265 
HandleBroadcast(HloInstruction * hlo)266 Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
267   return ForEachOperandDynamicDimension(
268       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
269                int64_t operand_index, HloInstruction* dynamic_size) {
270         int64_t broadcast_dim = hlo->dimensions(dimension);
271         parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size);
272         return OkStatus();
273       });
274 }
275 
HandleCustomCall(HloInstruction * hlo)276 Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
277   if (hlo->custom_call_target() == "PadToStatic") {
278     for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
279       if (hlo->operand(0)->shape().is_dynamic_dimension(i)) {
280         HloInstruction* dynamic_size =
281             hlo->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
282                 ShapeUtil::MakeScalarShape(S32), hlo, i + 1));
283         // PadToStatic converts a dynamic dimension to static dimension. It then
284         // returns the padded data output and the dynamic sizes of input
285         // dimensions.
286         ShapeIndex data_output = {0};
287         parent_->SetDynamicSize(hlo, data_output, i, dynamic_size);
288       }
289     }
290     return OkStatus();
291   }
292   if (custom_call_handler_) {
293     return custom_call_handler_(hlo, parent_);
294   }
295 
296   if (hlo->custom_call_target() == "DynamicConvolutionForward") {
297     // If input feature is dynamic and kernel feature is static, we can infer
298     // that input feature is also static.
299     // E.g.,:
300     // lhs = [B, X, Y, ?]
301     // rhs = [X, Y, I, O]
302     // dim_labels = b01f_01io
303     // We can infer that the dynamic dimension in rhs is static I.
304     const ConvolutionDimensionNumbers& dnums =
305         hlo->convolution_dimension_numbers();
306     HloInstruction* input_feature = parent_->GetDynamicSize(
307         hlo->mutable_operand(0), {}, dnums.input_feature_dimension());
308     HloInstruction* kernel_feature = parent_->GetDynamicSize(
309         hlo->mutable_operand(1), {}, dnums.kernel_input_feature_dimension());
310 
311     if (input_feature != nullptr && kernel_feature == nullptr) {
312       if (hlo->mutable_operand(0)->shape().dimensions(
313               dnums.input_feature_dimension()) ==
314           hlo->mutable_operand(1)->shape().dimensions(
315               dnums.kernel_input_feature_dimension()))
316         parent_->SetDynamicSize(hlo->mutable_operand(0), {},
317                                 dnums.input_feature_dimension(), nullptr);
318     }
319   }
320   return ForEachOperandDynamicDimension(
321       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
322                int64_t operand_index, HloInstruction* dynamic_size) {
323         // Resize custom call should propagate dynamic batch (0) and channel (3)
324         // dimensions.
325         if (hlo->custom_call_target() == "SliceToDynamic" ||
326             hlo->custom_call_target() == "Sharding" ||
327             (absl::StartsWith(hlo->custom_call_target(), "Resize") &&
328              (dimension == 0 || dimension == 3))) {
329           parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
330           return OkStatus();
331         }
332         if (hlo->custom_call_target() == "DynamicReduceWindowSamePadding") {
333           if (hlo->operand_count() > 2) {
334             return Unimplemented(
335                 "DynamicReduceWindowSamePadding doesn't support variadic "
336                 "reduce window %s",
337                 hlo->ToString());
338           }
339           return HandleDynamicWindowSamePadding(hlo, dynamic_size,
340                                                 operand_index, dimension);
341         }
342 
343         if (hlo->custom_call_target() == "DynamicSelectAndScatterSamePadding") {
344           if (operand_index == 1) {
345             // Operand 0 (input) determines dynamic output size. We ignore the
346             // dynamic size in the operand 1 (output gradient).
347             return OkStatus();
348           }
349           parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
350           return OkStatus();
351         }
352 
353         if (hlo->custom_call_target() == "DynamicConvolutionInputGrad") {
354           return HandleDynamicConvolutionInputGrad(hlo, operand_index,
355                                                    dimension);
356         }
357 
358         if (hlo->custom_call_target() == "DynamicConvolutionKernelGrad") {
359           return HandleDynamicConvolutionKernelGrad(hlo, operand_index,
360                                                     dimension);
361         }
362 
363         if (hlo->custom_call_target() == "DynamicConvolutionForward") {
364           return HandleDynamicConvolutionForward(hlo, operand_index, dimension,
365                                                  dynamic_size);
366         }
367         return Unimplemented(
368             "CustomCall \"%s\" is not supported to have a dynamic dimension",
369             hlo->custom_call_target());
370       });
371 }
372 
HandleSort(HloInstruction * hlo)373 Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) {
374   return ForEachOperandDynamicDimension(
375       hlo,
376       [&](HloInstruction* operand, ShapeIndex index, int64_t dynamic_dimension,
377           int64_t operand_index, HloInstruction* dynamic_size) {
378         HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
379         if (sort->values_count() == 0) {
380           parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size);
381         } else {
382           parent_->SetDynamicSize(hlo, {operand_index}, dynamic_dimension,
383                                   dynamic_size);
384         }
385 
386         return OkStatus();
387       });
388 }
389 
HandlePad(HloInstruction * hlo)390 Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
391   return ForEachOperandDynamicDimension(
392       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
393                int64_t operand_index, HloInstruction* dynamic_size) {
394         if (operand_index != 0) {
395           return Unimplemented(
396               "Dynamic dimension on padding value is not supported");
397         }
398         const PaddingConfig_PaddingConfigDimension& padding_config =
399             hlo->padding_config().dimensions(dimension);
400 
401         HloInstruction* dynamic_size_adjusted = dynamic_size;
402         if (padding_config.interior_padding() != 0) {
403           // Adjust for interior padding :
404           // Size' = max((Size - 1), 0) * interior_padding + Size
405           HloInstruction* one =
406               hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
407                   LiteralUtil::CreateR0<int32_t>(1)));
408           HloInstruction* zero =
409               hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
410                   LiteralUtil::CreateR0<int32_t>(0)));
411           HloInstruction* interior_padding = hlo->parent()->AddInstruction(
412               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
413                   padding_config.interior_padding())));
414           dynamic_size_adjusted =
415               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
416                   dynamic_size_adjusted->shape(), HloOpcode::kSubtract,
417                   dynamic_size_adjusted, one));
418           dynamic_size_adjusted =
419               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
420                   dynamic_size_adjusted->shape(), HloOpcode::kMaximum,
421                   dynamic_size_adjusted, zero));
422           dynamic_size_adjusted =
423               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
424                   dynamic_size_adjusted->shape(), HloOpcode::kMultiply,
425                   dynamic_size_adjusted, interior_padding));
426           dynamic_size_adjusted =
427               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
428                   dynamic_size_adjusted->shape(), HloOpcode::kAdd,
429                   dynamic_size_adjusted, dynamic_size));
430         }
431         HloInstruction* adjustment = hlo->parent()->AddInstruction(
432             HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
433                 padding_config.edge_padding_low() +
434                 padding_config.edge_padding_high())));
435         dynamic_size_adjusted =
436             hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
437                 dynamic_size_adjusted->shape(), HloOpcode::kAdd,
438                 dynamic_size_adjusted, adjustment));
439         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted);
440         return OkStatus();
441       });
442 }
443 
HandleReduce(HloInstruction * hlo)444 Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
445   return ForEachOperandDynamicDimension(
446       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
447                int64_t operand_index, HloInstruction* dynamic_size) {
448         auto* reduce = Cast<HloReduceInstruction>(hlo);
449         int64_t operand_count = reduce->operand_count();
450         CHECK_EQ(operand_count % 2, 0);
451         if (operand_index >= reduce->input_count()) {
452           // Init values doesn't have dynamic size.
453           return OkStatus();
454         }
455         if ((absl::c_count(reduce->dimensions(), dimension) != 0)) {
456           // Dimension is to be reduced, stop tracing.
457           return OkStatus();
458         }
459 
460         // Find out the new dynamic dimension after reduce.
461         int64_t dimensions_not_reduced_count = 0;
462         for (int i = 0; i < operand->shape().rank(); ++i) {
463           if (dimension == i) {
464             // The dimensions of all data operands of a variadic reduce have
465             // to be the same.  This means that if one operand of variadic
466             // reduce has a dynamic dimension, we set all outputs to use the
467             // same dynamic size in corresponding dimensions.
468             ShapeUtil::ForEachSubshape(
469                 reduce->shape(),
470                 [&](const Shape& subshape, ShapeIndex reduce_result_index) {
471                   if (!ShapeUtil::IsLeafIndex(reduce->shape(),
472                                               reduce_result_index)) {
473                     return;
474                   }
475                   parent_->SetDynamicSize(reduce, reduce_result_index,
476                                           dimensions_not_reduced_count,
477                                           dynamic_size);
478                 });
479 
480             return OkStatus();
481           }
482           if (absl::c_count(reduce->dimensions(), i) == 0) {
483             dimensions_not_reduced_count++;
484           }
485         }
486 
487         return OkStatus();
488       });
489 }
490 
HandleDot(HloInstruction * hlo)491 Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
492   return ForEachOperandDynamicDimension(hlo, [&](HloInstruction* operand,
493                                                  ShapeIndex operand_shape_index,
494                                                  int64_t operand_dimension,
495                                                  int64_t operand_index,
496                                                  HloInstruction* dynamic_size) {
497     // There are three types of dimensions in a dot:
498     // A. batch dims
499     // B. contracting dims
500     // C. non-batch non-contracting dims.
501     // The output dimensions of a dot has three parts with the following
502     // order:
503     // [(type A), (lhs type C), (rhs type C)]
504     //
505     // Note that both lhs and rhs have the same dimension sizes for batch,
506     // but the dimension index could be different.
507     //
508     // Given one dynamic input dimension, either lhs or rhs, we use a
509     // mapping to find the corresponding output dimension.
510     HloInstruction* dot = hlo;
511     const DotDimensionNumbers& dimension_numbers = dot->dot_dimension_numbers();
512     // A map from the operand dimensions to result dimension.
513     absl::flat_hash_map<int64_t, int64_t> result_dim_mapping;
514     int64_t current_result_dims = 0;
515 
516     bool lhs = operand_index == 0;
517 
518     // The first loop keep tracks of batch dimension. RHS and LHS could have
519     // different batch dimension numbers.
520     if (lhs) {
521       for (int64_t i : dimension_numbers.lhs_batch_dimensions()) {
522         result_dim_mapping[i] = current_result_dims++;
523       }
524     } else {
525       for (int64_t i : dimension_numbers.rhs_batch_dimensions()) {
526         result_dim_mapping[i] = current_result_dims++;
527       }
528     }
529 
530     // Handle dimensions in the lhs.
531     for (int64_t i = 0; i < dot->operand(0)->shape().rank(); i++) {
532       // Look for non-contracting and non-batching dimension.
533       if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
534                                 i)) {
535         continue;
536       }
537       if (absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
538         continue;
539       }
540       if (lhs) {
541         result_dim_mapping[i] = current_result_dims;
542       }
543       current_result_dims++;
544     }
545 
546     // Handle dimensions in the rhs.
547     for (int64_t i = 0; i < dot->operand(1)->shape().rank(); i++) {
548       // Look for non-contracting and non-batching dimension.
549       if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
550                                 i)) {
551         continue;
552       }
553       if (absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
554         continue;
555       }
556       if (!lhs) {
557         result_dim_mapping[i] = current_result_dims;
558       }
559       current_result_dims++;
560     }
561 
562     // Check if the operand dim is in the result shape. If so, add another
563     // work item to trace that dimension.
564     auto iter = result_dim_mapping.find(operand_dimension);
565     if (iter != result_dim_mapping.end()) {
566       parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
567     }
568 
569     return OkStatus();
570   });
571 }
572 
HandleTranspose(HloInstruction * hlo)573 Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
574   return ForEachOperandDynamicDimension(
575       hlo,
576       [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
577           int64_t operand_index, HloInstruction* dynamic_size) -> Status {
578         int64_t permuted_dim = -1;
579         for (int64_t i = 0; i < hlo->dimensions().size(); ++i) {
580           if (hlo->dimensions()[i] == dimension) {
581             TF_RET_CHECK(permuted_dim == -1);
582             permuted_dim = i;
583           }
584         }
585         parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size);
586         return OkStatus();
587       });
588 }
589 
HandleConvolution(HloInstruction * hlo)590 Status DynamicDimensionInferenceVisitor::HandleConvolution(
591     HloInstruction* hlo) {
592   return ForEachOperandDynamicDimension(
593       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
594                int64_t operand_index, HloInstruction* dynamic_size) {
595         HloInstruction* conv = hlo;
596         const ConvolutionDimensionNumbers& dimension_numbers =
597             conv->convolution_dimension_numbers();
598         if (operand_index == 0) {
599           if (dimension == dimension_numbers.input_batch_dimension()) {
600             parent_->SetDynamicSize(conv, {},
601                                     dimension_numbers.output_batch_dimension(),
602                                     dynamic_size);
603             return OkStatus();
604           }
605 
606           if (dimension == dimension_numbers.input_feature_dimension()) {
607             return OkStatus();
608           }
609         } else {
610           if (dimension == dimension_numbers.kernel_input_feature_dimension()) {
611             return OkStatus();
612           }
613         }
614 
615         return Unimplemented("Dynamic Spatial Convolution is not supported: %s",
616                              conv->ToString());
617       });
618 }
619 
HandleConcatenate(HloInstruction * hlo)620 Status DynamicDimensionInferenceVisitor::HandleConcatenate(
621     HloInstruction* hlo) {
622   // First handle concatenate dimensions. We do this by iterating through all
623   // operands while tracking both dynamic and static dimensions.
624 
625   // static_size is used to keep track of the concated size of static
626   // dimensions.
627   int64_t static_size = 0;
628   std::vector<HloInstruction*> dynamic_concat_dims;
629   for (int64_t i = 0; i < hlo->operand_count(); ++i) {
630     HloInstruction* dynamic_size = parent_->GetDynamicSize(
631         hlo->mutable_operand(i), {}, hlo->concatenate_dimension());
632     if (dynamic_size == nullptr) {
633       // This is a static dimension.
634       static_size +=
635           hlo->operand(i)->shape().dimensions(hlo->concatenate_dimension());
636     } else {
637       dynamic_concat_dims.push_back(dynamic_size);
638     }
639   }
640   // If concat dimension is dynamic, calculate its size by summing up static
641   // dims and dynamic dims together.
642   if (!dynamic_concat_dims.empty()) {
643     HloInstruction* dim_size_total =
644         hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
645             LiteralUtil::CreateR0<int32_t>(static_size)));
646     for (HloInstruction* dynamic_dim : dynamic_concat_dims) {
647       dim_size_total = hlo->parent()->AddInstruction(
648           HloInstruction::CreateBinary(dim_size_total->shape(), HloOpcode::kAdd,
649                                        dim_size_total, dynamic_dim));
650     }
651     parent_->SetDynamicSize(hlo, {}, hlo->concatenate_dimension(),
652                             dim_size_total);
653   }
654 
655   // Simply pass through non-concat dynamic dimensions.
656   return ForEachOperandDynamicDimension(
657       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
658                int64_t operand_index, HloInstruction* dynamic_size) {
659         int64_t concatenate_dimension = hlo->concatenate_dimension();
660         if (concatenate_dimension == dimension) {
661           return OkStatus();
662         }
663         parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
664         return OkStatus();
665       });
666 }
667 
HandleGetDimensionSize(HloInstruction * gds)668 Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
669     HloInstruction* gds) {
670   // Dynamic dimension doesn't propagate through GetDimensionSize:
671   //
672   //   Input: F32[x, y, z]
673   //     |
674   //   GetDimensionSize(1): S32[]
675   //
676   // The returned value is a scalar, which doesn't have any dynamic dimension in
677   // the shape (although the value contains the real size of the dynamic
678   // dimension of the input).
679   int64_t dim = gds->dimension();
680   HloInstruction* operand = gds->mutable_operand(0);
681   HloInstruction* dynamic_size = parent_->GetDynamicSize(operand, {}, dim);
682   HloComputation* computation = gds->parent();
683   if (dynamic_size != nullptr) {
684     TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(dynamic_size));
685     // The dependency between an instruction and its dynamic dimensions is not
686     // modeled in the IR. As instr is being replaced by dynamic_size, also tell
687     // dynamic dimension inference that the instruction is being replaced.
688     parent_->ReplaceAllDynamicDimensionUsesWith(gds, dynamic_size);
689   } else {
690     TF_RET_CHECK(dim < gds->operand(0)->shape().rank());
691     int32_t size = gds->operand(0)->shape().dimensions(dim);
692     HloInstruction* new_instr = computation->AddInstruction(
693         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(size)));
694     TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(new_instr));
695     parent_->ReplaceAllDynamicDimensionUsesWith(gds, new_instr);
696   }
697   return OkStatus();
698 }
699 
HandleSetDimensionSize(HloInstruction * hlo)700 Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize(
701     HloInstruction* hlo) {
702   bool dimension_is_static = false;
703   const HloInstruction* size = hlo->operand(1);
704   if (size->opcode() == HloOpcode::kConstant) {
705     // Check if we are setting a dimension size to its static size. If so,
706     // removes the dynamic dimension.
707     //
708     // size = s32[] constant(5)
709     // s32[2, 5] = set-dimension-size(s32[2,<=5]{1,0} %param, s32[] %size),
710     //                                                        dimensions={1}
711     // The result shape has no dynamic dimension.
712     TF_RET_CHECK(size->shape().rank() == 0);
713     if (size->literal().Get<int32_t>({}) ==
714         hlo->shape().dimensions(hlo->dimension())) {
715       dimension_is_static = true;
716     }
717   }
718 
719   if (!dimension_is_static) {
720     // Propagate dynamic dimension indicated by this set dimension size
721     // instruction.
722     parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1));
723   }
724 
725   // Also Propagate dynamic dimension already set by operands.
726   TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
727       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
728                int64_t operand_index, HloInstruction* dynamic_size) {
729         if (dimension != hlo->dimension()) {
730           parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
731         }
732         return OkStatus();
733       }));
734 
735   return OkStatus();
736 }
737 
HandleDynamicConvolutionForward(HloInstruction * hlo,int64_t operand_index,int64_t dimension,HloInstruction * dynamic_size)738 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionForward(
739     HloInstruction* hlo, int64_t operand_index, int64_t dimension,
740     HloInstruction* dynamic_size) {
741   TF_RET_CHECK(operand_index == 0);
742   const ConvolutionDimensionNumbers& dimension_numbers =
743       hlo->convolution_dimension_numbers();
744 
745   if (dimension == dimension_numbers.input_batch_dimension()) {
746     // Batch dimension is propagated without any changes.
747     parent_->SetDynamicSize(hlo, {}, dimension_numbers.output_batch_dimension(),
748                             dynamic_size);
749     return OkStatus();
750   }
751 
752   for (int64_t spatial_dim_index = 0;
753        spatial_dim_index < dimension_numbers.input_spatial_dimensions_size();
754        ++spatial_dim_index) {
755     int64_t input_spatial_dim =
756         dimension_numbers.input_spatial_dimensions(spatial_dim_index);
757     int64_t output_spatial_dim =
758         dimension_numbers.output_spatial_dimensions(spatial_dim_index);
759     if (dimension == input_spatial_dim) {
760       // This is a dynamic spatial dimension. Calculate the output size.
761       WindowDimension window_dim = hlo->window().dimensions(spatial_dim_index);
762       DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
763           dynamic_size, window_dim.size(), window_dim.window_dilation(),
764           window_dim.stride(), hlo->padding_type());
765       TF_RET_CHECK(window_dim.base_dilation() == 1);
766       parent_->SetDynamicSize(hlo, {}, output_spatial_dim,
767                               dynamic_window_dims.output_size);
768       return OkStatus();
769     }
770   }
771   // Input Feature dim disappears after convolution.
772   return OkStatus();
773 }
774 
HandleDynamicWindowSamePadding(HloInstruction * hlo,HloInstruction * dynamic_size,int64_t operand_index,int64_t dimension)775 Status DynamicDimensionInferenceVisitor::HandleDynamicWindowSamePadding(
776     HloInstruction* hlo, HloInstruction* dynamic_size, int64_t operand_index,
777     int64_t dimension) {
778   const Window& window = hlo->window();
779   const WindowDimension& window_dim = window.dimensions(dimension);
780   if (!window_util::IsTrivialWindowDimension(window_dim)) {
781     DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
782         dynamic_size, window_dim.size(), window_dim.window_dilation(),
783         window_dim.stride(), PaddingType::PADDING_SAME);
784     parent_->SetDynamicSize(hlo, {}, dimension,
785                             dynamic_window_dims.output_size);
786     return OkStatus();
787   }
788 
789   parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
790 
791   return OkStatus();
792 }
793 
HandleDynamicConvolutionInputGrad(HloInstruction * hlo,int64_t operand_index,int64_t dimension)794 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionInputGrad(
795     HloInstruction* hlo, int64_t operand_index, int64_t dimension) {
796   // The output size of convolution input grad is corresponding input size.
797   HloInstruction* input_sizes = hlo->mutable_operand(0);
798   HloComputation* comp = hlo->parent();
799   TF_RET_CHECK(input_sizes->shape().rank() == 1) << hlo->ToString();
800   TF_RET_CHECK(input_sizes->shape().element_type() == S32) << hlo->ToString();
801   TF_RET_CHECK(input_sizes->shape().dimensions(0) ==
802                hlo->shape().dimensions_size())
803       << hlo->ToString();
804   // Slice to get corresponding input size.
805   HloInstruction* slice = comp->AddInstruction(
806       HloInstruction::CreateSlice(ShapeUtil::MakeShape(S32, {1}), input_sizes,
807                                   {dimension}, {dimension + 1}, {1}));
808   HloInstruction* reshape = comp->AddInstruction(
809       HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
810   parent_->SetDynamicSize(hlo, {}, dimension, reshape);
811   return OkStatus();
812 }
813 
HandleDynamicConvolutionKernelGrad(HloInstruction * hlo,int64_t operand_index,int64_t dimension)814 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionKernelGrad(
815     HloInstruction* hlo, int64_t operand_index, int64_t dimension) {
816   // Dynamic convolution kernel grad produces static shape outputs.
817   return OkStatus();
818 }
819 
PassThroughDynamicDimension(HloInstruction * hlo)820 Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension(
821     HloInstruction* hlo) {
822   return ForEachOperandDynamicDimension(
823       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
824                int64_t operand_index, HloInstruction* dynamic_size) {
825         parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
826         return OkStatus();
827       });
828 }
829 
HandleDomain(HloInstruction * hlo)830 Status DynamicDimensionInferenceVisitor::HandleDomain(HloInstruction* hlo) {
831   return PassThroughDynamicDimension(hlo);
832 }
833 
HandleElementwiseUnary(HloInstruction * hlo)834 Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary(
835     HloInstruction* hlo) {
836   return PassThroughDynamicDimension(hlo);
837 }
838 
HandleSelect(HloInstruction * hlo)839 Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {
840   return PassThroughDynamicDimension(hlo);
841 }
842 
HandleElementwiseNary(HloInstruction * hlo)843 Status DynamicDimensionInferenceVisitor::HandleElementwiseNary(
844     HloInstruction* hlo) {
845   HloComputation* comp = hlo->parent();
846   return ForEachOperandDynamicDimension(
847       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
848                int64_t operand_index, HloInstruction* dynamic_size) {
849         HloInstruction* existing_size =
850             parent_->GetDynamicSize(hlo, index, dimension);
851         if (existing_size == nullptr || existing_size == dynamic_size) {
852           parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
853         } else {
854           TF_RETURN_IF_ERROR(
855               InsertShapeCheck(existing_size, dynamic_size,
856                                /*support_implicit_broadcast=*/true));
857 
858           auto one = comp->AddInstruction(
859               HloInstruction::CreateConstant(LiteralUtil::One(S32)));
860 
861           auto operand_needs_broadcast =
862               comp->AddInstruction(HloInstruction::CreateCompare(
863                   ShapeUtil::MakeShape(PRED, {}), dynamic_size, existing_size,
864                   ComparisonDirection::kLt));
865           auto is_one = comp->AddInstruction(HloInstruction::CreateCompare(
866               ShapeUtil::MakeShape(PRED, {}), dynamic_size, one,
867               ComparisonDirection::kEq));
868           operand_needs_broadcast =
869               comp->AddInstruction(HloInstruction::CreateBinary(
870                   ShapeUtil::MakeShape(PRED, {}), HloOpcode::kAnd, is_one,
871                   operand_needs_broadcast));
872 
873           auto existing_needs_broadcast =
874               comp->AddInstruction(HloInstruction::CreateCompare(
875                   ShapeUtil::MakeShape(PRED, {}), existing_size, dynamic_size,
876                   ComparisonDirection::kLt));
877           is_one = comp->AddInstruction(HloInstruction::CreateCompare(
878               ShapeUtil::MakeShape(PRED, {}), existing_size, one,
879               ComparisonDirection::kEq));
880           existing_needs_broadcast =
881               comp->AddInstruction(HloInstruction::CreateBinary(
882                   ShapeUtil::MakeShape(PRED, {}), HloOpcode::kAnd, is_one,
883                   existing_needs_broadcast));
884 
885           auto needs_broadcast =
886               comp->AddInstruction(HloInstruction::CreateBinary(
887                   ShapeUtil::MakeShape(PRED, {}), HloOpcode::kOr,
888                   operand_needs_broadcast, existing_needs_broadcast));
889           auto max_size = comp->AddInstruction(HloInstruction::CreateBinary(
890               ShapeUtil::MakeScalarShape(S32), HloOpcode::kMaximum,
891               dynamic_size, existing_size));
892           auto min_size = comp->AddInstruction(HloInstruction::CreateBinary(
893               ShapeUtil::MakeScalarShape(S32), HloOpcode::kMinimum,
894               dynamic_size, existing_size));
895           auto select_size = comp->AddInstruction(HloInstruction::CreateTernary(
896               ShapeUtil::MakeScalarShape(S32), HloOpcode::kSelect,
897               needs_broadcast, max_size, min_size));
898           parent_->SetDynamicSize(hlo, index, dimension, select_size);
899         }
900         return OkStatus();
901       });
902 }
903 
HandleElementwiseBinary(HloInstruction * hlo)904 Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
905     HloInstruction* hlo) {
906   return HandleElementwiseNary(hlo);
907 }
908 
HandleClamp(HloInstruction * hlo)909 Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {
910   return PassThroughDynamicDimension(hlo);
911 }
912 
HandleDynamicReshape(HloInstruction * hlo)913 Status DynamicDimensionInferenceVisitor::HandleDynamicReshape(
914     HloInstruction* hlo) {
915   HloDynamicReshapeInstruction* dynamic_reshape =
916       Cast<HloDynamicReshapeInstruction>(hlo);
917   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
918     if (hlo->shape().is_dynamic_dimension(i)) {
919       parent_->SetDynamicSize(hlo, {}, i, dynamic_reshape->dim_sizes(i));
920     }
921   }
922   return OkStatus();
923 }
924 
HandleReshape(HloInstruction * hlo)925 Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
926   // First scan to see if we need to decompose the dynamic reshape into a
927   // flatten-unflatten pair. If so, find the dynamic dimension using
928   // hlo->inferred_dimension() and calculate the dynamic size for that
929   // dimension.
930   bool need_flatten_unflatten = false;
931   // For a reshape we need the inferred_dimension to be present to disambiguate
932   // dynamic dimensions of hlo. HloOpcode::kDynamicReshape on the other hand
933   // allows more precise specification of dynamic dimensions of hlo's shape.
934   if (hlo->inferred_dimension() != -1) {
935     TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
936         hlo,
937         [&](HloInstruction* operand, ShapeIndex index,
938             int64_t input_dynamic_dimension, int64_t operand_index,
939             HloInstruction* operand_dynamic_size) -> Status {
940           auto common_factors = CommonFactors(operand->shape().dimensions(),
941                                               hlo->shape().dimensions());
942           int64_t input_dim_start = -1;
943           int64_t input_dim_end = -1;
944           int64_t output_dim_start = -1;
945           int64_t output_dim_end = -1;
946           // Find common_factors that the input belongs to.
947           for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
948             auto start = common_factors[i];
949             auto end = common_factors[i + 1];
950             if (input_dynamic_dimension >= start.first &&
951                 input_dynamic_dimension < end.first) {
952               // Found the common_factor group that the input_dim belongs to.
953               input_dim_start = start.first;
954               input_dim_end = end.first;
955               output_dim_start = start.second;
956               output_dim_end = end.second;
957             }
958           }
959           if ((input_dim_end - input_dim_start) > 1 &&
960               (output_dim_end - output_dim_start) > 1) {
961             need_flatten_unflatten = true;
962           }
963           return OkStatus();
964         }));
965     if (need_flatten_unflatten) {
966       HloInstruction* operand = hlo->mutable_operand(0);
967       HloComputation* comp = hlo->parent();
968       HloInstruction* dynamic_size = comp->AddInstruction(
969           HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
970       int64_t static_size = 1;
971       for (int64_t i = 0; i < operand->shape().rank(); i++) {
972         HloInstruction* dynamic_dim_size =
973             parent_->GetDynamicSize(operand, {}, i);
974         if (dynamic_dim_size == nullptr) {
975           static_size *= operand->shape().dimensions(i);
976         } else {
977           dynamic_size = comp->AddInstruction(HloInstruction::CreateBinary(
978               dynamic_size->shape(), HloOpcode::kMultiply, dynamic_size,
979               dynamic_dim_size));
980         }
981       }
982       HloInstruction* static_size_hlo =
983           comp->AddInstruction(HloInstruction::CreateConstant(
984               LiteralUtil::CreateR0<int32_t>(static_size)));
985       // Total dynamic shape size.
986       dynamic_size = comp->AddInstruction(HloInstruction::CreateBinary(
987           dynamic_size->shape(), HloOpcode::kMultiply, dynamic_size,
988           static_size_hlo));
989 
990       int64_t size_without_inferred_dim =
991           ShapeUtil::ElementsIn(hlo->shape()) /
992           hlo->shape().dimensions(hlo->inferred_dimension());
993       HloInstruction* size_without_inferred_dim_hlo =
994           comp->AddInstruction(HloInstruction::CreateConstant(
995               LiteralUtil::CreateR0<int32_t>(size_without_inferred_dim)));
996       dynamic_size = comp->AddInstruction(HloInstruction::CreateBinary(
997           dynamic_size->shape(), HloOpcode::kDivide, dynamic_size,
998           size_without_inferred_dim_hlo));
999       parent_->SetDynamicSize(hlo, {}, hlo->inferred_dimension(), dynamic_size);
1000       VLOG(3)
1001           << "Need to decopose a dynamic reshape to flatten-unflatten pair. "
1002           << comp->parent()->ToString();
1003       return OkStatus();
1004     }
1005   }
1006 
1007   return ForEachOperandDynamicDimension(
1008       hlo,
1009       [&](HloInstruction* operand, ShapeIndex index,
1010           int64_t input_dynamic_dimension, int64_t operand_index,
1011           HloInstruction* operand_dynamic_size) -> Status {
1012         HloInstruction* reshape = hlo;
1013         if (reshape->shape().rank() == 0) {
1014           VLOG(0) << "Reshaping a dynamic dimension into a scalar, which has "
1015                      "undefined behavior when input size is 0. The offending "
1016                      "instruction is: "
1017                   << reshape->ToString();
1018           return OkStatus();
1019         }
1020         auto common_factors = CommonFactors(operand->shape().dimensions(),
1021                                             reshape->shape().dimensions());
1022         int64_t input_dim_start = -1;
1023         int64_t input_dim_end = -1;
1024         int64_t output_dim_start = -1;
1025         int64_t output_dim_end = -1;
1026         // Find common_factors that the input belongs to.
1027         for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
1028           auto start = common_factors[i];
1029           auto end = common_factors[i + 1];
1030           if (input_dynamic_dimension >= start.first &&
1031               input_dynamic_dimension < end.first) {
1032             // Found the common_factor group that the input_dim belongs to.
1033             input_dim_start = start.first;
1034             input_dim_end = end.first;
1035             output_dim_start = start.second;
1036             output_dim_end = end.second;
1037           }
1038         }
1039 
1040         VLOG(2) << "Input dim start: " << input_dim_start
1041                 << " Input dim end: " << input_dim_end
1042                 << " output dim start: " << output_dim_start
1043                 << " output dim end: " << output_dim_end;
1044 
1045         if ((input_dim_end - input_dim_start) > 1 &&
1046             (output_dim_end - output_dim_start) > 1) {
1047           return InternalError(
1048               "Should be handled by decomposing reshape into "
1049               "flatten-unflatten pair. %s",
1050               hlo->ToString());
1051         }
1052 
1053         for (auto common_factor : common_factors) {
1054           // Expand common factor to include degenerated output dimensions.
1055           if (common_factor.first == input_dim_start) {
1056             output_dim_start = std::min(output_dim_start, common_factor.second);
1057           }
1058           if (common_factor.first == input_dim_end) {
1059             output_dim_end = std::max(output_dim_end, common_factor.second);
1060           }
1061         }
1062 
1063         int64_t output_dynamic_dimension = -1;
1064 
1065         if (operand->shape().dimensions(input_dynamic_dimension) == 1) {
1066           // If dynamic dimension is 1, it can only be most-major or
1067           // most-minor.
1068           if (input_dynamic_dimension == 0) {
1069             output_dynamic_dimension = 0;
1070           } else if (input_dynamic_dimension == operand->shape().rank() - 1) {
1071             output_dynamic_dimension = reshape->shape().rank() - 1;
1072           }
1073 
1074           if (output_dynamic_dimension == -1) {
1075             return Unimplemented(
1076                 "Dynamic degenerated dimension that's not most-minor nor "
1077                 "most-major is not supported %s",
1078                 reshape->ToString());
1079           }
1080         }
1081 
1082         if (output_dynamic_dimension == -1 &&
1083             output_dim_end - output_dim_start == 1) {
1084           // Only one possible output dimension.
1085           output_dynamic_dimension = output_dim_start;
1086         }
1087 
1088         if (output_dynamic_dimension == -1 &&
1089             output_dim_end - output_dim_start > 1) {
1090           // One input dimension is splitted into multiple output dimensions.
1091           // Output dimension is decomposed from input most major dimension.
1092           // In this case, we don't know which one is dynamic, e.g., when we
1093           // have:
1094           //
1095           //           [<=a/c, c, b]
1096           //              | Reshape
1097           //           [<=a, b] // a is dynamic, has to be multiple of c.
1098           //             |  Reshape
1099           // [1, 1, ... , a/c, c, b]
1100           //
1101           // Any dimension from the first '1' to 'a/c' can be dynamic.
1102           //
1103           // We use the following logics to disambiguate:
1104           // 1. If the user sets "inferred_dimension", then use that as
1105           // dynamic dimension.
1106           // 2. If the one dimension in the reshape is dynamic, use that as
1107           // dynamic dimension.
1108           // E.g.:
1109           //     [<=4]
1110           //      |
1111           //   reshape
1112           //      |
1113           //   [1, <=2, 2]
1114           // We use second dim as dynamic dimension.
1115           //
1116           // 3. If all logics above cannot disambiguate, e.g.,:
1117           //
1118           //     [<=1]
1119           //      |
1120           //   reshape
1121           //      |
1122           //   [1, 1, 1]
1123           //
1124           //   We bail out and return an error.
1125           // TODO(yunxing): Further simplify this, remove 1. and fully rely
1126           // on 2.
1127           output_dynamic_dimension = reshape->inferred_dimension();
1128           if (output_dynamic_dimension == -1) {
1129             // Try find dynamic dimension from the result shape.
1130             for (int64_t i = output_dim_start; i < output_dim_end; ++i) {
1131               if (reshape->shape().is_dynamic_dimension(i)) {
1132                 output_dynamic_dimension = i;
1133               }
1134             }
1135           }
1136 
1137           if (output_dynamic_dimension == -1) {
1138             std::vector<int64_t> output_non_degenerated;
1139             for (int64_t i = output_dim_start; i < output_dim_end; ++i) {
1140               if (reshape->shape().dimensions(i) != 1) {
1141                 output_non_degenerated.push_back(i);
1142               }
1143             }
1144             if (output_non_degenerated.size() == 1) {
1145               output_dynamic_dimension = output_non_degenerated[0];
1146             }
1147           }
1148 
1149           if (output_dynamic_dimension == -1) {
1150             return InvalidArgument(
1151                 "Reshape's input dynamic dimension is decomposed into "
1152                 "multiple output dynamic dimensions, but the constraint is "
1153                 "ambiguous and XLA can't infer the output dimension %s. ",
1154                 hlo->ToString());
1155           }
1156         }
1157 
1158         CHECK_NE(output_dynamic_dimension, -1);
1159         const int64_t input_dim_size =
1160             operand->shape().dimensions(input_dynamic_dimension);
1161         const int64_t output_dim_size =
1162             reshape->shape().dimensions(output_dynamic_dimension);
1163         VLOG(2) << "input_dim_size: " << input_dim_size
1164                 << " output_dim_size: " << output_dim_size;
1165 
1166         if (input_dim_size == output_dim_size) {
1167           // Simply forward dynamic dimension.
1168           parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1169                                   operand_dynamic_size);
1170         }
1171 
1172         if (input_dim_size > output_dim_size) {
1173           TF_RET_CHECK(input_dim_size % output_dim_size == 0)
1174               << reshape->ToString();
1175           const int64_t divisor = input_dim_size / output_dim_size;
1176           HloInstruction* divisor_hlo =
1177               hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
1178                   LiteralUtil::CreateR0<int32_t>(divisor)));
1179 
1180           HloInstruction* new_dynamic_size =
1181               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1182                   operand_dynamic_size->shape(), HloOpcode::kDivide,
1183                   operand_dynamic_size, divisor_hlo));
1184 
1185           parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1186                                   new_dynamic_size);
1187         }
1188 
1189         if (input_dim_size < output_dim_size) {
1190           // Input dimension is combined with other input dimensions.
1191           //
1192           // Adjust the output size by the ratio of dynamic_input_dim /
1193           // static_input_dim.
1194           //
1195           // For example if we have  [<=3, 3] -> [9], if the dynamic size is 2,
1196           // the new output dynamic isze is 9 / 3 * 2 = 6.
1197           //
1198           // If it turns out the second dimension is also dynamic:
1199           // [<=3, <=3] -> [9], and the dynamic size is also 2, the new output
1200           // dynamic size is 6 / 3 * 2 = 4.
1201           //
1202           //
1203           HloInstruction* output_dynamic_size =
1204               parent_->GetDynamicSize(reshape, {}, output_dynamic_dimension);
1205           if (output_dynamic_size == nullptr) {
1206             output_dynamic_size =
1207                 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
1208                     LiteralUtil::CreateR0<int32_t>(output_dim_size)));
1209           }
1210           HloInstruction* divisor_hlo = hlo->parent()->AddInstruction(
1211               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
1212                   operand->shape().dimensions(input_dynamic_dimension))));
1213 
1214           HloInstruction* new_dynamic_size =
1215               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1216                   output_dynamic_size->shape(), HloOpcode::kDivide,
1217                   output_dynamic_size, divisor_hlo));
1218 
1219           new_dynamic_size =
1220               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1221                   output_dynamic_size->shape(), HloOpcode::kMultiply,
1222                   new_dynamic_size, operand_dynamic_size));
1223           parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1224                                   new_dynamic_size);
1225         }
1226 
1227         return OkStatus();
1228       });
1229 }
1230 
HandleReduceWindow(HloInstruction * hlo)1231 Status DynamicDimensionInferenceVisitor::HandleReduceWindow(
1232     HloInstruction* hlo) {
1233   return ForEachOperandDynamicDimension(
1234       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
1235                int64_t operand_index, HloInstruction* dynamic_size) {
1236         auto* reduce_window = Cast<HloReduceWindowInstruction>(hlo);
1237         const WindowDimension& window_dim =
1238             reduce_window->window().dimensions(dimension);
1239 
1240         if (operand_index >= reduce_window->input_count()) {
1241           // Init values doesn't have dynamic size.
1242           return OkStatus();
1243         }
1244 
1245         if (!window_util::IsTrivialWindowDimension(window_dim)) {
1246           DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1247               dynamic_size, window_dim.size(), window_dim.window_dilation(),
1248               window_dim.stride(), PaddingType::PADDING_VALID);
1249           dynamic_size = dynamic_window_dims.output_size;
1250         }
1251 
1252         // The dimensions of all data operands of a variadic reduce window have
1253         // to be the same.  This means that if one operand of variadic
1254         // reduce has a dynamic dimension, we set all outputs to use the
1255         // same dynamic size in corresponding dimensions.
1256         ShapeUtil::ForEachSubshape(
1257             reduce_window->shape(),
1258             [&](const Shape& subshape, ShapeIndex reduce_window_result_index) {
1259               if (!ShapeUtil::IsLeafIndex(reduce_window->shape(),
1260                                           reduce_window_result_index)) {
1261                 return;
1262               }
1263               parent_->SetDynamicSize(reduce_window, reduce_window_result_index,
1264                                       dimension, dynamic_size);
1265             });
1266 
1267         return OkStatus();
1268       });
1269 }
1270 
HandleSelectAndScatter(HloInstruction * hlo)1271 Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
1272     HloInstruction* hlo) {
1273   return ForEachOperandDynamicDimension(
1274       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
1275                int64_t operand_index, HloInstruction* dynamic_size) {
1276         if (operand_index == 1) {
1277           // Operand 0 (input) determines dynamic output size. We ignore the
1278           // dynamic size in the operand 1 (output gradient).
1279           return OkStatus();
1280         }
1281         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1282 
1283         return OkStatus();
1284       });
1285 }
1286 
HandleSlice(HloInstruction * hlo)1287 Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
1288   return ForEachOperandDynamicDimension(
1289       hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64_t dimension,
1290                int64_t /*operand_index*/, HloInstruction* dynamic_size) {
1291         if (hlo->slice_starts(dimension) != 0 ||
1292             hlo->slice_strides(dimension) != 1 ||
1293             hlo->slice_limits(dimension) !=
1294                 operand->shape().dimensions(dimension)) {
1295           // Slicing a partial element out eliminates the dynamic dimension.
1296           return OkStatus();
1297         }
1298 
1299         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1300 
1301         return OkStatus();
1302       });
1303 }
1304 
HandleDynamicSlice(HloInstruction * hlo)1305 Status DynamicDimensionInferenceVisitor::HandleDynamicSlice(
1306     HloInstruction* hlo) {
1307   return ForEachOperandDynamicDimension(
1308       hlo, [&](HloInstruction*, ShapeIndex /*index*/, int64_t dimension,
1309                int64_t /*operand_index*/, HloInstruction* dynamic_size) {
1310         if (hlo->shape().dimensions(dimension) !=
1311             hlo->operand(0)->shape().dimensions(dimension)) {
1312           // Slicing a single element out kills the dynamic dimension.
1313           if (hlo->shape().dimensions(dimension) == 1) {
1314             return OkStatus();
1315           }
1316           return Unimplemented(
1317               "Dynamic dimension propagation on DynamicSlice where a partial "
1318               "dimension is selected %s",
1319               hlo->ToString());
1320         }
1321 
1322         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1323 
1324         return OkStatus();
1325       });
1326 }
1327 
HandleDynamicUpdateSlice(HloInstruction * hlo)1328 Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice(
1329     HloInstruction* hlo) {
1330   return ForEachOperandDynamicDimension(
1331       hlo,
1332       [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64_t dimension,
1333           int64_t operand_index, HloInstruction* dynamic_size) {
1334         if (hlo->shape().dimensions(dimension) !=
1335             hlo->operand(0)->shape().dimensions(dimension)) {
1336           return Unimplemented(
1337               "Dynamic dimension propagation on DynamicUpdateSlice where a "
1338               "partial dimension is selected %s",
1339               hlo->ToString());
1340         }
1341 
1342         if (operand_index == 1 &&
1343             hlo->operand(1)->shape().dimensions(dimension) <
1344                 hlo->operand(0)->shape().dimensions(dimension)) {
1345           // DUS(input=[A], update=[<=B])
1346           //
1347           // If update dim is smaller than input dim (B < A) , then we are doing
1348           // a partial update, no need to set the output dynamic dimension.
1349           //
1350           // The dynamic shape in `update` doesn't change output dynamic shape.
1351           return OkStatus();
1352         }
1353 
1354         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1355 
1356         return OkStatus();
1357       });
1358 }
1359 
HandleReverse(HloInstruction * hlo)1360 Status DynamicDimensionInferenceVisitor::HandleReverse(HloInstruction* hlo) {
1361   return PassThroughDynamicDimension(hlo);
1362 }
1363 
HandleGather(HloInstruction * hlo)1364 Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
1365   return ForEachOperandDynamicDimension(
1366       hlo, [&](HloInstruction* operand, ShapeIndex /*index*/,
1367                int64_t input_dynamic_dimension, int64_t operand_index,
1368                HloInstruction* dynamic_size) {
1369         const GatherDimensionNumbers& gather_dims =
1370             hlo->gather_dimension_numbers();
1371         if (operand_index != 1) {
1372           if (hlo->gather_slice_sizes()[input_dynamic_dimension] == 1) {
1373             // Gathering a size 1 dimension out of a dynamic dimension removes
1374             // the dynamicity.
1375             return OkStatus();
1376           }
1377           if (hlo->gather_slice_sizes()[input_dynamic_dimension] ==
1378               operand->shape().dimensions(input_dynamic_dimension)) {
1379             // Gathering a full-sized dimension out of a dynamic dimension
1380             // propagates the dynamicity to output.
1381             int64_t output_dimension = input_dynamic_dimension;
1382             for (int64_t collapsed_dim : gather_dims.collapsed_slice_dims()) {
1383               if (collapsed_dim < input_dynamic_dimension) {
1384                 // This output dimension is collapsed.
1385                 output_dimension--;
1386               }
1387             }
1388             parent_->SetDynamicSize(hlo, {}, output_dimension, dynamic_size);
1389             return OkStatus();
1390           }
1391           return Unimplemented(
1392               "Detects a dynamic dimension on the data input of gather, which "
1393               "is not supported: %s, %lld",
1394               hlo->ToString(), input_dynamic_dimension);
1395         }
1396         // A mapping from output to input batch dim number. -1 means not a batch
1397         // dimension.
1398         int64_t indices_rank = hlo->operand(1)->shape().rank();
1399         int64_t output_rank = hlo->shape().rank();
1400 
1401         // indices_dim is an iterator over indices dimensions.
1402         int64_t indices_dim = 0;
1403         // Find the corresponding batch dimension in the output.
1404         for (int64_t output_dim = 0; output_dim < output_rank; ++output_dim) {
1405           if (!absl::c_linear_search(gather_dims.offset_dims(), output_dim)) {
1406             // Skips index vector dimension.
1407             if (indices_dim == gather_dims.index_vector_dim()) {
1408               indices_dim++;
1409             }
1410             if (indices_dim++ == input_dynamic_dimension) {
1411               parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size);
1412               return OkStatus();
1413             }
1414           }
1415         }
1416         CHECK(indices_dim == indices_rank);
1417 
1418         return Unimplemented(
1419             "Detects a non-batch dynamic dimension of gather, "
1420             "which is not supported: %s",
1421             hlo->ToString());
1422       });
1423 }
1424 
HandleConditional(HloInstruction * hlo)1425 Status DynamicDimensionInferenceVisitor::HandleConditional(
1426     HloInstruction* hlo) {
1427   // Conditionals are handled by producing additional inputs and outputs of
1428   // the conditional instruction.
1429   std::vector<HloComputation*> new_branch_computations;
1430   std::vector<HloInstruction*> new_operands;
1431   // If the output of the conditional contains dynamic dimension. We send
1432   // dynamic dimension size out by adding additional root element. A mapping
1433   // from the root instruction's dynamic dimension index (represented by a shape
1434   // index as output index and a int64_t dimension number) to output index
1435   // (represented by an int64_t) is tracked for the conditional intsruction (all
1436   // branches should have the same mapping).
1437   ShapeTree<absl::flat_hash_map<int64_t, int64_t>> dynamic_output_mapping(
1438       hlo->shape());
1439 
1440   bool need_rewrite = false;
1441   for (int64_t branch_index = 0; branch_index < hlo->branch_count();
1442        ++branch_index) {
1443     std::vector<HloInstruction*> operands_to_add;
1444 
1445     absl::flat_hash_map<HloInstruction*, int64_t>
1446         dynamic_size_to_operand_id_index_map;
1447     // Only look at branch_index + 1, the correct operand index for a
1448     // given branch.
1449     const int64_t operand_index = branch_index + 1;
1450 
1451     int operand_count =
1452         hlo->operand(operand_index)->shape().tuple_shapes_size();
1453     // Prepare to pass dynamic dimension into the new computation and add
1454     // dynamic dimension sizes as parameters to the new tuple.
1455     TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
1456         hlo, operand_index,
1457         [&](HloInstruction*, ShapeIndex, int64_t, int64_t,
1458             HloInstruction* dynamic_size) -> Status {
1459           TF_RET_CHECK(hlo->operand(operand_index)->shape().IsTuple())
1460               << "Only tuple typed inputs can have dynamic dimension. Please "
1461                  "file a bug against XLA team.";
1462           const HloInstruction* tuple_operand = hlo->operand(operand_index);
1463           for (int64_t i = 0; i < tuple_operand->operand_count(); ++i) {
1464             // If the dynamic size is already an operand to the computation,
1465             // skip adding it to the computation input again.
1466             if (dynamic_size == tuple_operand->operand(i)) {
1467               dynamic_size_to_operand_id_index_map[dynamic_size] = i;
1468               return OkStatus();
1469             }
1470           }
1471           auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size);
1472           if (iter == dynamic_size_to_operand_id_index_map.end()) {
1473             operands_to_add.push_back(dynamic_size);
1474             dynamic_size_to_operand_id_index_map[dynamic_size] =
1475                 operand_count++;
1476           }
1477           return OkStatus();
1478         }));
1479 
1480     HloInstruction* original_input = hlo->mutable_operand(operand_index);
1481     HloComputation* branch_computation = hlo->branch_computation(branch_index);
1482 
1483     HloComputation* new_computation = branch_computation;
1484     HloInstruction* new_operand = hlo->mutable_operand(operand_index);
1485     if (!operands_to_add.empty()) {
1486       TF_RET_CHECK(original_input->shape().IsTuple());
1487       need_rewrite = true;
1488       new_operand = TupleUtil::AppendSuffix(original_input, operands_to_add);
1489       TF_ASSIGN_OR_RETURN(
1490           new_computation,
1491           WidenComputation(branch_computation, new_operand->shape()));
1492     }
1493     // Set the dynamic dimensions for the newly created branch computation's
1494     // parameters so that the hlos inside the computation can see dynamic
1495     // dimensions.
1496     DynamicParameterBinding dynamic_parameter_binding;
1497     TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
1498         hlo, operand_index,
1499         [&](HloInstruction*, ShapeIndex index, int64_t dimension,
1500             int64_t operand_index, HloInstruction* dynamic_size) {
1501           DynamicParameterBinding::DynamicParameter dynamic_parameter{
1502               0, {dynamic_size_to_operand_id_index_map[dynamic_size]}};
1503           DynamicParameterBinding::DynamicDimension dynamic_dimension{
1504               0, {index}, dimension};
1505           TF_RETURN_IF_ERROR(dynamic_parameter_binding.Bind(dynamic_parameter,
1506                                                             dynamic_dimension));
1507 
1508           return OkStatus();
1509         }));
1510     VLOG(2) << "dynamic_parameter_binding for conditional branch"
1511             << dynamic_parameter_binding;
1512     TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1513         new_computation, dynamic_parameter_binding, parent_));
1514 
1515     new_branch_computations.push_back(new_computation);
1516     new_operands.push_back(new_operand);
1517   }
1518   int tuple_count = hlo->shape().tuple_shapes_size();
1519   // The dynamism of the output of branches can be different.
1520   // E.g.,
1521   //   true_branch  (s32[<=4])
1522   //   false_branch (s32[4])
1523   //
1524   // The following loop populates dynamic_output_mapping and account for
1525   // dynamism across all branches.
1526   ShapeUtil::ForEachSubshape(
1527       hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
1528         if (!subshape.IsArray()) {
1529           return;
1530         }
1531         for (int64_t i = 0; i < subshape.rank(); ++i) {
1532           for (int64_t j = 0; j < new_branch_computations.size(); ++j) {
1533             HloInstruction* dynamic_size = parent_->GetDynamicSize(
1534                 new_branch_computations[j]->root_instruction(), index, i);
1535             if (dynamic_size) {
1536               if (dynamic_output_mapping.element(index).contains(i)) {
1537                 continue;
1538               }
1539               dynamic_output_mapping.mutable_element(index)->emplace(
1540                   i, tuple_count++);
1541             }
1542           }
1543         }
1544       });
1545   for (int64_t branch_index = 0; branch_index < hlo->branch_count();
1546        ++branch_index) {
1547     std::vector<HloInstruction*> hlos_to_add_in_root;
1548     // There may be some dynamic dimensions coming out of the computation, wire
1549     // that into the root instruction as additional tuple elements.
1550     ShapeUtil::ForEachSubshape(
1551         hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
1552           if (!subshape.IsArray()) {
1553             return;
1554           }
1555           for (int64_t i = 0; i < subshape.rank(); ++i) {
1556             if (dynamic_output_mapping.element(index).contains(i)) {
1557               HloInstruction* dynamic_size = parent_->GetDynamicSize(
1558                   new_branch_computations[branch_index]->root_instruction(),
1559                   index, i);
1560               if (dynamic_size) {
1561                 hlos_to_add_in_root.push_back(dynamic_size);
1562               } else {
1563                 HloInstruction* constant_size =
1564                     new_branch_computations[branch_index]->AddInstruction(
1565                         HloInstruction::CreateConstant(
1566                             LiteralUtil::CreateR0<int32_t>(
1567                                 subshape.dimensions(i))));
1568                 hlos_to_add_in_root.push_back(constant_size);
1569               }
1570             }
1571           }
1572         });
1573 
1574     VLOG(2) << "hlos_to_add_in_root:" << hlos_to_add_in_root.size();
1575     if (!hlos_to_add_in_root.empty()) {
1576       need_rewrite = true;
1577       HloInstruction* new_branch_root = TupleUtil::AppendSuffix(
1578           new_branch_computations[branch_index]->root_instruction(),
1579           hlos_to_add_in_root);
1580       new_branch_computations[branch_index]->set_root_instruction(
1581           new_branch_root,
1582           /*accept_different_shape=*/true);
1583     }
1584   }
1585 
1586   if (!need_rewrite) {
1587     return OkStatus();
1588   }
1589   // Create a new conditional with the new operations and computations.
1590   HloInstruction* new_conditional =
1591       hlo->parent()->AddInstruction(HloInstruction::CreateConditional(
1592           new_branch_computations[0]->root_instruction()->shape(),
1593           hlo->mutable_operand(0), new_branch_computations, new_operands));
1594 
1595   HloInstruction* new_conditional_extracted = TupleUtil::ExtractPrefix(
1596       new_conditional, hlo->shape().tuple_shapes_size());
1597   // Now set the dynamic dimensions of the newly created conditional.
1598   dynamic_output_mapping.ForEachElement(
1599       [&](const ShapeIndex& index,
1600           const absl::flat_hash_map<int64_t, int64_t>& dim_to_output) {
1601         for (auto iter : dim_to_output) {
1602           int64_t dim = iter.first;
1603           int64_t output_index = iter.second;
1604           HloInstruction* dynamic_size = hlo->parent()->AddInstruction(
1605               HloInstruction::CreateGetTupleElement(
1606                   ShapeUtil::MakeScalarShape(S32), new_conditional,
1607                   output_index));
1608           parent_->SetDynamicSize(new_conditional, index, dim, dynamic_size);
1609           parent_->SetDynamicSize(new_conditional_extracted, index, dim,
1610                                   dynamic_size);
1611         }
1612       });
1613 
1614   TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_conditional_extracted));
1615   // Remove the original instruction even if has side-effects.
1616   TF_RETURN_IF_ERROR(hlo->parent()->RemoveInstruction(hlo));
1617   SetVisited(*new_conditional);
1618   SetVisited(*new_conditional_extracted);
1619   return OkStatus();
1620 }
1621 
HandleMap(HloInstruction * hlo)1622 Status DynamicDimensionInferenceVisitor::HandleMap(HloInstruction* hlo) {
1623   return HandleElementwiseNary(hlo);
1624 }
1625 
HandleScatter(HloInstruction * hlo)1626 Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
1627   return ForEachOperandDynamicDimension(
1628       hlo,
1629       [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64_t dimension,
1630           int64_t operand_index, HloInstruction* operand_dynamic_size) {
1631         if (operand_index == 0) {
1632           parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size);
1633           return OkStatus();
1634         }
1635 
1636         const ScatterDimensionNumbers& scatter_dims =
1637             hlo->scatter_dimension_numbers();
1638         if (operand_index == 2 &&
1639             absl::c_linear_search(scatter_dims.update_window_dims(),
1640                                   dimension)) {
1641           // Dynamic update window dimension is only allowed if it is exactly
1642           // the same as the corresponding operand dimension.
1643           std::vector<int64_t> update_window_dims_in_operand;
1644           for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
1645             if (absl::c_linear_search(scatter_dims.inserted_window_dims(), i)) {
1646               continue;
1647             }
1648             update_window_dims_in_operand.push_back(i);
1649           }
1650 
1651           for (int64_t i = 0; i < scatter_dims.update_window_dims_size(); ++i) {
1652             if (scatter_dims.update_window_dims(i) == dimension) {
1653               const Shape& operand_shape = hlo->operand(0)->shape();
1654               const Shape& update_shape = hlo->operand(2)->shape();
1655               int64_t dim_in_operand = update_window_dims_in_operand[i];
1656               if (operand_shape.dimensions(dim_in_operand) !=
1657                       update_shape.dimensions(dimension) ||
1658                   !operand_shape.is_dynamic_dimension(dim_in_operand)) {
1659                 return Unimplemented(
1660                     "Dynamic dimension of update window dims that are not the "
1661                     "same as corresponding operand dim is not supported: "
1662                     "%s",
1663                     hlo->ToString());
1664               }
1665               HloInstruction* base_dynamic_size = parent_->GetDynamicSize(
1666                   hlo->mutable_operand(0), {}, dim_in_operand);
1667               if (base_dynamic_size != operand_dynamic_size) {
1668                 return Unimplemented(
1669                     "Dynamic dimension size of update window dims that are not "
1670                     "the same as corresponding operand dim is not supported: "
1671                     "%s.\n Dynamic dim size of base: %s, dynamic dim size of "
1672                     "update: %s",
1673                     hlo->ToString(), base_dynamic_size->ToString(),
1674                     operand_dynamic_size->ToString());
1675               }
1676             }
1677           }
1678         }
1679         // The dynamic dimension is collapsed and won't show up in the output.
1680         // Do nothing here.
1681         return OkStatus();
1682       });
1683 }
1684 
HandleWhile(HloInstruction * hlo)1685 Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
1686   // If the output of the kWhile contains dynamic dimension, we send
1687   // dynamic dimension size into the while body by adding additional root/body
1688   // element. A mapping from the root instruction's dynamic dimension index
1689   // (represented by a shape index as output index and an int64_t dimension
1690   // number) to output index (represented by an int64_t) is tracked for the
1691   // while instruction.
1692   ShapeTree<absl::flat_hash_map<int64_t, int64_t>> dynamic_output_mapping(
1693       hlo->shape());
1694   std::vector<HloInstruction*> operands_to_add;
1695   const int original_tuple_count = hlo->shape().tuple_shapes_size();
1696   int operand_count = original_tuple_count;
1697   TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
1698       hlo, [&](HloInstruction*, ShapeIndex index, int64_t dim, int64_t,
1699                HloInstruction* dynamic_size) {
1700         operands_to_add.push_back(dynamic_size);
1701         dynamic_output_mapping.mutable_element(index)->emplace(dim,
1702                                                                operand_count++);
1703         return OkStatus();
1704       }));
1705   ShapeUtil::ForEachSubshape(
1706       hlo->while_body()->root_instruction()->shape(),
1707       [&](const Shape& subshape, const ShapeIndex& index) {
1708         if (!subshape.IsArray()) {
1709           return;
1710         }
1711         for (int64_t dim = 0; dim < subshape.rank(); ++dim) {
1712           if (subshape.is_dynamic_dimension(dim)) {
1713             if (!dynamic_output_mapping.mutable_element(index)->contains(dim)) {
1714               // This dynamic dimension doesn't come from operand, but is
1715               // generated in the middle of the while body. Its initial size
1716               // should be static.
1717               operands_to_add.push_back(hlo->parent()->AddInstruction(
1718                   HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
1719                       subshape.dimensions(dim)))));
1720               dynamic_output_mapping.mutable_element(index)->emplace(
1721                   dim, operand_count++);
1722             }
1723           }
1724         }
1725       });
1726   DynamicParameterBinding binding_for_while;
1727   if (!operands_to_add.empty()) {
1728     // Only replace the while loop if there are new parameters to add.
1729     HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
1730     TF_ASSIGN_OR_RETURN(
1731         WhileUtil::MakeInstructionsLiveInResult result,
1732         WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add));
1733     // WhileUtil creates a new while hlo and tuple. Update the dynamic size
1734     // mapping for the newly created tuple.
1735     HloInstruction* new_tuple_operand =
1736         result.new_while_instr->mutable_operand(0);
1737     parent_->CopyMapping(/*from=*/old_tuple_operand,
1738                          /*to=*/new_tuple_operand);
1739     hlo = result.new_while_instr;
1740     // We have replaced the while loop, now set the dynamic dimensions for the
1741     // newly created while loop so that the hlos that consumes the while loop
1742     // can see the dynamic dimensions. Also sets the dynamic parameter binding
1743     // for running inference in the while loop.
1744     TF_RETURN_IF_ERROR(dynamic_output_mapping.ForEachElementWithStatus(
1745         [&](const ShapeIndex& index,
1746             const absl::flat_hash_map<int64_t, int64_t>& dim_to_size) {
1747           for (auto key : dim_to_size) {
1748             int64_t dimension = key.first;
1749             const int64_t output_dynamic_size_index = key.second;
1750             DynamicParameterBinding::DynamicParameter dynamic_parameter{
1751                 0, {output_dynamic_size_index}};
1752             DynamicParameterBinding::DynamicDimension dynamic_dimension{
1753                 0, index, dimension};
1754             TF_RETURN_IF_ERROR(
1755                 binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
1756             // This is the updated output dynamic size coming out of hlo while
1757             // loop.
1758             HloInstruction* output_dynamic_size = hlo->parent()->AddInstruction(
1759                 HloInstruction::CreateGetTupleElement(
1760                     ShapeUtil::MakeScalarShape(S32), hlo,
1761                     output_dynamic_size_index));
1762             parent_->SetDynamicSize(result.replacement_instr, index, dimension,
1763                                     output_dynamic_size);
1764           }
1765           return OkStatus();
1766         }));
1767     // Set the replacement instruction as visited to avoid visiting it again.
1768     SetVisited(*result.replacement_instr);
1769   }
1770   // Run inference in while body and condition.
1771   TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1772       hlo->while_body(), binding_for_while, parent_));
1773   TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1774       hlo->while_condition(), binding_for_while, parent_));
1775 
1776   if (operands_to_add.empty()) {
1777     // No dynamic dimension in the inputs and outputs.
1778     return OkStatus();
1779   }
1780 
1781   // The dynamic dimension size could have been changed in the loop body (e.g, A
1782   // loop that inserts items in a stack, the stack size increases with each
1783   // iteration). Rewrite the dynamic dimension size at the root.
1784   HloInstruction* body_root = hlo->while_body()->root_instruction();
1785   std::vector<HloInstruction*> new_root_operands(body_root->operand_count(),
1786                                                  nullptr);
1787 
1788   // Original non-dynamic-dim operands of root are pass-through.
1789   for (int i = 0; i < original_tuple_count; ++i) {
1790     new_root_operands[i] =
1791         hlo->while_body()->AddInstruction(HloInstruction::CreateGetTupleElement(
1792             body_root->shape().tuple_shapes(i), body_root, i));
1793   }
1794   // Add dynamic dimension size as new parameters.
1795   TF_RETURN_IF_ERROR(ForEachDynamicDimension(
1796       hlo->while_body()->root_instruction(),
1797       [&](ShapeIndex index, int64_t dim,
1798           HloInstruction* dynamic_size) -> Status {
1799         const int64_t output_index =
1800             dynamic_output_mapping.element(index).at(dim);
1801         new_root_operands[output_index] = dynamic_size;
1802         return OkStatus();
1803       }));
1804   for (auto operand : new_root_operands) {
1805     TF_RET_CHECK(operand != nullptr);
1806   }
1807   HloInstruction* new_body_root = hlo->while_body()->AddInstruction(
1808       HloInstruction::CreateTuple(new_root_operands));
1809   hlo->while_body()->set_root_instruction(new_body_root);
1810   return OkStatus();
1811 }
1812 
HandleParameter(HloInstruction * hlo)1813 Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) {
1814   return param_bindings_.ForEachBinding(
1815       [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
1816           const DynamicParameterBinding::DynamicDimension& dynamic_dimension) {
1817         if (dynamic_dimension.parameter_num != hlo->parameter_number()) {
1818           return OkStatus();
1819         }
1820         HloComputation* computation = hlo->parent();
1821         HloInstruction* target_parameter =
1822             computation->parameter_instruction(dynamic_dimension.parameter_num);
1823 
1824         HloInstruction* dynamic_size =
1825             computation->parameter_instruction(dynamic_parameter.parameter_num);
1826         for (int64_t i : dynamic_parameter.parameter_index) {
1827           dynamic_size =
1828               computation->AddInstruction(HloInstruction::CreateGetTupleElement(
1829                   ShapeUtil::GetSubshape(dynamic_size->shape(), {i}),
1830                   dynamic_size, i));
1831         }
1832 
1833         parent_->SetDynamicSize(target_parameter,
1834                                 dynamic_dimension.parameter_index,
1835                                 dynamic_dimension.dimension, dynamic_size);
1836         return OkStatus();
1837       });
1838 }
1839 
ForEachDynamicDimension(HloInstruction * inst,const DynamicDimensionFn & fn)1840 Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension(
1841     HloInstruction* inst, const DynamicDimensionFn& fn) {
1842   auto iter = parent_->per_hlo_dynamic_dimensions_.find(inst);
1843   if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
1844     for (auto& dynamic_dimension : iter->second) {
1845       HloInstruction* dynamic_size = parent_->GetDynamicSize(
1846           dynamic_dimension.inst, dynamic_dimension.index,
1847           dynamic_dimension.dim);
1848       TF_RETURN_IF_ERROR(
1849           fn(dynamic_dimension.index, dynamic_dimension.dim, dynamic_size));
1850     }
1851   }
1852   return OkStatus();
1853 }
1854 
InsertShapeCheck(HloInstruction * dim1,HloInstruction * dim2,bool support_implicit_broadcast)1855 Status DynamicDimensionInferenceVisitor::InsertShapeCheck(
1856     HloInstruction* dim1, HloInstruction* dim2,
1857     bool support_implicit_broadcast) {
1858   switch (shape_check_mode_) {
1859     case DynamicDimensionInference::kIgnore:
1860       return Status::OK();
1861     case DynamicDimensionInference::kCompileTime:
1862       return InvalidArgument(
1863           "Fail to proof the equality of two dimensions at compile time: "
1864           "%s vs %s",
1865           dim1->ToString(), dim2->ToString());
1866     case DynamicDimensionInference::kRuntime: {
1867       TF_ASSIGN_OR_RETURN(
1868           HloInstruction * assertion,
1869           MakeCompareHlo(Comparison::Direction::kEq, dim1, dim2));
1870       if (shape_assertion_ == nullptr) {
1871         shape_assertion_ = assertion;
1872       } else {
1873         TF_ASSIGN_OR_RETURN(
1874             shape_assertion_,
1875             MakeBinaryHlo(HloOpcode::kAnd, shape_assertion_, assertion));
1876       }
1877       return OkStatus();
1878     }
1879     default:
1880       LOG(FATAL) << "Unreachable";
1881   }
1882 }
1883 
ForEachDynamicDimensionInOperand(HloInstruction * inst,int64_t operand_index,const OperandDynamicDimensionFn & fn)1884 Status DynamicDimensionInferenceVisitor::ForEachDynamicDimensionInOperand(
1885     HloInstruction* inst, int64_t operand_index,
1886     const OperandDynamicDimensionFn& fn) {
1887   auto iter =
1888       parent_->per_hlo_dynamic_dimensions_.find(inst->operand(operand_index));
1889   if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
1890     for (auto& dynamic_dimension : iter->second) {
1891       HloInstruction* dynamic_size = parent_->GetDynamicSize(
1892           dynamic_dimension.inst, dynamic_dimension.index,
1893           dynamic_dimension.dim);
1894       TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index,
1895                             dynamic_dimension.dim, operand_index,
1896                             dynamic_size));
1897     }
1898   }
1899   return OkStatus();
1900 }
1901 
ForEachOperandDynamicDimension(HloInstruction * inst,const OperandDynamicDimensionFn & fn)1902 Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension(
1903     HloInstruction* inst, const OperandDynamicDimensionFn& fn) {
1904   for (int64_t operand_index = 0; operand_index < inst->operand_count();
1905        ++operand_index) {
1906     TF_RETURN_IF_ERROR(
1907         ForEachDynamicDimensionInOperand(inst, operand_index, fn));
1908   }
1909   return OkStatus();
1910 }
1911 
SetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64_t dim,HloInstruction * size)1912 void DynamicDimensionInference::SetDynamicSize(HloInstruction* inst,
1913                                                const ShapeIndex& index,
1914                                                int64_t dim,
1915                                                HloInstruction* size) {
1916   VLOG(1) << "Set dimension inst " << inst->ToString() << " index "
1917           << index.ToString() << "@" << dim << " to " << size->ToShortString();
1918   Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index);
1919   CHECK(!subshape.IsTuple()) << "Can't set a tuple shape to dynamic dimension";
1920   CHECK(dim < subshape.rank() && dim >= 0)
1921       << "Asked to set invalid dynamic dimension. Shape: "
1922       << subshape.ToString() << ", Dimension: " << dim;
1923   DynamicDimension dynamic_dimension{inst, index, dim};
1924   // Updating a dynamic dimension twice overwrites the previous one.
1925   dynamic_mapping_[dynamic_dimension] = size;
1926   auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst);
1927   iter.first->second.emplace(dynamic_dimension);
1928 }
1929 
CopyMapping(HloInstruction * from,HloInstruction * to)1930 void DynamicDimensionInference::CopyMapping(HloInstruction* from,
1931                                             HloInstruction* to) {
1932   auto iter = per_hlo_dynamic_dimensions_.find(from);
1933   if (iter != per_hlo_dynamic_dimensions_.end()) {
1934     for (auto& dynamic_dimension : iter->second) {
1935       HloInstruction* dynamic_size =
1936           GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index,
1937                          dynamic_dimension.dim);
1938       SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim,
1939                      dynamic_size);
1940     }
1941   }
1942 }
1943 
1944 /* static */
Run(HloModule * module,CustomCallInferenceHandler custom_call_handler,ShapeCheckMode shape_check_mode,const AssertionGenerator & assertion_generator)1945 StatusOr<DynamicDimensionInference> DynamicDimensionInference::Run(
1946     HloModule* module, CustomCallInferenceHandler custom_call_handler,
1947     ShapeCheckMode shape_check_mode,
1948     const AssertionGenerator& assertion_generator) {
1949   VLOG(2) << "Param Config " << module->dynamic_parameter_binding().ToString();
1950   DynamicDimensionInference inference(module, std::move(custom_call_handler),
1951                                       shape_check_mode, assertion_generator);
1952   TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions());
1953   return inference;
1954 }
1955 
ToString() const1956 std::string DynamicDimensionInference::ToString() const {
1957   std::vector<std::string> pieces;
1958   pieces.push_back("DynamicDimensionInference: ");
1959   for (const auto& mapping : dynamic_mapping_) {
1960     const DynamicDimension& dynamic_dimension = mapping.first;
1961     pieces.push_back(absl::StrFormat(
1962         " -- instruction %s at %s has dim %lld as dynamic"
1963         " dimension, which is represented by instruction %s",
1964         dynamic_dimension.inst->ToString(), dynamic_dimension.index.ToString(),
1965         dynamic_dimension.dim, mapping.second->ToString()));
1966   }
1967   return absl::StrJoin(pieces, "\n");
1968 }
1969 
DynamicDimensionInference(HloModule * module,CustomCallInferenceHandler custom_call_handler,ShapeCheckMode shape_check_mode,AssertionGenerator assertion_generator)1970 DynamicDimensionInference::DynamicDimensionInference(
1971     HloModule* module, CustomCallInferenceHandler custom_call_handler,
1972     ShapeCheckMode shape_check_mode, AssertionGenerator assertion_generator)
1973     : module_(module),
1974       custom_call_handler_(std::move(custom_call_handler)),
1975       shape_check_mode_(shape_check_mode),
1976       assertion_generator_(assertion_generator) {}
1977 
AnalyzeDynamicDimensions()1978 Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
1979   return DynamicDimensionInferenceVisitor::Run(
1980       module_->entry_computation(), module_->dynamic_parameter_binding(), this,
1981       custom_call_handler_, shape_check_mode_, assertion_generator_);
1982 }
1983 
ReplaceAllDynamicDimensionUsesWith(HloInstruction * replace,HloInstruction * with)1984 void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith(
1985     HloInstruction* replace, HloInstruction* with) {
1986   CHECK(Shape::Equal().IgnoreLayout()(replace->shape(),
1987                                       ShapeUtil::MakeScalarShape(S32)));
1988   CHECK(Shape::Equal().IgnoreLayout()(with->shape(),
1989                                       ShapeUtil::MakeScalarShape(S32)));
1990   for (auto& kv : dynamic_mapping_) {
1991     if (kv.second == replace) {
1992       kv.second = with;
1993     }
1994   }
1995 }
1996 
ForwardDynamicSize(HloInstruction * inst,HloInstruction * new_inst,const ShapeIndex & index)1997 Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
1998                                                      HloInstruction* new_inst,
1999                                                      const ShapeIndex& index) {
2000   CHECK(Shape::Equal()(inst->shape(), new_inst->shape()));
2001 
2002   for (int64_t dim = 0; dim < inst->shape().rank(); ++dim) {
2003     DynamicDimension dynamic_dimension_new{new_inst, index, dim};
2004     DynamicDimension dynamic_dimension{inst, index, dim};
2005     auto iter = dynamic_mapping_.find(dynamic_dimension);
2006     if (iter != dynamic_mapping_.end()) {
2007       dynamic_mapping_.insert({dynamic_dimension_new, iter->second});
2008       auto iter = per_hlo_dynamic_dimensions_.try_emplace(new_inst);
2009       iter.first->second.emplace(dynamic_dimension_new);
2010     }
2011   }
2012 
2013   return OkStatus();
2014 }
2015 
HasDynamicDimension(HloInstruction * inst,ShapeIndexView index) const2016 bool DynamicDimensionInference::HasDynamicDimension(
2017     HloInstruction* inst, ShapeIndexView index) const {
2018   bool has_dynamic_dim = false;
2019   ShapeUtil::ForEachSubshape(inst->shape(), [&](const Shape& subshape,
2020                                                 const ShapeIndex& subindex) {
2021     if (subshape.IsTuple()) {
2022       return;
2023     }
2024     if (ShapeIndexView(subindex).first(index.size()) != index) {
2025       return;
2026     }
2027     for (int64_t i = 0; i < subshape.dimensions_size(); ++i) {
2028       HloInstruction* operand_dynamic_size = GetDynamicSize(inst, subindex, i);
2029       if (operand_dynamic_size != nullptr) {
2030         has_dynamic_dim = true;
2031       }
2032     }
2033   });
2034   return has_dynamic_dim;
2035 }
2036 
Update(HloInstruction * inst)2037 Status DynamicDimensionInference::Update(HloInstruction* inst) {
2038   DynamicParameterBinding parameter_binding;
2039   DynamicDimensionInferenceVisitor visitor(
2040       parameter_binding, this, custom_call_handler_, shape_check_mode_);
2041   return inst->Visit(&visitor);
2042 }
2043 
GetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64_t dim) const2044 HloInstruction* DynamicDimensionInference::GetDynamicSize(
2045     HloInstruction* inst, const ShapeIndex& index, int64_t dim) const {
2046   auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim});
2047   if (iter != dynamic_mapping_.end()) {
2048     return iter->second;
2049   }
2050   return nullptr;
2051 }
2052 
GetDynamicSizes(HloInstruction * inst,const ShapeIndex & index) const2053 std::vector<HloInstruction*> DynamicDimensionInference::GetDynamicSizes(
2054     HloInstruction* inst, const ShapeIndex& index) const {
2055   CHECK(ShapeUtil::IndexIsValid(inst->shape(), index));
2056   const int64_t rank = ShapeUtil::GetSubshape(inst->shape(), index).rank();
2057   std::vector<HloInstruction*> result(rank, nullptr);
2058   for (int64_t i = 0; i < rank; ++i) {
2059     result[i] = GetDynamicSize(inst, {}, i);
2060   }
2061   return result;
2062 }
2063 
2064 }  // namespace xla
2065