1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/softmax.h"
17
18 #include <any>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "tensorflow/lite/delegates/gpu/common/shape.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/delegates/gpu/common/types.h"
28 #include "tensorflow/lite/delegates/gpu/common/util.h"
29 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
30
31 namespace tflite {
32 namespace gpu {
33 namespace gl {
34 namespace {
35
GetMask(int num_channels)36 float4 GetMask(int num_channels) {
37 float4 mask(0.0f);
38 const int remainder = num_channels % 4 == 0 ? 4 : num_channels % 4;
39 for (int i = 0; i < remainder; ++i) mask[i] = 1.0f;
40 return mask;
41 }
42
43 class Softmax : public NodeShader {
44 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const45 absl::Status GenerateCode(const GenerationContext& ctx,
46 GeneratedCode* generated_code) const final {
47 const auto& attr = std::any_cast<const SoftmaxAttributes&>(ctx.op_attr);
48 if (ctx.input_shapes[0] != ctx.output_shapes[0]) {
49 return absl::InvalidArgumentError(
50 "Input and output shapes do not match.");
51 }
52 if (attr.axis != Axis::CHANNELS) {
53 return absl::UnimplementedError(
54 "Softmax is only supported for channels axis.");
55 }
56 return ctx.input_shapes[0][1] == 1 && ctx.input_shapes[0][2] == 1
57 ? GenerateCodeFor1x1(ctx, generated_code)
58 : GenerateCodeGeneral(ctx, generated_code);
59 }
60
61 private:
GenerateCodeFor1x1(const GenerationContext & ctx,GeneratedCode * generated_code) const62 absl::Status GenerateCodeFor1x1(const GenerationContext& ctx,
63 GeneratedCode* generated_code) const {
64 const int depth = DivideRoundUp(ctx.output_shapes[0][3], 4);
65 std::vector<Variable> shared_variables = {
66 {"partial_sum", std::vector<float4>(8)},
67 };
68 std::vector<Variable> uniform_parameters = {
69 {"depth", depth},
70 {"mask", GetMask(ctx.output_shapes[0][3])},
71 };
72 std::string source_code = R"(
73 highp vec4 kOnes = vec4(1.0);
74 int tid = int(gl_LocalInvocationID.x);
75 highp vec4 maxx4 = $input_data_0[0, 0, 0]$;
76 maxx4.y = maxx4.x;
77 maxx4.z = maxx4.x;
78 maxx4.w = maxx4.x;
79 for (int s = tid; s < $depth$; s += 32) {
80 highp vec4 mask_a = s == $depth$ - 1 ? $mask$ : kOnes;
81 highp vec4 mask_b = kOnes - mask_a;
82 highp vec4 src = $input_data_0[0, 0, s]$;
83 src = src * mask_a + mask_b * src.x;
84 maxx4 = max(maxx4, src);
85 }
86 highp float maximum = max(maxx4.x, maxx4.y);
87 maximum = max(maximum, maxx4.z);
88 maximum = max(maximum, maxx4.w);
89 partial_sum[tid / 4][tid % 4] = maximum;
90
91 memoryBarrierShared();
92 barrier();
93
94 if (tid == 0) {
95 maxx4 = max(partial_sum[0], partial_sum[1]);
96 maxx4 = max(maxx4, partial_sum[2]);
97 maxx4 = max(maxx4, partial_sum[3]);
98 maxx4 = max(maxx4, partial_sum[4]);
99 maxx4 = max(maxx4, partial_sum[5]);
100 maxx4 = max(maxx4, partial_sum[6]);
101 maxx4 = max(maxx4, partial_sum[7]);
102 maximum = max(maxx4.x, maxx4.y);
103 maximum = max(maximum, maxx4.z);
104 maximum = max(maximum, maxx4.w);
105 partial_sum[0][0] = maximum;
106 }
107
108 memoryBarrierShared();
109 barrier();
110
111 maximum = partial_sum[0][0];
112
113 highp float sum = 0.0;
114 for (int s = tid; s < $depth$; s += 32) {
115 highp vec4 mask_temp = s == $depth$ - 1 ? $mask$ : kOnes;
116 highp vec4 src = $input_data_0[0, 0, s]$ - vec4(maximum);
117 sum += dot(mask_temp, exp(src));
118 }
119
120 memoryBarrierShared();
121 barrier();
122
123 partial_sum[tid / 4][tid % 4] = sum;
124
125 memoryBarrierShared();
126 barrier();
127
128 if (tid == 0) {
129 sum = dot(kOnes, partial_sum[0]);
130 sum += dot(kOnes, partial_sum[1]);
131 sum += dot(kOnes, partial_sum[2]);
132 sum += dot(kOnes, partial_sum[3]);
133 sum += dot(kOnes, partial_sum[4]);
134 sum += dot(kOnes, partial_sum[5]);
135 sum += dot(kOnes, partial_sum[6]);
136 sum += dot(kOnes, partial_sum[7]);
137 partial_sum[0][0] = 1.0 / sum;
138 }
139
140 memoryBarrierShared();
141 barrier();
142
143 sum = partial_sum[0][0];
144
145 int dst_s = int(gl_GlobalInvocationID.x);
146 if (dst_s < $depth$) {
147 highp vec4 src = $input_data_0[0, 0, dst_s]$ - vec4(maximum);
148 highp vec4 temp = exp(src) * sum;
149 $output_data_0[0, 0, dst_s] = temp$;
150 }
151 )";
152
153 *generated_code = {
154 /*parameters=*/std::move(uniform_parameters),
155 /*objects=*/{},
156 /*shared_variables=*/std::move(shared_variables),
157 /*workload=*/uint3(depth, 1, 1),
158 /*workgroup=*/uint3(32, 1, 1),
159 /*source_code=*/std::move(source_code),
160 /*input=*/IOStructure::ONLY_DEFINITIONS,
161 /*output=*/IOStructure::ONLY_DEFINITIONS,
162 };
163 return absl::OkStatus();
164 }
165
GenerateCodeGeneral(const GenerationContext & ctx,GeneratedCode * generated_code) const166 absl::Status GenerateCodeGeneral(const GenerationContext& ctx,
167 GeneratedCode* generated_code) const {
168 std::vector<Variable> parameters = {
169 {"src_depth",
170 DivideRoundUp(static_cast<int>(ctx.output_shapes[0][3]), 4)},
171 {"mask", GetMask(ctx.output_shapes[0][3])},
172 };
173
174 std::string source_code = R"(
175 highp vec4 kOnes = vec4(1.0);
176 highp float sum = 0.0;
177 highp float maximum = $input_data_0[gid.x, gid.y, 0]$.x;
178 for (int d = 0; d < $src_depth$; ++d) {
179 highp vec4 mask_a = d == $src_depth$ - 1 ? $mask$ : kOnes;
180 highp vec4 mask_b = kOnes - mask_a;
181 highp vec4 src = $input_data_0[gid.x, gid.y, d]$;
182 src = src * mask_a + mask_b * src.x;
183 maximum = max(maximum, src.x);
184 maximum = max(maximum, src.y);
185 maximum = max(maximum, src.z);
186 maximum = max(maximum, src.w);
187 }
188 for (int d = 0; d < $src_depth$; ++d) {
189 highp vec4 mask_temp = d == $src_depth$ - 1 ? $mask$ : kOnes;
190 highp vec4 src = $input_data_0[gid.x, gid.y, d]$ - vec4(maximum);
191 sum += dot(mask_temp, exp(src));
192 }
193 for (int d = 0; d < $src_depth$; ++d) {
194 highp vec4 src = $input_data_0[gid.x, gid.y, d]$ - vec4(maximum);
195 highp vec4 temp_sum = exp(src) / sum;
196 $output_data_0[gid.x, gid.y, d] = temp_sum$;
197 }
198 )";
199 *generated_code = {
200 /*parameters=*/std::move(parameters),
201 /*objects=*/{},
202 /*shared_variables=*/{},
203 /*workload=*/
204 uint3(static_cast<int>(ctx.output_shapes[0][2]),
205 static_cast<int>(ctx.output_shapes[0][1]), 1),
206 /*workgroup=*/uint3(),
207 /*source_code=*/std::move(source_code),
208 /*input=*/IOStructure::ONLY_DEFINITIONS,
209 /*output=*/IOStructure::ONLY_DEFINITIONS,
210 };
211 return absl::OkStatus();
212 }
213 };
214
215 } // namespace
216
NewSoftmaxNodeShader()217 std::unique_ptr<NodeShader> NewSoftmaxNodeShader() {
218 return std::make_unique<Softmax>();
219 }
220
221 } // namespace gl
222 } // namespace gpu
223 } // namespace tflite
224