xref: /aosp_15_r20/hardware/interfaces/neuralnetworks/aidl/vts/functional/CompilationCachingTests.cpp (revision 4d7e907c777eeecc4c5bd7cf640a754fac206ff7)
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_aidl_hal_test"
18 
19 #include <android-base/logging.h>
20 #include <android/binder_auto_utils.h>
21 #include <android/binder_interface_utils.h>
22 #include <android/binder_status.h>
23 #include <fcntl.h>
24 #include <ftw.h>
25 #include <gtest/gtest.h>
26 #include <unistd.h>
27 
28 #include <cstdio>
29 #include <cstdlib>
30 #include <iterator>
31 #include <random>
32 #include <thread>
33 
34 #include "Callbacks.h"
35 #include "GeneratedTestHarness.h"
36 #include "TestHarness.h"
37 #include "Utils.h"
38 #include "VtsHalNeuralnetworks.h"
39 
40 // Forward declaration of the mobilenet generated test models in
41 // frameworks/ml/nn/runtime/test/generated/.
42 namespace generated_tests::mobilenet_224_gender_basic_fixed {
43 const test_helper::TestModel& get_test_model();
44 }  // namespace generated_tests::mobilenet_224_gender_basic_fixed
45 
46 namespace generated_tests::mobilenet_quantized {
47 const test_helper::TestModel& get_test_model();
48 }  // namespace generated_tests::mobilenet_quantized
49 
50 namespace aidl::android::hardware::neuralnetworks::vts::functional {
51 
52 using namespace test_helper;
53 using implementation::PreparedModelCallback;
54 
55 namespace float32_model {
56 
57 constexpr auto get_test_model = generated_tests::mobilenet_224_gender_basic_fixed::get_test_model;
58 
59 }  // namespace float32_model
60 
61 namespace quant8_model {
62 
63 constexpr auto get_test_model = generated_tests::mobilenet_quantized::get_test_model;
64 
65 }  // namespace quant8_model
66 
67 namespace {
68 
69 enum class AccessMode { READ_WRITE, READ_ONLY, WRITE_ONLY };
70 
71 // Creates cache handles based on provided file groups.
72 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
createCacheFds(const std::vector<std::string> & files,const std::vector<AccessMode> & mode,std::vector<ndk::ScopedFileDescriptor> * fds)73 void createCacheFds(const std::vector<std::string>& files, const std::vector<AccessMode>& mode,
74                     std::vector<ndk::ScopedFileDescriptor>* fds) {
75     fds->clear();
76     fds->reserve(files.size());
77     for (uint32_t i = 0; i < files.size(); i++) {
78         const auto& file = files[i];
79         int fd;
80         if (mode[i] == AccessMode::READ_ONLY) {
81             fd = open(file.c_str(), O_RDONLY);
82         } else if (mode[i] == AccessMode::WRITE_ONLY) {
83             fd = open(file.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
84         } else if (mode[i] == AccessMode::READ_WRITE) {
85             fd = open(file.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
86         } else {
87             FAIL();
88         }
89         ASSERT_GE(fd, 0);
90         fds->emplace_back(fd);
91     }
92 }
93 
createCacheFds(const std::vector<std::string> & files,AccessMode mode,std::vector<ndk::ScopedFileDescriptor> * fds)94 void createCacheFds(const std::vector<std::string>& files, AccessMode mode,
95                     std::vector<ndk::ScopedFileDescriptor>* fds) {
96     createCacheFds(files, std::vector<AccessMode>(files.size(), mode), fds);
97 }
98 
99 // Create a chain of broadcast operations. The second operand is always constant tensor [1].
100 // For simplicity, activation scalar is shared. The second operand is not shared
101 // in the model to let driver maintain a non-trivial size of constant data and the corresponding
102 // data locations in cache.
103 //
104 //                --------- activation --------
105 //                ↓      ↓      ↓             ↓
106 // E.g. input -> ADD -> ADD -> ADD -> ... -> ADD -> output
107 //                ↑      ↑      ↑             ↑
108 //               [1]    [1]    [1]           [1]
109 //
110 // This function assumes the operation is either ADD or MUL.
111 template <typename CppType, TestOperandType operandType>
createLargeTestModelImpl(TestOperationType op,uint32_t len)112 TestModel createLargeTestModelImpl(TestOperationType op, uint32_t len) {
113     EXPECT_TRUE(op == TestOperationType::ADD || op == TestOperationType::MUL);
114 
115     // Model operations and operands.
116     std::vector<TestOperation> operations(len);
117     std::vector<TestOperand> operands(len * 2 + 2);
118 
119     // The activation scalar, value = 0.
120     operands[0] = {
121             .type = TestOperandType::INT32,
122             .dimensions = {},
123             .numberOfConsumers = len,
124             .scale = 0.0f,
125             .zeroPoint = 0,
126             .lifetime = TestOperandLifeTime::CONSTANT_COPY,
127             .data = TestBuffer::createFromVector<int32_t>({0}),
128     };
129 
130     // The buffer value of the constant second operand. The logical value is always 1.0f.
131     CppType bufferValue;
132     // The scale of the first and second operand.
133     float scale1, scale2;
134     if (operandType == TestOperandType::TENSOR_FLOAT32) {
135         bufferValue = 1.0f;
136         scale1 = 0.0f;
137         scale2 = 0.0f;
138     } else if (op == TestOperationType::ADD) {
139         bufferValue = 1;
140         scale1 = 1.0f;
141         scale2 = 1.0f;
142     } else {
143         // To satisfy the constraint on quant8 MUL: input0.scale * input1.scale < output.scale,
144         // set input1 to have scale = 0.5f and bufferValue = 2, i.e. 1.0f in floating point.
145         bufferValue = 2;
146         scale1 = 1.0f;
147         scale2 = 0.5f;
148     }
149 
150     for (uint32_t i = 0; i < len; i++) {
151         const uint32_t firstInputIndex = i * 2 + 1;
152         const uint32_t secondInputIndex = firstInputIndex + 1;
153         const uint32_t outputIndex = secondInputIndex + 1;
154 
155         // The first operation input.
156         operands[firstInputIndex] = {
157                 .type = operandType,
158                 .dimensions = {1},
159                 .numberOfConsumers = 1,
160                 .scale = scale1,
161                 .zeroPoint = 0,
162                 .lifetime = (i == 0 ? TestOperandLifeTime::MODEL_INPUT
163                                     : TestOperandLifeTime::TEMPORARY_VARIABLE),
164                 .data = (i == 0 ? TestBuffer::createFromVector<CppType>({1}) : TestBuffer()),
165         };
166 
167         // The second operation input, value = 1.
168         operands[secondInputIndex] = {
169                 .type = operandType,
170                 .dimensions = {1},
171                 .numberOfConsumers = 1,
172                 .scale = scale2,
173                 .zeroPoint = 0,
174                 .lifetime = TestOperandLifeTime::CONSTANT_COPY,
175                 .data = TestBuffer::createFromVector<CppType>({bufferValue}),
176         };
177 
178         // The operation. All operations share the same activation scalar.
179         // The output operand is created as an input in the next iteration of the loop, in the case
180         // of all but the last member of the chain; and after the loop as a model output, in the
181         // case of the last member of the chain.
182         operations[i] = {
183                 .type = op,
184                 .inputs = {firstInputIndex, secondInputIndex, /*activation scalar*/ 0},
185                 .outputs = {outputIndex},
186         };
187     }
188 
189     // For TestOperationType::ADD, output = 1 + 1 * len = len + 1
190     // For TestOperationType::MUL, output = 1 * 1 ^ len = 1
191     CppType outputResult = static_cast<CppType>(op == TestOperationType::ADD ? len + 1u : 1u);
192 
193     // The model output.
194     operands.back() = {
195             .type = operandType,
196             .dimensions = {1},
197             .numberOfConsumers = 0,
198             .scale = scale1,
199             .zeroPoint = 0,
200             .lifetime = TestOperandLifeTime::MODEL_OUTPUT,
201             .data = TestBuffer::createFromVector<CppType>({outputResult}),
202     };
203 
204     return {
205             .main = {.operands = std::move(operands),
206                      .operations = std::move(operations),
207                      .inputIndexes = {1},
208                      .outputIndexes = {len * 2 + 1}},
209             .isRelaxed = false,
210     };
211 }
212 
213 }  // namespace
214 
215 // Tag for the compilation caching tests.
216 class CompilationCachingTestBase : public testing::Test {
217   protected:
CompilationCachingTestBase(std::shared_ptr<IDevice> device,OperandType type)218     CompilationCachingTestBase(std::shared_ptr<IDevice> device, OperandType type)
219         : kDevice(std::move(device)), kOperandType(type) {}
220 
SetUp()221     void SetUp() override {
222         testing::Test::SetUp();
223         ASSERT_NE(kDevice.get(), nullptr);
224         const bool deviceIsResponsive =
225                 ndk::ScopedAStatus::fromStatus(AIBinder_ping(kDevice->asBinder().get())).isOk();
226         ASSERT_TRUE(deviceIsResponsive);
227 
228         // Create cache directory. The cache directory and a temporary cache file is always created
229         // to test the behavior of prepareModelFromCache, even when caching is not supported.
230 #ifdef __ANDROID__
231         char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
232 #else   // __ANDROID__
233         char cacheDirTemp[] = "/tmp/TestCompilationCachingXXXXXX";
234 #endif  // __ANDROID__
235         char* cacheDir = mkdtemp(cacheDirTemp);
236         ASSERT_NE(cacheDir, nullptr);
237         mCacheDir = cacheDir;
238         mCacheDir.push_back('/');
239 
240         NumberOfCacheFiles numCacheFiles;
241         const auto ret = kDevice->getNumberOfCacheFilesNeeded(&numCacheFiles);
242         ASSERT_TRUE(ret.isOk());
243 
244         mNumModelCache = numCacheFiles.numModelCache;
245         mNumDataCache = numCacheFiles.numDataCache;
246         ASSERT_GE(mNumModelCache, 0) << "Invalid numModelCache: " << mNumModelCache;
247         ASSERT_GE(mNumDataCache, 0) << "Invalid numDataCache: " << mNumDataCache;
248         mIsCachingSupported = mNumModelCache > 0 || mNumDataCache > 0;
249 
250         // Create empty cache files.
251         mTmpCache = mCacheDir + "tmp";
252         for (uint32_t i = 0; i < mNumModelCache; i++) {
253             mModelCache.push_back({mCacheDir + "model" + std::to_string(i)});
254         }
255         for (uint32_t i = 0; i < mNumDataCache; i++) {
256             mDataCache.push_back({mCacheDir + "data" + std::to_string(i)});
257         }
258         // Placeholder handles, use AccessMode::WRITE_ONLY for createCacheFds to create files.
259         std::vector<ndk::ScopedFileDescriptor> modelHandle, dataHandle, tmpHandle;
260         createCacheFds(mModelCache, AccessMode::WRITE_ONLY, &modelHandle);
261         createCacheFds(mDataCache, AccessMode::WRITE_ONLY, &dataHandle);
262         createCacheFds({mTmpCache}, AccessMode::WRITE_ONLY, &tmpHandle);
263 
264         if (!mIsCachingSupported) {
265             LOG(INFO) << "NN VTS: Early termination of test because vendor service does not "
266                          "support compilation caching.";
267             std::cout << "[          ]   Early termination of test because vendor service does not "
268                          "support compilation caching."
269                       << std::endl;
270         }
271     }
272 
TearDown()273     void TearDown() override {
274         // If the test passes, remove the tmp directory.  Otherwise, keep it for debugging purposes.
275         if (!testing::Test::HasFailure()) {
276             // Recursively remove the cache directory specified by mCacheDir.
277             auto callback = [](const char* entry, const struct stat*, int, struct FTW*) {
278                 return remove(entry);
279             };
280             nftw(mCacheDir.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS);
281         }
282         testing::Test::TearDown();
283     }
284 
285     // Model and examples creators. According to kOperandType, the following methods will return
286     // either float32 model/examples or the quant8 variant.
createTestModel()287     TestModel createTestModel() {
288         if (kOperandType == OperandType::TENSOR_FLOAT32) {
289             return float32_model::get_test_model();
290         } else {
291             return quant8_model::get_test_model();
292         }
293     }
294 
createLargeTestModel(OperationType op,uint32_t len)295     TestModel createLargeTestModel(OperationType op, uint32_t len) {
296         if (kOperandType == OperandType::TENSOR_FLOAT32) {
297             return createLargeTestModelImpl<float, TestOperandType::TENSOR_FLOAT32>(
298                     static_cast<TestOperationType>(op), len);
299         } else {
300             return createLargeTestModelImpl<uint8_t, TestOperandType::TENSOR_QUANT8_ASYMM>(
301                     static_cast<TestOperationType>(op), len);
302         }
303     }
304 
305     // See if the service can handle the model.
isModelFullySupported(const Model & model)306     bool isModelFullySupported(const Model& model) {
307         std::vector<bool> supportedOps;
308         const auto supportedCall = kDevice->getSupportedOperations(model, &supportedOps);
309         EXPECT_TRUE(supportedCall.isOk());
310         EXPECT_EQ(supportedOps.size(), model.main.operations.size());
311         if (!supportedCall.isOk() || supportedOps.size() != model.main.operations.size()) {
312             return false;
313         }
314         return std::all_of(supportedOps.begin(), supportedOps.end(),
315                            [](bool valid) { return valid; });
316     }
317 
saveModelToCache(const Model & model,const std::vector<ndk::ScopedFileDescriptor> & modelCache,const std::vector<ndk::ScopedFileDescriptor> & dataCache,std::shared_ptr<IPreparedModel> * preparedModel=nullptr)318     void saveModelToCache(const Model& model,
319                           const std::vector<ndk::ScopedFileDescriptor>& modelCache,
320                           const std::vector<ndk::ScopedFileDescriptor>& dataCache,
321                           std::shared_ptr<IPreparedModel>* preparedModel = nullptr) {
322         if (preparedModel != nullptr) *preparedModel = nullptr;
323 
324         // Launch prepare model.
325         std::shared_ptr<PreparedModelCallback> preparedModelCallback =
326                 ndk::SharedRefBase::make<PreparedModelCallback>();
327         std::vector<uint8_t> cacheToken(std::begin(mToken), std::end(mToken));
328         const auto prepareLaunchStatus = kDevice->prepareModel(
329                 model, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority, kNoDeadline,
330                 modelCache, dataCache, cacheToken, preparedModelCallback);
331         ASSERT_TRUE(prepareLaunchStatus.isOk());
332 
333         // Retrieve prepared model.
334         preparedModelCallback->wait();
335         ASSERT_EQ(preparedModelCallback->getStatus(), ErrorStatus::NONE);
336         if (preparedModel != nullptr) {
337             *preparedModel = preparedModelCallback->getPreparedModel();
338         }
339     }
340 
checkEarlyTermination(ErrorStatus status)341     bool checkEarlyTermination(ErrorStatus status) {
342         if (status == ErrorStatus::GENERAL_FAILURE) {
343             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
344                          "save the prepared model that it does not support.";
345             std::cout << "[          ]   Early termination of test because vendor service cannot "
346                          "save the prepared model that it does not support."
347                       << std::endl;
348             return true;
349         }
350         return false;
351     }
352 
checkEarlyTermination(const Model & model)353     bool checkEarlyTermination(const Model& model) {
354         if (!isModelFullySupported(model)) {
355             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
356                          "prepare model that it does not support.";
357             std::cout << "[          ]   Early termination of test because vendor service cannot "
358                          "prepare model that it does not support."
359                       << std::endl;
360             return true;
361         }
362         return false;
363     }
364 
365     // If fallbackModel is not provided, call prepareModelFromCache.
366     // If fallbackModel is provided, and prepareModelFromCache returns GENERAL_FAILURE,
367     // then prepareModel(fallbackModel) will be called.
368     // This replicates the behaviour of the runtime when loading a model from cache.
369     // NNAPI Shim depends on this behaviour and may try to load the model from cache in
370     // prepareModel (shim needs model information when loading from cache).
prepareModelFromCache(const std::vector<ndk::ScopedFileDescriptor> & modelCache,const std::vector<ndk::ScopedFileDescriptor> & dataCache,std::shared_ptr<IPreparedModel> * preparedModel,ErrorStatus * status,const Model * fallbackModel=nullptr)371     void prepareModelFromCache(const std::vector<ndk::ScopedFileDescriptor>& modelCache,
372                                const std::vector<ndk::ScopedFileDescriptor>& dataCache,
373                                std::shared_ptr<IPreparedModel>* preparedModel, ErrorStatus* status,
374                                const Model* fallbackModel = nullptr) {
375         // Launch prepare model from cache.
376         std::shared_ptr<PreparedModelCallback> preparedModelCallback =
377                 ndk::SharedRefBase::make<PreparedModelCallback>();
378         std::vector<uint8_t> cacheToken(std::begin(mToken), std::end(mToken));
379         auto prepareLaunchStatus = kDevice->prepareModelFromCache(
380                 kNoDeadline, modelCache, dataCache, cacheToken, preparedModelCallback);
381 
382         // The shim does not support prepareModelFromCache() properly, but it
383         // will still attempt to create a model from cache when modelCache or
384         // dataCache is provided in prepareModel(). Instead of failing straight
385         // away, we try to utilize that other code path when fallbackModel is
386         // set. Note that we cannot verify whether the returned model was
387         // actually prepared from cache in that case.
388         if (!prepareLaunchStatus.isOk() &&
389             prepareLaunchStatus.getExceptionCode() == EX_SERVICE_SPECIFIC &&
390             static_cast<ErrorStatus>(prepareLaunchStatus.getServiceSpecificError()) ==
391                     ErrorStatus::GENERAL_FAILURE &&
392             mIsCachingSupported && fallbackModel != nullptr) {
393             preparedModelCallback = ndk::SharedRefBase::make<PreparedModelCallback>();
394             prepareLaunchStatus = kDevice->prepareModel(
395                     *fallbackModel, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority,
396                     kNoDeadline, modelCache, dataCache, cacheToken, preparedModelCallback);
397         }
398 
399         ASSERT_TRUE(prepareLaunchStatus.isOk() ||
400                     prepareLaunchStatus.getExceptionCode() == EX_SERVICE_SPECIFIC)
401                 << "prepareLaunchStatus: " << prepareLaunchStatus.getDescription();
402         if (!prepareLaunchStatus.isOk()) {
403             *preparedModel = nullptr;
404             *status = static_cast<ErrorStatus>(prepareLaunchStatus.getServiceSpecificError());
405             return;
406         }
407 
408         // Retrieve prepared model.
409         preparedModelCallback->wait();
410         *status = preparedModelCallback->getStatus();
411         *preparedModel = preparedModelCallback->getPreparedModel();
412     }
413 
414     // Replicate behaviour of runtime when loading model from cache.
415     // Test if prepareModelFromCache behaves correctly when faced with bad
416     // arguments. If prepareModelFromCache is not supported (GENERAL_FAILURE),
417     // it attempts to call prepareModel with same arguments, which is expected either
418     // to not support the model (GENERAL_FAILURE) or return a valid model.
verifyModelPreparationBehaviour(const std::vector<ndk::ScopedFileDescriptor> & modelCache,const std::vector<ndk::ScopedFileDescriptor> & dataCache,const Model * model,const TestModel & testModel)419     void verifyModelPreparationBehaviour(const std::vector<ndk::ScopedFileDescriptor>& modelCache,
420                                          const std::vector<ndk::ScopedFileDescriptor>& dataCache,
421                                          const Model* model, const TestModel& testModel) {
422         std::shared_ptr<IPreparedModel> preparedModel;
423         ErrorStatus status;
424 
425         // Verify that prepareModelFromCache fails either due to bad
426         // arguments (INVALID_ARGUMENT) or GENERAL_FAILURE if not supported.
427         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status,
428                               /*fallbackModel=*/nullptr);
429         if (status != ErrorStatus::INVALID_ARGUMENT) {
430             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
431         }
432         ASSERT_EQ(preparedModel, nullptr);
433 
434         // If caching is not supported, attempt calling prepareModel.
435         if (status == ErrorStatus::GENERAL_FAILURE) {
436             // Fallback with prepareModel should succeed regardless of cache files
437             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status,
438                                   /*fallbackModel=*/model);
439             // Unless caching is not supported?
440             if (status != ErrorStatus::GENERAL_FAILURE) {
441                 // But if it is, we should see a valid model.
442                 ASSERT_EQ(status, ErrorStatus::NONE);
443                 ASSERT_NE(preparedModel, nullptr);
444                 EvaluatePreparedModel(kDevice, preparedModel, testModel,
445                                       /*testKind=*/TestKind::GENERAL);
446             }
447         }
448     }
449 
450     // Absolute path to the temporary cache directory.
451     std::string mCacheDir;
452 
453     // Groups of file paths for model and data cache in the tmp cache directory, initialized with
454     // size = mNum{Model|Data}Cache. The outer vector corresponds to handles and the inner vector is
455     // for fds held by each handle.
456     std::vector<std::string> mModelCache;
457     std::vector<std::string> mDataCache;
458 
459     // A separate temporary file path in the tmp cache directory.
460     std::string mTmpCache;
461 
462     uint8_t mToken[static_cast<uint32_t>(IDevice::BYTE_SIZE_OF_CACHE_TOKEN)] = {};
463     uint32_t mNumModelCache;
464     uint32_t mNumDataCache;
465     bool mIsCachingSupported;
466 
467     const std::shared_ptr<IDevice> kDevice;
468     // The primary data type of the testModel.
469     const OperandType kOperandType;
470 };
471 
472 using CompilationCachingTestParam = std::tuple<NamedDevice, OperandType>;
473 
474 // A parameterized fixture of CompilationCachingTestBase. Every test will run twice, with the first
475 // pass running with float32 models and the second pass running with quant8 models.
476 class CompilationCachingTest : public CompilationCachingTestBase,
477                                public testing::WithParamInterface<CompilationCachingTestParam> {
478   protected:
CompilationCachingTest()479     CompilationCachingTest()
480         : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
481                                      std::get<OperandType>(GetParam())) {}
482 };
483 
TEST_P(CompilationCachingTest,CacheSavingAndRetrieval)484 TEST_P(CompilationCachingTest, CacheSavingAndRetrieval) {
485     // Create test HIDL model and compile.
486     const TestModel& testModel = createTestModel();
487     const Model model = createModel(testModel);
488     if (checkEarlyTermination(model)) return;
489     std::shared_ptr<IPreparedModel> preparedModel = nullptr;
490 
491     // Save the compilation to cache.
492     {
493         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
494         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
495         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
496         saveModelToCache(model, modelCache, dataCache);
497     }
498 
499     // Retrieve preparedModel from cache.
500     {
501         preparedModel = nullptr;
502         ErrorStatus status;
503         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
504         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
505         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
506         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status,
507                               /*fallbackModel=*/&model);
508         if (!mIsCachingSupported) {
509             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
510             ASSERT_EQ(preparedModel, nullptr);
511             return;
512         } else if (checkEarlyTermination(status)) {
513             ASSERT_EQ(preparedModel, nullptr);
514             return;
515         } else {
516             ASSERT_EQ(status, ErrorStatus::NONE);
517             ASSERT_NE(preparedModel, nullptr);
518         }
519     }
520 
521     // Execute and verify results.
522     EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
523 }
524 
TEST_P(CompilationCachingTest,CacheSavingAndRetrievalNonZeroOffset)525 TEST_P(CompilationCachingTest, CacheSavingAndRetrievalNonZeroOffset) {
526     // Create test HIDL model and compile.
527     const TestModel& testModel = createTestModel();
528     const Model model = createModel(testModel);
529     if (checkEarlyTermination(model)) return;
530     std::shared_ptr<IPreparedModel> preparedModel = nullptr;
531 
532     // Save the compilation to cache.
533     {
534         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
535         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
536         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
537         uint8_t placeholderBytes[] = {0, 0};
538         // Write a placeholder integer to the cache.
539         // The driver should be able to handle non-empty cache and non-zero fd offset.
540         for (uint32_t i = 0; i < modelCache.size(); i++) {
541             ASSERT_EQ(write(modelCache[i].get(), &placeholderBytes, sizeof(placeholderBytes)),
542                       sizeof(placeholderBytes));
543         }
544         for (uint32_t i = 0; i < dataCache.size(); i++) {
545             ASSERT_EQ(write(dataCache[i].get(), &placeholderBytes, sizeof(placeholderBytes)),
546                       sizeof(placeholderBytes));
547         }
548         saveModelToCache(model, modelCache, dataCache);
549     }
550 
551     // Retrieve preparedModel from cache.
552     {
553         preparedModel = nullptr;
554         ErrorStatus status;
555         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
556         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
557         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
558         uint8_t placeholderByte = 0;
559         // Advance the offset of each handle by one byte.
560         // The driver should be able to handle non-zero fd offset.
561         for (uint32_t i = 0; i < modelCache.size(); i++) {
562             ASSERT_GE(read(modelCache[i].get(), &placeholderByte, 1), 0);
563         }
564         for (uint32_t i = 0; i < dataCache.size(); i++) {
565             ASSERT_GE(read(dataCache[i].get(), &placeholderByte, 1), 0);
566         }
567         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status,
568                               /*fallbackModel=*/&model);
569         if (!mIsCachingSupported) {
570             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
571             ASSERT_EQ(preparedModel, nullptr);
572             return;
573         } else if (checkEarlyTermination(status)) {
574             ASSERT_EQ(preparedModel, nullptr);
575             return;
576         } else {
577             ASSERT_EQ(status, ErrorStatus::NONE);
578             ASSERT_NE(preparedModel, nullptr);
579         }
580     }
581 
582     // Execute and verify results.
583     EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
584 }
585 
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumCache)586 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) {
587     // Create test HIDL model and compile.
588     const TestModel& testModel = createTestModel();
589     const Model model = createModel(testModel);
590     if (checkEarlyTermination(model)) return;
591 
592     // Test with number of model cache files greater than mNumModelCache.
593     {
594         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
595         // Pass an additional cache file for model cache.
596         mModelCache.push_back({mTmpCache});
597         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
598         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
599         mModelCache.pop_back();
600         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
601         saveModelToCache(model, modelCache, dataCache, &preparedModel);
602         ASSERT_NE(preparedModel, nullptr);
603         // Execute and verify results.
604         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
605         // Check if prepareModelFromCache fails.
606         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
607     }
608 
609     // Test with number of model cache files smaller than mNumModelCache.
610     if (mModelCache.size() > 0) {
611         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
612         // Pop out the last cache file.
613         auto tmp = mModelCache.back();
614         mModelCache.pop_back();
615         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
616         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
617         mModelCache.push_back(tmp);
618         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
619         saveModelToCache(model, modelCache, dataCache, &preparedModel);
620         ASSERT_NE(preparedModel, nullptr);
621         // Execute and verify results.
622         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
623         // Check if prepareModelFromCache fails.
624         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
625     }
626 
627     // Test with number of data cache files greater than mNumDataCache.
628     {
629         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
630         // Pass an additional cache file for data cache.
631         mDataCache.push_back({mTmpCache});
632         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
633         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
634         mDataCache.pop_back();
635         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
636         saveModelToCache(model, modelCache, dataCache, &preparedModel);
637         ASSERT_NE(preparedModel, nullptr);
638         // Execute and verify results.
639         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
640         // Check if prepareModelFromCache fails.
641         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
642     }
643 
644     // Test with number of data cache files smaller than mNumDataCache.
645     if (mDataCache.size() > 0) {
646         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
647         // Pop out the last cache file.
648         auto tmp = mDataCache.back();
649         mDataCache.pop_back();
650         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
651         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
652         mDataCache.push_back(tmp);
653         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
654         saveModelToCache(model, modelCache, dataCache, &preparedModel);
655         ASSERT_NE(preparedModel, nullptr);
656         // Execute and verify results.
657         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
658         // Check if prepareModelFromCache fails.
659         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
660     }
661 }
662 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumCache)663 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumCache) {
664     // Create test HIDL model and compile.
665     const TestModel& testModel = createTestModel();
666     const Model model = createModel(testModel);
667     if (checkEarlyTermination(model)) return;
668 
669     // Save the compilation to cache.
670     {
671         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
672         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
673         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
674         saveModelToCache(model, modelCache, dataCache);
675     }
676 
677     // Test with number of model cache files greater than mNumModelCache.
678     {
679         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
680         mModelCache.push_back({mTmpCache});
681         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
682         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
683         mModelCache.pop_back();
684 
685         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
686     }
687 
688     // Test with number of model cache files smaller than mNumModelCache.
689     if (mModelCache.size() > 0) {
690         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
691         auto tmp = mModelCache.back();
692         mModelCache.pop_back();
693         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
694         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
695         mModelCache.push_back(tmp);
696 
697         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
698     }
699 
700     // Test with number of data cache files greater than mNumDataCache.
701     {
702         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
703         mDataCache.push_back({mTmpCache});
704         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
705         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
706         mDataCache.pop_back();
707 
708         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
709     }
710 
711     // Test with number of data cache files smaller than mNumDataCache.
712     if (mDataCache.size() > 0) {
713         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
714         auto tmp = mDataCache.back();
715         mDataCache.pop_back();
716         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
717         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
718         mDataCache.push_back(tmp);
719 
720         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
721     }
722 }
723 
TEST_P(CompilationCachingTest,SaveToCacheInvalidAccessMode)724 TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) {
725     // Create test HIDL model and compile.
726     const TestModel& testModel = createTestModel();
727     const Model model = createModel(testModel);
728     if (checkEarlyTermination(model)) return;
729     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
730     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
731 
732     // Go through each handle in model cache, test with invalid access mode.
733     for (uint32_t i = 0; i < mNumModelCache; i++) {
734         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
735         modelCacheMode[i] = AccessMode::READ_ONLY;
736         createCacheFds(mModelCache, modelCacheMode, &modelCache);
737         createCacheFds(mDataCache, dataCacheMode, &dataCache);
738         modelCacheMode[i] = AccessMode::READ_WRITE;
739         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
740         saveModelToCache(model, modelCache, dataCache, &preparedModel);
741         ASSERT_NE(preparedModel, nullptr);
742         // Execute and verify results.
743         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
744         // Check if prepareModelFromCache fails.
745         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
746     }
747 
748     // Go through each handle in data cache, test with invalid access mode.
749     for (uint32_t i = 0; i < mNumDataCache; i++) {
750         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
751         dataCacheMode[i] = AccessMode::READ_ONLY;
752         createCacheFds(mModelCache, modelCacheMode, &modelCache);
753         createCacheFds(mDataCache, dataCacheMode, &dataCache);
754         dataCacheMode[i] = AccessMode::READ_WRITE;
755         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
756         saveModelToCache(model, modelCache, dataCache, &preparedModel);
757         ASSERT_NE(preparedModel, nullptr);
758         // Execute and verify results.
759         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
760         // Check if prepareModelFromCache fails.
761         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
762     }
763 }
764 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidAccessMode)765 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidAccessMode) {
766     // Create test HIDL model and compile.
767     const TestModel& testModel = createTestModel();
768     const Model model = createModel(testModel);
769     if (checkEarlyTermination(model)) return;
770     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
771     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
772 
773     // Save the compilation to cache.
774     {
775         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
776         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
777         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
778         saveModelToCache(model, modelCache, dataCache);
779     }
780 
781     // Go through each handle in model cache, test with invalid access mode.
782     for (uint32_t i = 0; i < mNumModelCache; i++) {
783         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
784         modelCacheMode[i] = AccessMode::WRITE_ONLY;
785         createCacheFds(mModelCache, modelCacheMode, &modelCache);
786         createCacheFds(mDataCache, dataCacheMode, &dataCache);
787         modelCacheMode[i] = AccessMode::READ_WRITE;
788 
789         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
790     }
791 
792     // Go through each handle in data cache, test with invalid access mode.
793     for (uint32_t i = 0; i < mNumDataCache; i++) {
794         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
795         dataCacheMode[i] = AccessMode::WRITE_ONLY;
796         createCacheFds(mModelCache, modelCacheMode, &modelCache);
797         createCacheFds(mDataCache, dataCacheMode, &dataCache);
798         dataCacheMode[i] = AccessMode::READ_WRITE;
799         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
800     }
801 }
802 
803 // Copy file contents between files.
804 // The vector sizes must match.
copyCacheFiles(const std::vector<std::string> & from,const std::vector<std::string> & to)805 static void copyCacheFiles(const std::vector<std::string>& from,
806                            const std::vector<std::string>& to) {
807     constexpr size_t kBufferSize = 1000000;
808     uint8_t buffer[kBufferSize];
809 
810     ASSERT_EQ(from.size(), to.size());
811     for (uint32_t i = 0; i < from.size(); i++) {
812         int fromFd = open(from[i].c_str(), O_RDONLY);
813         int toFd = open(to[i].c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
814         ASSERT_GE(fromFd, 0);
815         ASSERT_GE(toFd, 0);
816 
817         ssize_t readBytes;
818         while ((readBytes = read(fromFd, &buffer, kBufferSize)) > 0) {
819             ASSERT_EQ(write(toFd, &buffer, readBytes), readBytes);
820         }
821         ASSERT_GE(readBytes, 0);
822 
823         close(fromFd);
824         close(toFd);
825     }
826 }
827 
828 // Number of operations in the large test model.
829 constexpr uint32_t kLargeModelSize = 100;
830 constexpr uint32_t kNumIterationsTOCTOU = 100;
831 
TEST_P(CompilationCachingTest,SaveToCache_TOCTOU)832 TEST_P(CompilationCachingTest, SaveToCache_TOCTOU) {
833     if (!mIsCachingSupported) return;
834 
835     // Create test models and check if fully supported by the service.
836     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
837     const Model modelMul = createModel(testModelMul);
838     if (checkEarlyTermination(modelMul)) return;
839     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
840     const Model modelAdd = createModel(testModelAdd);
841     if (checkEarlyTermination(modelAdd)) return;
842 
843     // Save the modelMul compilation to cache.
844     auto modelCacheMul = mModelCache;
845     for (auto& cache : modelCacheMul) {
846         cache.append("_mul");
847     }
848     {
849         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
850         createCacheFds(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
851         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
852         saveModelToCache(modelMul, modelCache, dataCache);
853     }
854 
855     // Use a different token for modelAdd.
856     mToken[0]++;
857 
858     // This test is probabilistic, so we run it multiple times.
859     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
860         // Save the modelAdd compilation to cache.
861         {
862             std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
863             createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
864             createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
865 
866             // Spawn a thread to copy the cache content concurrently while saving to cache.
867             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
868             saveModelToCache(modelAdd, modelCache, dataCache);
869             thread.join();
870         }
871 
872         // Retrieve preparedModel from cache.
873         {
874             std::shared_ptr<IPreparedModel> preparedModel = nullptr;
875             ErrorStatus status;
876             std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
877             createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
878             createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
879             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status,
880                                   /*fallbackModel=*/nullptr);
881 
882             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
883             // the prepared model must be executed with the correct result and not crash.
884             if (status != ErrorStatus::NONE) {
885                 ASSERT_EQ(preparedModel, nullptr);
886             } else {
887                 ASSERT_NE(preparedModel, nullptr);
888                 EvaluatePreparedModel(kDevice, preparedModel, testModelAdd,
889                                       /*testKind=*/TestKind::GENERAL);
890             }
891         }
892     }
893 }
894 
TEST_P(CompilationCachingTest,PrepareFromCache_TOCTOU)895 TEST_P(CompilationCachingTest, PrepareFromCache_TOCTOU) {
896     if (!mIsCachingSupported) return;
897 
898     // Create test models and check if fully supported by the service.
899     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
900     const Model modelMul = createModel(testModelMul);
901     if (checkEarlyTermination(modelMul)) return;
902     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
903     const Model modelAdd = createModel(testModelAdd);
904     if (checkEarlyTermination(modelAdd)) return;
905 
906     // Save the modelMul compilation to cache.
907     auto modelCacheMul = mModelCache;
908     for (auto& cache : modelCacheMul) {
909         cache.append("_mul");
910     }
911     {
912         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
913         createCacheFds(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
914         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
915         saveModelToCache(modelMul, modelCache, dataCache);
916     }
917 
918     // Use a different token for modelAdd.
919     mToken[0]++;
920 
921     // This test is probabilistic, so we run it multiple times.
922     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
923         // Save the modelAdd compilation to cache.
924         {
925             std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
926             createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
927             createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
928             saveModelToCache(modelAdd, modelCache, dataCache);
929         }
930 
931         // Retrieve preparedModel from cache.
932         {
933             std::shared_ptr<IPreparedModel> preparedModel = nullptr;
934             ErrorStatus status;
935             std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
936             createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
937             createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
938 
939             // Spawn a thread to copy the cache content concurrently while preparing from cache.
940             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
941             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status,
942                                   /*fallbackModel=*/nullptr);
943             thread.join();
944 
945             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
946             // the prepared model must be executed with the correct result and not crash.
947             if (status != ErrorStatus::NONE) {
948                 ASSERT_EQ(preparedModel, nullptr);
949             } else {
950                 ASSERT_NE(preparedModel, nullptr);
951                 EvaluatePreparedModel(kDevice, preparedModel, testModelAdd,
952                                       /*testKind=*/TestKind::GENERAL);
953             }
954         }
955     }
956 }
957 
TEST_P(CompilationCachingTest,ReplaceSecuritySensitiveCache)958 TEST_P(CompilationCachingTest, ReplaceSecuritySensitiveCache) {
959     if (!mIsCachingSupported) return;
960 
961     // Create test models and check if fully supported by the service.
962     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
963     const Model modelMul = createModel(testModelMul);
964     if (checkEarlyTermination(modelMul)) return;
965     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
966     const Model modelAdd = createModel(testModelAdd);
967     if (checkEarlyTermination(modelAdd)) return;
968 
969     // Save the modelMul compilation to cache.
970     auto modelCacheMul = mModelCache;
971     for (auto& cache : modelCacheMul) {
972         cache.append("_mul");
973     }
974     {
975         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
976         createCacheFds(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
977         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
978         saveModelToCache(modelMul, modelCache, dataCache);
979     }
980 
981     // Use a different token for modelAdd.
982     mToken[0]++;
983 
984     // Save the modelAdd compilation to cache.
985     {
986         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
987         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
988         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
989         saveModelToCache(modelAdd, modelCache, dataCache);
990     }
991 
992     // Replace the model cache of modelAdd with modelMul.
993     copyCacheFiles(modelCacheMul, mModelCache);
994 
995     // Retrieve the preparedModel from cache, expect failure.
996     {
997         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
998         ErrorStatus status;
999         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
1000         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
1001         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
1002         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1003         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1004         ASSERT_EQ(preparedModel, nullptr);
1005     }
1006 }
1007 
1008 // TODO(b/179270601): restore kNamedDeviceChoices.
1009 static const auto kOperandTypeChoices =
1010         testing::Values(OperandType::TENSOR_FLOAT32, OperandType::TENSOR_QUANT8_ASYMM);
1011 
printCompilationCachingTest(const testing::TestParamInfo<CompilationCachingTestParam> & info)1012 std::string printCompilationCachingTest(
1013         const testing::TestParamInfo<CompilationCachingTestParam>& info) {
1014     const auto& [namedDevice, operandType] = info.param;
1015     const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1016     return gtestCompliantName(getName(namedDevice) + "_" + type);
1017 }
1018 
1019 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompilationCachingTest);
1020 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingTest,
1021                          testing::Combine(testing::ValuesIn(getNamedDevices()),
1022                                           kOperandTypeChoices),
1023                          printCompilationCachingTest);
1024 
1025 using CompilationCachingSecurityTestParam = std::tuple<NamedDevice, OperandType, uint32_t>;
1026 
1027 class CompilationCachingSecurityTest
1028     : public CompilationCachingTestBase,
1029       public testing::WithParamInterface<CompilationCachingSecurityTestParam> {
1030   protected:
CompilationCachingSecurityTest()1031     CompilationCachingSecurityTest()
1032         : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
1033                                      std::get<OperandType>(GetParam())) {}
1034 
SetUp()1035     void SetUp() {
1036         CompilationCachingTestBase::SetUp();
1037         generator.seed(kSeed);
1038     }
1039 
1040     // Get a random integer within a closed range [lower, upper].
1041     template <typename T>
getRandomInt(T lower,T upper)1042     T getRandomInt(T lower, T upper) {
1043         std::uniform_int_distribution<T> dis(lower, upper);
1044         return dis(generator);
1045     }
1046 
1047     // Randomly flip one single bit of the cache entry.
flipOneBitOfCache(const std::string & filename,bool * skip)1048     void flipOneBitOfCache(const std::string& filename, bool* skip) {
1049         FILE* pFile = fopen(filename.c_str(), "r+");
1050         ASSERT_EQ(fseek(pFile, 0, SEEK_END), 0);
1051         long int fileSize = ftell(pFile);
1052         if (fileSize == 0) {
1053             fclose(pFile);
1054             *skip = true;
1055             return;
1056         }
1057         ASSERT_EQ(fseek(pFile, getRandomInt(0l, fileSize - 1), SEEK_SET), 0);
1058         int readByte = fgetc(pFile);
1059         ASSERT_NE(readByte, EOF);
1060         ASSERT_EQ(fseek(pFile, -1, SEEK_CUR), 0);
1061         ASSERT_NE(fputc(static_cast<uint8_t>(readByte) ^ (1U << getRandomInt(0, 7)), pFile), EOF);
1062         fclose(pFile);
1063         *skip = false;
1064     }
1065 
1066     // Randomly append bytes to the cache entry.
appendBytesToCache(const std::string & filename,bool * skip)1067     void appendBytesToCache(const std::string& filename, bool* skip) {
1068         FILE* pFile = fopen(filename.c_str(), "a");
1069         uint32_t appendLength = getRandomInt(1, 256);
1070         for (uint32_t i = 0; i < appendLength; i++) {
1071             ASSERT_NE(fputc(getRandomInt<uint16_t>(0, 255), pFile), EOF);
1072         }
1073         fclose(pFile);
1074         *skip = false;
1075     }
1076 
1077     enum class ExpectedResult { GENERAL_FAILURE, NOT_CRASH };
1078 
1079     // Test if the driver behaves as expected when given corrupted cache or token.
1080     // The modifier will be invoked after save to cache but before prepare from cache.
1081     // The modifier accepts one pointer argument "skip" as the returning value, indicating
1082     // whether the test should be skipped or not.
testCorruptedCache(ExpectedResult expected,std::function<void (bool *)> modifier)1083     void testCorruptedCache(ExpectedResult expected, std::function<void(bool*)> modifier) {
1084         const TestModel& testModel = createTestModel();
1085         const Model model = createModel(testModel);
1086         if (checkEarlyTermination(model)) return;
1087 
1088         // Save the compilation to cache.
1089         {
1090             std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
1091             createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
1092             createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
1093             saveModelToCache(model, modelCache, dataCache);
1094         }
1095 
1096         bool skip = false;
1097         modifier(&skip);
1098         if (skip) return;
1099 
1100         // Retrieve preparedModel from cache.
1101         {
1102             std::shared_ptr<IPreparedModel> preparedModel = nullptr;
1103             ErrorStatus status;
1104             std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
1105             createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
1106             createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
1107             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1108 
1109             switch (expected) {
1110                 case ExpectedResult::GENERAL_FAILURE:
1111                     ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1112                     ASSERT_EQ(preparedModel, nullptr);
1113                     break;
1114                 case ExpectedResult::NOT_CRASH:
1115                     ASSERT_EQ(preparedModel == nullptr, status != ErrorStatus::NONE);
1116                     break;
1117                 default:
1118                     FAIL();
1119             }
1120         }
1121     }
1122 
1123     const uint32_t kSeed = std::get<uint32_t>(GetParam());
1124     std::mt19937 generator;
1125 };
1126 
TEST_P(CompilationCachingSecurityTest,CorruptedModelCache)1127 TEST_P(CompilationCachingSecurityTest, CorruptedModelCache) {
1128     if (!mIsCachingSupported) return;
1129     for (uint32_t i = 0; i < mNumModelCache; i++) {
1130         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1131                            [this, i](bool* skip) { flipOneBitOfCache(mModelCache[i], skip); });
1132     }
1133 }
1134 
TEST_P(CompilationCachingSecurityTest,WrongLengthModelCache)1135 TEST_P(CompilationCachingSecurityTest, WrongLengthModelCache) {
1136     if (!mIsCachingSupported) return;
1137     for (uint32_t i = 0; i < mNumModelCache; i++) {
1138         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1139                            [this, i](bool* skip) { appendBytesToCache(mModelCache[i], skip); });
1140     }
1141 }
1142 
TEST_P(CompilationCachingSecurityTest,CorruptedDataCache)1143 TEST_P(CompilationCachingSecurityTest, CorruptedDataCache) {
1144     if (!mIsCachingSupported) return;
1145     for (uint32_t i = 0; i < mNumDataCache; i++) {
1146         testCorruptedCache(ExpectedResult::NOT_CRASH,
1147                            [this, i](bool* skip) { flipOneBitOfCache(mDataCache[i], skip); });
1148     }
1149 }
1150 
TEST_P(CompilationCachingSecurityTest,WrongLengthDataCache)1151 TEST_P(CompilationCachingSecurityTest, WrongLengthDataCache) {
1152     if (!mIsCachingSupported) return;
1153     for (uint32_t i = 0; i < mNumDataCache; i++) {
1154         testCorruptedCache(ExpectedResult::NOT_CRASH,
1155                            [this, i](bool* skip) { appendBytesToCache(mDataCache[i], skip); });
1156     }
1157 }
1158 
TEST_P(CompilationCachingSecurityTest,WrongToken)1159 TEST_P(CompilationCachingSecurityTest, WrongToken) {
1160     if (!mIsCachingSupported) return;
1161     testCorruptedCache(ExpectedResult::GENERAL_FAILURE, [this](bool* skip) {
1162         // Randomly flip one single bit in mToken.
1163         uint32_t ind =
1164                 getRandomInt(0u, static_cast<uint32_t>(IDevice::BYTE_SIZE_OF_CACHE_TOKEN) - 1);
1165         mToken[ind] ^= (1U << getRandomInt(0, 7));
1166         *skip = false;
1167     });
1168 }
1169 
printCompilationCachingSecurityTest(const testing::TestParamInfo<CompilationCachingSecurityTestParam> & info)1170 std::string printCompilationCachingSecurityTest(
1171         const testing::TestParamInfo<CompilationCachingSecurityTestParam>& info) {
1172     const auto& [namedDevice, operandType, seed] = info.param;
1173     const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1174     return gtestCompliantName(getName(namedDevice) + "_" + type + "_" + std::to_string(seed));
1175 }
1176 
1177 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompilationCachingSecurityTest);
1178 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingSecurityTest,
1179                          testing::Combine(testing::ValuesIn(getNamedDevices()), kOperandTypeChoices,
1180                                           testing::Range(0U, 10U)),
1181                          printCompilationCachingSecurityTest);
1182 
1183 }  // namespace aidl::android::hardware::neuralnetworks::vts::functional
1184