xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/gl/kernels/resize.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/resize.h"
17 
18 #include <algorithm>
19 #include <any>
20 #include <cstdint>
21 #include <cstring>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/memory/memory.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/status.h"
30 #include "tensorflow/lite/delegates/gpu/common/types.h"
31 
32 namespace tflite {
33 namespace gpu {
34 namespace gl {
35 namespace {
36 
37 class Resize : public NodeShader {
38  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const39   absl::Status GenerateCode(const GenerationContext& ctx,
40                             GeneratedCode* generated_code) const final {
41     const auto& attr = std::any_cast<const Resize2DAttributes&>(ctx.op_attr);
42 
43     if (ctx.input_shapes[0][2] > ctx.output_shapes[0][2] ||
44         ctx.input_shapes[0][1] > ctx.output_shapes[0][1]) {
45       return absl::InvalidArgumentError("Output size is less than input size.");
46     }
47     if (ctx.output_shapes[0][2] != attr.new_shape.w ||
48         ctx.output_shapes[0][1] != attr.new_shape.h) {
49       return absl::InvalidArgumentError(
50           "Output size does not match new_size in attributes.");
51     }
52     if (ctx.input_shapes[0][3] != ctx.output_shapes[0][3]) {
53       return absl::InvalidArgumentError("Input/output channels mismatch.");
54     }
55     if (ctx.input_shapes[0][1] == 1 && ctx.input_shapes[0][2] == 1) {
56       // Copy a single element from input.
57       *generated_code = {
58           /*parameters=*/{},
59           /*objects=*/{},
60           /*shared_variables=*/{},
61           /*workload=*/uint3(),
62           /*workgroup=*/uint3(),
63           /*source_code=*/"value_0 = $input_data_0[0, 0, gid.z]$;",
64           /*input=*/IOStructure::ONLY_DEFINITIONS,
65           /*output=*/IOStructure::AUTO,
66       };
67       return absl::OkStatus();
68     }
69     std::vector<Variable> parameters = {
70         {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
71         {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
72         {"scale_factor",
73          float2(CalculateResizeScale(ctx.input_shapes[0][2],
74                                      ctx.output_shapes[0][2], attr),
75                 CalculateResizeScale(ctx.input_shapes[0][1],
76                                      ctx.output_shapes[0][1], attr))},
77     };
78 
79     std::string source;
80     if (attr.type == SamplingType::BILINEAR) {
81       if (attr.half_pixel_centers) {
82         source = "vec2 coord = (vec2(gid.xy) + 0.5) * $scale_factor$ - 0.5;";
83       } else {
84         source = "vec2 coord = vec2(gid.xy) * $scale_factor$;";
85       }
86       source += R"(
87       vec2 coord_floor = floor(coord);
88       ivec2 icoord_floor = ivec2(coord_floor);
89       ivec2 borders = ivec2($input_data_0_w$, $input_data_0_h$) - ivec2(1, 1);
90       ivec4 st;
91       st.xy = max(icoord_floor, ivec2(0, 0));
92       st.zw = min(icoord_floor + ivec2(1, 1), borders);
93 
94       vec2 t = coord - coord_floor; // interpolating factors
95 
96       vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$;
97       vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$;
98       vec4 tex12 = $input_data_0[st.x, st.w, gid.z]$;
99       vec4 tex22 = $input_data_0[st.z, st.w, gid.z]$;
100 
101       value_0 = mix(mix(tex11, tex21, t.x), mix(tex12, tex22, t.x), t.y);)";
102     } else if (attr.type == SamplingType::NEAREST) {
103       std::string fxc;
104       std::string fyc;
105       if (attr.half_pixel_centers) {
106         fxc = "(float(gid.x) + 0.5) * $scale_factor.x$";
107         fyc = "(float(gid.y) + 0.5) * $scale_factor.y$";
108       } else {
109         fxc = "float(gid.x) * $scale_factor.x$";
110         fyc = "float(gid.y) * $scale_factor.y$";
111       }
112       if (attr.align_corners) {
113         fxc += " + 0.5";
114         fyc += " + 0.5";
115       }
116       source += "  ivec2 coord;\n";
117       source += "  coord.x = int(" + fxc + ");\n";
118       source += "  coord.y = int(" + fyc + ");\n";
119       source += "  coord.x = max(0, coord.x);\n";
120       source += "  coord.y = max(0, coord.y);\n";
121       source += "  coord.x = min(coord.x, $input_data_0_w$ - 1);\n";
122       source += "  coord.y = min(coord.y, $input_data_0_h$ - 1);\n";
123       source += R"(
124       value_0 = $input_data_0[coord.x, coord.y, gid.z]$;
125       )";
126     } else {
127       return absl::InvalidArgumentError("Unknown sampling type");
128     }
129     *generated_code = {
130         /*parameters=*/std::move(parameters),
131         /*objects=*/{},
132         /*shared_variables=*/{},
133         /*workload=*/uint3(),
134         /*workgroup=*/uint3(),
135         /*source_code=*/std::move(source),
136         /*input=*/IOStructure::ONLY_DEFINITIONS,
137         /*output=*/IOStructure::AUTO,
138     };
139     return absl::OkStatus();
140   }
141 };
142 
143 }  // namespace
144 
NewResizeNodeShader()145 std::unique_ptr<NodeShader> NewResizeNodeShader() {
146   return std::make_unique<Resize>();
147 }
148 
149 }  // namespace gl
150 }  // namespace gpu
151 }  // namespace tflite
152