xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/dynamic_padder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
16 
17 #include <algorithm>
18 #include <functional>
19 #include <optional>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/comparison_util.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/literal_util.h"
30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
31 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
32 #include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
34 #include "tensorflow/compiler/xla/service/hlo_computation.h"
35 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_dce.h"
37 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
38 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
39 #include "tensorflow/compiler/xla/service/hlo_module.h"
40 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
41 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
42 #include "tensorflow/compiler/xla/service/shape_inference.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/compiler/xla/window_util.h"
47 #include "tensorflow/compiler/xla/xla_data.pb.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/monitoring/gauge.h"
50 #include "tensorflow/core/platform/errors.h"
51 #include "tensorflow/core/platform/statusor.h"
52 
53 namespace xla {
54 
55 namespace {
56 
57 auto* dynamic_padding_gauge = tensorflow::monitoring::Gauge<bool, 0>::New(
58     "/tensorflow/core/use_dynamic_padding_gauge",
59     "Tracks if dynamic padder is used.");
60 
61 // ChooseIdentityValue looks at the instruction's operand, returns a
62 // identity value which, when padded, doesn't change the result of the
63 // instruction.
64 //
65 // nullopt is returned if padding doesn't need to be reset.
ChooseIdentityValue(HloInstruction * inst,int64_t operand_number)66 StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst,
67                                               int64_t operand_number) {
68   // Padding on elementwise operation doesn't affect the result of the effective
69   // data.
70   if (inst->IsElementwise()) {
71     return nullptr;
72   }
73   if (inst->opcode() == HloOpcode::kSelectAndScatter ||
74       inst->IsCustomCall("DynamicSelectAndScatterSamePadding")) {
75     if (operand_number == 1) {
76       return inst->mutable_operand(2);
77     }
78     TF_RET_CHECK(operand_number == 0);
79     HloComputation* select = inst->called_computations()[0];
80 
81     if (Match(select->root_instruction(),
82               match::Compare(match::Parameter(), match::Parameter())
83                   .WithComparisonDirection(ComparisonDirection::kGe))) {
84       return inst->AddInstruction(HloInstruction::CreateConstant(
85           LiteralUtil::MinValue(inst->operand(0)->shape().element_type())));
86     } else {
87       return Unimplemented(
88           "Only select and scatter with `max` as select function is "
89           "supported, got %s",
90           select->ToString());
91     }
92   }
93   switch (inst->opcode()) {
94     case HloOpcode::kReduce: {
95       auto* reduce = Cast<HloReduceInstruction>(inst);
96       TF_RET_CHECK(operand_number < reduce->input_count())
97           << "Only data operand with dynamic dimension is valid.";
98       // Variadic reduce has different init value for different operand, given
99       // a data operand number, find the init value index.
100       int64_t init_value_index = reduce->input_count() + operand_number;
101       return inst->mutable_operand(init_value_index);
102     }
103     case HloOpcode::kReduceWindow: {
104       auto* reduce_window = Cast<HloReduceWindowInstruction>(inst);
105       TF_RET_CHECK(operand_number < reduce_window->input_count())
106           << "Only data operand with dynamic dimension is valid.";
107       // Variadic reduce has different init value for different operand, given
108       // a data operand number, find the init value index.
109       int64_t init_value_index = reduce_window->input_count() + operand_number;
110       return inst->mutable_operand(init_value_index);
111     }
112 
113     case HloOpcode::kConvolution:
114     case HloOpcode::kDot: {
115       // Use 0 as padding value for convolution and dot.
116       //
117       // Note that the output type (inst->shape().element_type()) isn't
118       // necessarily the same as the input type (element type of operands).  For
119       // example, a dot can take s8 as input and output s32.
120       PrimitiveType ptype = inst->operand(0)->shape().element_type();
121       return inst->AddInstruction(
122           HloInstruction::CreateConstant(LiteralUtil::Zero(ptype)));
123     }
124 
125     case HloOpcode::kPad:
126       return inst->mutable_operand(1);
127     case HloOpcode::kScatter: {
128       if (operand_number != 1) {
129         return nullptr;
130       }
131       PrimitiveType indices_ptype =
132           inst->operand(operand_number)->shape().element_type();
133 
134       return inst->AddInstruction(
135           HloInstruction::CreateConstant(LiteralUtil::MaxValue(indices_ptype)));
136     }
137     case HloOpcode::kParameter:
138     case HloOpcode::kGather:
139     case HloOpcode::kDynamicSlice:
140     case HloOpcode::kDynamicUpdateSlice:
141     case HloOpcode::kGetDimensionSize:
142     case HloOpcode::kSetDimensionSize:
143     case HloOpcode::kConcatenate:
144     case HloOpcode::kReshape:
145     case HloOpcode::kReverse:
146     case HloOpcode::kTuple:
147     case HloOpcode::kAllReduce:
148     case HloOpcode::kReduceScatter:
149     case HloOpcode::kBroadcast:
150     case HloOpcode::kTranspose:
151     case HloOpcode::kSort:
152     case HloOpcode::kSlice:
153     case HloOpcode::kDomain:
154       return nullptr;
155     case HloOpcode::kCustomCall:
156       // Assume that custom calls created by the client are valid with padded
157       // dynamic dimensions.
158       return nullptr;
159     default:
160       return UnimplementedStrCat("Unimplemented padding for instruction: ",
161                                  inst->ToString());
162   }
163 }
164 
ReplaceGetSize(HloInstruction * instr,DynamicDimensionInference * dynamic_dimension_inference)165 StatusOr<bool> ReplaceGetSize(
166     HloInstruction* instr,
167     DynamicDimensionInference* dynamic_dimension_inference) {
168   if (instr->opcode() != HloOpcode::kGetDimensionSize) {
169     return false;
170   }
171   HloComputation* computation = instr->parent();
172 
173   TF_ASSIGN_OR_RETURN(auto legal_shape,
174                       ShapeInference::InferGetDimensionSizeShape(
175                           instr->operand(0)->shape(), instr->dimension()));
176   TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape))
177       << "instr->shape() " << instr->shape().ToString() << " , "
178       << "legal_shape " << legal_shape.ToString();
179   TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32));
180   HloInstruction* operand = instr->mutable_operand(0);
181   int64_t dim = instr->dimension();
182   HloInstruction* dynamic_size =
183       dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
184   if (dynamic_size != nullptr) {
185     TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
186     // The dependency between a instruction and its dynamic dimensions is not
187     // modeled in the IR. As instr is being replaced by dynamic_size, also tell
188     // dynamic dimension inference that the instruction is being replaced.
189     dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(
190         instr, dynamic_size);
191   } else {
192     int32_t size = instr->operand(0)->shape().dimensions(dim);
193     HloInstruction* new_instr = computation->AddInstruction(
194         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(size)));
195     TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
196     dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr,
197                                                                     new_instr);
198   }
199   return true;
200 }
201 
ReplaceSetSize(HloInstruction * instr)202 StatusOr<bool> ReplaceSetSize(HloInstruction* instr) {
203   if (instr->opcode() != HloOpcode::kSetDimensionSize) {
204     return false;
205   }
206 
207   TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
208       instr->shape(), instr->operand(0)->shape()))
209       << "instr->shape() " << instr->shape().ToString() << " , "
210       << "instruction operand shape " << instr->operand(0)->shape();
211   HloInstruction* operand = instr->mutable_operand(0);
212 
213   TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
214   return true;
215 }
216 
ReplaceSetBound(HloInstruction * instr)217 StatusOr<bool> ReplaceSetBound(HloInstruction* instr) {
218   if (instr->opcode() != HloOpcode::kCustomCall ||
219       instr->custom_call_target() != "SetBound") {
220     return false;
221   }
222 
223   TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
224       instr->shape(), instr->operand(0)->shape()))
225       << "instr->shape() " << instr->shape().ToString() << " , "
226       << "instruction operand shape " << instr->operand(0)->shape();
227   HloInstruction* operand = instr->mutable_operand(0);
228 
229   TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
230   return true;
231 }
232 
ShouldSkipPadOnOperand(const HloInstruction * inst,int64_t operand_num,int64_t dimension)233 bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64_t operand_num,
234                             int64_t dimension) {
235   switch (inst->opcode()) {
236     case HloOpcode::kConvolution: {
237       if (operand_num == 0) {
238         if (dimension ==
239             inst->convolution_dimension_numbers().input_batch_dimension()) {
240           return true;
241         }
242         const auto& spatial_dims =
243             inst->convolution_dimension_numbers().input_spatial_dimensions();
244         for (int64_t spatial_dim = 0; spatial_dim < spatial_dims.size();
245              ++spatial_dim) {
246           // A spatial dimemnsion with a window of size 1 does not need
247           // padding.
248           if (spatial_dims[spatial_dim] == dimension &&
249               inst->window().dimensions(spatial_dim).size() == 1) {
250             return true;
251           }
252         }
253       }
254       return operand_num == 1 &&
255              (dimension == inst->convolution_dimension_numbers()
256                                .kernel_output_feature_dimension());
257     }
258     case HloOpcode::kDot: {
259       if (operand_num == 0) {
260         return !absl::c_linear_search(
261             inst->dot_dimension_numbers().lhs_contracting_dimensions(),
262             dimension);
263       }
264       return !absl::c_linear_search(
265           inst->dot_dimension_numbers().rhs_contracting_dimensions(),
266           dimension);
267     }
268     case HloOpcode::kReduce:
269       return !absl::c_linear_search(inst->dimensions(), dimension);
270     case HloOpcode::kSelectAndScatter:
271     case HloOpcode::kReduceWindow:
272       return inst->window().dimensions(dimension).size() == 1;
273     default:
274       return false;
275   }
276 }
277 
278 // Generates a mask representing the effective area of data and padded area of
279 // data using iota and dynamic_size. For example, given a dimension of 7
280 // elements and 5 effective elements:
281 //
282 // iota = [0, 1, 2, 3, 4, 5, 6]
283 // broadcast_dynamic_size = [5, 5, 5, 5, 5, 5, 5]
284 // mask = lt(iota, broadcast_dynamic_size) = [t, t, t, t, t, f, f]
285 //
286 // Once the mask is generated, the input data is then padded using the
287 // mask and pad value.
288 //
PadWithScalar(HloInstruction * inst,int64_t dim,HloInstruction * dynamic_size,HloInstruction * padding_scalar)289 HloInstruction* PadWithScalar(HloInstruction* inst, int64_t dim,
290                               HloInstruction* dynamic_size,
291                               HloInstruction* padding_scalar) {
292   CHECK(inst != nullptr && dynamic_size != nullptr &&
293         padding_scalar != nullptr);
294   const Shape mask_shape =
295       ShapeUtil::ChangeElementType(inst->shape(), xla::S32);
296   const Shape pred_shape =
297       ShapeUtil::ChangeElementType(inst->shape(), xla::PRED);
298   HloInstruction* iota =
299       inst->AddInstruction(HloInstruction::CreateIota(mask_shape, dim));
300 
301   HloInstruction* broadcasted_effective_size = inst->AddInstruction(
302       HloInstruction::CreateBroadcast(mask_shape, dynamic_size, {}));
303   HloInstruction* pred = inst->AddInstruction(HloInstruction::CreateCompare(
304       pred_shape, iota, broadcasted_effective_size, ComparisonDirection::kLt));
305 
306   HloInstruction* broadcasted_identity_value = inst->AddInstruction(
307       HloInstruction::CreateBroadcast(inst->shape(), padding_scalar, {}));
308   HloInstruction* padded = inst->AddInstruction(
309       HloInstruction::CreateTernary(inst->shape(), HloOpcode::kSelect, pred,
310                                     inst, broadcasted_identity_value));
311   return padded;
312 }
313 
314 // Generate a 1-0 mask for input_dim where 1 means data in dynamic shape.
GenerateBinaryMask(HloInstruction * reshape,int64_t input_dim,absl::Span<const int64_t> output_dims,absl::Span<HloInstruction * > output_dynamic_dims,HloInstruction * one,HloInstruction * zero,bool split_input)315 HloInstruction* GenerateBinaryMask(
316     HloInstruction* reshape, int64_t input_dim,
317     absl::Span<const int64_t> output_dims,
318     absl::Span<HloInstruction*> output_dynamic_dims, HloInstruction* one,
319     HloInstruction* zero, bool split_input) {
320   Shape input_shape =
321       split_input ? reshape->operand(0)->shape() : reshape->shape();
322   Shape output_shape =
323       split_input ? reshape->shape() : reshape->operand(0)->shape();
324   const Shape mask_input_shape =
325       ShapeUtil::MakeShape(xla::S32, {input_shape.dimensions(input_dim)});
326   const Shape pred_input_shape =
327       ShapeUtil::MakeShape(xla::PRED, {input_shape.dimensions(input_dim)});
328   HloInstruction* pred_true = reshape->AddInstruction(
329       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
330   HloInstruction* input_shape_pred_mask = reshape->AddInstruction(
331       HloInstruction::CreateBroadcast(pred_input_shape, pred_true, {}));
332   bool need_rewrite = false;
333   // Iota contains a linear index for each element in input shape.
334   HloInstruction* iota =
335       reshape->AddInstruction(HloInstruction::CreateIota(mask_input_shape, 0));
336 
337   // Compute the multi-dimensional indices from a linear index and
338   // compare to dynamic dimension size to generate the mask.
339   // For a 2x3x3 shape, iota is first set to:
340   //   [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16,17]
341   // iota % 3 gives the index for the last dimension.
342   //   [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
343   // Then iota goes to:
344   //   [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5] (after div 3)
345   // iota % 3 gives the index of the second last dimension.
346   //   [0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2]
347   // Then iota goes to:
348   //   [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1] (after div 3)
349   // It gives the index of the major dimension.
350   // For example, element 16 in the original iota will in the end get index
351   // (1, 2, 1). Each index is used for generating the mask (if necessary) by
352   // comparing to the dynamic size value for that dimension.
353   //
354   // Skip index 0 since there is no need to rewrite a major output dimension.
355   for (int64_t i = 1; i < output_dims.size(); ++i) {
356     if (output_dynamic_dims[output_dims[i]] != nullptr) {
357       // If there is dynamic dimension in the output, need to rewrite the input.
358       need_rewrite = true;
359       break;
360     }
361   }
362   if (!need_rewrite) {
363     return nullptr;
364   }
365 
366   for (int64_t i = output_dims.size() - 1; i > 0; i--) {
367     const int64_t output_dim = output_dims[i];
368     HloInstruction* dynamic_size = output_dynamic_dims[output_dim];
369     HloInstruction* static_output_dim_size = reshape->AddInstruction(
370         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
371             output_shape.dimensions(output_dim))));
372     HloInstruction* broadcasted_static_output_dim_size =
373         reshape->AddInstruction(HloInstruction::CreateBroadcast(
374             mask_input_shape, static_output_dim_size, {}));
375     if (dynamic_size != nullptr) {
376       // Generate the mask for output_dim.
377       HloInstruction* dim_index =
378           reshape->AddInstruction(HloInstruction::CreateBinary(
379               mask_input_shape, HloOpcode::kRemainder, iota,
380               broadcasted_static_output_dim_size));
381       HloInstruction* broadcasted_effective_size = reshape->AddInstruction(
382           HloInstruction::CreateBroadcast(mask_input_shape, dynamic_size, {}));
383       HloInstruction* selected =
384           reshape->AddInstruction(HloInstruction::CreateCompare(
385               pred_input_shape, dim_index, broadcasted_effective_size,
386               ComparisonDirection::kLt));
387 
388       // Merge the mask.
389       input_shape_pred_mask = reshape->AddInstruction(
390           HloInstruction::CreateBinary(pred_input_shape, HloOpcode::kAnd,
391                                        input_shape_pred_mask, selected));
392     }
393 
394     // Update iota values by "shifting out" dimension i.
395     iota = reshape->AddInstruction(
396         HloInstruction::CreateBinary(mask_input_shape, HloOpcode::kDivide, iota,
397                                      broadcasted_static_output_dim_size));
398   }
399 
400   HloInstruction* broadcasted_one = reshape->AddInstruction(
401       HloInstruction::CreateBroadcast(mask_input_shape, one, {}));
402   HloInstruction* broadcasted_zero = reshape->AddInstruction(
403       HloInstruction::CreateBroadcast(mask_input_shape, zero, {}));
404   return reshape->AddInstruction(HloInstruction::CreateTernary(
405       mask_input_shape, HloOpcode::kSelect, input_shape_pred_mask,
406       broadcasted_one, broadcasted_zero));
407 }
408 
409 // In a reshape if a dynamic dimension is splitted into multiple output
410 // dimensions, we need to rewrite the input of the reshape.
411 //
412 // The reason for this is that a continuous input may not be evenly reshaped
413 // into output.  Image we have [<=6] where valid data has size 4 and padding (P)
414 // data has size 2: [a,b,c,d,P,P]
415 //
416 // And we have a reshape that produces dynamic output dimensions.
417 //
418 // [<=6]
419 //  |
420 // Reshape
421 //  |
422 // [2, <=3]
423 //
424 // This should produce the same result as if the data has no padding:
425 //
426 // [4]     // [a, b, c, d]
427 //  |
428 // Reshape
429 //  |
430 // [2, 2]  // [[a,b], [c,d]]
431 //
432 // Without reshape rewriting, the result looks like:
433 //
434 // [[a,b,c]
435 //  [d,P,P]], which is incorrect.
436 //
437 // We need to rewrite the reshape such that it produces:
438 // [[a,b,P]
439 //  [c,d,P]]
440 //
441 // The way we do this is by a 4-steps cumsum-gather algorithm:
442 //
443 // 1.First we use the output shape to generate a binary 0-1 masking, which masks
444 // out the padded area of the flattened output shape:
445 // [1,1,0,1,1,0]
446 //
447 // 2.We then do a cumsum with the mask:
448 //  [1,2,2,3,4,4] and subtract it with 1:
449 //  [0,1,1,2,3,3]
450 //
451 // 3.Use the result of cumsum as gather indices to rearrange the original
452 // data. Feed the original input [a,b,c,d,P,P] and indices into gather.
453 //
454 //  operand [a,b,c,d,P,P], indices [0,1,1,2,3,3]
455 //     |                    |
456 //   Gather-----------------+
457 //     |
458 //     v
459 //  value[a,b,b,c,d,d], which is equivalent to [a,b,P,c,d,P] as padding value
460 //  doesn't matter.
461 //
462 //
463 // 4.Feed the sorted input to original reshape[6]->[2,3], we can now get the
464 // correct result:
465 //  [[a,b,P]
466 //   [c,d,P]]
467 //
RewriteDynamicReshapeSplitInput(HloInstruction * reshape,int64_t input_dim,absl::Span<const int64_t> output_dims,absl::Span<HloInstruction * > output_dynamic_dims,DynamicDimensionInference * dynamic_dimension_inference)468 StatusOr<bool> RewriteDynamicReshapeSplitInput(
469     HloInstruction* reshape, int64_t input_dim,
470     absl::Span<const int64_t> output_dims,
471     absl::Span<HloInstruction*> output_dynamic_dims,
472     DynamicDimensionInference* dynamic_dimension_inference) {
473   VLOG(2) << "Reshaping input dim " << input_dim << " to "
474           << VectorString(output_dims);
475   const Shape operand_shape = reshape->operand(0)->shape();
476   TF_RET_CHECK(output_dims.size() > 1);
477 
478   const Shape mask_input_shape =
479       ShapeUtil::MakeShape(xla::S32, {operand_shape.dimensions(input_dim)});
480   const Shape pred_input_shape =
481       ShapeUtil::MakeShape(xla::PRED, {operand_shape.dimensions(input_dim)});
482 
483   HloInstruction* zero = reshape->AddInstruction(
484       HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
485   HloInstruction* one = reshape->AddInstruction(
486       HloInstruction::CreateConstant(LiteralUtil::One(S32)));
487 
488   // Step 1 -- generate binary mask.
489   HloInstruction* input_shape_binary_mask =
490       GenerateBinaryMask(reshape, input_dim, output_dims, output_dynamic_dims,
491                          one, zero, /*split_input=*/true);
492   if (input_shape_binary_mask == nullptr) {
493     // No need to rewrite.
494     VLOG(2) << "No need to rewrite";
495     return false;
496   }
497 
498   // Step 2. Do a cumsum on the binary mask.
499 
500   auto embedded_builder = HloComputation::Builder("add");
501   {
502     auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
503         0, ShapeUtil::MakeShape(S32, {}), "lhs"));
504     auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
505         1, ShapeUtil::MakeShape(S32, {}), "rhs"));
506     embedded_builder.AddInstruction(
507         HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
508   }
509 
510   HloComputation* add =
511       reshape->GetModule()->AddEmbeddedComputation(embedded_builder.Build());
512   Window cumsum_window;
513   // First dimension is unchanged.
514   WindowDimension* dim = cumsum_window.add_dimensions();
515   dim->set_size(operand_shape.dimensions(input_dim));
516   dim->set_stride(1);
517   dim->set_padding_low(operand_shape.dimensions(input_dim) - 1);
518   dim->set_padding_high(0);
519   dim->set_window_dilation(1);
520   dim->set_base_dilation(1);
521   HloInstruction* cumsum =
522       reshape->AddInstruction(HloInstruction::CreateReduceWindow(
523           mask_input_shape, input_shape_binary_mask, zero, cumsum_window, add));
524 
525   HloInstruction* broadcast_ones = reshape->AddInstruction(
526       HloInstruction::CreateBroadcast(mask_input_shape, one, {}));
527   cumsum = reshape->AddInstruction(HloInstruction::CreateBinary(
528       mask_input_shape, HloOpcode::kSubtract, cumsum, broadcast_ones));
529 
530   GatherDimensionNumbers gather_dim_numbers;
531   // Use gather to rearrange the input dim dimension.
532   for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) {
533     // Offset dim is every dimension including newly added size 1 dim, except
534     // for input_dim, which acts as a batch_dim.
535     if (i != input_dim) {
536       gather_dim_numbers.add_offset_dims(i);
537     }
538   }
539   // The dimension to rewrite is the index dim.
540   gather_dim_numbers.add_start_index_map(input_dim);
541   gather_dim_numbers.set_index_vector_dim(1);
542   gather_dim_numbers.add_collapsed_slice_dims(input_dim);
543 
544   // Step 3. Gather.
545 
546   // Temporarily removes dynamic dimension before entering gather -- we want the
547   // gather to ignore dynamic dimension.
548   HloInstruction* operand_static_dim_size =
549       reshape->AddInstruction(HloInstruction::CreateConstant(
550           LiteralUtil::CreateR0<int32_t>(operand_shape.dimensions(input_dim))));
551   HloInstruction* operand_static =
552       reshape->AddInstruction(HloInstruction::CreateSetDimensionSize(
553           operand_shape, reshape->mutable_operand(0), operand_static_dim_size,
554           input_dim));
555 
556   std::vector<int64_t> slice_sizes(operand_shape.dimensions().begin(),
557                                    operand_shape.dimensions().end());
558   slice_sizes[input_dim] = 1;
559   HloInstruction* gather = reshape->AddInstruction(HloInstruction::CreateGather(
560       ShapeUtil::MakeShape(operand_shape.element_type(),
561                            operand_shape.dimensions()),
562       operand_static, cumsum, gather_dim_numbers, slice_sizes, true));
563 
564   // Step 4: Feed gather input to original reshape.
565 
566   TF_RETURN_IF_ERROR(reshape->ReplaceOperandWith(0, gather));
567 
568   HloInstruction* reshape_dynamic = reshape;
569 
570   auto users = reshape->users();
571 
572   // Forward the output dynamic dimension.
573   for (int64_t output_dim : output_dims) {
574     HloInstruction* output_dynamic_size =
575         dynamic_dimension_inference->GetDynamicSize(reshape, {}, output_dim);
576     if (output_dynamic_size != nullptr) {
577       reshape_dynamic =
578           reshape->AddInstruction(HloInstruction::CreateSetDimensionSize(
579               reshape->shape(), reshape_dynamic, output_dynamic_size,
580               output_dim));
581     }
582   }
583 
584   for (auto* user : users) {
585     TF_RETURN_IF_ERROR(reshape->ReplaceUseWith(user, reshape_dynamic));
586   }
587   TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
588       reshape, reshape_dynamic, {}));
589 
590   return true;
591 }
592 
593 // RewriteDynamicReshapeCombineInput is similar to
594 // RewriteDynamicReshapeSplitInput, in a reshape if multiple dimensions are
595 // combined into one dimension, we need to rewrite the output.
596 //
597 // The reason for this is that a continuous input may not be evenly reshaped
598 // into output.  Image we have [2, <=3] where second dimension has size 2 and
599 // padding(P) data has size 1:
600 // [[a,b,P]
601 //  [c,d,P]]
602 //
603 // And we have a reshape that combines this two input dimensions.
604 //
605 // [2, <=3]
606 //  |
607 // Reshape
608 //  |
609 // [6]
610 //
611 // This should produce the same result as if the data has no padding:
612 //
613 // [2, 2]     // [[a, b], [c, d]]
614 //  |
615 // Reshape
616 //  |
617 // [4]  // [a,b,c,d]
618 //
619 // Without rewriting, the result would be:
620 //
621 // [a,b,P,c,d,P], which is incorrect.
622 //
623 // We need to rewrite the reshape such that it produces:
624 // [a,b,c,d,P,P]
625 //
626 // The way we do this is by a 4-steps sort-gather algorithm:
627 //
628 // 1.First we use the input shape to generate a binary 0-1 masking, which masks
629 // out the padded area of the output:
630 //  [1,1,0,1,1,0]
631 //
632 // 2.We then generate an iota mask using the output shape:
633 //  [0,1,2,3,4,5]
634 //
635 // 3.Stable sort the iota mask using the binary mask as key:
636 //  key  [1,1,0,1,1,0]
637 //  value[0,1,2,3,4,5]
638 //     | Sort by key
639 //     v
640 //  key  [1,1,1,1,0,0]
641 //  value[0,1,3,4,2,5]
642 //
643 // 4.Gather the original output [a,b,P,c,d,P] using the sorted iota mask:
644 //      original output       gather indices
645 //       [a,b,P,c,d,P]         [0,1,3,4,2,5]
646 //            |                    |
647 //          Gather ----------------+
648 //            |
649 //       [a,b,c,d,P,P]
650 //
RewriteDynamicReshapeCombineInput(HloInstruction * reshape,absl::Span<const int64_t> input_dims,int64_t output_dim,absl::Span<HloInstruction * > input_dynamic_dims,DynamicDimensionInference * dynamic_dimension_inference)651 StatusOr<bool> RewriteDynamicReshapeCombineInput(
652     HloInstruction* reshape, absl::Span<const int64_t> input_dims,
653     int64_t output_dim, absl::Span<HloInstruction*> input_dynamic_dims,
654     DynamicDimensionInference* dynamic_dimension_inference) {
655   // Rewrite dynamic reshape into reshape followed by a sort, all padded
656   // data will be moved to the end.
657   HloInstruction* zero = reshape->AddInstruction(
658       HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
659   HloInstruction* one = reshape->AddInstruction(
660       HloInstruction::CreateConstant(LiteralUtil::One(S32)));
661   const Shape output_shape = reshape->shape();
662   const Shape input_shape = reshape->operand(0)->shape();
663   const Shape mask_output_shape =
664       ShapeUtil::MakeShape(xla::S32, {output_shape.dimensions(output_dim)});
665 
666   // Step 1.
667   // Generate binary mask.
668   HloInstruction* output_shape_binary_mask =
669       GenerateBinaryMask(reshape, output_dim, input_dims, input_dynamic_dims,
670                          one, zero, /*split_input=*/false);
671   if (output_shape_binary_mask == nullptr) {
672     VLOG(2) << "No need to rewrite";
673     return false;
674   }
675 
676   // Step 2.
677   // Generate an iota with output shape.
678   HloInstruction* iota =
679       reshape->AddInstruction(HloInstruction::CreateIota(mask_output_shape, 0));
680 
681   // Step 3.
682   // Stable sort the iota mask using the binary mask as key and iota as value:
683 
684   // Build computation for sort, key is the mask, value is the iota.
685   HloComputation::Builder comp_builder("compare");
686   HloInstruction* lhs_key =
687       comp_builder.AddInstruction(HloInstruction::CreateParameter(
688           0, ShapeUtil::MakeScalarShape(S32), "lhs_key"));
689   HloInstruction* rhs_key =
690       comp_builder.AddInstruction(HloInstruction::CreateParameter(
691           1, ShapeUtil::MakeScalarShape(S32), "rhs_key"));
692 
693   // Values for lhs and rhs
694   comp_builder.AddInstruction(HloInstruction::CreateParameter(
695       2, ShapeUtil::MakeScalarShape(S32), "lhs_value"));
696   comp_builder.AddInstruction(HloInstruction::CreateParameter(
697       3, ShapeUtil::MakeScalarShape(S32), "rhs_value"));
698   comp_builder.AddInstruction(
699       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), lhs_key,
700                                     rhs_key, ComparisonDirection::kGt));
701   HloComputation* compare =
702       reshape->GetModule()->AddEmbeddedComputation(comp_builder.Build());
703 
704   // Use mask_reshaped as key, sort reshaped data as value.
705   HloInstruction* sort = reshape->AddInstruction(HloInstruction::CreateSort(
706       ShapeUtil::MakeTupleShape({mask_output_shape, mask_output_shape}), 0,
707       {output_shape_binary_mask, iota}, compare,
708       /*is_stable=*/true));
709 
710   HloInstruction* gather_indices = reshape->AddInstruction(
711       HloInstruction::CreateGetTupleElement(mask_output_shape, sort, 1));
712 
713   // Step 4.Gather the original output using the sorted iota mask:
714 
715   GatherDimensionNumbers gather_dim_numbers;
716   // Use gather to rearrange the output dim dimension.
717   for (int64_t i = 0; i < output_shape.dimensions_size(); ++i) {
718     // Offset dim is every dimension including newly added size 1 dim, except
719     // for input_dim, which acts as a batch_dim.
720     if (i != output_dim) {
721       gather_dim_numbers.add_offset_dims(i);
722     }
723   }
724   // The dimension to rewrite is the index dim.
725   gather_dim_numbers.add_start_index_map(output_dim);
726   gather_dim_numbers.set_index_vector_dim(1);
727   gather_dim_numbers.add_collapsed_slice_dims(output_dim);
728 
729   HloInstruction* static_dim_size = reshape->AddInstruction(
730       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
731           reshape->shape().dimensions(output_dim))));
732 
733   // Temporarily removes dynamic dimension of the reshape before we send it to
734   // the sort -- we want padded area to also participate in the gather.
735   HloInstruction* reshape_static =
736       reshape->AddInstruction(HloInstruction::CreateSetDimensionSize(
737           reshape->shape(), reshape, static_dim_size, output_dim));
738   std::vector<int64_t> gather_slice_sizes(output_shape.dimensions().begin(),
739                                           output_shape.dimensions().end());
740   gather_slice_sizes[output_dim] = 1;
741   HloInstruction* gather = reshape->AddInstruction(HloInstruction::CreateGather(
742       output_shape, reshape_static, gather_indices, gather_dim_numbers,
743       gather_slice_sizes, true));
744 
745   // Forward dynamic size to the newly created gather.
746   HloInstruction* output_dynamic_size =
747       dynamic_dimension_inference->GetDynamicSize(reshape, {}, output_dim);
748   TF_RET_CHECK(output_dynamic_size != nullptr);
749   gather = reshape->AddInstruction(HloInstruction::CreateSetDimensionSize(
750       gather->shape(), gather, output_dynamic_size, output_dim));
751   auto users = reshape->users();
752   for (auto* user : users) {
753     // Avoid cycles by not replacing the static reshape and get_dimension_size.
754     if (user != reshape_static && user != output_dynamic_size) {
755       TF_RETURN_IF_ERROR(reshape->ReplaceUseWith(user, gather));
756     }
757   }
758 
759   if (reshape == reshape->parent()->root_instruction()) {
760     reshape->parent()->set_root_instruction(gather);
761   }
762 
763   TF_RETURN_IF_ERROR(
764       dynamic_dimension_inference->ForwardDynamicSize(reshape, gather, {}));
765 
766   return true;
767 }
768 
RewriteDynamicReshapeSingleGroup(HloInstruction * reshape,absl::Span<const int64_t> input_dims,absl::Span<const int64_t> output_dims,absl::Span<HloInstruction * > input_dynamic_dims,absl::Span<HloInstruction * > output_dynamic_dims,DynamicDimensionInference * dynamic_dimension_inference)769 StatusOr<bool> RewriteDynamicReshapeSingleGroup(
770     HloInstruction* reshape, absl::Span<const int64_t> input_dims,
771     absl::Span<const int64_t> output_dims,
772     absl::Span<HloInstruction*> input_dynamic_dims,
773     absl::Span<HloInstruction*> output_dynamic_dims,
774     DynamicDimensionInference* dynamic_dimension_inference) {
775   VLOG(2) << "Rewriting dynamic reshape " << reshape->ToString()
776           << " input dims: " << VectorString(input_dims)
777           << " output dims: " << VectorString(output_dims);
778 
779   const Shape operand_shape = reshape->operand(0)->shape();
780   const Shape output_shape = reshape->shape();
781 
782   if (input_dims.size() == 1) {
783     int64_t input_dim = input_dims[0];
784     // Size 1 dimension doesn't need a rewrite.
785     if (operand_shape.dimensions()[input_dim] == 1) {
786       return false;
787     }
788     // One input dimension is split into multiple output dimensions.
789     return RewriteDynamicReshapeSplitInput(reshape, input_dim, output_dims,
790                                            output_dynamic_dims,
791                                            dynamic_dimension_inference);
792   }
793 
794   if (output_dims.size() == 1) {
795     int64_t output_dim = output_dims[0];
796     if (output_shape.dimensions()[output_dim] == 1) {
797       return false;
798     }
799     // One input dimension is split into multiple output dimensions.
800     return RewriteDynamicReshapeCombineInput(reshape, input_dims, output_dim,
801                                              input_dynamic_dims,
802                                              dynamic_dimension_inference);
803   }
804 
805   // Shouldn't get here.
806   TF_RET_CHECK(false);
807   return false;
808 }
809 
RewriteReverse(HloInstruction * reverse,DynamicDimensionInference * dynamic_dimension_inference)810 StatusOr<bool> RewriteReverse(
811     HloInstruction* reverse,
812     DynamicDimensionInference* dynamic_dimension_inference) {
813   // When we have [A, B, C, D, E] and reverse them, we get [E, D, C, B, A].
814   // However, if the dynamic size is 2, we expect B, A to be in front:
815   // [B, A, P, P, P].
816   //
817   // We do this by running a pad and dynamic slice on the result:
818   // [A, B, C, D, E]
819   //      |
820   //    reverse
821   //      |
822   // [E, D, C, B, A]
823   //      |
824   //     pad # Use pad to double the size of the dimension to avoid OOB.
825   //      |
826   // [E, D, C, B, A, P, P, P, P, P]
827   //      |
828   //  dynamic slice
829   //      |
830   // [B, A, P, P, P]
831   auto reverse_dims = reverse->dimensions();
832   const Shape& reverse_shape = reverse->shape();
833   std::set<int64_t> dynamic_reverse_dims;
834   for (int64_t reverse_dim : reverse_dims) {
835     HloInstruction* dynamic_size =
836         dynamic_dimension_inference->GetDynamicSize(reverse, {}, reverse_dim);
837     if (dynamic_size == nullptr) {
838       // Reverse dimension is not dynamic -- no rewrite needed.
839       continue;
840     }
841     dynamic_reverse_dims.insert(reverse_dim);
842   }
843 
844   if (dynamic_reverse_dims.empty()) {
845     // We only need to rewrite dynamic dimensions that are also reverse
846     // dimensions.
847     return false;
848   }
849 
850   PaddingConfig padding;
851   // Doubles dynamic dimension size using a pad.
852   Shape pad_shape = reverse_shape;
853   for (int i = 0; i < reverse_shape.rank(); ++i) {
854     auto dimension = padding.add_dimensions();
855     if (dynamic_reverse_dims.count(i) > 0) {
856       dimension->set_edge_padding_low(0);
857       dimension->set_edge_padding_high(reverse_shape.dimensions(i));
858       dimension->set_interior_padding(0);
859       pad_shape.set_dimensions(i, 2 * pad_shape.dimensions(i));
860     }
861   }
862   HloInstruction* cloned_reverse = reverse->AddInstruction(reverse->Clone());
863   HloInstruction* zero = reverse->AddInstruction(HloInstruction::CreateConstant(
864       LiteralUtil::Zero(pad_shape.element_type())));
865   HloInstruction* pad = reverse->AddInstruction(
866       HloInstruction::CreatePad(pad_shape, cloned_reverse, zero, padding));
867   std::vector<HloInstruction*> start_indices;
868   start_indices.reserve(reverse_shape.rank());
869   for (int i = 0; i < reverse_shape.rank(); ++i) {
870     if (dynamic_reverse_dims.count(i) > 0) {
871       // Start at bound_size - dynamic_size.
872       HloInstruction* bound_size =
873           reverse->AddInstruction(HloInstruction::CreateConstant(
874               LiteralUtil::CreateR0<int32_t>(reverse_shape.dimensions(i))));
875       HloInstruction* dynamic_size =
876           dynamic_dimension_inference->GetDynamicSize(reverse, {}, i);
877       HloInstruction* start_offset =
878           reverse->AddInstruction(HloInstruction::CreateBinary(
879               ShapeUtil::MakeScalarShape(S32), HloOpcode::kSubtract, bound_size,
880               dynamic_size));
881       start_indices.push_back(start_offset);
882     } else {
883       HloInstruction* zero = reverse->AddInstruction(
884           HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
885       start_indices.push_back(zero);
886     }
887   }
888   HloInstruction* dynamic_reverse =
889       reverse->AddInstruction(HloInstruction::CreateDynamicSlice(
890           reverse_shape, pad, start_indices, reverse_shape.dimensions()));
891   TF_RETURN_IF_ERROR(
892       reverse->parent()->ReplaceInstruction(reverse, dynamic_reverse));
893   TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
894       reverse, dynamic_reverse, {}));
895   return true;
896 }
897 
RewriteInputWithDynamicPadding(HloInstruction * conv,HloInstruction * input,HloInstruction * padding_value,absl::Span<HloInstruction * > padding_before,Window * input_window,std::function<int64_t (int64_t)> window_dim_to_shape_dim)898 HloInstruction* RewriteInputWithDynamicPadding(
899     HloInstruction* conv, HloInstruction* input, HloInstruction* padding_value,
900     absl::Span<HloInstruction*> padding_before, Window* input_window,
901     std::function<int64_t(int64_t)> window_dim_to_shape_dim) {
902   HloInstruction* zero_s32 = conv->AddInstruction(
903       HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
904   // Padded shape represents the bounded shape after dynamic padding.
905   Shape padded_shape = input->shape();
906   PaddingConfig padding_configs;
907   for (int64_t i = 0; i < input->shape().rank(); ++i) {
908     PaddingConfig::PaddingConfigDimension padding_dim;
909     *padding_configs.add_dimensions() = padding_dim;
910   }
911   std::vector<HloInstruction*> start_indices(input->shape().rank(), zero_s32);
912   for (int64_t dim_index = 0; dim_index < input_window->dimensions_size();
913        ++dim_index) {
914     if (padding_before[dim_index] == nullptr) {
915       continue;
916     }
917     int64_t shape_dim = window_dim_to_shape_dim(dim_index);
918 
919     WindowDimension* window_dim = input_window->mutable_dimensions(dim_index);
920     auto* padding_dim = padding_configs.mutable_dimensions(shape_dim);
921     const int64_t dilated_window_size = window_util::DilatedBound(
922         window_dim->size(), window_dim->window_dilation());
923     // Use dilated window size as low padding and static padding_high +
924     // padding_low as high padding to make sure the following dynamic slice is
925     // valid and doesn't go out of bound.
926     //
927     // See go/xla-dynamic-spatial-dim for more details.
928     padding_dim->set_edge_padding_low(dilated_window_size);
929     padding_dim->set_edge_padding_high(window_dim->padding_high() +
930                                        window_dim->padding_low());
931     padding_dim->set_interior_padding(window_dim->base_dilation() - 1);
932     HloInstruction* slicing_start =
933         conv->AddInstruction(HloInstruction::CreateBinary(
934             ShapeUtil::MakeScalarShape(S32), HloOpcode::kSubtract,
935             conv->AddInstruction(
936                 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
937                     padding_dim->edge_padding_low()))),
938             padding_before[dim_index]));
939     start_indices[shape_dim] = slicing_start;
940 
941     padded_shape.mutable_dimensions()[shape_dim] =
942         window_dim->padding_low() +
943         window_util::DilatedBound(padded_shape.dimensions(shape_dim),
944                                   window_dim->base_dilation()) +
945         window_dim->padding_high();
946     window_dim->clear_padding_high();
947     window_dim->clear_padding_low();
948     window_dim->set_base_dilation(1);
949     input->mutable_shape()->set_dynamic_dimension(shape_dim, false);
950   }
951   // Reconstruct dynamic padding using pad and dynamic slice.
952 
953   HloInstruction* pad =
954       MakePadHlo(input, padding_value, padding_configs).ValueOrDie();
955   input = conv->AddInstruction(HloInstruction::CreateDynamicSlice(
956       padded_shape, pad, start_indices, padded_shape.dimensions()));
957   return input;
958 }
959 
RewriteDynamicConvolutionInputGrad(HloInstruction * custom_call_conv,DynamicDimensionInference * dynamic_dimension_inference)960 StatusOr<bool> RewriteDynamicConvolutionInputGrad(
961     HloInstruction* custom_call_conv,
962     DynamicDimensionInference* dynamic_dimension_inference) {
963   HloInstruction* grad = custom_call_conv->mutable_operand(1);
964   HloInstruction* kernel = custom_call_conv->mutable_operand(2);
965   TF_RET_CHECK(kernel->shape().is_static());
966   auto dnums = custom_call_conv->convolution_dimension_numbers();
967   Window window = custom_call_conv->window();
968   HloInstruction* zero =
969       custom_call_conv->AddInstruction(HloInstruction::CreateConstant(
970           LiteralUtil::Zero(custom_call_conv->shape().element_type())));
971   std::vector<HloInstruction*> padding_before(
972       dnums.input_spatial_dimensions_size(), nullptr);
973   for (int64_t spatial_dim_index = 0;
974        spatial_dim_index < dnums.input_spatial_dimensions_size();
975        ++spatial_dim_index) {
976     int64_t input_spatial_dim =
977         dnums.input_spatial_dimensions(spatial_dim_index);
978     HloInstruction* operand_dynamic_size =
979         dynamic_dimension_inference->GetDynamicSize(
980             custom_call_conv->mutable_operand(1), {}, input_spatial_dim);
981     if (operand_dynamic_size == nullptr) {
982       continue;
983     }
984     grad = PadWithScalar(grad, input_spatial_dim, operand_dynamic_size, zero);
985     HloInstruction* slice =
986         custom_call_conv->AddInstruction(HloInstruction::CreateSlice(
987             ShapeUtil::MakeShape(S32, {1}),
988             custom_call_conv->mutable_operand(0), {input_spatial_dim},
989             {input_spatial_dim + 1}, {1}));
990     HloInstruction* dynamic_input_size = custom_call_conv->AddInstruction(
991         HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
992     const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
993     // Window stride of forward prop is same as base dilation of backward prop.
994     DynamicWindowDims dynamic_window_dims = GetWindowedInputGradSize(
995         dynamic_input_size, /*window_size=*/window_dim.size(),
996         /*window_dilation=*/window_dim.window_dilation(),
997         /*window_stride=*/window_dim.base_dilation(),
998         custom_call_conv->padding_type());
999     padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
1000   }
1001 
1002   if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
1003     grad = RewriteInputWithDynamicPadding(
1004         custom_call_conv, grad, zero, absl::MakeSpan(padding_before), &window,
1005         [&](int64_t dim) { return dnums.input_spatial_dimensions(dim); });
1006   }
1007 
1008   PrecisionConfig precision_config;
1009   if (custom_call_conv->precision_config().operand_precision_size() == 3) {
1010     // We are not interested in the precision config of the first operand, which
1011     // is the input_sizes.
1012     *precision_config.mutable_operand_precision() = {
1013         custom_call_conv->precision_config().operand_precision().begin() + 1,
1014         custom_call_conv->precision_config().operand_precision().end()};
1015   }
1016   HloInstruction* static_conv =
1017       custom_call_conv->AddInstruction(HloInstruction::CreateConvolve(
1018           custom_call_conv->shape(), grad, kernel,
1019           custom_call_conv->feature_group_count(),
1020           custom_call_conv->batch_group_count(), window,
1021           custom_call_conv->convolution_dimension_numbers(),
1022           custom_call_conv->precision_config()));
1023   TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
1024   TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1025       custom_call_conv, static_conv, {}));
1026   return true;
1027 }
1028 
RewriteDynamicConvolutionForward(HloInstruction * custom_call_conv,DynamicDimensionInference * dynamic_dimension_inference)1029 StatusOr<bool> RewriteDynamicConvolutionForward(
1030     HloInstruction* custom_call_conv,
1031     DynamicDimensionInference* dynamic_dimension_inference) {
1032   HloInstruction* input = custom_call_conv->mutable_operand(0);
1033   HloInstruction* kernel = custom_call_conv->mutable_operand(1);
1034   TF_RET_CHECK(kernel->shape().is_static());
1035   TF_RET_CHECK(input->shape().is_dynamic());
1036   Window window = custom_call_conv->window();
1037   auto dnums = custom_call_conv->convolution_dimension_numbers();
1038   HloInstruction* zero =
1039       custom_call_conv->AddInstruction(HloInstruction::CreateConstant(
1040           LiteralUtil::Zero(custom_call_conv->shape().element_type())));
1041   std::vector<HloInstruction*> padding_before(
1042       dnums.input_spatial_dimensions_size(), nullptr);
1043   for (int64_t spatial_dim_index = 0;
1044        spatial_dim_index < dnums.input_spatial_dimensions_size();
1045        ++spatial_dim_index) {
1046     int64_t input_spatial_dim =
1047         dnums.input_spatial_dimensions(spatial_dim_index);
1048     HloInstruction* operand_dynamic_size =
1049         dynamic_dimension_inference->GetDynamicSize(
1050             custom_call_conv->mutable_operand(0), {}, input_spatial_dim);
1051     if (operand_dynamic_size == nullptr) {
1052       continue;
1053     }
1054 
1055     input = PadWithScalar(input, input_spatial_dim, operand_dynamic_size, zero);
1056     const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
1057     DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1058         operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
1059         window_dim.stride(), custom_call_conv->padding_type());
1060     padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
1061   }
1062   // Input feature dim can be dynamic too, reset it to zero.
1063   const int64_t input_feature_dim = dnums.input_feature_dimension();
1064   if (HloInstruction* input_feature_dynamic_size =
1065           dynamic_dimension_inference->GetDynamicSize(
1066               custom_call_conv->mutable_operand(0), {}, input_feature_dim)) {
1067     input = PadWithScalar(input, input_feature_dim, input_feature_dynamic_size,
1068                           zero);
1069   }
1070 
1071   if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
1072     input = RewriteInputWithDynamicPadding(
1073         custom_call_conv, input, zero, absl::MakeSpan(padding_before), &window,
1074         [&](int64_t dim) { return dnums.input_spatial_dimensions(dim); });
1075   }
1076 
1077   HloInstruction* static_conv =
1078       custom_call_conv->AddInstruction(HloInstruction::CreateConvolve(
1079           custom_call_conv->shape(), input, kernel,
1080           custom_call_conv->feature_group_count(),
1081           custom_call_conv->batch_group_count(), window,
1082           custom_call_conv->convolution_dimension_numbers(),
1083           custom_call_conv->precision_config()));
1084   TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
1085   TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1086       custom_call_conv, static_conv, {}));
1087   return true;
1088 }
1089 
RewriteDynamicConvolutionKernelGrad(HloInstruction * custom_call_conv,DynamicDimensionInference * dynamic_dimension_inference)1090 StatusOr<bool> RewriteDynamicConvolutionKernelGrad(
1091     HloInstruction* custom_call_conv,
1092     DynamicDimensionInference* dynamic_dimension_inference) {
1093   HloInstruction* activations = custom_call_conv->mutable_operand(0);
1094   HloInstruction* gradients = custom_call_conv->mutable_operand(1);
1095   TF_RET_CHECK(activations->shape().is_dynamic());
1096   TF_RET_CHECK(gradients->shape().is_dynamic());
1097   Window window = custom_call_conv->window();
1098   auto dnums = custom_call_conv->convolution_dimension_numbers();
1099   HloInstruction* zero =
1100       custom_call_conv->AddInstruction(HloInstruction::CreateConstant(
1101           LiteralUtil::Zero(custom_call_conv->shape().element_type())));
1102   std::vector<HloInstruction*> padding_before(
1103       dnums.input_spatial_dimensions_size(), nullptr);
1104   for (int64_t spatial_dim_index = 0;
1105        spatial_dim_index < dnums.input_spatial_dimensions_size();
1106        ++spatial_dim_index) {
1107     int64_t input_spatial_dim =
1108         dnums.input_spatial_dimensions(spatial_dim_index);
1109     int64_t kernel_spatial_dim =
1110         dnums.kernel_spatial_dimensions(spatial_dim_index);
1111     HloInstruction* activations_dynamic_size =
1112         dynamic_dimension_inference->GetDynamicSize(
1113             custom_call_conv->mutable_operand(0), {}, input_spatial_dim);
1114     if (activations_dynamic_size != nullptr) {
1115       activations = PadWithScalar(activations, input_spatial_dim,
1116                                   activations_dynamic_size, zero);
1117     }
1118 
1119     HloInstruction* gradients_dynamic_size =
1120         dynamic_dimension_inference->GetDynamicSize(
1121             custom_call_conv->mutable_operand(1), {}, kernel_spatial_dim);
1122     if (gradients_dynamic_size != nullptr) {
1123       gradients = PadWithScalar(gradients, kernel_spatial_dim,
1124                                 gradients_dynamic_size, zero);
1125     }
1126     if (activations_dynamic_size == nullptr ||
1127         gradients_dynamic_size == nullptr) {
1128       TF_RET_CHECK(activations_dynamic_size == nullptr &&
1129                    gradients_dynamic_size == nullptr);
1130       continue;
1131     }
1132     int64_t output_spatial_dim =
1133         dnums.output_spatial_dimensions(spatial_dim_index);
1134     const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
1135     DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1136         activations_dynamic_size, /*window_size=*/
1137         custom_call_conv->shape().dimensions(output_spatial_dim),
1138         /*window_dilation=*/window_dim.stride(),
1139         /*window_stride=*/window_dim.window_dilation(),
1140         custom_call_conv->padding_type());
1141     padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
1142   }
1143 
1144   // We only need to pad input feature on lhs to 0 -- it's mathematically
1145   // equivalent to padding both lhs and rhs to 0.
1146   const int64_t input_feature_dim = dnums.input_feature_dimension();
1147   if (HloInstruction* input_feature_dynamic_size =
1148           dynamic_dimension_inference->GetDynamicSize(
1149               custom_call_conv->mutable_operand(0), {}, input_feature_dim)) {
1150     activations = PadWithScalar(activations, input_feature_dim,
1151                                 input_feature_dynamic_size, zero);
1152   }
1153 
1154   if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
1155     activations = RewriteInputWithDynamicPadding(
1156         custom_call_conv, activations, zero, absl::MakeSpan(padding_before),
1157         &window,
1158         [&](int64_t dim) { return dnums.input_spatial_dimensions(dim); });
1159   }
1160 
1161   HloInstruction* static_conv =
1162       custom_call_conv->AddInstruction(HloInstruction::CreateConvolve(
1163           custom_call_conv->shape(), activations, gradients,
1164           custom_call_conv->feature_group_count(),
1165           custom_call_conv->batch_group_count(), window,
1166           custom_call_conv->convolution_dimension_numbers(),
1167           custom_call_conv->precision_config()));
1168   TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
1169   TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1170       custom_call_conv, static_conv, {}));
1171   return true;
1172 }
1173 
RewriteDynamicReduceWindowSamePadding(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1174 StatusOr<bool> RewriteDynamicReduceWindowSamePadding(
1175     HloInstruction* hlo,
1176     DynamicDimensionInference* dynamic_dimension_inference) {
1177   if (hlo->shape().IsTuple()) {
1178     // TODO (b/73062247) variadic reduce window is not yet supported here.
1179     return Unimplemented("DynamicReduceWindowSamePadding not yet supported.");
1180   }
1181   HloInstruction* input = hlo->mutable_operand(0);
1182   HloInstruction* init = hlo->mutable_operand(1);
1183   int64_t rank = hlo->shape().rank();
1184   Window window = hlo->window();
1185   std::vector<HloInstruction*> padding_before(hlo->shape().rank(), nullptr);
1186   for (int64_t dim_index = 0; dim_index < rank; ++dim_index) {
1187     HloInstruction* operand_dynamic_size =
1188         dynamic_dimension_inference->GetDynamicSize(hlo->mutable_operand(0), {},
1189                                                     dim_index);
1190     if (operand_dynamic_size == nullptr) {
1191       continue;
1192     }
1193     const WindowDimension& window_dim = window.dimensions(dim_index);
1194     if (window_util::IsTrivialWindowDimension(window_dim)) {
1195       continue;
1196     }
1197     input = PadWithScalar(input, dim_index, operand_dynamic_size, init);
1198 
1199     DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1200         operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
1201         window_dim.stride(), PaddingType::PADDING_SAME);
1202     padding_before[dim_index] = dynamic_window_dims.padding_before;
1203   }
1204 
1205   input = RewriteInputWithDynamicPadding(
1206       hlo, input, init, absl::MakeSpan(padding_before), &window,
1207       [](int64_t dim) { return dim; });
1208 
1209   HloInstruction* rewritten =
1210       hlo->AddInstruction(HloInstruction::CreateReduceWindow(
1211           hlo->shape(), input, init, window, hlo->called_computations()[0]));
1212   TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rewritten));
1213   TF_RETURN_IF_ERROR(
1214       dynamic_dimension_inference->ForwardDynamicSize(hlo, rewritten, {}));
1215   return true;
1216 }
1217 
RewriteDynamicSelectAndScatterSamePadding(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1218 StatusOr<bool> RewriteDynamicSelectAndScatterSamePadding(
1219     HloInstruction* hlo,
1220     DynamicDimensionInference* dynamic_dimension_inference) {
1221   HloInstruction* input = hlo->mutable_operand(0);
1222   HloInstruction* source = hlo->mutable_operand(1);
1223   HloInstruction* init = hlo->mutable_operand(2);
1224   TF_ASSIGN_OR_RETURN(HloInstruction * input_padding_value,
1225                       ChooseIdentityValue(hlo, /*operand_number=*/0));
1226   int64_t rank = hlo->shape().rank();
1227   Window window = hlo->window();
1228   std::vector<HloInstruction*> padding_before(hlo->shape().rank(), nullptr);
1229   for (int64_t dim_index = 0; dim_index < rank; ++dim_index) {
1230     const WindowDimension& window_dim = window.dimensions(dim_index);
1231     if (window_util::IsTrivialWindowDimension(window_dim)) {
1232       continue;
1233     }
1234     HloInstruction* operand_dynamic_size =
1235         dynamic_dimension_inference->GetDynamicSize(hlo->mutable_operand(0), {},
1236                                                     dim_index);
1237     if (operand_dynamic_size == nullptr) {
1238       continue;
1239     }
1240 
1241     input = PadWithScalar(input, dim_index, operand_dynamic_size,
1242                           input_padding_value);
1243 
1244     HloInstruction* source_dynamic_size =
1245         dynamic_dimension_inference->GetDynamicSize(hlo->mutable_operand(1), {},
1246                                                     dim_index);
1247     if (source_dynamic_size == nullptr) {
1248       continue;
1249     }
1250     source = PadWithScalar(source, dim_index, source_dynamic_size, init);
1251 
1252     DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1253         operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
1254         window_dim.stride(), PaddingType::PADDING_SAME);
1255     padding_before[dim_index] = dynamic_window_dims.padding_before;
1256   }
1257 
1258   input = RewriteInputWithDynamicPadding(
1259       hlo, input, input_padding_value, absl::MakeSpan(padding_before), &window,
1260       [](int64_t dim) { return dim; });
1261 
1262   // RewriteInputWithDynamicPadding adds padding to the input. However those
1263   // inputs should not be materialized in select and scatter's output and we
1264   // need to slice them out using dynamic slice. To prevent dynamic slicegoing
1265   // OOB, we first add some high-pad to the output to leave enough space.
1266   HloInstruction* rewritten =
1267       hlo->AddInstruction(HloInstruction::CreateSelectAndScatter(
1268           input->shape(), input, hlo->called_computations()[0], window, source,
1269           init, hlo->called_computations()[1]));
1270   std::vector<HloInstruction*> start_indices(
1271       input->shape().rank(), hlo->AddInstruction(HloInstruction::CreateConstant(
1272                                  LiteralUtil::Zero(S32))));
1273   PaddingConfig padding_configs;
1274   for (int64_t dim_index = 0; dim_index < rank; ++dim_index) {
1275     PaddingConfig::PaddingConfigDimension padding_dim;
1276     if (padding_before[dim_index] != nullptr) {
1277       const WindowDimension& window_dim = window.dimensions(dim_index);
1278       const int64_t dilated_window_size = window_util::DilatedBound(
1279           window_dim.size(), window_dim.window_dilation());
1280       padding_dim.set_edge_padding_high(dilated_window_size);
1281       start_indices[dim_index] = padding_before[dim_index];
1282     }
1283     *padding_configs.add_dimensions() = padding_dim;
1284   }
1285   HloInstruction* padded =
1286       MakePadHlo(rewritten, init, padding_configs).ValueOrDie();
1287   rewritten = hlo->AddInstruction(HloInstruction::CreateDynamicSlice(
1288       hlo->shape(), padded, start_indices, hlo->shape().dimensions()));
1289   TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rewritten));
1290   TF_RETURN_IF_ERROR(
1291       dynamic_dimension_inference->ForwardDynamicSize(hlo, rewritten, {}));
1292   return true;
1293 }
1294 
RewriteDynamicConcat(HloInstruction * concat,DynamicDimensionInference * dynamic_dimension_inference)1295 StatusOr<bool> RewriteDynamicConcat(
1296     HloInstruction* concat,
1297     DynamicDimensionInference* dynamic_dimension_inference) {
1298   const int64_t concat_dim = concat->concatenate_dimension();
1299   if (dynamic_dimension_inference->GetDynamicSize(concat, {}, concat_dim) ==
1300       nullptr) {
1301     // Concat dimension is not dynamic -- no rewrite needed.
1302     return false;
1303   }
1304   std::vector<HloInstruction*> offsets;
1305   for (int64_t i = 0; i < concat->shape().dimensions_size(); ++i) {
1306     offsets.push_back(concat->AddInstruction(
1307         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0))));
1308   }
1309   HloInstruction* rewritten_concat = concat;
1310   // Keep track of previous users before rewrite so that we can update their
1311   // operands later.
1312   auto prev_users = concat->users();
1313   for (int64_t i = 0; i < concat->operand_count(); ++i) {
1314     // Rewrite the concat by dynamic update slicing operand into the concat dim.
1315     HloInstruction* operand = concat->mutable_operand(i);
1316     rewritten_concat =
1317         concat->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1318             rewritten_concat->shape(), rewritten_concat, operand, offsets));
1319     // Update the offset of concat dimension by adding the size of the concat
1320     // dimension of the operand to it.
1321     HloInstruction* dynamic_size =
1322         dynamic_dimension_inference->GetDynamicSize(operand, {}, concat_dim);
1323     if (dynamic_size == nullptr) {
1324       HloInstruction* static_size = concat->AddInstruction(
1325           HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
1326               operand->shape().dimensions(concat_dim))));
1327       offsets[concat_dim] = concat->AddInstruction(HloInstruction::CreateBinary(
1328           ShapeUtil::MakeScalarShape(S32), HloOpcode::kAdd, offsets[concat_dim],
1329           static_size));
1330     } else {
1331       offsets[concat_dim] = concat->AddInstruction(HloInstruction::CreateBinary(
1332           ShapeUtil::MakeScalarShape(S32), HloOpcode::kAdd, offsets[concat_dim],
1333           dynamic_size));
1334     }
1335   }
1336   TF_RETURN_IF_ERROR(concat->ReplaceUsesWith(prev_users, rewritten_concat));
1337   TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1338       concat, rewritten_concat, {}));
1339   return true;
1340 }
1341 
RewriteDynamicSort(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1342 StatusOr<bool> RewriteDynamicSort(
1343     HloInstruction* hlo,
1344     DynamicDimensionInference* dynamic_dimension_inference) {
1345   HloInstruction* dynamic_size = nullptr;
1346   HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
1347   int64_t sort_dim = sort->sort_dimension();
1348   // Find the dynamic dimension in the operand.
1349   for (auto* operand : sort->operands()) {
1350     if (dynamic_size == nullptr) {
1351       dynamic_size =
1352           dynamic_dimension_inference->GetDynamicSize(operand, {}, sort_dim);
1353     }
1354   }
1355 
1356   if (dynamic_size == nullptr) {
1357     // Not a dynamic sort, ignore.
1358     return false;
1359   }
1360 
1361   Shape operand_shape =
1362       ShapeUtil::ChangeElementType(sort->operand(0)->shape(), S32);
1363   HloInstruction* iota =
1364       hlo->AddInstruction(HloInstruction::CreateIota(operand_shape, sort_dim));
1365   HloInstruction* dynamic_size_broadcasted = hlo->AddInstruction(
1366       HloInstruction::CreateBroadcast(operand_shape, dynamic_size, {}));
1367   HloInstruction* lt = hlo->AddInstruction(HloInstruction::CreateCompare(
1368       ShapeUtil::ChangeElementType(operand_shape, PRED), iota,
1369       dynamic_size_broadcasted, ComparisonDirection::kLt));
1370   sort->AppendOperand(lt);
1371 
1372   const int64_t param_number_before_rewritten =
1373       sort->called_computations()[0]->num_parameters();
1374   auto new_param_0 = HloInstruction::CreateParameter(
1375       param_number_before_rewritten, ShapeUtil::MakeScalarShape(PRED),
1376       "inbound_lhs");
1377   auto new_param_1 = HloInstruction::CreateParameter(
1378       param_number_before_rewritten + 1, ShapeUtil::MakeScalarShape(PRED),
1379       "inbound_rhs");
1380   std::vector<const HloInstruction*> extra_parameters{new_param_0.get(),
1381                                                       new_param_1.get()};
1382   HloComputation* sort_comp = sort->parent()->parent()->AddEmbeddedComputation(
1383       sort->called_computations()[0]->CloneWithReplacements(
1384           /*replacements=*/nullptr, extra_parameters));
1385   auto inbound_lhs =
1386       sort_comp->parameter_instruction(param_number_before_rewritten);
1387   auto inbound_rhs =
1388       sort_comp->parameter_instruction(param_number_before_rewritten + 1);
1389   sort->ReplaceCalledComputations(
1390       [&](HloComputation* comp) { return sort_comp; });
1391 
1392   // inbound_lhs & (sort_comp | !in_bound_rhs)
1393   // Select the lhs if it is in bounds and the rhs is out of bounds or the
1394   // sort_comp returns true.
1395   auto out_of_bound_rhs = sort_comp->AddInstruction(HloInstruction::CreateUnary(
1396       ShapeUtil::MakeScalarShape(PRED), HloOpcode::kNot, inbound_rhs));
1397   auto sort_comp_or_out_of_bound_rhs =
1398       sort_comp->AddInstruction(HloInstruction::CreateBinary(
1399           ShapeUtil::MakeScalarShape(PRED), HloOpcode::kOr,
1400           sort_comp->root_instruction(), out_of_bound_rhs));
1401 
1402   auto new_root = sort_comp->AddInstruction(HloInstruction::CreateBinary(
1403       ShapeUtil::MakeScalarShape(PRED), HloOpcode::kAnd, inbound_lhs,
1404       sort_comp_or_out_of_bound_rhs));
1405   sort_comp->set_root_instruction(new_root);
1406   Shape compare_shape =
1407       ShapeUtil::ChangeElementType(sort->operand(0)->shape(), PRED);
1408   if (sort->shape().IsTuple()) {
1409     // For sort that is already tuple, simply add another result to the tuple.
1410     *sort->mutable_shape()->add_tuple_shapes() =
1411         ShapeUtil::ChangeElementType(operand_shape, PRED);
1412   } else {
1413     auto sort_users = sort->users();
1414     auto sort_clone = hlo->AddInstruction(sort->Clone());
1415     *sort_clone->mutable_shape() = ShapeUtil::MakeTupleShape(
1416         {sort->shape(), ShapeUtil::ChangeElementType(operand_shape, PRED)});
1417     auto rewritten_sort = hlo->AddInstruction(
1418         HloInstruction::CreateGetTupleElement(sort->shape(), sort_clone, 0));
1419     for (HloInstruction* user : sort_users) {
1420       TF_RETURN_IF_ERROR(sort->ReplaceUseWith(user, rewritten_sort));
1421     }
1422     TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1423         sort, rewritten_sort, {}));
1424     if (hlo->parent()->root_instruction() == sort) {
1425       hlo->parent()->set_root_instruction(rewritten_sort);
1426     }
1427   }
1428 
1429   return true;
1430 }
1431 
RewriteDynamicBinaryOp(HloInstruction * binary,DynamicDimensionInference * dynamic_dimension_inference)1432 StatusOr<bool> RewriteDynamicBinaryOp(
1433     HloInstruction* binary,
1434     DynamicDimensionInference* dynamic_dimension_inference) {
1435   HloInstruction* operand_0 = binary->mutable_operand(0);
1436   HloInstruction* operand_1 = binary->mutable_operand(1);
1437 
1438   TF_RET_CHECK(operand_0->shape().rank() == operand_1->shape().rank());
1439   auto dims_0 = dynamic_dimension_inference->GetDynamicSizes(operand_0, {});
1440   auto dims_1 = dynamic_dimension_inference->GetDynamicSizes(operand_1, {});
1441   bool changed = false;
1442   for (int64_t i = 0; i < dims_0.size(); ++i) {
1443     HloInstruction* dim_0 = dims_0[i];
1444     HloInstruction* dim_1 = dims_1[i];
1445 
1446     if (dims_0[i] != dims_1[i] && dims_0[i] != nullptr &&
1447         dims_1[i] != nullptr) {
1448       changed = true;
1449       // It is possible that a dynamic dimension of one operand is size 1 while
1450       // the other is greater than one. According to implicit broadcast
1451       // semantics, we need to insert broadcast in this case to make the dynamic
1452       // shape match.
1453 
1454       // An implicit broadcast is inserted by slicing the small shape into a
1455       // size 1 slice, reshape out the size 1 dimension then broadcast to the
1456       // full shape:
1457       //
1458       // Input [2, <=5, 3]
1459       //   |
1460       // Slice [2, 1, 3]
1461       //   |
1462       // Reshape [2, 3]
1463       //   |
1464       // Broadcast [2, 5, 3]
1465       auto rewrite_operand = [&](HloInstruction* pred,
1466                                  HloInstruction* operand) -> HloInstruction* {
1467         Shape static_shape = operand->shape();
1468         static_shape.clear_dynamic_dimensions();
1469         pred = binary->AddInstruction(HloInstruction::CreateBroadcast(
1470             ShapeUtil::ChangeElementType(static_shape, PRED), pred, {}));
1471         Shape slice_shape = static_shape;
1472         slice_shape.set_dimensions(i, 1);
1473         std::vector<int64_t> start_indices(slice_shape.rank(), 0);
1474         std::vector<int64_t> strides(slice_shape.rank(), 1);
1475         HloInstruction* slice = binary->AddInstruction(
1476             HloInstruction::CreateSlice(slice_shape, operand, start_indices,
1477                                         slice_shape.dimensions(), strides));
1478         Shape reshape_shape = ShapeUtil::DeleteDimension(i, slice_shape);
1479         HloInstruction* reshape = binary->AddInstruction(
1480             HloInstruction::CreateReshape(reshape_shape, slice));
1481         std::vector<int64_t> broadcast_dims;
1482         broadcast_dims.reserve(static_shape.rank() - 1);
1483         // Broadcast to all dims execpt for i.
1484         for (int64_t j = 0; j < static_shape.rank(); ++j) {
1485           if (j != i) {
1486             broadcast_dims.push_back(j);
1487           }
1488         }
1489 
1490         HloInstruction* broadcast = binary->parent()->AddInstruction(
1491             HloInstruction::CreateBroadcast(static_shape, reshape,
1492                                             broadcast_dims),
1493             "implicit_broadcast");
1494 
1495         // Use a select instead of conditional as elementwise operations promote
1496         // more fusion.
1497         HloInstruction* select =
1498             binary->AddInstruction(HloInstruction::CreateTernary(
1499                 static_shape, HloOpcode::kSelect, pred, broadcast, operand));
1500         return select;
1501       };
1502 
1503       HloInstruction* one = binary->AddInstruction(
1504           HloInstruction::CreateConstant(LiteralUtil::One(S32)));
1505 
1506       auto operand_0_needs_broadcast = binary->parent()->AddInstruction(
1507           HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_0,
1508                                         dim_1, ComparisonDirection::kLt),
1509           "lhs_less_than_rhs");
1510       auto is_one = binary->parent()->AddInstruction(
1511           HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_0,
1512                                         one, ComparisonDirection::kEq),
1513           "lhs_is_one");
1514       operand_0_needs_broadcast = binary->parent()->AddInstruction(
1515           HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
1516                                        HloOpcode::kAnd, is_one,
1517                                        operand_0_needs_broadcast),
1518           "lhs_needs_implicit_broadcast");
1519       operand_0 = rewrite_operand(operand_0_needs_broadcast, operand_0);
1520 
1521       auto operand_1_needs_broadcast = binary->parent()->AddInstruction(
1522           HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_1,
1523                                         dim_0, ComparisonDirection::kLt),
1524           "rhs_less_than_lhs");
1525       is_one = binary->parent()->AddInstruction(
1526           HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_1,
1527                                         one, ComparisonDirection::kEq),
1528           "rhs_is_one");
1529       operand_1_needs_broadcast = binary->parent()->AddInstruction(
1530           HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
1531                                        HloOpcode::kAnd, is_one,
1532                                        operand_1_needs_broadcast),
1533           "lhs_needs_implicit_broadcast");
1534       operand_1 = rewrite_operand(operand_1_needs_broadcast, operand_1);
1535     }
1536   }
1537   if (changed) {
1538     TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(0, operand_0));
1539     TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(1, operand_1));
1540   }
1541   return changed;
1542 }
1543 
RewriteDynamicUpdateSlice(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1544 StatusOr<bool> RewriteDynamicUpdateSlice(
1545     HloInstruction* hlo,
1546     DynamicDimensionInference* dynamic_dimension_inference) {
1547   HloDynamicUpdateSliceInstruction* dus =
1548       Cast<HloDynamicUpdateSliceInstruction>(hlo);
1549   // Suppose we have a base area that we want to update:
1550   // +------------------------+
1551   // |                        |
1552   // |                  base  |
1553   // |                        |
1554   // +------------------------+
1555   //
1556   // A partial update with dynamic padding looks like this:
1557   //
1558   //           +------+-------+
1559   //           |update|padding|
1560   //           +------+-------+
1561   //
1562   // We don't want the padding to overwrite the base area:
1563   //
1564   // +------------------------+
1565   // |         +------+-------+
1566   // |<-begin->|update|padding| (what we want to avoid)
1567   // |         +------+-------+
1568   // +------------------------+
1569   //
1570   // Instead we want to keep the base area untouched except for the update
1571   // region:
1572   //
1573   // +------------------------+
1574   // |         +------+       |
1575   // |<-begin->|update|  base | (what we want)
1576   // |         +------+       |
1577   // +------------------------+
1578   //
1579   // We do this by dynamic slicing the base area out first with the same begin
1580   // index:
1581   //
1582   //           +--------------+
1583   // <-begin-> |         base |
1584   //           +--------------+
1585   //
1586   // Then replace the update's padding part with base:
1587   //
1588   //           +------+-------+
1589   //           |update|  base |
1590   //           +------+-------+
1591   //
1592   // Then do the DUS.
1593 
1594   HloInstruction* update = dus->mutable_operand(1);
1595   HloInstruction* base = dus->mutable_operand(0);
1596   std::vector<HloInstruction*> dynamic_dims_in_partial_update(
1597       update->shape().rank(), nullptr);
1598   bool needs_rewrite = false;
1599   for (int64_t i = 0; i < update->shape().rank(); ++i) {
1600     if (update->shape().dimensions(i) < base->shape().dimensions(i)) {
1601       HloInstruction* dynamic_dim =
1602           dynamic_dimension_inference->GetDynamicSize(update, {}, i);
1603 
1604       if (dynamic_dim != nullptr) {
1605         dynamic_dims_in_partial_update[i] = dynamic_dim;
1606         needs_rewrite = true;
1607       }
1608     }
1609   }
1610 
1611   if (!needs_rewrite) {
1612     return false;
1613   }
1614   std::vector<HloInstruction*> indices;
1615   indices.reserve(dus->operand_count() - 2);
1616   for (int64_t i = 2; i < dus->operand_count(); ++i) {
1617     indices.push_back(dus->mutable_operand(i));
1618   }
1619   HloInstruction* base_slice =
1620       dus->AddInstruction(HloInstruction::CreateDynamicSlice(
1621           update->shape(), base, indices, update->shape().dimensions()));
1622 
1623   for (int64_t i = 0; i < dynamic_dims_in_partial_update.size(); ++i) {
1624     HloInstruction* dynamic_dim = dynamic_dims_in_partial_update[i];
1625     if (dynamic_dim != nullptr) {
1626       Shape mask_shape_int = ShapeUtil::ChangeElementType(update->shape(), S32);
1627       Shape mask_shape_pred =
1628           ShapeUtil::ChangeElementType(update->shape(), PRED);
1629       // Generate mask using iota and dynamic_dim.
1630       HloInstruction* iota =
1631           dus->AddInstruction(HloInstruction::CreateIota(mask_shape_int, i));
1632       HloInstruction* broadcast_dim = dus->AddInstruction(
1633           HloInstruction::CreateBroadcast(mask_shape_int, dynamic_dim, {}));
1634       HloInstruction* pred = dus->AddInstruction(HloInstruction::CreateCompare(
1635           mask_shape_pred, iota, broadcast_dim, ComparisonDirection::kLt));
1636       // Update `update` to include base.
1637       update = dus->AddInstruction(HloInstruction::CreateTernary(
1638           update->shape(), HloOpcode::kSelect, pred, update, base_slice));
1639     }
1640   }
1641   TF_RETURN_IF_ERROR(dus->ReplaceOperandWith(1, update));
1642 
1643   return true;
1644 }
1645 
RewriteDynamicReshape(HloInstruction * reshape,DynamicDimensionInference * dynamic_dimension_inference)1646 StatusOr<bool> RewriteDynamicReshape(
1647     HloInstruction* reshape,
1648     DynamicDimensionInference* dynamic_dimension_inference) {
1649   bool changed = false;
1650   HloInstruction* operand = reshape->mutable_operand(0);
1651   std::vector<HloInstruction*> input_dynamic_dims;
1652   for (int64_t dim = 0; dim < operand->shape().dimensions_size(); ++dim) {
1653     input_dynamic_dims.push_back(
1654         dynamic_dimension_inference->GetDynamicSize(operand, {}, dim));
1655   }
1656 
1657   std::vector<HloInstruction*> output_dynamic_dims;
1658   for (int64_t dim = 0; dim < reshape->shape().dimensions_size(); ++dim) {
1659     output_dynamic_dims.push_back(
1660         dynamic_dimension_inference->GetDynamicSize(reshape, {}, dim));
1661   }
1662 
1663   auto common_factors = CommonFactors(operand->shape().dimensions(),
1664                                       reshape->shape().dimensions());
1665 
1666   // Scan first to see if we need to decompose the reshape to a
1667   // flatten-unflatten pair.
1668   bool need_flatten_unflatten = false;
1669   auto is_dynamic_dimension = [&](int64_t dim) {
1670     HloInstruction* operand_dynamic_size =
1671         dynamic_dimension_inference->GetDynamicSize(reshape, {}, dim);
1672     return operand_dynamic_size != nullptr ||
1673            reshape->shape().is_dynamic_dimension(dim);
1674   };
1675 
1676   auto should_skip_common_factor_group = [&](DimensionVector input_dims,
1677                                              DimensionVector output_dims) {
1678     if (input_dims.empty() || output_dims.empty()) {
1679       return true;
1680     }
1681     if (absl::c_none_of(output_dims, is_dynamic_dimension)) {
1682       // Don't need to rewrite any group without dynamic dimensions.
1683       VLOG(2) << "All dimensions are static in this common factor group";
1684       return true;
1685     }
1686     if (input_dims.size() == 1 && output_dims.size() == 1) {
1687       // The dimension is unchanged. No rewrite needed.
1688       return true;
1689     }
1690     return false;
1691   };
1692 
1693   for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
1694     auto start = common_factors[i];
1695     auto end = common_factors[i + 1];
1696     DimensionVector input_dims;
1697     DimensionVector output_dims;
1698     for (int64_t dim = start.first; dim < end.first; ++dim) {
1699       input_dims.push_back(dim);
1700     }
1701     for (int64_t dim = start.second; dim < end.second; ++dim) {
1702       output_dims.push_back(dim);
1703     }
1704     if (should_skip_common_factor_group(input_dims, output_dims)) {
1705       continue;
1706     }
1707     if (input_dims.size() > 1 && output_dims.size() > 1) {
1708       need_flatten_unflatten = true;
1709       break;
1710     }
1711   }
1712 
1713   if (need_flatten_unflatten) {
1714     VLOG(2) << "Rewrite dynamic reshape to flatten-unflatten pair. "
1715             << reshape->ToString();
1716     int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
1717     Shape flattened_shape =
1718         ShapeUtil::MakeShape(operand->shape().element_type(), {num_elements});
1719     HloInstruction* flatten = operand->AddInstruction(
1720         HloInstruction::CreateReshape(flattened_shape, operand));
1721 
1722     HloInstruction* dynamic_size =
1723         operand->AddInstruction(HloInstruction::CreateConstant(
1724             LiteralUtil::CreateR0<int32_t>(num_elements)));
1725     for (int64_t i = 0; i < operand->shape().rank(); i++) {
1726       HloInstruction* dynamic_dim_size =
1727           dynamic_dimension_inference->GetDynamicSize(operand, {}, i);
1728       if (dynamic_dim_size != nullptr) {
1729         HloInstruction* static_dim_size = operand->AddInstruction(
1730             HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
1731                 operand->shape().dimensions(i))));
1732         dynamic_size = operand->AddInstruction(HloInstruction::CreateBinary(
1733             dynamic_size->shape(), HloOpcode::kDivide, dynamic_size,
1734             static_dim_size));
1735         dynamic_size = operand->AddInstruction(HloInstruction::CreateBinary(
1736             dynamic_size->shape(), HloOpcode::kMultiply, dynamic_size,
1737             dynamic_dim_size));
1738       }
1739     }
1740     dynamic_dimension_inference->SetDynamicSize(flatten, {}, 0, dynamic_size);
1741 
1742     HloInstruction* unflatten = reshape->AddInstruction(
1743         HloInstruction::CreateReshape(reshape->shape(), flatten));
1744     TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1745         reshape, unflatten, {}));
1746 
1747     TF_ASSIGN_OR_RETURN(
1748         bool changed_unused,
1749         RewriteDynamicReshape(flatten, dynamic_dimension_inference));
1750     TF_ASSIGN_OR_RETURN(
1751         changed_unused,
1752         RewriteDynamicReshape(unflatten, dynamic_dimension_inference));
1753     TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(unflatten));
1754 
1755     return true;
1756   }
1757 
1758   // Find common_factors that the input belongs to.
1759   for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
1760     auto start = common_factors[i];
1761     auto end = common_factors[i + 1];
1762     DimensionVector input_dims;
1763     DimensionVector output_dims;
1764     for (int64_t dim = start.first; dim < end.first; ++dim) {
1765       input_dims.push_back(dim);
1766     }
1767     for (int64_t dim = start.second; dim < end.second; ++dim) {
1768       output_dims.push_back(dim);
1769     }
1770 
1771     VLOG(2) << "input_dims: " << VectorString(input_dims);
1772     VLOG(2) << "output_dims: " << VectorString(output_dims);
1773 
1774     if (should_skip_common_factor_group(input_dims, output_dims)) {
1775       continue;
1776     }
1777     if (input_dims.size() > 1 && output_dims.size() > 1) {
1778       return InternalError(
1779           "Should be handled by decomposing reshape into "
1780           "flatten-unflatten pair. %s",
1781           reshape->ToString());
1782     }
1783 
1784     TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicReshapeSingleGroup(
1785                                     reshape, input_dims, output_dims,
1786                                     absl::MakeSpan(input_dynamic_dims),
1787                                     absl::MakeSpan(output_dynamic_dims),
1788                                     dynamic_dimension_inference));
1789     changed |= c;
1790   }
1791 
1792   if (reshape->opcode() == HloOpcode::kDynamicReshape) {
1793     auto* static_reshape =
1794         reshape->AddInstruction(HloInstruction::CreateReshape(
1795             reshape->shape(), reshape->mutable_operand(0)));
1796     TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(static_reshape));
1797     TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1798         reshape, static_reshape, {}));
1799     changed = true;
1800   }
1801 
1802   return changed;
1803 }
1804 
1805 // Insert pad-to-static after `inst` if `inst` has dynamic dimensions in it.
1806 // Recurse into tuple instructions.
InsertPadToStaticOnInstruction(HloInstruction * inst)1807 StatusOr<HloInstruction*> InsertPadToStaticOnInstruction(HloInstruction* inst) {
1808   if (inst->shape().is_static()) {
1809     return inst;
1810   }
1811   if (!inst->shape().IsTuple()) {
1812     // The output shape of pad static is a tuple. The 0th element is the data
1813     // output, which is the same as input shape, but without dynamic dimensions;
1814     // i-th element is the dynamic dimension size for i-1th input dimension.
1815     Shape data_output_shape = inst->shape();  // 0th element.
1816     data_output_shape.clear_dynamic_dimensions();
1817     Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
1818     for (int64_t i = 0; i < inst->shape().rank(); ++i) {
1819       ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
1820                                     &output_shape);
1821     }
1822     HloInstruction* pad_to_static =
1823         inst->AddInstruction(HloInstruction::CreateCustomCall(
1824             output_shape, {inst}, "PadToStatic", ""));
1825     HloInstruction* data_output =
1826         inst->AddInstruction(HloInstruction::CreateGetTupleElement(
1827             data_output_shape, pad_to_static, 0));
1828     return data_output;
1829   }
1830 
1831   TF_RET_CHECK(inst->shape().IsTuple());
1832   std::vector<HloInstruction*> static_tuple_elements;
1833   for (int64_t i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
1834     // For each tuple element, if it is static, pass it through. If it is
1835     // dynamic, recursively call this function again.
1836     HloInstruction* gte =
1837         inst->AddInstruction(HloInstruction::CreateGetTupleElement(
1838             inst->shape().tuple_shapes(i), inst, i));
1839 
1840     if (gte->shape().is_static()) {
1841       static_tuple_elements.push_back(gte);
1842     } else {
1843       TF_ASSIGN_OR_RETURN(HloInstruction * static_gte,
1844                           InsertPadToStaticOnInstruction(gte));
1845       static_tuple_elements.push_back(static_gte);
1846     }
1847   }
1848 
1849   return inst->AddInstruction(
1850       HloInstruction::CreateTuple(static_tuple_elements));
1851 }
1852 
1853 // Inserts PadToStatic for parameters and custom-calls which "materialize"
1854 // dynamic outputs given only static inputs.
InsertPadToStaticAfterModuleInputs(HloModule * module)1855 Status InsertPadToStaticAfterModuleInputs(HloModule* module) {
1856   std::vector<HloInstruction*> params;
1857   HloComputation* entry = module->entry_computation();
1858   for (HloComputation* comp : module->MakeNonfusionComputationsSorted()) {
1859     for (HloInstruction* instr : comp->instructions()) {
1860       if (!instr->shape().is_static() &&
1861           ((instr->opcode() == HloOpcode::kParameter && comp == entry) ||
1862            instr->opcode() == HloOpcode::kCustomCall) &&
1863           absl::c_all_of(instr->operands(), [&](HloInstruction* operand) {
1864             return operand->shape().is_static();
1865           })) {
1866         LOG(ERROR) << "Inserting PadToStatic for instruction: "
1867                    << instr->ToString();
1868         auto users = instr->users();
1869         TF_ASSIGN_OR_RETURN(HloInstruction * instr_static,
1870                             InsertPadToStaticOnInstruction(instr));
1871         for (auto* user : users) {
1872           TF_RETURN_IF_ERROR(instr->ReplaceUseWith(user, instr_static));
1873         }
1874         if (instr == entry->root_instruction()) {
1875           module->entry_computation()->set_root_instruction(instr_static);
1876         }
1877       }
1878     }
1879   }
1880   return OkStatus();
1881 }
1882 
1883 // Remove all dynamic shapes between pad-to-static and slice-to-dynamic.
1884 //
1885 // After this visitor the entry computation then looks like:
1886 //  Param(dynamic)
1887 //    |
1888 //   GTE (dynamic)
1889 //    |
1890 //  PadToStatic(static)
1891 //    |
1892 //   .... regular computation with static shapes.
1893 //    |
1894 //  SliceToDynamic(dynamic)
1895 //    |
1896 // ROOT tuple (dynamic)
1897 class DynamicShapeRemovingVisitor : public DfsHloVisitorWithDefault {
1898  public:
DynamicShapeRemovingVisitor(const DynamicPadderOptions::OpSupportsDynamismHandler & op_supports_dynamism_handler,DynamicDimensionInference * dynamic_dimension_inference)1899   explicit DynamicShapeRemovingVisitor(
1900       const DynamicPadderOptions::OpSupportsDynamismHandler&
1901           op_supports_dynamism_handler,
1902       DynamicDimensionInference* dynamic_dimension_inference)
1903       : op_supports_dynamism_handler_(op_supports_dynamism_handler),
1904         dynamic_dimension_inference_(dynamic_dimension_inference) {}
1905 
1906   Status DefaultAction(HloInstruction* hlo) override;
1907 
1908   Status HandleCustomCall(HloInstruction* hlo) override;
1909 
1910   Status HandleTuple(HloInstruction* hlo) override;
1911   Status HandleGetTupleElement(HloInstruction* hlo) override;
1912 
1913   Status HandleParameter(HloInstruction* hlo) override;
1914 
Run(HloComputation * computation,const DynamicPadderOptions::OpSupportsDynamismHandler & op_supports_dynamism_handler,DynamicDimensionInference * dynamic_shape_inference,bool require_dynamic_output)1915   static Status Run(HloComputation* computation,
1916                     const DynamicPadderOptions::OpSupportsDynamismHandler&
1917                         op_supports_dynamism_handler,
1918                     DynamicDimensionInference* dynamic_shape_inference,
1919                     bool require_dynamic_output) {
1920     DynamicShapeRemovingVisitor visitor(op_supports_dynamism_handler,
1921                                         dynamic_shape_inference);
1922     TF_RETURN_IF_ERROR(computation->Accept(&visitor));
1923     // If the outputs is required to be dynamic form, insert static to dynamic
1924     // conversion as root.
1925     if (require_dynamic_output) {
1926       HloInstruction* root = computation->root_instruction();
1927       if (dynamic_shape_inference->HasDynamicDimension(root)) {
1928         TF_ASSIGN_OR_RETURN(HloInstruction * new_root,
1929                             visitor.ConvertToDynamic(root));
1930         computation->set_root_instruction(new_root);
1931       }
1932     }
1933     return OkStatus();
1934   }
1935 
1936  private:
1937   // If a tensor produced by `inst` is in dynamic form, convert it to static and
1938   // returns the new instruction.
1939   StatusOr<HloInstruction*> ConvertToStatic(HloInstruction* inst);
1940 
1941   // If a tensor produced by `inst` is in static form, convert it to dynamic and
1942   // returns the new instruction.
1943   StatusOr<HloInstruction*> ConvertToDynamic(HloInstruction* inst);
1944 
1945   const DynamicPadderOptions::OpSupportsDynamismHandler&
1946       op_supports_dynamism_handler_;
1947 
1948   DynamicDimensionInference* dynamic_dimension_inference_;
1949 };
1950 
ConvertToDynamic(HloInstruction * inst)1951 StatusOr<HloInstruction*> DynamicShapeRemovingVisitor::ConvertToDynamic(
1952     HloInstruction* inst) {
1953   const Shape& shape = inst->shape();
1954   if (shape.IsTuple()) {
1955     std::vector<HloInstruction*> dynamic_operands;
1956     for (int64_t i = 0; i < shape.tuple_shapes_size(); ++i) {
1957       auto gte = inst->AddInstruction(HloInstruction::CreateGetTupleElement(
1958           shape.tuple_shapes(i), inst, i));
1959       if (dynamic_dimension_inference_->HasDynamicDimension(inst, {i})) {
1960         TF_RETURN_IF_ERROR(dynamic_dimension_inference_->Update(gte));
1961         TF_ASSIGN_OR_RETURN(auto dynamic, ConvertToDynamic(gte));
1962         dynamic_operands.push_back(dynamic);
1963       } else {
1964         dynamic_operands.push_back(gte);
1965       }
1966     }
1967     return inst->AddInstruction(HloInstruction::CreateTuple(dynamic_operands));
1968   } else {
1969     // Collect the data input, as well as dimension sizes, and feed them to
1970     // slice to dynamic to create a dynamic tensor.
1971     Shape output_shape = shape;  // 0th element.
1972     CHECK(output_shape.is_static());
1973     std::vector<HloInstruction*> slice_operand;
1974     slice_operand.push_back(inst);
1975     for (int64_t i = 0; i < output_shape.dimensions_size(); ++i) {
1976       auto dimension_size =
1977           dynamic_dimension_inference_->GetDynamicSize(inst, {}, i);
1978       if (dimension_size == nullptr) {
1979         dimension_size = inst->AddInstruction(HloInstruction::CreateConstant(
1980             LiteralUtil::CreateR0<int32_t>(output_shape.dimensions(i))));
1981       } else {
1982         output_shape.set_dynamic_dimension(i, true);
1983       }
1984       slice_operand.push_back(dimension_size);
1985     }
1986     return inst->AddInstruction(HloInstruction::CreateCustomCall(
1987         output_shape, slice_operand, "SliceToDynamic"));
1988   }
1989 }
1990 
ConvertToStatic(HloInstruction * inst)1991 StatusOr<HloInstruction*> DynamicShapeRemovingVisitor::ConvertToStatic(
1992     HloInstruction* inst) {
1993   const Shape& shape = inst->shape();
1994   CHECK(shape.is_dynamic());
1995   if (shape.IsTuple()) {
1996     std::vector<HloInstruction*> static_operands;
1997     for (int64_t i = 0; i < shape.tuple_shapes_size(); ++i) {
1998       auto gte = inst->AddInstruction(HloInstruction::CreateGetTupleElement(
1999           shape.tuple_shapes(i), inst, i));
2000       TF_RETURN_IF_ERROR(dynamic_dimension_inference_->Update(gte));
2001       auto operand = inst->mutable_operand(i);
2002       if (shape.tuple_shapes(i).is_dynamic()) {
2003         TF_ASSIGN_OR_RETURN(auto static_inst, ConvertToStatic(gte));
2004         static_operands.push_back(static_inst);
2005       } else {
2006         static_operands.push_back(operand);
2007       }
2008     }
2009     return inst->AddInstruction(HloInstruction::CreateTuple(static_operands));
2010   } else {
2011     // The output shape of pad static is a tuple. The 0th element is the data
2012     // output, which is the same as input shape, but without dynamic dimensions.
2013     // i-th element is the dynamic dimension size for i-1th input dimension.
2014     Shape data_output_shape = shape;  // 0th element.
2015     data_output_shape.clear_dynamic_dimensions();
2016     Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
2017     for (int64_t i = 0; i < shape.rank(); ++i) {
2018       ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
2019                                     &output_shape);
2020     }
2021     HloInstruction* pad_to_static =
2022         inst->AddInstruction(HloInstruction::CreateCustomCall(
2023             output_shape, {inst}, "PadToStatic", ""));
2024     HloInstruction* data_output =
2025         inst->AddInstruction(HloInstruction::CreateGetTupleElement(
2026             data_output_shape, pad_to_static, 0));
2027     return data_output;
2028   }
2029 }
2030 
DefaultAction(HloInstruction * hlo)2031 Status DynamicShapeRemovingVisitor::DefaultAction(HloInstruction* hlo) {
2032   const bool input_is_dynamic = absl::c_any_of(
2033       hlo->operands(),
2034       [](const HloInstruction* hlo) { return hlo->shape().is_dynamic(); });
2035 
2036   // By default, ops don't support dynamic lowering.
2037   OpDynamismSupport op_support = OpDynamismSupport::kNoSupport;
2038   if (op_supports_dynamism_handler_) {
2039     op_support = op_supports_dynamism_handler_(hlo);
2040   }
2041   if (op_support == OpDynamismSupport::kNoSupport) {
2042     for (auto* sub_computation : hlo->called_computations()) {
2043       for (auto* param : sub_computation->parameter_instructions()) {
2044         param->mutable_shape()->clear_dynamic_dimensions();
2045       }
2046     }
2047   }
2048   // If the input to an op is static and the op doesn't support
2049   // dynamic output, remove dynamism in output -- dynamic_padder should have
2050   // rewritten it to support static shapes.
2051   if (!input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) {
2052     hlo->mutable_shape()->clear_dynamic_dimensions();
2053     return OkStatus();
2054   }
2055 
2056   // Op doesn't support dynamic tensor: For each operand rewrite dynamic input
2057   // into static input using pad_to_static.
2058   if (input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) {
2059     VLOG(1) << "op doesn't support dynamic tensor: " << hlo->ToString();
2060     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
2061       if (hlo->operand(i)->shape().is_dynamic()) {
2062         TF_ASSIGN_OR_RETURN(auto static_operand,
2063                             ConvertToStatic(hlo->mutable_operand(i)));
2064         TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, static_operand));
2065       }
2066     }
2067     // This op doesn't support dynamic lowering so the op has to be static.
2068     hlo->mutable_shape()->clear_dynamic_dimensions();
2069     return OkStatus();
2070   }
2071 
2072   // If the op requires dynamic tensor and input is static -- construct a
2073   // dynamic tensor from the static tensor to feed it.
2074   if (!input_is_dynamic && op_support == OpDynamismSupport::kRequired) {
2075     VLOG(1) << "op doesn't support static tensor: " << hlo->ToString();
2076     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
2077       auto operand = hlo->mutable_operand(i);
2078       if (dynamic_dimension_inference_->HasDynamicDimension(operand)) {
2079         TF_ASSIGN_OR_RETURN(auto dynamic_operand,
2080                             ConvertToDynamic(hlo->mutable_operand(i)));
2081         TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, dynamic_operand));
2082       }
2083     }
2084     return OkStatus();
2085   }
2086 
2087   return OkStatus();
2088 }
2089 
HandleGetTupleElement(HloInstruction * hlo)2090 Status DynamicShapeRemovingVisitor::HandleGetTupleElement(HloInstruction* hlo) {
2091   *hlo->mutable_shape() =
2092       hlo->operand(0)->shape().tuple_shapes(hlo->tuple_index());
2093   return OkStatus();
2094 }
2095 
HandleTuple(HloInstruction * hlo)2096 Status DynamicShapeRemovingVisitor::HandleTuple(HloInstruction* hlo) {
2097   for (int64_t i = 0; i < hlo->operand_count(); ++i) {
2098     *hlo->mutable_shape()->mutable_tuple_shapes(i) = hlo->operand(i)->shape();
2099   }
2100   return OkStatus();
2101 }
2102 
HandleParameter(HloInstruction * hlo)2103 Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) {
2104   return OkStatus();
2105 }
2106 
HandleCustomCall(HloInstruction * hlo)2107 Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) {
2108   if (hlo->custom_call_target() == "SliceToDynamic" ||
2109       hlo->custom_call_target() == "PadToStatic") {
2110     // Those ops support are created to handle dynamic tensors so by their
2111     // nature they support dynamic lowering.
2112     return OkStatus();
2113   }
2114 
2115   return DefaultAction(hlo);
2116 }
2117 
2118 }  // namespace
2119 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2120 StatusOr<bool> DynamicPadder::Run(
2121     HloModule* module,
2122     const absl::flat_hash_set<absl::string_view>& execution_threads) {
2123   bool changed = false;
2124   VLOG(2) << "Pre DynamicPadder HLO:";
2125   XLA_VLOG_LINES(2, module->ToString());
2126   // Removes dynamic dimensions on parameters if there is already a binding for
2127   // it. We do this because we have two different APIs to express a dynamic
2128   // dimension:
2129   //
2130   // 1. Dynamic dimension as specified directly in the shape -- Needed for
2131   // PyTorch.
2132   //
2133   // 2. Dynamic dimension using dynamic parameter binding object. This
2134   // is needed for tensorflow.
2135   //
2136   // For case 1, we will insert "pad-to-static" instruction in the
2137   // beginning of xla execution, to make it into a static layout.
2138   //
2139   // For case 2, since it already has a static layout, we remove the
2140   // dynamic dimension.
2141   //
2142   // TODO(b/145140571): Convert all API invocations to case 1.
2143   //
2144   TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().ForEachBinding(
2145       [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
2146           const DynamicParameterBinding::DynamicDimension& dynamic_dimension)
2147           -> Status {
2148         HloInstruction* parameter =
2149             module->entry_computation()->parameter_instruction(
2150                 dynamic_dimension.parameter_num);
2151         ShapeUtil::UpdateDynamicDimension(parameter->mutable_shape(),
2152                                           dynamic_dimension.parameter_index,
2153                                           dynamic_dimension.dimension, false);
2154         return OkStatus();
2155       }));
2156 
2157   TF_RETURN_IF_ERROR(InsertPadToStaticAfterModuleInputs(module));
2158   TF_ASSIGN_OR_RETURN(
2159       DynamicDimensionInference dynamic_dimension_inference,
2160       DynamicDimensionInference::Run(module, options_.custom_call_handler,
2161                                      options_.shape_check_mode,
2162                                      options_.assertion_generator));
2163 
2164   std::vector<HloComputation*> computations =
2165       module->MakeComputationPostOrder(execution_threads);
2166 
2167   for (HloComputation* computation : computations) {
2168     for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
2169       OpDynamismSupport has_dynamism_support = OpDynamismSupport::kNoSupport;
2170       if (options_.op_supports_dynamism_handler != nullptr) {
2171         has_dynamism_support = options_.op_supports_dynamism_handler(inst);
2172       }
2173       // This op support dynamic lowering, no padding is required.
2174       if (has_dynamism_support != OpDynamismSupport::kNoSupport) {
2175         continue;
2176       }
2177       if (inst->opcode() == HloOpcode::kConcatenate) {
2178         TF_ASSIGN_OR_RETURN(
2179             bool c, RewriteDynamicConcat(inst, &dynamic_dimension_inference));
2180         changed |= c;
2181         continue;
2182       }
2183       if (inst->opcode() == HloOpcode::kReverse) {
2184         TF_ASSIGN_OR_RETURN(bool c,
2185                             RewriteReverse(inst, &dynamic_dimension_inference));
2186         changed |= c;
2187         continue;
2188       }
2189       if (inst->opcode() == HloOpcode::kSort) {
2190         TF_ASSIGN_OR_RETURN(
2191             bool c, RewriteDynamicSort(inst, &dynamic_dimension_inference));
2192         changed |= c;
2193         continue;
2194       }
2195       if (inst->opcode() == HloOpcode::kReshape ||
2196           inst->opcode() == HloOpcode::kDynamicReshape) {
2197         TF_ASSIGN_OR_RETURN(
2198             bool c, RewriteDynamicReshape(inst, &dynamic_dimension_inference));
2199         changed |= c;
2200         continue;
2201       }
2202 
2203       // Elementwise binary with dynamic shapes have implicit broadcast
2204       // semantics.
2205       if (inst->IsElementwiseBinary()) {
2206         TF_ASSIGN_OR_RETURN(
2207             bool c, RewriteDynamicBinaryOp(inst, &dynamic_dimension_inference));
2208         changed |= c;
2209         continue;
2210       }
2211 
2212       if (inst->opcode() == HloOpcode::kDynamicUpdateSlice) {
2213         TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicUpdateSlice(
2214                                         inst, &dynamic_dimension_inference));
2215         changed |= c;
2216         continue;
2217       }
2218 
2219       if (inst->IsCustomCall("DynamicConvolutionInputGrad")) {
2220         TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicConvolutionInputGrad(
2221                                         inst, &dynamic_dimension_inference));
2222         changed |= c;
2223         continue;
2224       }
2225 
2226       if (inst->IsCustomCall("DynamicConvolutionForward")) {
2227         TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicConvolutionForward(
2228                                         inst, &dynamic_dimension_inference));
2229         changed |= c;
2230         continue;
2231       }
2232 
2233       if (inst->IsCustomCall("DynamicConvolutionKernelGrad")) {
2234         TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicConvolutionKernelGrad(
2235                                         inst, &dynamic_dimension_inference));
2236         changed |= c;
2237         continue;
2238       }
2239 
2240       if (inst->IsCustomCall("DynamicReduceWindowSamePadding")) {
2241         TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicReduceWindowSamePadding(
2242                                         inst, &dynamic_dimension_inference));
2243         changed |= c;
2244         continue;
2245       }
2246 
2247       if (inst->IsCustomCall("DynamicSelectAndScatterSamePadding")) {
2248         TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicSelectAndScatterSamePadding(
2249                                         inst, &dynamic_dimension_inference));
2250         changed |= c;
2251         continue;
2252       }
2253 
2254       for (int64_t operand_num = 0; operand_num < inst->operand_count();
2255            ++operand_num) {
2256         HloInstruction* original_operand = inst->mutable_operand(operand_num);
2257         HloInstruction* operand = original_operand;
2258         if (!operand->shape().IsArray()) {
2259           continue;
2260         }
2261 
2262         for (int64_t input_dim = 0; input_dim < operand->shape().rank();
2263              ++input_dim) {
2264           HloInstruction* operand_dynamic_size =
2265               dynamic_dimension_inference.GetDynamicSize(original_operand, {},
2266                                                          input_dim);
2267           if (operand_dynamic_size == nullptr) {
2268             continue;
2269           }
2270           VLOG(2) << "Has dynamic dimension of operand" << operand_num << " @"
2271                   << input_dim;
2272 
2273           if (ShouldSkipPadOnOperand(inst, operand_num, input_dim)) {
2274             continue;
2275           }
2276 
2277           TF_ASSIGN_OR_RETURN(HloInstruction * identity_value,
2278                               ChooseIdentityValue(inst, operand_num));
2279           if (identity_value == nullptr) {
2280             continue;
2281           }
2282 
2283           HloInstruction* padded = PadWithScalar(
2284               operand, input_dim, operand_dynamic_size, identity_value);
2285           TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded));
2286           operand = inst->mutable_operand(operand_num);
2287           changed = true;
2288         }
2289       }
2290     }
2291   }
2292 
2293   // There are ops that only support dynamic lowering and ops that only support
2294   // static lowering, add dynamic<->static tensor conversion around the boundary
2295   // between those ops, as well as the root instruction.
2296   computations = module->MakeComputationPostOrder(execution_threads);
2297   // Reverse postorder so that if caller doesn't support dynamic tensor (while,
2298   // etc), change their called computation to only take static tensors.
2299   for (auto it = computations.rbegin(); it != computations.rend(); ++it) {
2300     HloComputation* computation = *it;
2301     // if slice_dynamic_output_ is set and this is entry computation, we need
2302     // the output tensor to be in dynamic form.
2303     bool require_dynamic_output = options_.slice_dynamic_output &&
2304                                   computation == module->entry_computation();
2305     changed |= require_dynamic_output;
2306     TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run(
2307         computation, options_.op_supports_dynamism_handler,
2308         &dynamic_dimension_inference,
2309         /*require_dynamic_output=*/require_dynamic_output));
2310   }
2311 
2312   if (changed) {
2313     dynamic_padding_gauge->GetCell()->Set(changed);
2314     module->set_is_dynamic(true);
2315   }
2316 
2317   for (auto* computation : module->computations(execution_threads)) {
2318     for (auto instruction : computation->MakeInstructionPostOrder()) {
2319       TF_ASSIGN_OR_RETURN(
2320           bool c, ReplaceGetSize(instruction, &dynamic_dimension_inference));
2321       changed |= c;
2322     }
2323   }
2324 
2325   for (auto* computation : module->computations(execution_threads)) {
2326     for (auto instruction : computation->MakeInstructionPostOrder()) {
2327       TF_ASSIGN_OR_RETURN(bool c, ReplaceSetSize(instruction));
2328       changed |= c;
2329 
2330       TF_ASSIGN_OR_RETURN(c, ReplaceSetBound(instruction));
2331       changed |= c;
2332     }
2333   }
2334 
2335   HloDCE dce;
2336   TF_ASSIGN_OR_RETURN(bool c, dce.Run(module, execution_threads));
2337   changed |= c;
2338 
2339   VLOG(2) << "Post DynamicPadder HLO:";
2340   XLA_VLOG_LINES(2, module->ToString());
2341   return changed;
2342 }
2343 
2344 }  // namespace xla
2345