xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/reduce.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/reduce.h"
17 
18 #include <set>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/strings/substitute.h"
23 #include "tensorflow/lite/delegates/gpu/common/status.h"
24 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
25 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
26 #include "tensorflow/lite/delegates/gpu/common/util.h"
27 
28 namespace tflite {
29 namespace gpu {
30 
31 namespace {
GetMaximumWGTotalSize(const GpuInfo & gpu_info)32 int GetMaximumWGTotalSize(const GpuInfo& gpu_info) {
33   // total_wg_size must be power of 2 and >= 4;
34   int total_wg_size = 256;
35   if (gpu_info.IsAdreno() && gpu_info.adreno_info.IsAdreno3xx()) {
36     total_wg_size = 128;
37   }
38   if (gpu_info.IsMali()) {
39     const MaliInfo& mali_info = gpu_info.mali_info;
40     if (mali_info.IsMaliT6xx() || mali_info.IsMaliT7xx() ||
41         mali_info.IsMaliT8xx()) {
42       total_wg_size = 32;
43     } else {
44       total_wg_size = 64;
45     }
46   }
47   return total_wg_size;
48 }
49 
HasAxis(const std::vector<Axis> & axis,Axis a)50 bool HasAxis(const std::vector<Axis>& axis, Axis a) {
51   for (const auto& a2 : axis) {
52     if (a2 == a) {
53       return true;
54     }
55   }
56   return false;
57 }
58 
MakeOp(OperationType op_type,const std::string & a,const std::string & b)59 std::string MakeOp(OperationType op_type, const std::string& a,
60                    const std::string& b) {
61   if (op_type == OperationType::REDUCE_SUM || op_type == OperationType::MEAN) {
62     return "((" + a + ") + (" + b + "))";
63   } else if (op_type == OperationType::REDUCE_PRODUCT) {
64     return "((" + a + ") * (" + b + "))";
65   } else if (op_type == OperationType::REDUCE_MAXIMUM) {
66     return "max(" + a + ", " + b + ")";
67   } else if (op_type == OperationType::REDUCE_MINIMUM) {
68     return "min(" + a + ", " + b + ")";
69   }
70   return "UnsupportedOperation";
71 }
72 
73 // max_total_wg_size is pot
GetMaximumPossibleWGSize(const std::vector<int> & ordered_sizes,int max_total_wg_size)74 int3 GetMaximumPossibleWGSize(const std::vector<int>& ordered_sizes,
75                               int max_total_wg_size) {
76   int3 wg_size = int3(1, 1, 1);
77   int wg_size_total = 1;
78   for (int i = ordered_sizes.size() - 1; i >= 0; i--) {
79     const int wg_index = ordered_sizes.size() - 1 - i;
80     if (wg_index >= 3) {
81       return wg_size;
82     }
83     while (ordered_sizes[i] >= wg_size[wg_index] * 2) {
84       wg_size_total *= 2;
85       if (wg_size_total > max_total_wg_size) {
86         return wg_size;
87       }
88       wg_size[wg_index] *= 2;
89     }
90   }
91   return wg_size;
92 }
93 
GetSizesFromShape(const std::set<Axis> & axis,const BHWC & shape)94 std::map<Axis, int> GetSizesFromShape(const std::set<Axis>& axis,
95                                       const BHWC& shape) {
96   std::map<Axis, int> result;
97   for (auto a : axis) {
98     result[a] = shape.get(a);
99   }
100   return result;
101 }
102 
GetSizesFromShape(const std::set<Axis> & axis,const BHWDC & shape)103 std::map<Axis, int> GetSizesFromShape(const std::set<Axis>& axis,
104                                       const BHWDC& shape) {
105   std::map<Axis, int> result;
106   for (auto a : axis) {
107     result[a] = shape.get(a);
108   }
109   return result;
110 }
111 
GetAccumType(DataType src_type)112 DataType GetAccumType(DataType src_type) {
113   if (src_type == DataType::FLOAT32 || src_type == DataType::FLOAT16) {
114     return DataType::FLOAT32;
115   } else if (src_type == DataType::INT32 || src_type == DataType::INT16 ||
116              src_type == DataType::INT8) {
117     return DataType::INT32;
118   } else if (src_type == DataType::UINT32 || src_type == DataType::UINT16 ||
119              src_type == DataType::UINT8) {
120     return DataType::UINT32;
121   } else {
122     return src_type;
123   }
124 }
125 
126 }  // namespace
127 
Reduce(const std::map<Axis,int> & axis_to_reduce,OperationType op_type,const OperationDef & definition,const GpuInfo & gpu_info)128 Reduce::Reduce(const std::map<Axis, int>& axis_to_reduce, OperationType op_type,
129                const OperationDef& definition, const GpuInfo& gpu_info)
130     : GPUOperation(definition) {
131   std::vector<Axis> ordered_axis_to_reduce;
132   std::vector<int> ordered_sizes;
133   for (const auto& a :
134        {Axis::CHANNELS, Axis::DEPTH, Axis::HEIGHT, Axis::WIDTH, Axis::BATCH}) {
135     auto it = axis_to_reduce.find(a);
136     if (it != axis_to_reduce.end()) {
137       ordered_axis_to_reduce.push_back(it->first);
138       int reduction_size = it->second;
139       if (a == Axis::CHANNELS) {
140         reduction_size = DivideRoundUp(reduction_size, 4);
141       }
142       ordered_sizes.push_back(reduction_size);
143     }
144   }
145   const int max_total_wg_size = GetMaximumWGTotalSize(gpu_info);
146   int3 current_wg_size =
147       GetMaximumPossibleWGSize(ordered_sizes, max_total_wg_size);
148   int current_wg_size_total =
149       current_wg_size.x * current_wg_size.y * current_wg_size.z;
150   int threshold = max_total_wg_size / 4;
151   if (gpu_info.IsApple()) {
152     threshold = 16;
153   }
154   if (current_wg_size_total < threshold) {
155     use_wg_reduction_ = false;
156   } else {
157     use_wg_reduction_ = true;
158     work_group_size_ = current_wg_size;
159   }
160   code_ = GetReduceKernelCode(definition_, gpu_info, work_group_size_,
161                               ordered_axis_to_reduce, op_type);
162 }
163 
Reduce(Reduce && operation)164 Reduce::Reduce(Reduce&& operation)
165     : GPUOperation(std::move(operation)),
166       use_wg_reduction_(operation.use_wg_reduction_) {}
167 
operator =(Reduce && operation)168 Reduce& Reduce::operator=(Reduce&& operation) {
169   if (this != &operation) {
170     use_wg_reduction_ = operation.use_wg_reduction_;
171     GPUOperation::operator=(std::move(operation));
172   }
173   return *this;
174 }
175 
GetReduceKernelCode(const OperationDef & op_def,const GpuInfo & gpu_info,const int3 & work_group_size,const std::vector<Axis> & axis_to_reduce,OperationType op_type)176 std::string Reduce::GetReduceKernelCode(const OperationDef& op_def,
177                                         const GpuInfo& gpu_info,
178                                         const int3& work_group_size,
179                                         const std::vector<Axis>& axis_to_reduce,
180                                         OperationType op_type) {
181   AddSrcTensor("src_tensor", op_def.src_tensors[0]);
182   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
183   args_.AddFloat("inv_multiplier_1");
184   args_.AddFloat("inv_multiplier_2");
185 
186   std::set<Axis> axis_to_leave;
187   const std::vector<Axis> all_axis = {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH,
188                                       Axis::CHANNELS, Axis::BATCH};
189   for (const auto& a : all_axis) {
190     if (op_def.dst_tensors[0].HasAxis(a)) {
191       if (!HasAxis(axis_to_reduce, a)) {
192         axis_to_leave.insert(a);
193       }
194     }
195   }
196   const bool channels_reductin = HasAxis(axis_to_reduce, Axis::CHANNELS);
197   int wg_dims = 0;
198   if (use_wg_reduction_) {
199     if (work_group_size.y == 1 && work_group_size.z == 1) {
200       wg_dims = 1;
201     } else if (work_group_size.z == 1) {
202       wg_dims = 2;
203     } else {
204       wg_dims = 3;
205     }
206   }
207 
208   auto get_global_id = [&](int i) {
209     if (use_wg_reduction_) {
210       return "GROUP_ID_" + std::to_string(i);
211     } else {
212       return "GLOBAL_ID_" + std::to_string(i);
213     }
214   };
215 
216   auto accum_type = GetAccumType(op_def.src_tensors[0].GetDataType());
217   const std::string accum_type_decl =
218       GetTypeDeclaration(gpu_info, accum_type, 4);
219   std::string read_as_template;
220   if (accum_type == DataType::FLOAT32) {
221     read_as_template = "<float>";
222   } else if (accum_type == DataType::INT32) {
223     read_as_template = "<int>";
224   } else if (accum_type == DataType::UINT32) {
225     read_as_template = "<uint>";
226   }
227 
228   std::string c;
229   const std::string wg_x = std::to_string(work_group_size.x);
230   const std::string wg_y = std::to_string(work_group_size.y);
231   const std::string wg_z = std::to_string(work_group_size.z);
232   const int wg_total_size =
233       work_group_size.x * work_group_size.y * work_group_size.z;
234   c += "MAIN_FUNCTION($0) {\n";
235   if (use_wg_reduction_) {
236     c += "  __local " + accum_type_decl + " accum[" +
237          std::to_string(wg_total_size) + "];\n";
238     if (wg_dims == 1) {
239       c += "  int local_x = LOCAL_ID_0;\n";
240       c += "  int local_id = local_x;\n";
241     } else if (wg_dims == 2) {
242       c += "  int local_x = LOCAL_ID_0;\n";
243       c += "  int local_y = LOCAL_ID_1;\n";
244       c += "  int local_id = local_y * " + wg_x + " + local_x;\n";
245     } else if (wg_dims == 3) {
246       c += "  int local_x = LOCAL_ID_0;\n";
247       c += "  int local_y = LOCAL_ID_1;\n";
248       c += "  int local_z = LOCAL_ID_2;\n";
249       c += "  int local_id = (local_z * " + wg_y + " + local_y) * " + wg_x +
250            " + local_x;\n";
251     }
252   }
253   if (axis_to_leave.count(Axis::WIDTH)) {
254     if (axis_to_leave.count(Axis::BATCH)) {
255       c += "  int linear_id = " + get_global_id(0) + ";\n";
256       c += "  int DST_X = linear_id / args.dst_tensor.Batch();\n";
257       c += "  int DST_B = linear_id % args.dst_tensor.Batch();\n";
258     } else {
259       c += "  int DST_X = " + get_global_id(0) + ";\n";
260     }
261   } else if (axis_to_leave.count(Axis::BATCH)) {
262     c += "  int DST_B = " + get_global_id(0) + ";\n";
263   }
264   if (axis_to_leave.count(Axis::HEIGHT)) {
265     if (axis_to_leave.count(Axis::DEPTH)) {
266       c += "  int linear_id = " + get_global_id(1) + ";\n";
267       c += "  int DST_Y = linear_id % args.dst_tensor.Height();\n";
268       c += "  int DST_Z = linear_id / args.dst_tensor.Height();\n";
269     } else {
270       c += "  int DST_Y = " + get_global_id(1) + ";\n";
271     }
272   } else if (axis_to_leave.count(Axis::DEPTH)) {
273     c += "  int DST_Z = " + get_global_id(1) + ";\n";
274   }
275   if (axis_to_leave.count(Axis::CHANNELS)) {
276     c += "  int DST_S = " + get_global_id(2) + ";\n";
277   }
278   std::map<Axis, std::string> axis_to_selector = {
279       {Axis::BATCH, "Batch()"},     {Axis::WIDTH, "Width()"},
280       {Axis::HEIGHT, "Height()"},   {Axis::DEPTH, "Depth()"},
281       {Axis::CHANNELS, "Slices()"},
282   };
283   std::map<Axis, std::string> axis_to_coord = {
284       {Axis::BATCH, "B"}, {Axis::WIDTH, "X"},    {Axis::HEIGHT, "Y"},
285       {Axis::DEPTH, "Z"}, {Axis::CHANNELS, "S"},
286   };
287   std::string dst_check;
288   for (auto& axis : axis_to_leave) {
289     if (!dst_check.empty()) {
290       dst_check += " || ";
291     }
292     dst_check += "DST_" + axis_to_coord[axis] + " >= args.dst_tensor." +
293                  axis_to_selector[axis];
294   }
295   if (!dst_check.empty()) {
296     c += "  if (" + dst_check + ") return;\n";
297   }
298   std::map<Axis, std::string> src_coords;
299   for (const auto& a : all_axis) {
300     if (op_def.dst_tensors[0].HasAxis(a) && !HasAxis(axis_to_reduce, a)) {
301       src_coords[a] = "DST_" + axis_to_coord[a];
302     } else {
303       src_coords[a] = "0";
304     }
305   }
306   std::string src_coordinates;
307   for (const auto& a : all_axis) {
308     if (op_def.src_tensors[0].HasAxis(a)) {
309       if (!src_coordinates.empty()) {
310         src_coordinates += ", ";
311       }
312       src_coordinates += src_coords[a];
313     }
314   }
315   if (op_type == OperationType::REDUCE_SUM || op_type == OperationType::MEAN) {
316     c += "  " + accum_type_decl +
317          " reducer = " + GetZeroValue(gpu_info, accum_type, 4) + ";\n";
318   } else if (op_type == OperationType::REDUCE_PRODUCT) {
319     c += "  " + accum_type_decl +
320          " reducer = " + GetOneValue(gpu_info, accum_type, 4) + ";\n";
321   } else if (op_type == OperationType::REDUCE_MAXIMUM ||
322              op_type == OperationType::REDUCE_MINIMUM) {
323     c += "  " + accum_type_decl + " reducer = args.src_tensor.Read" +
324          read_as_template + "(" + src_coordinates + ");\n";
325     if (channels_reductin) {
326       c += "  reducer.y = reducer.x;\n";
327       c += "  reducer.z = reducer.x;\n";
328       c += "  reducer.w = reducer.x;\n";
329     }
330   }
331   const std::vector<std::string> local_ids = {"local_x", "local_y", "local_z"};
332   const std::vector<std::string> local_sizes = {wg_x, wg_y, wg_z};
333   for (const auto& axis : axis_to_reduce) {
334     if (axis == Axis::CHANNELS) {
335       c += "  " + accum_type_decl + " mask;\n";
336       const std::string one_or_zero_value =
337           GetOneValue(gpu_info, accum_type, 1) + " : " +
338           GetZeroValue(gpu_info, accum_type, 1);
339       c += "  mask.x = (args.src_tensor.Slices() - 1) * 4 + 0 < "
340            "args.src_tensor.Channels() ? " +
341            one_or_zero_value + ";\n";
342       c += "  mask.y = (args.src_tensor.Slices() - 1) * 4 + 1 < "
343            "args.src_tensor.Channels() ? " +
344            one_or_zero_value + ";\n";
345       c += "  mask.z = (args.src_tensor.Slices() - 1) * 4 + 2 < "
346            "args.src_tensor.Channels() ? " +
347            one_or_zero_value + ";\n";
348       c += "  mask.w = (args.src_tensor.Slices() - 1) * 4 + 3 < "
349            "args.src_tensor.Channels() ? " +
350            one_or_zero_value + ";\n";
351     }
352   }
353   for (int i = 0; i < axis_to_reduce.size(); ++i) {
354     const auto& axis = axis_to_reduce[i];
355     const int index = axis_to_reduce.size() - 1 - i;
356     const std::string first = index < wg_dims ? local_ids[index] : "0";
357     const std::string step = index < wg_dims ? local_sizes[index] : "1";
358     const std::string src_coord = "SRC_" + axis_to_coord[axis];
359     src_coords[axis] = src_coord;
360     c += "  for (int " + src_coord + " = " + first + "; " + src_coord +
361          " < args.src_tensor." + axis_to_selector[axis] + "; " + src_coord +
362          " += " + step + ") {\n";
363     if (axis == Axis::CHANNELS) {
364       c += "    bool last = SRC_S == args.src_tensor.Slices() - 1;\n";
365       c += "    " + accum_type_decl +
366            " mask_a = last ? mask : " + GetOneValue(gpu_info, accum_type, 4) +
367            ";\n";
368       if (op_type == OperationType::REDUCE_PRODUCT ||
369           op_type == OperationType::REDUCE_MAXIMUM ||
370           op_type == OperationType::REDUCE_MINIMUM) {
371         c += "    " + accum_type_decl +
372              " mask_b = " + GetOneValue(gpu_info, accum_type, 4) +
373              " - mask_a;\n";
374       }
375     }
376   }
377   src_coordinates = "";
378   for (const auto& a : all_axis) {
379     if (op_def.src_tensors[0].HasAxis(a)) {
380       if (!src_coordinates.empty()) {
381         src_coordinates += ", ";
382       }
383       src_coordinates += src_coords[a];
384     }
385   }
386   c += "    " + accum_type_decl + " src_val = args.src_tensor.Read" +
387        read_as_template + "(" + src_coordinates + ");\n";
388   if (channels_reductin) {
389     if (op_type == OperationType::REDUCE_SUM ||
390         op_type == OperationType::MEAN) {
391       c += "    src_val = src_val * mask_a;\n";
392     } else if (op_type == OperationType::REDUCE_PRODUCT) {
393       c += "    src_val = src_val * mask_a + mask_b;\n";
394     } else if (op_type == OperationType::REDUCE_MAXIMUM ||
395                op_type == OperationType::REDUCE_MINIMUM) {
396       c += "    src_val = src_val * mask_a + mask_b * src_val.x;\n";
397     }
398   }
399   c += "    reducer = " + MakeOp(op_type, "reducer", "src_val") + ";\n";
400   for (int i = 0; i < axis_to_reduce.size(); ++i) {
401     c += "  }\n";
402   }
403   if (op_type == OperationType::MEAN) {
404     c += "  reducer *= args.inv_multiplier_1;\n";
405   }
406   if (use_wg_reduction_) {
407     c += "  accum[local_id] = reducer;\n";
408     c += "  LOCAL_MEM_BARRIER;\n";
409     const int total_size =
410         work_group_size.x * work_group_size.y * work_group_size.z;
411     int offset = 1;
412     int reminder = total_size / 4;
413     for (; reminder >= 8; reminder /= 4, offset *= 4) {
414       c += "  if (local_id < " + std::to_string(reminder) + ") {\n";
415       c += "    int t = local_id * " + std::to_string(offset * 4) + ";\n";
416       c += "    " + accum_type_decl + " sum = accum[t + " +
417            std::to_string(offset) + "];\n";
418       c += "    sum = " +
419            MakeOp(op_type, "sum",
420                   "accum[t + " + std::to_string(offset * 2) + "]") +
421            ";\n";
422       c += "    sum = " +
423            MakeOp(op_type, "sum",
424                   "accum[t + " + std::to_string(offset * 3) + "]") +
425            ";\n";
426       c += "    accum[t] = " + MakeOp(op_type, "accum[t]", "sum") + ";\n";
427       c += "  }\n";
428       c += "  LOCAL_MEM_BARRIER;\n";
429     }
430     c += "  reducer = accum[0];\n";
431     reminder *= 4;
432     for (int i = 1; i < reminder; ++i) {
433       c += "  reducer = " +
434            MakeOp(op_type, "reducer",
435                   "accum[" + std::to_string(offset * i) + "]") +
436            ";\n";
437     }
438     if (op_type == OperationType::MEAN) {
439       c += "  reducer *= args.inv_multiplier_2;\n";
440     }
441   }
442   if (channels_reductin) {
443     if (op_type == OperationType::REDUCE_SUM ||
444         op_type == OperationType::MEAN) {
445       c += "  reducer.x += reducer.y + reducer.z + reducer.w;\n";
446     } else if (op_type == OperationType::REDUCE_PRODUCT) {
447       c += "  reducer.x *= reducer.y * reducer.z * reducer.w;\n";
448     } else if (op_type == OperationType::REDUCE_MAXIMUM) {
449       c += "  reducer.x = max(reducer.x, reducer.y);\n";
450       c += "  reducer.x = max(reducer.x, reducer.z);\n";
451       c += "  reducer.x = max(reducer.x, reducer.w);\n";
452     } else if (op_type == OperationType::REDUCE_MINIMUM) {
453       c += "  reducer.x = min(reducer.x, reducer.y);\n";
454       c += "  reducer.x = min(reducer.x, reducer.z);\n";
455       c += "  reducer.x = min(reducer.x, reducer.w);\n";
456     }
457   }
458   const std::string conversion = GetTypeConversion(
459       gpu_info, accum_type, op_def.src_tensors[0].GetDataType(), 4);
460   c += "  args.src_tensor::type result = " +
461        absl::Substitute(conversion, "reducer") + ";\n";
462   std::string dst_coordinates;
463   for (const auto& a : all_axis) {
464     if (op_def.dst_tensors[0].HasAxis(a)) {
465       if (!dst_coordinates.empty()) {
466         dst_coordinates += ", ";
467       }
468       if (axis_to_leave.count(a)) {
469         dst_coordinates += "DST_" + axis_to_coord[a];
470       } else {
471         dst_coordinates += "0";
472       }
473     }
474   }
475   c += "  args.dst_tensor.Write(result, " + dst_coordinates + ");\n";
476   c += "}\n";
477   return c;
478 }
479 
BindArguments(ArgumentsBinder * args)480 absl::Status Reduce::BindArguments(ArgumentsBinder* args) {
481   const double total_src_elements = 1.0 * src_[0]->Batch() * src_[0]->Width() *
482                                     src_[0]->Height() * src_[0]->Depth() *
483                                     src_[0]->Channels();
484   const double total_dst_elements = 1.0 * dst_[0]->Batch() * dst_[0]->Width() *
485                                     dst_[0]->Height() * dst_[0]->Depth() *
486                                     dst_[0]->Channels();
487   const double reduction_size = total_src_elements / total_dst_elements;
488   if (use_wg_reduction_) {
489     const double size_0 =
490         work_group_size_.x * work_group_size_.y * work_group_size_.z;
491     const double size_1 = reduction_size / size_0;
492     RETURN_IF_ERROR(args->SetFloat("inv_multiplier_1", 1.0 / size_1));
493     RETURN_IF_ERROR(args->SetFloat("inv_multiplier_2", 1.0 / size_0));
494   } else {
495     RETURN_IF_ERROR(args->SetFloat("inv_multiplier_1", 1.0 / reduction_size));
496     RETURN_IF_ERROR(args->SetFloat("inv_multiplier_2", 1.0));
497   }
498   return absl::OkStatus();
499 }
500 
GetGridSize() const501 int3 Reduce::GetGridSize() const {
502   int grid_x = dst_[0]->Width() * dst_[0]->Batch();
503   int grid_y = dst_[0]->Height() * dst_[0]->Depth();
504   int grid_z = dst_[0]->Slices();
505   if (use_wg_reduction_) {
506     grid_x *= work_group_size_.x;
507     grid_y *= work_group_size_.y;
508     grid_z *= work_group_size_.z;
509   }
510   return int3(grid_x, grid_y, grid_z);
511 }
512 
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const513 void Reduce::GetPossibleKernelWorkGroups(TuningType tuning_type,
514                                          const GpuInfo& gpu_info,
515                                          const KernelInfo& kernel_info,
516                                          std::vector<int3>* work_groups) const {
517   if (use_wg_reduction_) {
518     work_groups->push_back(work_group_size_);
519   } else {
520     GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
521                           work_groups);
522   }
523 }
524 
CreateReduce(const std::set<Axis> & axis_to_reduce,const BHWC & src_shape,OperationType op_type,const OperationDef & definition,const GpuInfo & gpu_info)525 Reduce CreateReduce(const std::set<Axis>& axis_to_reduce, const BHWC& src_shape,
526                     OperationType op_type, const OperationDef& definition,
527                     const GpuInfo& gpu_info) {
528   return Reduce(GetSizesFromShape(axis_to_reduce, src_shape), op_type,
529                 definition, gpu_info);
530 }
531 
CreateReduce(const std::set<Axis> & axis_to_reduce,const BHWDC & src_shape,OperationType op_type,const OperationDef & definition,const GpuInfo & gpu_info)532 Reduce CreateReduce(const std::set<Axis>& axis_to_reduce,
533                     const BHWDC& src_shape, OperationType op_type,
534                     const OperationDef& definition, const GpuInfo& gpu_info) {
535   return Reduce(GetSizesFromShape(axis_to_reduce, src_shape), op_type,
536                 definition, gpu_info);
537 }
538 
539 }  // namespace gpu
540 }  // namespace tflite
541