1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/resize.h"
17
18 #include <algorithm>
19 #include <any>
20 #include <cstdint>
21 #include <cstring>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/memory/memory.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/status.h"
30 #include "tensorflow/lite/delegates/gpu/common/types.h"
31
32 namespace tflite {
33 namespace gpu {
34 namespace gl {
35 namespace {
36
37 class Resize : public NodeShader {
38 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const39 absl::Status GenerateCode(const GenerationContext& ctx,
40 GeneratedCode* generated_code) const final {
41 const auto& attr = std::any_cast<const Resize2DAttributes&>(ctx.op_attr);
42
43 if (ctx.input_shapes[0][2] > ctx.output_shapes[0][2] ||
44 ctx.input_shapes[0][1] > ctx.output_shapes[0][1]) {
45 return absl::InvalidArgumentError("Output size is less than input size.");
46 }
47 if (ctx.output_shapes[0][2] != attr.new_shape.w ||
48 ctx.output_shapes[0][1] != attr.new_shape.h) {
49 return absl::InvalidArgumentError(
50 "Output size does not match new_size in attributes.");
51 }
52 if (ctx.input_shapes[0][3] != ctx.output_shapes[0][3]) {
53 return absl::InvalidArgumentError("Input/output channels mismatch.");
54 }
55 if (ctx.input_shapes[0][1] == 1 && ctx.input_shapes[0][2] == 1) {
56 // Copy a single element from input.
57 *generated_code = {
58 /*parameters=*/{},
59 /*objects=*/{},
60 /*shared_variables=*/{},
61 /*workload=*/uint3(),
62 /*workgroup=*/uint3(),
63 /*source_code=*/"value_0 = $input_data_0[0, 0, gid.z]$;",
64 /*input=*/IOStructure::ONLY_DEFINITIONS,
65 /*output=*/IOStructure::AUTO,
66 };
67 return absl::OkStatus();
68 }
69 std::vector<Variable> parameters = {
70 {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
71 {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
72 {"scale_factor",
73 float2(CalculateResizeScale(ctx.input_shapes[0][2],
74 ctx.output_shapes[0][2], attr),
75 CalculateResizeScale(ctx.input_shapes[0][1],
76 ctx.output_shapes[0][1], attr))},
77 };
78
79 std::string source;
80 if (attr.type == SamplingType::BILINEAR) {
81 if (attr.half_pixel_centers) {
82 source = "vec2 coord = (vec2(gid.xy) + 0.5) * $scale_factor$ - 0.5;";
83 } else {
84 source = "vec2 coord = vec2(gid.xy) * $scale_factor$;";
85 }
86 source += R"(
87 vec2 coord_floor = floor(coord);
88 ivec2 icoord_floor = ivec2(coord_floor);
89 ivec2 borders = ivec2($input_data_0_w$, $input_data_0_h$) - ivec2(1, 1);
90 ivec4 st;
91 st.xy = max(icoord_floor, ivec2(0, 0));
92 st.zw = min(icoord_floor + ivec2(1, 1), borders);
93
94 vec2 t = coord - coord_floor; // interpolating factors
95
96 vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$;
97 vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$;
98 vec4 tex12 = $input_data_0[st.x, st.w, gid.z]$;
99 vec4 tex22 = $input_data_0[st.z, st.w, gid.z]$;
100
101 value_0 = mix(mix(tex11, tex21, t.x), mix(tex12, tex22, t.x), t.y);)";
102 } else if (attr.type == SamplingType::NEAREST) {
103 std::string fxc;
104 std::string fyc;
105 if (attr.half_pixel_centers) {
106 fxc = "(float(gid.x) + 0.5) * $scale_factor.x$";
107 fyc = "(float(gid.y) + 0.5) * $scale_factor.y$";
108 } else {
109 fxc = "float(gid.x) * $scale_factor.x$";
110 fyc = "float(gid.y) * $scale_factor.y$";
111 }
112 if (attr.align_corners) {
113 fxc += " + 0.5";
114 fyc += " + 0.5";
115 }
116 source += " ivec2 coord;\n";
117 source += " coord.x = int(" + fxc + ");\n";
118 source += " coord.y = int(" + fyc + ");\n";
119 source += " coord.x = max(0, coord.x);\n";
120 source += " coord.y = max(0, coord.y);\n";
121 source += " coord.x = min(coord.x, $input_data_0_w$ - 1);\n";
122 source += " coord.y = min(coord.y, $input_data_0_h$ - 1);\n";
123 source += R"(
124 value_0 = $input_data_0[coord.x, coord.y, gid.z]$;
125 )";
126 } else {
127 return absl::InvalidArgumentError("Unknown sampling type");
128 }
129 *generated_code = {
130 /*parameters=*/std::move(parameters),
131 /*objects=*/{},
132 /*shared_variables=*/{},
133 /*workload=*/uint3(),
134 /*workgroup=*/uint3(),
135 /*source_code=*/std::move(source),
136 /*input=*/IOStructure::ONLY_DEFINITIONS,
137 /*output=*/IOStructure::AUTO,
138 };
139 return absl::OkStatus();
140 }
141 };
142
143 } // namespace
144
NewResizeNodeShader()145 std::unique_ptr<NodeShader> NewResizeNodeShader() {
146 return std::make_unique<Resize>();
147 }
148
149 } // namespace gl
150 } // namespace gpu
151 } // namespace tflite
152