xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/reference_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
18 
19 #include <array>
20 #include <functional>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/array2d.h"
27 #include "tensorflow/compiler/xla/array3d.h"
28 #include "tensorflow/compiler/xla/array4d.h"
29 #include "tensorflow/compiler/xla/client/padding.h"
30 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 
34 namespace xla {
35 
36 // Utility class for reference implementations of linear algebra routines.
37 class ReferenceUtil {
38  public:
39   // Returns the result of a transpose operation on the input matrix.
40   template <typename T>
TransposeArray2D(const Array2D<T> & operand)41   static std::unique_ptr<Array2D<T>> TransposeArray2D(
42       const Array2D<T>& operand) {
43     auto result =
44         std::make_unique<Array2D<T>>(operand.width(), operand.height());
45     for (int64_t w = 0; w < operand.width(); ++w) {
46       for (int64_t h = 0; h < operand.height(); ++h) {
47         (*result)(w, h) = operand(h, w);
48       }
49     }
50 
51     return result;
52   }
53 
54   // Returns the result of a matrix multiply `lhs x rhs`.
55   template <typename T>
MatmulArray2D(const Array2D<T> & lhs,const Array2D<T> & rhs)56   static std::unique_ptr<Array2D<T>> MatmulArray2D(const Array2D<T>& lhs,
57                                                    const Array2D<T>& rhs) {
58     return HloEvaluator::MatmulArray2D(lhs, rhs);
59   }
60 
61   // Converts the input operand to use f64 values instead of f32 values.
62   static std::unique_ptr<Array2D<double>> Array2DF32ToF64(
63       const Array2D<float>& input);
64 
65   // Returns the result of a convolution `lhs <conv> rhs`, with the default
66   // convolution dimension numbers returned from
67   // ComputationBuilder::CreateDefaultConvDimensionNumbers().
68   static std::unique_ptr<Array4D<float>> ConvArray4D(
69       const Array4D<float>& lhs, const Array4D<float>& rhs,
70       std::pair<int64_t, int64_t> kernel_stride, Padding padding);
71 
72   // Returns the result of a convolution `lhs <conv> rhs`, with the given
73   // convolution dimension numbers.
74   static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensions(
75       const Array4D<float>& lhs, const Array4D<float>& rhs,
76       std::pair<int64_t, int64_t> kernel_stride, Padding padding,
77       ConvolutionDimensionNumbers dimension_numbers);
78 
79   // Returns the result of a convolution `lhs <conv> rhs`, with the given
80   // dilation factors.
81   static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensionsDilated(
82       const Array4D<float>& lhs, const Array4D<float>& rhs,
83       std::pair<int64_t, int64_t> kernel_stride, Padding padding,
84       std::pair<int64_t, int64_t> lhs_dilation,
85       std::pair<int64_t, int64_t> rhs_dilation,
86       ConvolutionDimensionNumbers dnums);
87 
88   // Returns the result of a convolution `lhs <conv> rhs`, with the default
89   // convolution dimension numbers returned from
90   // ComputationBuilder::CreateDefaultConvDimensionNumbers().
91   static std::unique_ptr<Array3D<float>> ConvArray3D(const Array3D<float>& lhs,
92                                                      const Array3D<float>& rhs,
93                                                      int64_t kernel_stride,
94                                                      Padding padding);
95 
96   // Returns the result of a convolution `lhs <conv> rhs`.
97   static std::unique_ptr<Array3D<float>> ConvArray3DGeneralDimensionsDilated(
98       const Array3D<float>& lhs, const Array3D<float>& rhs,
99       int64_t kernel_stride, Padding padding, int64_t lhs_dilation,
100       int64_t rhs_dilation, const ConvolutionDimensionNumbers& dnums);
101 
102   // Returns the result of a separable  convolution with the given parameters.
103   // kernel_stride and padding applies to the depthwise convolution during
104   // the separable convolution. pointwise_weights.depth() must be equal to
105   // input.depth() * depthwise_weights.planes().
106   static std::unique_ptr<Array4D<float>> SeparableConvArray4D(
107       const Array4D<float>& input, const Array4D<float>& depthwise_weights,
108       const Array4D<float>& pointwise_weights,
109       std::pair<int64_t, int64_t> kernel_stride, Padding padding);
110 
111   // Returns the result of reducing a matrix to a column vector. init is the
112   // initial value for the reduce operation, and reduce_function is the function
113   // to apply for each reduction step.
114   static std::unique_ptr<std::vector<float>> ReduceToColArray2D(
115       const Array2D<float>& matrix, float init,
116       const std::function<float(float, float)>& reduce_function);
117 
118   // Returns the result of reducing a matrix to a row vector. init is the
119   // initial value for the reduce operation, and reduce_function is the function
120   // to apply for each reduction step.
121   static std::unique_ptr<std::vector<float>> ReduceToRowArray2D(
122       const Array2D<float>& matrix, float init,
123       const std::function<float(float, float)>& reduce_function);
124 
125   // Performs a R2=>R1 reduction by reducing away the dimension specified in
126   // 'dimension_to_reduce'.
127   template <typename T>
ReduceR2ToR1(const Array2D<T> & input,int dimension_to_reduce,T init,const std::function<T (T,T)> & freduce)128   static std::vector<T> ReduceR2ToR1(const Array2D<T>& input,
129                                      int dimension_to_reduce, T init,
130                                      const std::function<T(T, T)>& freduce) {
131     std::vector<T> result(dimension_to_reduce == 0 ? input.n2() : input.n1(),
132                           init);
133     for (int i0 = 0; i0 < input.n1(); ++i0) {
134       for (int i1 = 0; i1 < input.n2(); ++i1) {
135         int output = dimension_to_reduce == 0 ? i1 : i0;
136         result[output] = freduce(result[output], input(i0, i1));
137       }
138     }
139     return result;
140   }
141 
142   // Returns the result of reducing the 4D array to a vector, reducing away
143   // the dimensions specified in dims.
144   static std::vector<float> Reduce4DTo1D(
145       const Array4D<float>& array, float init, absl::Span<const int64_t> dims,
146       const std::function<float(float, float)>& reduce_function);
147 
148   // Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`.
149   static std::unique_ptr<Array4D<float>> Broadcast1DTo4D(
150       const std::vector<float>& array, const std::vector<int64_t>& bounds,
151       int64_t broadcast_from_dim);
152 
153   // Returns the result of reducing the 3D array to a 2D array, reducing away
154   // the dimensions specified in dims.
155   static std::unique_ptr<Array2D<float>> Reduce3DTo2D(
156       const Array3D<float>& array, float init, absl::Span<const int64_t> dims,
157       const std::function<float(float, float)>& reduce_function);
158 
159   // Applies map_function to each element in the input (2D array) and returns
160   // the result.
161   static std::unique_ptr<Array2D<float>> MapArray2D(
162       const Array2D<float>& matrix,
163       const std::function<float(float)>& map_function);
164 
165   // Applies map_function to each pair of corresponding elements in the two
166   // inputs arrays and returns the result.
167   static std::unique_ptr<Array2D<float>> MapArray2D(
168       const Array2D<float>& lhs, const Array2D<float>& rhs,
169       const std::function<float(float, float)>& map_function);
170 
171   // Applies map_function to each element in the input (3D array) and returns
172   // the result.
173   static std::unique_ptr<Array3D<float>> MapArray3D(
174       const Array3D<float>& array,
175       const std::function<float(float)>& map_function);
176 
177   // Applies map_function to each pair of corresponding elements in the two
178   // inputs arrays and returns the result.
179   static std::unique_ptr<Array3D<float>> MapArray3D(
180       const Array3D<float>& lhs, const Array3D<float>& rhs,
181       const std::function<float(float, float)>& map_function);
182 
183   // Number of windows in a given dimension. Calculation taken from
184   // xla::MakePadding().
185   static int64_t WindowCount(int64_t unpadded_width, int64_t window_len,
186                              int64_t stride, Padding padding);
187 
188   // Windowed reductions with Add as the function to apply.
189   static std::unique_ptr<std::vector<float>> ReduceWindow1DAdd(
190       absl::Span<const float> operand, float init,
191       absl::Span<const int64_t> window, absl::Span<const int64_t> stride,
192       Padding padding);
193   static std::unique_ptr<Array3D<float>> ReduceWindow3DAdd(
194       const Array3D<float>& operand, float init,
195       absl::Span<const int64_t> window, absl::Span<const int64_t> stride,
196       Padding padding);
197   static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
198       const Array4D<float>& operand, float init,
199       absl::Span<const int64_t> window, absl::Span<const int64_t> stride,
200       Padding padding);
201 
202   // Windowed reductions with a generic reduce function.
203   static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
204       absl::Span<const float> operand, float init,
205       const std::function<float(float, float)>& reduce_func,
206       absl::Span<const int64_t> window, absl::Span<const int64_t> stride,
207       absl::Span<const std::pair<int64_t, int64_t>> padding);
208   static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
209       const Array4D<float>& operand, float init,
210       const std::function<float(float, float)>& reduce_func,
211       absl::Span<const int64_t> window, absl::Span<const int64_t> stride,
212       Padding padding);
213   // With arbitrary padding.
214   static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
215       const Array4D<float>& operand, float init,
216       const std::function<float(float, float)>& reduce_func,
217       absl::Span<const int64_t> window, absl::Span<const int64_t> stride,
218       absl::Span<const std::pair<int64_t, int64_t>> padding);
219 
220   // Batch normalize data.
221   static std::unique_ptr<Array4D<float>> BatchNorm4D(
222       const Array4D<float>& input, const Array4D<float>& mean,
223       const Array4D<float>& var, const Array4D<float>& scale,
224       const Array4D<float>& offset, float epsilon);
225 
226   // Performs select and scatter with Greater Than or equal as the select, plus
227   // as the scatter, and Same Padding.
228   // TODO(b/74533103) Switch tests to evaluator and remove this implementation.
229   static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
230       const Array4D<float>& operand, const Array4D<float>& source, float init,
231       absl::Span<const int64_t> window, absl::Span<const int64_t> stride,
232       bool same_padding);
233 
234   // Concatenates the lhs and rhs arrays along the concatenate_dimension.
235   // E.g. if concatenate_dimension is 0, the "n1"/height dimension is
236   // concatenated, so the arrays are stacked on top of each other.
237   template <typename T>
Concat2D(const Array2D<T> & lhs,const Array2D<T> & rhs,int concatenate_dimension)238   static std::unique_ptr<Array2D<T>> Concat2D(const Array2D<T>& lhs,
239                                               const Array2D<T>& rhs,
240                                               int concatenate_dimension) {
241     CHECK(0 <= concatenate_dimension && concatenate_dimension < 2);
242     auto result = std::make_unique<Array2D<T>>(
243         concatenate_dimension == 0 ? lhs.n1() + rhs.n1() : lhs.n1(),
244         concatenate_dimension == 1 ? lhs.n2() + rhs.n2() : lhs.n2());
245     for (int64_t i0 = 0; i0 < result->n1(); ++i0) {
246       for (int64_t i1 = 0; i1 < result->n2(); ++i1) {
247         // If we exceed the bounds of the LHS, draw from the RHS, where the
248         // result index is adjusted by the number of values present in the LHS.
249         (*result)(i0, i1) = i0 < lhs.n1() && i1 < lhs.n2()
250                                 ? lhs(i0, i1)
251                                 : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
252                                       i1 >= lhs.n2() ? i1 - lhs.n2() : i1);
253       }
254     }
255     return result;
256   }
257 
258   // Concatenates the lhs and rhs 3D arrays along the concatenate_dimension. lhs
259   // and rhs must have the same dimensions except for the concatenate dimension.
260   template <typename T>
Concat3D(const Array3D<T> & lhs,const Array3D<T> & rhs,int concatenate_dimension)261   static std::unique_ptr<Array3D<T>> Concat3D(const Array3D<T>& lhs,
262                                               const Array3D<T>& rhs,
263                                               int concatenate_dimension) {
264     CHECK(0 <= concatenate_dimension && concatenate_dimension < 3);
265     const int64_t lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3()};
266     const int64_t rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()};
267     int64_t out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()};
268     for (int i = 0; i < 3; ++i) {
269       if (i != concatenate_dimension) {
270         out_dims[i] = lhs_dims[i];
271         CHECK_EQ(lhs_dims[i], rhs_dims[i]);
272       } else {
273         out_dims[i] = lhs_dims[i] + rhs_dims[i];
274       }
275     }
276     auto result =
277         std::make_unique<Array3D<T>>(out_dims[0], out_dims[1], out_dims[2]);
278     for (int64_t i0 = 0; i0 < result->n1(); ++i0) {
279       for (int64_t i1 = 0; i1 < result->n2(); ++i1) {
280         for (int64_t i2 = 0; i2 < result->n3(); ++i2) {
281           (*result)(i0, i1, i2) =
282               i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3()
283                   ? lhs(i0, i1, i2)
284                   : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
285                         i1 >= lhs.n2() ? i1 - lhs.n2() : i1,
286                         i2 >= lhs.n3() ? i2 - lhs.n3() : i2);
287         }
288       }
289     }
290     return result;
291   }
292 
293   // Concatenates the lhs and rhs 4D arrays along the concatenate_dimension. lhs
294   // and rhs must have the same dimensions except for the concatenate dimension.
295   template <typename T>
Concat4D(const Array4D<T> & lhs,const Array4D<T> & rhs,int concatenate_dimension)296   static std::unique_ptr<Array4D<T>> Concat4D(const Array4D<T>& lhs,
297                                               const Array4D<T>& rhs,
298                                               int concatenate_dimension) {
299     CHECK(0 <= concatenate_dimension && concatenate_dimension < 4);
300     const int64_t lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()};
301     const int64_t rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
302     int64_t out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
303     for (int i = 0; i < 4; ++i) {
304       if (i != concatenate_dimension) {
305         out_dims[i] = lhs_dims[i];
306         CHECK_EQ(lhs_dims[i], rhs_dims[i]);
307       } else {
308         out_dims[i] = lhs_dims[i] + rhs_dims[i];
309       }
310     }
311     auto result = std::make_unique<Array4D<T>>(out_dims[0], out_dims[1],
312                                                out_dims[2], out_dims[3]);
313     for (int64_t i0 = 0; i0 < result->n1(); ++i0) {
314       for (int64_t i1 = 0; i1 < result->n2(); ++i1) {
315         for (int64_t i2 = 0; i2 < result->n3(); ++i2) {
316           for (int64_t i3 = 0; i3 < result->n4(); ++i3) {
317             (*result)(i0, i1, i2, i3) =
318                 i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3() && i3 < lhs.n4()
319                     ? lhs(i0, i1, i2, i3)
320                     : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
321                           i1 >= lhs.n2() ? i1 - lhs.n2() : i1,
322                           i2 >= lhs.n3() ? i2 - lhs.n3() : i2,
323                           i3 >= lhs.n4() ? i3 - lhs.n4() : i3);
324           }
325         }
326       }
327     }
328     return result;
329   }
330 
331   // Slices with index clamping
332   template <typename T>
ClampSlice1D(absl::Span<const T> input,int64_t start,int64_t size)333   static std::vector<T> ClampSlice1D(absl::Span<const T> input, int64_t start,
334                                      int64_t size) {
335     start = std::min<int64_t>(std::max<int64_t>(0, start), input.size() - size);
336     std::vector<T> result;
337     for (int64_t i = 0; i < size; ++i) {
338       result.push_back(input[(start + i)]);
339     }
340     return result;
341   }
342 
343   // Slices the input array given starting indices, limit indices, and strides
344   // in each dimension.
345   template <typename T>
Slice2D(const Array2D<T> & input,std::array<int64_t,2> starts,std::array<int64_t,2> limits,std::array<int64_t,2> strides)346   static std::unique_ptr<Array2D<T>> Slice2D(const Array2D<T>& input,
347                                              std::array<int64_t, 2> starts,
348                                              std::array<int64_t, 2> limits,
349                                              std::array<int64_t, 2> strides) {
350     CHECK_LE(starts[0], input.n1());
351     CHECK_LE(starts[1], input.n2());
352     CHECK_LE(limits[0], input.n1());
353     CHECK_LE(limits[1], input.n2());
354     CHECK_GE(strides[0], 1);
355     CHECK_GE(strides[1], 1);
356     auto result = std::make_unique<Array2D<T>>(
357         CeilOfRatio(limits[0] - starts[0], strides[0]),
358         CeilOfRatio(limits[1] - starts[1], strides[1]));
359     for (int64_t i0 = 0; i0 < result->n1(); ++i0) {
360       for (int64_t i1 = 0; i1 < result->n2(); ++i1) {
361         (*result)(i0, i1) =
362             input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1]);
363       }
364     }
365     return result;
366   }
367 
368   template <typename T>
Slice3D(const Array3D<T> & input,std::array<int64_t,3> starts,std::array<int64_t,3> limits,std::array<int64_t,3> strides)369   static std::unique_ptr<Array3D<T>> Slice3D(const Array3D<T>& input,
370                                              std::array<int64_t, 3> starts,
371                                              std::array<int64_t, 3> limits,
372                                              std::array<int64_t, 3> strides) {
373     CHECK_LE(starts[0], input.n1());
374     CHECK_LE(starts[1], input.n2());
375     CHECK_LE(starts[2], input.n3());
376     CHECK_LE(limits[0], input.n1());
377     CHECK_LE(limits[1], input.n2());
378     CHECK_LE(limits[2], input.n3());
379     CHECK_GE(strides[0], 1);
380     CHECK_GE(strides[1], 1);
381     CHECK_GE(strides[2], 1);
382     auto result = std::make_unique<Array3D<T>>(
383         CeilOfRatio(limits[0] - starts[0], strides[0]),
384         CeilOfRatio(limits[1] - starts[1], strides[1]),
385         CeilOfRatio(limits[2] - starts[2], strides[2]));
386 
387     for (int64_t i0 = 0; i0 < result->n1(); ++i0) {
388       for (int64_t i1 = 0; i1 < result->n2(); ++i1) {
389         for (int64_t i2 = 0; i2 < result->n3(); ++i2) {
390           (*result)(i0, i1, i2) =
391               input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1],
392                     starts[2] + i2 * strides[2]);
393         }
394       }
395     }
396     return result;
397   }
398 
399   template <typename T>
Slice4D(const Array4D<T> & input,std::array<int64_t,4> starts,std::array<int64_t,4> limits,std::array<int64_t,4> strides)400   static std::unique_ptr<Array4D<T>> Slice4D(const Array4D<T>& input,
401                                              std::array<int64_t, 4> starts,
402                                              std::array<int64_t, 4> limits,
403                                              std::array<int64_t, 4> strides) {
404     CHECK_LE(starts[0], input.n1());
405     CHECK_LE(starts[1], input.n2());
406     CHECK_LE(starts[2], input.n3());
407     CHECK_LE(starts[3], input.n4());
408     CHECK_LE(limits[0], input.n1());
409     CHECK_LE(limits[1], input.n2());
410     CHECK_LE(limits[2], input.n3());
411     CHECK_LE(limits[3], input.n4());
412     CHECK_GE(strides[0], 1);
413     CHECK_GE(strides[1], 1);
414     CHECK_GE(strides[2], 1);
415     CHECK_GE(strides[3], 1);
416     auto result = std::make_unique<Array4D<T>>(
417         CeilOfRatio(limits[0] - starts[0], strides[0]),
418         CeilOfRatio(limits[1] - starts[1], strides[1]),
419         CeilOfRatio(limits[2] - starts[2], strides[2]),
420         CeilOfRatio(limits[3] - starts[3], strides[3]));
421     for (int64_t i0 = 0; i0 < result->n1(); ++i0) {
422       for (int64_t i1 = 0; i1 < result->n2(); ++i1) {
423         for (int64_t i2 = 0; i2 < result->n3(); ++i2) {
424           for (int64_t i3 = 0; i3 < result->n4(); ++i3) {
425             (*result)(i0, i1, i2, i3) =
426                 input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1],
427                       starts[2] + i2 * strides[2], starts[3] + i3 * strides[3]);
428           }
429         }
430       }
431     }
432     return result;
433   }
434 
435   // Applies map_function to each element in the input (2D array) and returns
436   // the result.
437   // (row, column) index of each element is also provided as arguments to
438   // map_function.
439   static std::unique_ptr<Array2D<float>> MapWithIndexArray2D(
440       const Array2D<float>& matrix,
441       const std::function<float(float, int64_t, int64_t)>& map_function);
442 
443   // Applies map_function to each element in the input (4D array) and returns
444   // the result.
445   template <typename F>
MapArray4D(const Array4D<float> & input,F && map_function)446   static std::unique_ptr<Array4D<float>> MapArray4D(const Array4D<float>& input,
447                                                     F&& map_function) {
448     return MapWithIndexArray4D(
449         input, [&](float value, int64_t, int64_t, int64_t, int64_t) {
450           return map_function(value);
451         });
452   }
453 
454   // Applies map_function to each element in the input (4D array) and returns
455   // the result.
456   // (plane, depth, height, width) index of each element is also provided as
457   // arguments to map_function.
458   template <typename F>
MapWithIndexArray4D(const Array4D<float> & input,F && map_function)459   static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
460       const Array4D<float>& input, F&& map_function) {
461     auto result = std::make_unique<Array4D<float>>(
462         input.planes(), input.depth(), input.height(), input.width());
463     for (int64_t plane = 0; plane < input.planes(); ++plane) {
464       for (int64_t depth = 0; depth < input.depth(); ++depth) {
465         for (int64_t height = 0; height < input.height(); ++height) {
466           for (int64_t width = 0; width < input.width(); ++width) {
467             (*result)(plane, depth, height, width) =
468                 map_function(input(plane, depth, height, width), plane, depth,
469                              height, width);
470           }
471         }
472       }
473     }
474     return result;
475   }
476 
477   // Applies map_function to each pair of elements in the input lhs and rhs
478   // (4D array) and returns the result.
479   template <typename F>
MapArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,F && map_function)480   static std::unique_ptr<Array4D<float>> MapArray4D(const Array4D<float>& lhs,
481                                                     const Array4D<float>& rhs,
482                                                     F&& map_function) {
483     return MapWithIndexArray4D(
484         lhs, rhs,
485         [&](float lhs, float rhs, int64_t, int64_t, int64_t, int64_t) {
486           return map_function(lhs, rhs);
487         });
488   }
489 
490   // Applies map_function to each pair of element in lhs and rhs (4D array) and
491   // returns the result.
492   // (plane, depth, height, width) index of each element is also provided as
493   // arguments to map_function.
494   template <typename F>
MapWithIndexArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,F && map_function)495   static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
496       const Array4D<float>& lhs, const Array4D<float>& rhs, F&& map_function) {
497     auto result = std::make_unique<Array4D<float>>(lhs.planes(), lhs.depth(),
498                                                    lhs.height(), lhs.width());
499     for (int64_t plane = 0; plane < lhs.planes(); ++plane) {
500       for (int64_t depth = 0; depth < lhs.depth(); ++depth) {
501         for (int64_t height = 0; height < lhs.height(); ++height) {
502           for (int64_t width = 0; width < lhs.width(); ++width) {
503             (*result)(plane, depth, height, width) = map_function(
504                 lhs(plane, depth, height, width),
505                 rhs(plane, depth, height, width), plane, depth, height, width);
506           }
507         }
508       }
509     }
510     return result;
511   }
512 
513   // Returns the result of a 2D pad on an input matrix.
514   template <typename NativeT>
PadArray2D(const Array2D<NativeT> & operand,const PaddingConfig & padding,const NativeT pad)515   static std::unique_ptr<Array2D<NativeT>> PadArray2D(
516       const Array2D<NativeT>& operand, const PaddingConfig& padding,
517       const NativeT pad) {
518     int64_t in0 = operand.n1();
519     int64_t high_padding0 = padding.dimensions(0).edge_padding_high();
520     int64_t low_padding0 = padding.dimensions(0).edge_padding_low();
521     int64_t interior_padding0 = padding.dimensions(0).interior_padding();
522     int64_t out0 =
523         in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0;
524 
525     int64_t in1 = operand.n2();
526     int64_t high_padding1 = padding.dimensions(1).edge_padding_high();
527     int64_t low_padding1 = padding.dimensions(1).edge_padding_low();
528     int64_t interior_padding1 = padding.dimensions(1).interior_padding();
529     int64_t out1 =
530         in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
531 
532     auto result = std::make_unique<Array2D<NativeT>>(out0, out1);
533     result->Fill(pad);
534     int64_t o0 = low_padding0;
535     for (int64_t i0 = 0; i0 < in0; ++i0) {
536       int64_t o1 = low_padding1;
537       for (int64_t i1 = 0; i1 < in1; ++i1) {
538         if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) {
539           (*result)(o0, o1) = operand(i0, i1);
540         }
541         o1 += interior_padding1 + 1;
542       }
543       o0 += interior_padding0 + 1;
544     }
545     return result;
546   }
547 
548   // Returns the result of a 3D pad on an input matrix.
549   template <typename NativeT>
PadArray3D(const Array3D<NativeT> & operand,const PaddingConfig & padding,const NativeT pad)550   static Array3D<NativeT> PadArray3D(const Array3D<NativeT>& operand,
551                                      const PaddingConfig& padding,
552                                      const NativeT pad) {
553     CHECK_EQ(padding.dimensions_size(), 3);
554 
555     const int64_t input_bounds[] = {operand.n1(), operand.n2(), operand.n3()};
556     int64_t pad_low[3];
557     int64_t pad_high[3];
558     int64_t pad_interior[3];
559     int64_t output_bounds[3];
560     for (int64_t i = 0; i < 3; ++i) {
561       pad_low[i] = padding.dimensions(i).edge_padding_low();
562       pad_high[i] = padding.dimensions(i).edge_padding_high();
563       CHECK_LE(0, pad_low[i]);
564       CHECK_LE(0, pad_high[i]);
565       CHECK_LE(0, padding.dimensions(i).interior_padding())
566           << "not implemented";
567       pad_interior[i] = padding.dimensions(i).interior_padding();
568 
569       output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
570                          (input_bounds[i] - 1) * pad_interior[i];
571     }
572 
573     Array3D<NativeT> result(output_bounds[0], output_bounds[1],
574                             output_bounds[2]);
575     int indices[] = {0, 0, 0};
576     for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) {
577       for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) {
578         for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) {
579           NativeT* value = &result(indices[0], indices[1], indices[2]);
580           bool value_padded = false;
581           for (int i = 0; i < 3; ++i) {
582             bool in_low_padding = indices[i] < pad_low[i];
583             bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
584             if (in_low_padding || in_high_padding) {
585               *value = pad;
586               value_padded = true;
587             }
588             if (pad_interior[i] &&
589                 (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
590               *value = pad;
591               value_padded = true;
592             }
593           }
594           if (value_padded) {
595             continue;
596           }
597           *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
598                            (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
599                            (indices[2] - pad_low[2]) / (pad_interior[2] + 1));
600         }
601       }
602     }
603     return result;
604   }
605 
606   // Returns the result of a 4D pad on an input array.
607   template <typename NativeT>
PadArray4D(const Array4D<NativeT> & operand,const PaddingConfig & padding,const NativeT pad)608   static Array4D<NativeT> PadArray4D(const Array4D<NativeT>& operand,
609                                      const PaddingConfig& padding,
610                                      const NativeT pad) {
611     CHECK_EQ(padding.dimensions_size(), 4);
612 
613     const int64_t input_bounds[] = {operand.n1(), operand.n2(), operand.n3(),
614                                     operand.n4()};
615     int64_t pad_low[4];
616     int64_t pad_high[4];
617     int64_t pad_interior[4];
618     int64_t output_bounds[4];
619     for (int64_t i = 0; i < 4; ++i) {
620       pad_low[i] = padding.dimensions(i).edge_padding_low();
621       pad_high[i] = padding.dimensions(i).edge_padding_high();
622       CHECK_LE(0, padding.dimensions(i).interior_padding())
623           << "not implemented";
624       pad_interior[i] = padding.dimensions(i).interior_padding();
625 
626       output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
627                          (input_bounds[i] - 1) * pad_interior[i];
628     }
629 
630     Array4D<NativeT> result(output_bounds[0], output_bounds[1],
631                             output_bounds[2], output_bounds[3]);
632     result.Each([&](absl::Span<const int64_t> indices, NativeT* value) {
633       for (int i = 0; i < 4; ++i) {
634         bool in_low_padding = indices[i] < pad_low[i];
635         bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
636         if (in_low_padding || in_high_padding) {
637           *value = pad;
638           return;
639         }
640         if (pad_interior[i] &&
641             (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
642           *value = pad;
643           return;
644         }
645       }
646       *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
647                        (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
648                        (indices[2] - pad_low[2]) / (pad_interior[2] + 1),
649                        (indices[3] - pad_low[3]) / (pad_interior[3] + 1));
650     });
651     return result;
652   }
653 
654   // ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running
655   // f(x[i], y[i], ...) for each array element in the Array2Ds x, y, ....
656   //
657   // The given arrays must have the same size and element type, and the return
658   // type of f must be implicitly convertible to the arrays' element type.
659   //
660   // Example usage:
661   //
662   //   Array2D<float> x, y, z = ...;
663   //   std::unique_ptr<Array2D> result = ReferenceUtil::ApplyElementwise2D(
664   //     [](float a, float b, float c) { return a * b + c; }, x, y, z);
665   //
666   template <typename F, typename T1, typename... Ts>
ApplyElementwise2D(F && f,const Array2D<T1> & array1,const Array2D<Ts> &...arrays)667   static std::unique_ptr<Array2D<T1>> ApplyElementwise2D(
668       F&& f, const Array2D<T1>& array1, const Array2D<Ts>&... arrays) {
669     AssertSameSize2D(array1, arrays...);
670     auto result = std::make_unique<Array2D<T1>>(array1.n1(), array1.n2());
671     for (int64_t i = 0; i < array1.n1(); ++i) {
672       for (int64_t j = 0; j < array1.n2(); ++j) {
673         (*result)(i, j) = f(array1(i, j), arrays(i, j)...);
674       }
675     }
676     return result;
677   }
678 
679  private:
680   template <typename T1, typename T2, typename... Ts>
AssertSameSize2D(const Array2D<T1> & array1,const Array2D<T2> & array2,const Array2D<Ts> &...arrays)681   static void AssertSameSize2D(const Array2D<T1>& array1,
682                                const Array2D<T2>& array2,
683                                const Array2D<Ts>&... arrays) {
684     static_assert(std::is_same<T1, T2>::value, "Args must be same type.");
685     CHECK_EQ(array1.n1(), array2.n1());
686     CHECK_EQ(array1.n2(), array2.n2());
687     AssertSameSize2D(array2, arrays...);
688   }
689 
690   // Recursive base case for AssertSameSize2D.
691   template <typename Array1>
AssertSameSize2D(const Array1 & array1)692   static void AssertSameSize2D(const Array1& array1) {}
693 
694   ReferenceUtil(const ReferenceUtil&) = delete;
695   ReferenceUtil& operator=(const ReferenceUtil&) = delete;
696 };
697 
698 }  // namespace xla
699 
700 #endif  // TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
701