xref: /aosp_15_r20/external/armnn/tests/Cifar10Database.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "Cifar10Database.hpp"
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <fstream>
11*89c4ff92SAndroid Build Coastguard Worker #include <vector>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int g_kCifar10ImageByteSize = 1 + 3 * 32 * 32;
14*89c4ff92SAndroid Build Coastguard Worker 
Cifar10Database(const std::string & binaryFileDirectory,bool rgbPack)15*89c4ff92SAndroid Build Coastguard Worker Cifar10Database::Cifar10Database(const std::string& binaryFileDirectory, bool rgbPack)
16*89c4ff92SAndroid Build Coastguard Worker     : m_BinaryDirectory(binaryFileDirectory), m_RgbPack(rgbPack)
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker }
19*89c4ff92SAndroid Build Coastguard Worker 
GetTestCaseData(unsigned int testCaseId)20*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Cifar10Database::TTestCaseData> Cifar10Database::GetTestCaseData(unsigned int testCaseId)
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned char> I(g_kCifar10ImageByteSize);
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker     std::string fullpath = m_BinaryDirectory + std::string("test_batch.bin");
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker     std::ifstream fileStream(fullpath, std::ios::binary);
27*89c4ff92SAndroid Build Coastguard Worker     if (!fileStream.is_open())
28*89c4ff92SAndroid Build Coastguard Worker     {
29*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(fatal) << "Failed to load " << fullpath;
30*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
31*89c4ff92SAndroid Build Coastguard Worker     }
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     fileStream.seekg(testCaseId * g_kCifar10ImageByteSize, std::ios_base::beg);
34*89c4ff92SAndroid Build Coastguard Worker     fileStream.read(reinterpret_cast<char*>(&I[0]), g_kCifar10ImageByteSize);
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker     if (!fileStream.good())
37*89c4ff92SAndroid Build Coastguard Worker     {
38*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(fatal) << "Failed to read " << fullpath;
39*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
40*89c4ff92SAndroid Build Coastguard Worker     }
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputImageData;
44*89c4ff92SAndroid Build Coastguard Worker     inputImageData.resize(g_kCifar10ImageByteSize - 1);
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker     unsigned int step;
47*89c4ff92SAndroid Build Coastguard Worker     unsigned int countR_o;
48*89c4ff92SAndroid Build Coastguard Worker     unsigned int countG_o;
49*89c4ff92SAndroid Build Coastguard Worker     unsigned int countB_o;
50*89c4ff92SAndroid Build Coastguard Worker     unsigned int countR = 1;
51*89c4ff92SAndroid Build Coastguard Worker     unsigned int countG = 1 + 32 * 32;
52*89c4ff92SAndroid Build Coastguard Worker     unsigned int countB = 1 + 2 * 32 * 32;
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     if (m_RgbPack)
55*89c4ff92SAndroid Build Coastguard Worker     {
56*89c4ff92SAndroid Build Coastguard Worker         countR_o = 0;
57*89c4ff92SAndroid Build Coastguard Worker         countG_o = 1;
58*89c4ff92SAndroid Build Coastguard Worker         countB_o = 2;
59*89c4ff92SAndroid Build Coastguard Worker         step = 3;
60*89c4ff92SAndroid Build Coastguard Worker     }
61*89c4ff92SAndroid Build Coastguard Worker     else
62*89c4ff92SAndroid Build Coastguard Worker     {
63*89c4ff92SAndroid Build Coastguard Worker         countR_o = 0;
64*89c4ff92SAndroid Build Coastguard Worker         countG_o = 32 * 32;
65*89c4ff92SAndroid Build Coastguard Worker         countB_o = 2 * 32 * 32;
66*89c4ff92SAndroid Build Coastguard Worker         step = 1;
67*89c4ff92SAndroid Build Coastguard Worker     }
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int h = 0; h < 32; h++)
70*89c4ff92SAndroid Build Coastguard Worker     {
71*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int w = 0; w < 32; w++)
72*89c4ff92SAndroid Build Coastguard Worker         {
73*89c4ff92SAndroid Build Coastguard Worker             // Static_cast of unsigned char is safe with float
74*89c4ff92SAndroid Build Coastguard Worker             inputImageData[countR_o] = static_cast<float>(I[countR++]);
75*89c4ff92SAndroid Build Coastguard Worker             inputImageData[countG_o] = static_cast<float>(I[countG++]);
76*89c4ff92SAndroid Build Coastguard Worker             inputImageData[countB_o] = static_cast<float>(I[countB++]);
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker             countR_o += step;
79*89c4ff92SAndroid Build Coastguard Worker             countG_o += step;
80*89c4ff92SAndroid Build Coastguard Worker             countB_o += step;
81*89c4ff92SAndroid Build Coastguard Worker         }
82*89c4ff92SAndroid Build Coastguard Worker     }
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker     const unsigned int label = armnn::numeric_cast<unsigned int>(I[0]);
85*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<TTestCaseData>(label, std::move(inputImageData));
86*89c4ff92SAndroid Build Coastguard Worker }
87