1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/reduce.h"
17
18 #include <set>
19 #include <string>
20 #include <utility>
21
22 #include "absl/strings/substitute.h"
23 #include "tensorflow/lite/delegates/gpu/common/status.h"
24 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
25 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
26 #include "tensorflow/lite/delegates/gpu/common/util.h"
27
28 namespace tflite {
29 namespace gpu {
30
31 namespace {
GetMaximumWGTotalSize(const GpuInfo & gpu_info)32 int GetMaximumWGTotalSize(const GpuInfo& gpu_info) {
33 // total_wg_size must be power of 2 and >= 4;
34 int total_wg_size = 256;
35 if (gpu_info.IsAdreno() && gpu_info.adreno_info.IsAdreno3xx()) {
36 total_wg_size = 128;
37 }
38 if (gpu_info.IsMali()) {
39 const MaliInfo& mali_info = gpu_info.mali_info;
40 if (mali_info.IsMaliT6xx() || mali_info.IsMaliT7xx() ||
41 mali_info.IsMaliT8xx()) {
42 total_wg_size = 32;
43 } else {
44 total_wg_size = 64;
45 }
46 }
47 return total_wg_size;
48 }
49
HasAxis(const std::vector<Axis> & axis,Axis a)50 bool HasAxis(const std::vector<Axis>& axis, Axis a) {
51 for (const auto& a2 : axis) {
52 if (a2 == a) {
53 return true;
54 }
55 }
56 return false;
57 }
58
MakeOp(OperationType op_type,const std::string & a,const std::string & b)59 std::string MakeOp(OperationType op_type, const std::string& a,
60 const std::string& b) {
61 if (op_type == OperationType::REDUCE_SUM || op_type == OperationType::MEAN) {
62 return "((" + a + ") + (" + b + "))";
63 } else if (op_type == OperationType::REDUCE_PRODUCT) {
64 return "((" + a + ") * (" + b + "))";
65 } else if (op_type == OperationType::REDUCE_MAXIMUM) {
66 return "max(" + a + ", " + b + ")";
67 } else if (op_type == OperationType::REDUCE_MINIMUM) {
68 return "min(" + a + ", " + b + ")";
69 }
70 return "UnsupportedOperation";
71 }
72
73 // max_total_wg_size is pot
GetMaximumPossibleWGSize(const std::vector<int> & ordered_sizes,int max_total_wg_size)74 int3 GetMaximumPossibleWGSize(const std::vector<int>& ordered_sizes,
75 int max_total_wg_size) {
76 int3 wg_size = int3(1, 1, 1);
77 int wg_size_total = 1;
78 for (int i = ordered_sizes.size() - 1; i >= 0; i--) {
79 const int wg_index = ordered_sizes.size() - 1 - i;
80 if (wg_index >= 3) {
81 return wg_size;
82 }
83 while (ordered_sizes[i] >= wg_size[wg_index] * 2) {
84 wg_size_total *= 2;
85 if (wg_size_total > max_total_wg_size) {
86 return wg_size;
87 }
88 wg_size[wg_index] *= 2;
89 }
90 }
91 return wg_size;
92 }
93
GetSizesFromShape(const std::set<Axis> & axis,const BHWC & shape)94 std::map<Axis, int> GetSizesFromShape(const std::set<Axis>& axis,
95 const BHWC& shape) {
96 std::map<Axis, int> result;
97 for (auto a : axis) {
98 result[a] = shape.get(a);
99 }
100 return result;
101 }
102
GetSizesFromShape(const std::set<Axis> & axis,const BHWDC & shape)103 std::map<Axis, int> GetSizesFromShape(const std::set<Axis>& axis,
104 const BHWDC& shape) {
105 std::map<Axis, int> result;
106 for (auto a : axis) {
107 result[a] = shape.get(a);
108 }
109 return result;
110 }
111
GetAccumType(DataType src_type)112 DataType GetAccumType(DataType src_type) {
113 if (src_type == DataType::FLOAT32 || src_type == DataType::FLOAT16) {
114 return DataType::FLOAT32;
115 } else if (src_type == DataType::INT32 || src_type == DataType::INT16 ||
116 src_type == DataType::INT8) {
117 return DataType::INT32;
118 } else if (src_type == DataType::UINT32 || src_type == DataType::UINT16 ||
119 src_type == DataType::UINT8) {
120 return DataType::UINT32;
121 } else {
122 return src_type;
123 }
124 }
125
126 } // namespace
127
Reduce(const std::map<Axis,int> & axis_to_reduce,OperationType op_type,const OperationDef & definition,const GpuInfo & gpu_info)128 Reduce::Reduce(const std::map<Axis, int>& axis_to_reduce, OperationType op_type,
129 const OperationDef& definition, const GpuInfo& gpu_info)
130 : GPUOperation(definition) {
131 std::vector<Axis> ordered_axis_to_reduce;
132 std::vector<int> ordered_sizes;
133 for (const auto& a :
134 {Axis::CHANNELS, Axis::DEPTH, Axis::HEIGHT, Axis::WIDTH, Axis::BATCH}) {
135 auto it = axis_to_reduce.find(a);
136 if (it != axis_to_reduce.end()) {
137 ordered_axis_to_reduce.push_back(it->first);
138 int reduction_size = it->second;
139 if (a == Axis::CHANNELS) {
140 reduction_size = DivideRoundUp(reduction_size, 4);
141 }
142 ordered_sizes.push_back(reduction_size);
143 }
144 }
145 const int max_total_wg_size = GetMaximumWGTotalSize(gpu_info);
146 int3 current_wg_size =
147 GetMaximumPossibleWGSize(ordered_sizes, max_total_wg_size);
148 int current_wg_size_total =
149 current_wg_size.x * current_wg_size.y * current_wg_size.z;
150 int threshold = max_total_wg_size / 4;
151 if (gpu_info.IsApple()) {
152 threshold = 16;
153 }
154 if (current_wg_size_total < threshold) {
155 use_wg_reduction_ = false;
156 } else {
157 use_wg_reduction_ = true;
158 work_group_size_ = current_wg_size;
159 }
160 code_ = GetReduceKernelCode(definition_, gpu_info, work_group_size_,
161 ordered_axis_to_reduce, op_type);
162 }
163
Reduce(Reduce && operation)164 Reduce::Reduce(Reduce&& operation)
165 : GPUOperation(std::move(operation)),
166 use_wg_reduction_(operation.use_wg_reduction_) {}
167
operator =(Reduce && operation)168 Reduce& Reduce::operator=(Reduce&& operation) {
169 if (this != &operation) {
170 use_wg_reduction_ = operation.use_wg_reduction_;
171 GPUOperation::operator=(std::move(operation));
172 }
173 return *this;
174 }
175
GetReduceKernelCode(const OperationDef & op_def,const GpuInfo & gpu_info,const int3 & work_group_size,const std::vector<Axis> & axis_to_reduce,OperationType op_type)176 std::string Reduce::GetReduceKernelCode(const OperationDef& op_def,
177 const GpuInfo& gpu_info,
178 const int3& work_group_size,
179 const std::vector<Axis>& axis_to_reduce,
180 OperationType op_type) {
181 AddSrcTensor("src_tensor", op_def.src_tensors[0]);
182 AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
183 args_.AddFloat("inv_multiplier_1");
184 args_.AddFloat("inv_multiplier_2");
185
186 std::set<Axis> axis_to_leave;
187 const std::vector<Axis> all_axis = {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH,
188 Axis::CHANNELS, Axis::BATCH};
189 for (const auto& a : all_axis) {
190 if (op_def.dst_tensors[0].HasAxis(a)) {
191 if (!HasAxis(axis_to_reduce, a)) {
192 axis_to_leave.insert(a);
193 }
194 }
195 }
196 const bool channels_reductin = HasAxis(axis_to_reduce, Axis::CHANNELS);
197 int wg_dims = 0;
198 if (use_wg_reduction_) {
199 if (work_group_size.y == 1 && work_group_size.z == 1) {
200 wg_dims = 1;
201 } else if (work_group_size.z == 1) {
202 wg_dims = 2;
203 } else {
204 wg_dims = 3;
205 }
206 }
207
208 auto get_global_id = [&](int i) {
209 if (use_wg_reduction_) {
210 return "GROUP_ID_" + std::to_string(i);
211 } else {
212 return "GLOBAL_ID_" + std::to_string(i);
213 }
214 };
215
216 auto accum_type = GetAccumType(op_def.src_tensors[0].GetDataType());
217 const std::string accum_type_decl =
218 GetTypeDeclaration(gpu_info, accum_type, 4);
219 std::string read_as_template;
220 if (accum_type == DataType::FLOAT32) {
221 read_as_template = "<float>";
222 } else if (accum_type == DataType::INT32) {
223 read_as_template = "<int>";
224 } else if (accum_type == DataType::UINT32) {
225 read_as_template = "<uint>";
226 }
227
228 std::string c;
229 const std::string wg_x = std::to_string(work_group_size.x);
230 const std::string wg_y = std::to_string(work_group_size.y);
231 const std::string wg_z = std::to_string(work_group_size.z);
232 const int wg_total_size =
233 work_group_size.x * work_group_size.y * work_group_size.z;
234 c += "MAIN_FUNCTION($0) {\n";
235 if (use_wg_reduction_) {
236 c += " __local " + accum_type_decl + " accum[" +
237 std::to_string(wg_total_size) + "];\n";
238 if (wg_dims == 1) {
239 c += " int local_x = LOCAL_ID_0;\n";
240 c += " int local_id = local_x;\n";
241 } else if (wg_dims == 2) {
242 c += " int local_x = LOCAL_ID_0;\n";
243 c += " int local_y = LOCAL_ID_1;\n";
244 c += " int local_id = local_y * " + wg_x + " + local_x;\n";
245 } else if (wg_dims == 3) {
246 c += " int local_x = LOCAL_ID_0;\n";
247 c += " int local_y = LOCAL_ID_1;\n";
248 c += " int local_z = LOCAL_ID_2;\n";
249 c += " int local_id = (local_z * " + wg_y + " + local_y) * " + wg_x +
250 " + local_x;\n";
251 }
252 }
253 if (axis_to_leave.count(Axis::WIDTH)) {
254 if (axis_to_leave.count(Axis::BATCH)) {
255 c += " int linear_id = " + get_global_id(0) + ";\n";
256 c += " int DST_X = linear_id / args.dst_tensor.Batch();\n";
257 c += " int DST_B = linear_id % args.dst_tensor.Batch();\n";
258 } else {
259 c += " int DST_X = " + get_global_id(0) + ";\n";
260 }
261 } else if (axis_to_leave.count(Axis::BATCH)) {
262 c += " int DST_B = " + get_global_id(0) + ";\n";
263 }
264 if (axis_to_leave.count(Axis::HEIGHT)) {
265 if (axis_to_leave.count(Axis::DEPTH)) {
266 c += " int linear_id = " + get_global_id(1) + ";\n";
267 c += " int DST_Y = linear_id % args.dst_tensor.Height();\n";
268 c += " int DST_Z = linear_id / args.dst_tensor.Height();\n";
269 } else {
270 c += " int DST_Y = " + get_global_id(1) + ";\n";
271 }
272 } else if (axis_to_leave.count(Axis::DEPTH)) {
273 c += " int DST_Z = " + get_global_id(1) + ";\n";
274 }
275 if (axis_to_leave.count(Axis::CHANNELS)) {
276 c += " int DST_S = " + get_global_id(2) + ";\n";
277 }
278 std::map<Axis, std::string> axis_to_selector = {
279 {Axis::BATCH, "Batch()"}, {Axis::WIDTH, "Width()"},
280 {Axis::HEIGHT, "Height()"}, {Axis::DEPTH, "Depth()"},
281 {Axis::CHANNELS, "Slices()"},
282 };
283 std::map<Axis, std::string> axis_to_coord = {
284 {Axis::BATCH, "B"}, {Axis::WIDTH, "X"}, {Axis::HEIGHT, "Y"},
285 {Axis::DEPTH, "Z"}, {Axis::CHANNELS, "S"},
286 };
287 std::string dst_check;
288 for (auto& axis : axis_to_leave) {
289 if (!dst_check.empty()) {
290 dst_check += " || ";
291 }
292 dst_check += "DST_" + axis_to_coord[axis] + " >= args.dst_tensor." +
293 axis_to_selector[axis];
294 }
295 if (!dst_check.empty()) {
296 c += " if (" + dst_check + ") return;\n";
297 }
298 std::map<Axis, std::string> src_coords;
299 for (const auto& a : all_axis) {
300 if (op_def.dst_tensors[0].HasAxis(a) && !HasAxis(axis_to_reduce, a)) {
301 src_coords[a] = "DST_" + axis_to_coord[a];
302 } else {
303 src_coords[a] = "0";
304 }
305 }
306 std::string src_coordinates;
307 for (const auto& a : all_axis) {
308 if (op_def.src_tensors[0].HasAxis(a)) {
309 if (!src_coordinates.empty()) {
310 src_coordinates += ", ";
311 }
312 src_coordinates += src_coords[a];
313 }
314 }
315 if (op_type == OperationType::REDUCE_SUM || op_type == OperationType::MEAN) {
316 c += " " + accum_type_decl +
317 " reducer = " + GetZeroValue(gpu_info, accum_type, 4) + ";\n";
318 } else if (op_type == OperationType::REDUCE_PRODUCT) {
319 c += " " + accum_type_decl +
320 " reducer = " + GetOneValue(gpu_info, accum_type, 4) + ";\n";
321 } else if (op_type == OperationType::REDUCE_MAXIMUM ||
322 op_type == OperationType::REDUCE_MINIMUM) {
323 c += " " + accum_type_decl + " reducer = args.src_tensor.Read" +
324 read_as_template + "(" + src_coordinates + ");\n";
325 if (channels_reductin) {
326 c += " reducer.y = reducer.x;\n";
327 c += " reducer.z = reducer.x;\n";
328 c += " reducer.w = reducer.x;\n";
329 }
330 }
331 const std::vector<std::string> local_ids = {"local_x", "local_y", "local_z"};
332 const std::vector<std::string> local_sizes = {wg_x, wg_y, wg_z};
333 for (const auto& axis : axis_to_reduce) {
334 if (axis == Axis::CHANNELS) {
335 c += " " + accum_type_decl + " mask;\n";
336 const std::string one_or_zero_value =
337 GetOneValue(gpu_info, accum_type, 1) + " : " +
338 GetZeroValue(gpu_info, accum_type, 1);
339 c += " mask.x = (args.src_tensor.Slices() - 1) * 4 + 0 < "
340 "args.src_tensor.Channels() ? " +
341 one_or_zero_value + ";\n";
342 c += " mask.y = (args.src_tensor.Slices() - 1) * 4 + 1 < "
343 "args.src_tensor.Channels() ? " +
344 one_or_zero_value + ";\n";
345 c += " mask.z = (args.src_tensor.Slices() - 1) * 4 + 2 < "
346 "args.src_tensor.Channels() ? " +
347 one_or_zero_value + ";\n";
348 c += " mask.w = (args.src_tensor.Slices() - 1) * 4 + 3 < "
349 "args.src_tensor.Channels() ? " +
350 one_or_zero_value + ";\n";
351 }
352 }
353 for (int i = 0; i < axis_to_reduce.size(); ++i) {
354 const auto& axis = axis_to_reduce[i];
355 const int index = axis_to_reduce.size() - 1 - i;
356 const std::string first = index < wg_dims ? local_ids[index] : "0";
357 const std::string step = index < wg_dims ? local_sizes[index] : "1";
358 const std::string src_coord = "SRC_" + axis_to_coord[axis];
359 src_coords[axis] = src_coord;
360 c += " for (int " + src_coord + " = " + first + "; " + src_coord +
361 " < args.src_tensor." + axis_to_selector[axis] + "; " + src_coord +
362 " += " + step + ") {\n";
363 if (axis == Axis::CHANNELS) {
364 c += " bool last = SRC_S == args.src_tensor.Slices() - 1;\n";
365 c += " " + accum_type_decl +
366 " mask_a = last ? mask : " + GetOneValue(gpu_info, accum_type, 4) +
367 ";\n";
368 if (op_type == OperationType::REDUCE_PRODUCT ||
369 op_type == OperationType::REDUCE_MAXIMUM ||
370 op_type == OperationType::REDUCE_MINIMUM) {
371 c += " " + accum_type_decl +
372 " mask_b = " + GetOneValue(gpu_info, accum_type, 4) +
373 " - mask_a;\n";
374 }
375 }
376 }
377 src_coordinates = "";
378 for (const auto& a : all_axis) {
379 if (op_def.src_tensors[0].HasAxis(a)) {
380 if (!src_coordinates.empty()) {
381 src_coordinates += ", ";
382 }
383 src_coordinates += src_coords[a];
384 }
385 }
386 c += " " + accum_type_decl + " src_val = args.src_tensor.Read" +
387 read_as_template + "(" + src_coordinates + ");\n";
388 if (channels_reductin) {
389 if (op_type == OperationType::REDUCE_SUM ||
390 op_type == OperationType::MEAN) {
391 c += " src_val = src_val * mask_a;\n";
392 } else if (op_type == OperationType::REDUCE_PRODUCT) {
393 c += " src_val = src_val * mask_a + mask_b;\n";
394 } else if (op_type == OperationType::REDUCE_MAXIMUM ||
395 op_type == OperationType::REDUCE_MINIMUM) {
396 c += " src_val = src_val * mask_a + mask_b * src_val.x;\n";
397 }
398 }
399 c += " reducer = " + MakeOp(op_type, "reducer", "src_val") + ";\n";
400 for (int i = 0; i < axis_to_reduce.size(); ++i) {
401 c += " }\n";
402 }
403 if (op_type == OperationType::MEAN) {
404 c += " reducer *= args.inv_multiplier_1;\n";
405 }
406 if (use_wg_reduction_) {
407 c += " accum[local_id] = reducer;\n";
408 c += " LOCAL_MEM_BARRIER;\n";
409 const int total_size =
410 work_group_size.x * work_group_size.y * work_group_size.z;
411 int offset = 1;
412 int reminder = total_size / 4;
413 for (; reminder >= 8; reminder /= 4, offset *= 4) {
414 c += " if (local_id < " + std::to_string(reminder) + ") {\n";
415 c += " int t = local_id * " + std::to_string(offset * 4) + ";\n";
416 c += " " + accum_type_decl + " sum = accum[t + " +
417 std::to_string(offset) + "];\n";
418 c += " sum = " +
419 MakeOp(op_type, "sum",
420 "accum[t + " + std::to_string(offset * 2) + "]") +
421 ";\n";
422 c += " sum = " +
423 MakeOp(op_type, "sum",
424 "accum[t + " + std::to_string(offset * 3) + "]") +
425 ";\n";
426 c += " accum[t] = " + MakeOp(op_type, "accum[t]", "sum") + ";\n";
427 c += " }\n";
428 c += " LOCAL_MEM_BARRIER;\n";
429 }
430 c += " reducer = accum[0];\n";
431 reminder *= 4;
432 for (int i = 1; i < reminder; ++i) {
433 c += " reducer = " +
434 MakeOp(op_type, "reducer",
435 "accum[" + std::to_string(offset * i) + "]") +
436 ";\n";
437 }
438 if (op_type == OperationType::MEAN) {
439 c += " reducer *= args.inv_multiplier_2;\n";
440 }
441 }
442 if (channels_reductin) {
443 if (op_type == OperationType::REDUCE_SUM ||
444 op_type == OperationType::MEAN) {
445 c += " reducer.x += reducer.y + reducer.z + reducer.w;\n";
446 } else if (op_type == OperationType::REDUCE_PRODUCT) {
447 c += " reducer.x *= reducer.y * reducer.z * reducer.w;\n";
448 } else if (op_type == OperationType::REDUCE_MAXIMUM) {
449 c += " reducer.x = max(reducer.x, reducer.y);\n";
450 c += " reducer.x = max(reducer.x, reducer.z);\n";
451 c += " reducer.x = max(reducer.x, reducer.w);\n";
452 } else if (op_type == OperationType::REDUCE_MINIMUM) {
453 c += " reducer.x = min(reducer.x, reducer.y);\n";
454 c += " reducer.x = min(reducer.x, reducer.z);\n";
455 c += " reducer.x = min(reducer.x, reducer.w);\n";
456 }
457 }
458 const std::string conversion = GetTypeConversion(
459 gpu_info, accum_type, op_def.src_tensors[0].GetDataType(), 4);
460 c += " args.src_tensor::type result = " +
461 absl::Substitute(conversion, "reducer") + ";\n";
462 std::string dst_coordinates;
463 for (const auto& a : all_axis) {
464 if (op_def.dst_tensors[0].HasAxis(a)) {
465 if (!dst_coordinates.empty()) {
466 dst_coordinates += ", ";
467 }
468 if (axis_to_leave.count(a)) {
469 dst_coordinates += "DST_" + axis_to_coord[a];
470 } else {
471 dst_coordinates += "0";
472 }
473 }
474 }
475 c += " args.dst_tensor.Write(result, " + dst_coordinates + ");\n";
476 c += "}\n";
477 return c;
478 }
479
BindArguments(ArgumentsBinder * args)480 absl::Status Reduce::BindArguments(ArgumentsBinder* args) {
481 const double total_src_elements = 1.0 * src_[0]->Batch() * src_[0]->Width() *
482 src_[0]->Height() * src_[0]->Depth() *
483 src_[0]->Channels();
484 const double total_dst_elements = 1.0 * dst_[0]->Batch() * dst_[0]->Width() *
485 dst_[0]->Height() * dst_[0]->Depth() *
486 dst_[0]->Channels();
487 const double reduction_size = total_src_elements / total_dst_elements;
488 if (use_wg_reduction_) {
489 const double size_0 =
490 work_group_size_.x * work_group_size_.y * work_group_size_.z;
491 const double size_1 = reduction_size / size_0;
492 RETURN_IF_ERROR(args->SetFloat("inv_multiplier_1", 1.0 / size_1));
493 RETURN_IF_ERROR(args->SetFloat("inv_multiplier_2", 1.0 / size_0));
494 } else {
495 RETURN_IF_ERROR(args->SetFloat("inv_multiplier_1", 1.0 / reduction_size));
496 RETURN_IF_ERROR(args->SetFloat("inv_multiplier_2", 1.0));
497 }
498 return absl::OkStatus();
499 }
500
GetGridSize() const501 int3 Reduce::GetGridSize() const {
502 int grid_x = dst_[0]->Width() * dst_[0]->Batch();
503 int grid_y = dst_[0]->Height() * dst_[0]->Depth();
504 int grid_z = dst_[0]->Slices();
505 if (use_wg_reduction_) {
506 grid_x *= work_group_size_.x;
507 grid_y *= work_group_size_.y;
508 grid_z *= work_group_size_.z;
509 }
510 return int3(grid_x, grid_y, grid_z);
511 }
512
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const513 void Reduce::GetPossibleKernelWorkGroups(TuningType tuning_type,
514 const GpuInfo& gpu_info,
515 const KernelInfo& kernel_info,
516 std::vector<int3>* work_groups) const {
517 if (use_wg_reduction_) {
518 work_groups->push_back(work_group_size_);
519 } else {
520 GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
521 work_groups);
522 }
523 }
524
CreateReduce(const std::set<Axis> & axis_to_reduce,const BHWC & src_shape,OperationType op_type,const OperationDef & definition,const GpuInfo & gpu_info)525 Reduce CreateReduce(const std::set<Axis>& axis_to_reduce, const BHWC& src_shape,
526 OperationType op_type, const OperationDef& definition,
527 const GpuInfo& gpu_info) {
528 return Reduce(GetSizesFromShape(axis_to_reduce, src_shape), op_type,
529 definition, gpu_info);
530 }
531
CreateReduce(const std::set<Axis> & axis_to_reduce,const BHWDC & src_shape,OperationType op_type,const OperationDef & definition,const GpuInfo & gpu_info)532 Reduce CreateReduce(const std::set<Axis>& axis_to_reduce,
533 const BHWDC& src_shape, OperationType op_type,
534 const OperationDef& definition, const GpuInfo& gpu_info) {
535 return Reduce(GetSizesFromShape(axis_to_reduce, src_shape), op_type,
536 definition, gpu_info);
537 }
538
539 } // namespace gpu
540 } // namespace tflite
541