1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.h"
17
18 #include <algorithm>
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23
24 #include "absl/strings/substitute.h"
25 #include "tensorflow/lite/delegates/gpu/common/operations.h"
26 #include "tensorflow/lite/delegates/gpu/common/util.h"
27
28 namespace tflite {
29 namespace gpu {
30
31 namespace {
32
CheckIfValidNodeOfType(const Node * node,OperationType required_type)33 absl::Status CheckIfValidNodeOfType(const Node* node,
34 OperationType required_type) {
35 if (node == nullptr) {
36 return absl::NotFoundError("Invalid node.");
37 }
38 if (OperationTypeFromString(node->operation.type) != required_type) {
39 return absl::NotFoundError("Type mismatch.");
40 }
41 return absl::OkStatus();
42 }
43
GetElementwiseScalarValue(const Node * node,float * result)44 absl::Status GetElementwiseScalarValue(const Node* node, float* result) {
45 auto attr = absl::any_cast<ElementwiseAttributes>(node->operation.attributes);
46 const float* value = absl::get_if<float>(&attr.param);
47 if (!value) {
48 return absl::NotFoundError("Not a scalar value inside attributes.");
49 }
50 *result = *value;
51 return absl::OkStatus();
52 }
53
GetNextSingleNode(const GraphFloat32 & graph,const Node & node,OperationType next_type,Node ** next_node)54 absl::Status GetNextSingleNode(const GraphFloat32& graph, const Node& node,
55 OperationType next_type, Node** next_node) {
56 auto consumers = graph.FindConsumers(graph.FindOutputs(node.id)[0]->id);
57 if (consumers.size() != 1) {
58 return absl::NotFoundError("Not a single consumer.");
59 }
60 RETURN_IF_ERROR(CheckIfValidNodeOfType(consumers[0], next_type));
61 *next_node = consumers[0];
62 return absl::OkStatus();
63 }
64
GetReduceCode(const std::string & src_value,const std::string & dst_value,int3 work_group_size,bool two_step)65 std::string GetReduceCode(const std::string& src_value,
66 const std::string& dst_value, int3 work_group_size,
67 bool two_step) {
68 int reduction_size = work_group_size.z;
69 std::string mem_name = work_group_size.x * work_group_size.y != 1
70 ? "shared_mem[LOCAL_ID_1][LOCAL_ID_0]"
71 : "shared_mem";
72 if (reduction_size <= 8) {
73 std::string result;
74 result += " { // reduction\n";
75 result += " " + mem_name + "[local_id] = " + src_value + ";\n";
76 result += " LOCAL_MEM_BARRIER;\n";
77 result += " " + dst_value + " = " + mem_name + "[0];\n";
78 for (int i = 1; i < reduction_size; ++i) {
79 result += " " + dst_value + " += " + mem_name + "[" +
80 std::to_string(i) + "];\n";
81 }
82 if (two_step) {
83 result += " LOCAL_MEM_BARRIER;\n";
84 }
85 result += " }\n";
86 return result;
87 } else {
88 // In the reduction step add upper half of the still-to-be-summed vector to
89 // the lower half, while taking care of odd sizes and rounding. E.g.:
90 // Number of items still to be summed before: 5
91 // Local memory before: [a, b, c, d, e];
92 // Local memory after: [a+d, b+e, c, d, e];
93 // Threads doing work: id < 2 = floor(5/2)
94 // Offset to the added items: 3 = ceil(5/2)
95 // Number of items still to be summed after: 3 = ceil(5/2)
96 return absl::Substitute(R"(
97 { // reduction, all threads inside workgroup must execute this code
98 $3[local_id] = $1;
99 LOCAL_MEM_BARRIER;
100 // The number of items still need to be summed
101 int reduction_size = $0;
102 while (reduction_size > 1) {
103 int active_thread_limit = reduction_size / 2;
104 int offset = (reduction_size + 1) / 2;
105 if (local_id < active_thread_limit) {
106 $1 += $3[local_id + offset];
107 $3[local_id] = $1;
108 }
109 LOCAL_MEM_BARRIER;
110 reduction_size = offset;
111 }
112 $2 = $3[0];
113 }
114 )",
115 reduction_size, src_value, dst_value, mem_name);
116 }
117 }
118
ZeroClampVec4Code(const std::string & slice_name,const std::string & channels_name,const std::string & value_name)119 std::string ZeroClampVec4Code(const std::string& slice_name,
120 const std::string& channels_name,
121 const std::string& value_name) {
122 return absl::Substitute(R"(
123 // no need to check first element, always valid
124 if ($0 * 4 + 1 >= $1) { $2.y = 0.0f; }
125 if ($0 * 4 + 2 >= $1) { $2.z = 0.0f; }
126 if ($0 * 4 + 3 >= $1) { $2.w = 0.0f; }
127 )",
128 slice_name, channels_name, value_name);
129 }
130 } // namespace
131
MeanStdDevNormalization(const OperationDef & definition,const GpuInfo & gpu_info,const BHWC & shape,float variance_bias,bool two_step)132 MeanStdDevNormalization::MeanStdDevNormalization(const OperationDef& definition,
133 const GpuInfo& gpu_info,
134 const BHWC& shape,
135 float variance_bias,
136 bool two_step)
137 : GPUOperation(definition) {
138 const int tensor_slices = DivideRoundUp(shape.c, 4);
139 int desired_work_group_size = gpu_info.GetMaxWorkGroupSizeForZ();
140 if (gpu_info.IsMali()) {
141 // Don't use more than 64 work items per work group on ARM Mali. They
142 // implement local memory using the global memory, larger workgroups have
143 // severe performance penalty.
144 desired_work_group_size = 64;
145 }
146 if (gpu_info.IsAdreno()) {
147 AdrenoInfo info = gpu_info.adreno_info;
148 desired_work_group_size = 256;
149 if (info.IsAdreno3xx()) {
150 if (info.adreno_gpu == AdrenoGpu::kAdreno320 ||
151 info.adreno_gpu == AdrenoGpu::kAdreno330) {
152 desired_work_group_size = 128;
153 } else {
154 desired_work_group_size = 64;
155 }
156 } else if (info.IsAdreno4xx()) {
157 if (info.adreno_gpu == AdrenoGpu::kAdreno430) {
158 desired_work_group_size = 256;
159 } else {
160 desired_work_group_size = 128;
161 }
162 } else if (info.IsAdreno5xx()) {
163 if (info.adreno_gpu == AdrenoGpu::kAdreno530 ||
164 info.adreno_gpu == AdrenoGpu::kAdreno540) {
165 desired_work_group_size = 256;
166 } else {
167 desired_work_group_size = 128;
168 }
169 }
170 }
171 if (gpu_info.IsPowerVR()) {
172 desired_work_group_size = 64;
173 }
174 if (gpu_info.IsApple()) {
175 desired_work_group_size = 64;
176 }
177 if (gpu_info.IsAMD()) {
178 desired_work_group_size = 512;
179 }
180 if (shape.w * shape.h == 1) {
181 desired_work_group_size =
182 std::min(desired_work_group_size, gpu_info.GetMaxWorkGroupSizeForZ());
183 while (desired_work_group_size >= tensor_slices * 2) {
184 desired_work_group_size /= 2;
185 }
186 work_group_size_.x = 1;
187 work_group_size_.y = 1;
188 work_group_size_.z = desired_work_group_size;
189 } else {
190 if (tensor_slices >= 16) {
191 work_group_size_.z = 8;
192 } else if (tensor_slices >= 10) {
193 work_group_size_.z = 4;
194 } else {
195 std::map<int, int> slices_to_group_size = {
196 {1, 1}, {2, 2}, {3, 3}, {4, 4}, {5, 3},
197 {6, 3}, {7, 4}, {8, 4}, {9, 3},
198 };
199 work_group_size_.z = slices_to_group_size[tensor_slices];
200 }
201 desired_work_group_size =
202 std::min(desired_work_group_size, gpu_info.GetMaxWorkGroupTotalSize());
203 work_group_size_.x = 1;
204 work_group_size_.y =
205 desired_work_group_size / AlignByN(work_group_size_.z, 4);
206 while (work_group_size_.y > work_group_size_.x) {
207 work_group_size_.y /= 2;
208 work_group_size_.x *= 2;
209 }
210 }
211 args_.AddFloat("variance_bias", variance_bias);
212 args_.AddFloat("inv_ch_count", 1.0f / shape.c);
213 code_ = GetNormalizationCode(gpu_info, shape.c % 4 == 0, two_step);
214 }
215
GetNormalizationCode(const GpuInfo & gpu_info,bool channels_x4,bool two_step)216 std::string MeanStdDevNormalization::GetNormalizationCode(
217 const GpuInfo& gpu_info, bool channels_x4, bool two_step) {
218 AddSrcTensor("src_tensor", definition_.src_tensors[0]);
219 AddDstTensor("dst_tensor", definition_.dst_tensors[0]);
220
221 std::string c;
222 if (gpu_info.IsApiOpenCl()) {
223 c += "__attribute__((reqd_work_group_size(" +
224 std::to_string(work_group_size_.x) + ", " +
225 std::to_string(work_group_size_.y) + ", " +
226 std::to_string(work_group_size_.z) + ")))\n";
227 }
228 c += "MAIN_FUNCTION($0) {\n";
229 std::string accum_type = two_step ? "float" : "float2";
230 if (work_group_size_.x * work_group_size_.y == 1) {
231 c += "__local " + accum_type + " shared_mem[" +
232 std::to_string(work_group_size_.z) + "];\n";
233 } else {
234 c += "__local " + accum_type + " shared_mem[" +
235 std::to_string(work_group_size_.x) + "][" +
236 std::to_string(work_group_size_.y) + "][" +
237 std::to_string(work_group_size_.z) + "];\n";
238 }
239 if (definition_.dst_tensors[0].HasAxis(Axis::BATCH)) {
240 c += " int linear_id = GLOBAL_ID_0;\n";
241 c += " int X = linear_id / args.dst_tensor.Batch();\n";
242 c += " int B = linear_id % args.dst_tensor.Batch();\n";
243 c += " args.src_tensor.SetBatchRef(B);\n";
244 c += " args.dst_tensor.SetBatchRef(B);\n";
245 } else {
246 c += " int X = GLOBAL_ID_0;\n";
247 }
248 c += " int Y = GLOBAL_ID_1;\n";
249 if (!two_step) {
250 c += " float4 private_sum4_sq = INIT_FLOAT4(0.0f);\n";
251 }
252 c += R"(
253 float4 private_sum4 = INIT_FLOAT4(0.0f);
254 int local_id = LOCAL_ID_2;
255 int reduction_group_size = GROUP_SIZE_2;
256 for (int S = local_id; S < args.src_tensor.Slices(); S += reduction_group_size) {
257 int x_clamped = min(X, args.src_tensor.Width() - 1);
258 int y_clamped = min(Y, args.src_tensor.Height() - 1);
259 float4 t = args.src_tensor.Read<float>(x_clamped, y_clamped, S);)";
260 if (!channels_x4) {
261 c += ZeroClampVec4Code("S", "args.src_tensor.Channels()", "t");
262 }
263 if (two_step) {
264 c += " private_sum4 += t;\n";
265 c += " }\n";
266 c += " float private_sum = dot(private_sum4, INIT_FLOAT4(1.0f));\n";
267 c += " float sum;\n";
268 } else {
269 c += " private_sum4 += t;\n";
270 c += " private_sum4_sq += t * t;\n";
271 c += " }\n";
272 c += " float2 private_sum;\n";
273 c += " private_sum.x = dot(private_sum4, INIT_FLOAT4(1.0f));\n";
274 c += " private_sum.y = dot(private_sum4_sq, INIT_FLOAT4(1.0f));\n";
275 c += " float2 sum;\n";
276 }
277 c += GetReduceCode("private_sum", "sum", work_group_size_, two_step);
278 if (two_step) {
279 c += R"(
280 // Calculate the mean
281 float mean = sum * args.inv_ch_count;
282 // Calculate the squared sum of the difference from the mean.
283 float4 private_sum_diff_sq4 = INIT_FLOAT4(0.0f);
284 for (int S = local_id; S < args.src_tensor.Slices(); S += reduction_group_size) {
285 int x_clamped = min(X, args.src_tensor.Width() - 1);
286 int y_clamped = min(Y, args.src_tensor.Height() - 1);
287 float4 t = args.src_tensor.Read<float>(x_clamped, y_clamped, S);
288 float4 diff = t - mean;)";
289 if (!channels_x4) {
290 c += ZeroClampVec4Code("S", "args.src_tensor.Channels()", "diff");
291 }
292 c += R"(
293 private_sum_diff_sq4 += diff * diff;
294 }
295 // Reduce
296 float private_sum_diff_sq = dot(private_sum_diff_sq4, INIT_FLOAT4(1.0f));
297 float sum_diff_sq;
298 )";
299 c += GetReduceCode("private_sum_diff_sq", "sum_diff_sq", work_group_size_,
300 two_step);
301 c += " float variance = sum_diff_sq * args.inv_ch_count;\n";
302 } else {
303 c += " float mean = sum.x * args.inv_ch_count;\n";
304 c += " float mean_sq = sum.y * args.inv_ch_count;\n";
305 c += " float variance = mean_sq - mean * mean;\n";
306 }
307 c += R"(
308 // no more shared memory usage, 'useless' threads can exit now
309 if (X >= args.dst_tensor.Width()) { return; }
310 if (Y >= args.dst_tensor.Height()) { return; }
311 // Calculate 1/stddev (with the 'regulazing constant' as in tensor_utils.cc)
312 float stddev_inv = rsqrt(variance + args.variance_bias);
313 // Calculate (t-mean)/stddev for each element
314 for (int S = local_id; S < args.src_tensor.Slices(); S += reduction_group_size) {
315 float4 t = args.src_tensor.Read<float>(X, Y, S);
316 FLT4 result = TO_FLT4((t - mean) * stddev_inv);
317 args.dst_tensor.Write(result, X, Y, S);
318 }
319 })";
320 return c;
321 }
322
GetGridSize() const323 int3 MeanStdDevNormalization::GetGridSize() const {
324 // To avoid dealing with global reductions, we restrict the grid size to the
325 // work group size in the first dimension.
326 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
327 const int grid_y = dst_[0]->Height();
328 const int grid_z = work_group_size_.z;
329 return int3(grid_x, grid_y, grid_z);
330 }
331
CreateMeanStdDevNormalization(const OperationDef & definition,const GpuInfo & gpu_info,const BHWC & shape,float variance_bias,bool two_step)332 MeanStdDevNormalization CreateMeanStdDevNormalization(
333 const OperationDef& definition, const GpuInfo& gpu_info, const BHWC& shape,
334 float variance_bias, bool two_step) {
335 return MeanStdDevNormalization(definition, gpu_info, shape, variance_bias,
336 two_step);
337 }
338
TryMeanStdDevNormalization(const GpuInfo & gpu_info,CalculationsPrecision precision,const GraphFloat32 & graph,NodeId first_node_id,const std::map<ValueId,TensorDescriptor> & tensor_descriptors,std::set<NodeId> * consumed_nodes,GPUOperationsSubgraph * gpu_subgraph)339 absl::Status TryMeanStdDevNormalization(
340 const GpuInfo& gpu_info, CalculationsPrecision precision,
341 const GraphFloat32& graph, NodeId first_node_id,
342 const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
343 std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) {
344 Node* first_mean_node = graph.GetNode(first_node_id);
345 RETURN_IF_ERROR(CheckIfValidNodeOfType(first_mean_node, OperationType::MEAN));
346 auto first_mean_attr =
347 absl::any_cast<MeanAttributes>(first_mean_node->operation.attributes);
348 if (first_mean_attr.dims != std::set<Axis>{Axis::CHANNELS}) {
349 return absl::NotFoundError("MeanStdDevNormalization not suitable.");
350 }
351 Node* sub_node;
352 RETURN_IF_ERROR(GetNextSingleNode(graph, *first_mean_node, OperationType::SUB,
353 &sub_node));
354 auto sub_inputs = graph.FindInputs(sub_node->id);
355 if (sub_inputs.size() != 2) {
356 return absl::NotFoundError("MeanStdDevNormalization not suitable.");
357 } else {
358 // checking structure
359 // input
360 // / \
361 // | mean
362 // \ /
363 // substraction
364 Node* sub_first_parent = graph.FindProducer(sub_inputs[0]->id);
365 Node* sub_second_parent = graph.FindProducer(sub_inputs[1]->id);
366 if (sub_second_parent != first_mean_node) {
367 return absl::NotFoundError("MeanStdDevNormalization not suitable.");
368 }
369 auto mean_inputs = graph.FindInputs(first_mean_node->id);
370 Node* mean_parent = graph.FindProducer(mean_inputs[0]->id);
371 if (mean_parent != sub_first_parent) {
372 return absl::NotFoundError("MeanStdDevNormalization not suitable.");
373 }
374 }
375 auto sub_output = graph.FindOutputs(sub_node->id)[0]->id;
376 auto consumers = graph.FindConsumers(sub_output);
377 if (consumers.size() != 2) {
378 return absl::NotFoundError("MeanStdDevNormalization not suitable.");
379 }
380 Node* square_node = consumers[0];
381 Node* sub_child_mul_node = consumers[1];
382 if (!CheckIfValidNodeOfType(square_node, OperationType::SQUARE).ok()) {
383 square_node = consumers[1];
384 sub_child_mul_node = consumers[0];
385 }
386 RETURN_IF_ERROR(CheckIfValidNodeOfType(square_node, OperationType::SQUARE));
387 RETURN_IF_ERROR(
388 CheckIfValidNodeOfType(sub_child_mul_node, OperationType::MUL));
389 Node* second_mean_node;
390 RETURN_IF_ERROR(GetNextSingleNode(graph, *square_node, OperationType::MEAN,
391 &second_mean_node));
392 auto second_mean_attr =
393 absl::any_cast<MeanAttributes>(second_mean_node->operation.attributes);
394 if (second_mean_attr.dims != std::set<Axis>{Axis::CHANNELS}) {
395 return absl::NotFoundError("MeanStdDevNormalization not suitable.");
396 }
397 Node* add_node;
398 RETURN_IF_ERROR(GetNextSingleNode(graph, *second_mean_node,
399 OperationType::ADD, &add_node));
400 float add_value;
401 RETURN_IF_ERROR(GetElementwiseScalarValue(add_node, &add_value));
402 Node* rsqrt_node;
403 RETURN_IF_ERROR(
404 GetNextSingleNode(graph, *add_node, OperationType::RSQRT, &rsqrt_node));
405 Node* mul_node;
406 RETURN_IF_ERROR(
407 GetNextSingleNode(graph, *rsqrt_node, OperationType::MUL, &mul_node));
408 if (sub_child_mul_node != mul_node) {
409 return absl::NotFoundError("MeanStdDevNormalization not suitable.");
410 }
411
412 OperationDef op_def;
413 op_def.precision = precision;
414 auto input_id = graph.FindInputs(first_mean_node->id)[0]->id;
415 auto it = tensor_descriptors.find(input_id);
416 if (it != tensor_descriptors.end()) {
417 op_def.src_tensors.push_back(it->second);
418 }
419 auto output_id = graph.FindInputs(mul_node->id)[0]->id;
420 it = tensor_descriptors.find(output_id);
421 if (it != tensor_descriptors.end()) {
422 op_def.dst_tensors.push_back(it->second);
423 }
424
425 auto subgraph_inputs = graph.FindInputs(first_mean_node->id);
426 auto subgraph_outputs = graph.FindOutputs(mul_node->id);
427 std::unique_ptr<GPUOperation>* gpu_op =
428 InitSingleOpSubgraph(subgraph_inputs, subgraph_outputs, gpu_subgraph);
429 *gpu_op =
430 std::make_unique<MeanStdDevNormalization>(CreateMeanStdDevNormalization(
431 op_def, gpu_info, subgraph_inputs[0]->tensor.shape, add_value,
432 /*two_step*/ false));
433
434 consumed_nodes->insert(first_mean_node->id);
435 consumed_nodes->insert(sub_node->id);
436 consumed_nodes->insert(square_node->id);
437 consumed_nodes->insert(second_mean_node->id);
438 consumed_nodes->insert(add_node->id);
439 consumed_nodes->insert(rsqrt_node->id);
440 consumed_nodes->insert(mul_node->id);
441
442 return absl::OkStatus();
443 }
444
445 } // namespace gpu
446 } // namespace tflite
447