1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
16
17 #include <algorithm>
18 #include <functional>
19 #include <optional>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/comparison_util.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/literal_util.h"
30 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
31 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
32 #include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
34 #include "tensorflow/compiler/xla/service/hlo_computation.h"
35 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_dce.h"
37 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
38 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
39 #include "tensorflow/compiler/xla/service/hlo_module.h"
40 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
41 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
42 #include "tensorflow/compiler/xla/service/shape_inference.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/compiler/xla/window_util.h"
47 #include "tensorflow/compiler/xla/xla_data.pb.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/monitoring/gauge.h"
50 #include "tensorflow/core/platform/errors.h"
51 #include "tensorflow/core/platform/statusor.h"
52
53 namespace xla {
54
55 namespace {
56
57 auto* dynamic_padding_gauge = tensorflow::monitoring::Gauge<bool, 0>::New(
58 "/tensorflow/core/use_dynamic_padding_gauge",
59 "Tracks if dynamic padder is used.");
60
61 // ChooseIdentityValue looks at the instruction's operand, returns a
62 // identity value which, when padded, doesn't change the result of the
63 // instruction.
64 //
65 // nullopt is returned if padding doesn't need to be reset.
ChooseIdentityValue(HloInstruction * inst,int64_t operand_number)66 StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst,
67 int64_t operand_number) {
68 // Padding on elementwise operation doesn't affect the result of the effective
69 // data.
70 if (inst->IsElementwise()) {
71 return nullptr;
72 }
73 if (inst->opcode() == HloOpcode::kSelectAndScatter ||
74 inst->IsCustomCall("DynamicSelectAndScatterSamePadding")) {
75 if (operand_number == 1) {
76 return inst->mutable_operand(2);
77 }
78 TF_RET_CHECK(operand_number == 0);
79 HloComputation* select = inst->called_computations()[0];
80
81 if (Match(select->root_instruction(),
82 match::Compare(match::Parameter(), match::Parameter())
83 .WithComparisonDirection(ComparisonDirection::kGe))) {
84 return inst->AddInstruction(HloInstruction::CreateConstant(
85 LiteralUtil::MinValue(inst->operand(0)->shape().element_type())));
86 } else {
87 return Unimplemented(
88 "Only select and scatter with `max` as select function is "
89 "supported, got %s",
90 select->ToString());
91 }
92 }
93 switch (inst->opcode()) {
94 case HloOpcode::kReduce: {
95 auto* reduce = Cast<HloReduceInstruction>(inst);
96 TF_RET_CHECK(operand_number < reduce->input_count())
97 << "Only data operand with dynamic dimension is valid.";
98 // Variadic reduce has different init value for different operand, given
99 // a data operand number, find the init value index.
100 int64_t init_value_index = reduce->input_count() + operand_number;
101 return inst->mutable_operand(init_value_index);
102 }
103 case HloOpcode::kReduceWindow: {
104 auto* reduce_window = Cast<HloReduceWindowInstruction>(inst);
105 TF_RET_CHECK(operand_number < reduce_window->input_count())
106 << "Only data operand with dynamic dimension is valid.";
107 // Variadic reduce has different init value for different operand, given
108 // a data operand number, find the init value index.
109 int64_t init_value_index = reduce_window->input_count() + operand_number;
110 return inst->mutable_operand(init_value_index);
111 }
112
113 case HloOpcode::kConvolution:
114 case HloOpcode::kDot: {
115 // Use 0 as padding value for convolution and dot.
116 //
117 // Note that the output type (inst->shape().element_type()) isn't
118 // necessarily the same as the input type (element type of operands). For
119 // example, a dot can take s8 as input and output s32.
120 PrimitiveType ptype = inst->operand(0)->shape().element_type();
121 return inst->AddInstruction(
122 HloInstruction::CreateConstant(LiteralUtil::Zero(ptype)));
123 }
124
125 case HloOpcode::kPad:
126 return inst->mutable_operand(1);
127 case HloOpcode::kScatter: {
128 if (operand_number != 1) {
129 return nullptr;
130 }
131 PrimitiveType indices_ptype =
132 inst->operand(operand_number)->shape().element_type();
133
134 return inst->AddInstruction(
135 HloInstruction::CreateConstant(LiteralUtil::MaxValue(indices_ptype)));
136 }
137 case HloOpcode::kParameter:
138 case HloOpcode::kGather:
139 case HloOpcode::kDynamicSlice:
140 case HloOpcode::kDynamicUpdateSlice:
141 case HloOpcode::kGetDimensionSize:
142 case HloOpcode::kSetDimensionSize:
143 case HloOpcode::kConcatenate:
144 case HloOpcode::kReshape:
145 case HloOpcode::kReverse:
146 case HloOpcode::kTuple:
147 case HloOpcode::kAllReduce:
148 case HloOpcode::kReduceScatter:
149 case HloOpcode::kBroadcast:
150 case HloOpcode::kTranspose:
151 case HloOpcode::kSort:
152 case HloOpcode::kSlice:
153 case HloOpcode::kDomain:
154 return nullptr;
155 case HloOpcode::kCustomCall:
156 // Assume that custom calls created by the client are valid with padded
157 // dynamic dimensions.
158 return nullptr;
159 default:
160 return UnimplementedStrCat("Unimplemented padding for instruction: ",
161 inst->ToString());
162 }
163 }
164
ReplaceGetSize(HloInstruction * instr,DynamicDimensionInference * dynamic_dimension_inference)165 StatusOr<bool> ReplaceGetSize(
166 HloInstruction* instr,
167 DynamicDimensionInference* dynamic_dimension_inference) {
168 if (instr->opcode() != HloOpcode::kGetDimensionSize) {
169 return false;
170 }
171 HloComputation* computation = instr->parent();
172
173 TF_ASSIGN_OR_RETURN(auto legal_shape,
174 ShapeInference::InferGetDimensionSizeShape(
175 instr->operand(0)->shape(), instr->dimension()));
176 TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape))
177 << "instr->shape() " << instr->shape().ToString() << " , "
178 << "legal_shape " << legal_shape.ToString();
179 TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32));
180 HloInstruction* operand = instr->mutable_operand(0);
181 int64_t dim = instr->dimension();
182 HloInstruction* dynamic_size =
183 dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
184 if (dynamic_size != nullptr) {
185 TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
186 // The dependency between a instruction and its dynamic dimensions is not
187 // modeled in the IR. As instr is being replaced by dynamic_size, also tell
188 // dynamic dimension inference that the instruction is being replaced.
189 dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(
190 instr, dynamic_size);
191 } else {
192 int32_t size = instr->operand(0)->shape().dimensions(dim);
193 HloInstruction* new_instr = computation->AddInstruction(
194 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(size)));
195 TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
196 dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr,
197 new_instr);
198 }
199 return true;
200 }
201
ReplaceSetSize(HloInstruction * instr)202 StatusOr<bool> ReplaceSetSize(HloInstruction* instr) {
203 if (instr->opcode() != HloOpcode::kSetDimensionSize) {
204 return false;
205 }
206
207 TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
208 instr->shape(), instr->operand(0)->shape()))
209 << "instr->shape() " << instr->shape().ToString() << " , "
210 << "instruction operand shape " << instr->operand(0)->shape();
211 HloInstruction* operand = instr->mutable_operand(0);
212
213 TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
214 return true;
215 }
216
ReplaceSetBound(HloInstruction * instr)217 StatusOr<bool> ReplaceSetBound(HloInstruction* instr) {
218 if (instr->opcode() != HloOpcode::kCustomCall ||
219 instr->custom_call_target() != "SetBound") {
220 return false;
221 }
222
223 TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
224 instr->shape(), instr->operand(0)->shape()))
225 << "instr->shape() " << instr->shape().ToString() << " , "
226 << "instruction operand shape " << instr->operand(0)->shape();
227 HloInstruction* operand = instr->mutable_operand(0);
228
229 TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
230 return true;
231 }
232
ShouldSkipPadOnOperand(const HloInstruction * inst,int64_t operand_num,int64_t dimension)233 bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64_t operand_num,
234 int64_t dimension) {
235 switch (inst->opcode()) {
236 case HloOpcode::kConvolution: {
237 if (operand_num == 0) {
238 if (dimension ==
239 inst->convolution_dimension_numbers().input_batch_dimension()) {
240 return true;
241 }
242 const auto& spatial_dims =
243 inst->convolution_dimension_numbers().input_spatial_dimensions();
244 for (int64_t spatial_dim = 0; spatial_dim < spatial_dims.size();
245 ++spatial_dim) {
246 // A spatial dimemnsion with a window of size 1 does not need
247 // padding.
248 if (spatial_dims[spatial_dim] == dimension &&
249 inst->window().dimensions(spatial_dim).size() == 1) {
250 return true;
251 }
252 }
253 }
254 return operand_num == 1 &&
255 (dimension == inst->convolution_dimension_numbers()
256 .kernel_output_feature_dimension());
257 }
258 case HloOpcode::kDot: {
259 if (operand_num == 0) {
260 return !absl::c_linear_search(
261 inst->dot_dimension_numbers().lhs_contracting_dimensions(),
262 dimension);
263 }
264 return !absl::c_linear_search(
265 inst->dot_dimension_numbers().rhs_contracting_dimensions(),
266 dimension);
267 }
268 case HloOpcode::kReduce:
269 return !absl::c_linear_search(inst->dimensions(), dimension);
270 case HloOpcode::kSelectAndScatter:
271 case HloOpcode::kReduceWindow:
272 return inst->window().dimensions(dimension).size() == 1;
273 default:
274 return false;
275 }
276 }
277
278 // Generates a mask representing the effective area of data and padded area of
279 // data using iota and dynamic_size. For example, given a dimension of 7
280 // elements and 5 effective elements:
281 //
282 // iota = [0, 1, 2, 3, 4, 5, 6]
283 // broadcast_dynamic_size = [5, 5, 5, 5, 5, 5, 5]
284 // mask = lt(iota, broadcast_dynamic_size) = [t, t, t, t, t, f, f]
285 //
286 // Once the mask is generated, the input data is then padded using the
287 // mask and pad value.
288 //
PadWithScalar(HloInstruction * inst,int64_t dim,HloInstruction * dynamic_size,HloInstruction * padding_scalar)289 HloInstruction* PadWithScalar(HloInstruction* inst, int64_t dim,
290 HloInstruction* dynamic_size,
291 HloInstruction* padding_scalar) {
292 CHECK(inst != nullptr && dynamic_size != nullptr &&
293 padding_scalar != nullptr);
294 const Shape mask_shape =
295 ShapeUtil::ChangeElementType(inst->shape(), xla::S32);
296 const Shape pred_shape =
297 ShapeUtil::ChangeElementType(inst->shape(), xla::PRED);
298 HloInstruction* iota =
299 inst->AddInstruction(HloInstruction::CreateIota(mask_shape, dim));
300
301 HloInstruction* broadcasted_effective_size = inst->AddInstruction(
302 HloInstruction::CreateBroadcast(mask_shape, dynamic_size, {}));
303 HloInstruction* pred = inst->AddInstruction(HloInstruction::CreateCompare(
304 pred_shape, iota, broadcasted_effective_size, ComparisonDirection::kLt));
305
306 HloInstruction* broadcasted_identity_value = inst->AddInstruction(
307 HloInstruction::CreateBroadcast(inst->shape(), padding_scalar, {}));
308 HloInstruction* padded = inst->AddInstruction(
309 HloInstruction::CreateTernary(inst->shape(), HloOpcode::kSelect, pred,
310 inst, broadcasted_identity_value));
311 return padded;
312 }
313
314 // Generate a 1-0 mask for input_dim where 1 means data in dynamic shape.
GenerateBinaryMask(HloInstruction * reshape,int64_t input_dim,absl::Span<const int64_t> output_dims,absl::Span<HloInstruction * > output_dynamic_dims,HloInstruction * one,HloInstruction * zero,bool split_input)315 HloInstruction* GenerateBinaryMask(
316 HloInstruction* reshape, int64_t input_dim,
317 absl::Span<const int64_t> output_dims,
318 absl::Span<HloInstruction*> output_dynamic_dims, HloInstruction* one,
319 HloInstruction* zero, bool split_input) {
320 Shape input_shape =
321 split_input ? reshape->operand(0)->shape() : reshape->shape();
322 Shape output_shape =
323 split_input ? reshape->shape() : reshape->operand(0)->shape();
324 const Shape mask_input_shape =
325 ShapeUtil::MakeShape(xla::S32, {input_shape.dimensions(input_dim)});
326 const Shape pred_input_shape =
327 ShapeUtil::MakeShape(xla::PRED, {input_shape.dimensions(input_dim)});
328 HloInstruction* pred_true = reshape->AddInstruction(
329 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
330 HloInstruction* input_shape_pred_mask = reshape->AddInstruction(
331 HloInstruction::CreateBroadcast(pred_input_shape, pred_true, {}));
332 bool need_rewrite = false;
333 // Iota contains a linear index for each element in input shape.
334 HloInstruction* iota =
335 reshape->AddInstruction(HloInstruction::CreateIota(mask_input_shape, 0));
336
337 // Compute the multi-dimensional indices from a linear index and
338 // compare to dynamic dimension size to generate the mask.
339 // For a 2x3x3 shape, iota is first set to:
340 // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16,17]
341 // iota % 3 gives the index for the last dimension.
342 // [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
343 // Then iota goes to:
344 // [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5] (after div 3)
345 // iota % 3 gives the index of the second last dimension.
346 // [0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2]
347 // Then iota goes to:
348 // [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1] (after div 3)
349 // It gives the index of the major dimension.
350 // For example, element 16 in the original iota will in the end get index
351 // (1, 2, 1). Each index is used for generating the mask (if necessary) by
352 // comparing to the dynamic size value for that dimension.
353 //
354 // Skip index 0 since there is no need to rewrite a major output dimension.
355 for (int64_t i = 1; i < output_dims.size(); ++i) {
356 if (output_dynamic_dims[output_dims[i]] != nullptr) {
357 // If there is dynamic dimension in the output, need to rewrite the input.
358 need_rewrite = true;
359 break;
360 }
361 }
362 if (!need_rewrite) {
363 return nullptr;
364 }
365
366 for (int64_t i = output_dims.size() - 1; i > 0; i--) {
367 const int64_t output_dim = output_dims[i];
368 HloInstruction* dynamic_size = output_dynamic_dims[output_dim];
369 HloInstruction* static_output_dim_size = reshape->AddInstruction(
370 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
371 output_shape.dimensions(output_dim))));
372 HloInstruction* broadcasted_static_output_dim_size =
373 reshape->AddInstruction(HloInstruction::CreateBroadcast(
374 mask_input_shape, static_output_dim_size, {}));
375 if (dynamic_size != nullptr) {
376 // Generate the mask for output_dim.
377 HloInstruction* dim_index =
378 reshape->AddInstruction(HloInstruction::CreateBinary(
379 mask_input_shape, HloOpcode::kRemainder, iota,
380 broadcasted_static_output_dim_size));
381 HloInstruction* broadcasted_effective_size = reshape->AddInstruction(
382 HloInstruction::CreateBroadcast(mask_input_shape, dynamic_size, {}));
383 HloInstruction* selected =
384 reshape->AddInstruction(HloInstruction::CreateCompare(
385 pred_input_shape, dim_index, broadcasted_effective_size,
386 ComparisonDirection::kLt));
387
388 // Merge the mask.
389 input_shape_pred_mask = reshape->AddInstruction(
390 HloInstruction::CreateBinary(pred_input_shape, HloOpcode::kAnd,
391 input_shape_pred_mask, selected));
392 }
393
394 // Update iota values by "shifting out" dimension i.
395 iota = reshape->AddInstruction(
396 HloInstruction::CreateBinary(mask_input_shape, HloOpcode::kDivide, iota,
397 broadcasted_static_output_dim_size));
398 }
399
400 HloInstruction* broadcasted_one = reshape->AddInstruction(
401 HloInstruction::CreateBroadcast(mask_input_shape, one, {}));
402 HloInstruction* broadcasted_zero = reshape->AddInstruction(
403 HloInstruction::CreateBroadcast(mask_input_shape, zero, {}));
404 return reshape->AddInstruction(HloInstruction::CreateTernary(
405 mask_input_shape, HloOpcode::kSelect, input_shape_pred_mask,
406 broadcasted_one, broadcasted_zero));
407 }
408
409 // In a reshape if a dynamic dimension is splitted into multiple output
410 // dimensions, we need to rewrite the input of the reshape.
411 //
412 // The reason for this is that a continuous input may not be evenly reshaped
413 // into output. Image we have [<=6] where valid data has size 4 and padding (P)
414 // data has size 2: [a,b,c,d,P,P]
415 //
416 // And we have a reshape that produces dynamic output dimensions.
417 //
418 // [<=6]
419 // |
420 // Reshape
421 // |
422 // [2, <=3]
423 //
424 // This should produce the same result as if the data has no padding:
425 //
426 // [4] // [a, b, c, d]
427 // |
428 // Reshape
429 // |
430 // [2, 2] // [[a,b], [c,d]]
431 //
432 // Without reshape rewriting, the result looks like:
433 //
434 // [[a,b,c]
435 // [d,P,P]], which is incorrect.
436 //
437 // We need to rewrite the reshape such that it produces:
438 // [[a,b,P]
439 // [c,d,P]]
440 //
441 // The way we do this is by a 4-steps cumsum-gather algorithm:
442 //
443 // 1.First we use the output shape to generate a binary 0-1 masking, which masks
444 // out the padded area of the flattened output shape:
445 // [1,1,0,1,1,0]
446 //
447 // 2.We then do a cumsum with the mask:
448 // [1,2,2,3,4,4] and subtract it with 1:
449 // [0,1,1,2,3,3]
450 //
451 // 3.Use the result of cumsum as gather indices to rearrange the original
452 // data. Feed the original input [a,b,c,d,P,P] and indices into gather.
453 //
454 // operand [a,b,c,d,P,P], indices [0,1,1,2,3,3]
455 // | |
456 // Gather-----------------+
457 // |
458 // v
459 // value[a,b,b,c,d,d], which is equivalent to [a,b,P,c,d,P] as padding value
460 // doesn't matter.
461 //
462 //
463 // 4.Feed the sorted input to original reshape[6]->[2,3], we can now get the
464 // correct result:
465 // [[a,b,P]
466 // [c,d,P]]
467 //
RewriteDynamicReshapeSplitInput(HloInstruction * reshape,int64_t input_dim,absl::Span<const int64_t> output_dims,absl::Span<HloInstruction * > output_dynamic_dims,DynamicDimensionInference * dynamic_dimension_inference)468 StatusOr<bool> RewriteDynamicReshapeSplitInput(
469 HloInstruction* reshape, int64_t input_dim,
470 absl::Span<const int64_t> output_dims,
471 absl::Span<HloInstruction*> output_dynamic_dims,
472 DynamicDimensionInference* dynamic_dimension_inference) {
473 VLOG(2) << "Reshaping input dim " << input_dim << " to "
474 << VectorString(output_dims);
475 const Shape operand_shape = reshape->operand(0)->shape();
476 TF_RET_CHECK(output_dims.size() > 1);
477
478 const Shape mask_input_shape =
479 ShapeUtil::MakeShape(xla::S32, {operand_shape.dimensions(input_dim)});
480 const Shape pred_input_shape =
481 ShapeUtil::MakeShape(xla::PRED, {operand_shape.dimensions(input_dim)});
482
483 HloInstruction* zero = reshape->AddInstruction(
484 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
485 HloInstruction* one = reshape->AddInstruction(
486 HloInstruction::CreateConstant(LiteralUtil::One(S32)));
487
488 // Step 1 -- generate binary mask.
489 HloInstruction* input_shape_binary_mask =
490 GenerateBinaryMask(reshape, input_dim, output_dims, output_dynamic_dims,
491 one, zero, /*split_input=*/true);
492 if (input_shape_binary_mask == nullptr) {
493 // No need to rewrite.
494 VLOG(2) << "No need to rewrite";
495 return false;
496 }
497
498 // Step 2. Do a cumsum on the binary mask.
499
500 auto embedded_builder = HloComputation::Builder("add");
501 {
502 auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
503 0, ShapeUtil::MakeShape(S32, {}), "lhs"));
504 auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
505 1, ShapeUtil::MakeShape(S32, {}), "rhs"));
506 embedded_builder.AddInstruction(
507 HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
508 }
509
510 HloComputation* add =
511 reshape->GetModule()->AddEmbeddedComputation(embedded_builder.Build());
512 Window cumsum_window;
513 // First dimension is unchanged.
514 WindowDimension* dim = cumsum_window.add_dimensions();
515 dim->set_size(operand_shape.dimensions(input_dim));
516 dim->set_stride(1);
517 dim->set_padding_low(operand_shape.dimensions(input_dim) - 1);
518 dim->set_padding_high(0);
519 dim->set_window_dilation(1);
520 dim->set_base_dilation(1);
521 HloInstruction* cumsum =
522 reshape->AddInstruction(HloInstruction::CreateReduceWindow(
523 mask_input_shape, input_shape_binary_mask, zero, cumsum_window, add));
524
525 HloInstruction* broadcast_ones = reshape->AddInstruction(
526 HloInstruction::CreateBroadcast(mask_input_shape, one, {}));
527 cumsum = reshape->AddInstruction(HloInstruction::CreateBinary(
528 mask_input_shape, HloOpcode::kSubtract, cumsum, broadcast_ones));
529
530 GatherDimensionNumbers gather_dim_numbers;
531 // Use gather to rearrange the input dim dimension.
532 for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) {
533 // Offset dim is every dimension including newly added size 1 dim, except
534 // for input_dim, which acts as a batch_dim.
535 if (i != input_dim) {
536 gather_dim_numbers.add_offset_dims(i);
537 }
538 }
539 // The dimension to rewrite is the index dim.
540 gather_dim_numbers.add_start_index_map(input_dim);
541 gather_dim_numbers.set_index_vector_dim(1);
542 gather_dim_numbers.add_collapsed_slice_dims(input_dim);
543
544 // Step 3. Gather.
545
546 // Temporarily removes dynamic dimension before entering gather -- we want the
547 // gather to ignore dynamic dimension.
548 HloInstruction* operand_static_dim_size =
549 reshape->AddInstruction(HloInstruction::CreateConstant(
550 LiteralUtil::CreateR0<int32_t>(operand_shape.dimensions(input_dim))));
551 HloInstruction* operand_static =
552 reshape->AddInstruction(HloInstruction::CreateSetDimensionSize(
553 operand_shape, reshape->mutable_operand(0), operand_static_dim_size,
554 input_dim));
555
556 std::vector<int64_t> slice_sizes(operand_shape.dimensions().begin(),
557 operand_shape.dimensions().end());
558 slice_sizes[input_dim] = 1;
559 HloInstruction* gather = reshape->AddInstruction(HloInstruction::CreateGather(
560 ShapeUtil::MakeShape(operand_shape.element_type(),
561 operand_shape.dimensions()),
562 operand_static, cumsum, gather_dim_numbers, slice_sizes, true));
563
564 // Step 4: Feed gather input to original reshape.
565
566 TF_RETURN_IF_ERROR(reshape->ReplaceOperandWith(0, gather));
567
568 HloInstruction* reshape_dynamic = reshape;
569
570 auto users = reshape->users();
571
572 // Forward the output dynamic dimension.
573 for (int64_t output_dim : output_dims) {
574 HloInstruction* output_dynamic_size =
575 dynamic_dimension_inference->GetDynamicSize(reshape, {}, output_dim);
576 if (output_dynamic_size != nullptr) {
577 reshape_dynamic =
578 reshape->AddInstruction(HloInstruction::CreateSetDimensionSize(
579 reshape->shape(), reshape_dynamic, output_dynamic_size,
580 output_dim));
581 }
582 }
583
584 for (auto* user : users) {
585 TF_RETURN_IF_ERROR(reshape->ReplaceUseWith(user, reshape_dynamic));
586 }
587 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
588 reshape, reshape_dynamic, {}));
589
590 return true;
591 }
592
593 // RewriteDynamicReshapeCombineInput is similar to
594 // RewriteDynamicReshapeSplitInput, in a reshape if multiple dimensions are
595 // combined into one dimension, we need to rewrite the output.
596 //
597 // The reason for this is that a continuous input may not be evenly reshaped
598 // into output. Image we have [2, <=3] where second dimension has size 2 and
599 // padding(P) data has size 1:
600 // [[a,b,P]
601 // [c,d,P]]
602 //
603 // And we have a reshape that combines this two input dimensions.
604 //
605 // [2, <=3]
606 // |
607 // Reshape
608 // |
609 // [6]
610 //
611 // This should produce the same result as if the data has no padding:
612 //
613 // [2, 2] // [[a, b], [c, d]]
614 // |
615 // Reshape
616 // |
617 // [4] // [a,b,c,d]
618 //
619 // Without rewriting, the result would be:
620 //
621 // [a,b,P,c,d,P], which is incorrect.
622 //
623 // We need to rewrite the reshape such that it produces:
624 // [a,b,c,d,P,P]
625 //
626 // The way we do this is by a 4-steps sort-gather algorithm:
627 //
628 // 1.First we use the input shape to generate a binary 0-1 masking, which masks
629 // out the padded area of the output:
630 // [1,1,0,1,1,0]
631 //
632 // 2.We then generate an iota mask using the output shape:
633 // [0,1,2,3,4,5]
634 //
635 // 3.Stable sort the iota mask using the binary mask as key:
636 // key [1,1,0,1,1,0]
637 // value[0,1,2,3,4,5]
638 // | Sort by key
639 // v
640 // key [1,1,1,1,0,0]
641 // value[0,1,3,4,2,5]
642 //
643 // 4.Gather the original output [a,b,P,c,d,P] using the sorted iota mask:
644 // original output gather indices
645 // [a,b,P,c,d,P] [0,1,3,4,2,5]
646 // | |
647 // Gather ----------------+
648 // |
649 // [a,b,c,d,P,P]
650 //
RewriteDynamicReshapeCombineInput(HloInstruction * reshape,absl::Span<const int64_t> input_dims,int64_t output_dim,absl::Span<HloInstruction * > input_dynamic_dims,DynamicDimensionInference * dynamic_dimension_inference)651 StatusOr<bool> RewriteDynamicReshapeCombineInput(
652 HloInstruction* reshape, absl::Span<const int64_t> input_dims,
653 int64_t output_dim, absl::Span<HloInstruction*> input_dynamic_dims,
654 DynamicDimensionInference* dynamic_dimension_inference) {
655 // Rewrite dynamic reshape into reshape followed by a sort, all padded
656 // data will be moved to the end.
657 HloInstruction* zero = reshape->AddInstruction(
658 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
659 HloInstruction* one = reshape->AddInstruction(
660 HloInstruction::CreateConstant(LiteralUtil::One(S32)));
661 const Shape output_shape = reshape->shape();
662 const Shape input_shape = reshape->operand(0)->shape();
663 const Shape mask_output_shape =
664 ShapeUtil::MakeShape(xla::S32, {output_shape.dimensions(output_dim)});
665
666 // Step 1.
667 // Generate binary mask.
668 HloInstruction* output_shape_binary_mask =
669 GenerateBinaryMask(reshape, output_dim, input_dims, input_dynamic_dims,
670 one, zero, /*split_input=*/false);
671 if (output_shape_binary_mask == nullptr) {
672 VLOG(2) << "No need to rewrite";
673 return false;
674 }
675
676 // Step 2.
677 // Generate an iota with output shape.
678 HloInstruction* iota =
679 reshape->AddInstruction(HloInstruction::CreateIota(mask_output_shape, 0));
680
681 // Step 3.
682 // Stable sort the iota mask using the binary mask as key and iota as value:
683
684 // Build computation for sort, key is the mask, value is the iota.
685 HloComputation::Builder comp_builder("compare");
686 HloInstruction* lhs_key =
687 comp_builder.AddInstruction(HloInstruction::CreateParameter(
688 0, ShapeUtil::MakeScalarShape(S32), "lhs_key"));
689 HloInstruction* rhs_key =
690 comp_builder.AddInstruction(HloInstruction::CreateParameter(
691 1, ShapeUtil::MakeScalarShape(S32), "rhs_key"));
692
693 // Values for lhs and rhs
694 comp_builder.AddInstruction(HloInstruction::CreateParameter(
695 2, ShapeUtil::MakeScalarShape(S32), "lhs_value"));
696 comp_builder.AddInstruction(HloInstruction::CreateParameter(
697 3, ShapeUtil::MakeScalarShape(S32), "rhs_value"));
698 comp_builder.AddInstruction(
699 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), lhs_key,
700 rhs_key, ComparisonDirection::kGt));
701 HloComputation* compare =
702 reshape->GetModule()->AddEmbeddedComputation(comp_builder.Build());
703
704 // Use mask_reshaped as key, sort reshaped data as value.
705 HloInstruction* sort = reshape->AddInstruction(HloInstruction::CreateSort(
706 ShapeUtil::MakeTupleShape({mask_output_shape, mask_output_shape}), 0,
707 {output_shape_binary_mask, iota}, compare,
708 /*is_stable=*/true));
709
710 HloInstruction* gather_indices = reshape->AddInstruction(
711 HloInstruction::CreateGetTupleElement(mask_output_shape, sort, 1));
712
713 // Step 4.Gather the original output using the sorted iota mask:
714
715 GatherDimensionNumbers gather_dim_numbers;
716 // Use gather to rearrange the output dim dimension.
717 for (int64_t i = 0; i < output_shape.dimensions_size(); ++i) {
718 // Offset dim is every dimension including newly added size 1 dim, except
719 // for input_dim, which acts as a batch_dim.
720 if (i != output_dim) {
721 gather_dim_numbers.add_offset_dims(i);
722 }
723 }
724 // The dimension to rewrite is the index dim.
725 gather_dim_numbers.add_start_index_map(output_dim);
726 gather_dim_numbers.set_index_vector_dim(1);
727 gather_dim_numbers.add_collapsed_slice_dims(output_dim);
728
729 HloInstruction* static_dim_size = reshape->AddInstruction(
730 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
731 reshape->shape().dimensions(output_dim))));
732
733 // Temporarily removes dynamic dimension of the reshape before we send it to
734 // the sort -- we want padded area to also participate in the gather.
735 HloInstruction* reshape_static =
736 reshape->AddInstruction(HloInstruction::CreateSetDimensionSize(
737 reshape->shape(), reshape, static_dim_size, output_dim));
738 std::vector<int64_t> gather_slice_sizes(output_shape.dimensions().begin(),
739 output_shape.dimensions().end());
740 gather_slice_sizes[output_dim] = 1;
741 HloInstruction* gather = reshape->AddInstruction(HloInstruction::CreateGather(
742 output_shape, reshape_static, gather_indices, gather_dim_numbers,
743 gather_slice_sizes, true));
744
745 // Forward dynamic size to the newly created gather.
746 HloInstruction* output_dynamic_size =
747 dynamic_dimension_inference->GetDynamicSize(reshape, {}, output_dim);
748 TF_RET_CHECK(output_dynamic_size != nullptr);
749 gather = reshape->AddInstruction(HloInstruction::CreateSetDimensionSize(
750 gather->shape(), gather, output_dynamic_size, output_dim));
751 auto users = reshape->users();
752 for (auto* user : users) {
753 // Avoid cycles by not replacing the static reshape and get_dimension_size.
754 if (user != reshape_static && user != output_dynamic_size) {
755 TF_RETURN_IF_ERROR(reshape->ReplaceUseWith(user, gather));
756 }
757 }
758
759 if (reshape == reshape->parent()->root_instruction()) {
760 reshape->parent()->set_root_instruction(gather);
761 }
762
763 TF_RETURN_IF_ERROR(
764 dynamic_dimension_inference->ForwardDynamicSize(reshape, gather, {}));
765
766 return true;
767 }
768
RewriteDynamicReshapeSingleGroup(HloInstruction * reshape,absl::Span<const int64_t> input_dims,absl::Span<const int64_t> output_dims,absl::Span<HloInstruction * > input_dynamic_dims,absl::Span<HloInstruction * > output_dynamic_dims,DynamicDimensionInference * dynamic_dimension_inference)769 StatusOr<bool> RewriteDynamicReshapeSingleGroup(
770 HloInstruction* reshape, absl::Span<const int64_t> input_dims,
771 absl::Span<const int64_t> output_dims,
772 absl::Span<HloInstruction*> input_dynamic_dims,
773 absl::Span<HloInstruction*> output_dynamic_dims,
774 DynamicDimensionInference* dynamic_dimension_inference) {
775 VLOG(2) << "Rewriting dynamic reshape " << reshape->ToString()
776 << " input dims: " << VectorString(input_dims)
777 << " output dims: " << VectorString(output_dims);
778
779 const Shape operand_shape = reshape->operand(0)->shape();
780 const Shape output_shape = reshape->shape();
781
782 if (input_dims.size() == 1) {
783 int64_t input_dim = input_dims[0];
784 // Size 1 dimension doesn't need a rewrite.
785 if (operand_shape.dimensions()[input_dim] == 1) {
786 return false;
787 }
788 // One input dimension is split into multiple output dimensions.
789 return RewriteDynamicReshapeSplitInput(reshape, input_dim, output_dims,
790 output_dynamic_dims,
791 dynamic_dimension_inference);
792 }
793
794 if (output_dims.size() == 1) {
795 int64_t output_dim = output_dims[0];
796 if (output_shape.dimensions()[output_dim] == 1) {
797 return false;
798 }
799 // One input dimension is split into multiple output dimensions.
800 return RewriteDynamicReshapeCombineInput(reshape, input_dims, output_dim,
801 input_dynamic_dims,
802 dynamic_dimension_inference);
803 }
804
805 // Shouldn't get here.
806 TF_RET_CHECK(false);
807 return false;
808 }
809
RewriteReverse(HloInstruction * reverse,DynamicDimensionInference * dynamic_dimension_inference)810 StatusOr<bool> RewriteReverse(
811 HloInstruction* reverse,
812 DynamicDimensionInference* dynamic_dimension_inference) {
813 // When we have [A, B, C, D, E] and reverse them, we get [E, D, C, B, A].
814 // However, if the dynamic size is 2, we expect B, A to be in front:
815 // [B, A, P, P, P].
816 //
817 // We do this by running a pad and dynamic slice on the result:
818 // [A, B, C, D, E]
819 // |
820 // reverse
821 // |
822 // [E, D, C, B, A]
823 // |
824 // pad # Use pad to double the size of the dimension to avoid OOB.
825 // |
826 // [E, D, C, B, A, P, P, P, P, P]
827 // |
828 // dynamic slice
829 // |
830 // [B, A, P, P, P]
831 auto reverse_dims = reverse->dimensions();
832 const Shape& reverse_shape = reverse->shape();
833 std::set<int64_t> dynamic_reverse_dims;
834 for (int64_t reverse_dim : reverse_dims) {
835 HloInstruction* dynamic_size =
836 dynamic_dimension_inference->GetDynamicSize(reverse, {}, reverse_dim);
837 if (dynamic_size == nullptr) {
838 // Reverse dimension is not dynamic -- no rewrite needed.
839 continue;
840 }
841 dynamic_reverse_dims.insert(reverse_dim);
842 }
843
844 if (dynamic_reverse_dims.empty()) {
845 // We only need to rewrite dynamic dimensions that are also reverse
846 // dimensions.
847 return false;
848 }
849
850 PaddingConfig padding;
851 // Doubles dynamic dimension size using a pad.
852 Shape pad_shape = reverse_shape;
853 for (int i = 0; i < reverse_shape.rank(); ++i) {
854 auto dimension = padding.add_dimensions();
855 if (dynamic_reverse_dims.count(i) > 0) {
856 dimension->set_edge_padding_low(0);
857 dimension->set_edge_padding_high(reverse_shape.dimensions(i));
858 dimension->set_interior_padding(0);
859 pad_shape.set_dimensions(i, 2 * pad_shape.dimensions(i));
860 }
861 }
862 HloInstruction* cloned_reverse = reverse->AddInstruction(reverse->Clone());
863 HloInstruction* zero = reverse->AddInstruction(HloInstruction::CreateConstant(
864 LiteralUtil::Zero(pad_shape.element_type())));
865 HloInstruction* pad = reverse->AddInstruction(
866 HloInstruction::CreatePad(pad_shape, cloned_reverse, zero, padding));
867 std::vector<HloInstruction*> start_indices;
868 start_indices.reserve(reverse_shape.rank());
869 for (int i = 0; i < reverse_shape.rank(); ++i) {
870 if (dynamic_reverse_dims.count(i) > 0) {
871 // Start at bound_size - dynamic_size.
872 HloInstruction* bound_size =
873 reverse->AddInstruction(HloInstruction::CreateConstant(
874 LiteralUtil::CreateR0<int32_t>(reverse_shape.dimensions(i))));
875 HloInstruction* dynamic_size =
876 dynamic_dimension_inference->GetDynamicSize(reverse, {}, i);
877 HloInstruction* start_offset =
878 reverse->AddInstruction(HloInstruction::CreateBinary(
879 ShapeUtil::MakeScalarShape(S32), HloOpcode::kSubtract, bound_size,
880 dynamic_size));
881 start_indices.push_back(start_offset);
882 } else {
883 HloInstruction* zero = reverse->AddInstruction(
884 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
885 start_indices.push_back(zero);
886 }
887 }
888 HloInstruction* dynamic_reverse =
889 reverse->AddInstruction(HloInstruction::CreateDynamicSlice(
890 reverse_shape, pad, start_indices, reverse_shape.dimensions()));
891 TF_RETURN_IF_ERROR(
892 reverse->parent()->ReplaceInstruction(reverse, dynamic_reverse));
893 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
894 reverse, dynamic_reverse, {}));
895 return true;
896 }
897
RewriteInputWithDynamicPadding(HloInstruction * conv,HloInstruction * input,HloInstruction * padding_value,absl::Span<HloInstruction * > padding_before,Window * input_window,std::function<int64_t (int64_t)> window_dim_to_shape_dim)898 HloInstruction* RewriteInputWithDynamicPadding(
899 HloInstruction* conv, HloInstruction* input, HloInstruction* padding_value,
900 absl::Span<HloInstruction*> padding_before, Window* input_window,
901 std::function<int64_t(int64_t)> window_dim_to_shape_dim) {
902 HloInstruction* zero_s32 = conv->AddInstruction(
903 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
904 // Padded shape represents the bounded shape after dynamic padding.
905 Shape padded_shape = input->shape();
906 PaddingConfig padding_configs;
907 for (int64_t i = 0; i < input->shape().rank(); ++i) {
908 PaddingConfig::PaddingConfigDimension padding_dim;
909 *padding_configs.add_dimensions() = padding_dim;
910 }
911 std::vector<HloInstruction*> start_indices(input->shape().rank(), zero_s32);
912 for (int64_t dim_index = 0; dim_index < input_window->dimensions_size();
913 ++dim_index) {
914 if (padding_before[dim_index] == nullptr) {
915 continue;
916 }
917 int64_t shape_dim = window_dim_to_shape_dim(dim_index);
918
919 WindowDimension* window_dim = input_window->mutable_dimensions(dim_index);
920 auto* padding_dim = padding_configs.mutable_dimensions(shape_dim);
921 const int64_t dilated_window_size = window_util::DilatedBound(
922 window_dim->size(), window_dim->window_dilation());
923 // Use dilated window size as low padding and static padding_high +
924 // padding_low as high padding to make sure the following dynamic slice is
925 // valid and doesn't go out of bound.
926 //
927 // See go/xla-dynamic-spatial-dim for more details.
928 padding_dim->set_edge_padding_low(dilated_window_size);
929 padding_dim->set_edge_padding_high(window_dim->padding_high() +
930 window_dim->padding_low());
931 padding_dim->set_interior_padding(window_dim->base_dilation() - 1);
932 HloInstruction* slicing_start =
933 conv->AddInstruction(HloInstruction::CreateBinary(
934 ShapeUtil::MakeScalarShape(S32), HloOpcode::kSubtract,
935 conv->AddInstruction(
936 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
937 padding_dim->edge_padding_low()))),
938 padding_before[dim_index]));
939 start_indices[shape_dim] = slicing_start;
940
941 padded_shape.mutable_dimensions()[shape_dim] =
942 window_dim->padding_low() +
943 window_util::DilatedBound(padded_shape.dimensions(shape_dim),
944 window_dim->base_dilation()) +
945 window_dim->padding_high();
946 window_dim->clear_padding_high();
947 window_dim->clear_padding_low();
948 window_dim->set_base_dilation(1);
949 input->mutable_shape()->set_dynamic_dimension(shape_dim, false);
950 }
951 // Reconstruct dynamic padding using pad and dynamic slice.
952
953 HloInstruction* pad =
954 MakePadHlo(input, padding_value, padding_configs).ValueOrDie();
955 input = conv->AddInstruction(HloInstruction::CreateDynamicSlice(
956 padded_shape, pad, start_indices, padded_shape.dimensions()));
957 return input;
958 }
959
RewriteDynamicConvolutionInputGrad(HloInstruction * custom_call_conv,DynamicDimensionInference * dynamic_dimension_inference)960 StatusOr<bool> RewriteDynamicConvolutionInputGrad(
961 HloInstruction* custom_call_conv,
962 DynamicDimensionInference* dynamic_dimension_inference) {
963 HloInstruction* grad = custom_call_conv->mutable_operand(1);
964 HloInstruction* kernel = custom_call_conv->mutable_operand(2);
965 TF_RET_CHECK(kernel->shape().is_static());
966 auto dnums = custom_call_conv->convolution_dimension_numbers();
967 Window window = custom_call_conv->window();
968 HloInstruction* zero =
969 custom_call_conv->AddInstruction(HloInstruction::CreateConstant(
970 LiteralUtil::Zero(custom_call_conv->shape().element_type())));
971 std::vector<HloInstruction*> padding_before(
972 dnums.input_spatial_dimensions_size(), nullptr);
973 for (int64_t spatial_dim_index = 0;
974 spatial_dim_index < dnums.input_spatial_dimensions_size();
975 ++spatial_dim_index) {
976 int64_t input_spatial_dim =
977 dnums.input_spatial_dimensions(spatial_dim_index);
978 HloInstruction* operand_dynamic_size =
979 dynamic_dimension_inference->GetDynamicSize(
980 custom_call_conv->mutable_operand(1), {}, input_spatial_dim);
981 if (operand_dynamic_size == nullptr) {
982 continue;
983 }
984 grad = PadWithScalar(grad, input_spatial_dim, operand_dynamic_size, zero);
985 HloInstruction* slice =
986 custom_call_conv->AddInstruction(HloInstruction::CreateSlice(
987 ShapeUtil::MakeShape(S32, {1}),
988 custom_call_conv->mutable_operand(0), {input_spatial_dim},
989 {input_spatial_dim + 1}, {1}));
990 HloInstruction* dynamic_input_size = custom_call_conv->AddInstruction(
991 HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
992 const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
993 // Window stride of forward prop is same as base dilation of backward prop.
994 DynamicWindowDims dynamic_window_dims = GetWindowedInputGradSize(
995 dynamic_input_size, /*window_size=*/window_dim.size(),
996 /*window_dilation=*/window_dim.window_dilation(),
997 /*window_stride=*/window_dim.base_dilation(),
998 custom_call_conv->padding_type());
999 padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
1000 }
1001
1002 if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
1003 grad = RewriteInputWithDynamicPadding(
1004 custom_call_conv, grad, zero, absl::MakeSpan(padding_before), &window,
1005 [&](int64_t dim) { return dnums.input_spatial_dimensions(dim); });
1006 }
1007
1008 PrecisionConfig precision_config;
1009 if (custom_call_conv->precision_config().operand_precision_size() == 3) {
1010 // We are not interested in the precision config of the first operand, which
1011 // is the input_sizes.
1012 *precision_config.mutable_operand_precision() = {
1013 custom_call_conv->precision_config().operand_precision().begin() + 1,
1014 custom_call_conv->precision_config().operand_precision().end()};
1015 }
1016 HloInstruction* static_conv =
1017 custom_call_conv->AddInstruction(HloInstruction::CreateConvolve(
1018 custom_call_conv->shape(), grad, kernel,
1019 custom_call_conv->feature_group_count(),
1020 custom_call_conv->batch_group_count(), window,
1021 custom_call_conv->convolution_dimension_numbers(),
1022 custom_call_conv->precision_config()));
1023 TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
1024 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1025 custom_call_conv, static_conv, {}));
1026 return true;
1027 }
1028
RewriteDynamicConvolutionForward(HloInstruction * custom_call_conv,DynamicDimensionInference * dynamic_dimension_inference)1029 StatusOr<bool> RewriteDynamicConvolutionForward(
1030 HloInstruction* custom_call_conv,
1031 DynamicDimensionInference* dynamic_dimension_inference) {
1032 HloInstruction* input = custom_call_conv->mutable_operand(0);
1033 HloInstruction* kernel = custom_call_conv->mutable_operand(1);
1034 TF_RET_CHECK(kernel->shape().is_static());
1035 TF_RET_CHECK(input->shape().is_dynamic());
1036 Window window = custom_call_conv->window();
1037 auto dnums = custom_call_conv->convolution_dimension_numbers();
1038 HloInstruction* zero =
1039 custom_call_conv->AddInstruction(HloInstruction::CreateConstant(
1040 LiteralUtil::Zero(custom_call_conv->shape().element_type())));
1041 std::vector<HloInstruction*> padding_before(
1042 dnums.input_spatial_dimensions_size(), nullptr);
1043 for (int64_t spatial_dim_index = 0;
1044 spatial_dim_index < dnums.input_spatial_dimensions_size();
1045 ++spatial_dim_index) {
1046 int64_t input_spatial_dim =
1047 dnums.input_spatial_dimensions(spatial_dim_index);
1048 HloInstruction* operand_dynamic_size =
1049 dynamic_dimension_inference->GetDynamicSize(
1050 custom_call_conv->mutable_operand(0), {}, input_spatial_dim);
1051 if (operand_dynamic_size == nullptr) {
1052 continue;
1053 }
1054
1055 input = PadWithScalar(input, input_spatial_dim, operand_dynamic_size, zero);
1056 const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
1057 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1058 operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
1059 window_dim.stride(), custom_call_conv->padding_type());
1060 padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
1061 }
1062 // Input feature dim can be dynamic too, reset it to zero.
1063 const int64_t input_feature_dim = dnums.input_feature_dimension();
1064 if (HloInstruction* input_feature_dynamic_size =
1065 dynamic_dimension_inference->GetDynamicSize(
1066 custom_call_conv->mutable_operand(0), {}, input_feature_dim)) {
1067 input = PadWithScalar(input, input_feature_dim, input_feature_dynamic_size,
1068 zero);
1069 }
1070
1071 if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
1072 input = RewriteInputWithDynamicPadding(
1073 custom_call_conv, input, zero, absl::MakeSpan(padding_before), &window,
1074 [&](int64_t dim) { return dnums.input_spatial_dimensions(dim); });
1075 }
1076
1077 HloInstruction* static_conv =
1078 custom_call_conv->AddInstruction(HloInstruction::CreateConvolve(
1079 custom_call_conv->shape(), input, kernel,
1080 custom_call_conv->feature_group_count(),
1081 custom_call_conv->batch_group_count(), window,
1082 custom_call_conv->convolution_dimension_numbers(),
1083 custom_call_conv->precision_config()));
1084 TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
1085 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1086 custom_call_conv, static_conv, {}));
1087 return true;
1088 }
1089
RewriteDynamicConvolutionKernelGrad(HloInstruction * custom_call_conv,DynamicDimensionInference * dynamic_dimension_inference)1090 StatusOr<bool> RewriteDynamicConvolutionKernelGrad(
1091 HloInstruction* custom_call_conv,
1092 DynamicDimensionInference* dynamic_dimension_inference) {
1093 HloInstruction* activations = custom_call_conv->mutable_operand(0);
1094 HloInstruction* gradients = custom_call_conv->mutable_operand(1);
1095 TF_RET_CHECK(activations->shape().is_dynamic());
1096 TF_RET_CHECK(gradients->shape().is_dynamic());
1097 Window window = custom_call_conv->window();
1098 auto dnums = custom_call_conv->convolution_dimension_numbers();
1099 HloInstruction* zero =
1100 custom_call_conv->AddInstruction(HloInstruction::CreateConstant(
1101 LiteralUtil::Zero(custom_call_conv->shape().element_type())));
1102 std::vector<HloInstruction*> padding_before(
1103 dnums.input_spatial_dimensions_size(), nullptr);
1104 for (int64_t spatial_dim_index = 0;
1105 spatial_dim_index < dnums.input_spatial_dimensions_size();
1106 ++spatial_dim_index) {
1107 int64_t input_spatial_dim =
1108 dnums.input_spatial_dimensions(spatial_dim_index);
1109 int64_t kernel_spatial_dim =
1110 dnums.kernel_spatial_dimensions(spatial_dim_index);
1111 HloInstruction* activations_dynamic_size =
1112 dynamic_dimension_inference->GetDynamicSize(
1113 custom_call_conv->mutable_operand(0), {}, input_spatial_dim);
1114 if (activations_dynamic_size != nullptr) {
1115 activations = PadWithScalar(activations, input_spatial_dim,
1116 activations_dynamic_size, zero);
1117 }
1118
1119 HloInstruction* gradients_dynamic_size =
1120 dynamic_dimension_inference->GetDynamicSize(
1121 custom_call_conv->mutable_operand(1), {}, kernel_spatial_dim);
1122 if (gradients_dynamic_size != nullptr) {
1123 gradients = PadWithScalar(gradients, kernel_spatial_dim,
1124 gradients_dynamic_size, zero);
1125 }
1126 if (activations_dynamic_size == nullptr ||
1127 gradients_dynamic_size == nullptr) {
1128 TF_RET_CHECK(activations_dynamic_size == nullptr &&
1129 gradients_dynamic_size == nullptr);
1130 continue;
1131 }
1132 int64_t output_spatial_dim =
1133 dnums.output_spatial_dimensions(spatial_dim_index);
1134 const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
1135 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1136 activations_dynamic_size, /*window_size=*/
1137 custom_call_conv->shape().dimensions(output_spatial_dim),
1138 /*window_dilation=*/window_dim.stride(),
1139 /*window_stride=*/window_dim.window_dilation(),
1140 custom_call_conv->padding_type());
1141 padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
1142 }
1143
1144 // We only need to pad input feature on lhs to 0 -- it's mathematically
1145 // equivalent to padding both lhs and rhs to 0.
1146 const int64_t input_feature_dim = dnums.input_feature_dimension();
1147 if (HloInstruction* input_feature_dynamic_size =
1148 dynamic_dimension_inference->GetDynamicSize(
1149 custom_call_conv->mutable_operand(0), {}, input_feature_dim)) {
1150 activations = PadWithScalar(activations, input_feature_dim,
1151 input_feature_dynamic_size, zero);
1152 }
1153
1154 if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
1155 activations = RewriteInputWithDynamicPadding(
1156 custom_call_conv, activations, zero, absl::MakeSpan(padding_before),
1157 &window,
1158 [&](int64_t dim) { return dnums.input_spatial_dimensions(dim); });
1159 }
1160
1161 HloInstruction* static_conv =
1162 custom_call_conv->AddInstruction(HloInstruction::CreateConvolve(
1163 custom_call_conv->shape(), activations, gradients,
1164 custom_call_conv->feature_group_count(),
1165 custom_call_conv->batch_group_count(), window,
1166 custom_call_conv->convolution_dimension_numbers(),
1167 custom_call_conv->precision_config()));
1168 TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
1169 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1170 custom_call_conv, static_conv, {}));
1171 return true;
1172 }
1173
RewriteDynamicReduceWindowSamePadding(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1174 StatusOr<bool> RewriteDynamicReduceWindowSamePadding(
1175 HloInstruction* hlo,
1176 DynamicDimensionInference* dynamic_dimension_inference) {
1177 if (hlo->shape().IsTuple()) {
1178 // TODO (b/73062247) variadic reduce window is not yet supported here.
1179 return Unimplemented("DynamicReduceWindowSamePadding not yet supported.");
1180 }
1181 HloInstruction* input = hlo->mutable_operand(0);
1182 HloInstruction* init = hlo->mutable_operand(1);
1183 int64_t rank = hlo->shape().rank();
1184 Window window = hlo->window();
1185 std::vector<HloInstruction*> padding_before(hlo->shape().rank(), nullptr);
1186 for (int64_t dim_index = 0; dim_index < rank; ++dim_index) {
1187 HloInstruction* operand_dynamic_size =
1188 dynamic_dimension_inference->GetDynamicSize(hlo->mutable_operand(0), {},
1189 dim_index);
1190 if (operand_dynamic_size == nullptr) {
1191 continue;
1192 }
1193 const WindowDimension& window_dim = window.dimensions(dim_index);
1194 if (window_util::IsTrivialWindowDimension(window_dim)) {
1195 continue;
1196 }
1197 input = PadWithScalar(input, dim_index, operand_dynamic_size, init);
1198
1199 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1200 operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
1201 window_dim.stride(), PaddingType::PADDING_SAME);
1202 padding_before[dim_index] = dynamic_window_dims.padding_before;
1203 }
1204
1205 input = RewriteInputWithDynamicPadding(
1206 hlo, input, init, absl::MakeSpan(padding_before), &window,
1207 [](int64_t dim) { return dim; });
1208
1209 HloInstruction* rewritten =
1210 hlo->AddInstruction(HloInstruction::CreateReduceWindow(
1211 hlo->shape(), input, init, window, hlo->called_computations()[0]));
1212 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rewritten));
1213 TF_RETURN_IF_ERROR(
1214 dynamic_dimension_inference->ForwardDynamicSize(hlo, rewritten, {}));
1215 return true;
1216 }
1217
RewriteDynamicSelectAndScatterSamePadding(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1218 StatusOr<bool> RewriteDynamicSelectAndScatterSamePadding(
1219 HloInstruction* hlo,
1220 DynamicDimensionInference* dynamic_dimension_inference) {
1221 HloInstruction* input = hlo->mutable_operand(0);
1222 HloInstruction* source = hlo->mutable_operand(1);
1223 HloInstruction* init = hlo->mutable_operand(2);
1224 TF_ASSIGN_OR_RETURN(HloInstruction * input_padding_value,
1225 ChooseIdentityValue(hlo, /*operand_number=*/0));
1226 int64_t rank = hlo->shape().rank();
1227 Window window = hlo->window();
1228 std::vector<HloInstruction*> padding_before(hlo->shape().rank(), nullptr);
1229 for (int64_t dim_index = 0; dim_index < rank; ++dim_index) {
1230 const WindowDimension& window_dim = window.dimensions(dim_index);
1231 if (window_util::IsTrivialWindowDimension(window_dim)) {
1232 continue;
1233 }
1234 HloInstruction* operand_dynamic_size =
1235 dynamic_dimension_inference->GetDynamicSize(hlo->mutable_operand(0), {},
1236 dim_index);
1237 if (operand_dynamic_size == nullptr) {
1238 continue;
1239 }
1240
1241 input = PadWithScalar(input, dim_index, operand_dynamic_size,
1242 input_padding_value);
1243
1244 HloInstruction* source_dynamic_size =
1245 dynamic_dimension_inference->GetDynamicSize(hlo->mutable_operand(1), {},
1246 dim_index);
1247 if (source_dynamic_size == nullptr) {
1248 continue;
1249 }
1250 source = PadWithScalar(source, dim_index, source_dynamic_size, init);
1251
1252 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1253 operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
1254 window_dim.stride(), PaddingType::PADDING_SAME);
1255 padding_before[dim_index] = dynamic_window_dims.padding_before;
1256 }
1257
1258 input = RewriteInputWithDynamicPadding(
1259 hlo, input, input_padding_value, absl::MakeSpan(padding_before), &window,
1260 [](int64_t dim) { return dim; });
1261
1262 // RewriteInputWithDynamicPadding adds padding to the input. However those
1263 // inputs should not be materialized in select and scatter's output and we
1264 // need to slice them out using dynamic slice. To prevent dynamic slicegoing
1265 // OOB, we first add some high-pad to the output to leave enough space.
1266 HloInstruction* rewritten =
1267 hlo->AddInstruction(HloInstruction::CreateSelectAndScatter(
1268 input->shape(), input, hlo->called_computations()[0], window, source,
1269 init, hlo->called_computations()[1]));
1270 std::vector<HloInstruction*> start_indices(
1271 input->shape().rank(), hlo->AddInstruction(HloInstruction::CreateConstant(
1272 LiteralUtil::Zero(S32))));
1273 PaddingConfig padding_configs;
1274 for (int64_t dim_index = 0; dim_index < rank; ++dim_index) {
1275 PaddingConfig::PaddingConfigDimension padding_dim;
1276 if (padding_before[dim_index] != nullptr) {
1277 const WindowDimension& window_dim = window.dimensions(dim_index);
1278 const int64_t dilated_window_size = window_util::DilatedBound(
1279 window_dim.size(), window_dim.window_dilation());
1280 padding_dim.set_edge_padding_high(dilated_window_size);
1281 start_indices[dim_index] = padding_before[dim_index];
1282 }
1283 *padding_configs.add_dimensions() = padding_dim;
1284 }
1285 HloInstruction* padded =
1286 MakePadHlo(rewritten, init, padding_configs).ValueOrDie();
1287 rewritten = hlo->AddInstruction(HloInstruction::CreateDynamicSlice(
1288 hlo->shape(), padded, start_indices, hlo->shape().dimensions()));
1289 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rewritten));
1290 TF_RETURN_IF_ERROR(
1291 dynamic_dimension_inference->ForwardDynamicSize(hlo, rewritten, {}));
1292 return true;
1293 }
1294
RewriteDynamicConcat(HloInstruction * concat,DynamicDimensionInference * dynamic_dimension_inference)1295 StatusOr<bool> RewriteDynamicConcat(
1296 HloInstruction* concat,
1297 DynamicDimensionInference* dynamic_dimension_inference) {
1298 const int64_t concat_dim = concat->concatenate_dimension();
1299 if (dynamic_dimension_inference->GetDynamicSize(concat, {}, concat_dim) ==
1300 nullptr) {
1301 // Concat dimension is not dynamic -- no rewrite needed.
1302 return false;
1303 }
1304 std::vector<HloInstruction*> offsets;
1305 for (int64_t i = 0; i < concat->shape().dimensions_size(); ++i) {
1306 offsets.push_back(concat->AddInstruction(
1307 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0))));
1308 }
1309 HloInstruction* rewritten_concat = concat;
1310 // Keep track of previous users before rewrite so that we can update their
1311 // operands later.
1312 auto prev_users = concat->users();
1313 for (int64_t i = 0; i < concat->operand_count(); ++i) {
1314 // Rewrite the concat by dynamic update slicing operand into the concat dim.
1315 HloInstruction* operand = concat->mutable_operand(i);
1316 rewritten_concat =
1317 concat->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1318 rewritten_concat->shape(), rewritten_concat, operand, offsets));
1319 // Update the offset of concat dimension by adding the size of the concat
1320 // dimension of the operand to it.
1321 HloInstruction* dynamic_size =
1322 dynamic_dimension_inference->GetDynamicSize(operand, {}, concat_dim);
1323 if (dynamic_size == nullptr) {
1324 HloInstruction* static_size = concat->AddInstruction(
1325 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
1326 operand->shape().dimensions(concat_dim))));
1327 offsets[concat_dim] = concat->AddInstruction(HloInstruction::CreateBinary(
1328 ShapeUtil::MakeScalarShape(S32), HloOpcode::kAdd, offsets[concat_dim],
1329 static_size));
1330 } else {
1331 offsets[concat_dim] = concat->AddInstruction(HloInstruction::CreateBinary(
1332 ShapeUtil::MakeScalarShape(S32), HloOpcode::kAdd, offsets[concat_dim],
1333 dynamic_size));
1334 }
1335 }
1336 TF_RETURN_IF_ERROR(concat->ReplaceUsesWith(prev_users, rewritten_concat));
1337 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1338 concat, rewritten_concat, {}));
1339 return true;
1340 }
1341
RewriteDynamicSort(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1342 StatusOr<bool> RewriteDynamicSort(
1343 HloInstruction* hlo,
1344 DynamicDimensionInference* dynamic_dimension_inference) {
1345 HloInstruction* dynamic_size = nullptr;
1346 HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
1347 int64_t sort_dim = sort->sort_dimension();
1348 // Find the dynamic dimension in the operand.
1349 for (auto* operand : sort->operands()) {
1350 if (dynamic_size == nullptr) {
1351 dynamic_size =
1352 dynamic_dimension_inference->GetDynamicSize(operand, {}, sort_dim);
1353 }
1354 }
1355
1356 if (dynamic_size == nullptr) {
1357 // Not a dynamic sort, ignore.
1358 return false;
1359 }
1360
1361 Shape operand_shape =
1362 ShapeUtil::ChangeElementType(sort->operand(0)->shape(), S32);
1363 HloInstruction* iota =
1364 hlo->AddInstruction(HloInstruction::CreateIota(operand_shape, sort_dim));
1365 HloInstruction* dynamic_size_broadcasted = hlo->AddInstruction(
1366 HloInstruction::CreateBroadcast(operand_shape, dynamic_size, {}));
1367 HloInstruction* lt = hlo->AddInstruction(HloInstruction::CreateCompare(
1368 ShapeUtil::ChangeElementType(operand_shape, PRED), iota,
1369 dynamic_size_broadcasted, ComparisonDirection::kLt));
1370 sort->AppendOperand(lt);
1371
1372 const int64_t param_number_before_rewritten =
1373 sort->called_computations()[0]->num_parameters();
1374 auto new_param_0 = HloInstruction::CreateParameter(
1375 param_number_before_rewritten, ShapeUtil::MakeScalarShape(PRED),
1376 "inbound_lhs");
1377 auto new_param_1 = HloInstruction::CreateParameter(
1378 param_number_before_rewritten + 1, ShapeUtil::MakeScalarShape(PRED),
1379 "inbound_rhs");
1380 std::vector<const HloInstruction*> extra_parameters{new_param_0.get(),
1381 new_param_1.get()};
1382 HloComputation* sort_comp = sort->parent()->parent()->AddEmbeddedComputation(
1383 sort->called_computations()[0]->CloneWithReplacements(
1384 /*replacements=*/nullptr, extra_parameters));
1385 auto inbound_lhs =
1386 sort_comp->parameter_instruction(param_number_before_rewritten);
1387 auto inbound_rhs =
1388 sort_comp->parameter_instruction(param_number_before_rewritten + 1);
1389 sort->ReplaceCalledComputations(
1390 [&](HloComputation* comp) { return sort_comp; });
1391
1392 // inbound_lhs & (sort_comp | !in_bound_rhs)
1393 // Select the lhs if it is in bounds and the rhs is out of bounds or the
1394 // sort_comp returns true.
1395 auto out_of_bound_rhs = sort_comp->AddInstruction(HloInstruction::CreateUnary(
1396 ShapeUtil::MakeScalarShape(PRED), HloOpcode::kNot, inbound_rhs));
1397 auto sort_comp_or_out_of_bound_rhs =
1398 sort_comp->AddInstruction(HloInstruction::CreateBinary(
1399 ShapeUtil::MakeScalarShape(PRED), HloOpcode::kOr,
1400 sort_comp->root_instruction(), out_of_bound_rhs));
1401
1402 auto new_root = sort_comp->AddInstruction(HloInstruction::CreateBinary(
1403 ShapeUtil::MakeScalarShape(PRED), HloOpcode::kAnd, inbound_lhs,
1404 sort_comp_or_out_of_bound_rhs));
1405 sort_comp->set_root_instruction(new_root);
1406 Shape compare_shape =
1407 ShapeUtil::ChangeElementType(sort->operand(0)->shape(), PRED);
1408 if (sort->shape().IsTuple()) {
1409 // For sort that is already tuple, simply add another result to the tuple.
1410 *sort->mutable_shape()->add_tuple_shapes() =
1411 ShapeUtil::ChangeElementType(operand_shape, PRED);
1412 } else {
1413 auto sort_users = sort->users();
1414 auto sort_clone = hlo->AddInstruction(sort->Clone());
1415 *sort_clone->mutable_shape() = ShapeUtil::MakeTupleShape(
1416 {sort->shape(), ShapeUtil::ChangeElementType(operand_shape, PRED)});
1417 auto rewritten_sort = hlo->AddInstruction(
1418 HloInstruction::CreateGetTupleElement(sort->shape(), sort_clone, 0));
1419 for (HloInstruction* user : sort_users) {
1420 TF_RETURN_IF_ERROR(sort->ReplaceUseWith(user, rewritten_sort));
1421 }
1422 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1423 sort, rewritten_sort, {}));
1424 if (hlo->parent()->root_instruction() == sort) {
1425 hlo->parent()->set_root_instruction(rewritten_sort);
1426 }
1427 }
1428
1429 return true;
1430 }
1431
RewriteDynamicBinaryOp(HloInstruction * binary,DynamicDimensionInference * dynamic_dimension_inference)1432 StatusOr<bool> RewriteDynamicBinaryOp(
1433 HloInstruction* binary,
1434 DynamicDimensionInference* dynamic_dimension_inference) {
1435 HloInstruction* operand_0 = binary->mutable_operand(0);
1436 HloInstruction* operand_1 = binary->mutable_operand(1);
1437
1438 TF_RET_CHECK(operand_0->shape().rank() == operand_1->shape().rank());
1439 auto dims_0 = dynamic_dimension_inference->GetDynamicSizes(operand_0, {});
1440 auto dims_1 = dynamic_dimension_inference->GetDynamicSizes(operand_1, {});
1441 bool changed = false;
1442 for (int64_t i = 0; i < dims_0.size(); ++i) {
1443 HloInstruction* dim_0 = dims_0[i];
1444 HloInstruction* dim_1 = dims_1[i];
1445
1446 if (dims_0[i] != dims_1[i] && dims_0[i] != nullptr &&
1447 dims_1[i] != nullptr) {
1448 changed = true;
1449 // It is possible that a dynamic dimension of one operand is size 1 while
1450 // the other is greater than one. According to implicit broadcast
1451 // semantics, we need to insert broadcast in this case to make the dynamic
1452 // shape match.
1453
1454 // An implicit broadcast is inserted by slicing the small shape into a
1455 // size 1 slice, reshape out the size 1 dimension then broadcast to the
1456 // full shape:
1457 //
1458 // Input [2, <=5, 3]
1459 // |
1460 // Slice [2, 1, 3]
1461 // |
1462 // Reshape [2, 3]
1463 // |
1464 // Broadcast [2, 5, 3]
1465 auto rewrite_operand = [&](HloInstruction* pred,
1466 HloInstruction* operand) -> HloInstruction* {
1467 Shape static_shape = operand->shape();
1468 static_shape.clear_dynamic_dimensions();
1469 pred = binary->AddInstruction(HloInstruction::CreateBroadcast(
1470 ShapeUtil::ChangeElementType(static_shape, PRED), pred, {}));
1471 Shape slice_shape = static_shape;
1472 slice_shape.set_dimensions(i, 1);
1473 std::vector<int64_t> start_indices(slice_shape.rank(), 0);
1474 std::vector<int64_t> strides(slice_shape.rank(), 1);
1475 HloInstruction* slice = binary->AddInstruction(
1476 HloInstruction::CreateSlice(slice_shape, operand, start_indices,
1477 slice_shape.dimensions(), strides));
1478 Shape reshape_shape = ShapeUtil::DeleteDimension(i, slice_shape);
1479 HloInstruction* reshape = binary->AddInstruction(
1480 HloInstruction::CreateReshape(reshape_shape, slice));
1481 std::vector<int64_t> broadcast_dims;
1482 broadcast_dims.reserve(static_shape.rank() - 1);
1483 // Broadcast to all dims execpt for i.
1484 for (int64_t j = 0; j < static_shape.rank(); ++j) {
1485 if (j != i) {
1486 broadcast_dims.push_back(j);
1487 }
1488 }
1489
1490 HloInstruction* broadcast = binary->parent()->AddInstruction(
1491 HloInstruction::CreateBroadcast(static_shape, reshape,
1492 broadcast_dims),
1493 "implicit_broadcast");
1494
1495 // Use a select instead of conditional as elementwise operations promote
1496 // more fusion.
1497 HloInstruction* select =
1498 binary->AddInstruction(HloInstruction::CreateTernary(
1499 static_shape, HloOpcode::kSelect, pred, broadcast, operand));
1500 return select;
1501 };
1502
1503 HloInstruction* one = binary->AddInstruction(
1504 HloInstruction::CreateConstant(LiteralUtil::One(S32)));
1505
1506 auto operand_0_needs_broadcast = binary->parent()->AddInstruction(
1507 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_0,
1508 dim_1, ComparisonDirection::kLt),
1509 "lhs_less_than_rhs");
1510 auto is_one = binary->parent()->AddInstruction(
1511 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_0,
1512 one, ComparisonDirection::kEq),
1513 "lhs_is_one");
1514 operand_0_needs_broadcast = binary->parent()->AddInstruction(
1515 HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
1516 HloOpcode::kAnd, is_one,
1517 operand_0_needs_broadcast),
1518 "lhs_needs_implicit_broadcast");
1519 operand_0 = rewrite_operand(operand_0_needs_broadcast, operand_0);
1520
1521 auto operand_1_needs_broadcast = binary->parent()->AddInstruction(
1522 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_1,
1523 dim_0, ComparisonDirection::kLt),
1524 "rhs_less_than_lhs");
1525 is_one = binary->parent()->AddInstruction(
1526 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_1,
1527 one, ComparisonDirection::kEq),
1528 "rhs_is_one");
1529 operand_1_needs_broadcast = binary->parent()->AddInstruction(
1530 HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
1531 HloOpcode::kAnd, is_one,
1532 operand_1_needs_broadcast),
1533 "lhs_needs_implicit_broadcast");
1534 operand_1 = rewrite_operand(operand_1_needs_broadcast, operand_1);
1535 }
1536 }
1537 if (changed) {
1538 TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(0, operand_0));
1539 TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(1, operand_1));
1540 }
1541 return changed;
1542 }
1543
RewriteDynamicUpdateSlice(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1544 StatusOr<bool> RewriteDynamicUpdateSlice(
1545 HloInstruction* hlo,
1546 DynamicDimensionInference* dynamic_dimension_inference) {
1547 HloDynamicUpdateSliceInstruction* dus =
1548 Cast<HloDynamicUpdateSliceInstruction>(hlo);
1549 // Suppose we have a base area that we want to update:
1550 // +------------------------+
1551 // | |
1552 // | base |
1553 // | |
1554 // +------------------------+
1555 //
1556 // A partial update with dynamic padding looks like this:
1557 //
1558 // +------+-------+
1559 // |update|padding|
1560 // +------+-------+
1561 //
1562 // We don't want the padding to overwrite the base area:
1563 //
1564 // +------------------------+
1565 // | +------+-------+
1566 // |<-begin->|update|padding| (what we want to avoid)
1567 // | +------+-------+
1568 // +------------------------+
1569 //
1570 // Instead we want to keep the base area untouched except for the update
1571 // region:
1572 //
1573 // +------------------------+
1574 // | +------+ |
1575 // |<-begin->|update| base | (what we want)
1576 // | +------+ |
1577 // +------------------------+
1578 //
1579 // We do this by dynamic slicing the base area out first with the same begin
1580 // index:
1581 //
1582 // +--------------+
1583 // <-begin-> | base |
1584 // +--------------+
1585 //
1586 // Then replace the update's padding part with base:
1587 //
1588 // +------+-------+
1589 // |update| base |
1590 // +------+-------+
1591 //
1592 // Then do the DUS.
1593
1594 HloInstruction* update = dus->mutable_operand(1);
1595 HloInstruction* base = dus->mutable_operand(0);
1596 std::vector<HloInstruction*> dynamic_dims_in_partial_update(
1597 update->shape().rank(), nullptr);
1598 bool needs_rewrite = false;
1599 for (int64_t i = 0; i < update->shape().rank(); ++i) {
1600 if (update->shape().dimensions(i) < base->shape().dimensions(i)) {
1601 HloInstruction* dynamic_dim =
1602 dynamic_dimension_inference->GetDynamicSize(update, {}, i);
1603
1604 if (dynamic_dim != nullptr) {
1605 dynamic_dims_in_partial_update[i] = dynamic_dim;
1606 needs_rewrite = true;
1607 }
1608 }
1609 }
1610
1611 if (!needs_rewrite) {
1612 return false;
1613 }
1614 std::vector<HloInstruction*> indices;
1615 indices.reserve(dus->operand_count() - 2);
1616 for (int64_t i = 2; i < dus->operand_count(); ++i) {
1617 indices.push_back(dus->mutable_operand(i));
1618 }
1619 HloInstruction* base_slice =
1620 dus->AddInstruction(HloInstruction::CreateDynamicSlice(
1621 update->shape(), base, indices, update->shape().dimensions()));
1622
1623 for (int64_t i = 0; i < dynamic_dims_in_partial_update.size(); ++i) {
1624 HloInstruction* dynamic_dim = dynamic_dims_in_partial_update[i];
1625 if (dynamic_dim != nullptr) {
1626 Shape mask_shape_int = ShapeUtil::ChangeElementType(update->shape(), S32);
1627 Shape mask_shape_pred =
1628 ShapeUtil::ChangeElementType(update->shape(), PRED);
1629 // Generate mask using iota and dynamic_dim.
1630 HloInstruction* iota =
1631 dus->AddInstruction(HloInstruction::CreateIota(mask_shape_int, i));
1632 HloInstruction* broadcast_dim = dus->AddInstruction(
1633 HloInstruction::CreateBroadcast(mask_shape_int, dynamic_dim, {}));
1634 HloInstruction* pred = dus->AddInstruction(HloInstruction::CreateCompare(
1635 mask_shape_pred, iota, broadcast_dim, ComparisonDirection::kLt));
1636 // Update `update` to include base.
1637 update = dus->AddInstruction(HloInstruction::CreateTernary(
1638 update->shape(), HloOpcode::kSelect, pred, update, base_slice));
1639 }
1640 }
1641 TF_RETURN_IF_ERROR(dus->ReplaceOperandWith(1, update));
1642
1643 return true;
1644 }
1645
RewriteDynamicReshape(HloInstruction * reshape,DynamicDimensionInference * dynamic_dimension_inference)1646 StatusOr<bool> RewriteDynamicReshape(
1647 HloInstruction* reshape,
1648 DynamicDimensionInference* dynamic_dimension_inference) {
1649 bool changed = false;
1650 HloInstruction* operand = reshape->mutable_operand(0);
1651 std::vector<HloInstruction*> input_dynamic_dims;
1652 for (int64_t dim = 0; dim < operand->shape().dimensions_size(); ++dim) {
1653 input_dynamic_dims.push_back(
1654 dynamic_dimension_inference->GetDynamicSize(operand, {}, dim));
1655 }
1656
1657 std::vector<HloInstruction*> output_dynamic_dims;
1658 for (int64_t dim = 0; dim < reshape->shape().dimensions_size(); ++dim) {
1659 output_dynamic_dims.push_back(
1660 dynamic_dimension_inference->GetDynamicSize(reshape, {}, dim));
1661 }
1662
1663 auto common_factors = CommonFactors(operand->shape().dimensions(),
1664 reshape->shape().dimensions());
1665
1666 // Scan first to see if we need to decompose the reshape to a
1667 // flatten-unflatten pair.
1668 bool need_flatten_unflatten = false;
1669 auto is_dynamic_dimension = [&](int64_t dim) {
1670 HloInstruction* operand_dynamic_size =
1671 dynamic_dimension_inference->GetDynamicSize(reshape, {}, dim);
1672 return operand_dynamic_size != nullptr ||
1673 reshape->shape().is_dynamic_dimension(dim);
1674 };
1675
1676 auto should_skip_common_factor_group = [&](DimensionVector input_dims,
1677 DimensionVector output_dims) {
1678 if (input_dims.empty() || output_dims.empty()) {
1679 return true;
1680 }
1681 if (absl::c_none_of(output_dims, is_dynamic_dimension)) {
1682 // Don't need to rewrite any group without dynamic dimensions.
1683 VLOG(2) << "All dimensions are static in this common factor group";
1684 return true;
1685 }
1686 if (input_dims.size() == 1 && output_dims.size() == 1) {
1687 // The dimension is unchanged. No rewrite needed.
1688 return true;
1689 }
1690 return false;
1691 };
1692
1693 for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
1694 auto start = common_factors[i];
1695 auto end = common_factors[i + 1];
1696 DimensionVector input_dims;
1697 DimensionVector output_dims;
1698 for (int64_t dim = start.first; dim < end.first; ++dim) {
1699 input_dims.push_back(dim);
1700 }
1701 for (int64_t dim = start.second; dim < end.second; ++dim) {
1702 output_dims.push_back(dim);
1703 }
1704 if (should_skip_common_factor_group(input_dims, output_dims)) {
1705 continue;
1706 }
1707 if (input_dims.size() > 1 && output_dims.size() > 1) {
1708 need_flatten_unflatten = true;
1709 break;
1710 }
1711 }
1712
1713 if (need_flatten_unflatten) {
1714 VLOG(2) << "Rewrite dynamic reshape to flatten-unflatten pair. "
1715 << reshape->ToString();
1716 int64_t num_elements = ShapeUtil::ElementsIn(operand->shape());
1717 Shape flattened_shape =
1718 ShapeUtil::MakeShape(operand->shape().element_type(), {num_elements});
1719 HloInstruction* flatten = operand->AddInstruction(
1720 HloInstruction::CreateReshape(flattened_shape, operand));
1721
1722 HloInstruction* dynamic_size =
1723 operand->AddInstruction(HloInstruction::CreateConstant(
1724 LiteralUtil::CreateR0<int32_t>(num_elements)));
1725 for (int64_t i = 0; i < operand->shape().rank(); i++) {
1726 HloInstruction* dynamic_dim_size =
1727 dynamic_dimension_inference->GetDynamicSize(operand, {}, i);
1728 if (dynamic_dim_size != nullptr) {
1729 HloInstruction* static_dim_size = operand->AddInstruction(
1730 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
1731 operand->shape().dimensions(i))));
1732 dynamic_size = operand->AddInstruction(HloInstruction::CreateBinary(
1733 dynamic_size->shape(), HloOpcode::kDivide, dynamic_size,
1734 static_dim_size));
1735 dynamic_size = operand->AddInstruction(HloInstruction::CreateBinary(
1736 dynamic_size->shape(), HloOpcode::kMultiply, dynamic_size,
1737 dynamic_dim_size));
1738 }
1739 }
1740 dynamic_dimension_inference->SetDynamicSize(flatten, {}, 0, dynamic_size);
1741
1742 HloInstruction* unflatten = reshape->AddInstruction(
1743 HloInstruction::CreateReshape(reshape->shape(), flatten));
1744 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1745 reshape, unflatten, {}));
1746
1747 TF_ASSIGN_OR_RETURN(
1748 bool changed_unused,
1749 RewriteDynamicReshape(flatten, dynamic_dimension_inference));
1750 TF_ASSIGN_OR_RETURN(
1751 changed_unused,
1752 RewriteDynamicReshape(unflatten, dynamic_dimension_inference));
1753 TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(unflatten));
1754
1755 return true;
1756 }
1757
1758 // Find common_factors that the input belongs to.
1759 for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
1760 auto start = common_factors[i];
1761 auto end = common_factors[i + 1];
1762 DimensionVector input_dims;
1763 DimensionVector output_dims;
1764 for (int64_t dim = start.first; dim < end.first; ++dim) {
1765 input_dims.push_back(dim);
1766 }
1767 for (int64_t dim = start.second; dim < end.second; ++dim) {
1768 output_dims.push_back(dim);
1769 }
1770
1771 VLOG(2) << "input_dims: " << VectorString(input_dims);
1772 VLOG(2) << "output_dims: " << VectorString(output_dims);
1773
1774 if (should_skip_common_factor_group(input_dims, output_dims)) {
1775 continue;
1776 }
1777 if (input_dims.size() > 1 && output_dims.size() > 1) {
1778 return InternalError(
1779 "Should be handled by decomposing reshape into "
1780 "flatten-unflatten pair. %s",
1781 reshape->ToString());
1782 }
1783
1784 TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicReshapeSingleGroup(
1785 reshape, input_dims, output_dims,
1786 absl::MakeSpan(input_dynamic_dims),
1787 absl::MakeSpan(output_dynamic_dims),
1788 dynamic_dimension_inference));
1789 changed |= c;
1790 }
1791
1792 if (reshape->opcode() == HloOpcode::kDynamicReshape) {
1793 auto* static_reshape =
1794 reshape->AddInstruction(HloInstruction::CreateReshape(
1795 reshape->shape(), reshape->mutable_operand(0)));
1796 TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(static_reshape));
1797 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1798 reshape, static_reshape, {}));
1799 changed = true;
1800 }
1801
1802 return changed;
1803 }
1804
1805 // Insert pad-to-static after `inst` if `inst` has dynamic dimensions in it.
1806 // Recurse into tuple instructions.
InsertPadToStaticOnInstruction(HloInstruction * inst)1807 StatusOr<HloInstruction*> InsertPadToStaticOnInstruction(HloInstruction* inst) {
1808 if (inst->shape().is_static()) {
1809 return inst;
1810 }
1811 if (!inst->shape().IsTuple()) {
1812 // The output shape of pad static is a tuple. The 0th element is the data
1813 // output, which is the same as input shape, but without dynamic dimensions;
1814 // i-th element is the dynamic dimension size for i-1th input dimension.
1815 Shape data_output_shape = inst->shape(); // 0th element.
1816 data_output_shape.clear_dynamic_dimensions();
1817 Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
1818 for (int64_t i = 0; i < inst->shape().rank(); ++i) {
1819 ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
1820 &output_shape);
1821 }
1822 HloInstruction* pad_to_static =
1823 inst->AddInstruction(HloInstruction::CreateCustomCall(
1824 output_shape, {inst}, "PadToStatic", ""));
1825 HloInstruction* data_output =
1826 inst->AddInstruction(HloInstruction::CreateGetTupleElement(
1827 data_output_shape, pad_to_static, 0));
1828 return data_output;
1829 }
1830
1831 TF_RET_CHECK(inst->shape().IsTuple());
1832 std::vector<HloInstruction*> static_tuple_elements;
1833 for (int64_t i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
1834 // For each tuple element, if it is static, pass it through. If it is
1835 // dynamic, recursively call this function again.
1836 HloInstruction* gte =
1837 inst->AddInstruction(HloInstruction::CreateGetTupleElement(
1838 inst->shape().tuple_shapes(i), inst, i));
1839
1840 if (gte->shape().is_static()) {
1841 static_tuple_elements.push_back(gte);
1842 } else {
1843 TF_ASSIGN_OR_RETURN(HloInstruction * static_gte,
1844 InsertPadToStaticOnInstruction(gte));
1845 static_tuple_elements.push_back(static_gte);
1846 }
1847 }
1848
1849 return inst->AddInstruction(
1850 HloInstruction::CreateTuple(static_tuple_elements));
1851 }
1852
1853 // Inserts PadToStatic for parameters and custom-calls which "materialize"
1854 // dynamic outputs given only static inputs.
InsertPadToStaticAfterModuleInputs(HloModule * module)1855 Status InsertPadToStaticAfterModuleInputs(HloModule* module) {
1856 std::vector<HloInstruction*> params;
1857 HloComputation* entry = module->entry_computation();
1858 for (HloComputation* comp : module->MakeNonfusionComputationsSorted()) {
1859 for (HloInstruction* instr : comp->instructions()) {
1860 if (!instr->shape().is_static() &&
1861 ((instr->opcode() == HloOpcode::kParameter && comp == entry) ||
1862 instr->opcode() == HloOpcode::kCustomCall) &&
1863 absl::c_all_of(instr->operands(), [&](HloInstruction* operand) {
1864 return operand->shape().is_static();
1865 })) {
1866 LOG(ERROR) << "Inserting PadToStatic for instruction: "
1867 << instr->ToString();
1868 auto users = instr->users();
1869 TF_ASSIGN_OR_RETURN(HloInstruction * instr_static,
1870 InsertPadToStaticOnInstruction(instr));
1871 for (auto* user : users) {
1872 TF_RETURN_IF_ERROR(instr->ReplaceUseWith(user, instr_static));
1873 }
1874 if (instr == entry->root_instruction()) {
1875 module->entry_computation()->set_root_instruction(instr_static);
1876 }
1877 }
1878 }
1879 }
1880 return OkStatus();
1881 }
1882
1883 // Remove all dynamic shapes between pad-to-static and slice-to-dynamic.
1884 //
1885 // After this visitor the entry computation then looks like:
1886 // Param(dynamic)
1887 // |
1888 // GTE (dynamic)
1889 // |
1890 // PadToStatic(static)
1891 // |
1892 // .... regular computation with static shapes.
1893 // |
1894 // SliceToDynamic(dynamic)
1895 // |
1896 // ROOT tuple (dynamic)
1897 class DynamicShapeRemovingVisitor : public DfsHloVisitorWithDefault {
1898 public:
DynamicShapeRemovingVisitor(const DynamicPadderOptions::OpSupportsDynamismHandler & op_supports_dynamism_handler,DynamicDimensionInference * dynamic_dimension_inference)1899 explicit DynamicShapeRemovingVisitor(
1900 const DynamicPadderOptions::OpSupportsDynamismHandler&
1901 op_supports_dynamism_handler,
1902 DynamicDimensionInference* dynamic_dimension_inference)
1903 : op_supports_dynamism_handler_(op_supports_dynamism_handler),
1904 dynamic_dimension_inference_(dynamic_dimension_inference) {}
1905
1906 Status DefaultAction(HloInstruction* hlo) override;
1907
1908 Status HandleCustomCall(HloInstruction* hlo) override;
1909
1910 Status HandleTuple(HloInstruction* hlo) override;
1911 Status HandleGetTupleElement(HloInstruction* hlo) override;
1912
1913 Status HandleParameter(HloInstruction* hlo) override;
1914
Run(HloComputation * computation,const DynamicPadderOptions::OpSupportsDynamismHandler & op_supports_dynamism_handler,DynamicDimensionInference * dynamic_shape_inference,bool require_dynamic_output)1915 static Status Run(HloComputation* computation,
1916 const DynamicPadderOptions::OpSupportsDynamismHandler&
1917 op_supports_dynamism_handler,
1918 DynamicDimensionInference* dynamic_shape_inference,
1919 bool require_dynamic_output) {
1920 DynamicShapeRemovingVisitor visitor(op_supports_dynamism_handler,
1921 dynamic_shape_inference);
1922 TF_RETURN_IF_ERROR(computation->Accept(&visitor));
1923 // If the outputs is required to be dynamic form, insert static to dynamic
1924 // conversion as root.
1925 if (require_dynamic_output) {
1926 HloInstruction* root = computation->root_instruction();
1927 if (dynamic_shape_inference->HasDynamicDimension(root)) {
1928 TF_ASSIGN_OR_RETURN(HloInstruction * new_root,
1929 visitor.ConvertToDynamic(root));
1930 computation->set_root_instruction(new_root);
1931 }
1932 }
1933 return OkStatus();
1934 }
1935
1936 private:
1937 // If a tensor produced by `inst` is in dynamic form, convert it to static and
1938 // returns the new instruction.
1939 StatusOr<HloInstruction*> ConvertToStatic(HloInstruction* inst);
1940
1941 // If a tensor produced by `inst` is in static form, convert it to dynamic and
1942 // returns the new instruction.
1943 StatusOr<HloInstruction*> ConvertToDynamic(HloInstruction* inst);
1944
1945 const DynamicPadderOptions::OpSupportsDynamismHandler&
1946 op_supports_dynamism_handler_;
1947
1948 DynamicDimensionInference* dynamic_dimension_inference_;
1949 };
1950
ConvertToDynamic(HloInstruction * inst)1951 StatusOr<HloInstruction*> DynamicShapeRemovingVisitor::ConvertToDynamic(
1952 HloInstruction* inst) {
1953 const Shape& shape = inst->shape();
1954 if (shape.IsTuple()) {
1955 std::vector<HloInstruction*> dynamic_operands;
1956 for (int64_t i = 0; i < shape.tuple_shapes_size(); ++i) {
1957 auto gte = inst->AddInstruction(HloInstruction::CreateGetTupleElement(
1958 shape.tuple_shapes(i), inst, i));
1959 if (dynamic_dimension_inference_->HasDynamicDimension(inst, {i})) {
1960 TF_RETURN_IF_ERROR(dynamic_dimension_inference_->Update(gte));
1961 TF_ASSIGN_OR_RETURN(auto dynamic, ConvertToDynamic(gte));
1962 dynamic_operands.push_back(dynamic);
1963 } else {
1964 dynamic_operands.push_back(gte);
1965 }
1966 }
1967 return inst->AddInstruction(HloInstruction::CreateTuple(dynamic_operands));
1968 } else {
1969 // Collect the data input, as well as dimension sizes, and feed them to
1970 // slice to dynamic to create a dynamic tensor.
1971 Shape output_shape = shape; // 0th element.
1972 CHECK(output_shape.is_static());
1973 std::vector<HloInstruction*> slice_operand;
1974 slice_operand.push_back(inst);
1975 for (int64_t i = 0; i < output_shape.dimensions_size(); ++i) {
1976 auto dimension_size =
1977 dynamic_dimension_inference_->GetDynamicSize(inst, {}, i);
1978 if (dimension_size == nullptr) {
1979 dimension_size = inst->AddInstruction(HloInstruction::CreateConstant(
1980 LiteralUtil::CreateR0<int32_t>(output_shape.dimensions(i))));
1981 } else {
1982 output_shape.set_dynamic_dimension(i, true);
1983 }
1984 slice_operand.push_back(dimension_size);
1985 }
1986 return inst->AddInstruction(HloInstruction::CreateCustomCall(
1987 output_shape, slice_operand, "SliceToDynamic"));
1988 }
1989 }
1990
ConvertToStatic(HloInstruction * inst)1991 StatusOr<HloInstruction*> DynamicShapeRemovingVisitor::ConvertToStatic(
1992 HloInstruction* inst) {
1993 const Shape& shape = inst->shape();
1994 CHECK(shape.is_dynamic());
1995 if (shape.IsTuple()) {
1996 std::vector<HloInstruction*> static_operands;
1997 for (int64_t i = 0; i < shape.tuple_shapes_size(); ++i) {
1998 auto gte = inst->AddInstruction(HloInstruction::CreateGetTupleElement(
1999 shape.tuple_shapes(i), inst, i));
2000 TF_RETURN_IF_ERROR(dynamic_dimension_inference_->Update(gte));
2001 auto operand = inst->mutable_operand(i);
2002 if (shape.tuple_shapes(i).is_dynamic()) {
2003 TF_ASSIGN_OR_RETURN(auto static_inst, ConvertToStatic(gte));
2004 static_operands.push_back(static_inst);
2005 } else {
2006 static_operands.push_back(operand);
2007 }
2008 }
2009 return inst->AddInstruction(HloInstruction::CreateTuple(static_operands));
2010 } else {
2011 // The output shape of pad static is a tuple. The 0th element is the data
2012 // output, which is the same as input shape, but without dynamic dimensions.
2013 // i-th element is the dynamic dimension size for i-1th input dimension.
2014 Shape data_output_shape = shape; // 0th element.
2015 data_output_shape.clear_dynamic_dimensions();
2016 Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
2017 for (int64_t i = 0; i < shape.rank(); ++i) {
2018 ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
2019 &output_shape);
2020 }
2021 HloInstruction* pad_to_static =
2022 inst->AddInstruction(HloInstruction::CreateCustomCall(
2023 output_shape, {inst}, "PadToStatic", ""));
2024 HloInstruction* data_output =
2025 inst->AddInstruction(HloInstruction::CreateGetTupleElement(
2026 data_output_shape, pad_to_static, 0));
2027 return data_output;
2028 }
2029 }
2030
DefaultAction(HloInstruction * hlo)2031 Status DynamicShapeRemovingVisitor::DefaultAction(HloInstruction* hlo) {
2032 const bool input_is_dynamic = absl::c_any_of(
2033 hlo->operands(),
2034 [](const HloInstruction* hlo) { return hlo->shape().is_dynamic(); });
2035
2036 // By default, ops don't support dynamic lowering.
2037 OpDynamismSupport op_support = OpDynamismSupport::kNoSupport;
2038 if (op_supports_dynamism_handler_) {
2039 op_support = op_supports_dynamism_handler_(hlo);
2040 }
2041 if (op_support == OpDynamismSupport::kNoSupport) {
2042 for (auto* sub_computation : hlo->called_computations()) {
2043 for (auto* param : sub_computation->parameter_instructions()) {
2044 param->mutable_shape()->clear_dynamic_dimensions();
2045 }
2046 }
2047 }
2048 // If the input to an op is static and the op doesn't support
2049 // dynamic output, remove dynamism in output -- dynamic_padder should have
2050 // rewritten it to support static shapes.
2051 if (!input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) {
2052 hlo->mutable_shape()->clear_dynamic_dimensions();
2053 return OkStatus();
2054 }
2055
2056 // Op doesn't support dynamic tensor: For each operand rewrite dynamic input
2057 // into static input using pad_to_static.
2058 if (input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) {
2059 VLOG(1) << "op doesn't support dynamic tensor: " << hlo->ToString();
2060 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
2061 if (hlo->operand(i)->shape().is_dynamic()) {
2062 TF_ASSIGN_OR_RETURN(auto static_operand,
2063 ConvertToStatic(hlo->mutable_operand(i)));
2064 TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, static_operand));
2065 }
2066 }
2067 // This op doesn't support dynamic lowering so the op has to be static.
2068 hlo->mutable_shape()->clear_dynamic_dimensions();
2069 return OkStatus();
2070 }
2071
2072 // If the op requires dynamic tensor and input is static -- construct a
2073 // dynamic tensor from the static tensor to feed it.
2074 if (!input_is_dynamic && op_support == OpDynamismSupport::kRequired) {
2075 VLOG(1) << "op doesn't support static tensor: " << hlo->ToString();
2076 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
2077 auto operand = hlo->mutable_operand(i);
2078 if (dynamic_dimension_inference_->HasDynamicDimension(operand)) {
2079 TF_ASSIGN_OR_RETURN(auto dynamic_operand,
2080 ConvertToDynamic(hlo->mutable_operand(i)));
2081 TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, dynamic_operand));
2082 }
2083 }
2084 return OkStatus();
2085 }
2086
2087 return OkStatus();
2088 }
2089
HandleGetTupleElement(HloInstruction * hlo)2090 Status DynamicShapeRemovingVisitor::HandleGetTupleElement(HloInstruction* hlo) {
2091 *hlo->mutable_shape() =
2092 hlo->operand(0)->shape().tuple_shapes(hlo->tuple_index());
2093 return OkStatus();
2094 }
2095
HandleTuple(HloInstruction * hlo)2096 Status DynamicShapeRemovingVisitor::HandleTuple(HloInstruction* hlo) {
2097 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
2098 *hlo->mutable_shape()->mutable_tuple_shapes(i) = hlo->operand(i)->shape();
2099 }
2100 return OkStatus();
2101 }
2102
HandleParameter(HloInstruction * hlo)2103 Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) {
2104 return OkStatus();
2105 }
2106
HandleCustomCall(HloInstruction * hlo)2107 Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) {
2108 if (hlo->custom_call_target() == "SliceToDynamic" ||
2109 hlo->custom_call_target() == "PadToStatic") {
2110 // Those ops support are created to handle dynamic tensors so by their
2111 // nature they support dynamic lowering.
2112 return OkStatus();
2113 }
2114
2115 return DefaultAction(hlo);
2116 }
2117
2118 } // namespace
2119
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2120 StatusOr<bool> DynamicPadder::Run(
2121 HloModule* module,
2122 const absl::flat_hash_set<absl::string_view>& execution_threads) {
2123 bool changed = false;
2124 VLOG(2) << "Pre DynamicPadder HLO:";
2125 XLA_VLOG_LINES(2, module->ToString());
2126 // Removes dynamic dimensions on parameters if there is already a binding for
2127 // it. We do this because we have two different APIs to express a dynamic
2128 // dimension:
2129 //
2130 // 1. Dynamic dimension as specified directly in the shape -- Needed for
2131 // PyTorch.
2132 //
2133 // 2. Dynamic dimension using dynamic parameter binding object. This
2134 // is needed for tensorflow.
2135 //
2136 // For case 1, we will insert "pad-to-static" instruction in the
2137 // beginning of xla execution, to make it into a static layout.
2138 //
2139 // For case 2, since it already has a static layout, we remove the
2140 // dynamic dimension.
2141 //
2142 // TODO(b/145140571): Convert all API invocations to case 1.
2143 //
2144 TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().ForEachBinding(
2145 [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
2146 const DynamicParameterBinding::DynamicDimension& dynamic_dimension)
2147 -> Status {
2148 HloInstruction* parameter =
2149 module->entry_computation()->parameter_instruction(
2150 dynamic_dimension.parameter_num);
2151 ShapeUtil::UpdateDynamicDimension(parameter->mutable_shape(),
2152 dynamic_dimension.parameter_index,
2153 dynamic_dimension.dimension, false);
2154 return OkStatus();
2155 }));
2156
2157 TF_RETURN_IF_ERROR(InsertPadToStaticAfterModuleInputs(module));
2158 TF_ASSIGN_OR_RETURN(
2159 DynamicDimensionInference dynamic_dimension_inference,
2160 DynamicDimensionInference::Run(module, options_.custom_call_handler,
2161 options_.shape_check_mode,
2162 options_.assertion_generator));
2163
2164 std::vector<HloComputation*> computations =
2165 module->MakeComputationPostOrder(execution_threads);
2166
2167 for (HloComputation* computation : computations) {
2168 for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
2169 OpDynamismSupport has_dynamism_support = OpDynamismSupport::kNoSupport;
2170 if (options_.op_supports_dynamism_handler != nullptr) {
2171 has_dynamism_support = options_.op_supports_dynamism_handler(inst);
2172 }
2173 // This op support dynamic lowering, no padding is required.
2174 if (has_dynamism_support != OpDynamismSupport::kNoSupport) {
2175 continue;
2176 }
2177 if (inst->opcode() == HloOpcode::kConcatenate) {
2178 TF_ASSIGN_OR_RETURN(
2179 bool c, RewriteDynamicConcat(inst, &dynamic_dimension_inference));
2180 changed |= c;
2181 continue;
2182 }
2183 if (inst->opcode() == HloOpcode::kReverse) {
2184 TF_ASSIGN_OR_RETURN(bool c,
2185 RewriteReverse(inst, &dynamic_dimension_inference));
2186 changed |= c;
2187 continue;
2188 }
2189 if (inst->opcode() == HloOpcode::kSort) {
2190 TF_ASSIGN_OR_RETURN(
2191 bool c, RewriteDynamicSort(inst, &dynamic_dimension_inference));
2192 changed |= c;
2193 continue;
2194 }
2195 if (inst->opcode() == HloOpcode::kReshape ||
2196 inst->opcode() == HloOpcode::kDynamicReshape) {
2197 TF_ASSIGN_OR_RETURN(
2198 bool c, RewriteDynamicReshape(inst, &dynamic_dimension_inference));
2199 changed |= c;
2200 continue;
2201 }
2202
2203 // Elementwise binary with dynamic shapes have implicit broadcast
2204 // semantics.
2205 if (inst->IsElementwiseBinary()) {
2206 TF_ASSIGN_OR_RETURN(
2207 bool c, RewriteDynamicBinaryOp(inst, &dynamic_dimension_inference));
2208 changed |= c;
2209 continue;
2210 }
2211
2212 if (inst->opcode() == HloOpcode::kDynamicUpdateSlice) {
2213 TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicUpdateSlice(
2214 inst, &dynamic_dimension_inference));
2215 changed |= c;
2216 continue;
2217 }
2218
2219 if (inst->IsCustomCall("DynamicConvolutionInputGrad")) {
2220 TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicConvolutionInputGrad(
2221 inst, &dynamic_dimension_inference));
2222 changed |= c;
2223 continue;
2224 }
2225
2226 if (inst->IsCustomCall("DynamicConvolutionForward")) {
2227 TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicConvolutionForward(
2228 inst, &dynamic_dimension_inference));
2229 changed |= c;
2230 continue;
2231 }
2232
2233 if (inst->IsCustomCall("DynamicConvolutionKernelGrad")) {
2234 TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicConvolutionKernelGrad(
2235 inst, &dynamic_dimension_inference));
2236 changed |= c;
2237 continue;
2238 }
2239
2240 if (inst->IsCustomCall("DynamicReduceWindowSamePadding")) {
2241 TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicReduceWindowSamePadding(
2242 inst, &dynamic_dimension_inference));
2243 changed |= c;
2244 continue;
2245 }
2246
2247 if (inst->IsCustomCall("DynamicSelectAndScatterSamePadding")) {
2248 TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicSelectAndScatterSamePadding(
2249 inst, &dynamic_dimension_inference));
2250 changed |= c;
2251 continue;
2252 }
2253
2254 for (int64_t operand_num = 0; operand_num < inst->operand_count();
2255 ++operand_num) {
2256 HloInstruction* original_operand = inst->mutable_operand(operand_num);
2257 HloInstruction* operand = original_operand;
2258 if (!operand->shape().IsArray()) {
2259 continue;
2260 }
2261
2262 for (int64_t input_dim = 0; input_dim < operand->shape().rank();
2263 ++input_dim) {
2264 HloInstruction* operand_dynamic_size =
2265 dynamic_dimension_inference.GetDynamicSize(original_operand, {},
2266 input_dim);
2267 if (operand_dynamic_size == nullptr) {
2268 continue;
2269 }
2270 VLOG(2) << "Has dynamic dimension of operand" << operand_num << " @"
2271 << input_dim;
2272
2273 if (ShouldSkipPadOnOperand(inst, operand_num, input_dim)) {
2274 continue;
2275 }
2276
2277 TF_ASSIGN_OR_RETURN(HloInstruction * identity_value,
2278 ChooseIdentityValue(inst, operand_num));
2279 if (identity_value == nullptr) {
2280 continue;
2281 }
2282
2283 HloInstruction* padded = PadWithScalar(
2284 operand, input_dim, operand_dynamic_size, identity_value);
2285 TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded));
2286 operand = inst->mutable_operand(operand_num);
2287 changed = true;
2288 }
2289 }
2290 }
2291 }
2292
2293 // There are ops that only support dynamic lowering and ops that only support
2294 // static lowering, add dynamic<->static tensor conversion around the boundary
2295 // between those ops, as well as the root instruction.
2296 computations = module->MakeComputationPostOrder(execution_threads);
2297 // Reverse postorder so that if caller doesn't support dynamic tensor (while,
2298 // etc), change their called computation to only take static tensors.
2299 for (auto it = computations.rbegin(); it != computations.rend(); ++it) {
2300 HloComputation* computation = *it;
2301 // if slice_dynamic_output_ is set and this is entry computation, we need
2302 // the output tensor to be in dynamic form.
2303 bool require_dynamic_output = options_.slice_dynamic_output &&
2304 computation == module->entry_computation();
2305 changed |= require_dynamic_output;
2306 TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run(
2307 computation, options_.op_supports_dynamism_handler,
2308 &dynamic_dimension_inference,
2309 /*require_dynamic_output=*/require_dynamic_output));
2310 }
2311
2312 if (changed) {
2313 dynamic_padding_gauge->GetCell()->Set(changed);
2314 module->set_is_dynamic(true);
2315 }
2316
2317 for (auto* computation : module->computations(execution_threads)) {
2318 for (auto instruction : computation->MakeInstructionPostOrder()) {
2319 TF_ASSIGN_OR_RETURN(
2320 bool c, ReplaceGetSize(instruction, &dynamic_dimension_inference));
2321 changed |= c;
2322 }
2323 }
2324
2325 for (auto* computation : module->computations(execution_threads)) {
2326 for (auto instruction : computation->MakeInstructionPostOrder()) {
2327 TF_ASSIGN_OR_RETURN(bool c, ReplaceSetSize(instruction));
2328 changed |= c;
2329
2330 TF_ASSIGN_OR_RETURN(c, ReplaceSetBound(instruction));
2331 changed |= c;
2332 }
2333 }
2334
2335 HloDCE dce;
2336 TF_ASSIGN_OR_RETURN(bool c, dce.Run(module, execution_threads));
2337 changed |= c;
2338
2339 VLOG(2) << "Post DynamicPadder HLO:";
2340 XLA_VLOG_LINES(2, module->ToString());
2341 return changed;
2342 }
2343
2344 } // namespace xla
2345