xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/test_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tests/test_utils.h"
17 
18 #include <cmath>
19 #include <memory>
20 
21 #include "absl/base/casts.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
28 #include "tensorflow/compiler/xla/service/transfer_manager.h"
29 
30 namespace xla {
31 
32 namespace {
33 
34 template <typename FloatT, typename GeneratorT>
PopulateWithRandomFloatingPointData(Literal * literal,std::minstd_rand0 * engine)35 void PopulateWithRandomFloatingPointData(Literal* literal,
36                                          std::minstd_rand0* engine) {
37   std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
38   for (FloatT& value : literal->data<FloatT>()) {
39     value = static_cast<FloatT>(generator(*engine));
40   }
41 }
42 
43 // Populates a floating point literal with random floating points sampled from a
44 // uniform-log distribution spanning approximately the entire range of the
45 // representable floating point.
46 template <typename FloatT>
PopulateWithRandomFullRangeFloatingPointData(Literal * literal,std::minstd_rand0 * engine)47 void PopulateWithRandomFullRangeFloatingPointData(Literal* literal,
48                                                   std::minstd_rand0* engine) {
49   constexpr float kSpecialValueProbability = 1e-6;
50   constexpr float kSpecialValues[] = {+0.F,
51                                       -0.F,
52                                       1.F,
53                                       -1.F,
54                                       std::numeric_limits<float>::infinity(),
55                                       -std::numeric_limits<float>::infinity()};
56   constexpr int kNumSpecialValues = sizeof(kSpecialValues) / sizeof(float);
57   std::uniform_real_distribution<float> special_value_gen(0, 1);
58 
59   // Generates floating points with a log-uniform distribution. This causes the
60   // exponent of the floating point to have a uniform distribution.
61   int min_exp, max_exp;
62   if (std::is_same<FloatT, bfloat16>()) {
63     min_exp = std::numeric_limits<float>::min_exponent;
64     max_exp = std::numeric_limits<float>::max_exponent;
65   } else {
66     min_exp = std::numeric_limits<FloatT>::min_exponent;
67     max_exp = std::numeric_limits<FloatT>::max_exponent;
68   }
69   std::uniform_real_distribution<double> generator(min_exp - 1, max_exp - 1);
70 
71   for (FloatT& value : literal->data<FloatT>()) {
72     // Each special value has a kSpecialValueProbability chance to be generated
73     // instead of sampling using the normal distributions.
74     if (special_value_gen(*engine) <
75         kSpecialValueProbability * kNumSpecialValues) {
76       value =
77           static_cast<FloatT>(kSpecialValues[(*engine)() % kNumSpecialValues]);
78     } else {
79       float sign = ((*engine)() % 2 == 0) ? 1 : -1;
80       value = static_cast<FloatT>(pow(2, generator(*engine)) * sign);
81     }
82   }
83 }
84 
85 template <typename FloatT>
86 void PopulateWithIntNext(Literal* literal);
87 
88 template <>
PopulateWithIntNext(Literal * literal)89 void PopulateWithIntNext<half>(Literal* literal) {
90   // Duplicates may be generated if we don't have enough bits.
91   uint16_t next_value = 0;
92   for (half& value : literal->data<half>()) {
93     // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
94     // the sign bit. We could be less wasteful, but this is best-effort anyway.
95     uint16_t exponent_msb = next_value & 0x4000;
96     value = Eigen::numext::bit_cast<half, uint16_t>((next_value & 0xBFFF) |
97                                                     (exponent_msb << 1));
98     next_value++;
99   }
100 }
101 
102 template <>
PopulateWithIntNext(Literal * literal)103 void PopulateWithIntNext<bfloat16>(Literal* literal) {
104   // Duplicates may be generated if we don't have enough bits.
105   // Start at 0x80 rather than 0 to avoid denormals.
106   uint16_t next_value = 0x80;
107   for (bfloat16& value : literal->data<bfloat16>()) {
108     // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
109     // the sign bit. We could be less wasteful, but this is best-effort anyway.
110     uint16_t exponent_msb = next_value & 0x4000;
111     value = Eigen::numext::bit_cast<bfloat16, uint16_t>((next_value & 0xBFFF) |
112                                                         (exponent_msb << 1));
113     next_value++;
114   }
115 }
116 
117 template <typename FloatT>
PopulateWithNextAfter(Literal * literal)118 void PopulateWithNextAfter(Literal* literal) {
119   // Duplicates may be generated if the number of elements in the literal
120   // exceeds the number of positive values supported by the type.
121   float next_value = std::numeric_limits<float>::min();
122   for (float& value : literal->data<float>()) {
123     value = next_value;
124     next_value = std::nextafter(next_value, std::numeric_limits<float>::max());
125   }
126 }
127 
128 template <typename FloatT,
129           typename std::enable_if<std::is_same<bfloat16, FloatT>::value ||
130                                       std::is_same<half, FloatT>::value,
131                                   int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)132 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
133   PopulateWithIntNext<FloatT>(literal);
134   std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
135                *engine);
136 }
137 
138 template <typename FloatT,
139           typename std::enable_if<!std::is_same<bfloat16, FloatT>::value &&
140                                       !std::is_same<half, FloatT>::value,
141                                   int>::type = 0>
PopulateWithNoDuplicateData(Literal * literal,std::minstd_rand0 * engine)142 void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
143   PopulateWithNextAfter<FloatT>(literal);
144   std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
145                *engine);
146 }
147 
148 template <typename FloatT>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)149 void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine,
150                                    bool no_duplicates, bool use_large_range) {
151   CHECK(engine != nullptr);
152   CHECK_EQ(literal->shape().element_type(),
153            primitive_util::NativeToPrimitiveType<FloatT>());
154   if (no_duplicates) {
155     PopulateWithNoDuplicateData<FloatT>(literal, engine);
156   } else if (use_large_range) {
157     PopulateWithRandomFullRangeFloatingPointData<FloatT>(literal, engine);
158   } else {
159     PopulateWithRandomFloatingPointData<FloatT, FloatT>(literal, engine);
160   }
161 }
162 
163 template <typename ComplexT>
PopulateWithComplexData(Literal * result,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)164 void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine,
165                              bool no_duplicates, bool use_large_range) {
166   using InnerFloatT = typename ComplexT::value_type;
167   CHECK(engine != nullptr);
168   CHECK_EQ(result->shape().element_type(),
169            primitive_util::NativeToPrimitiveType<ComplexT>());
170   Shape floating_point_shape = ShapeUtil::ChangeElementType(
171       result->shape(), primitive_util::NativeToPrimitiveType<InnerFloatT>());
172   Literal real_lit(floating_point_shape);
173   Literal imaginary_lit(floating_point_shape);
174 
175   PopulateWithFloatingPointData<InnerFloatT>(&real_lit, engine, no_duplicates,
176                                              use_large_range);
177   PopulateWithFloatingPointData<InnerFloatT>(&imaginary_lit, engine,
178                                              no_duplicates, use_large_range);
179 
180   absl::Span<const InnerFloatT> real_data = real_lit.data<InnerFloatT>();
181   absl::Span<const InnerFloatT> imaginary_data =
182       imaginary_lit.data<InnerFloatT>();
183   absl::Span<ComplexT> result_data = result->data<ComplexT>();
184   for (int i = 0; i < real_lit.data<InnerFloatT>().size(); i++) {
185     result_data[i] = ComplexT(real_data[i], imaginary_data[i]);
186   }
187 }
188 
189 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)190 void PopulateWithFloatingPointData<half>(Literal* literal,
191                                          std::minstd_rand0* engine,
192                                          bool no_duplicates,
193                                          bool use_large_range) {
194   CHECK(engine != nullptr);
195   CHECK_EQ(literal->shape().element_type(),
196            primitive_util::NativeToPrimitiveType<half>());
197   if (no_duplicates) {
198     PopulateWithNoDuplicateData<half>(literal, engine);
199   } else if (use_large_range) {
200     PopulateWithRandomFullRangeFloatingPointData<half>(literal, engine);
201   } else {
202     PopulateWithRandomFloatingPointData<half, float>(literal, engine);
203   }
204 }
205 
206 template <>
PopulateWithFloatingPointData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)207 void PopulateWithFloatingPointData<bfloat16>(Literal* literal,
208                                              std::minstd_rand0* engine,
209                                              bool no_duplicates,
210                                              bool use_large_range) {
211   CHECK(engine != nullptr);
212   CHECK_EQ(literal->shape().element_type(),
213            primitive_util::NativeToPrimitiveType<bfloat16>());
214   if (no_duplicates) {
215     PopulateWithNoDuplicateData<bfloat16>(literal, engine);
216   } else if (use_large_range) {
217     PopulateWithRandomFullRangeFloatingPointData<bfloat16>(literal, engine);
218   } else {
219     PopulateWithRandomFloatingPointData<bfloat16, float>(literal, engine);
220   }
221 }
222 
223 // uniform_int_distribution is not defined for 8-bit integers.
224 // Use 'short' for those types.
225 template <typename IntT>
226 struct RngT {
227   using type = IntT;
228 };
229 
230 template <>
231 struct RngT<int8_t> {
232   using type = int16_t;
233 };
234 
235 template <>
236 struct RngT<uint8_t> {
237   using type = uint16_t;
238 };
239 
240 template <typename IntT>
PopulateWithRandomIntegralData(Literal * literal,std::minstd_rand0 * engine,bool no_duplicates)241 void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
242                                     bool no_duplicates) {
243   CHECK(engine != nullptr);
244   CHECK_EQ(literal->shape().element_type(),
245            primitive_util::NativeToPrimitiveType<IntT>());
246   if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) <
247                            std::numeric_limits<IntT>::max()) {
248     std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(), 0);
249     std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
250                  *engine);
251   } else {
252     std::uniform_int_distribution<typename RngT<IntT>::type> generator(
253         std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
254     for (IntT& value : literal->data<IntT>()) {
255       value = generator(*engine);
256     }
257   }
258 }
259 
260 // Similar to MakeFakeLiteral but takes a random number generator engine to
261 // enable reusing the engine across randomly generated literals. 'no_duplicates'
262 // indicates that there should be no duplicate values in each generated
263 // array. This is uniqueness is best-effort only. Some types (half and bfloat16)
264 // are not supported and uniqueness cannot be guaranteed if the number of
265 // elements exceeds the number of different values supported by the type.
MakeFakeLiteralInternal(const Shape & shape,std::minstd_rand0 * engine,bool no_duplicates,bool use_large_range)266 StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
267                                           std::minstd_rand0* engine,
268                                           bool no_duplicates,
269                                           bool use_large_range) {
270   if (shape.IsTuple()) {
271     std::vector<Literal> elements;
272     const auto& shape_tuple_shapes = shape.tuple_shapes();
273     elements.reserve(shape_tuple_shapes.size());
274     for (const Shape& element_shape : shape_tuple_shapes) {
275       TF_ASSIGN_OR_RETURN(Literal element, MakeFakeLiteralInternal(
276                                                element_shape, engine,
277                                                no_duplicates, use_large_range));
278       elements.push_back(std::move(element));
279     }
280     return LiteralUtil::MakeTupleOwned(std::move(elements));
281   }
282   if (engine == nullptr) {
283     return Literal::CreateFromShape(shape);
284   }
285   // Clear tiles/element size in shape's layout before using it for creating
286   // literal.
287   Shape new_shape = shape;
288   new_shape.mutable_layout()->clear_tiles();
289   new_shape.mutable_layout()->set_element_size_in_bits(0);
290   Literal literal(new_shape);
291   switch (shape.element_type()) {
292     case BF16:
293       PopulateWithFloatingPointData<bfloat16>(&literal, engine, no_duplicates,
294                                               use_large_range);
295       break;
296     case F16:
297       PopulateWithFloatingPointData<half>(&literal, engine, no_duplicates,
298                                           use_large_range);
299       break;
300     case F32:
301       PopulateWithFloatingPointData<float>(&literal, engine, no_duplicates,
302                                            use_large_range);
303       break;
304     case F64:
305       PopulateWithFloatingPointData<double>(&literal, engine, no_duplicates,
306                                             use_large_range);
307       break;
308     case S8:
309       PopulateWithRandomIntegralData<int8_t>(&literal, engine, no_duplicates);
310       break;
311     case U8:
312       PopulateWithRandomIntegralData<uint8_t>(&literal, engine, no_duplicates);
313       break;
314     case S16:
315       PopulateWithRandomIntegralData<int16_t>(&literal, engine, no_duplicates);
316       break;
317     case U16:
318       PopulateWithRandomIntegralData<uint16_t>(&literal, engine, no_duplicates);
319       break;
320     case S32:
321       PopulateWithRandomIntegralData<int32_t>(&literal, engine, no_duplicates);
322       break;
323     case U32:
324       PopulateWithRandomIntegralData<uint32_t>(&literal, engine, no_duplicates);
325       break;
326     case S64:
327       PopulateWithRandomIntegralData<int64_t>(&literal, engine, no_duplicates);
328       break;
329     case U64:
330       PopulateWithRandomIntegralData<uint64_t>(&literal, engine, no_duplicates);
331       break;
332     case C64:
333       PopulateWithComplexData<complex64>(&literal, engine, no_duplicates,
334                                          use_large_range);
335       break;
336     case C128:
337       PopulateWithComplexData<complex128>(&literal, engine, no_duplicates,
338                                           use_large_range);
339       break;
340     case PRED: {
341       std::uniform_int_distribution<int> generator(0, 1);
342       TF_CHECK_OK(
343           literal.Populate<bool>([&](absl::Span<const int64_t> /*indices*/) {
344             return generator(*engine);
345           }));
346       break;
347     }
348     default:
349       return Unimplemented("Unsupported type for fake literal generation: %s",
350                            ShapeUtil::HumanString(shape));
351   }
352   return std::move(literal);
353 }
354 
355 template <typename IntT>
PopulateWithRandomIntegralDataWithBounds(Literal * literal,std::minstd_rand0 * engine,IntT min,IntT max)356 void PopulateWithRandomIntegralDataWithBounds(Literal* literal,
357                                               std::minstd_rand0* engine,
358                                               IntT min, IntT max) {
359   CHECK(engine != nullptr);
360   CHECK_EQ(literal->shape().element_type(),
361            primitive_util::NativeToPrimitiveType<IntT>());
362   std::uniform_int_distribution<typename RngT<IntT>::type> generator(min, max);
363   for (IntT& value : literal->data<IntT>()) {
364     value = generator(*engine);
365   }
366 }
367 
368 // Same as MakeFakeLiteralInternal but generates random numbers in the given
369 // range [min, max]. Currently this works only for INT types.
MakeFakeLiteralInternalWithBounds(const Shape & shape,std::minstd_rand0 * engine,int64_t min,int64_t max,bool is_sorted)370 StatusOr<Literal> MakeFakeLiteralInternalWithBounds(const Shape& shape,
371                                                     std::minstd_rand0* engine,
372                                                     int64_t min, int64_t max,
373                                                     bool is_sorted) {
374   if (shape.IsTuple()) {
375     std::vector<Literal> elements;
376     const auto& shape_tuple_shapes = shape.tuple_shapes();
377     elements.reserve(shape_tuple_shapes.size());
378     for (const Shape& element_shape : shape_tuple_shapes) {
379       TF_ASSIGN_OR_RETURN(Literal element,
380                           MakeFakeLiteralInternalWithBounds(
381                               element_shape, engine, min, max, is_sorted));
382       elements.push_back(std::move(element));
383     }
384     return LiteralUtil::MakeTupleOwned(std::move(elements));
385   }
386   if (engine == nullptr) {
387     return Literal::CreateFromShape(shape);
388   }
389   // Clear tiles/element size in shape's layout before using it for creating
390   // literal.
391   Shape new_shape = shape;
392   new_shape.mutable_layout()->clear_tiles();
393   new_shape.mutable_layout()->set_element_size_in_bits(0);
394   Literal literal(new_shape);
395   switch (shape.element_type()) {
396     case S8:
397       PopulateWithRandomIntegralDataWithBounds<int8_t>(
398           &literal, engine, static_cast<int8_t>(min), static_cast<int8_t>(max));
399       if (is_sorted) {
400         std::sort(literal.data<int8_t>().begin(), literal.data<int8_t>().end());
401       }
402       break;
403     case U8:
404       PopulateWithRandomIntegralDataWithBounds<uint8_t>(
405           &literal, engine, static_cast<uint8_t>(min),
406           static_cast<uint8_t>(max));
407       if (is_sorted) {
408         std::sort(literal.data<uint8_t>().begin(),
409                   literal.data<uint8_t>().end());
410       }
411       break;
412     case S16:
413       PopulateWithRandomIntegralDataWithBounds<int16_t>(
414           &literal, engine, static_cast<int16_t>(min),
415           static_cast<int16_t>(max));
416       if (is_sorted) {
417         std::sort(literal.data<int16_t>().begin(),
418                   literal.data<int16_t>().end());
419       }
420       break;
421     case U16:
422       PopulateWithRandomIntegralDataWithBounds<uint16_t>(
423           &literal, engine, static_cast<uint16_t>(min),
424           static_cast<uint16_t>(max));
425       if (is_sorted) {
426         std::sort(literal.data<uint16_t>().begin(),
427                   literal.data<uint16_t>().end());
428       }
429       break;
430     case S32:
431       PopulateWithRandomIntegralDataWithBounds<int32_t>(
432           &literal, engine, static_cast<int32_t>(min),
433           static_cast<int32_t>(max));
434       if (is_sorted) {
435         std::sort(literal.data<int32_t>().begin(),
436                   literal.data<int32_t>().end());
437       }
438       break;
439     case U32:
440       PopulateWithRandomIntegralDataWithBounds<uint32_t>(
441           &literal, engine, static_cast<uint32_t>(min),
442           static_cast<uint32_t>(max));
443       if (is_sorted) {
444         std::sort(literal.data<uint32_t>().begin(),
445                   literal.data<uint32_t>().end());
446       }
447       break;
448     case S64:
449       PopulateWithRandomIntegralDataWithBounds<int64_t>(
450           &literal, engine, static_cast<int64_t>(min),
451           static_cast<int64_t>(max));
452       if (is_sorted) {
453         std::sort(literal.data<int64_t>().begin(),
454                   literal.data<int64_t>().end());
455       }
456       break;
457     case U64:
458       PopulateWithRandomIntegralDataWithBounds<uint64_t>(
459           &literal, engine, static_cast<uint64_t>(min),
460           static_cast<uint64_t>(max));
461       if (is_sorted) {
462         std::sort(literal.data<uint64_t>().begin(),
463                   literal.data<uint64_t>().end());
464       }
465       break;
466     default:
467       return Unimplemented(
468           "Unsupported type for fake random literal generation with bounds: %s",
469           ShapeUtil::HumanString(shape));
470   }
471   return std::move(literal);
472 }
473 
474 enum class ConstantType { kUnknown, kZero, kOne };
475 
476 // Return the constant type required by this computation, if known.
GetInitValue(const HloComputation & computation)477 ConstantType GetInitValue(const HloComputation& computation) {
478   // TODO(b/77635120): Add init values, for min, max, and their arg variants.
479   const HloInstruction* const root = computation.root_instruction();
480   if (computation.num_parameters() != 2 || root->operand_count() != 2 ||
481       root->operand(0)->opcode() != HloOpcode::kParameter ||
482       root->operand(1)->opcode() != HloOpcode::kParameter ||
483       root->operand(0) == root->operand(1)) {
484     return ConstantType::kUnknown;
485   }
486 
487   switch (root->opcode()) {
488     case HloOpcode::kAdd:
489       return ConstantType::kZero;
490     case HloOpcode::kMultiply:
491       return ConstantType::kOne;
492     default:
493       return ConstantType::kUnknown;
494   }
495 }
496 
497 // Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random
498 // initialization value.
NeedsInitValue(const HloUse & use)499 bool NeedsInitValue(const HloUse& use) {
500   const HloInstruction* const instruction = use.instruction;
501   const HloOpcode opcode = instruction->opcode();
502   const int64_t op_num = use.operand_number;
503   return ((opcode == HloOpcode::kReduceWindow && op_num == 1) ||
504           (opcode == HloOpcode::kSelectAndScatter && op_num == 2) ||
505           (opcode == HloOpcode::kReduce &&
506            op_num >= instruction->operand_count() / 2));
507 }
508 
509 // Generate random values that are constrained to the input_shape minus the
510 // output_shape so as not to produce wrapping slices, for instance.
MakeRandomIndex(int64_t index_bound,std::minstd_rand0 * engine)511 Literal MakeRandomIndex(int64_t index_bound, std::minstd_rand0* engine) {
512   std::uniform_int_distribution<int32_t> generator(0, index_bound);
513   return LiteralUtil::CreateR0<int32_t>(generator(*engine));
514 }
515 
516 // Returns true if `dest' is reachable from `src' through data-formatting and
517 // custom call instructions within the same computation.
ReachableViaDataFormatting(const HloInstruction * src,const HloInstruction * dest)518 bool ReachableViaDataFormatting(const HloInstruction* src,
519                                 const HloInstruction* dest) {
520   if (src == dest) {
521     return true;
522   }
523   switch (dest->opcode()) {
524     case HloOpcode::kReshape:
525     case HloOpcode::kTranspose:
526     case HloOpcode::kCopy:
527     case HloOpcode::kSlice:
528       break;
529     case HloOpcode::kCustomCall:
530       if (dest->custom_call_target() == "AssumeGatherIndicesInBound") {
531         break;
532       }
533       return false;
534     default:
535       return false;
536   }
537   for (const auto* operand : dest->operands()) {
538     if (ReachableViaDataFormatting(src, operand)) {
539       return true;
540     }
541   }
542   return false;
543 }
544 
545 // Use dataflow analysis on each parameter to see if there are uses that would
546 // be problematic when generating input data.  Returns the list of instructions
547 // that correspond to their uses.
548 //
549 // Should be paired with the CreateLiteralForConstrainedUses() function below.
FindConstrainedUses(const HloDataflowAnalysis & dataflow,const HloInstruction & param)550 std::vector<HloInstruction*> FindConstrainedUses(
551     const HloDataflowAnalysis& dataflow, const HloInstruction& param) {
552   std::vector<HloInstruction*> constrained_uses;
553   for (const auto& pair : dataflow.GetInstructionValueSet(&param)) {
554     const HloValue& value = dataflow.GetUniqueValueAt(&param, pair.first);
555     for (const HloUse& use : value.GetUses()) {
556       HloInstruction* instruction = use.instruction;
557       const HloOpcode opcode = instruction->opcode();
558       const int64_t op_num = use.operand_number;
559       if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) ||
560           (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) {
561         constrained_uses.push_back(instruction);
562       } else if ((opcode == HloOpcode::kGather ||
563                   opcode == HloOpcode::kScatter) &&
564                  op_num == 1) {
565         constrained_uses.push_back(instruction);
566       } else if (opcode == HloOpcode::kFusion) {
567         const HloInstruction* const to_analyze =
568             instruction->fused_parameter(op_num);
569         auto fused_uses = FindConstrainedUses(dataflow, *to_analyze);
570         constrained_uses.insert(constrained_uses.end(), fused_uses.begin(),
571                                 fused_uses.end());
572       } else if (NeedsInitValue(use)) {
573         constrained_uses.push_back(instruction);
574       } else if (opcode == HloOpcode::kConvert ||
575                  opcode == HloOpcode::kReducePrecision) {
576         auto converted_uses = FindConstrainedUses(dataflow, *instruction);
577         constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
578                                 converted_uses.end());
579       } else if (opcode == HloOpcode::kSort &&
580                  instruction->operand_count() >= 2 && op_num == 0) {
581         // Operand 0 of sort is the array of keys used for key/value
582         // (two-operand) kSort instructions. Since sort stability is not
583         // guaranteed, constrain keys of key-value sort not to have duplicates,
584         // since otherwise the value order may legitimately differ.
585         constrained_uses.push_back(instruction);
586       }
587     }
588   }
589 
590   for (auto* instruction : param.parent()->instructions()) {
591     const HloOpcode opcode = instruction->opcode();
592     if (opcode == HloOpcode::kGather || opcode == HloOpcode::kScatter) {
593       if (instruction->operand(1) == &param) {
594         // Above already covers this case.
595         continue;
596       }
597       if (ReachableViaDataFormatting(&param, instruction->operand(1))) {
598         constrained_uses.push_back(instruction);
599       }
600     }
601   }
602   return constrained_uses;
603 }
604 
605 // Given a parameter, generate a random Literal to use as input if there exist
606 // no constrained uses in the dataflow graph.  If such constraints exist,
607 // generate a constrained literal (either bounded in the case of indices, or
608 // zero in the case of init_values for reductions).
CreateLiteralForConstrainedUses(const absl::Span<HloInstruction * const> constrained_uses,const HloInstruction & param,const Shape & param_shape,std::minstd_rand0 * engine,bool use_large_range)609 StatusOr<Literal> CreateLiteralForConstrainedUses(
610     const absl::Span<HloInstruction* const> constrained_uses,
611     const HloInstruction& param, const Shape& param_shape,
612     std::minstd_rand0* engine, bool use_large_range) {
613   int64_t index_bound = INT64_MAX;
614   bool no_duplicates = false;
615   bool needs_constant = false;
616   bool needs_sorted_indices = false;
617   ConstantType constant_type = ConstantType::kUnknown;
618   for (HloInstruction* use : constrained_uses) {
619     switch (use->opcode()) {
620       case HloOpcode::kDynamicSlice:
621       case HloOpcode::kDynamicUpdateSlice: {
622         const Shape& indexed_shape = use->operand(0)->shape();
623         const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
624                                        ? use->shape()
625                                        : use->operand(1)->shape();
626         const int64_t first_index =
627             Cast<HloDynamicIndexInstruction>(use)->first_index_operand_number();
628         for (int64_t operand = first_index; operand < use->operand_count();
629              ++operand) {
630           if (use->operand(operand) == &param) {
631             index_bound = std::min(
632                 index_bound,
633                 ShapeUtil::GetDimension(indexed_shape, operand - first_index) -
634                     ShapeUtil::GetDimension(slice_shape,
635                                             operand - first_index));
636           }
637         }
638         break;
639       }
640       case HloOpcode::kGather:
641       case HloOpcode::kScatter: {
642         const Shape& operand_shape = use->operand(0)->shape();
643         auto index_map = use->opcode() == HloOpcode::kGather
644                              ? use->gather_dimension_numbers().start_index_map()
645                              : use->scatter_dimension_numbers()
646                                    .scatter_dims_to_operand_dims();
647         for (const auto dim_in_operand : index_map) {
648           index_bound = std::min(index_bound,
649                                  operand_shape.dimensions(dim_in_operand) - 1);
650         }
651         if (use->opcode() == HloOpcode::kScatter) {
652           needs_sorted_indices |=
653               Cast<const HloScatterInstruction>(use)->indices_are_sorted();
654         } else {
655           needs_sorted_indices |=
656               Cast<const HloGatherInstruction>(use)->indices_are_sorted();
657         }
658         break;
659       }
660       case HloOpcode::kReduce:
661       case HloOpcode::kReduceWindow:
662         needs_constant = true;
663         constant_type = GetInitValue(*use->to_apply());
664         break;
665 
666       case HloOpcode::kSelectAndScatter:
667         needs_constant = true;
668         constant_type = GetInitValue(*use->scatter());
669         break;
670 
671       case HloOpcode::kSort:
672         no_duplicates = true;
673         break;
674 
675       default:
676         return Unimplemented(
677             "Constrained operand generation not implemented for %s.",
678             use->ToString());
679     }
680   }
681   int constraint_count = 0;
682   constraint_count += no_duplicates ? 1 : 0;
683   constraint_count += (index_bound != INT64_MAX) ? 1 : 0;
684   constraint_count += needs_constant ? 1 : 0;
685   if (constraint_count > 1) {
686     return Unimplemented("Conflicting operand generation constraints.");
687   }
688   if (index_bound != INT64_MAX) {
689     return MakeFakeLiteralInternalWithBounds(param_shape, engine, 0,
690                                              index_bound, needs_sorted_indices);
691   } else if (needs_constant) {
692     switch (constant_type) {
693       case ConstantType::kZero:
694         return LiteralUtil::Zero(param_shape.element_type());
695       case ConstantType::kOne:
696         return LiteralUtil::One(param_shape.element_type());
697       case ConstantType::kUnknown:
698         // We want the identity element for the computation, but we don't really
699         // know what it is - so any value we generate will be just as wrong.
700         return MakeFakeLiteralInternal(param_shape, engine,
701                                        /*no_duplicates=*/false,
702                                        use_large_range);
703     }
704   } else {
705     return MakeFakeLiteralInternal(param_shape, engine, no_duplicates,
706                                    use_large_range);
707   }
708 }
709 
710 // Given a module entry parameter, use the dataflow analysis to see if a
711 // special case literal must be created, or if we can generate fake data.
MakeConstrainedArgument(const HloDataflowAnalysis & dataflow,const HloInstruction & param,const Shape & param_shape,std::minstd_rand0 * engine,bool use_large_range)712 StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
713                                           const HloInstruction& param,
714                                           const Shape& param_shape,
715                                           std::minstd_rand0* engine,
716                                           bool use_large_range) {
717   const auto constrained_uses = FindConstrainedUses(dataflow, param);
718   return CreateLiteralForConstrainedUses(constrained_uses, param, param_shape,
719                                          engine, use_large_range);
720 }
721 
722 }  // namespace
723 
MakeFakeLiteral(const Shape & shape,bool pseudo_random,bool use_large_range)724 StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random,
725                                   bool use_large_range) {
726   auto engine = pseudo_random ? std::make_unique<std::minstd_rand0>() : nullptr;
727   return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false,
728                                  use_large_range);
729 }
730 
MakeFakeArguments(const HloModule * module,bool pseudo_random,bool use_large_range)731 StatusOr<std::vector<Literal>> MakeFakeArguments(const HloModule* module,
732                                                  bool pseudo_random,
733                                                  bool use_large_range) {
734   auto engine = pseudo_random ? std::make_unique<std::minstd_rand0>() : nullptr;
735   return MakeFakeArguments(module, engine.get(), use_large_range);
736 }
737 
MakeFakeArguments(const HloModule * module,std::minstd_rand0 * engine,bool use_large_range)738 StatusOr<std::vector<Literal>> MakeFakeArguments(const HloModule* module,
739                                                  std::minstd_rand0* engine,
740                                                  bool use_large_range) {
741   TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
742   const auto params = module->entry_computation()->parameter_instructions();
743   std::vector<Literal> arguments(params.size());
744   for (int i = 0; i < params.size(); ++i) {
745     const HloModuleConfig& module_config = module->config();
746     const Shape& param_shape = (module_config.has_entry_computation_layout() &&
747                                 module_config.entry_computation_layout()
748                                     .parameter_layout(i)
749                                     .shape()
750                                     .is_static())
751                                    ? module_config.entry_computation_layout()
752                                          .parameter_layout(i)
753                                          .shape()
754                                    : params[i]->shape();
755 
756     TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument(
757                                           *dataflow, *params[i], param_shape,
758                                           engine, use_large_range));
759   }
760   return std::move(arguments);
761 }
762 
VerifyHloModule(HloModule * const module,bool layout_sensitive,bool allow_mixed_precision)763 Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
764                        bool allow_mixed_precision) {
765   return HloVerifier(/*layout_sensitive=*/layout_sensitive,
766                      /*allow_mixed_precision=*/allow_mixed_precision)
767       .Run(module)
768       .status();
769 }
770 
CreateCanonicalDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)771 std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
772                                                       HloInstruction* lhs,
773                                                       HloInstruction* rhs) {
774   CHECK_LE(lhs->shape().rank(), 2);
775   CHECK_LE(rhs->shape().rank(), 2);
776   PrecisionConfig precision_config;
777   precision_config.mutable_operand_precision()->Resize(
778       2, PrecisionConfig::DEFAULT);
779   DotDimensionNumbers dot_dimension_numbers;
780   dot_dimension_numbers.add_lhs_contracting_dimensions(
781       lhs->shape().rank() > 1 ? 1 : 0);
782   dot_dimension_numbers.add_rhs_contracting_dimensions(0);
783   return std::make_unique<HloDotInstruction>(
784       shape, lhs, rhs, dot_dimension_numbers, precision_config);
785 }
786 }  // namespace xla
787