xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/scatter_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/scatter_expander.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/call_inliner.h"
21 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/while_util.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 
30 namespace xla {
31 
32 // Transposes the given scatter_indices such that the index_vector_dim becomes
33 // the most-minor dimension.
TransposeIndexVectorDimToLast(HloInstruction * scatter_indices,int64_t index_vector_dim)34 static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
35     HloInstruction* scatter_indices, int64_t index_vector_dim) {
36   const Shape& scatter_indices_shape = scatter_indices->shape();
37 
38   if (scatter_indices_shape.dimensions_size() == index_vector_dim) {
39     return scatter_indices;
40   }
41 
42   if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) {
43     return scatter_indices;
44   }
45 
46   std::vector<int64_t> permutation;
47   permutation.reserve(scatter_indices_shape.dimensions_size());
48   for (int64_t i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
49     if (i != index_vector_dim) {
50       permutation.push_back(i);
51     }
52   }
53   permutation.push_back(index_vector_dim);
54   return MakeTransposeHlo(scatter_indices, permutation);
55 }
56 
57 // Canonicalizes the scatter_indices tensor in order to keep them uniform while
58 // performing the scatter operation.
CanonicalizeScatterIndices(HloInstruction * scatter_indices,int64_t index_vector_dim)59 static StatusOr<HloInstruction*> CanonicalizeScatterIndices(
60     HloInstruction* scatter_indices, int64_t index_vector_dim) {
61   // Transpose the non-index-vector dimensions to the front.
62   TF_ASSIGN_OR_RETURN(
63       HloInstruction * transposed_scatter_indices,
64       TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim));
65   if (scatter_indices->shape().rank() == index_vector_dim + 1 &&
66       scatter_indices->shape().dimensions(index_vector_dim) == 1) {
67     auto new_shape =
68         ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape());
69     TF_ASSIGN_OR_RETURN(scatter_indices,
70                         MakeReshapeHlo(new_shape, scatter_indices));
71   }
72   bool indices_are_scalar =
73       index_vector_dim == scatter_indices->shape().dimensions_size();
74 
75   // The number of dimensions in scatter_indices that are index dimensions.
76   const int64_t index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1;
77 
78   // If there is only one index (i.e. scatter_indices has rank 1 and this
79   // scatter is really just a dynamic update slice) add a leading degenerate
80   // dimension for uniformity.  Otherwise create a "collapsed" leading dimension
81   // that subsumes all of the non-index-vector dimensions.
82   const Shape& shape = transposed_scatter_indices->shape();
83   if (shape.dimensions_size() == index_dims_in_scatter_indices) {
84     return PrependDegenerateDims(transposed_scatter_indices, 1);
85   } else {
86     // Collapse all but the dimensions (0 or 1) in scatter_indices containing
87     // the index vectors.
88     return CollapseFirstNDims(
89         transposed_scatter_indices,
90         shape.dimensions_size() - index_dims_in_scatter_indices);
91   }
92 }
93 
94 // Permutes the `updates` tensor such that all the scatter dims appear in the
95 // major dimensions and all the window dimensions appear in the minor
96 // dimensions.
PermuteScatterAndWindowDims(HloInstruction * updates,absl::Span<const int64_t> update_window_dims)97 static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
98     HloInstruction* updates, absl::Span<const int64_t> update_window_dims) {
99   std::vector<int64_t> permutation;
100   const int64_t updates_rank = updates->shape().rank();
101   permutation.reserve(updates_rank);
102 
103   for (int64_t i = 0; i < updates_rank; ++i) {
104     bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i);
105     if (is_scatter_dim) {
106       permutation.push_back(i);
107     }
108   }
109   for (auto window_dim : update_window_dims) {
110     permutation.push_back(window_dim);
111   }
112 
113   return MakeTransposeHlo(updates, permutation);
114 }
115 
116 // Expands or contracts the scatter indices in the updates tensor.
AdjustScatterDims(const Shape & scatter_indices_shape,HloInstruction * updates,int64_t index_vector_dim)117 static StatusOr<HloInstruction*> AdjustScatterDims(
118     const Shape& scatter_indices_shape, HloInstruction* updates,
119     int64_t index_vector_dim) {
120   int64_t num_scatter_dims = scatter_indices_shape.dimensions_size();
121   if (index_vector_dim < scatter_indices_shape.dimensions_size()) {
122     --num_scatter_dims;
123   }
124   if (num_scatter_dims == 0) {
125     // If there are no scatter dims, this must be a dynamic-update-slice kind of
126     // scatter. In this case, we prepend a degenerate dimension to work
127     // uniformly in the while loop.
128     return PrependDegenerateDims(updates, 1);
129   }
130   return CollapseFirstNDims(updates, num_scatter_dims);
131 }
132 
133 // Expands an index vector from the scatter_indices tensor into a vector that
134 // can be used to dynamic-update-slice to perform the scatter update.
ExpandIndexVectorIntoOperandSpace(HloInstruction * index_vector,const ScatterDimensionNumbers & dim_numbers,int64_t operand_rank)135 static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
136     HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers,
137     int64_t operand_rank) {
138   HloComputation* computation = index_vector->parent();
139   const Shape& index_shape = index_vector->shape();
140 
141   // Scatter of a scalar. Return a zero-sized vector of indices.
142   if (operand_rank == 0) {
143     return computation->AddInstruction(HloInstruction::CreateConstant(
144         LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0})));
145   }
146 
147   HloInstruction* zero =
148       computation->AddInstruction(HloInstruction::CreateConstant(
149           LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
150 
151   // We extract out individual components from the smaller index and concatenate
152   // them (interspersing zeros as needed) into the larger index.
153   std::vector<HloInstruction*> expanded_index_components;
154 
155   for (int i = 0; i < operand_rank; i++) {
156     int64_t index_vector_dim_index =
157         FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i);
158     if (index_vector_dim_index !=
159         dim_numbers.scatter_dims_to_operand_dims_size()) {
160       TF_ASSIGN_OR_RETURN(
161           HloInstruction * component_to_concat,
162           MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
163                        /*limit_indices=*/{index_vector_dim_index + 1},
164                        /*strides=*/{1}));
165       expanded_index_components.push_back(component_to_concat);
166     } else {
167       expanded_index_components.push_back(zero);
168     }
169   }
170 
171   return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
172 }
173 
CheckIndexValidity(HloComputation * computation,HloInstruction * index,absl::Span<const int64_t> operand_dims,absl::Span<const int64_t> window_sizes,HloModule * module)174 static StatusOr<HloInstruction*> CheckIndexValidity(
175     HloComputation* computation, HloInstruction* index,
176     absl::Span<const int64_t> operand_dims,
177     absl::Span<const int64_t> window_sizes, HloModule* module) {
178   DCHECK_NE(nullptr, module);
179   DCHECK_EQ(operand_dims.size(), window_sizes.size());
180 
181   // Valid range for the index: [0, operand_dims - window_sizes]
182 
183   // Check if the index has any negative values.
184   HloInstruction* zero_index = BroadcastZeros(
185       computation, index->shape().element_type(), index->shape().dimensions());
186   TF_ASSIGN_OR_RETURN(
187       HloInstruction * negative_index_check,
188       MakeCompareHlo(ComparisonDirection::kLe, zero_index, index));
189 
190   // Check if the index is OOB w.r.t. the operand dimensions and window sizes.
191   std::vector<int64_t> max_valid_index(operand_dims.size());
192   for (int i = 0; i < operand_dims.size(); ++i) {
193     max_valid_index[i] = operand_dims[i] - window_sizes[i];
194   }
195   TF_ASSIGN_OR_RETURN(
196       HloInstruction * max_valid_index_constant,
197       MakeR1ConstantHlo<int64_t>(computation, index->shape().element_type(),
198                                  max_valid_index));
199   TF_ASSIGN_OR_RETURN(HloInstruction * oob_index_check,
200                       MakeCompareHlo(ComparisonDirection::kGe,
201                                      max_valid_index_constant, index));
202 
203   // Combine the results of the two checks above.
204   TF_ASSIGN_OR_RETURN(
205       HloInstruction * valid_index,
206       MakeBinaryHlo(HloOpcode::kAnd, negative_index_check, oob_index_check));
207 
208   // Reduce the index validity check vector into a scalar predicate.
209   auto reduction_init = computation->AddInstruction(
210       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
211   TF_ASSIGN_OR_RETURN(
212       HloInstruction * valid_index_reduced,
213       MakeReduceHlo(valid_index, reduction_init, HloOpcode::kAnd, module));
214 
215   // Return a broadcasted value of the scalar predicate to the same size as the
216   // window.
217   return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes);
218 }
219 
CallAndGetOutput(HloComputation * original,int output_index)220 static StatusOr<HloComputation*> CallAndGetOutput(HloComputation* original,
221                                                   int output_index) {
222   HloInstruction* original_root = original->root_instruction();
223   if (!original_root->shape().IsTuple()) {
224     return original;
225   }
226   HloComputation* new_comp = [&] {
227     HloComputation::Builder builder(
228         absl::StrCat(original->name(), ".dup.", output_index));
229     for (int i = 0, n = original->num_parameters(); i < n; ++i) {
230       HloInstruction* original_param = original->parameter_instruction(i);
231       builder.AddInstruction(HloInstruction::CreateParameter(
232           i, original_param->shape(), original_param->name()));
233     }
234     return original->parent()->AddEmbeddedComputation(builder.Build());
235   }();
236   HloInstruction* call_original = new_comp->AddInstruction(
237       HloInstruction::CreateCall(original_root->shape(),
238                                  new_comp->parameter_instructions(), original));
239   new_comp->set_root_instruction(
240       new_comp->AddInstruction(
241           HloInstruction::CreateGetTupleElement(call_original, output_index)),
242       /*accept_different_shape=*/true);
243   TF_RETURN_IF_ERROR(CallInliner::Inline(call_original).status());
244   return new_comp;
245 }
246 
247 // Body of the while loop that performs the scatter operation using other HLOs.
ScatterLoopBody(HloScatterInstruction * scatter,HloInstruction * induction_var,absl::Span<HloInstruction * const> loop_state)248 static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
249     HloScatterInstruction* scatter, HloInstruction* induction_var,
250     absl::Span<HloInstruction* const> loop_state) {
251   const ScatterDimensionNumbers& dim_numbers =
252       scatter->scatter_dimension_numbers();
253   CHECK_EQ(loop_state.size(), scatter->operand_count());
254   auto operands = loop_state.first(scatter->scatter_operand_count());
255   HloInstruction* scatter_indices = loop_state[operands.size()];
256   auto updates = loop_state.last(operands.size());
257 
258   bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1;
259 
260   // Build a vector form of the induction variable of the while loop.
261   HloInstruction* induction_var_as_vector =
262       MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
263                        /*result_shape_bounds=*/{1});
264 
265   // Pick the index to scatter from scatter_indices based on the induction_var
266   // and transform that to an index into the `operand` space.
267   HloInstruction* index_vector;
268   if (has_scalar_indices) {
269     TF_ASSIGN_OR_RETURN(
270         index_vector,
271         MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1}));
272   } else {
273     TF_ASSIGN_OR_RETURN(
274         HloInstruction * index_into_scatter_indices,
275         PadVectorWithZeros(induction_var_as_vector,
276                            /*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
277     int index_vector_size = scatter_indices->shape().dimensions(1);
278     TF_ASSIGN_OR_RETURN(
279         HloInstruction * index_vector_2d,
280         MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices,
281                             {1, index_vector_size}));
282     TF_ASSIGN_OR_RETURN(index_vector,
283                         ElideDegenerateDims(index_vector_2d, {0}));
284   }
285   TF_ASSIGN_OR_RETURN(
286       HloInstruction * scatter_slice_start,
287       ExpandIndexVectorIntoOperandSpace(
288           index_vector, dim_numbers, operands[0]->shape().dimensions_size()));
289 
290   // Extract the slice to be used to update from `updates` tensor for the
291   // induction_var corresponding to this iteration of the while loop.
292   TF_ASSIGN_OR_RETURN(
293       HloInstruction * index_into_updates,
294       PadVectorWithZeros(
295           induction_var_as_vector, /*zeros_to_prepend=*/0,
296           /*zeros_to_append=*/updates[0]->shape().dimensions_size() - 1));
297   std::vector<int64_t> update_slice_bounds(
298       updates[0]->shape().dimensions().begin(),
299       updates[0]->shape().dimensions().end());
300   update_slice_bounds[0] = 1;
301 
302   absl::InlinedVector<HloInstruction*, 2> map_operands(
303       operands.size() + updates.size(), nullptr);
304   auto operand_slices_to_update =
305       absl::MakeSpan(map_operands).first(operands.size());
306   auto update_slices_with_dims_inserted =
307       absl::MakeSpan(map_operands).last(updates.size());
308   absl::Span<const int64_t> actual_update_slice_dims;
309   for (int i = 0, n = operands.size(); i < n; ++i) {
310     HloInstruction* update = updates[i];
311     TF_ASSIGN_OR_RETURN(
312         HloInstruction * update_slice,
313         MakeDynamicSliceHlo(update, index_into_updates, update_slice_bounds));
314     TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter,
315                         ElideDegenerateDims(update_slice, {0}));
316     TF_ASSIGN_OR_RETURN(
317         HloInstruction * update_slice_with_dims_inserted,
318         InsertDegenerateDims(update_slice_for_scatter,
319                              dim_numbers.inserted_window_dims()));
320     update_slices_with_dims_inserted[i] = update_slice_with_dims_inserted;
321     // Note that the following transformation assumes that both DynamicSlice and
322     // DynamicUpdateSlice follow the same semantics for OOB indices. For
323     // example, if there are negative indices and DynamicSlice uses "clamping"
324     // semantics, then the extracted data will be "shifted". Since
325     // DynamicUpdateSlice also follows the same "clamping" semantics, writing
326     // the update will also be "shifted" by exactly the same amount. So, this
327     // transformation is correct as long as the semantics of handling OOB
328     // indices remain the same in DynamicSlice and DynamicUpdateSlice.
329 
330     // Extract the slice to update from `operand` tensor.
331     HloInstruction* operand = operands[i];
332     const Shape& update_slice_shape = update_slice_with_dims_inserted->shape();
333     TF_ASSIGN_OR_RETURN(HloInstruction * operand_slice_to_update,
334                         MakeDynamicSliceHlo(operand, scatter_slice_start,
335                                             update_slice_shape.dimensions()));
336     operand_slices_to_update[i] = operand_slice_to_update;
337     if (i == 0) {
338       actual_update_slice_dims = update_slice_shape.dimensions();
339     } else {
340       TF_RET_CHECK(actual_update_slice_dims == update_slice_shape.dimensions());
341     }
342   }
343 
344   TF_ASSIGN_OR_RETURN(
345       HloInstruction * is_index_valid,
346       CheckIndexValidity(operands[0]->parent(), scatter_slice_start,
347                          operands[0]->shape().dimensions(),
348                          actual_update_slice_dims, scatter->GetModule()));
349 
350   // Write the updated value of the slice into `operand` tensor.
351   std::vector<HloInstruction*> updated_loop_state;
352   updated_loop_state.reserve(loop_state.size());
353   for (int i = 0, n = operands.size(); i < n; ++i) {
354     // Compute the new value for the slice to be updated in `operand` tensor by
355     // combining the existing value and the update value using the update
356     // computation.
357     // NOTE: For scatters with N outputs, we currently have duplicate the Map
358     // computation N times because we don't support multioutput Map yet.
359     TF_ASSIGN_OR_RETURN(HloComputation * to_apply,
360                         CallAndGetOutput(scatter->to_apply(), i));
361     TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand_slice,
362                         MakeMapHlo(map_operands, to_apply));
363     // Select the updated operand only if the index is valid. If not, select the
364     // original value.
365     TF_ASSIGN_OR_RETURN(HloInstruction * updates_to_apply,
366                         MakeSelectHlo(is_index_valid, updated_operand_slice,
367                                       operand_slices_to_update[i]));
368     TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand,
369                         MakeDynamicUpdateSliceHlo(operands[i], updates_to_apply,
370                                                   scatter_slice_start));
371     updated_loop_state.push_back(updated_operand);
372   }
373   updated_loop_state.push_back(scatter_indices);
374   absl::c_copy(updates, std::back_inserter(updated_loop_state));
375 
376   return updated_loop_state;
377 }
378 
ScatterTripCount(const HloScatterInstruction * scatter)379 static int64_t ScatterTripCount(const HloScatterInstruction* scatter) {
380   // Compute the trip count for the while loop to be used for scatter. This
381   // should be the number of indices we should scatter into the operand.
382   const HloInstruction* scatter_indices = scatter->scatter_indices();
383   const Shape& scatter_indices_shape = scatter_indices->shape();
384   const ScatterDimensionNumbers& dim_numbers =
385       scatter->scatter_dimension_numbers();
386   int64_t scatter_loop_trip_count = 1;
387   for (int64_t i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
388     if (i != dim_numbers.index_vector_dim()) {
389       scatter_loop_trip_count *= scatter_indices_shape.dimensions(i);
390     }
391   }
392   return scatter_loop_trip_count;
393 }
394 
395 // High Level Algorithm.
396 //
397 // 1. Canonicalize the scatter_indices tensor such that it has rank 2, where
398 //    each row is an index into the operand.
399 // 2. Canonicalize the updates tensor such that is has rank `num_window_dims+1`
400 //    and the scatter dim is the most-major dimension.
401 // 3. Iterate over the set of indices in the canonicalized scatter_indices
402 //    tensor using a while loop, updating the operand for each such index. Each
403 //    iteration of this while loop performs the following:
404 //      a. Pick the index from scatter_indices for this iteration.
405 //      b. Transfrom this index into an index into the operand space.
406 //      c. Extract the slice to be used to update from the updates tensor.
407 //      d. Extract the slice to update from the operand tensor.
408 //      e. Compute the new value for the slice to update by combining the slices
409 //         from c. and d. using the update_computation of scatter.
410 //      f. Write the updated value of the slice into the operand tensor.
411 
ExpandInstruction(HloInstruction * inst)412 StatusOr<HloInstruction*> ScatterExpander::ExpandInstruction(
413     HloInstruction* inst) {
414   auto* scatter = Cast<HloScatterInstruction>(inst);
415   auto scatter_operands = scatter->scatter_operands();
416   HloInstruction* scatter_indices = scatter->scatter_indices();
417   auto scatter_updates = scatter->scatter_updates();
418   const ScatterDimensionNumbers& dim_numbers =
419       scatter->scatter_dimension_numbers();
420 
421   // If the updates tensors are empty, there is no need to update the operands.
422   // The operands can be forwarded.
423   if (ShapeUtil::IsZeroElementArray(scatter_updates[0]->shape())) {
424     if (scatter_operands.size() == 1) {
425       return scatter_operands[0];
426     }
427     return scatter->parent()->AddInstruction(
428         HloInstruction::CreateTuple(scatter_operands));
429   }
430 
431   // Compute the trip count for the while loop to be used for scatter. This
432   // should be the number of indices we should scatter into the operand.
433   int64_t scatter_loop_trip_count = ScatterTripCount(scatter);
434   if (!IsInt32(scatter_loop_trip_count)) {
435     return Unimplemented(
436         "Scatter operations with more than 2147483647 scatter indices are not "
437         "supported. This error occurred for %s.",
438         scatter->ToString());
439   }
440 
441   // Canonicalize the scatter_indices, after which the size of its most-major
442   // dimension must be same as the while loop trip count.
443   TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices,
444                       CanonicalizeScatterIndices(
445                           scatter_indices, dim_numbers.index_vector_dim()));
446   CHECK_EQ(scatter_loop_trip_count,
447            canonical_scatter_indices->shape().dimensions(0));
448 
449   // Canonicalize the updates, after which the size of its most-major dimension
450   // must be same as the while loop trip count.
451   std::vector<HloInstruction*> adjusted_canonical_updates;
452   adjusted_canonical_updates.reserve(scatter_updates.size());
453   for (HloInstruction* update : scatter_updates) {
454     TF_ASSIGN_OR_RETURN(
455         HloInstruction * canonical_update,
456         PermuteScatterAndWindowDims(update, dim_numbers.update_window_dims()));
457     TF_ASSIGN_OR_RETURN(
458         HloInstruction * adjusted_canonical_update,
459         AdjustScatterDims(scatter_indices->shape(), canonical_update,
460                           dim_numbers.index_vector_dim()));
461     CHECK_EQ(scatter_loop_trip_count,
462              adjusted_canonical_update->shape().dimensions(0));
463     adjusted_canonical_updates.push_back(adjusted_canonical_update);
464   }
465 
466   // The while loop that implements the scatter operation.
467   std::vector<HloInstruction*> loop_state;
468   loop_state.reserve(scatter->operand_count());
469   absl::c_copy(scatter_operands, std::back_inserter(loop_state));
470   loop_state.push_back(canonical_scatter_indices);
471   absl::c_copy(adjusted_canonical_updates, std::back_inserter(loop_state));
472   StatusOr<std::vector<HloInstruction*>> scatter_loop_result_status =
473       WhileUtil::MakeCountedLoop(
474           scatter->parent(), scatter_loop_trip_count, loop_state,
475           [scatter](HloInstruction* induction_var,
476                     const std::vector<HloInstruction*>& loop_state) {
477             return ScatterLoopBody(scatter, induction_var, loop_state);
478           },
479           scatter->metadata());
480   TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> scatter_loop_result,
481                       scatter_loop_result_status);
482   auto results =
483       absl::MakeSpan(scatter_loop_result).first(scatter_operands.size());
484   return MaybeMakeTuple(results);
485 }
486 
InstructionMatchesPattern(HloInstruction * inst)487 bool ScatterExpander::InstructionMatchesPattern(HloInstruction* inst) {
488   auto* scatter = DynCast<HloScatterInstruction>(inst);
489   return scatter &&
490          (mode_ == kEliminateAllScatters || ScatterTripCount(scatter) == 1);
491 }
492 
493 }  // namespace xla
494