1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_header_parser.h"
16 
17 #include <string>
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "tensorflow/lite/c/c_api_types.h"
22 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/embedded_chessboard_jpeg.h"
23 
24 namespace tflite {
25 namespace acceleration {
26 namespace decode_jpeg_kernel {
27 
PrintTo(const Status & status,std::ostream * os)28 void PrintTo(const Status& status, std::ostream* os) {
29   *os << "{ code: " + std::to_string(status.code) + ", error_message: '" +
30              status.error_message + "'}";
31 }
32 
33 }  // namespace decode_jpeg_kernel
34 }  // namespace acceleration
35 }  // namespace tflite
36 
37 namespace {
38 
39 using ::testing::AllOf;
40 using ::testing::Eq;
41 using ::testing::Field;
42 using ::testing::Matcher;
43 
44 using tflite::acceleration::decode_jpeg_kernel::JpegHeader;
45 using tflite::acceleration::decode_jpeg_kernel::ReadJpegHeader;
46 
JpegHeaderEq(const JpegHeader & expected)47 Matcher<JpegHeader> JpegHeaderEq(const JpegHeader& expected) {
48   return AllOf(
49       Field(&JpegHeader::channels, Eq(expected.channels)),
50       Field(&JpegHeader::height, Eq(expected.height)),
51       Field(&JpegHeader::width, Eq(expected.width)),
52       Field(&JpegHeader::bits_per_sample, Eq(expected.bits_per_sample)));
53 }
54 
55 using tflite::acceleration::decode_jpeg_kernel::Status;
56 
StatusEq(const Status & expected)57 Matcher<Status> StatusEq(const Status& expected) {
58   return AllOf(Field(&Status::code, Eq(expected.code)),
59                Field(&Status::error_message, Eq(expected.error_message)));
60 }
61 
62 const int kChessboardImgHeight = 300;
63 const int kChessboardImgWidth = 250;
64 const int kChessboardImgChannels = 3;
65 
TEST(ReadJpegHeader,ShouldParseValidJpgImage)66 TEST(ReadJpegHeader, ShouldParseValidJpgImage) {
67   const tflite::StringRef chessboard_image{
68       reinterpret_cast<const char*>(g_tflite_acceleration_chessboard_jpeg),
69       g_tflite_acceleration_chessboard_jpeg_len};
70   ASSERT_GT(chessboard_image.len, 4);
71 
72   JpegHeader header;
73 
74   ASSERT_THAT(ReadJpegHeader(chessboard_image, &header),
75               StatusEq({kTfLiteOk, ""}));
76   EXPECT_THAT(header, JpegHeaderEq({kChessboardImgHeight, kChessboardImgWidth,
77                                     kChessboardImgChannels}));
78 }
79 
TEST(ReadJpegHeader,ShouldFailForInvalidJpegImage)80 TEST(ReadJpegHeader, ShouldFailForInvalidJpegImage) {
81   const std::string invalid_image = "invalid image content";
82   const tflite::StringRef invalid_image_ref{
83       invalid_image.c_str(), static_cast<int>(invalid_image.size())};
84 
85   JpegHeader header;
86 
87   EXPECT_THAT(ReadJpegHeader(invalid_image_ref, &header),
88               StatusEq({kTfLiteError, "Not a valid JPEG image."}));
89 }
90 
TEST(ReadJpegHeader,ShouldFailForEmptyJpegImage)91 TEST(ReadJpegHeader, ShouldFailForEmptyJpegImage) {
92   const tflite::StringRef invalid_image_ref{"", 0};
93 
94   JpegHeader header;
95 
96   EXPECT_THAT(ReadJpegHeader(invalid_image_ref, &header),
97               StatusEq({kTfLiteError, "Not a valid JPEG image."}));
98 }
99 
TEST(ApplyHeaderToImage,ReturnsNewImageWithDifferentHeader)100 TEST(ApplyHeaderToImage, ReturnsNewImageWithDifferentHeader) {
101   const tflite::StringRef chessboard_image{
102       reinterpret_cast<const char*>(g_tflite_acceleration_chessboard_jpeg),
103       g_tflite_acceleration_chessboard_jpeg_len};
104 
105   JpegHeader new_header{
106       .height = 20, .width = 30, .channels = 1, .bits_per_sample = 3};
107 
108   std::string new_image_data;
109 
110   ASSERT_THAT(
111       BuildImageWithNewHeader(chessboard_image, new_header, new_image_data),
112       StatusEq({kTfLiteOk, ""}));
113 
114   const tflite::StringRef altered_image{
115       new_image_data.c_str(), static_cast<int>(new_image_data.size())};
116   JpegHeader header;
117   ASSERT_THAT(ReadJpegHeader(altered_image, &header),
118               StatusEq({kTfLiteOk, ""}));
119   EXPECT_THAT(header, JpegHeaderEq(new_header));
120 }
121 
122 }  // namespace
123