1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17
18 #include <vector>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/match.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
24 #include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/tuple_util.h"
32 #include "tensorflow/compiler/xla/service/while_util.h"
33 #include "tensorflow/compiler/xla/shape_tree.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/window_util.h"
38 namespace xla {
39
40 namespace {
41 // Replace `narrow_comp` with a new computation with `wide_shape` as input.
WidenComputation(HloComputation * narrow_comp,const Shape & wide_shape)42 StatusOr<HloComputation*> WidenComputation(HloComputation* narrow_comp,
43 const Shape& wide_shape) {
44 TF_RET_CHECK(wide_shape.IsTuple());
45 const Shape& narrow_shape = narrow_comp->parameter_instruction(0)->shape();
46 if (Shape::Equal()(wide_shape, narrow_shape)) {
47 // No need to widen the computation.
48 return narrow_comp;
49 }
50 HloComputation* wide_comp = [&]() {
51 HloComputation::Builder builder(absl::StrCat("wide.", narrow_comp->name()));
52 builder.AddInstruction(
53 HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
54 return narrow_comp->parent()->AddEmbeddedComputation(builder.Build());
55 }();
56
57 HloInstruction* wide_parameter = wide_comp->parameter_instruction(0);
58 HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
59 wide_parameter, narrow_shape.tuple_shapes_size());
60 HloInstruction* call_narrow_comp = wide_comp->AddInstruction(
61 HloInstruction::CreateCall(narrow_comp->root_instruction()->shape(),
62 {truncated_parameter}, narrow_comp));
63 wide_comp->set_root_instruction(call_narrow_comp,
64 /*accept_different_shape=*/true);
65 TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_comp).status());
66 return wide_comp;
67 }
68 } // namespace
69
70 class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
71 public:
DynamicDimensionInferenceVisitor(const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent,DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler,DynamicDimensionInference::ShapeCheckMode shape_check_mode)72 explicit DynamicDimensionInferenceVisitor(
73 const DynamicParameterBinding& param_bindings,
74 DynamicDimensionInference* parent,
75 DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler,
76 DynamicDimensionInference::ShapeCheckMode shape_check_mode)
77 : param_bindings_(param_bindings),
78 parent_(parent),
79 custom_call_handler_(std::move(custom_call_handler)),
80 shape_check_mode_(shape_check_mode) {}
81
82 Status DefaultAction(HloInstruction* hlo) override;
83
Run(HloComputation * computation,const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent,DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler=nullptr,DynamicDimensionInference::ShapeCheckMode shape_check_mode=DynamicDimensionInference::ShapeCheckMode::kIgnore,const DynamicDimensionInference::AssertionGenerator & assertion_generator=nullptr)84 static Status Run(HloComputation* computation,
85 const DynamicParameterBinding& param_bindings,
86 DynamicDimensionInference* parent,
87 DynamicDimensionInference::CustomCallInferenceHandler
88 custom_call_handler = nullptr,
89 DynamicDimensionInference::ShapeCheckMode shape_check_mode =
90 DynamicDimensionInference::ShapeCheckMode::kIgnore,
91 const DynamicDimensionInference::AssertionGenerator&
92 assertion_generator = nullptr) {
93 DynamicDimensionInferenceVisitor visitor(param_bindings, parent,
94 std::move(custom_call_handler),
95 shape_check_mode);
96
97 TF_RETURN_IF_ERROR(computation->Accept(&visitor));
98 if (visitor.shape_assertion_ != nullptr) {
99 CHECK(assertion_generator);
100 assertion_generator(visitor.shape_assertion_);
101 }
102 return Status::OK();
103 }
104
105 Status HandleParameter(HloInstruction* hlo) override;
106
107 Status HandleReduce(HloInstruction* hlo) override;
108
109 Status HandleDot(HloInstruction* hlo) override;
110
111 Status HandleTuple(HloInstruction* hlo) override;
112
113 Status HandleTranspose(HloInstruction* hlo) override;
114
115 Status HandleDynamicReshape(HloInstruction* hlo) override;
116
117 Status HandleReshape(HloInstruction* hlo) override;
118
119 Status HandleSort(HloInstruction* hlo) override;
120
121 Status HandlePad(HloInstruction* hlo) override;
122
123 Status HandleCustomCall(HloInstruction* hlo) override;
124
125 Status HandleBroadcast(HloInstruction* hlo) override;
126
127 Status HandleGetDimensionSize(HloInstruction* hlo) override;
128
129 Status HandleSetDimensionSize(HloInstruction* hlo) override;
130
131 Status HandleSelect(HloInstruction* hlo) override;
132
133 Status HandleConvolution(HloInstruction* hlo) override;
134
135 Status HandleConcatenate(HloInstruction* hlo) override;
136
137 Status HandleReduceWindow(HloInstruction* hlo) override;
138
139 Status HandleReverse(HloInstruction* hlo) override;
140
141 Status HandleSelectAndScatter(HloInstruction* hlo) override;
142
143 Status HandleGetTupleElement(HloInstruction* hlo) override;
144
145 Status HandleElementwiseUnary(HloInstruction* hlo) override;
146
147 Status HandleElementwiseNary(HloInstruction* hlo);
148
149 Status HandleElementwiseBinary(HloInstruction* hlo) override;
150
151 Status HandleClamp(HloInstruction* hlo) override;
152
153 Status HandleConditional(HloInstruction* hlo) override;
154
155 Status HandleWhile(HloInstruction* hlo) override;
156
157 Status HandleSlice(HloInstruction* hlo) override;
158
159 Status HandleDynamicSlice(HloInstruction* hlo) override;
160
161 Status HandleDynamicUpdateSlice(HloInstruction* hlo) override;
162
163 Status HandleGather(HloInstruction* hlo) override;
164
165 Status HandleScatter(HloInstruction* hlo) override;
166
167 Status HandleMap(HloInstruction* hlo) override;
168
169 Status HandleDomain(HloInstruction* hlo) override;
170
171 private:
172 using OperandDynamicDimensionFn = std::function<Status(
173 HloInstruction* operand, ShapeIndex index, int64_t dimension,
174 int64_t operand_index, HloInstruction* dynamic_size)>;
175
176 using DynamicDimensionFn = std::function<Status(
177 ShapeIndex index, int64_t dimension, HloInstruction* dynamic_size)>;
178
179 Status HandleDynamicConvolutionForward(HloInstruction* hlo,
180 int64_t operand_index,
181 int64_t dimension,
182 HloInstruction* dynamic_size);
183
184 Status HandleDynamicConvolutionKernelGrad(HloInstruction* hlo,
185 int64_t operand_index,
186 int64_t dimension);
187
188 Status HandleDynamicConvolutionInputGrad(HloInstruction* hlo,
189 int64_t operand_index,
190 int64_t dimension);
191
192 Status HandleDynamicWindowSamePadding(HloInstruction* hlo,
193 HloInstruction* dynamic_size,
194 int64_t operand_index,
195 int64_t dimension);
196
197 Status ForEachOperandDynamicDimension(HloInstruction* inst,
198 const OperandDynamicDimensionFn&);
199 Status ForEachDynamicDimensionInOperand(HloInstruction* inst,
200 int64_t operand_index,
201 const OperandDynamicDimensionFn&);
202 Status ForEachDynamicDimension(HloInstruction* inst,
203 const DynamicDimensionFn& fn);
204
205 // Insert shape check to make sure `dim1` is equal to `dim2`. If
206 // support_implicit_broadcast is true, the check will pass if either of them
207 // is 1, even if they are different.
208 Status InsertShapeCheck(HloInstruction* dim1, HloInstruction* dim2,
209 bool support_implicit_broadcast);
210
211 // Pass through a dynamic dimension from the input to the output with the
212 // same value and index in the shape. This is a helper function to handle
213 // trivial instructions like elementwise operations.
214 Status PassThroughDynamicDimension(HloInstruction*);
215
216 // The dynamic parameter bindings of this computation.
217 const DynamicParameterBinding& param_bindings_;
218
219 // A pointer to DynamicDimensionInference, used to update the dynamic mapping.
220 DynamicDimensionInference* parent_;
221
222 // A handler for custom calls.
223 DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler_;
224
225 // Indicates what to do at places where shape check is needed.
226 DynamicDimensionInference::ShapeCheckMode shape_check_mode_;
227
228 // Value which has to be `true` for the shapes to match.
229 HloInstruction* shape_assertion_ = nullptr;
230 };
231
DefaultAction(HloInstruction * hlo)232 Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) {
233 return ForEachOperandDynamicDimension(
234 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
235 int64_t operand_index, HloInstruction* dynamic_size) {
236 return UnimplementedStrCat(
237 "Asked to propagate a dynamic dimension from hlo ", operand->name(),
238 "@", index.ToString(), "@", dimension, " to hlo ", hlo->ToString(),
239 ", which is not implemented.");
240 });
241 }
242
HandleGetTupleElement(HloInstruction * hlo)243 Status DynamicDimensionInferenceVisitor::HandleGetTupleElement(
244 HloInstruction* hlo) {
245 return ForEachOperandDynamicDimension(
246 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
247 int64_t operand_index, HloInstruction* dynamic_size) {
248 if (hlo->tuple_index() == index[0]) {
249 ShapeIndex new_index(ShapeIndexView(index).subspan(1));
250 parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size);
251 }
252 return OkStatus();
253 });
254 }
255
HandleTuple(HloInstruction * hlo)256 Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) {
257 return ForEachOperandDynamicDimension(
258 hlo, [&](HloInstruction*, ShapeIndex index, int64_t dimension,
259 int64_t operand_index, HloInstruction* dynamic_size) {
260 index.push_front(operand_index);
261 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
262 return OkStatus();
263 });
264 }
265
HandleBroadcast(HloInstruction * hlo)266 Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
267 return ForEachOperandDynamicDimension(
268 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
269 int64_t operand_index, HloInstruction* dynamic_size) {
270 int64_t broadcast_dim = hlo->dimensions(dimension);
271 parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size);
272 return OkStatus();
273 });
274 }
275
HandleCustomCall(HloInstruction * hlo)276 Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
277 if (hlo->custom_call_target() == "PadToStatic") {
278 for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
279 if (hlo->operand(0)->shape().is_dynamic_dimension(i)) {
280 HloInstruction* dynamic_size =
281 hlo->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
282 ShapeUtil::MakeScalarShape(S32), hlo, i + 1));
283 // PadToStatic converts a dynamic dimension to static dimension. It then
284 // returns the padded data output and the dynamic sizes of input
285 // dimensions.
286 ShapeIndex data_output = {0};
287 parent_->SetDynamicSize(hlo, data_output, i, dynamic_size);
288 }
289 }
290 return OkStatus();
291 }
292 if (custom_call_handler_) {
293 return custom_call_handler_(hlo, parent_);
294 }
295
296 if (hlo->custom_call_target() == "DynamicConvolutionForward") {
297 // If input feature is dynamic and kernel feature is static, we can infer
298 // that input feature is also static.
299 // E.g.,:
300 // lhs = [B, X, Y, ?]
301 // rhs = [X, Y, I, O]
302 // dim_labels = b01f_01io
303 // We can infer that the dynamic dimension in rhs is static I.
304 const ConvolutionDimensionNumbers& dnums =
305 hlo->convolution_dimension_numbers();
306 HloInstruction* input_feature = parent_->GetDynamicSize(
307 hlo->mutable_operand(0), {}, dnums.input_feature_dimension());
308 HloInstruction* kernel_feature = parent_->GetDynamicSize(
309 hlo->mutable_operand(1), {}, dnums.kernel_input_feature_dimension());
310
311 if (input_feature != nullptr && kernel_feature == nullptr) {
312 if (hlo->mutable_operand(0)->shape().dimensions(
313 dnums.input_feature_dimension()) ==
314 hlo->mutable_operand(1)->shape().dimensions(
315 dnums.kernel_input_feature_dimension()))
316 parent_->SetDynamicSize(hlo->mutable_operand(0), {},
317 dnums.input_feature_dimension(), nullptr);
318 }
319 }
320 return ForEachOperandDynamicDimension(
321 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
322 int64_t operand_index, HloInstruction* dynamic_size) {
323 // Resize custom call should propagate dynamic batch (0) and channel (3)
324 // dimensions.
325 if (hlo->custom_call_target() == "SliceToDynamic" ||
326 hlo->custom_call_target() == "Sharding" ||
327 (absl::StartsWith(hlo->custom_call_target(), "Resize") &&
328 (dimension == 0 || dimension == 3))) {
329 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
330 return OkStatus();
331 }
332 if (hlo->custom_call_target() == "DynamicReduceWindowSamePadding") {
333 if (hlo->operand_count() > 2) {
334 return Unimplemented(
335 "DynamicReduceWindowSamePadding doesn't support variadic "
336 "reduce window %s",
337 hlo->ToString());
338 }
339 return HandleDynamicWindowSamePadding(hlo, dynamic_size,
340 operand_index, dimension);
341 }
342
343 if (hlo->custom_call_target() == "DynamicSelectAndScatterSamePadding") {
344 if (operand_index == 1) {
345 // Operand 0 (input) determines dynamic output size. We ignore the
346 // dynamic size in the operand 1 (output gradient).
347 return OkStatus();
348 }
349 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
350 return OkStatus();
351 }
352
353 if (hlo->custom_call_target() == "DynamicConvolutionInputGrad") {
354 return HandleDynamicConvolutionInputGrad(hlo, operand_index,
355 dimension);
356 }
357
358 if (hlo->custom_call_target() == "DynamicConvolutionKernelGrad") {
359 return HandleDynamicConvolutionKernelGrad(hlo, operand_index,
360 dimension);
361 }
362
363 if (hlo->custom_call_target() == "DynamicConvolutionForward") {
364 return HandleDynamicConvolutionForward(hlo, operand_index, dimension,
365 dynamic_size);
366 }
367 return Unimplemented(
368 "CustomCall \"%s\" is not supported to have a dynamic dimension",
369 hlo->custom_call_target());
370 });
371 }
372
HandleSort(HloInstruction * hlo)373 Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) {
374 return ForEachOperandDynamicDimension(
375 hlo,
376 [&](HloInstruction* operand, ShapeIndex index, int64_t dynamic_dimension,
377 int64_t operand_index, HloInstruction* dynamic_size) {
378 HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
379 if (sort->values_count() == 0) {
380 parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size);
381 } else {
382 parent_->SetDynamicSize(hlo, {operand_index}, dynamic_dimension,
383 dynamic_size);
384 }
385
386 return OkStatus();
387 });
388 }
389
HandlePad(HloInstruction * hlo)390 Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
391 return ForEachOperandDynamicDimension(
392 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
393 int64_t operand_index, HloInstruction* dynamic_size) {
394 if (operand_index != 0) {
395 return Unimplemented(
396 "Dynamic dimension on padding value is not supported");
397 }
398 const PaddingConfig_PaddingConfigDimension& padding_config =
399 hlo->padding_config().dimensions(dimension);
400
401 HloInstruction* dynamic_size_adjusted = dynamic_size;
402 if (padding_config.interior_padding() != 0) {
403 // Adjust for interior padding :
404 // Size' = max((Size - 1), 0) * interior_padding + Size
405 HloInstruction* one =
406 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
407 LiteralUtil::CreateR0<int32_t>(1)));
408 HloInstruction* zero =
409 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
410 LiteralUtil::CreateR0<int32_t>(0)));
411 HloInstruction* interior_padding = hlo->parent()->AddInstruction(
412 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
413 padding_config.interior_padding())));
414 dynamic_size_adjusted =
415 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
416 dynamic_size_adjusted->shape(), HloOpcode::kSubtract,
417 dynamic_size_adjusted, one));
418 dynamic_size_adjusted =
419 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
420 dynamic_size_adjusted->shape(), HloOpcode::kMaximum,
421 dynamic_size_adjusted, zero));
422 dynamic_size_adjusted =
423 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
424 dynamic_size_adjusted->shape(), HloOpcode::kMultiply,
425 dynamic_size_adjusted, interior_padding));
426 dynamic_size_adjusted =
427 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
428 dynamic_size_adjusted->shape(), HloOpcode::kAdd,
429 dynamic_size_adjusted, dynamic_size));
430 }
431 HloInstruction* adjustment = hlo->parent()->AddInstruction(
432 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
433 padding_config.edge_padding_low() +
434 padding_config.edge_padding_high())));
435 dynamic_size_adjusted =
436 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
437 dynamic_size_adjusted->shape(), HloOpcode::kAdd,
438 dynamic_size_adjusted, adjustment));
439 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted);
440 return OkStatus();
441 });
442 }
443
HandleReduce(HloInstruction * hlo)444 Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
445 return ForEachOperandDynamicDimension(
446 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
447 int64_t operand_index, HloInstruction* dynamic_size) {
448 auto* reduce = Cast<HloReduceInstruction>(hlo);
449 int64_t operand_count = reduce->operand_count();
450 CHECK_EQ(operand_count % 2, 0);
451 if (operand_index >= reduce->input_count()) {
452 // Init values doesn't have dynamic size.
453 return OkStatus();
454 }
455 if ((absl::c_count(reduce->dimensions(), dimension) != 0)) {
456 // Dimension is to be reduced, stop tracing.
457 return OkStatus();
458 }
459
460 // Find out the new dynamic dimension after reduce.
461 int64_t dimensions_not_reduced_count = 0;
462 for (int i = 0; i < operand->shape().rank(); ++i) {
463 if (dimension == i) {
464 // The dimensions of all data operands of a variadic reduce have
465 // to be the same. This means that if one operand of variadic
466 // reduce has a dynamic dimension, we set all outputs to use the
467 // same dynamic size in corresponding dimensions.
468 ShapeUtil::ForEachSubshape(
469 reduce->shape(),
470 [&](const Shape& subshape, ShapeIndex reduce_result_index) {
471 if (!ShapeUtil::IsLeafIndex(reduce->shape(),
472 reduce_result_index)) {
473 return;
474 }
475 parent_->SetDynamicSize(reduce, reduce_result_index,
476 dimensions_not_reduced_count,
477 dynamic_size);
478 });
479
480 return OkStatus();
481 }
482 if (absl::c_count(reduce->dimensions(), i) == 0) {
483 dimensions_not_reduced_count++;
484 }
485 }
486
487 return OkStatus();
488 });
489 }
490
HandleDot(HloInstruction * hlo)491 Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
492 return ForEachOperandDynamicDimension(hlo, [&](HloInstruction* operand,
493 ShapeIndex operand_shape_index,
494 int64_t operand_dimension,
495 int64_t operand_index,
496 HloInstruction* dynamic_size) {
497 // There are three types of dimensions in a dot:
498 // A. batch dims
499 // B. contracting dims
500 // C. non-batch non-contracting dims.
501 // The output dimensions of a dot has three parts with the following
502 // order:
503 // [(type A), (lhs type C), (rhs type C)]
504 //
505 // Note that both lhs and rhs have the same dimension sizes for batch,
506 // but the dimension index could be different.
507 //
508 // Given one dynamic input dimension, either lhs or rhs, we use a
509 // mapping to find the corresponding output dimension.
510 HloInstruction* dot = hlo;
511 const DotDimensionNumbers& dimension_numbers = dot->dot_dimension_numbers();
512 // A map from the operand dimensions to result dimension.
513 absl::flat_hash_map<int64_t, int64_t> result_dim_mapping;
514 int64_t current_result_dims = 0;
515
516 bool lhs = operand_index == 0;
517
518 // The first loop keep tracks of batch dimension. RHS and LHS could have
519 // different batch dimension numbers.
520 if (lhs) {
521 for (int64_t i : dimension_numbers.lhs_batch_dimensions()) {
522 result_dim_mapping[i] = current_result_dims++;
523 }
524 } else {
525 for (int64_t i : dimension_numbers.rhs_batch_dimensions()) {
526 result_dim_mapping[i] = current_result_dims++;
527 }
528 }
529
530 // Handle dimensions in the lhs.
531 for (int64_t i = 0; i < dot->operand(0)->shape().rank(); i++) {
532 // Look for non-contracting and non-batching dimension.
533 if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
534 i)) {
535 continue;
536 }
537 if (absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
538 continue;
539 }
540 if (lhs) {
541 result_dim_mapping[i] = current_result_dims;
542 }
543 current_result_dims++;
544 }
545
546 // Handle dimensions in the rhs.
547 for (int64_t i = 0; i < dot->operand(1)->shape().rank(); i++) {
548 // Look for non-contracting and non-batching dimension.
549 if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
550 i)) {
551 continue;
552 }
553 if (absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
554 continue;
555 }
556 if (!lhs) {
557 result_dim_mapping[i] = current_result_dims;
558 }
559 current_result_dims++;
560 }
561
562 // Check if the operand dim is in the result shape. If so, add another
563 // work item to trace that dimension.
564 auto iter = result_dim_mapping.find(operand_dimension);
565 if (iter != result_dim_mapping.end()) {
566 parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
567 }
568
569 return OkStatus();
570 });
571 }
572
HandleTranspose(HloInstruction * hlo)573 Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
574 return ForEachOperandDynamicDimension(
575 hlo,
576 [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
577 int64_t operand_index, HloInstruction* dynamic_size) -> Status {
578 int64_t permuted_dim = -1;
579 for (int64_t i = 0; i < hlo->dimensions().size(); ++i) {
580 if (hlo->dimensions()[i] == dimension) {
581 TF_RET_CHECK(permuted_dim == -1);
582 permuted_dim = i;
583 }
584 }
585 parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size);
586 return OkStatus();
587 });
588 }
589
HandleConvolution(HloInstruction * hlo)590 Status DynamicDimensionInferenceVisitor::HandleConvolution(
591 HloInstruction* hlo) {
592 return ForEachOperandDynamicDimension(
593 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
594 int64_t operand_index, HloInstruction* dynamic_size) {
595 HloInstruction* conv = hlo;
596 const ConvolutionDimensionNumbers& dimension_numbers =
597 conv->convolution_dimension_numbers();
598 if (operand_index == 0) {
599 if (dimension == dimension_numbers.input_batch_dimension()) {
600 parent_->SetDynamicSize(conv, {},
601 dimension_numbers.output_batch_dimension(),
602 dynamic_size);
603 return OkStatus();
604 }
605
606 if (dimension == dimension_numbers.input_feature_dimension()) {
607 return OkStatus();
608 }
609 } else {
610 if (dimension == dimension_numbers.kernel_input_feature_dimension()) {
611 return OkStatus();
612 }
613 }
614
615 return Unimplemented("Dynamic Spatial Convolution is not supported: %s",
616 conv->ToString());
617 });
618 }
619
HandleConcatenate(HloInstruction * hlo)620 Status DynamicDimensionInferenceVisitor::HandleConcatenate(
621 HloInstruction* hlo) {
622 // First handle concatenate dimensions. We do this by iterating through all
623 // operands while tracking both dynamic and static dimensions.
624
625 // static_size is used to keep track of the concated size of static
626 // dimensions.
627 int64_t static_size = 0;
628 std::vector<HloInstruction*> dynamic_concat_dims;
629 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
630 HloInstruction* dynamic_size = parent_->GetDynamicSize(
631 hlo->mutable_operand(i), {}, hlo->concatenate_dimension());
632 if (dynamic_size == nullptr) {
633 // This is a static dimension.
634 static_size +=
635 hlo->operand(i)->shape().dimensions(hlo->concatenate_dimension());
636 } else {
637 dynamic_concat_dims.push_back(dynamic_size);
638 }
639 }
640 // If concat dimension is dynamic, calculate its size by summing up static
641 // dims and dynamic dims together.
642 if (!dynamic_concat_dims.empty()) {
643 HloInstruction* dim_size_total =
644 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
645 LiteralUtil::CreateR0<int32_t>(static_size)));
646 for (HloInstruction* dynamic_dim : dynamic_concat_dims) {
647 dim_size_total = hlo->parent()->AddInstruction(
648 HloInstruction::CreateBinary(dim_size_total->shape(), HloOpcode::kAdd,
649 dim_size_total, dynamic_dim));
650 }
651 parent_->SetDynamicSize(hlo, {}, hlo->concatenate_dimension(),
652 dim_size_total);
653 }
654
655 // Simply pass through non-concat dynamic dimensions.
656 return ForEachOperandDynamicDimension(
657 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
658 int64_t operand_index, HloInstruction* dynamic_size) {
659 int64_t concatenate_dimension = hlo->concatenate_dimension();
660 if (concatenate_dimension == dimension) {
661 return OkStatus();
662 }
663 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
664 return OkStatus();
665 });
666 }
667
HandleGetDimensionSize(HloInstruction * gds)668 Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
669 HloInstruction* gds) {
670 // Dynamic dimension doesn't propagate through GetDimensionSize:
671 //
672 // Input: F32[x, y, z]
673 // |
674 // GetDimensionSize(1): S32[]
675 //
676 // The returned value is a scalar, which doesn't have any dynamic dimension in
677 // the shape (although the value contains the real size of the dynamic
678 // dimension of the input).
679 int64_t dim = gds->dimension();
680 HloInstruction* operand = gds->mutable_operand(0);
681 HloInstruction* dynamic_size = parent_->GetDynamicSize(operand, {}, dim);
682 HloComputation* computation = gds->parent();
683 if (dynamic_size != nullptr) {
684 TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(dynamic_size));
685 // The dependency between an instruction and its dynamic dimensions is not
686 // modeled in the IR. As instr is being replaced by dynamic_size, also tell
687 // dynamic dimension inference that the instruction is being replaced.
688 parent_->ReplaceAllDynamicDimensionUsesWith(gds, dynamic_size);
689 } else {
690 TF_RET_CHECK(dim < gds->operand(0)->shape().rank());
691 int32_t size = gds->operand(0)->shape().dimensions(dim);
692 HloInstruction* new_instr = computation->AddInstruction(
693 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(size)));
694 TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(new_instr));
695 parent_->ReplaceAllDynamicDimensionUsesWith(gds, new_instr);
696 }
697 return OkStatus();
698 }
699
HandleSetDimensionSize(HloInstruction * hlo)700 Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize(
701 HloInstruction* hlo) {
702 bool dimension_is_static = false;
703 const HloInstruction* size = hlo->operand(1);
704 if (size->opcode() == HloOpcode::kConstant) {
705 // Check if we are setting a dimension size to its static size. If so,
706 // removes the dynamic dimension.
707 //
708 // size = s32[] constant(5)
709 // s32[2, 5] = set-dimension-size(s32[2,<=5]{1,0} %param, s32[] %size),
710 // dimensions={1}
711 // The result shape has no dynamic dimension.
712 TF_RET_CHECK(size->shape().rank() == 0);
713 if (size->literal().Get<int32_t>({}) ==
714 hlo->shape().dimensions(hlo->dimension())) {
715 dimension_is_static = true;
716 }
717 }
718
719 if (!dimension_is_static) {
720 // Propagate dynamic dimension indicated by this set dimension size
721 // instruction.
722 parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1));
723 }
724
725 // Also Propagate dynamic dimension already set by operands.
726 TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
727 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
728 int64_t operand_index, HloInstruction* dynamic_size) {
729 if (dimension != hlo->dimension()) {
730 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
731 }
732 return OkStatus();
733 }));
734
735 return OkStatus();
736 }
737
HandleDynamicConvolutionForward(HloInstruction * hlo,int64_t operand_index,int64_t dimension,HloInstruction * dynamic_size)738 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionForward(
739 HloInstruction* hlo, int64_t operand_index, int64_t dimension,
740 HloInstruction* dynamic_size) {
741 TF_RET_CHECK(operand_index == 0);
742 const ConvolutionDimensionNumbers& dimension_numbers =
743 hlo->convolution_dimension_numbers();
744
745 if (dimension == dimension_numbers.input_batch_dimension()) {
746 // Batch dimension is propagated without any changes.
747 parent_->SetDynamicSize(hlo, {}, dimension_numbers.output_batch_dimension(),
748 dynamic_size);
749 return OkStatus();
750 }
751
752 for (int64_t spatial_dim_index = 0;
753 spatial_dim_index < dimension_numbers.input_spatial_dimensions_size();
754 ++spatial_dim_index) {
755 int64_t input_spatial_dim =
756 dimension_numbers.input_spatial_dimensions(spatial_dim_index);
757 int64_t output_spatial_dim =
758 dimension_numbers.output_spatial_dimensions(spatial_dim_index);
759 if (dimension == input_spatial_dim) {
760 // This is a dynamic spatial dimension. Calculate the output size.
761 WindowDimension window_dim = hlo->window().dimensions(spatial_dim_index);
762 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
763 dynamic_size, window_dim.size(), window_dim.window_dilation(),
764 window_dim.stride(), hlo->padding_type());
765 TF_RET_CHECK(window_dim.base_dilation() == 1);
766 parent_->SetDynamicSize(hlo, {}, output_spatial_dim,
767 dynamic_window_dims.output_size);
768 return OkStatus();
769 }
770 }
771 // Input Feature dim disappears after convolution.
772 return OkStatus();
773 }
774
HandleDynamicWindowSamePadding(HloInstruction * hlo,HloInstruction * dynamic_size,int64_t operand_index,int64_t dimension)775 Status DynamicDimensionInferenceVisitor::HandleDynamicWindowSamePadding(
776 HloInstruction* hlo, HloInstruction* dynamic_size, int64_t operand_index,
777 int64_t dimension) {
778 const Window& window = hlo->window();
779 const WindowDimension& window_dim = window.dimensions(dimension);
780 if (!window_util::IsTrivialWindowDimension(window_dim)) {
781 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
782 dynamic_size, window_dim.size(), window_dim.window_dilation(),
783 window_dim.stride(), PaddingType::PADDING_SAME);
784 parent_->SetDynamicSize(hlo, {}, dimension,
785 dynamic_window_dims.output_size);
786 return OkStatus();
787 }
788
789 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
790
791 return OkStatus();
792 }
793
HandleDynamicConvolutionInputGrad(HloInstruction * hlo,int64_t operand_index,int64_t dimension)794 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionInputGrad(
795 HloInstruction* hlo, int64_t operand_index, int64_t dimension) {
796 // The output size of convolution input grad is corresponding input size.
797 HloInstruction* input_sizes = hlo->mutable_operand(0);
798 HloComputation* comp = hlo->parent();
799 TF_RET_CHECK(input_sizes->shape().rank() == 1) << hlo->ToString();
800 TF_RET_CHECK(input_sizes->shape().element_type() == S32) << hlo->ToString();
801 TF_RET_CHECK(input_sizes->shape().dimensions(0) ==
802 hlo->shape().dimensions_size())
803 << hlo->ToString();
804 // Slice to get corresponding input size.
805 HloInstruction* slice = comp->AddInstruction(
806 HloInstruction::CreateSlice(ShapeUtil::MakeShape(S32, {1}), input_sizes,
807 {dimension}, {dimension + 1}, {1}));
808 HloInstruction* reshape = comp->AddInstruction(
809 HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
810 parent_->SetDynamicSize(hlo, {}, dimension, reshape);
811 return OkStatus();
812 }
813
HandleDynamicConvolutionKernelGrad(HloInstruction * hlo,int64_t operand_index,int64_t dimension)814 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionKernelGrad(
815 HloInstruction* hlo, int64_t operand_index, int64_t dimension) {
816 // Dynamic convolution kernel grad produces static shape outputs.
817 return OkStatus();
818 }
819
PassThroughDynamicDimension(HloInstruction * hlo)820 Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension(
821 HloInstruction* hlo) {
822 return ForEachOperandDynamicDimension(
823 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
824 int64_t operand_index, HloInstruction* dynamic_size) {
825 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
826 return OkStatus();
827 });
828 }
829
HandleDomain(HloInstruction * hlo)830 Status DynamicDimensionInferenceVisitor::HandleDomain(HloInstruction* hlo) {
831 return PassThroughDynamicDimension(hlo);
832 }
833
HandleElementwiseUnary(HloInstruction * hlo)834 Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary(
835 HloInstruction* hlo) {
836 return PassThroughDynamicDimension(hlo);
837 }
838
HandleSelect(HloInstruction * hlo)839 Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {
840 return PassThroughDynamicDimension(hlo);
841 }
842
HandleElementwiseNary(HloInstruction * hlo)843 Status DynamicDimensionInferenceVisitor::HandleElementwiseNary(
844 HloInstruction* hlo) {
845 HloComputation* comp = hlo->parent();
846 return ForEachOperandDynamicDimension(
847 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
848 int64_t operand_index, HloInstruction* dynamic_size) {
849 HloInstruction* existing_size =
850 parent_->GetDynamicSize(hlo, index, dimension);
851 if (existing_size == nullptr || existing_size == dynamic_size) {
852 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
853 } else {
854 TF_RETURN_IF_ERROR(
855 InsertShapeCheck(existing_size, dynamic_size,
856 /*support_implicit_broadcast=*/true));
857
858 auto one = comp->AddInstruction(
859 HloInstruction::CreateConstant(LiteralUtil::One(S32)));
860
861 auto operand_needs_broadcast =
862 comp->AddInstruction(HloInstruction::CreateCompare(
863 ShapeUtil::MakeShape(PRED, {}), dynamic_size, existing_size,
864 ComparisonDirection::kLt));
865 auto is_one = comp->AddInstruction(HloInstruction::CreateCompare(
866 ShapeUtil::MakeShape(PRED, {}), dynamic_size, one,
867 ComparisonDirection::kEq));
868 operand_needs_broadcast =
869 comp->AddInstruction(HloInstruction::CreateBinary(
870 ShapeUtil::MakeShape(PRED, {}), HloOpcode::kAnd, is_one,
871 operand_needs_broadcast));
872
873 auto existing_needs_broadcast =
874 comp->AddInstruction(HloInstruction::CreateCompare(
875 ShapeUtil::MakeShape(PRED, {}), existing_size, dynamic_size,
876 ComparisonDirection::kLt));
877 is_one = comp->AddInstruction(HloInstruction::CreateCompare(
878 ShapeUtil::MakeShape(PRED, {}), existing_size, one,
879 ComparisonDirection::kEq));
880 existing_needs_broadcast =
881 comp->AddInstruction(HloInstruction::CreateBinary(
882 ShapeUtil::MakeShape(PRED, {}), HloOpcode::kAnd, is_one,
883 existing_needs_broadcast));
884
885 auto needs_broadcast =
886 comp->AddInstruction(HloInstruction::CreateBinary(
887 ShapeUtil::MakeShape(PRED, {}), HloOpcode::kOr,
888 operand_needs_broadcast, existing_needs_broadcast));
889 auto max_size = comp->AddInstruction(HloInstruction::CreateBinary(
890 ShapeUtil::MakeScalarShape(S32), HloOpcode::kMaximum,
891 dynamic_size, existing_size));
892 auto min_size = comp->AddInstruction(HloInstruction::CreateBinary(
893 ShapeUtil::MakeScalarShape(S32), HloOpcode::kMinimum,
894 dynamic_size, existing_size));
895 auto select_size = comp->AddInstruction(HloInstruction::CreateTernary(
896 ShapeUtil::MakeScalarShape(S32), HloOpcode::kSelect,
897 needs_broadcast, max_size, min_size));
898 parent_->SetDynamicSize(hlo, index, dimension, select_size);
899 }
900 return OkStatus();
901 });
902 }
903
HandleElementwiseBinary(HloInstruction * hlo)904 Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
905 HloInstruction* hlo) {
906 return HandleElementwiseNary(hlo);
907 }
908
HandleClamp(HloInstruction * hlo)909 Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {
910 return PassThroughDynamicDimension(hlo);
911 }
912
HandleDynamicReshape(HloInstruction * hlo)913 Status DynamicDimensionInferenceVisitor::HandleDynamicReshape(
914 HloInstruction* hlo) {
915 HloDynamicReshapeInstruction* dynamic_reshape =
916 Cast<HloDynamicReshapeInstruction>(hlo);
917 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
918 if (hlo->shape().is_dynamic_dimension(i)) {
919 parent_->SetDynamicSize(hlo, {}, i, dynamic_reshape->dim_sizes(i));
920 }
921 }
922 return OkStatus();
923 }
924
HandleReshape(HloInstruction * hlo)925 Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
926 // First scan to see if we need to decompose the dynamic reshape into a
927 // flatten-unflatten pair. If so, find the dynamic dimension using
928 // hlo->inferred_dimension() and calculate the dynamic size for that
929 // dimension.
930 bool need_flatten_unflatten = false;
931 // For a reshape we need the inferred_dimension to be present to disambiguate
932 // dynamic dimensions of hlo. HloOpcode::kDynamicReshape on the other hand
933 // allows more precise specification of dynamic dimensions of hlo's shape.
934 if (hlo->inferred_dimension() != -1) {
935 TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
936 hlo,
937 [&](HloInstruction* operand, ShapeIndex index,
938 int64_t input_dynamic_dimension, int64_t operand_index,
939 HloInstruction* operand_dynamic_size) -> Status {
940 auto common_factors = CommonFactors(operand->shape().dimensions(),
941 hlo->shape().dimensions());
942 int64_t input_dim_start = -1;
943 int64_t input_dim_end = -1;
944 int64_t output_dim_start = -1;
945 int64_t output_dim_end = -1;
946 // Find common_factors that the input belongs to.
947 for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
948 auto start = common_factors[i];
949 auto end = common_factors[i + 1];
950 if (input_dynamic_dimension >= start.first &&
951 input_dynamic_dimension < end.first) {
952 // Found the common_factor group that the input_dim belongs to.
953 input_dim_start = start.first;
954 input_dim_end = end.first;
955 output_dim_start = start.second;
956 output_dim_end = end.second;
957 }
958 }
959 if ((input_dim_end - input_dim_start) > 1 &&
960 (output_dim_end - output_dim_start) > 1) {
961 need_flatten_unflatten = true;
962 }
963 return OkStatus();
964 }));
965 if (need_flatten_unflatten) {
966 HloInstruction* operand = hlo->mutable_operand(0);
967 HloComputation* comp = hlo->parent();
968 HloInstruction* dynamic_size = comp->AddInstruction(
969 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
970 int64_t static_size = 1;
971 for (int64_t i = 0; i < operand->shape().rank(); i++) {
972 HloInstruction* dynamic_dim_size =
973 parent_->GetDynamicSize(operand, {}, i);
974 if (dynamic_dim_size == nullptr) {
975 static_size *= operand->shape().dimensions(i);
976 } else {
977 dynamic_size = comp->AddInstruction(HloInstruction::CreateBinary(
978 dynamic_size->shape(), HloOpcode::kMultiply, dynamic_size,
979 dynamic_dim_size));
980 }
981 }
982 HloInstruction* static_size_hlo =
983 comp->AddInstruction(HloInstruction::CreateConstant(
984 LiteralUtil::CreateR0<int32_t>(static_size)));
985 // Total dynamic shape size.
986 dynamic_size = comp->AddInstruction(HloInstruction::CreateBinary(
987 dynamic_size->shape(), HloOpcode::kMultiply, dynamic_size,
988 static_size_hlo));
989
990 int64_t size_without_inferred_dim =
991 ShapeUtil::ElementsIn(hlo->shape()) /
992 hlo->shape().dimensions(hlo->inferred_dimension());
993 HloInstruction* size_without_inferred_dim_hlo =
994 comp->AddInstruction(HloInstruction::CreateConstant(
995 LiteralUtil::CreateR0<int32_t>(size_without_inferred_dim)));
996 dynamic_size = comp->AddInstruction(HloInstruction::CreateBinary(
997 dynamic_size->shape(), HloOpcode::kDivide, dynamic_size,
998 size_without_inferred_dim_hlo));
999 parent_->SetDynamicSize(hlo, {}, hlo->inferred_dimension(), dynamic_size);
1000 VLOG(3)
1001 << "Need to decopose a dynamic reshape to flatten-unflatten pair. "
1002 << comp->parent()->ToString();
1003 return OkStatus();
1004 }
1005 }
1006
1007 return ForEachOperandDynamicDimension(
1008 hlo,
1009 [&](HloInstruction* operand, ShapeIndex index,
1010 int64_t input_dynamic_dimension, int64_t operand_index,
1011 HloInstruction* operand_dynamic_size) -> Status {
1012 HloInstruction* reshape = hlo;
1013 if (reshape->shape().rank() == 0) {
1014 VLOG(0) << "Reshaping a dynamic dimension into a scalar, which has "
1015 "undefined behavior when input size is 0. The offending "
1016 "instruction is: "
1017 << reshape->ToString();
1018 return OkStatus();
1019 }
1020 auto common_factors = CommonFactors(operand->shape().dimensions(),
1021 reshape->shape().dimensions());
1022 int64_t input_dim_start = -1;
1023 int64_t input_dim_end = -1;
1024 int64_t output_dim_start = -1;
1025 int64_t output_dim_end = -1;
1026 // Find common_factors that the input belongs to.
1027 for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
1028 auto start = common_factors[i];
1029 auto end = common_factors[i + 1];
1030 if (input_dynamic_dimension >= start.first &&
1031 input_dynamic_dimension < end.first) {
1032 // Found the common_factor group that the input_dim belongs to.
1033 input_dim_start = start.first;
1034 input_dim_end = end.first;
1035 output_dim_start = start.second;
1036 output_dim_end = end.second;
1037 }
1038 }
1039
1040 VLOG(2) << "Input dim start: " << input_dim_start
1041 << " Input dim end: " << input_dim_end
1042 << " output dim start: " << output_dim_start
1043 << " output dim end: " << output_dim_end;
1044
1045 if ((input_dim_end - input_dim_start) > 1 &&
1046 (output_dim_end - output_dim_start) > 1) {
1047 return InternalError(
1048 "Should be handled by decomposing reshape into "
1049 "flatten-unflatten pair. %s",
1050 hlo->ToString());
1051 }
1052
1053 for (auto common_factor : common_factors) {
1054 // Expand common factor to include degenerated output dimensions.
1055 if (common_factor.first == input_dim_start) {
1056 output_dim_start = std::min(output_dim_start, common_factor.second);
1057 }
1058 if (common_factor.first == input_dim_end) {
1059 output_dim_end = std::max(output_dim_end, common_factor.second);
1060 }
1061 }
1062
1063 int64_t output_dynamic_dimension = -1;
1064
1065 if (operand->shape().dimensions(input_dynamic_dimension) == 1) {
1066 // If dynamic dimension is 1, it can only be most-major or
1067 // most-minor.
1068 if (input_dynamic_dimension == 0) {
1069 output_dynamic_dimension = 0;
1070 } else if (input_dynamic_dimension == operand->shape().rank() - 1) {
1071 output_dynamic_dimension = reshape->shape().rank() - 1;
1072 }
1073
1074 if (output_dynamic_dimension == -1) {
1075 return Unimplemented(
1076 "Dynamic degenerated dimension that's not most-minor nor "
1077 "most-major is not supported %s",
1078 reshape->ToString());
1079 }
1080 }
1081
1082 if (output_dynamic_dimension == -1 &&
1083 output_dim_end - output_dim_start == 1) {
1084 // Only one possible output dimension.
1085 output_dynamic_dimension = output_dim_start;
1086 }
1087
1088 if (output_dynamic_dimension == -1 &&
1089 output_dim_end - output_dim_start > 1) {
1090 // One input dimension is splitted into multiple output dimensions.
1091 // Output dimension is decomposed from input most major dimension.
1092 // In this case, we don't know which one is dynamic, e.g., when we
1093 // have:
1094 //
1095 // [<=a/c, c, b]
1096 // | Reshape
1097 // [<=a, b] // a is dynamic, has to be multiple of c.
1098 // | Reshape
1099 // [1, 1, ... , a/c, c, b]
1100 //
1101 // Any dimension from the first '1' to 'a/c' can be dynamic.
1102 //
1103 // We use the following logics to disambiguate:
1104 // 1. If the user sets "inferred_dimension", then use that as
1105 // dynamic dimension.
1106 // 2. If the one dimension in the reshape is dynamic, use that as
1107 // dynamic dimension.
1108 // E.g.:
1109 // [<=4]
1110 // |
1111 // reshape
1112 // |
1113 // [1, <=2, 2]
1114 // We use second dim as dynamic dimension.
1115 //
1116 // 3. If all logics above cannot disambiguate, e.g.,:
1117 //
1118 // [<=1]
1119 // |
1120 // reshape
1121 // |
1122 // [1, 1, 1]
1123 //
1124 // We bail out and return an error.
1125 // TODO(yunxing): Further simplify this, remove 1. and fully rely
1126 // on 2.
1127 output_dynamic_dimension = reshape->inferred_dimension();
1128 if (output_dynamic_dimension == -1) {
1129 // Try find dynamic dimension from the result shape.
1130 for (int64_t i = output_dim_start; i < output_dim_end; ++i) {
1131 if (reshape->shape().is_dynamic_dimension(i)) {
1132 output_dynamic_dimension = i;
1133 }
1134 }
1135 }
1136
1137 if (output_dynamic_dimension == -1) {
1138 std::vector<int64_t> output_non_degenerated;
1139 for (int64_t i = output_dim_start; i < output_dim_end; ++i) {
1140 if (reshape->shape().dimensions(i) != 1) {
1141 output_non_degenerated.push_back(i);
1142 }
1143 }
1144 if (output_non_degenerated.size() == 1) {
1145 output_dynamic_dimension = output_non_degenerated[0];
1146 }
1147 }
1148
1149 if (output_dynamic_dimension == -1) {
1150 return InvalidArgument(
1151 "Reshape's input dynamic dimension is decomposed into "
1152 "multiple output dynamic dimensions, but the constraint is "
1153 "ambiguous and XLA can't infer the output dimension %s. ",
1154 hlo->ToString());
1155 }
1156 }
1157
1158 CHECK_NE(output_dynamic_dimension, -1);
1159 const int64_t input_dim_size =
1160 operand->shape().dimensions(input_dynamic_dimension);
1161 const int64_t output_dim_size =
1162 reshape->shape().dimensions(output_dynamic_dimension);
1163 VLOG(2) << "input_dim_size: " << input_dim_size
1164 << " output_dim_size: " << output_dim_size;
1165
1166 if (input_dim_size == output_dim_size) {
1167 // Simply forward dynamic dimension.
1168 parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1169 operand_dynamic_size);
1170 }
1171
1172 if (input_dim_size > output_dim_size) {
1173 TF_RET_CHECK(input_dim_size % output_dim_size == 0)
1174 << reshape->ToString();
1175 const int64_t divisor = input_dim_size / output_dim_size;
1176 HloInstruction* divisor_hlo =
1177 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
1178 LiteralUtil::CreateR0<int32_t>(divisor)));
1179
1180 HloInstruction* new_dynamic_size =
1181 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1182 operand_dynamic_size->shape(), HloOpcode::kDivide,
1183 operand_dynamic_size, divisor_hlo));
1184
1185 parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1186 new_dynamic_size);
1187 }
1188
1189 if (input_dim_size < output_dim_size) {
1190 // Input dimension is combined with other input dimensions.
1191 //
1192 // Adjust the output size by the ratio of dynamic_input_dim /
1193 // static_input_dim.
1194 //
1195 // For example if we have [<=3, 3] -> [9], if the dynamic size is 2,
1196 // the new output dynamic isze is 9 / 3 * 2 = 6.
1197 //
1198 // If it turns out the second dimension is also dynamic:
1199 // [<=3, <=3] -> [9], and the dynamic size is also 2, the new output
1200 // dynamic size is 6 / 3 * 2 = 4.
1201 //
1202 //
1203 HloInstruction* output_dynamic_size =
1204 parent_->GetDynamicSize(reshape, {}, output_dynamic_dimension);
1205 if (output_dynamic_size == nullptr) {
1206 output_dynamic_size =
1207 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
1208 LiteralUtil::CreateR0<int32_t>(output_dim_size)));
1209 }
1210 HloInstruction* divisor_hlo = hlo->parent()->AddInstruction(
1211 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
1212 operand->shape().dimensions(input_dynamic_dimension))));
1213
1214 HloInstruction* new_dynamic_size =
1215 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1216 output_dynamic_size->shape(), HloOpcode::kDivide,
1217 output_dynamic_size, divisor_hlo));
1218
1219 new_dynamic_size =
1220 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1221 output_dynamic_size->shape(), HloOpcode::kMultiply,
1222 new_dynamic_size, operand_dynamic_size));
1223 parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1224 new_dynamic_size);
1225 }
1226
1227 return OkStatus();
1228 });
1229 }
1230
HandleReduceWindow(HloInstruction * hlo)1231 Status DynamicDimensionInferenceVisitor::HandleReduceWindow(
1232 HloInstruction* hlo) {
1233 return ForEachOperandDynamicDimension(
1234 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
1235 int64_t operand_index, HloInstruction* dynamic_size) {
1236 auto* reduce_window = Cast<HloReduceWindowInstruction>(hlo);
1237 const WindowDimension& window_dim =
1238 reduce_window->window().dimensions(dimension);
1239
1240 if (operand_index >= reduce_window->input_count()) {
1241 // Init values doesn't have dynamic size.
1242 return OkStatus();
1243 }
1244
1245 if (!window_util::IsTrivialWindowDimension(window_dim)) {
1246 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1247 dynamic_size, window_dim.size(), window_dim.window_dilation(),
1248 window_dim.stride(), PaddingType::PADDING_VALID);
1249 dynamic_size = dynamic_window_dims.output_size;
1250 }
1251
1252 // The dimensions of all data operands of a variadic reduce window have
1253 // to be the same. This means that if one operand of variadic
1254 // reduce has a dynamic dimension, we set all outputs to use the
1255 // same dynamic size in corresponding dimensions.
1256 ShapeUtil::ForEachSubshape(
1257 reduce_window->shape(),
1258 [&](const Shape& subshape, ShapeIndex reduce_window_result_index) {
1259 if (!ShapeUtil::IsLeafIndex(reduce_window->shape(),
1260 reduce_window_result_index)) {
1261 return;
1262 }
1263 parent_->SetDynamicSize(reduce_window, reduce_window_result_index,
1264 dimension, dynamic_size);
1265 });
1266
1267 return OkStatus();
1268 });
1269 }
1270
HandleSelectAndScatter(HloInstruction * hlo)1271 Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
1272 HloInstruction* hlo) {
1273 return ForEachOperandDynamicDimension(
1274 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
1275 int64_t operand_index, HloInstruction* dynamic_size) {
1276 if (operand_index == 1) {
1277 // Operand 0 (input) determines dynamic output size. We ignore the
1278 // dynamic size in the operand 1 (output gradient).
1279 return OkStatus();
1280 }
1281 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1282
1283 return OkStatus();
1284 });
1285 }
1286
HandleSlice(HloInstruction * hlo)1287 Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
1288 return ForEachOperandDynamicDimension(
1289 hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64_t dimension,
1290 int64_t /*operand_index*/, HloInstruction* dynamic_size) {
1291 if (hlo->slice_starts(dimension) != 0 ||
1292 hlo->slice_strides(dimension) != 1 ||
1293 hlo->slice_limits(dimension) !=
1294 operand->shape().dimensions(dimension)) {
1295 // Slicing a partial element out eliminates the dynamic dimension.
1296 return OkStatus();
1297 }
1298
1299 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1300
1301 return OkStatus();
1302 });
1303 }
1304
HandleDynamicSlice(HloInstruction * hlo)1305 Status DynamicDimensionInferenceVisitor::HandleDynamicSlice(
1306 HloInstruction* hlo) {
1307 return ForEachOperandDynamicDimension(
1308 hlo, [&](HloInstruction*, ShapeIndex /*index*/, int64_t dimension,
1309 int64_t /*operand_index*/, HloInstruction* dynamic_size) {
1310 if (hlo->shape().dimensions(dimension) !=
1311 hlo->operand(0)->shape().dimensions(dimension)) {
1312 // Slicing a single element out kills the dynamic dimension.
1313 if (hlo->shape().dimensions(dimension) == 1) {
1314 return OkStatus();
1315 }
1316 return Unimplemented(
1317 "Dynamic dimension propagation on DynamicSlice where a partial "
1318 "dimension is selected %s",
1319 hlo->ToString());
1320 }
1321
1322 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1323
1324 return OkStatus();
1325 });
1326 }
1327
HandleDynamicUpdateSlice(HloInstruction * hlo)1328 Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice(
1329 HloInstruction* hlo) {
1330 return ForEachOperandDynamicDimension(
1331 hlo,
1332 [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64_t dimension,
1333 int64_t operand_index, HloInstruction* dynamic_size) {
1334 if (hlo->shape().dimensions(dimension) !=
1335 hlo->operand(0)->shape().dimensions(dimension)) {
1336 return Unimplemented(
1337 "Dynamic dimension propagation on DynamicUpdateSlice where a "
1338 "partial dimension is selected %s",
1339 hlo->ToString());
1340 }
1341
1342 if (operand_index == 1 &&
1343 hlo->operand(1)->shape().dimensions(dimension) <
1344 hlo->operand(0)->shape().dimensions(dimension)) {
1345 // DUS(input=[A], update=[<=B])
1346 //
1347 // If update dim is smaller than input dim (B < A) , then we are doing
1348 // a partial update, no need to set the output dynamic dimension.
1349 //
1350 // The dynamic shape in `update` doesn't change output dynamic shape.
1351 return OkStatus();
1352 }
1353
1354 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1355
1356 return OkStatus();
1357 });
1358 }
1359
HandleReverse(HloInstruction * hlo)1360 Status DynamicDimensionInferenceVisitor::HandleReverse(HloInstruction* hlo) {
1361 return PassThroughDynamicDimension(hlo);
1362 }
1363
HandleGather(HloInstruction * hlo)1364 Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
1365 return ForEachOperandDynamicDimension(
1366 hlo, [&](HloInstruction* operand, ShapeIndex /*index*/,
1367 int64_t input_dynamic_dimension, int64_t operand_index,
1368 HloInstruction* dynamic_size) {
1369 const GatherDimensionNumbers& gather_dims =
1370 hlo->gather_dimension_numbers();
1371 if (operand_index != 1) {
1372 if (hlo->gather_slice_sizes()[input_dynamic_dimension] == 1) {
1373 // Gathering a size 1 dimension out of a dynamic dimension removes
1374 // the dynamicity.
1375 return OkStatus();
1376 }
1377 if (hlo->gather_slice_sizes()[input_dynamic_dimension] ==
1378 operand->shape().dimensions(input_dynamic_dimension)) {
1379 // Gathering a full-sized dimension out of a dynamic dimension
1380 // propagates the dynamicity to output.
1381 int64_t output_dimension = input_dynamic_dimension;
1382 for (int64_t collapsed_dim : gather_dims.collapsed_slice_dims()) {
1383 if (collapsed_dim < input_dynamic_dimension) {
1384 // This output dimension is collapsed.
1385 output_dimension--;
1386 }
1387 }
1388 parent_->SetDynamicSize(hlo, {}, output_dimension, dynamic_size);
1389 return OkStatus();
1390 }
1391 return Unimplemented(
1392 "Detects a dynamic dimension on the data input of gather, which "
1393 "is not supported: %s, %lld",
1394 hlo->ToString(), input_dynamic_dimension);
1395 }
1396 // A mapping from output to input batch dim number. -1 means not a batch
1397 // dimension.
1398 int64_t indices_rank = hlo->operand(1)->shape().rank();
1399 int64_t output_rank = hlo->shape().rank();
1400
1401 // indices_dim is an iterator over indices dimensions.
1402 int64_t indices_dim = 0;
1403 // Find the corresponding batch dimension in the output.
1404 for (int64_t output_dim = 0; output_dim < output_rank; ++output_dim) {
1405 if (!absl::c_linear_search(gather_dims.offset_dims(), output_dim)) {
1406 // Skips index vector dimension.
1407 if (indices_dim == gather_dims.index_vector_dim()) {
1408 indices_dim++;
1409 }
1410 if (indices_dim++ == input_dynamic_dimension) {
1411 parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size);
1412 return OkStatus();
1413 }
1414 }
1415 }
1416 CHECK(indices_dim == indices_rank);
1417
1418 return Unimplemented(
1419 "Detects a non-batch dynamic dimension of gather, "
1420 "which is not supported: %s",
1421 hlo->ToString());
1422 });
1423 }
1424
HandleConditional(HloInstruction * hlo)1425 Status DynamicDimensionInferenceVisitor::HandleConditional(
1426 HloInstruction* hlo) {
1427 // Conditionals are handled by producing additional inputs and outputs of
1428 // the conditional instruction.
1429 std::vector<HloComputation*> new_branch_computations;
1430 std::vector<HloInstruction*> new_operands;
1431 // If the output of the conditional contains dynamic dimension. We send
1432 // dynamic dimension size out by adding additional root element. A mapping
1433 // from the root instruction's dynamic dimension index (represented by a shape
1434 // index as output index and a int64_t dimension number) to output index
1435 // (represented by an int64_t) is tracked for the conditional intsruction (all
1436 // branches should have the same mapping).
1437 ShapeTree<absl::flat_hash_map<int64_t, int64_t>> dynamic_output_mapping(
1438 hlo->shape());
1439
1440 bool need_rewrite = false;
1441 for (int64_t branch_index = 0; branch_index < hlo->branch_count();
1442 ++branch_index) {
1443 std::vector<HloInstruction*> operands_to_add;
1444
1445 absl::flat_hash_map<HloInstruction*, int64_t>
1446 dynamic_size_to_operand_id_index_map;
1447 // Only look at branch_index + 1, the correct operand index for a
1448 // given branch.
1449 const int64_t operand_index = branch_index + 1;
1450
1451 int operand_count =
1452 hlo->operand(operand_index)->shape().tuple_shapes_size();
1453 // Prepare to pass dynamic dimension into the new computation and add
1454 // dynamic dimension sizes as parameters to the new tuple.
1455 TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
1456 hlo, operand_index,
1457 [&](HloInstruction*, ShapeIndex, int64_t, int64_t,
1458 HloInstruction* dynamic_size) -> Status {
1459 TF_RET_CHECK(hlo->operand(operand_index)->shape().IsTuple())
1460 << "Only tuple typed inputs can have dynamic dimension. Please "
1461 "file a bug against XLA team.";
1462 const HloInstruction* tuple_operand = hlo->operand(operand_index);
1463 for (int64_t i = 0; i < tuple_operand->operand_count(); ++i) {
1464 // If the dynamic size is already an operand to the computation,
1465 // skip adding it to the computation input again.
1466 if (dynamic_size == tuple_operand->operand(i)) {
1467 dynamic_size_to_operand_id_index_map[dynamic_size] = i;
1468 return OkStatus();
1469 }
1470 }
1471 auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size);
1472 if (iter == dynamic_size_to_operand_id_index_map.end()) {
1473 operands_to_add.push_back(dynamic_size);
1474 dynamic_size_to_operand_id_index_map[dynamic_size] =
1475 operand_count++;
1476 }
1477 return OkStatus();
1478 }));
1479
1480 HloInstruction* original_input = hlo->mutable_operand(operand_index);
1481 HloComputation* branch_computation = hlo->branch_computation(branch_index);
1482
1483 HloComputation* new_computation = branch_computation;
1484 HloInstruction* new_operand = hlo->mutable_operand(operand_index);
1485 if (!operands_to_add.empty()) {
1486 TF_RET_CHECK(original_input->shape().IsTuple());
1487 need_rewrite = true;
1488 new_operand = TupleUtil::AppendSuffix(original_input, operands_to_add);
1489 TF_ASSIGN_OR_RETURN(
1490 new_computation,
1491 WidenComputation(branch_computation, new_operand->shape()));
1492 }
1493 // Set the dynamic dimensions for the newly created branch computation's
1494 // parameters so that the hlos inside the computation can see dynamic
1495 // dimensions.
1496 DynamicParameterBinding dynamic_parameter_binding;
1497 TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
1498 hlo, operand_index,
1499 [&](HloInstruction*, ShapeIndex index, int64_t dimension,
1500 int64_t operand_index, HloInstruction* dynamic_size) {
1501 DynamicParameterBinding::DynamicParameter dynamic_parameter{
1502 0, {dynamic_size_to_operand_id_index_map[dynamic_size]}};
1503 DynamicParameterBinding::DynamicDimension dynamic_dimension{
1504 0, {index}, dimension};
1505 TF_RETURN_IF_ERROR(dynamic_parameter_binding.Bind(dynamic_parameter,
1506 dynamic_dimension));
1507
1508 return OkStatus();
1509 }));
1510 VLOG(2) << "dynamic_parameter_binding for conditional branch"
1511 << dynamic_parameter_binding;
1512 TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1513 new_computation, dynamic_parameter_binding, parent_));
1514
1515 new_branch_computations.push_back(new_computation);
1516 new_operands.push_back(new_operand);
1517 }
1518 int tuple_count = hlo->shape().tuple_shapes_size();
1519 // The dynamism of the output of branches can be different.
1520 // E.g.,
1521 // true_branch (s32[<=4])
1522 // false_branch (s32[4])
1523 //
1524 // The following loop populates dynamic_output_mapping and account for
1525 // dynamism across all branches.
1526 ShapeUtil::ForEachSubshape(
1527 hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
1528 if (!subshape.IsArray()) {
1529 return;
1530 }
1531 for (int64_t i = 0; i < subshape.rank(); ++i) {
1532 for (int64_t j = 0; j < new_branch_computations.size(); ++j) {
1533 HloInstruction* dynamic_size = parent_->GetDynamicSize(
1534 new_branch_computations[j]->root_instruction(), index, i);
1535 if (dynamic_size) {
1536 if (dynamic_output_mapping.element(index).contains(i)) {
1537 continue;
1538 }
1539 dynamic_output_mapping.mutable_element(index)->emplace(
1540 i, tuple_count++);
1541 }
1542 }
1543 }
1544 });
1545 for (int64_t branch_index = 0; branch_index < hlo->branch_count();
1546 ++branch_index) {
1547 std::vector<HloInstruction*> hlos_to_add_in_root;
1548 // There may be some dynamic dimensions coming out of the computation, wire
1549 // that into the root instruction as additional tuple elements.
1550 ShapeUtil::ForEachSubshape(
1551 hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
1552 if (!subshape.IsArray()) {
1553 return;
1554 }
1555 for (int64_t i = 0; i < subshape.rank(); ++i) {
1556 if (dynamic_output_mapping.element(index).contains(i)) {
1557 HloInstruction* dynamic_size = parent_->GetDynamicSize(
1558 new_branch_computations[branch_index]->root_instruction(),
1559 index, i);
1560 if (dynamic_size) {
1561 hlos_to_add_in_root.push_back(dynamic_size);
1562 } else {
1563 HloInstruction* constant_size =
1564 new_branch_computations[branch_index]->AddInstruction(
1565 HloInstruction::CreateConstant(
1566 LiteralUtil::CreateR0<int32_t>(
1567 subshape.dimensions(i))));
1568 hlos_to_add_in_root.push_back(constant_size);
1569 }
1570 }
1571 }
1572 });
1573
1574 VLOG(2) << "hlos_to_add_in_root:" << hlos_to_add_in_root.size();
1575 if (!hlos_to_add_in_root.empty()) {
1576 need_rewrite = true;
1577 HloInstruction* new_branch_root = TupleUtil::AppendSuffix(
1578 new_branch_computations[branch_index]->root_instruction(),
1579 hlos_to_add_in_root);
1580 new_branch_computations[branch_index]->set_root_instruction(
1581 new_branch_root,
1582 /*accept_different_shape=*/true);
1583 }
1584 }
1585
1586 if (!need_rewrite) {
1587 return OkStatus();
1588 }
1589 // Create a new conditional with the new operations and computations.
1590 HloInstruction* new_conditional =
1591 hlo->parent()->AddInstruction(HloInstruction::CreateConditional(
1592 new_branch_computations[0]->root_instruction()->shape(),
1593 hlo->mutable_operand(0), new_branch_computations, new_operands));
1594
1595 HloInstruction* new_conditional_extracted = TupleUtil::ExtractPrefix(
1596 new_conditional, hlo->shape().tuple_shapes_size());
1597 // Now set the dynamic dimensions of the newly created conditional.
1598 dynamic_output_mapping.ForEachElement(
1599 [&](const ShapeIndex& index,
1600 const absl::flat_hash_map<int64_t, int64_t>& dim_to_output) {
1601 for (auto iter : dim_to_output) {
1602 int64_t dim = iter.first;
1603 int64_t output_index = iter.second;
1604 HloInstruction* dynamic_size = hlo->parent()->AddInstruction(
1605 HloInstruction::CreateGetTupleElement(
1606 ShapeUtil::MakeScalarShape(S32), new_conditional,
1607 output_index));
1608 parent_->SetDynamicSize(new_conditional, index, dim, dynamic_size);
1609 parent_->SetDynamicSize(new_conditional_extracted, index, dim,
1610 dynamic_size);
1611 }
1612 });
1613
1614 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_conditional_extracted));
1615 // Remove the original instruction even if has side-effects.
1616 TF_RETURN_IF_ERROR(hlo->parent()->RemoveInstruction(hlo));
1617 SetVisited(*new_conditional);
1618 SetVisited(*new_conditional_extracted);
1619 return OkStatus();
1620 }
1621
HandleMap(HloInstruction * hlo)1622 Status DynamicDimensionInferenceVisitor::HandleMap(HloInstruction* hlo) {
1623 return HandleElementwiseNary(hlo);
1624 }
1625
HandleScatter(HloInstruction * hlo)1626 Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
1627 return ForEachOperandDynamicDimension(
1628 hlo,
1629 [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64_t dimension,
1630 int64_t operand_index, HloInstruction* operand_dynamic_size) {
1631 if (operand_index == 0) {
1632 parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size);
1633 return OkStatus();
1634 }
1635
1636 const ScatterDimensionNumbers& scatter_dims =
1637 hlo->scatter_dimension_numbers();
1638 if (operand_index == 2 &&
1639 absl::c_linear_search(scatter_dims.update_window_dims(),
1640 dimension)) {
1641 // Dynamic update window dimension is only allowed if it is exactly
1642 // the same as the corresponding operand dimension.
1643 std::vector<int64_t> update_window_dims_in_operand;
1644 for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
1645 if (absl::c_linear_search(scatter_dims.inserted_window_dims(), i)) {
1646 continue;
1647 }
1648 update_window_dims_in_operand.push_back(i);
1649 }
1650
1651 for (int64_t i = 0; i < scatter_dims.update_window_dims_size(); ++i) {
1652 if (scatter_dims.update_window_dims(i) == dimension) {
1653 const Shape& operand_shape = hlo->operand(0)->shape();
1654 const Shape& update_shape = hlo->operand(2)->shape();
1655 int64_t dim_in_operand = update_window_dims_in_operand[i];
1656 if (operand_shape.dimensions(dim_in_operand) !=
1657 update_shape.dimensions(dimension) ||
1658 !operand_shape.is_dynamic_dimension(dim_in_operand)) {
1659 return Unimplemented(
1660 "Dynamic dimension of update window dims that are not the "
1661 "same as corresponding operand dim is not supported: "
1662 "%s",
1663 hlo->ToString());
1664 }
1665 HloInstruction* base_dynamic_size = parent_->GetDynamicSize(
1666 hlo->mutable_operand(0), {}, dim_in_operand);
1667 if (base_dynamic_size != operand_dynamic_size) {
1668 return Unimplemented(
1669 "Dynamic dimension size of update window dims that are not "
1670 "the same as corresponding operand dim is not supported: "
1671 "%s.\n Dynamic dim size of base: %s, dynamic dim size of "
1672 "update: %s",
1673 hlo->ToString(), base_dynamic_size->ToString(),
1674 operand_dynamic_size->ToString());
1675 }
1676 }
1677 }
1678 }
1679 // The dynamic dimension is collapsed and won't show up in the output.
1680 // Do nothing here.
1681 return OkStatus();
1682 });
1683 }
1684
HandleWhile(HloInstruction * hlo)1685 Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
1686 // If the output of the kWhile contains dynamic dimension, we send
1687 // dynamic dimension size into the while body by adding additional root/body
1688 // element. A mapping from the root instruction's dynamic dimension index
1689 // (represented by a shape index as output index and an int64_t dimension
1690 // number) to output index (represented by an int64_t) is tracked for the
1691 // while instruction.
1692 ShapeTree<absl::flat_hash_map<int64_t, int64_t>> dynamic_output_mapping(
1693 hlo->shape());
1694 std::vector<HloInstruction*> operands_to_add;
1695 const int original_tuple_count = hlo->shape().tuple_shapes_size();
1696 int operand_count = original_tuple_count;
1697 TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
1698 hlo, [&](HloInstruction*, ShapeIndex index, int64_t dim, int64_t,
1699 HloInstruction* dynamic_size) {
1700 operands_to_add.push_back(dynamic_size);
1701 dynamic_output_mapping.mutable_element(index)->emplace(dim,
1702 operand_count++);
1703 return OkStatus();
1704 }));
1705 ShapeUtil::ForEachSubshape(
1706 hlo->while_body()->root_instruction()->shape(),
1707 [&](const Shape& subshape, const ShapeIndex& index) {
1708 if (!subshape.IsArray()) {
1709 return;
1710 }
1711 for (int64_t dim = 0; dim < subshape.rank(); ++dim) {
1712 if (subshape.is_dynamic_dimension(dim)) {
1713 if (!dynamic_output_mapping.mutable_element(index)->contains(dim)) {
1714 // This dynamic dimension doesn't come from operand, but is
1715 // generated in the middle of the while body. Its initial size
1716 // should be static.
1717 operands_to_add.push_back(hlo->parent()->AddInstruction(
1718 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
1719 subshape.dimensions(dim)))));
1720 dynamic_output_mapping.mutable_element(index)->emplace(
1721 dim, operand_count++);
1722 }
1723 }
1724 }
1725 });
1726 DynamicParameterBinding binding_for_while;
1727 if (!operands_to_add.empty()) {
1728 // Only replace the while loop if there are new parameters to add.
1729 HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
1730 TF_ASSIGN_OR_RETURN(
1731 WhileUtil::MakeInstructionsLiveInResult result,
1732 WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add));
1733 // WhileUtil creates a new while hlo and tuple. Update the dynamic size
1734 // mapping for the newly created tuple.
1735 HloInstruction* new_tuple_operand =
1736 result.new_while_instr->mutable_operand(0);
1737 parent_->CopyMapping(/*from=*/old_tuple_operand,
1738 /*to=*/new_tuple_operand);
1739 hlo = result.new_while_instr;
1740 // We have replaced the while loop, now set the dynamic dimensions for the
1741 // newly created while loop so that the hlos that consumes the while loop
1742 // can see the dynamic dimensions. Also sets the dynamic parameter binding
1743 // for running inference in the while loop.
1744 TF_RETURN_IF_ERROR(dynamic_output_mapping.ForEachElementWithStatus(
1745 [&](const ShapeIndex& index,
1746 const absl::flat_hash_map<int64_t, int64_t>& dim_to_size) {
1747 for (auto key : dim_to_size) {
1748 int64_t dimension = key.first;
1749 const int64_t output_dynamic_size_index = key.second;
1750 DynamicParameterBinding::DynamicParameter dynamic_parameter{
1751 0, {output_dynamic_size_index}};
1752 DynamicParameterBinding::DynamicDimension dynamic_dimension{
1753 0, index, dimension};
1754 TF_RETURN_IF_ERROR(
1755 binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
1756 // This is the updated output dynamic size coming out of hlo while
1757 // loop.
1758 HloInstruction* output_dynamic_size = hlo->parent()->AddInstruction(
1759 HloInstruction::CreateGetTupleElement(
1760 ShapeUtil::MakeScalarShape(S32), hlo,
1761 output_dynamic_size_index));
1762 parent_->SetDynamicSize(result.replacement_instr, index, dimension,
1763 output_dynamic_size);
1764 }
1765 return OkStatus();
1766 }));
1767 // Set the replacement instruction as visited to avoid visiting it again.
1768 SetVisited(*result.replacement_instr);
1769 }
1770 // Run inference in while body and condition.
1771 TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1772 hlo->while_body(), binding_for_while, parent_));
1773 TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1774 hlo->while_condition(), binding_for_while, parent_));
1775
1776 if (operands_to_add.empty()) {
1777 // No dynamic dimension in the inputs and outputs.
1778 return OkStatus();
1779 }
1780
1781 // The dynamic dimension size could have been changed in the loop body (e.g, A
1782 // loop that inserts items in a stack, the stack size increases with each
1783 // iteration). Rewrite the dynamic dimension size at the root.
1784 HloInstruction* body_root = hlo->while_body()->root_instruction();
1785 std::vector<HloInstruction*> new_root_operands(body_root->operand_count(),
1786 nullptr);
1787
1788 // Original non-dynamic-dim operands of root are pass-through.
1789 for (int i = 0; i < original_tuple_count; ++i) {
1790 new_root_operands[i] =
1791 hlo->while_body()->AddInstruction(HloInstruction::CreateGetTupleElement(
1792 body_root->shape().tuple_shapes(i), body_root, i));
1793 }
1794 // Add dynamic dimension size as new parameters.
1795 TF_RETURN_IF_ERROR(ForEachDynamicDimension(
1796 hlo->while_body()->root_instruction(),
1797 [&](ShapeIndex index, int64_t dim,
1798 HloInstruction* dynamic_size) -> Status {
1799 const int64_t output_index =
1800 dynamic_output_mapping.element(index).at(dim);
1801 new_root_operands[output_index] = dynamic_size;
1802 return OkStatus();
1803 }));
1804 for (auto operand : new_root_operands) {
1805 TF_RET_CHECK(operand != nullptr);
1806 }
1807 HloInstruction* new_body_root = hlo->while_body()->AddInstruction(
1808 HloInstruction::CreateTuple(new_root_operands));
1809 hlo->while_body()->set_root_instruction(new_body_root);
1810 return OkStatus();
1811 }
1812
HandleParameter(HloInstruction * hlo)1813 Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) {
1814 return param_bindings_.ForEachBinding(
1815 [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
1816 const DynamicParameterBinding::DynamicDimension& dynamic_dimension) {
1817 if (dynamic_dimension.parameter_num != hlo->parameter_number()) {
1818 return OkStatus();
1819 }
1820 HloComputation* computation = hlo->parent();
1821 HloInstruction* target_parameter =
1822 computation->parameter_instruction(dynamic_dimension.parameter_num);
1823
1824 HloInstruction* dynamic_size =
1825 computation->parameter_instruction(dynamic_parameter.parameter_num);
1826 for (int64_t i : dynamic_parameter.parameter_index) {
1827 dynamic_size =
1828 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
1829 ShapeUtil::GetSubshape(dynamic_size->shape(), {i}),
1830 dynamic_size, i));
1831 }
1832
1833 parent_->SetDynamicSize(target_parameter,
1834 dynamic_dimension.parameter_index,
1835 dynamic_dimension.dimension, dynamic_size);
1836 return OkStatus();
1837 });
1838 }
1839
ForEachDynamicDimension(HloInstruction * inst,const DynamicDimensionFn & fn)1840 Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension(
1841 HloInstruction* inst, const DynamicDimensionFn& fn) {
1842 auto iter = parent_->per_hlo_dynamic_dimensions_.find(inst);
1843 if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
1844 for (auto& dynamic_dimension : iter->second) {
1845 HloInstruction* dynamic_size = parent_->GetDynamicSize(
1846 dynamic_dimension.inst, dynamic_dimension.index,
1847 dynamic_dimension.dim);
1848 TF_RETURN_IF_ERROR(
1849 fn(dynamic_dimension.index, dynamic_dimension.dim, dynamic_size));
1850 }
1851 }
1852 return OkStatus();
1853 }
1854
InsertShapeCheck(HloInstruction * dim1,HloInstruction * dim2,bool support_implicit_broadcast)1855 Status DynamicDimensionInferenceVisitor::InsertShapeCheck(
1856 HloInstruction* dim1, HloInstruction* dim2,
1857 bool support_implicit_broadcast) {
1858 switch (shape_check_mode_) {
1859 case DynamicDimensionInference::kIgnore:
1860 return Status::OK();
1861 case DynamicDimensionInference::kCompileTime:
1862 return InvalidArgument(
1863 "Fail to proof the equality of two dimensions at compile time: "
1864 "%s vs %s",
1865 dim1->ToString(), dim2->ToString());
1866 case DynamicDimensionInference::kRuntime: {
1867 TF_ASSIGN_OR_RETURN(
1868 HloInstruction * assertion,
1869 MakeCompareHlo(Comparison::Direction::kEq, dim1, dim2));
1870 if (shape_assertion_ == nullptr) {
1871 shape_assertion_ = assertion;
1872 } else {
1873 TF_ASSIGN_OR_RETURN(
1874 shape_assertion_,
1875 MakeBinaryHlo(HloOpcode::kAnd, shape_assertion_, assertion));
1876 }
1877 return OkStatus();
1878 }
1879 default:
1880 LOG(FATAL) << "Unreachable";
1881 }
1882 }
1883
ForEachDynamicDimensionInOperand(HloInstruction * inst,int64_t operand_index,const OperandDynamicDimensionFn & fn)1884 Status DynamicDimensionInferenceVisitor::ForEachDynamicDimensionInOperand(
1885 HloInstruction* inst, int64_t operand_index,
1886 const OperandDynamicDimensionFn& fn) {
1887 auto iter =
1888 parent_->per_hlo_dynamic_dimensions_.find(inst->operand(operand_index));
1889 if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
1890 for (auto& dynamic_dimension : iter->second) {
1891 HloInstruction* dynamic_size = parent_->GetDynamicSize(
1892 dynamic_dimension.inst, dynamic_dimension.index,
1893 dynamic_dimension.dim);
1894 TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index,
1895 dynamic_dimension.dim, operand_index,
1896 dynamic_size));
1897 }
1898 }
1899 return OkStatus();
1900 }
1901
ForEachOperandDynamicDimension(HloInstruction * inst,const OperandDynamicDimensionFn & fn)1902 Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension(
1903 HloInstruction* inst, const OperandDynamicDimensionFn& fn) {
1904 for (int64_t operand_index = 0; operand_index < inst->operand_count();
1905 ++operand_index) {
1906 TF_RETURN_IF_ERROR(
1907 ForEachDynamicDimensionInOperand(inst, operand_index, fn));
1908 }
1909 return OkStatus();
1910 }
1911
SetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64_t dim,HloInstruction * size)1912 void DynamicDimensionInference::SetDynamicSize(HloInstruction* inst,
1913 const ShapeIndex& index,
1914 int64_t dim,
1915 HloInstruction* size) {
1916 VLOG(1) << "Set dimension inst " << inst->ToString() << " index "
1917 << index.ToString() << "@" << dim << " to " << size->ToShortString();
1918 Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index);
1919 CHECK(!subshape.IsTuple()) << "Can't set a tuple shape to dynamic dimension";
1920 CHECK(dim < subshape.rank() && dim >= 0)
1921 << "Asked to set invalid dynamic dimension. Shape: "
1922 << subshape.ToString() << ", Dimension: " << dim;
1923 DynamicDimension dynamic_dimension{inst, index, dim};
1924 // Updating a dynamic dimension twice overwrites the previous one.
1925 dynamic_mapping_[dynamic_dimension] = size;
1926 auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst);
1927 iter.first->second.emplace(dynamic_dimension);
1928 }
1929
CopyMapping(HloInstruction * from,HloInstruction * to)1930 void DynamicDimensionInference::CopyMapping(HloInstruction* from,
1931 HloInstruction* to) {
1932 auto iter = per_hlo_dynamic_dimensions_.find(from);
1933 if (iter != per_hlo_dynamic_dimensions_.end()) {
1934 for (auto& dynamic_dimension : iter->second) {
1935 HloInstruction* dynamic_size =
1936 GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index,
1937 dynamic_dimension.dim);
1938 SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim,
1939 dynamic_size);
1940 }
1941 }
1942 }
1943
1944 /* static */
Run(HloModule * module,CustomCallInferenceHandler custom_call_handler,ShapeCheckMode shape_check_mode,const AssertionGenerator & assertion_generator)1945 StatusOr<DynamicDimensionInference> DynamicDimensionInference::Run(
1946 HloModule* module, CustomCallInferenceHandler custom_call_handler,
1947 ShapeCheckMode shape_check_mode,
1948 const AssertionGenerator& assertion_generator) {
1949 VLOG(2) << "Param Config " << module->dynamic_parameter_binding().ToString();
1950 DynamicDimensionInference inference(module, std::move(custom_call_handler),
1951 shape_check_mode, assertion_generator);
1952 TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions());
1953 return inference;
1954 }
1955
ToString() const1956 std::string DynamicDimensionInference::ToString() const {
1957 std::vector<std::string> pieces;
1958 pieces.push_back("DynamicDimensionInference: ");
1959 for (const auto& mapping : dynamic_mapping_) {
1960 const DynamicDimension& dynamic_dimension = mapping.first;
1961 pieces.push_back(absl::StrFormat(
1962 " -- instruction %s at %s has dim %lld as dynamic"
1963 " dimension, which is represented by instruction %s",
1964 dynamic_dimension.inst->ToString(), dynamic_dimension.index.ToString(),
1965 dynamic_dimension.dim, mapping.second->ToString()));
1966 }
1967 return absl::StrJoin(pieces, "\n");
1968 }
1969
DynamicDimensionInference(HloModule * module,CustomCallInferenceHandler custom_call_handler,ShapeCheckMode shape_check_mode,AssertionGenerator assertion_generator)1970 DynamicDimensionInference::DynamicDimensionInference(
1971 HloModule* module, CustomCallInferenceHandler custom_call_handler,
1972 ShapeCheckMode shape_check_mode, AssertionGenerator assertion_generator)
1973 : module_(module),
1974 custom_call_handler_(std::move(custom_call_handler)),
1975 shape_check_mode_(shape_check_mode),
1976 assertion_generator_(assertion_generator) {}
1977
AnalyzeDynamicDimensions()1978 Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
1979 return DynamicDimensionInferenceVisitor::Run(
1980 module_->entry_computation(), module_->dynamic_parameter_binding(), this,
1981 custom_call_handler_, shape_check_mode_, assertion_generator_);
1982 }
1983
ReplaceAllDynamicDimensionUsesWith(HloInstruction * replace,HloInstruction * with)1984 void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith(
1985 HloInstruction* replace, HloInstruction* with) {
1986 CHECK(Shape::Equal().IgnoreLayout()(replace->shape(),
1987 ShapeUtil::MakeScalarShape(S32)));
1988 CHECK(Shape::Equal().IgnoreLayout()(with->shape(),
1989 ShapeUtil::MakeScalarShape(S32)));
1990 for (auto& kv : dynamic_mapping_) {
1991 if (kv.second == replace) {
1992 kv.second = with;
1993 }
1994 }
1995 }
1996
ForwardDynamicSize(HloInstruction * inst,HloInstruction * new_inst,const ShapeIndex & index)1997 Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
1998 HloInstruction* new_inst,
1999 const ShapeIndex& index) {
2000 CHECK(Shape::Equal()(inst->shape(), new_inst->shape()));
2001
2002 for (int64_t dim = 0; dim < inst->shape().rank(); ++dim) {
2003 DynamicDimension dynamic_dimension_new{new_inst, index, dim};
2004 DynamicDimension dynamic_dimension{inst, index, dim};
2005 auto iter = dynamic_mapping_.find(dynamic_dimension);
2006 if (iter != dynamic_mapping_.end()) {
2007 dynamic_mapping_.insert({dynamic_dimension_new, iter->second});
2008 auto iter = per_hlo_dynamic_dimensions_.try_emplace(new_inst);
2009 iter.first->second.emplace(dynamic_dimension_new);
2010 }
2011 }
2012
2013 return OkStatus();
2014 }
2015
HasDynamicDimension(HloInstruction * inst,ShapeIndexView index) const2016 bool DynamicDimensionInference::HasDynamicDimension(
2017 HloInstruction* inst, ShapeIndexView index) const {
2018 bool has_dynamic_dim = false;
2019 ShapeUtil::ForEachSubshape(inst->shape(), [&](const Shape& subshape,
2020 const ShapeIndex& subindex) {
2021 if (subshape.IsTuple()) {
2022 return;
2023 }
2024 if (ShapeIndexView(subindex).first(index.size()) != index) {
2025 return;
2026 }
2027 for (int64_t i = 0; i < subshape.dimensions_size(); ++i) {
2028 HloInstruction* operand_dynamic_size = GetDynamicSize(inst, subindex, i);
2029 if (operand_dynamic_size != nullptr) {
2030 has_dynamic_dim = true;
2031 }
2032 }
2033 });
2034 return has_dynamic_dim;
2035 }
2036
Update(HloInstruction * inst)2037 Status DynamicDimensionInference::Update(HloInstruction* inst) {
2038 DynamicParameterBinding parameter_binding;
2039 DynamicDimensionInferenceVisitor visitor(
2040 parameter_binding, this, custom_call_handler_, shape_check_mode_);
2041 return inst->Visit(&visitor);
2042 }
2043
GetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64_t dim) const2044 HloInstruction* DynamicDimensionInference::GetDynamicSize(
2045 HloInstruction* inst, const ShapeIndex& index, int64_t dim) const {
2046 auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim});
2047 if (iter != dynamic_mapping_.end()) {
2048 return iter->second;
2049 }
2050 return nullptr;
2051 }
2052
GetDynamicSizes(HloInstruction * inst,const ShapeIndex & index) const2053 std::vector<HloInstruction*> DynamicDimensionInference::GetDynamicSizes(
2054 HloInstruction* inst, const ShapeIndex& index) const {
2055 CHECK(ShapeUtil::IndexIsValid(inst->shape(), index));
2056 const int64_t rank = ShapeUtil::GetSubshape(inst->shape(), index).rank();
2057 std::vector<HloInstruction*> result(rank, nullptr);
2058 for (int64_t i = 0; i < rank; ++i) {
2059 result[i] = GetDynamicSize(inst, {}, i);
2060 }
2061 return result;
2062 }
2063
2064 } // namespace xla
2065