1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/tests/test_utils.h"
17
18 #include <cmath>
19 #include <memory>
20
21 #include "absl/base/casts.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
28 #include "tensorflow/compiler/xla/service/transfer_manager.h"
29
30 namespace xla {
31
32 namespace {
33
34 template <typename FloatT, typename GeneratorT>
PopulateWithRandomFloatingPointData(Literal * literal,std::minstd_rand0 * engine)35 void PopulateWithRandomFloatingPointData(Literal* literal,
36 std::minstd_rand0* engine) {
37 std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
38 for (FloatT& value : literal->data<FloatT>()) {
39 value = static_cast<FloatT>(generator(*engine));
40 }
41 }
42
43 // Populates a floating point literal with random floating points sampled from a
44 // uniform-log distribution spanning approximately the entire range of the
45 // representable floating point.
46 template <typename FloatT>
PopulateWithRandomFullRangeFloatingPointData(Literal * literal,std::minstd_rand0 * engine)47 void PopulateWithRandomFullRangeFloatingPointData(Literal* literal,
48 std::minstd_rand0* engine) {
49 constexpr float kSpecialValueProbability = 1e-6;
50 constexpr float kSpecialValues[] = {+0.F,
51 -0.F,
52 1.F,
53 -1.F,
54 std::numeric_limits<float>::infinity(),
55 -std::numeric_limits<float>::infinity()};
56 constexpr int kNumSpecialValues = sizeof(kSpecialValues) / sizeof(float);
57 std::uniform_real_distribution<float> special_value_gen(0, 1);
58
59 // Generates floating points with a log-uniform distribution. This causes the
60 // exponent of the floating point to have a uniform distribution.
61 int min_exp, max_exp;
62 if (std::is_same<FloatT, bfloat16>()) {
63 min_exp = std::numeric_limits<float>::min_exponent;
64 max_exp = std::numeric_limits<float>::max_exponent;
65 } else {
66 min_exp = std::numeric_limits<FloatT>::min_exponent;
67 max_exp = std::numeric_limits<FloatT>::max_exponent;
68 }
69 std::uniform_real_distribution<double> generator(min_exp - 1, max_exp - 1);
70
71 for (FloatT& value : literal->data<FloatT>()) {
72 // Each special value has a kSpecialValueProbability chance to be generated
73 // instead of sampling using the normal distributions.
74 if (special_value_gen(*engine) <
75 kSpecialValueProbability * kNumSpecialValues) {
76 value =
77 static_cast<FloatT>(kSpecialValues[(*engine)() % kNumSpecialValues]);
78 } else {
79 float sign = ((*engine)() % 2 == 0) ? 1 : -1;
80 value = static_cast<FloatT>(pow(2, generator(*engine)) * sign);
81 }
82 }
83 }
84
85 template <typename FloatT>
86 void PopulateWithIntNext(Literal* literal);
87
88 template <>
PopulateWithIntNext(Literal * literal)89 void PopulateWithIntNext<half>(Literal* literal) {
90 // Duplicates may be generated if we don't have enough bits.
91 uint16_t next_value = 0;
92 for (half& value : literal->data<half>()) {
93 // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
94 // the sign bit. We could be less wasteful, but this is best-effort anyway.
95 uint16_t exponent_msb = next_value & 0x4000;
96 value = Eigen::numext::bit_cast<half, uint16_t>((next_value & 0xBFFF) |
97 (exponent_msb << 1));
98 next_value++;
99 }
100 }
101
102 template <>
PopulateWithIntNext(Literal * literal)103 void PopulateWithIntNext<bfloat16>(Literal* literal) {
104 // Duplicates may be generated if we don't have enough bits.
105 // Start at 0x80 rather than 0 to avoid denormals.
106 uint16_t next_value = 0x80;
107 for (bfloat16& value : literal->data<bfloat16>()) {
108 // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
109 // the sign bit. We could be less wasteful, but this is best-effort anyway.
110 uint16_t exponent_msb = next_value & 0x4000;
111 value = Eigen::numext::bit_cast<bfloat16, uint16_t>((next_value & 0xBFFF) |
112 (exponent_msb << 1));
113 next_value++;
114 }
115 }
116
117 template <typename FloatT>
PopulateWithNextAfter(Literal * literal)118 void PopulateWithNextAfter(Literal* literal) {
119 // Duplicates may be generated if the number of elements in the literal
120 // exceeds the number of positive values supported by the type.
121 float next_value = std::numeric_limits<float>::min();
122 for (float& value : literal->data<float>()) {
123 value = next_value;
124 next_value = std::nextafter(next_value, std::numeric_limits<float>::max());
125 }
126 }
127
128 template <typename FloatT,
129 typename std::enable_if<std::is_same<bfloat16, FloatT>::value ||
130 std::is_same<half, FloatT>::value,
131 int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)132 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
133 PopulateWithIntNext<FloatT>(literal);
134 std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
135 *engine);
136 }
137
138 template <typename FloatT,
139 typename std::enable_if<!std::is_same<bfloat16, FloatT>::value &&
140 !std::is_same<half, FloatT>::value,
141 int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)142 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
143 PopulateWithNextAfter<FloatT>(literal);
144 std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
145 *engine);
146 }
147
148 template <typename FloatT>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)149 void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine,
150 bool no_duplicates, bool use_large_range) {
151 CHECK(engine != nullptr);
152 CHECK_EQ(literal->shape().element_type(),
153 primitive_util::NativeToPrimitiveType<FloatT>());
154 if (no_duplicates) {
155 PopulateWithNoDuplicateData<FloatT>(literal, engine);
156 } else if (use_large_range) {
157 PopulateWithRandomFullRangeFloatingPointData<FloatT>(literal, engine);
158 } else {
159 PopulateWithRandomFloatingPointData<FloatT, FloatT>(literal, engine);
160 }
161 }
162
163 template <typename ComplexT>
PopulateWithComplexData(Literal * result,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)164 void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine,
165 bool no_duplicates, bool use_large_range) {
166 using InnerFloatT = typename ComplexT::value_type;
167 CHECK(engine != nullptr);
168 CHECK_EQ(result->shape().element_type(),
169 primitive_util::NativeToPrimitiveType<ComplexT>());
170 Shape floating_point_shape = ShapeUtil::ChangeElementType(
171 result->shape(), primitive_util::NativeToPrimitiveType<InnerFloatT>());
172 Literal real_lit(floating_point_shape);
173 Literal imaginary_lit(floating_point_shape);
174
175 PopulateWithFloatingPointData<InnerFloatT>(&real_lit, engine, no_duplicates,
176 use_large_range);
177 PopulateWithFloatingPointData<InnerFloatT>(&imaginary_lit, engine,
178 no_duplicates, use_large_range);
179
180 absl::Span<const InnerFloatT> real_data = real_lit.data<InnerFloatT>();
181 absl::Span<const InnerFloatT> imaginary_data =
182 imaginary_lit.data<InnerFloatT>();
183 absl::Span<ComplexT> result_data = result->data<ComplexT>();
184 for (int i = 0; i < real_lit.data<InnerFloatT>().size(); i++) {
185 result_data[i] = ComplexT(real_data[i], imaginary_data[i]);
186 }
187 }
188
189 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)190 void PopulateWithFloatingPointData<half>(Literal* literal,
191 std::minstd_rand0* engine,
192 bool no_duplicates,
193 bool use_large_range) {
194 CHECK(engine != nullptr);
195 CHECK_EQ(literal->shape().element_type(),
196 primitive_util::NativeToPrimitiveType<half>());
197 if (no_duplicates) {
198 PopulateWithNoDuplicateData<half>(literal, engine);
199 } else if (use_large_range) {
200 PopulateWithRandomFullRangeFloatingPointData<half>(literal, engine);
201 } else {
202 PopulateWithRandomFloatingPointData<half, float>(literal, engine);
203 }
204 }
205
206 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)207 void PopulateWithFloatingPointData<bfloat16>(Literal* literal,
208 std::minstd_rand0* engine,
209 bool no_duplicates,
210 bool use_large_range) {
211 CHECK(engine != nullptr);
212 CHECK_EQ(literal->shape().element_type(),
213 primitive_util::NativeToPrimitiveType<bfloat16>());
214 if (no_duplicates) {
215 PopulateWithNoDuplicateData<bfloat16>(literal, engine);
216 } else if (use_large_range) {
217 PopulateWithRandomFullRangeFloatingPointData<bfloat16>(literal, engine);
218 } else {
219 PopulateWithRandomFloatingPointData<bfloat16, float>(literal, engine);
220 }
221 }
222
223 // uniform_int_distribution is not defined for 8-bit integers.
224 // Use 'short' for those types.
225 template <typename IntT>
226 struct RngT {
227 using type = IntT;
228 };
229
230 template <>
231 struct RngT<int8_t> {
232 using type = int16_t;
233 };
234
235 template <>
236 struct RngT<uint8_t> {
237 using type = uint16_t;
238 };
239
240 template <typename IntT>
PopulateWithRandomIntegralData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)241 void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
242 bool no_duplicates) {
243 CHECK(engine != nullptr);
244 CHECK_EQ(literal->shape().element_type(),
245 primitive_util::NativeToPrimitiveType<IntT>());
246 if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) <
247 std::numeric_limits<IntT>::max()) {
248 std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(), 0);
249 std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
250 *engine);
251 } else {
252 std::uniform_int_distribution<typename RngT<IntT>::type> generator(
253 std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
254 for (IntT& value : literal->data<IntT>()) {
255 value = generator(*engine);
256 }
257 }
258 }
259
260 // Similar to MakeFakeLiteral but takes a random number generator engine to
261 // enable reusing the engine across randomly generated literals. 'no_duplicates'
262 // indicates that there should be no duplicate values in each generated
263 // array. This is uniqueness is best-effort only. Some types (half and bfloat16)
264 // are not supported and uniqueness cannot be guaranteed if the number of
265 // elements exceeds the number of different values supported by the type.
MakeFakeLiteralInternal(const Shape & shape,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)266 StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
267 std::minstd_rand0* engine,
268 bool no_duplicates,
269 bool use_large_range) {
270 if (shape.IsTuple()) {
271 std::vector<Literal> elements;
272 const auto& shape_tuple_shapes = shape.tuple_shapes();
273 elements.reserve(shape_tuple_shapes.size());
274 for (const Shape& element_shape : shape_tuple_shapes) {
275 TF_ASSIGN_OR_RETURN(Literal element, MakeFakeLiteralInternal(
276 element_shape, engine,
277 no_duplicates, use_large_range));
278 elements.push_back(std::move(element));
279 }
280 return LiteralUtil::MakeTupleOwned(std::move(elements));
281 }
282 if (engine == nullptr) {
283 return Literal::CreateFromShape(shape);
284 }
285 // Clear tiles/element size in shape's layout before using it for creating
286 // literal.
287 Shape new_shape = shape;
288 new_shape.mutable_layout()->clear_tiles();
289 new_shape.mutable_layout()->set_element_size_in_bits(0);
290 Literal literal(new_shape);
291 switch (shape.element_type()) {
292 case BF16:
293 PopulateWithFloatingPointData<bfloat16>(&literal, engine, no_duplicates,
294 use_large_range);
295 break;
296 case F16:
297 PopulateWithFloatingPointData<half>(&literal, engine, no_duplicates,
298 use_large_range);
299 break;
300 case F32:
301 PopulateWithFloatingPointData<float>(&literal, engine, no_duplicates,
302 use_large_range);
303 break;
304 case F64:
305 PopulateWithFloatingPointData<double>(&literal, engine, no_duplicates,
306 use_large_range);
307 break;
308 case S8:
309 PopulateWithRandomIntegralData<int8_t>(&literal, engine, no_duplicates);
310 break;
311 case U8:
312 PopulateWithRandomIntegralData<uint8_t>(&literal, engine, no_duplicates);
313 break;
314 case S16:
315 PopulateWithRandomIntegralData<int16_t>(&literal, engine, no_duplicates);
316 break;
317 case U16:
318 PopulateWithRandomIntegralData<uint16_t>(&literal, engine, no_duplicates);
319 break;
320 case S32:
321 PopulateWithRandomIntegralData<int32_t>(&literal, engine, no_duplicates);
322 break;
323 case U32:
324 PopulateWithRandomIntegralData<uint32_t>(&literal, engine, no_duplicates);
325 break;
326 case S64:
327 PopulateWithRandomIntegralData<int64_t>(&literal, engine, no_duplicates);
328 break;
329 case U64:
330 PopulateWithRandomIntegralData<uint64_t>(&literal, engine, no_duplicates);
331 break;
332 case C64:
333 PopulateWithComplexData<complex64>(&literal, engine, no_duplicates,
334 use_large_range);
335 break;
336 case C128:
337 PopulateWithComplexData<complex128>(&literal, engine, no_duplicates,
338 use_large_range);
339 break;
340 case PRED: {
341 std::uniform_int_distribution<int> generator(0, 1);
342 TF_CHECK_OK(
343 literal.Populate<bool>([&](absl::Span<const int64_t> /*indices*/) {
344 return generator(*engine);
345 }));
346 break;
347 }
348 default:
349 return Unimplemented("Unsupported type for fake literal generation: %s",
350 ShapeUtil::HumanString(shape));
351 }
352 return std::move(literal);
353 }
354
355 template <typename IntT>
PopulateWithRandomIntegralDataWithBounds(Literal * literal,std::minstd_rand0 * engine,IntT min,IntT max)356 void PopulateWithRandomIntegralDataWithBounds(Literal* literal,
357 std::minstd_rand0* engine,
358 IntT min, IntT max) {
359 CHECK(engine != nullptr);
360 CHECK_EQ(literal->shape().element_type(),
361 primitive_util::NativeToPrimitiveType<IntT>());
362 std::uniform_int_distribution<typename RngT<IntT>::type> generator(min, max);
363 for (IntT& value : literal->data<IntT>()) {
364 value = generator(*engine);
365 }
366 }
367
368 // Same as MakeFakeLiteralInternal but generates random numbers in the given
369 // range [min, max]. Currently this works only for INT types.
MakeFakeLiteralInternalWithBounds(const Shape & shape,std::minstd_rand0 * engine,int64_t min,int64_t max,bool is_sorted)370 StatusOr<Literal> MakeFakeLiteralInternalWithBounds(const Shape& shape,
371 std::minstd_rand0* engine,
372 int64_t min, int64_t max,
373 bool is_sorted) {
374 if (shape.IsTuple()) {
375 std::vector<Literal> elements;
376 const auto& shape_tuple_shapes = shape.tuple_shapes();
377 elements.reserve(shape_tuple_shapes.size());
378 for (const Shape& element_shape : shape_tuple_shapes) {
379 TF_ASSIGN_OR_RETURN(Literal element,
380 MakeFakeLiteralInternalWithBounds(
381 element_shape, engine, min, max, is_sorted));
382 elements.push_back(std::move(element));
383 }
384 return LiteralUtil::MakeTupleOwned(std::move(elements));
385 }
386 if (engine == nullptr) {
387 return Literal::CreateFromShape(shape);
388 }
389 // Clear tiles/element size in shape's layout before using it for creating
390 // literal.
391 Shape new_shape = shape;
392 new_shape.mutable_layout()->clear_tiles();
393 new_shape.mutable_layout()->set_element_size_in_bits(0);
394 Literal literal(new_shape);
395 switch (shape.element_type()) {
396 case S8:
397 PopulateWithRandomIntegralDataWithBounds<int8_t>(
398 &literal, engine, static_cast<int8_t>(min), static_cast<int8_t>(max));
399 if (is_sorted) {
400 std::sort(literal.data<int8_t>().begin(), literal.data<int8_t>().end());
401 }
402 break;
403 case U8:
404 PopulateWithRandomIntegralDataWithBounds<uint8_t>(
405 &literal, engine, static_cast<uint8_t>(min),
406 static_cast<uint8_t>(max));
407 if (is_sorted) {
408 std::sort(literal.data<uint8_t>().begin(),
409 literal.data<uint8_t>().end());
410 }
411 break;
412 case S16:
413 PopulateWithRandomIntegralDataWithBounds<int16_t>(
414 &literal, engine, static_cast<int16_t>(min),
415 static_cast<int16_t>(max));
416 if (is_sorted) {
417 std::sort(literal.data<int16_t>().begin(),
418 literal.data<int16_t>().end());
419 }
420 break;
421 case U16:
422 PopulateWithRandomIntegralDataWithBounds<uint16_t>(
423 &literal, engine, static_cast<uint16_t>(min),
424 static_cast<uint16_t>(max));
425 if (is_sorted) {
426 std::sort(literal.data<uint16_t>().begin(),
427 literal.data<uint16_t>().end());
428 }
429 break;
430 case S32:
431 PopulateWithRandomIntegralDataWithBounds<int32_t>(
432 &literal, engine, static_cast<int32_t>(min),
433 static_cast<int32_t>(max));
434 if (is_sorted) {
435 std::sort(literal.data<int32_t>().begin(),
436 literal.data<int32_t>().end());
437 }
438 break;
439 case U32:
440 PopulateWithRandomIntegralDataWithBounds<uint32_t>(
441 &literal, engine, static_cast<uint32_t>(min),
442 static_cast<uint32_t>(max));
443 if (is_sorted) {
444 std::sort(literal.data<uint32_t>().begin(),
445 literal.data<uint32_t>().end());
446 }
447 break;
448 case S64:
449 PopulateWithRandomIntegralDataWithBounds<int64_t>(
450 &literal, engine, static_cast<int64_t>(min),
451 static_cast<int64_t>(max));
452 if (is_sorted) {
453 std::sort(literal.data<int64_t>().begin(),
454 literal.data<int64_t>().end());
455 }
456 break;
457 case U64:
458 PopulateWithRandomIntegralDataWithBounds<uint64_t>(
459 &literal, engine, static_cast<uint64_t>(min),
460 static_cast<uint64_t>(max));
461 if (is_sorted) {
462 std::sort(literal.data<uint64_t>().begin(),
463 literal.data<uint64_t>().end());
464 }
465 break;
466 default:
467 return Unimplemented(
468 "Unsupported type for fake random literal generation with bounds: %s",
469 ShapeUtil::HumanString(shape));
470 }
471 return std::move(literal);
472 }
473
474 enum class ConstantType { kUnknown, kZero, kOne };
475
476 // Return the constant type required by this computation, if known.
GetInitValue(const HloComputation & computation)477 ConstantType GetInitValue(const HloComputation& computation) {
478 // TODO(b/77635120): Add init values, for min, max, and their arg variants.
479 const HloInstruction* const root = computation.root_instruction();
480 if (computation.num_parameters() != 2 || root->operand_count() != 2 ||
481 root->operand(0)->opcode() != HloOpcode::kParameter ||
482 root->operand(1)->opcode() != HloOpcode::kParameter ||
483 root->operand(0) == root->operand(1)) {
484 return ConstantType::kUnknown;
485 }
486
487 switch (root->opcode()) {
488 case HloOpcode::kAdd:
489 return ConstantType::kZero;
490 case HloOpcode::kMultiply:
491 return ConstantType::kOne;
492 default:
493 return ConstantType::kUnknown;
494 }
495 }
496
497 // Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random
498 // initialization value.
NeedsInitValue(const HloUse & use)499 bool NeedsInitValue(const HloUse& use) {
500 const HloInstruction* const instruction = use.instruction;
501 const HloOpcode opcode = instruction->opcode();
502 const int64_t op_num = use.operand_number;
503 return ((opcode == HloOpcode::kReduceWindow && op_num == 1) ||
504 (opcode == HloOpcode::kSelectAndScatter && op_num == 2) ||
505 (opcode == HloOpcode::kReduce &&
506 op_num >= instruction->operand_count() / 2));
507 }
508
509 // Generate random values that are constrained to the input_shape minus the
510 // output_shape so as not to produce wrapping slices, for instance.
MakeRandomIndex(int64_t index_bound,std::minstd_rand0 * engine)511 Literal MakeRandomIndex(int64_t index_bound, std::minstd_rand0* engine) {
512 std::uniform_int_distribution<int32_t> generator(0, index_bound);
513 return LiteralUtil::CreateR0<int32_t>(generator(*engine));
514 }
515
516 // Returns true if `dest' is reachable from `src' through data-formatting and
517 // custom call instructions within the same computation.
ReachableViaDataFormatting(const HloInstruction * src,const HloInstruction * dest)518 bool ReachableViaDataFormatting(const HloInstruction* src,
519 const HloInstruction* dest) {
520 if (src == dest) {
521 return true;
522 }
523 switch (dest->opcode()) {
524 case HloOpcode::kReshape:
525 case HloOpcode::kTranspose:
526 case HloOpcode::kCopy:
527 case HloOpcode::kSlice:
528 break;
529 case HloOpcode::kCustomCall:
530 if (dest->custom_call_target() == "AssumeGatherIndicesInBound") {
531 break;
532 }
533 return false;
534 default:
535 return false;
536 }
537 for (const auto* operand : dest->operands()) {
538 if (ReachableViaDataFormatting(src, operand)) {
539 return true;
540 }
541 }
542 return false;
543 }
544
545 // Use dataflow analysis on each parameter to see if there are uses that would
546 // be problematic when generating input data. Returns the list of instructions
547 // that correspond to their uses.
548 //
549 // Should be paired with the CreateLiteralForConstrainedUses() function below.
FindConstrainedUses(const HloDataflowAnalysis & dataflow,const HloInstruction & param)550 std::vector<HloInstruction*> FindConstrainedUses(
551 const HloDataflowAnalysis& dataflow, const HloInstruction& param) {
552 std::vector<HloInstruction*> constrained_uses;
553 for (const auto& pair : dataflow.GetInstructionValueSet(¶m)) {
554 const HloValue& value = dataflow.GetUniqueValueAt(¶m, pair.first);
555 for (const HloUse& use : value.GetUses()) {
556 HloInstruction* instruction = use.instruction;
557 const HloOpcode opcode = instruction->opcode();
558 const int64_t op_num = use.operand_number;
559 if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) ||
560 (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) {
561 constrained_uses.push_back(instruction);
562 } else if ((opcode == HloOpcode::kGather ||
563 opcode == HloOpcode::kScatter) &&
564 op_num == 1) {
565 constrained_uses.push_back(instruction);
566 } else if (opcode == HloOpcode::kFusion) {
567 const HloInstruction* const to_analyze =
568 instruction->fused_parameter(op_num);
569 auto fused_uses = FindConstrainedUses(dataflow, *to_analyze);
570 constrained_uses.insert(constrained_uses.end(), fused_uses.begin(),
571 fused_uses.end());
572 } else if (NeedsInitValue(use)) {
573 constrained_uses.push_back(instruction);
574 } else if (opcode == HloOpcode::kConvert ||
575 opcode == HloOpcode::kReducePrecision) {
576 auto converted_uses = FindConstrainedUses(dataflow, *instruction);
577 constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
578 converted_uses.end());
579 } else if (opcode == HloOpcode::kSort &&
580 instruction->operand_count() >= 2 && op_num == 0) {
581 // Operand 0 of sort is the array of keys used for key/value
582 // (two-operand) kSort instructions. Since sort stability is not
583 // guaranteed, constrain keys of key-value sort not to have duplicates,
584 // since otherwise the value order may legitimately differ.
585 constrained_uses.push_back(instruction);
586 }
587 }
588 }
589
590 for (auto* instruction : param.parent()->instructions()) {
591 const HloOpcode opcode = instruction->opcode();
592 if (opcode == HloOpcode::kGather || opcode == HloOpcode::kScatter) {
593 if (instruction->operand(1) == ¶m) {
594 // Above already covers this case.
595 continue;
596 }
597 if (ReachableViaDataFormatting(¶m, instruction->operand(1))) {
598 constrained_uses.push_back(instruction);
599 }
600 }
601 }
602 return constrained_uses;
603 }
604
605 // Given a parameter, generate a random Literal to use as input if there exist
606 // no constrained uses in the dataflow graph. If such constraints exist,
607 // generate a constrained literal (either bounded in the case of indices, or
608 // zero in the case of init_values for reductions).
CreateLiteralForConstrainedUses(const absl::Span<HloInstruction * const> constrained_uses,const HloInstruction & param,const Shape & param_shape,std::minstd_rand0 * engine,bool use_large_range)609 StatusOr<Literal> CreateLiteralForConstrainedUses(
610 const absl::Span<HloInstruction* const> constrained_uses,
611 const HloInstruction& param, const Shape& param_shape,
612 std::minstd_rand0* engine, bool use_large_range) {
613 int64_t index_bound = INT64_MAX;
614 bool no_duplicates = false;
615 bool needs_constant = false;
616 bool needs_sorted_indices = false;
617 ConstantType constant_type = ConstantType::kUnknown;
618 for (HloInstruction* use : constrained_uses) {
619 switch (use->opcode()) {
620 case HloOpcode::kDynamicSlice:
621 case HloOpcode::kDynamicUpdateSlice: {
622 const Shape& indexed_shape = use->operand(0)->shape();
623 const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
624 ? use->shape()
625 : use->operand(1)->shape();
626 const int64_t first_index =
627 Cast<HloDynamicIndexInstruction>(use)->first_index_operand_number();
628 for (int64_t operand = first_index; operand < use->operand_count();
629 ++operand) {
630 if (use->operand(operand) == ¶m) {
631 index_bound = std::min(
632 index_bound,
633 ShapeUtil::GetDimension(indexed_shape, operand - first_index) -
634 ShapeUtil::GetDimension(slice_shape,
635 operand - first_index));
636 }
637 }
638 break;
639 }
640 case HloOpcode::kGather:
641 case HloOpcode::kScatter: {
642 const Shape& operand_shape = use->operand(0)->shape();
643 auto index_map = use->opcode() == HloOpcode::kGather
644 ? use->gather_dimension_numbers().start_index_map()
645 : use->scatter_dimension_numbers()
646 .scatter_dims_to_operand_dims();
647 for (const auto dim_in_operand : index_map) {
648 index_bound = std::min(index_bound,
649 operand_shape.dimensions(dim_in_operand) - 1);
650 }
651 if (use->opcode() == HloOpcode::kScatter) {
652 needs_sorted_indices |=
653 Cast<const HloScatterInstruction>(use)->indices_are_sorted();
654 } else {
655 needs_sorted_indices |=
656 Cast<const HloGatherInstruction>(use)->indices_are_sorted();
657 }
658 break;
659 }
660 case HloOpcode::kReduce:
661 case HloOpcode::kReduceWindow:
662 needs_constant = true;
663 constant_type = GetInitValue(*use->to_apply());
664 break;
665
666 case HloOpcode::kSelectAndScatter:
667 needs_constant = true;
668 constant_type = GetInitValue(*use->scatter());
669 break;
670
671 case HloOpcode::kSort:
672 no_duplicates = true;
673 break;
674
675 default:
676 return Unimplemented(
677 "Constrained operand generation not implemented for %s.",
678 use->ToString());
679 }
680 }
681 int constraint_count = 0;
682 constraint_count += no_duplicates ? 1 : 0;
683 constraint_count += (index_bound != INT64_MAX) ? 1 : 0;
684 constraint_count += needs_constant ? 1 : 0;
685 if (constraint_count > 1) {
686 return Unimplemented("Conflicting operand generation constraints.");
687 }
688 if (index_bound != INT64_MAX) {
689 return MakeFakeLiteralInternalWithBounds(param_shape, engine, 0,
690 index_bound, needs_sorted_indices);
691 } else if (needs_constant) {
692 switch (constant_type) {
693 case ConstantType::kZero:
694 return LiteralUtil::Zero(param_shape.element_type());
695 case ConstantType::kOne:
696 return LiteralUtil::One(param_shape.element_type());
697 case ConstantType::kUnknown:
698 // We want the identity element for the computation, but we don't really
699 // know what it is - so any value we generate will be just as wrong.
700 return MakeFakeLiteralInternal(param_shape, engine,
701 /*no_duplicates=*/false,
702 use_large_range);
703 }
704 } else {
705 return MakeFakeLiteralInternal(param_shape, engine, no_duplicates,
706 use_large_range);
707 }
708 }
709
710 // Given a module entry parameter, use the dataflow analysis to see if a
711 // special case literal must be created, or if we can generate fake data.
MakeConstrainedArgument(const HloDataflowAnalysis & dataflow,const HloInstruction & param,const Shape & param_shape,std::minstd_rand0 * engine,bool use_large_range)712 StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
713 const HloInstruction& param,
714 const Shape& param_shape,
715 std::minstd_rand0* engine,
716 bool use_large_range) {
717 const auto constrained_uses = FindConstrainedUses(dataflow, param);
718 return CreateLiteralForConstrainedUses(constrained_uses, param, param_shape,
719 engine, use_large_range);
720 }
721
722 } // namespace
723
MakeFakeLiteral(const Shape & shape,bool pseudo_random,bool use_large_range)724 StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random,
725 bool use_large_range) {
726 auto engine = pseudo_random ? std::make_unique<std::minstd_rand0>() : nullptr;
727 return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false,
728 use_large_range);
729 }
730
MakeFakeArguments(const HloModule * module,bool pseudo_random,bool use_large_range)731 StatusOr<std::vector<Literal>> MakeFakeArguments(const HloModule* module,
732 bool pseudo_random,
733 bool use_large_range) {
734 auto engine = pseudo_random ? std::make_unique<std::minstd_rand0>() : nullptr;
735 return MakeFakeArguments(module, engine.get(), use_large_range);
736 }
737
MakeFakeArguments(const HloModule * module,std::minstd_rand0 * engine,bool use_large_range)738 StatusOr<std::vector<Literal>> MakeFakeArguments(const HloModule* module,
739 std::minstd_rand0* engine,
740 bool use_large_range) {
741 TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
742 const auto params = module->entry_computation()->parameter_instructions();
743 std::vector<Literal> arguments(params.size());
744 for (int i = 0; i < params.size(); ++i) {
745 const HloModuleConfig& module_config = module->config();
746 const Shape& param_shape = (module_config.has_entry_computation_layout() &&
747 module_config.entry_computation_layout()
748 .parameter_layout(i)
749 .shape()
750 .is_static())
751 ? module_config.entry_computation_layout()
752 .parameter_layout(i)
753 .shape()
754 : params[i]->shape();
755
756 TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument(
757 *dataflow, *params[i], param_shape,
758 engine, use_large_range));
759 }
760 return std::move(arguments);
761 }
762
VerifyHloModule(HloModule * const module,bool layout_sensitive,bool allow_mixed_precision)763 Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
764 bool allow_mixed_precision) {
765 return HloVerifier(/*layout_sensitive=*/layout_sensitive,
766 /*allow_mixed_precision=*/allow_mixed_precision)
767 .Run(module)
768 .status();
769 }
770
CreateCanonicalDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)771 std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
772 HloInstruction* lhs,
773 HloInstruction* rhs) {
774 CHECK_LE(lhs->shape().rank(), 2);
775 CHECK_LE(rhs->shape().rank(), 2);
776 PrecisionConfig precision_config;
777 precision_config.mutable_operand_precision()->Resize(
778 2, PrecisionConfig::DEFAULT);
779 DotDimensionNumbers dot_dimension_numbers;
780 dot_dimension_numbers.add_lhs_contracting_dimensions(
781 lhs->shape().rank() > 1 ? 1 : 0);
782 dot_dimension_numbers.add_rhs_contracting_dimensions(0);
783 return std::make_unique<HloDotInstruction>(
784 shape, lhs, rhs, dot_dimension_numbers, precision_config);
785 }
786 } // namespace xla
787