1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/prelu.h"
17
18 #include <algorithm>
19 #include <any>
20 #include <cstdint>
21 #include <cstring>
22 #include <memory>
23 #include <string>
24 #include <variant>
25 #include <vector>
26
27 #include "absl/memory/memory.h"
28 #include "tensorflow/lite/delegates/gpu/common/convert.h"
29 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
30 #include "tensorflow/lite/delegates/gpu/common/shape.h"
31 #include "tensorflow/lite/delegates/gpu/common/status.h"
32 #include "tensorflow/lite/delegates/gpu/common/types.h"
33
34 namespace tflite {
35 namespace gpu {
36 namespace gl {
37 namespace {
38
39 class PReLULinearAlpha : public NodeShader {
40 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const41 absl::Status GenerateCode(const GenerationContext& ctx,
42 GeneratedCode* generated_code) const final {
43 const auto& attr = std::any_cast<const PReLUAttributes&>(ctx.op_attr);
44 auto alpha = std::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.alpha);
45 if (!alpha) {
46 return absl::InvalidArgumentError("Alpha is missing");
47 }
48 if (alpha->shape.v != ctx.output_shapes[0][3]) {
49 return absl::InvalidArgumentError(
50 "Alpha shape does not match the number of channels.");
51 }
52
53 *generated_code = GeneratedCode{
54 /*parameters=*/{},
55 /*objects=*/{{"alpha", MakeReadonlyObject(alpha->data)}},
56 /*shared_variables=*/{},
57 // Declare workload explicitly because shader depends on
58 // gid.z.
59 /*workload=*/
60 uint3(static_cast<int>(ctx.output_shapes[0][2]),
61 static_cast<int>(ctx.output_shapes[0][1]),
62 DivideRoundUp(static_cast<int>(ctx.output_shapes[0][3]), 4)),
63 /*workgroup=*/uint3(),
64 /*source_code=*/
65 "value_0 = max(value_0, 0.0) + $alpha[gid.z]$ * min(value_0, "
66 "0.0);",
67 /*input=*/IOStructure::AUTO,
68 /*output=*/IOStructure::AUTO,
69 };
70 return absl::OkStatus();
71 }
72 };
73
74 class PReLUFull : public NodeShader {
75 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const76 absl::Status GenerateCode(const GenerationContext& ctx,
77 GeneratedCode* generated_code) const final {
78 const auto& attr = std::any_cast<const PReLUAttributes&>(ctx.op_attr);
79 auto alpha = std::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr.alpha);
80 if (!alpha) {
81 return absl::InvalidArgumentError("Alpha is missing");
82 }
83 if (alpha->shape.h != ctx.output_shapes[0][1] ||
84 alpha->shape.w != ctx.output_shapes[0][2] ||
85 alpha->shape.c != ctx.output_shapes[0][3]) {
86 return absl::InvalidArgumentError(
87 "Alpha shape does not match input shape.");
88 }
89
90 ObjectSize obj_size =
91 uint3(static_cast<int>(ctx.output_shapes[0][2]),
92 static_cast<int>(ctx.output_shapes[0][1]),
93 DivideRoundUp(static_cast<int>(ctx.output_shapes[0][3]), 4));
94
95 *generated_code = GeneratedCode{
96 /*parameters=*/{},
97 /*objects=*/
98 {{"alpha", MakeReadonlyObject(obj_size, ConvertToPHWC4(*alpha))}},
99 /*shared_variables=*/{},
100 // Declare workload explicitly because shader depends on
101 // gid.z.
102 /*workload=*/
103 uint3(static_cast<int>(ctx.output_shapes[0][2]),
104 static_cast<int>(ctx.output_shapes[0][1]),
105 DivideRoundUp(static_cast<int>(ctx.output_shapes[0][3]), 4)),
106 /*workgroup=*/uint3(),
107 /*source_code=*/
108 "value_0 = max(value_0, 0.0) + $alpha[gid.x, gid.y, gid.z]$ "
109 "* min(value_0, 0.0);",
110 /*input=*/IOStructure::AUTO,
111 /*output=*/IOStructure::AUTO,
112 };
113 return absl::OkStatus();
114 }
115 };
116
117 class PReLU : public NodeShader {
118 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const119 absl::Status GenerateCode(const GenerationContext& ctx,
120 GeneratedCode* generated_code) const final {
121 const auto& attr = std::any_cast<const PReLUAttributes&>(ctx.op_attr);
122 auto* alpha = std::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr.alpha);
123 return alpha ? full_.GenerateCode(ctx, generated_code)
124 : linear_.GenerateCode(ctx, generated_code);
125 }
126
127 private:
128 PReLULinearAlpha linear_;
129 PReLUFull full_;
130 };
131
132 } // namespace
133
NewPReLUNodeShader()134 std::unique_ptr<NodeShader> NewPReLUNodeShader() {
135 return std::make_unique<PReLU>();
136 }
137
138 } // namespace gl
139 } // namespace gpu
140 } // namespace tflite
141