1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_format.h"
23 #include "tensorflow_lite_support/cc/port/status_macros.h"
24 
25 namespace tflite {
26 namespace task {
27 namespace vision {
28 namespace {
29 
30 using ::tflite::support::StatusOr;
31 
32 constexpr int kRgbaChannels = 4;
33 constexpr int kRgbChannels = 3;
34 constexpr int kGrayChannel = 1;
35 
36 // Creates a FrameBuffer from raw NV12 buffer and passing arguments.
CreateFromNV12RawBuffer(const uint8 * input,FrameBuffer::Dimension dimension,FrameBuffer::Orientation orientation,const absl::Time timestamp)37 std::unique_ptr<FrameBuffer> CreateFromNV12RawBuffer(
38     const uint8* input, FrameBuffer::Dimension dimension,
39     FrameBuffer::Orientation orientation, const absl::Time timestamp) {
40   const std::vector<FrameBuffer::Plane> planes_nv12 = {
41       {input, /*stride=*/{dimension.width, kGrayChannel}},
42       {input + dimension.Size(), /*stride=*/{dimension.width, 2}}};
43   return FrameBuffer::Create(planes_nv12, dimension, FrameBuffer::Format::kNV12,
44                              orientation, timestamp);
45 }
46 
47 // Creates a FrameBuffer from raw NV21 buffer and passing arguments.
CreateFromNV21RawBuffer(const uint8 * input,FrameBuffer::Dimension dimension,FrameBuffer::Orientation orientation,const absl::Time timestamp)48 std::unique_ptr<FrameBuffer> CreateFromNV21RawBuffer(
49     const uint8* input, FrameBuffer::Dimension dimension,
50     FrameBuffer::Orientation orientation, const absl::Time timestamp) {
51   FrameBuffer::Plane input_plane = {/*buffer=*/input,
52                                     /*stride=*/{dimension.width, kGrayChannel}};
53   return FrameBuffer::Create({input_plane}, dimension,
54                              FrameBuffer::Format::kNV21, orientation,
55                              timestamp);
56 }
57 
58 // Indicates whether the given buffers have the same dimensions.
AreBufferDimsEqual(const FrameBuffer & buffer1,const FrameBuffer & buffer2)59 bool AreBufferDimsEqual(const FrameBuffer& buffer1,
60                         const FrameBuffer& buffer2) {
61   return buffer1.dimension() == buffer2.dimension();
62 }
63 
64 // Indicates whether the given buffers formats are compatible. Same formats are
65 // compatible and all YUV family formats (e.g. NV21, NV12, YV12, YV21, etc) are
66 // compatible.
AreBufferFormatsCompatible(const FrameBuffer & buffer1,const FrameBuffer & buffer2)67 bool AreBufferFormatsCompatible(const FrameBuffer& buffer1,
68                                 const FrameBuffer& buffer2) {
69   switch (buffer1.format()) {
70     case FrameBuffer::Format::kRGBA:
71     case FrameBuffer::Format::kRGB:
72       return (buffer2.format() == FrameBuffer::Format::kRGBA ||
73               buffer2.format() == FrameBuffer::Format::kRGB);
74     case FrameBuffer::Format::kNV12:
75     case FrameBuffer::Format::kNV21:
76     case FrameBuffer::Format::kYV12:
77     case FrameBuffer::Format::kYV21:
78       return (buffer2.format() == FrameBuffer::Format::kNV12 ||
79               buffer2.format() == FrameBuffer::Format::kNV21 ||
80               buffer2.format() == FrameBuffer::Format::kYV12 ||
81               buffer2.format() == FrameBuffer::Format::kYV21);
82     case FrameBuffer::Format::kGRAY:
83     default:
84       return buffer1.format() == buffer2.format();
85   }
86 }
87 
88 }  // namespace
89 
90 // Miscellaneous Methods
91 // -----------------------------------------------------------------
GetFrameBufferByteSize(FrameBuffer::Dimension dimension,FrameBuffer::Format format)92 int GetFrameBufferByteSize(FrameBuffer::Dimension dimension,
93                            FrameBuffer::Format format) {
94   switch (format) {
95     case FrameBuffer::Format::kNV12:
96     case FrameBuffer::Format::kNV21:
97     case FrameBuffer::Format::kYV12:
98     case FrameBuffer::Format::kYV21:
99       return /*y plane*/ dimension.Size() +
100              /*uv plane*/ ((static_cast<float>(dimension.width + 1) / 2) *
101                            (static_cast<float>(dimension.height + 1) / 2) * 2);
102     case FrameBuffer::Format::kRGB:
103       return dimension.Size() * 3;
104     case FrameBuffer::Format::kRGBA:
105       return dimension.Size() * 4;
106     case FrameBuffer::Format::kGRAY:
107       return dimension.Size();
108     default:
109       return 0;
110   }
111 }
112 
GetPixelStrides(FrameBuffer::Format format)113 StatusOr<int> GetPixelStrides(FrameBuffer::Format format) {
114   switch (format) {
115     case FrameBuffer::Format::kGRAY:
116       return kGrayPixelBytes;
117     case FrameBuffer::Format::kRGB:
118       return kRgbPixelBytes;
119     case FrameBuffer::Format::kRGBA:
120       return kRgbaPixelBytes;
121     default:
122       return absl::InvalidArgumentError(absl::StrFormat(
123           "GetPixelStrides does not support format: %i.", format));
124   }
125 }
126 
GetUvRawBuffer(const FrameBuffer & buffer)127 StatusOr<const uint8*> GetUvRawBuffer(const FrameBuffer& buffer) {
128   if (buffer.format() != FrameBuffer::Format::kNV12 &&
129       buffer.format() != FrameBuffer::Format::kNV21) {
130     return absl::InvalidArgumentError(
131         "Only support getting biplanar UV buffer from NV12/NV21 frame buffer.");
132   }
133   ASSIGN_OR_RETURN(FrameBuffer::YuvData yuv_data,
134                    FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
135   const uint8* uv_buffer = buffer.format() == FrameBuffer::Format::kNV12
136                                ? yuv_data.u_buffer
137                                : yuv_data.v_buffer;
138   return uv_buffer;
139 }
140 
GetUvPlaneDimension(FrameBuffer::Dimension dimension,FrameBuffer::Format format)141 StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension(
142     FrameBuffer::Dimension dimension, FrameBuffer::Format format) {
143   if (dimension.width <= 0 || dimension.height <= 0) {
144     return absl::InvalidArgumentError(
145         absl::StrFormat("Invalid input dimension: {%d, %d}.", dimension.width,
146                         dimension.height));
147   }
148   switch (format) {
149     case FrameBuffer::Format::kNV12:
150     case FrameBuffer::Format::kNV21:
151     case FrameBuffer::Format::kYV12:
152     case FrameBuffer::Format::kYV21:
153       return FrameBuffer::Dimension{(dimension.width + 1) / 2,
154                                     (dimension.height + 1) / 2};
155     default:
156       return absl::InvalidArgumentError(
157           absl::StrFormat("Input format is not YUV-like: %i.", format));
158   }
159 }
160 
GetCropDimension(int x0,int x1,int y0,int y1)161 FrameBuffer::Dimension GetCropDimension(int x0, int x1, int y0, int y1) {
162   return {x1 - x0 + 1, y1 - y0 + 1};
163 }
164 
165 // Validation Methods
166 // -----------------------------------------------------------------
167 
ValidateBufferPlaneMetadata(const FrameBuffer & buffer)168 absl::Status ValidateBufferPlaneMetadata(const FrameBuffer& buffer) {
169   if (buffer.plane_count() < 1) {
170     return absl::InvalidArgumentError(
171         "There must be at least 1 plane specified.");
172   }
173 
174   for (int i = 0; i < buffer.plane_count(); i++) {
175     if (buffer.plane(i).stride.row_stride_bytes == 0 ||
176         buffer.plane(i).stride.pixel_stride_bytes == 0) {
177       return absl::InvalidArgumentError("Invalid stride information.");
178     }
179   }
180 
181   return absl::OkStatus();
182 }
183 
ValidateBufferFormat(const FrameBuffer & buffer)184 absl::Status ValidateBufferFormat(const FrameBuffer& buffer) {
185   switch (buffer.format()) {
186     case FrameBuffer::Format::kGRAY:
187     case FrameBuffer::Format::kRGB:
188     case FrameBuffer::Format::kRGBA:
189       if (buffer.plane_count() == 1) return absl::OkStatus();
190       return absl::InvalidArgumentError(
191           "Plane count must be 1 for grayscale and RGB[a] buffers.");
192     case FrameBuffer::Format::kNV21:
193     case FrameBuffer::Format::kNV12:
194     case FrameBuffer::Format::kYV21:
195     case FrameBuffer::Format::kYV12:
196       return absl::OkStatus();
197     default:
198       return absl::InternalError(
199           absl::StrFormat("Unsupported buffer format: %i.", buffer.format()));
200   }
201 }
202 
ValidateBufferFormats(const FrameBuffer & buffer1,const FrameBuffer & buffer2)203 absl::Status ValidateBufferFormats(const FrameBuffer& buffer1,
204                                    const FrameBuffer& buffer2) {
205   RETURN_IF_ERROR(ValidateBufferFormat(buffer1));
206   RETURN_IF_ERROR(ValidateBufferFormat(buffer2));
207   return absl::OkStatus();
208 }
209 
ValidateResizeBufferInputs(const FrameBuffer & buffer,const FrameBuffer & output_buffer)210 absl::Status ValidateResizeBufferInputs(const FrameBuffer& buffer,
211                                         const FrameBuffer& output_buffer) {
212   bool valid_format = false;
213   switch (buffer.format()) {
214     case FrameBuffer::Format::kGRAY:
215     case FrameBuffer::Format::kRGB:
216     case FrameBuffer::Format::kNV12:
217     case FrameBuffer::Format::kNV21:
218     case FrameBuffer::Format::kYV12:
219     case FrameBuffer::Format::kYV21:
220       valid_format = (buffer.format() == output_buffer.format());
221       break;
222     case FrameBuffer::Format::kRGBA:
223       valid_format = (output_buffer.format() == FrameBuffer::Format::kRGBA ||
224                       output_buffer.format() == FrameBuffer::Format::kRGB);
225       break;
226     default:
227       return absl::InternalError(
228           absl::StrFormat("Unsupported buffer format: %i.", buffer.format()));
229   }
230   if (!valid_format) {
231     return absl::InvalidArgumentError(
232         "Input and output buffer formats must match.");
233   }
234   return ValidateBufferFormats(buffer, output_buffer);
235 }
236 
ValidateRotateBufferInputs(const FrameBuffer & buffer,const FrameBuffer & output_buffer,int angle_deg)237 absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer,
238                                         const FrameBuffer& output_buffer,
239                                         int angle_deg) {
240   if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
241     return absl::InvalidArgumentError(
242         "Input and output buffer formats must match.");
243   }
244 
245   const bool is_dimension_change = (angle_deg / 90) % 2 == 1;
246   const bool are_dimensions_rotated =
247       (buffer.dimension().width == output_buffer.dimension().height) &&
248       (buffer.dimension().height == output_buffer.dimension().width);
249   const bool are_dimensions_equal =
250       buffer.dimension() == output_buffer.dimension();
251 
252   if (angle_deg >= 360 || angle_deg <= 0 || angle_deg % 90 != 0) {
253     return absl::InvalidArgumentError(
254         "Rotation angle must be between 0 and 360, in multiples of 90 "
255         "degrees.");
256   } else if ((is_dimension_change && !are_dimensions_rotated) ||
257              (!is_dimension_change && !are_dimensions_equal)) {
258     return absl::InvalidArgumentError(
259         "Output buffer has invalid dimensions for rotation.");
260   }
261   return absl::OkStatus();
262 }
263 
ValidateCropBufferInputs(const FrameBuffer & buffer,const FrameBuffer & output_buffer,int x0,int y0,int x1,int y1)264 absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer,
265                                       const FrameBuffer& output_buffer, int x0,
266                                       int y0, int x1, int y1) {
267   if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
268     return absl::InvalidArgumentError(
269         "Input and output buffer formats must match.");
270   }
271 
272   bool is_buffer_size_valid =
273       ((x1 < buffer.dimension().width) && y1 < buffer.dimension().height);
274   bool are_points_valid = (x0 >= 0) && (y0 >= 0) && (x1 >= x0) && (y1 >= y0);
275 
276   if (!is_buffer_size_valid || !are_points_valid) {
277     return absl::InvalidArgumentError("Invalid crop coordinates.");
278   }
279   return absl::OkStatus();
280 }
281 
ValidateFlipBufferInputs(const FrameBuffer & buffer,const FrameBuffer & output_buffer)282 absl::Status ValidateFlipBufferInputs(const FrameBuffer& buffer,
283                                       const FrameBuffer& output_buffer) {
284   if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
285     return absl::InvalidArgumentError(
286         "Input and output buffer formats must match.");
287   }
288   return AreBufferDimsEqual(buffer, output_buffer)
289              ? absl::OkStatus()
290              : absl::InvalidArgumentError(
291                    "Input and output buffers must have the same dimensions.");
292 }
293 
ValidateConvertFormats(FrameBuffer::Format from_format,FrameBuffer::Format to_format)294 absl::Status ValidateConvertFormats(FrameBuffer::Format from_format,
295                                     FrameBuffer::Format to_format) {
296   if (from_format == to_format) {
297     return absl::InvalidArgumentError("Formats must be different.");
298   }
299 
300   switch (from_format) {
301     case FrameBuffer::Format::kGRAY:
302       return absl::InvalidArgumentError(
303           "Grayscale format does not convert to other formats.");
304     case FrameBuffer::Format::kRGB:
305       if (to_format == FrameBuffer::Format::kRGBA) {
306         return absl::InvalidArgumentError(
307             "RGB format does not convert to RGBA");
308       }
309       return absl::OkStatus();
310     case FrameBuffer::Format::kRGBA:
311     case FrameBuffer::Format::kNV12:
312     case FrameBuffer::Format::kNV21:
313     case FrameBuffer::Format::kYV12:
314     case FrameBuffer::Format::kYV21:
315       return absl::OkStatus();
316     default:
317       return absl::InternalError(
318           absl::StrFormat("Unsupported buffer format: %i.", from_format));
319   }
320 }
321 
322 // Creation Methods
323 // -----------------------------------------------------------------
324 
325 // Creates a FrameBuffer from raw RGBA buffer and passing arguments.
CreateFromRgbaRawBuffer(const uint8 * input,FrameBuffer::Dimension dimension,FrameBuffer::Orientation orientation,const absl::Time timestamp)326 std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
327     const uint8* input, FrameBuffer::Dimension dimension,
328     FrameBuffer::Orientation orientation, const absl::Time timestamp) {
329   FrameBuffer::Plane input_plane = {
330       /*buffer=*/input,
331       /*stride=*/{dimension.width * kRgbaChannels, kRgbaChannels}};
332   return FrameBuffer::Create({input_plane}, dimension,
333                              FrameBuffer::Format::kRGBA, orientation,
334                              timestamp);
335 }
336 
337 // Creates a FrameBuffer from raw RGB buffer and passing arguments.
CreateFromRgbRawBuffer(const uint8 * input,FrameBuffer::Dimension dimension,FrameBuffer::Orientation orientation,const absl::Time timestamp)338 std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
339     const uint8* input, FrameBuffer::Dimension dimension,
340     FrameBuffer::Orientation orientation, const absl::Time timestamp) {
341   FrameBuffer::Plane input_plane = {
342       /*buffer=*/input,
343       /*stride=*/{dimension.width * kRgbChannels, kRgbChannels}};
344   return FrameBuffer::Create({input_plane}, dimension,
345                              FrameBuffer::Format::kRGB, orientation, timestamp);
346 }
347 
348 // Creates a FrameBuffer from raw grayscale buffer and passing arguments.
CreateFromGrayRawBuffer(const uint8 * input,FrameBuffer::Dimension dimension,FrameBuffer::Orientation orientation,const absl::Time timestamp)349 std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
350     const uint8* input, FrameBuffer::Dimension dimension,
351     FrameBuffer::Orientation orientation, const absl::Time timestamp) {
352   FrameBuffer::Plane input_plane = {/*buffer=*/input,
353                                     /*stride=*/{dimension.width, kGrayChannel}};
354   return FrameBuffer::Create({input_plane}, dimension,
355                              FrameBuffer::Format::kGRAY, orientation,
356                              timestamp);
357 }
358 
359 // Creates a FrameBuffer from raw YUV buffer and passing arguments.
CreateFromYuvRawBuffer(const uint8 * y_plane,const uint8 * u_plane,const uint8 * v_plane,FrameBuffer::Format format,FrameBuffer::Dimension dimension,int row_stride_y,int row_stride_uv,int pixel_stride_uv,FrameBuffer::Orientation orientation,const absl::Time timestamp)360 StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
361     const uint8* y_plane, const uint8* u_plane, const uint8* v_plane,
362     FrameBuffer::Format format, FrameBuffer::Dimension dimension,
363     int row_stride_y, int row_stride_uv, int pixel_stride_uv,
364     FrameBuffer::Orientation orientation, const absl::Time timestamp) {
365   const int pixel_stride_y = 1;
366   std::vector<FrameBuffer::Plane> planes;
367   if (format == FrameBuffer::Format::kNV21 ||
368       format == FrameBuffer::Format::kYV12) {
369     planes = {{y_plane, /*stride=*/{row_stride_y, pixel_stride_y}},
370               {v_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}},
371               {u_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}};
372   } else if (format == FrameBuffer::Format::kNV12 ||
373              format == FrameBuffer::Format::kYV21) {
374     planes = {{y_plane, /*stride=*/{row_stride_y, pixel_stride_y}},
375               {u_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}},
376               {v_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}};
377   } else {
378     return absl::InvalidArgumentError(
379         absl::StrFormat("Input format is not YUV-like: %i.", format));
380   }
381   return FrameBuffer::Create(planes, dimension, format, orientation, timestamp);
382 }
383 
CreateFromRawBuffer(const uint8 * buffer,FrameBuffer::Dimension dimension,const FrameBuffer::Format target_format,FrameBuffer::Orientation orientation,absl::Time timestamp)384 StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer(
385     const uint8* buffer, FrameBuffer::Dimension dimension,
386     const FrameBuffer::Format target_format,
387     FrameBuffer::Orientation orientation, absl::Time timestamp) {
388   switch (target_format) {
389     case FrameBuffer::Format::kNV12:
390       return CreateFromNV12RawBuffer(buffer, dimension, orientation, timestamp);
391     case FrameBuffer::Format::kNV21:
392       return CreateFromNV21RawBuffer(buffer, dimension, orientation, timestamp);
393     case FrameBuffer::Format::kYV12: {
394       ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_dimension,
395                        GetUvPlaneDimension(dimension, target_format));
396       return CreateFromYuvRawBuffer(
397           /*y_plane=*/buffer,
398           /*u_plane=*/buffer + dimension.Size() + uv_dimension.Size(),
399           /*v_plane=*/buffer + dimension.Size(), target_format, dimension,
400           /*row_stride_y=*/dimension.width, uv_dimension.width,
401           /*pixel_stride_uv=*/1, orientation, timestamp);
402     }
403     case FrameBuffer::Format::kYV21: {
404       ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_dimension,
405                        GetUvPlaneDimension(dimension, target_format));
406       return CreateFromYuvRawBuffer(
407           /*y_plane=*/buffer, /*u_plane=*/buffer + dimension.Size(),
408           /*v_plane=*/buffer + dimension.Size() + uv_dimension.Size(),
409           target_format, dimension, /*row_stride_y=*/dimension.width,
410           uv_dimension.width,
411           /*pixel_stride_uv=*/1, orientation, timestamp);
412     }
413     case FrameBuffer::Format::kRGBA:
414       return CreateFromRgbaRawBuffer(buffer, dimension, orientation, timestamp);
415     case FrameBuffer::Format::kRGB:
416       return CreateFromRgbRawBuffer(buffer, dimension, orientation, timestamp);
417     case FrameBuffer::Format::kGRAY:
418       return CreateFromGrayRawBuffer(buffer, dimension, orientation, timestamp);
419     default:
420 
421       return absl::InternalError(
422           absl::StrFormat("Unsupported buffer format: %i.", target_format));
423   }
424 }
425 
426 }  // namespace vision
427 }  // namespace task
428 }  // namespace tflite
429