1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/scatter_expander.h"
17
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/call_inliner.h"
21 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/while_util.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29
30 namespace xla {
31
32 // Transposes the given scatter_indices such that the index_vector_dim becomes
33 // the most-minor dimension.
TransposeIndexVectorDimToLast(HloInstruction * scatter_indices,int64_t index_vector_dim)34 static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
35 HloInstruction* scatter_indices, int64_t index_vector_dim) {
36 const Shape& scatter_indices_shape = scatter_indices->shape();
37
38 if (scatter_indices_shape.dimensions_size() == index_vector_dim) {
39 return scatter_indices;
40 }
41
42 if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) {
43 return scatter_indices;
44 }
45
46 std::vector<int64_t> permutation;
47 permutation.reserve(scatter_indices_shape.dimensions_size());
48 for (int64_t i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
49 if (i != index_vector_dim) {
50 permutation.push_back(i);
51 }
52 }
53 permutation.push_back(index_vector_dim);
54 return MakeTransposeHlo(scatter_indices, permutation);
55 }
56
57 // Canonicalizes the scatter_indices tensor in order to keep them uniform while
58 // performing the scatter operation.
CanonicalizeScatterIndices(HloInstruction * scatter_indices,int64_t index_vector_dim)59 static StatusOr<HloInstruction*> CanonicalizeScatterIndices(
60 HloInstruction* scatter_indices, int64_t index_vector_dim) {
61 // Transpose the non-index-vector dimensions to the front.
62 TF_ASSIGN_OR_RETURN(
63 HloInstruction * transposed_scatter_indices,
64 TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim));
65 if (scatter_indices->shape().rank() == index_vector_dim + 1 &&
66 scatter_indices->shape().dimensions(index_vector_dim) == 1) {
67 auto new_shape =
68 ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape());
69 TF_ASSIGN_OR_RETURN(scatter_indices,
70 MakeReshapeHlo(new_shape, scatter_indices));
71 }
72 bool indices_are_scalar =
73 index_vector_dim == scatter_indices->shape().dimensions_size();
74
75 // The number of dimensions in scatter_indices that are index dimensions.
76 const int64_t index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1;
77
78 // If there is only one index (i.e. scatter_indices has rank 1 and this
79 // scatter is really just a dynamic update slice) add a leading degenerate
80 // dimension for uniformity. Otherwise create a "collapsed" leading dimension
81 // that subsumes all of the non-index-vector dimensions.
82 const Shape& shape = transposed_scatter_indices->shape();
83 if (shape.dimensions_size() == index_dims_in_scatter_indices) {
84 return PrependDegenerateDims(transposed_scatter_indices, 1);
85 } else {
86 // Collapse all but the dimensions (0 or 1) in scatter_indices containing
87 // the index vectors.
88 return CollapseFirstNDims(
89 transposed_scatter_indices,
90 shape.dimensions_size() - index_dims_in_scatter_indices);
91 }
92 }
93
94 // Permutes the `updates` tensor such that all the scatter dims appear in the
95 // major dimensions and all the window dimensions appear in the minor
96 // dimensions.
PermuteScatterAndWindowDims(HloInstruction * updates,absl::Span<const int64_t> update_window_dims)97 static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
98 HloInstruction* updates, absl::Span<const int64_t> update_window_dims) {
99 std::vector<int64_t> permutation;
100 const int64_t updates_rank = updates->shape().rank();
101 permutation.reserve(updates_rank);
102
103 for (int64_t i = 0; i < updates_rank; ++i) {
104 bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i);
105 if (is_scatter_dim) {
106 permutation.push_back(i);
107 }
108 }
109 for (auto window_dim : update_window_dims) {
110 permutation.push_back(window_dim);
111 }
112
113 return MakeTransposeHlo(updates, permutation);
114 }
115
116 // Expands or contracts the scatter indices in the updates tensor.
AdjustScatterDims(const Shape & scatter_indices_shape,HloInstruction * updates,int64_t index_vector_dim)117 static StatusOr<HloInstruction*> AdjustScatterDims(
118 const Shape& scatter_indices_shape, HloInstruction* updates,
119 int64_t index_vector_dim) {
120 int64_t num_scatter_dims = scatter_indices_shape.dimensions_size();
121 if (index_vector_dim < scatter_indices_shape.dimensions_size()) {
122 --num_scatter_dims;
123 }
124 if (num_scatter_dims == 0) {
125 // If there are no scatter dims, this must be a dynamic-update-slice kind of
126 // scatter. In this case, we prepend a degenerate dimension to work
127 // uniformly in the while loop.
128 return PrependDegenerateDims(updates, 1);
129 }
130 return CollapseFirstNDims(updates, num_scatter_dims);
131 }
132
133 // Expands an index vector from the scatter_indices tensor into a vector that
134 // can be used to dynamic-update-slice to perform the scatter update.
ExpandIndexVectorIntoOperandSpace(HloInstruction * index_vector,const ScatterDimensionNumbers & dim_numbers,int64_t operand_rank)135 static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
136 HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers,
137 int64_t operand_rank) {
138 HloComputation* computation = index_vector->parent();
139 const Shape& index_shape = index_vector->shape();
140
141 // Scatter of a scalar. Return a zero-sized vector of indices.
142 if (operand_rank == 0) {
143 return computation->AddInstruction(HloInstruction::CreateConstant(
144 LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0})));
145 }
146
147 HloInstruction* zero =
148 computation->AddInstruction(HloInstruction::CreateConstant(
149 LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
150
151 // We extract out individual components from the smaller index and concatenate
152 // them (interspersing zeros as needed) into the larger index.
153 std::vector<HloInstruction*> expanded_index_components;
154
155 for (int i = 0; i < operand_rank; i++) {
156 int64_t index_vector_dim_index =
157 FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i);
158 if (index_vector_dim_index !=
159 dim_numbers.scatter_dims_to_operand_dims_size()) {
160 TF_ASSIGN_OR_RETURN(
161 HloInstruction * component_to_concat,
162 MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
163 /*limit_indices=*/{index_vector_dim_index + 1},
164 /*strides=*/{1}));
165 expanded_index_components.push_back(component_to_concat);
166 } else {
167 expanded_index_components.push_back(zero);
168 }
169 }
170
171 return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
172 }
173
CheckIndexValidity(HloComputation * computation,HloInstruction * index,absl::Span<const int64_t> operand_dims,absl::Span<const int64_t> window_sizes,HloModule * module)174 static StatusOr<HloInstruction*> CheckIndexValidity(
175 HloComputation* computation, HloInstruction* index,
176 absl::Span<const int64_t> operand_dims,
177 absl::Span<const int64_t> window_sizes, HloModule* module) {
178 DCHECK_NE(nullptr, module);
179 DCHECK_EQ(operand_dims.size(), window_sizes.size());
180
181 // Valid range for the index: [0, operand_dims - window_sizes]
182
183 // Check if the index has any negative values.
184 HloInstruction* zero_index = BroadcastZeros(
185 computation, index->shape().element_type(), index->shape().dimensions());
186 TF_ASSIGN_OR_RETURN(
187 HloInstruction * negative_index_check,
188 MakeCompareHlo(ComparisonDirection::kLe, zero_index, index));
189
190 // Check if the index is OOB w.r.t. the operand dimensions and window sizes.
191 std::vector<int64_t> max_valid_index(operand_dims.size());
192 for (int i = 0; i < operand_dims.size(); ++i) {
193 max_valid_index[i] = operand_dims[i] - window_sizes[i];
194 }
195 TF_ASSIGN_OR_RETURN(
196 HloInstruction * max_valid_index_constant,
197 MakeR1ConstantHlo<int64_t>(computation, index->shape().element_type(),
198 max_valid_index));
199 TF_ASSIGN_OR_RETURN(HloInstruction * oob_index_check,
200 MakeCompareHlo(ComparisonDirection::kGe,
201 max_valid_index_constant, index));
202
203 // Combine the results of the two checks above.
204 TF_ASSIGN_OR_RETURN(
205 HloInstruction * valid_index,
206 MakeBinaryHlo(HloOpcode::kAnd, negative_index_check, oob_index_check));
207
208 // Reduce the index validity check vector into a scalar predicate.
209 auto reduction_init = computation->AddInstruction(
210 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
211 TF_ASSIGN_OR_RETURN(
212 HloInstruction * valid_index_reduced,
213 MakeReduceHlo(valid_index, reduction_init, HloOpcode::kAnd, module));
214
215 // Return a broadcasted value of the scalar predicate to the same size as the
216 // window.
217 return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes);
218 }
219
CallAndGetOutput(HloComputation * original,int output_index)220 static StatusOr<HloComputation*> CallAndGetOutput(HloComputation* original,
221 int output_index) {
222 HloInstruction* original_root = original->root_instruction();
223 if (!original_root->shape().IsTuple()) {
224 return original;
225 }
226 HloComputation* new_comp = [&] {
227 HloComputation::Builder builder(
228 absl::StrCat(original->name(), ".dup.", output_index));
229 for (int i = 0, n = original->num_parameters(); i < n; ++i) {
230 HloInstruction* original_param = original->parameter_instruction(i);
231 builder.AddInstruction(HloInstruction::CreateParameter(
232 i, original_param->shape(), original_param->name()));
233 }
234 return original->parent()->AddEmbeddedComputation(builder.Build());
235 }();
236 HloInstruction* call_original = new_comp->AddInstruction(
237 HloInstruction::CreateCall(original_root->shape(),
238 new_comp->parameter_instructions(), original));
239 new_comp->set_root_instruction(
240 new_comp->AddInstruction(
241 HloInstruction::CreateGetTupleElement(call_original, output_index)),
242 /*accept_different_shape=*/true);
243 TF_RETURN_IF_ERROR(CallInliner::Inline(call_original).status());
244 return new_comp;
245 }
246
247 // Body of the while loop that performs the scatter operation using other HLOs.
ScatterLoopBody(HloScatterInstruction * scatter,HloInstruction * induction_var,absl::Span<HloInstruction * const> loop_state)248 static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
249 HloScatterInstruction* scatter, HloInstruction* induction_var,
250 absl::Span<HloInstruction* const> loop_state) {
251 const ScatterDimensionNumbers& dim_numbers =
252 scatter->scatter_dimension_numbers();
253 CHECK_EQ(loop_state.size(), scatter->operand_count());
254 auto operands = loop_state.first(scatter->scatter_operand_count());
255 HloInstruction* scatter_indices = loop_state[operands.size()];
256 auto updates = loop_state.last(operands.size());
257
258 bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1;
259
260 // Build a vector form of the induction variable of the while loop.
261 HloInstruction* induction_var_as_vector =
262 MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
263 /*result_shape_bounds=*/{1});
264
265 // Pick the index to scatter from scatter_indices based on the induction_var
266 // and transform that to an index into the `operand` space.
267 HloInstruction* index_vector;
268 if (has_scalar_indices) {
269 TF_ASSIGN_OR_RETURN(
270 index_vector,
271 MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1}));
272 } else {
273 TF_ASSIGN_OR_RETURN(
274 HloInstruction * index_into_scatter_indices,
275 PadVectorWithZeros(induction_var_as_vector,
276 /*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
277 int index_vector_size = scatter_indices->shape().dimensions(1);
278 TF_ASSIGN_OR_RETURN(
279 HloInstruction * index_vector_2d,
280 MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices,
281 {1, index_vector_size}));
282 TF_ASSIGN_OR_RETURN(index_vector,
283 ElideDegenerateDims(index_vector_2d, {0}));
284 }
285 TF_ASSIGN_OR_RETURN(
286 HloInstruction * scatter_slice_start,
287 ExpandIndexVectorIntoOperandSpace(
288 index_vector, dim_numbers, operands[0]->shape().dimensions_size()));
289
290 // Extract the slice to be used to update from `updates` tensor for the
291 // induction_var corresponding to this iteration of the while loop.
292 TF_ASSIGN_OR_RETURN(
293 HloInstruction * index_into_updates,
294 PadVectorWithZeros(
295 induction_var_as_vector, /*zeros_to_prepend=*/0,
296 /*zeros_to_append=*/updates[0]->shape().dimensions_size() - 1));
297 std::vector<int64_t> update_slice_bounds(
298 updates[0]->shape().dimensions().begin(),
299 updates[0]->shape().dimensions().end());
300 update_slice_bounds[0] = 1;
301
302 absl::InlinedVector<HloInstruction*, 2> map_operands(
303 operands.size() + updates.size(), nullptr);
304 auto operand_slices_to_update =
305 absl::MakeSpan(map_operands).first(operands.size());
306 auto update_slices_with_dims_inserted =
307 absl::MakeSpan(map_operands).last(updates.size());
308 absl::Span<const int64_t> actual_update_slice_dims;
309 for (int i = 0, n = operands.size(); i < n; ++i) {
310 HloInstruction* update = updates[i];
311 TF_ASSIGN_OR_RETURN(
312 HloInstruction * update_slice,
313 MakeDynamicSliceHlo(update, index_into_updates, update_slice_bounds));
314 TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter,
315 ElideDegenerateDims(update_slice, {0}));
316 TF_ASSIGN_OR_RETURN(
317 HloInstruction * update_slice_with_dims_inserted,
318 InsertDegenerateDims(update_slice_for_scatter,
319 dim_numbers.inserted_window_dims()));
320 update_slices_with_dims_inserted[i] = update_slice_with_dims_inserted;
321 // Note that the following transformation assumes that both DynamicSlice and
322 // DynamicUpdateSlice follow the same semantics for OOB indices. For
323 // example, if there are negative indices and DynamicSlice uses "clamping"
324 // semantics, then the extracted data will be "shifted". Since
325 // DynamicUpdateSlice also follows the same "clamping" semantics, writing
326 // the update will also be "shifted" by exactly the same amount. So, this
327 // transformation is correct as long as the semantics of handling OOB
328 // indices remain the same in DynamicSlice and DynamicUpdateSlice.
329
330 // Extract the slice to update from `operand` tensor.
331 HloInstruction* operand = operands[i];
332 const Shape& update_slice_shape = update_slice_with_dims_inserted->shape();
333 TF_ASSIGN_OR_RETURN(HloInstruction * operand_slice_to_update,
334 MakeDynamicSliceHlo(operand, scatter_slice_start,
335 update_slice_shape.dimensions()));
336 operand_slices_to_update[i] = operand_slice_to_update;
337 if (i == 0) {
338 actual_update_slice_dims = update_slice_shape.dimensions();
339 } else {
340 TF_RET_CHECK(actual_update_slice_dims == update_slice_shape.dimensions());
341 }
342 }
343
344 TF_ASSIGN_OR_RETURN(
345 HloInstruction * is_index_valid,
346 CheckIndexValidity(operands[0]->parent(), scatter_slice_start,
347 operands[0]->shape().dimensions(),
348 actual_update_slice_dims, scatter->GetModule()));
349
350 // Write the updated value of the slice into `operand` tensor.
351 std::vector<HloInstruction*> updated_loop_state;
352 updated_loop_state.reserve(loop_state.size());
353 for (int i = 0, n = operands.size(); i < n; ++i) {
354 // Compute the new value for the slice to be updated in `operand` tensor by
355 // combining the existing value and the update value using the update
356 // computation.
357 // NOTE: For scatters with N outputs, we currently have duplicate the Map
358 // computation N times because we don't support multioutput Map yet.
359 TF_ASSIGN_OR_RETURN(HloComputation * to_apply,
360 CallAndGetOutput(scatter->to_apply(), i));
361 TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand_slice,
362 MakeMapHlo(map_operands, to_apply));
363 // Select the updated operand only if the index is valid. If not, select the
364 // original value.
365 TF_ASSIGN_OR_RETURN(HloInstruction * updates_to_apply,
366 MakeSelectHlo(is_index_valid, updated_operand_slice,
367 operand_slices_to_update[i]));
368 TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand,
369 MakeDynamicUpdateSliceHlo(operands[i], updates_to_apply,
370 scatter_slice_start));
371 updated_loop_state.push_back(updated_operand);
372 }
373 updated_loop_state.push_back(scatter_indices);
374 absl::c_copy(updates, std::back_inserter(updated_loop_state));
375
376 return updated_loop_state;
377 }
378
ScatterTripCount(const HloScatterInstruction * scatter)379 static int64_t ScatterTripCount(const HloScatterInstruction* scatter) {
380 // Compute the trip count for the while loop to be used for scatter. This
381 // should be the number of indices we should scatter into the operand.
382 const HloInstruction* scatter_indices = scatter->scatter_indices();
383 const Shape& scatter_indices_shape = scatter_indices->shape();
384 const ScatterDimensionNumbers& dim_numbers =
385 scatter->scatter_dimension_numbers();
386 int64_t scatter_loop_trip_count = 1;
387 for (int64_t i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
388 if (i != dim_numbers.index_vector_dim()) {
389 scatter_loop_trip_count *= scatter_indices_shape.dimensions(i);
390 }
391 }
392 return scatter_loop_trip_count;
393 }
394
395 // High Level Algorithm.
396 //
397 // 1. Canonicalize the scatter_indices tensor such that it has rank 2, where
398 // each row is an index into the operand.
399 // 2. Canonicalize the updates tensor such that is has rank `num_window_dims+1`
400 // and the scatter dim is the most-major dimension.
401 // 3. Iterate over the set of indices in the canonicalized scatter_indices
402 // tensor using a while loop, updating the operand for each such index. Each
403 // iteration of this while loop performs the following:
404 // a. Pick the index from scatter_indices for this iteration.
405 // b. Transfrom this index into an index into the operand space.
406 // c. Extract the slice to be used to update from the updates tensor.
407 // d. Extract the slice to update from the operand tensor.
408 // e. Compute the new value for the slice to update by combining the slices
409 // from c. and d. using the update_computation of scatter.
410 // f. Write the updated value of the slice into the operand tensor.
411
ExpandInstruction(HloInstruction * inst)412 StatusOr<HloInstruction*> ScatterExpander::ExpandInstruction(
413 HloInstruction* inst) {
414 auto* scatter = Cast<HloScatterInstruction>(inst);
415 auto scatter_operands = scatter->scatter_operands();
416 HloInstruction* scatter_indices = scatter->scatter_indices();
417 auto scatter_updates = scatter->scatter_updates();
418 const ScatterDimensionNumbers& dim_numbers =
419 scatter->scatter_dimension_numbers();
420
421 // If the updates tensors are empty, there is no need to update the operands.
422 // The operands can be forwarded.
423 if (ShapeUtil::IsZeroElementArray(scatter_updates[0]->shape())) {
424 if (scatter_operands.size() == 1) {
425 return scatter_operands[0];
426 }
427 return scatter->parent()->AddInstruction(
428 HloInstruction::CreateTuple(scatter_operands));
429 }
430
431 // Compute the trip count for the while loop to be used for scatter. This
432 // should be the number of indices we should scatter into the operand.
433 int64_t scatter_loop_trip_count = ScatterTripCount(scatter);
434 if (!IsInt32(scatter_loop_trip_count)) {
435 return Unimplemented(
436 "Scatter operations with more than 2147483647 scatter indices are not "
437 "supported. This error occurred for %s.",
438 scatter->ToString());
439 }
440
441 // Canonicalize the scatter_indices, after which the size of its most-major
442 // dimension must be same as the while loop trip count.
443 TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices,
444 CanonicalizeScatterIndices(
445 scatter_indices, dim_numbers.index_vector_dim()));
446 CHECK_EQ(scatter_loop_trip_count,
447 canonical_scatter_indices->shape().dimensions(0));
448
449 // Canonicalize the updates, after which the size of its most-major dimension
450 // must be same as the while loop trip count.
451 std::vector<HloInstruction*> adjusted_canonical_updates;
452 adjusted_canonical_updates.reserve(scatter_updates.size());
453 for (HloInstruction* update : scatter_updates) {
454 TF_ASSIGN_OR_RETURN(
455 HloInstruction * canonical_update,
456 PermuteScatterAndWindowDims(update, dim_numbers.update_window_dims()));
457 TF_ASSIGN_OR_RETURN(
458 HloInstruction * adjusted_canonical_update,
459 AdjustScatterDims(scatter_indices->shape(), canonical_update,
460 dim_numbers.index_vector_dim()));
461 CHECK_EQ(scatter_loop_trip_count,
462 adjusted_canonical_update->shape().dimensions(0));
463 adjusted_canonical_updates.push_back(adjusted_canonical_update);
464 }
465
466 // The while loop that implements the scatter operation.
467 std::vector<HloInstruction*> loop_state;
468 loop_state.reserve(scatter->operand_count());
469 absl::c_copy(scatter_operands, std::back_inserter(loop_state));
470 loop_state.push_back(canonical_scatter_indices);
471 absl::c_copy(adjusted_canonical_updates, std::back_inserter(loop_state));
472 StatusOr<std::vector<HloInstruction*>> scatter_loop_result_status =
473 WhileUtil::MakeCountedLoop(
474 scatter->parent(), scatter_loop_trip_count, loop_state,
475 [scatter](HloInstruction* induction_var,
476 const std::vector<HloInstruction*>& loop_state) {
477 return ScatterLoopBody(scatter, induction_var, loop_state);
478 },
479 scatter->metadata());
480 TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> scatter_loop_result,
481 scatter_loop_result_status);
482 auto results =
483 absl::MakeSpan(scatter_loop_result).first(scatter_operands.size());
484 return MaybeMakeTuple(results);
485 }
486
InstructionMatchesPattern(HloInstruction * inst)487 bool ScatterExpander::InstructionMatchesPattern(HloInstruction* inst) {
488 auto* scatter = DynCast<HloScatterInstruction>(inst);
489 return scatter &&
490 (mode_ == kEliminateAllScatters || ScatterTripCount(scatter) == 1);
491 }
492
493 } // namespace xla
494