1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_format.h"
23 #include "tensorflow_lite_support/cc/port/status_macros.h"
24
25 namespace tflite {
26 namespace task {
27 namespace vision {
28 namespace {
29
30 using ::tflite::support::StatusOr;
31
32 constexpr int kRgbaChannels = 4;
33 constexpr int kRgbChannels = 3;
34 constexpr int kGrayChannel = 1;
35
36 // Creates a FrameBuffer from raw NV12 buffer and passing arguments.
CreateFromNV12RawBuffer(const uint8 * input,FrameBuffer::Dimension dimension,FrameBuffer::Orientation orientation,const absl::Time timestamp)37 std::unique_ptr<FrameBuffer> CreateFromNV12RawBuffer(
38 const uint8* input, FrameBuffer::Dimension dimension,
39 FrameBuffer::Orientation orientation, const absl::Time timestamp) {
40 const std::vector<FrameBuffer::Plane> planes_nv12 = {
41 {input, /*stride=*/{dimension.width, kGrayChannel}},
42 {input + dimension.Size(), /*stride=*/{dimension.width, 2}}};
43 return FrameBuffer::Create(planes_nv12, dimension, FrameBuffer::Format::kNV12,
44 orientation, timestamp);
45 }
46
47 // Creates a FrameBuffer from raw NV21 buffer and passing arguments.
CreateFromNV21RawBuffer(const uint8 * input,FrameBuffer::Dimension dimension,FrameBuffer::Orientation orientation,const absl::Time timestamp)48 std::unique_ptr<FrameBuffer> CreateFromNV21RawBuffer(
49 const uint8* input, FrameBuffer::Dimension dimension,
50 FrameBuffer::Orientation orientation, const absl::Time timestamp) {
51 FrameBuffer::Plane input_plane = {/*buffer=*/input,
52 /*stride=*/{dimension.width, kGrayChannel}};
53 return FrameBuffer::Create({input_plane}, dimension,
54 FrameBuffer::Format::kNV21, orientation,
55 timestamp);
56 }
57
58 // Indicates whether the given buffers have the same dimensions.
AreBufferDimsEqual(const FrameBuffer & buffer1,const FrameBuffer & buffer2)59 bool AreBufferDimsEqual(const FrameBuffer& buffer1,
60 const FrameBuffer& buffer2) {
61 return buffer1.dimension() == buffer2.dimension();
62 }
63
64 // Indicates whether the given buffers formats are compatible. Same formats are
65 // compatible and all YUV family formats (e.g. NV21, NV12, YV12, YV21, etc) are
66 // compatible.
AreBufferFormatsCompatible(const FrameBuffer & buffer1,const FrameBuffer & buffer2)67 bool AreBufferFormatsCompatible(const FrameBuffer& buffer1,
68 const FrameBuffer& buffer2) {
69 switch (buffer1.format()) {
70 case FrameBuffer::Format::kRGBA:
71 case FrameBuffer::Format::kRGB:
72 return (buffer2.format() == FrameBuffer::Format::kRGBA ||
73 buffer2.format() == FrameBuffer::Format::kRGB);
74 case FrameBuffer::Format::kNV12:
75 case FrameBuffer::Format::kNV21:
76 case FrameBuffer::Format::kYV12:
77 case FrameBuffer::Format::kYV21:
78 return (buffer2.format() == FrameBuffer::Format::kNV12 ||
79 buffer2.format() == FrameBuffer::Format::kNV21 ||
80 buffer2.format() == FrameBuffer::Format::kYV12 ||
81 buffer2.format() == FrameBuffer::Format::kYV21);
82 case FrameBuffer::Format::kGRAY:
83 default:
84 return buffer1.format() == buffer2.format();
85 }
86 }
87
88 } // namespace
89
90 // Miscellaneous Methods
91 // -----------------------------------------------------------------
GetFrameBufferByteSize(FrameBuffer::Dimension dimension,FrameBuffer::Format format)92 int GetFrameBufferByteSize(FrameBuffer::Dimension dimension,
93 FrameBuffer::Format format) {
94 switch (format) {
95 case FrameBuffer::Format::kNV12:
96 case FrameBuffer::Format::kNV21:
97 case FrameBuffer::Format::kYV12:
98 case FrameBuffer::Format::kYV21:
99 return /*y plane*/ dimension.Size() +
100 /*uv plane*/ ((static_cast<float>(dimension.width + 1) / 2) *
101 (static_cast<float>(dimension.height + 1) / 2) * 2);
102 case FrameBuffer::Format::kRGB:
103 return dimension.Size() * 3;
104 case FrameBuffer::Format::kRGBA:
105 return dimension.Size() * 4;
106 case FrameBuffer::Format::kGRAY:
107 return dimension.Size();
108 default:
109 return 0;
110 }
111 }
112
GetPixelStrides(FrameBuffer::Format format)113 StatusOr<int> GetPixelStrides(FrameBuffer::Format format) {
114 switch (format) {
115 case FrameBuffer::Format::kGRAY:
116 return kGrayPixelBytes;
117 case FrameBuffer::Format::kRGB:
118 return kRgbPixelBytes;
119 case FrameBuffer::Format::kRGBA:
120 return kRgbaPixelBytes;
121 default:
122 return absl::InvalidArgumentError(absl::StrFormat(
123 "GetPixelStrides does not support format: %i.", format));
124 }
125 }
126
GetUvRawBuffer(const FrameBuffer & buffer)127 StatusOr<const uint8*> GetUvRawBuffer(const FrameBuffer& buffer) {
128 if (buffer.format() != FrameBuffer::Format::kNV12 &&
129 buffer.format() != FrameBuffer::Format::kNV21) {
130 return absl::InvalidArgumentError(
131 "Only support getting biplanar UV buffer from NV12/NV21 frame buffer.");
132 }
133 ASSIGN_OR_RETURN(FrameBuffer::YuvData yuv_data,
134 FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
135 const uint8* uv_buffer = buffer.format() == FrameBuffer::Format::kNV12
136 ? yuv_data.u_buffer
137 : yuv_data.v_buffer;
138 return uv_buffer;
139 }
140
GetUvPlaneDimension(FrameBuffer::Dimension dimension,FrameBuffer::Format format)141 StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension(
142 FrameBuffer::Dimension dimension, FrameBuffer::Format format) {
143 if (dimension.width <= 0 || dimension.height <= 0) {
144 return absl::InvalidArgumentError(
145 absl::StrFormat("Invalid input dimension: {%d, %d}.", dimension.width,
146 dimension.height));
147 }
148 switch (format) {
149 case FrameBuffer::Format::kNV12:
150 case FrameBuffer::Format::kNV21:
151 case FrameBuffer::Format::kYV12:
152 case FrameBuffer::Format::kYV21:
153 return FrameBuffer::Dimension{(dimension.width + 1) / 2,
154 (dimension.height + 1) / 2};
155 default:
156 return absl::InvalidArgumentError(
157 absl::StrFormat("Input format is not YUV-like: %i.", format));
158 }
159 }
160
GetCropDimension(int x0,int x1,int y0,int y1)161 FrameBuffer::Dimension GetCropDimension(int x0, int x1, int y0, int y1) {
162 return {x1 - x0 + 1, y1 - y0 + 1};
163 }
164
165 // Validation Methods
166 // -----------------------------------------------------------------
167
ValidateBufferPlaneMetadata(const FrameBuffer & buffer)168 absl::Status ValidateBufferPlaneMetadata(const FrameBuffer& buffer) {
169 if (buffer.plane_count() < 1) {
170 return absl::InvalidArgumentError(
171 "There must be at least 1 plane specified.");
172 }
173
174 for (int i = 0; i < buffer.plane_count(); i++) {
175 if (buffer.plane(i).stride.row_stride_bytes == 0 ||
176 buffer.plane(i).stride.pixel_stride_bytes == 0) {
177 return absl::InvalidArgumentError("Invalid stride information.");
178 }
179 }
180
181 return absl::OkStatus();
182 }
183
ValidateBufferFormat(const FrameBuffer & buffer)184 absl::Status ValidateBufferFormat(const FrameBuffer& buffer) {
185 switch (buffer.format()) {
186 case FrameBuffer::Format::kGRAY:
187 case FrameBuffer::Format::kRGB:
188 case FrameBuffer::Format::kRGBA:
189 if (buffer.plane_count() == 1) return absl::OkStatus();
190 return absl::InvalidArgumentError(
191 "Plane count must be 1 for grayscale and RGB[a] buffers.");
192 case FrameBuffer::Format::kNV21:
193 case FrameBuffer::Format::kNV12:
194 case FrameBuffer::Format::kYV21:
195 case FrameBuffer::Format::kYV12:
196 return absl::OkStatus();
197 default:
198 return absl::InternalError(
199 absl::StrFormat("Unsupported buffer format: %i.", buffer.format()));
200 }
201 }
202
ValidateBufferFormats(const FrameBuffer & buffer1,const FrameBuffer & buffer2)203 absl::Status ValidateBufferFormats(const FrameBuffer& buffer1,
204 const FrameBuffer& buffer2) {
205 RETURN_IF_ERROR(ValidateBufferFormat(buffer1));
206 RETURN_IF_ERROR(ValidateBufferFormat(buffer2));
207 return absl::OkStatus();
208 }
209
ValidateResizeBufferInputs(const FrameBuffer & buffer,const FrameBuffer & output_buffer)210 absl::Status ValidateResizeBufferInputs(const FrameBuffer& buffer,
211 const FrameBuffer& output_buffer) {
212 bool valid_format = false;
213 switch (buffer.format()) {
214 case FrameBuffer::Format::kGRAY:
215 case FrameBuffer::Format::kRGB:
216 case FrameBuffer::Format::kNV12:
217 case FrameBuffer::Format::kNV21:
218 case FrameBuffer::Format::kYV12:
219 case FrameBuffer::Format::kYV21:
220 valid_format = (buffer.format() == output_buffer.format());
221 break;
222 case FrameBuffer::Format::kRGBA:
223 valid_format = (output_buffer.format() == FrameBuffer::Format::kRGBA ||
224 output_buffer.format() == FrameBuffer::Format::kRGB);
225 break;
226 default:
227 return absl::InternalError(
228 absl::StrFormat("Unsupported buffer format: %i.", buffer.format()));
229 }
230 if (!valid_format) {
231 return absl::InvalidArgumentError(
232 "Input and output buffer formats must match.");
233 }
234 return ValidateBufferFormats(buffer, output_buffer);
235 }
236
ValidateRotateBufferInputs(const FrameBuffer & buffer,const FrameBuffer & output_buffer,int angle_deg)237 absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer,
238 const FrameBuffer& output_buffer,
239 int angle_deg) {
240 if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
241 return absl::InvalidArgumentError(
242 "Input and output buffer formats must match.");
243 }
244
245 const bool is_dimension_change = (angle_deg / 90) % 2 == 1;
246 const bool are_dimensions_rotated =
247 (buffer.dimension().width == output_buffer.dimension().height) &&
248 (buffer.dimension().height == output_buffer.dimension().width);
249 const bool are_dimensions_equal =
250 buffer.dimension() == output_buffer.dimension();
251
252 if (angle_deg >= 360 || angle_deg <= 0 || angle_deg % 90 != 0) {
253 return absl::InvalidArgumentError(
254 "Rotation angle must be between 0 and 360, in multiples of 90 "
255 "degrees.");
256 } else if ((is_dimension_change && !are_dimensions_rotated) ||
257 (!is_dimension_change && !are_dimensions_equal)) {
258 return absl::InvalidArgumentError(
259 "Output buffer has invalid dimensions for rotation.");
260 }
261 return absl::OkStatus();
262 }
263
ValidateCropBufferInputs(const FrameBuffer & buffer,const FrameBuffer & output_buffer,int x0,int y0,int x1,int y1)264 absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer,
265 const FrameBuffer& output_buffer, int x0,
266 int y0, int x1, int y1) {
267 if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
268 return absl::InvalidArgumentError(
269 "Input and output buffer formats must match.");
270 }
271
272 bool is_buffer_size_valid =
273 ((x1 < buffer.dimension().width) && y1 < buffer.dimension().height);
274 bool are_points_valid = (x0 >= 0) && (y0 >= 0) && (x1 >= x0) && (y1 >= y0);
275
276 if (!is_buffer_size_valid || !are_points_valid) {
277 return absl::InvalidArgumentError("Invalid crop coordinates.");
278 }
279 return absl::OkStatus();
280 }
281
ValidateFlipBufferInputs(const FrameBuffer & buffer,const FrameBuffer & output_buffer)282 absl::Status ValidateFlipBufferInputs(const FrameBuffer& buffer,
283 const FrameBuffer& output_buffer) {
284 if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
285 return absl::InvalidArgumentError(
286 "Input and output buffer formats must match.");
287 }
288 return AreBufferDimsEqual(buffer, output_buffer)
289 ? absl::OkStatus()
290 : absl::InvalidArgumentError(
291 "Input and output buffers must have the same dimensions.");
292 }
293
ValidateConvertFormats(FrameBuffer::Format from_format,FrameBuffer::Format to_format)294 absl::Status ValidateConvertFormats(FrameBuffer::Format from_format,
295 FrameBuffer::Format to_format) {
296 if (from_format == to_format) {
297 return absl::InvalidArgumentError("Formats must be different.");
298 }
299
300 switch (from_format) {
301 case FrameBuffer::Format::kGRAY:
302 return absl::InvalidArgumentError(
303 "Grayscale format does not convert to other formats.");
304 case FrameBuffer::Format::kRGB:
305 if (to_format == FrameBuffer::Format::kRGBA) {
306 return absl::InvalidArgumentError(
307 "RGB format does not convert to RGBA");
308 }
309 return absl::OkStatus();
310 case FrameBuffer::Format::kRGBA:
311 case FrameBuffer::Format::kNV12:
312 case FrameBuffer::Format::kNV21:
313 case FrameBuffer::Format::kYV12:
314 case FrameBuffer::Format::kYV21:
315 return absl::OkStatus();
316 default:
317 return absl::InternalError(
318 absl::StrFormat("Unsupported buffer format: %i.", from_format));
319 }
320 }
321
322 // Creation Methods
323 // -----------------------------------------------------------------
324
325 // Creates a FrameBuffer from raw RGBA buffer and passing arguments.
CreateFromRgbaRawBuffer(const uint8 * input,FrameBuffer::Dimension dimension,FrameBuffer::Orientation orientation,const absl::Time timestamp)326 std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
327 const uint8* input, FrameBuffer::Dimension dimension,
328 FrameBuffer::Orientation orientation, const absl::Time timestamp) {
329 FrameBuffer::Plane input_plane = {
330 /*buffer=*/input,
331 /*stride=*/{dimension.width * kRgbaChannels, kRgbaChannels}};
332 return FrameBuffer::Create({input_plane}, dimension,
333 FrameBuffer::Format::kRGBA, orientation,
334 timestamp);
335 }
336
337 // Creates a FrameBuffer from raw RGB buffer and passing arguments.
CreateFromRgbRawBuffer(const uint8 * input,FrameBuffer::Dimension dimension,FrameBuffer::Orientation orientation,const absl::Time timestamp)338 std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
339 const uint8* input, FrameBuffer::Dimension dimension,
340 FrameBuffer::Orientation orientation, const absl::Time timestamp) {
341 FrameBuffer::Plane input_plane = {
342 /*buffer=*/input,
343 /*stride=*/{dimension.width * kRgbChannels, kRgbChannels}};
344 return FrameBuffer::Create({input_plane}, dimension,
345 FrameBuffer::Format::kRGB, orientation, timestamp);
346 }
347
348 // Creates a FrameBuffer from raw grayscale buffer and passing arguments.
CreateFromGrayRawBuffer(const uint8 * input,FrameBuffer::Dimension dimension,FrameBuffer::Orientation orientation,const absl::Time timestamp)349 std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
350 const uint8* input, FrameBuffer::Dimension dimension,
351 FrameBuffer::Orientation orientation, const absl::Time timestamp) {
352 FrameBuffer::Plane input_plane = {/*buffer=*/input,
353 /*stride=*/{dimension.width, kGrayChannel}};
354 return FrameBuffer::Create({input_plane}, dimension,
355 FrameBuffer::Format::kGRAY, orientation,
356 timestamp);
357 }
358
359 // Creates a FrameBuffer from raw YUV buffer and passing arguments.
CreateFromYuvRawBuffer(const uint8 * y_plane,const uint8 * u_plane,const uint8 * v_plane,FrameBuffer::Format format,FrameBuffer::Dimension dimension,int row_stride_y,int row_stride_uv,int pixel_stride_uv,FrameBuffer::Orientation orientation,const absl::Time timestamp)360 StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
361 const uint8* y_plane, const uint8* u_plane, const uint8* v_plane,
362 FrameBuffer::Format format, FrameBuffer::Dimension dimension,
363 int row_stride_y, int row_stride_uv, int pixel_stride_uv,
364 FrameBuffer::Orientation orientation, const absl::Time timestamp) {
365 const int pixel_stride_y = 1;
366 std::vector<FrameBuffer::Plane> planes;
367 if (format == FrameBuffer::Format::kNV21 ||
368 format == FrameBuffer::Format::kYV12) {
369 planes = {{y_plane, /*stride=*/{row_stride_y, pixel_stride_y}},
370 {v_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}},
371 {u_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}};
372 } else if (format == FrameBuffer::Format::kNV12 ||
373 format == FrameBuffer::Format::kYV21) {
374 planes = {{y_plane, /*stride=*/{row_stride_y, pixel_stride_y}},
375 {u_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}},
376 {v_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}};
377 } else {
378 return absl::InvalidArgumentError(
379 absl::StrFormat("Input format is not YUV-like: %i.", format));
380 }
381 return FrameBuffer::Create(planes, dimension, format, orientation, timestamp);
382 }
383
CreateFromRawBuffer(const uint8 * buffer,FrameBuffer::Dimension dimension,const FrameBuffer::Format target_format,FrameBuffer::Orientation orientation,absl::Time timestamp)384 StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer(
385 const uint8* buffer, FrameBuffer::Dimension dimension,
386 const FrameBuffer::Format target_format,
387 FrameBuffer::Orientation orientation, absl::Time timestamp) {
388 switch (target_format) {
389 case FrameBuffer::Format::kNV12:
390 return CreateFromNV12RawBuffer(buffer, dimension, orientation, timestamp);
391 case FrameBuffer::Format::kNV21:
392 return CreateFromNV21RawBuffer(buffer, dimension, orientation, timestamp);
393 case FrameBuffer::Format::kYV12: {
394 ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_dimension,
395 GetUvPlaneDimension(dimension, target_format));
396 return CreateFromYuvRawBuffer(
397 /*y_plane=*/buffer,
398 /*u_plane=*/buffer + dimension.Size() + uv_dimension.Size(),
399 /*v_plane=*/buffer + dimension.Size(), target_format, dimension,
400 /*row_stride_y=*/dimension.width, uv_dimension.width,
401 /*pixel_stride_uv=*/1, orientation, timestamp);
402 }
403 case FrameBuffer::Format::kYV21: {
404 ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_dimension,
405 GetUvPlaneDimension(dimension, target_format));
406 return CreateFromYuvRawBuffer(
407 /*y_plane=*/buffer, /*u_plane=*/buffer + dimension.Size(),
408 /*v_plane=*/buffer + dimension.Size() + uv_dimension.Size(),
409 target_format, dimension, /*row_stride_y=*/dimension.width,
410 uv_dimension.width,
411 /*pixel_stride_uv=*/1, orientation, timestamp);
412 }
413 case FrameBuffer::Format::kRGBA:
414 return CreateFromRgbaRawBuffer(buffer, dimension, orientation, timestamp);
415 case FrameBuffer::Format::kRGB:
416 return CreateFromRgbRawBuffer(buffer, dimension, orientation, timestamp);
417 case FrameBuffer::Format::kGRAY:
418 return CreateFromGrayRawBuffer(buffer, dimension, orientation, timestamp);
419 default:
420
421 return absl::InternalError(
422 absl::StrFormat("Unsupported buffer format: %i.", target_format));
423 }
424 }
425
426 } // namespace vision
427 } // namespace task
428 } // namespace tflite
429