xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/softmax.h"
17 
18 #include <any>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "tensorflow/lite/delegates/gpu/common/shape.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/delegates/gpu/common/types.h"
28 #include "tensorflow/lite/delegates/gpu/common/util.h"
29 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
30 
31 namespace tflite {
32 namespace gpu {
33 namespace gl {
34 namespace {
35 
GetMask(int num_channels)36 float4 GetMask(int num_channels) {
37   float4 mask(0.0f);
38   const int remainder = num_channels % 4 == 0 ? 4 : num_channels % 4;
39   for (int i = 0; i < remainder; ++i) mask[i] = 1.0f;
40   return mask;
41 }
42 
43 class Softmax : public NodeShader {
44  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const45   absl::Status GenerateCode(const GenerationContext& ctx,
46                             GeneratedCode* generated_code) const final {
47     const auto& attr = std::any_cast<const SoftmaxAttributes&>(ctx.op_attr);
48     if (ctx.input_shapes[0] != ctx.output_shapes[0]) {
49       return absl::InvalidArgumentError(
50           "Input and output shapes do not match.");
51     }
52     if (attr.axis != Axis::CHANNELS) {
53       return absl::UnimplementedError(
54           "Softmax is only supported for channels axis.");
55     }
56     return ctx.input_shapes[0][1] == 1 && ctx.input_shapes[0][2] == 1
57                ? GenerateCodeFor1x1(ctx, generated_code)
58                : GenerateCodeGeneral(ctx, generated_code);
59   }
60 
61  private:
GenerateCodeFor1x1(const GenerationContext & ctx,GeneratedCode * generated_code) const62   absl::Status GenerateCodeFor1x1(const GenerationContext& ctx,
63                                   GeneratedCode* generated_code) const {
64     const int depth = DivideRoundUp(ctx.output_shapes[0][3], 4);
65     std::vector<Variable> shared_variables = {
66         {"partial_sum", std::vector<float4>(8)},
67     };
68     std::vector<Variable> uniform_parameters = {
69         {"depth", depth},
70         {"mask", GetMask(ctx.output_shapes[0][3])},
71     };
72     std::string source_code = R"(
73   highp vec4 kOnes = vec4(1.0);
74   int tid = int(gl_LocalInvocationID.x);
75   highp vec4 maxx4 = $input_data_0[0, 0, 0]$;
76   maxx4.y = maxx4.x;
77   maxx4.z = maxx4.x;
78   maxx4.w = maxx4.x;
79   for (int s = tid; s < $depth$; s += 32) {
80     highp vec4 mask_a = s == $depth$ - 1 ? $mask$ : kOnes;
81     highp vec4 mask_b = kOnes - mask_a;
82     highp vec4 src = $input_data_0[0, 0, s]$;
83     src = src * mask_a + mask_b * src.x;
84     maxx4 = max(maxx4, src);
85   }
86   highp float maximum = max(maxx4.x, maxx4.y);
87   maximum = max(maximum, maxx4.z);
88   maximum = max(maximum, maxx4.w);
89   partial_sum[tid / 4][tid % 4] = maximum;
90 
91   memoryBarrierShared();
92   barrier();
93 
94   if (tid == 0) {
95     maxx4 = max(partial_sum[0], partial_sum[1]);
96     maxx4 = max(maxx4, partial_sum[2]);
97     maxx4 = max(maxx4, partial_sum[3]);
98     maxx4 = max(maxx4, partial_sum[4]);
99     maxx4 = max(maxx4, partial_sum[5]);
100     maxx4 = max(maxx4, partial_sum[6]);
101     maxx4 = max(maxx4, partial_sum[7]);
102     maximum = max(maxx4.x, maxx4.y);
103     maximum = max(maximum, maxx4.z);
104     maximum = max(maximum, maxx4.w);
105     partial_sum[0][0] = maximum;
106   }
107 
108   memoryBarrierShared();
109   barrier();
110 
111   maximum = partial_sum[0][0];
112 
113   highp float sum = 0.0;
114   for (int s = tid; s < $depth$; s += 32) {
115     highp vec4 mask_temp = s == $depth$ - 1 ? $mask$ : kOnes;
116     highp vec4 src = $input_data_0[0, 0, s]$ - vec4(maximum);
117     sum += dot(mask_temp, exp(src));
118   }
119 
120   memoryBarrierShared();
121   barrier();
122 
123   partial_sum[tid / 4][tid % 4] = sum;
124 
125   memoryBarrierShared();
126   barrier();
127 
128   if (tid == 0) {
129     sum = dot(kOnes, partial_sum[0]);
130     sum += dot(kOnes, partial_sum[1]);
131     sum += dot(kOnes, partial_sum[2]);
132     sum += dot(kOnes, partial_sum[3]);
133     sum += dot(kOnes, partial_sum[4]);
134     sum += dot(kOnes, partial_sum[5]);
135     sum += dot(kOnes, partial_sum[6]);
136     sum += dot(kOnes, partial_sum[7]);
137     partial_sum[0][0] = 1.0 / sum;
138   }
139 
140   memoryBarrierShared();
141   barrier();
142 
143   sum = partial_sum[0][0];
144 
145   int dst_s = int(gl_GlobalInvocationID.x);
146   if (dst_s < $depth$) {
147     highp vec4 src = $input_data_0[0, 0, dst_s]$ - vec4(maximum);
148     highp vec4 temp = exp(src) * sum;
149     $output_data_0[0, 0, dst_s] = temp$;
150   }
151 )";
152 
153     *generated_code = {
154         /*parameters=*/std::move(uniform_parameters),
155         /*objects=*/{},
156         /*shared_variables=*/std::move(shared_variables),
157         /*workload=*/uint3(depth, 1, 1),
158         /*workgroup=*/uint3(32, 1, 1),
159         /*source_code=*/std::move(source_code),
160         /*input=*/IOStructure::ONLY_DEFINITIONS,
161         /*output=*/IOStructure::ONLY_DEFINITIONS,
162     };
163     return absl::OkStatus();
164   }
165 
GenerateCodeGeneral(const GenerationContext & ctx,GeneratedCode * generated_code) const166   absl::Status GenerateCodeGeneral(const GenerationContext& ctx,
167                                    GeneratedCode* generated_code) const {
168     std::vector<Variable> parameters = {
169         {"src_depth",
170          DivideRoundUp(static_cast<int>(ctx.output_shapes[0][3]), 4)},
171         {"mask", GetMask(ctx.output_shapes[0][3])},
172     };
173 
174     std::string source_code = R"(
175   highp vec4 kOnes = vec4(1.0);
176   highp float sum = 0.0;
177   highp float maximum = $input_data_0[gid.x, gid.y, 0]$.x;
178   for (int d = 0; d < $src_depth$; ++d) {
179     highp vec4 mask_a = d == $src_depth$ - 1 ? $mask$ : kOnes;
180     highp vec4 mask_b = kOnes - mask_a;
181     highp vec4 src = $input_data_0[gid.x, gid.y, d]$;
182     src = src * mask_a + mask_b * src.x;
183     maximum = max(maximum, src.x);
184     maximum = max(maximum, src.y);
185     maximum = max(maximum, src.z);
186     maximum = max(maximum, src.w);
187   }
188   for (int d = 0; d < $src_depth$; ++d) {
189     highp vec4 mask_temp = d == $src_depth$ - 1 ? $mask$ : kOnes;
190     highp vec4 src = $input_data_0[gid.x, gid.y, d]$ - vec4(maximum);
191     sum += dot(mask_temp, exp(src));
192   }
193   for (int d = 0; d < $src_depth$; ++d) {
194     highp vec4 src = $input_data_0[gid.x, gid.y, d]$ - vec4(maximum);
195     highp vec4 temp_sum = exp(src) / sum;
196     $output_data_0[gid.x, gid.y, d] = temp_sum$;
197   }
198 )";
199     *generated_code = {
200         /*parameters=*/std::move(parameters),
201         /*objects=*/{},
202         /*shared_variables=*/{},
203         /*workload=*/
204         uint3(static_cast<int>(ctx.output_shapes[0][2]),
205               static_cast<int>(ctx.output_shapes[0][1]), 1),
206         /*workgroup=*/uint3(),
207         /*source_code=*/std::move(source_code),
208         /*input=*/IOStructure::ONLY_DEFINITIONS,
209         /*output=*/IOStructure::ONLY_DEFINITIONS,
210     };
211     return absl::OkStatus();
212   }
213 };
214 
215 }  // namespace
216 
NewSoftmaxNodeShader()217 std::unique_ptr<NodeShader> NewSoftmaxNodeShader() {
218   return std::make_unique<Softmax>();
219 }
220 
221 }  // namespace gl
222 }  // namespace gpu
223 }  // namespace tflite
224