xref: /aosp_15_r20/hardware/interfaces/drm/aidl/vts/drm_hal_common.cpp (revision 4d7e907c777eeecc4c5bd7cf640a754fac206ff7)
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "drm_hal_common"
18 
19 #include <gtest/gtest.h>
20 #include <log/log.h>
21 #include <openssl/aes.h>
22 #include <sys/mman.h>
23 #include <random>
24 
25 #include <aidlcommonsupport/NativeHandle.h>
26 #include <android/binder_manager.h>
27 #include <android/binder_process.h>
28 #include <android/sharedmem.h>
29 #include <cutils/native_handle.h>
30 #include <cutils/properties.h>
31 
32 #include "drm_hal_clearkey_module.h"
33 #include "drm_hal_common.h"
34 
35 namespace aidl {
36 namespace android {
37 namespace hardware {
38 namespace drm {
39 namespace vts {
40 
41 namespace clearkeydrm = ::android::hardware::drm::V1_2::vts;
42 
43 using std::vector;
44 using ::aidl::android::hardware::common::Ashmem;
45 using ::aidl::android::hardware::drm::DecryptArgs;
46 using ::aidl::android::hardware::drm::DestinationBuffer;
47 using ::aidl::android::hardware::drm::EventType;
48 using ::aidl::android::hardware::drm::ICryptoPlugin;
49 using ::aidl::android::hardware::drm::IDrmPlugin;
50 using ::aidl::android::hardware::drm::KeyRequest;
51 using ::aidl::android::hardware::drm::KeyRequestType;
52 using ::aidl::android::hardware::drm::KeySetId;
53 using ::aidl::android::hardware::drm::KeyType;
54 using ::aidl::android::hardware::drm::KeyValue;
55 using ::aidl::android::hardware::drm::Mode;
56 using ::aidl::android::hardware::drm::Pattern;
57 using ::aidl::android::hardware::drm::ProvisionRequest;
58 using ::aidl::android::hardware::drm::ProvideProvisionResponseResult;
59 using ::aidl::android::hardware::drm::SecurityLevel;
60 using ::aidl::android::hardware::drm::Status;
61 using ::aidl::android::hardware::drm::SubSample;
62 using ::aidl::android::hardware::drm::Uuid;
63 
DrmErr(const::ndk::ScopedAStatus & ret)64 Status DrmErr(const ::ndk::ScopedAStatus& ret) {
65     return static_cast<Status>(ret.getServiceSpecificError());
66 }
67 
HalBaseName(const std::string & fullname)68 std::string HalBaseName(const std::string& fullname) {
69     auto idx = fullname.find('/');
70     if (idx == std::string::npos) {
71         return fullname;
72     }
73     return fullname.substr(idx + 1);
74 }
75 
76 const char* kDrmIface = "android.hardware.drm.IDrmFactory";
77 const int MAX_OPEN_SESSION_ATTEMPTS = 3;
78 
HalFullName(const std::string & iface,const std::string & basename)79 std::string HalFullName(const std::string& iface, const std::string& basename) {
80     return iface + '/' + basename;
81 }
82 
IsOk(const::ndk::ScopedAStatus & ret)83 testing::AssertionResult IsOk(const ::ndk::ScopedAStatus& ret) {
84     if (ret.isOk()) {
85         return testing::AssertionSuccess();
86     }
87     return testing::AssertionFailure() << "ex: " << ret.getExceptionCode()
88                                        << "; svc err: " << ret.getServiceSpecificError()
89                                        << "; desc: " << ret.getDescription();
90 }
91 
92 const char* kCallbackLostState = "LostState";
93 const char* kCallbackKeysChange = "KeysChange";
94 
95 drm_vts::VendorModules* DrmHalTest::gVendorModules = nullptr;
96 
97 /**
98  * DrmHalPluginListener
99  */
onEvent(EventType eventType,const vector<uint8_t> & sessionId,const vector<uint8_t> & data)100 ::ndk::ScopedAStatus DrmHalPluginListener::onEvent(
101         EventType eventType,
102         const vector<uint8_t>& sessionId,
103         const vector<uint8_t>& data) {
104     ListenerArgs args{};
105     args.eventType = eventType;
106     args.sessionId = sessionId;
107     args.data = data;
108     eventPromise.set_value(args);
109     return ::ndk::ScopedAStatus::ok();
110 }
111 
onExpirationUpdate(const vector<uint8_t> & sessionId,int64_t expiryTimeInMS)112 ::ndk::ScopedAStatus DrmHalPluginListener::onExpirationUpdate(
113         const vector<uint8_t>& sessionId,
114         int64_t expiryTimeInMS) {
115     ListenerArgs args{};
116     args.sessionId = sessionId;
117     args.expiryTimeInMS = expiryTimeInMS;
118     expirationUpdatePromise.set_value(args);
119     return ::ndk::ScopedAStatus::ok();
120 
121 }
122 
onSessionLostState(const vector<uint8_t> & sessionId)123 ::ndk::ScopedAStatus DrmHalPluginListener::onSessionLostState(const vector<uint8_t>& sessionId) {
124     ListenerArgs args{};
125     args.sessionId = sessionId;
126     sessionLostStatePromise.set_value(args);
127     return ::ndk::ScopedAStatus::ok();
128 }
129 
onKeysChange(const std::vector<uint8_t> & sessionId,const std::vector<::aidl::android::hardware::drm::KeyStatus> & keyStatusList,bool hasNewUsableKey)130 ::ndk::ScopedAStatus DrmHalPluginListener::onKeysChange(
131         const std::vector<uint8_t>& sessionId,
132         const std::vector<::aidl::android::hardware::drm::KeyStatus>& keyStatusList,
133         bool hasNewUsableKey) {
134     ListenerArgs args{};
135     args.sessionId = sessionId;
136     args.keyStatusList = keyStatusList;
137     args.hasNewUsableKey = hasNewUsableKey;
138     keysChangePromise.set_value(args);
139     return ::ndk::ScopedAStatus::ok();
140 }
141 
getListenerArgs(std::promise<ListenerArgs> & promise)142 ListenerArgs DrmHalPluginListener::getListenerArgs(std::promise<ListenerArgs>& promise) {
143     auto future = promise.get_future();
144     auto timeout = std::chrono::milliseconds(500);
145     EXPECT_EQ(future.wait_for(timeout), std::future_status::ready);
146     return future.get();
147 }
148 
getEventArgs()149 ListenerArgs DrmHalPluginListener::getEventArgs() {
150     return getListenerArgs(eventPromise);
151 }
152 
getExpirationUpdateArgs()153 ListenerArgs DrmHalPluginListener::getExpirationUpdateArgs() {
154     return getListenerArgs(expirationUpdatePromise);
155 }
156 
getSessionLostStateArgs()157 ListenerArgs DrmHalPluginListener::getSessionLostStateArgs() {
158     return getListenerArgs(sessionLostStatePromise);
159 }
160 
getKeysChangeArgs()161 ListenerArgs DrmHalPluginListener::getKeysChangeArgs() {
162     return getListenerArgs(keysChangePromise);
163 }
164 
getModuleForInstance(const std::string & instance)165 static DrmHalVTSVendorModule_V1* getModuleForInstance(const std::string& instance) {
166     if (instance.find("clearkey") != std::string::npos ||
167         instance.find("default") != std::string::npos) {
168         return new clearkeydrm::DrmHalVTSClearkeyModule();
169     }
170 
171     return static_cast<DrmHalVTSVendorModule_V1*>(
172             DrmHalTest::gVendorModules->getModuleByName(instance));
173 }
174 
175 /**
176  * DrmHalTest
177  */
178 
DrmHalTest()179 DrmHalTest::DrmHalTest() : vendorModule(getModuleForInstance(GetParamService())) {}
180 
SetUp()181 void DrmHalTest::SetUp() {
182     const ::testing::TestInfo* const test_info =
183             ::testing::UnitTest::GetInstance()->current_test_info();
184 
185     ALOGD("Running test %s.%s from (vendor) module %s", test_info->test_case_name(),
186           test_info->name(), GetParamService().c_str());
187 
188     auto svc = GetParamService();
189     const string drmInstance = HalFullName(kDrmIface, svc);
190 
191     if (!vendorModule) {
192         ASSERT_NE(drmInstance, HalFullName(kDrmIface, "widevine")) << "Widevine requires vendor module.";
193         ASSERT_NE(drmInstance, HalFullName(kDrmIface, "clearkey")) << "Clearkey requires vendor module.";
194         GTEST_SKIP() << "No vendor module installed";
195     }
196 
197     char bootloader_state[PROPERTY_VALUE_MAX] = {};
198     if (property_get("ro.boot.vbmeta.device_state", bootloader_state, "") != 0) {
199         if (!strcmp(bootloader_state, "unlocked")) {
200             GTEST_SKIP() << "Skip test because bootloader is unlocked";
201         }
202     }
203 
204     if (drmInstance.find("IDrmFactory") != std::string::npos) {
205         drmFactory = IDrmFactory::fromBinder(
206                 ::ndk::SpAIBinder(AServiceManager_waitForService(drmInstance.c_str())));
207         ASSERT_NE(drmFactory, nullptr);
208         drmPlugin = createDrmPlugin();
209         cryptoPlugin = createCryptoPlugin();
210     }
211 
212     ASSERT_EQ(HalBaseName(drmInstance), vendorModule->getServiceName());
213     contentConfigurations = vendorModule->getContentConfigurations();
214 
215     // If drm scheme not installed skip subsequent tests
216     bool result = isCryptoSchemeSupported(getAidlUUID(), SecurityLevel::SW_SECURE_CRYPTO, "cenc");
217     if (!result) {
218         if (GetParamUUID() == std::array<uint8_t, 16>()) {
219             GTEST_SKIP() << "vendor module drm scheme not supported";
220         } else {
221             FAIL() << "param scheme must be supported";
222         }
223     }
224 
225     ASSERT_NE(nullptr, drmPlugin.get())
226             << "Can't find " << vendorModule->getServiceName() << " drm aidl plugin";
227     ASSERT_NE(nullptr, cryptoPlugin.get())
228             << "Can't find " << vendorModule->getServiceName() << " crypto aidl plugin";
229 }
230 
createDrmPlugin()231 std::shared_ptr<::aidl::android::hardware::drm::IDrmPlugin> DrmHalTest::createDrmPlugin() {
232     if (drmFactory == nullptr) {
233         return nullptr;
234     }
235     std::string packageName("aidl.android.hardware.drm.test");
236     std::shared_ptr<::aidl::android::hardware::drm::IDrmPlugin> result;
237     auto ret = drmFactory->createDrmPlugin(getAidlUUID(), packageName, &result);
238     EXPECT_OK(ret) << "createDrmPlugin remote call failed";
239     return result;
240 }
241 
createCryptoPlugin()242 std::shared_ptr<::aidl::android::hardware::drm::ICryptoPlugin> DrmHalTest::createCryptoPlugin() {
243     if (drmFactory == nullptr) {
244         return nullptr;
245     }
246     vector<uint8_t> initVec;
247     std::shared_ptr<::aidl::android::hardware::drm::ICryptoPlugin> result;
248     auto ret = drmFactory->createCryptoPlugin(getAidlUUID(), initVec, &result);
249     EXPECT_OK(ret) << "createCryptoPlugin remote call failed";
250     return result;
251 }
252 
getAidlUUID()253 ::aidl::android::hardware::drm::Uuid DrmHalTest::getAidlUUID() {
254     return toAidlUuid(getUUID());
255 }
256 
getUUID()257 std::vector<uint8_t> DrmHalTest::getUUID() {
258     auto paramUUID = GetParamUUID();
259     if (paramUUID == std::array<uint8_t, 16>()) {
260         return getVendorUUID();
261     }
262     return std::vector(paramUUID.begin(), paramUUID.end());
263 }
264 
getVendorUUID()265 std::vector<uint8_t> DrmHalTest::getVendorUUID() {
266     if (vendorModule == nullptr) {
267         ALOGW("vendor module for %s not found", GetParamService().c_str());
268         return std::vector<uint8_t>(16);
269     }
270     return vendorModule->getUUID();
271 }
272 
isCryptoSchemeSupported(Uuid uuid,SecurityLevel level,std::string mime)273 bool DrmHalTest::isCryptoSchemeSupported(Uuid uuid, SecurityLevel level, std::string mime) {
274     if (drmFactory == nullptr) {
275         return false;
276     }
277     CryptoSchemes schemes{};
278     auto ret = drmFactory->getSupportedCryptoSchemes(&schemes);
279     EXPECT_OK(ret);
280     if (!ret.isOk() || !std::count(schemes.uuids.begin(), schemes.uuids.end(), uuid)) {
281         return false;
282     }
283     if (mime.empty()) {
284         EXPECT_THAT(level, AnyOf(Eq(SecurityLevel::DEFAULT), Eq(SecurityLevel::UNKNOWN)));
285         return true;
286     }
287     for (auto ct : schemes.mimeTypes) {
288         if (ct.mime != mime) {
289             continue;
290         }
291         if (level == SecurityLevel::DEFAULT || level == SecurityLevel::UNKNOWN) {
292             return true;
293         }
294         if (level <= ct.maxLevel && level >= ct.minLevel) {
295             return true;
296         }
297     }
298     return false;
299 }
300 
provision()301 void DrmHalTest::provision() {
302     std::string certificateType;
303     std::string certificateAuthority;
304     vector<uint8_t> provisionRequest;
305     std::string defaultUrl;
306     ProvisionRequest result;
307     auto ret = drmPlugin->getProvisionRequest(certificateType, certificateAuthority, &result);
308 
309     EXPECT_TXN(ret);
310     if (ret.isOk()) {
311         EXPECT_NE(result.request.size(), 0u);
312         provisionRequest = result.request;
313         defaultUrl = result.defaultUrl;
314     } else if (DrmErr(ret) == Status::ERROR_DRM_CANNOT_HANDLE) {
315         EXPECT_EQ(0u, result.request.size());
316     }
317 
318     if (provisionRequest.size() > 0) {
319         vector<uint8_t> response =
320                 vendorModule->handleProvisioningRequest(provisionRequest, defaultUrl);
321         ASSERT_NE(0u, response.size());
322 
323         ProvideProvisionResponseResult result;
324         auto ret = drmPlugin->provideProvisionResponse(response, &result);
325         EXPECT_TXN(ret);
326     }
327 }
328 
openSession(SecurityLevel level,Status * err)329 SessionId DrmHalTest::openSession(SecurityLevel level, Status* err) {
330     SessionId sessionId;
331     auto ret = drmPlugin->openSession(level, &sessionId);
332     EXPECT_TXN(ret);
333     *err = DrmErr(ret);
334     return sessionId;
335 }
336 
337 /**
338  * Helper method to open a session and verify that a non-empty
339  * session ID is returned
340  */
openSession()341 SessionId DrmHalTest::openSession() {
342     SessionId sessionId;
343 
344     int attmpt = 0;
345     while (attmpt++ < MAX_OPEN_SESSION_ATTEMPTS) {
346         auto ret = drmPlugin->openSession(SecurityLevel::DEFAULT, &sessionId);
347         if(DrmErr(ret) == Status::ERROR_DRM_NOT_PROVISIONED) {
348             provision();
349         } else {
350             EXPECT_OK(ret);
351             EXPECT_NE(0u, sessionId.size());
352             break;
353         }
354     }
355 
356     return sessionId;
357 }
358 
359 /**
360  * Helper method to close a session
361  */
closeSession(const SessionId & sessionId)362 void DrmHalTest::closeSession(const SessionId& sessionId) {
363     auto ret = drmPlugin->closeSession(sessionId);
364     EXPECT_OK(ret);
365 }
366 
getKeyRequest(const SessionId & sessionId,const DrmHalVTSVendorModule_V1::ContentConfiguration & configuration,const KeyType & type=KeyType::STREAMING)367 vector<uint8_t> DrmHalTest::getKeyRequest(
368         const SessionId& sessionId,
369         const DrmHalVTSVendorModule_V1::ContentConfiguration& configuration,
370         const KeyType& type = KeyType::STREAMING) {
371     KeyRequest result;
372     auto ret = drmPlugin->getKeyRequest(sessionId, configuration.initData, configuration.mimeType,
373                                         type, toAidlKeyedVector(configuration.optionalParameters),
374                                         &result);
375     EXPECT_OK(ret) << "Failed to get key request for configuration "
376                    << configuration.name << " for key type "
377                    << static_cast<int>(type);
378     if (type == KeyType::RELEASE) {
379         EXPECT_EQ(KeyRequestType::RELEASE, result.requestType);
380     } else {
381         EXPECT_EQ(KeyRequestType::INITIAL, result.requestType);
382     }
383     EXPECT_NE(result.request.size(), 0u) << "Expected key request size"
384                                             " to have length > 0 bytes";
385     return result.request;
386 }
387 
getContent(const KeyType & type) const388 DrmHalVTSVendorModule_V1::ContentConfiguration DrmHalTest::getContent(const KeyType& type) const {
389     for (const auto& config : contentConfigurations) {
390         if (type != KeyType::OFFLINE || config.policy.allowOffline) {
391             return config;
392         }
393     }
394     ADD_FAILURE() << "no content configurations found";
395     return {};
396 }
397 
provideKeyResponse(const SessionId & sessionId,const vector<uint8_t> & keyResponse)398 vector<uint8_t> DrmHalTest::provideKeyResponse(const SessionId& sessionId,
399                                                const vector<uint8_t>& keyResponse) {
400     KeySetId result;
401     auto ret = drmPlugin->provideKeyResponse(sessionId, keyResponse, &result);
402     EXPECT_OK(ret) << "Failure providing key response for configuration ";
403     return result.keySetId;
404 }
405 
406 /**
407  * Helper method to load keys for subsequent decrypt tests.
408  * These tests use predetermined key request/response to
409  * avoid requiring a round trip to a license server.
410  */
loadKeys(const SessionId & sessionId,const DrmHalVTSVendorModule_V1::ContentConfiguration & configuration,const KeyType & type)411 vector<uint8_t> DrmHalTest::loadKeys(
412         const SessionId& sessionId,
413         const DrmHalVTSVendorModule_V1::ContentConfiguration& configuration, const KeyType& type) {
414     vector<uint8_t> keyRequest = getKeyRequest(sessionId, configuration, type);
415 
416     /**
417      * Get key response from vendor module
418      */
419     vector<uint8_t> keyResponse =
420             vendorModule->handleKeyRequest(keyRequest, configuration.serverUrl);
421     EXPECT_NE(keyResponse.size(), 0u) << "Expected key response size "
422                                          "to have length > 0 bytes";
423 
424     return provideKeyResponse(sessionId, keyResponse);
425 }
426 
loadKeys(const SessionId & sessionId,const KeyType & type)427 vector<uint8_t> DrmHalTest::loadKeys(const SessionId& sessionId, const KeyType& type) {
428     return loadKeys(sessionId, getContent(type), type);
429 }
430 
toStdArray(const vector<uint8_t> & vec)431 std::array<uint8_t, 16> DrmHalTest::toStdArray(const vector<uint8_t>& vec) {
432     EXPECT_EQ(16u, vec.size());
433     std::array<uint8_t, 16> arr;
434     std::copy_n(vec.begin(), vec.size(), arr.begin());
435     return arr;
436 }
437 
toAidlKeyedVector(const map<string,string> & params)438 KeyedVector DrmHalTest::toAidlKeyedVector(const map<string, string>& params) {
439     std::vector<KeyValue> stdKeyedVector;
440     for (auto it = params.begin(); it != params.end(); ++it) {
441         KeyValue keyValue;
442         keyValue.key = it->first;
443         keyValue.value = it->second;
444         stdKeyedVector.push_back(keyValue);
445     }
446     return KeyedVector(stdKeyedVector);
447 }
448 
449 /**
450  * getDecryptMemory allocates memory for decryption, then sets it
451  * as a shared buffer base in the crypto hal. An output SharedBuffer
452  * is updated via reference.
453  *
454  * @param size the size of the memory segment to allocate
455  * @param the index of the memory segment which will be used
456  * to refer to it for decryption.
457  */
getDecryptMemory(size_t size,size_t index,SharedBuffer & out)458 void DrmHalTest::getDecryptMemory(size_t size, size_t index, SharedBuffer& out) {
459     out.bufferId = static_cast<int32_t>(index);
460     out.offset = 0;
461     out.size = static_cast<int64_t>(size);
462 
463     int fd = ASharedMemory_create("drmVtsSharedMemory", size);
464     EXPECT_GE(fd, 0);
465     EXPECT_EQ(size, ASharedMemory_getSize(fd));
466     auto handle = native_handle_create(1, 0);
467     handle->data[0] = fd;
468     out.handle = ::android::makeToAidl(handle);
469 
470     EXPECT_OK(cryptoPlugin->setSharedBufferBase(out));
471     native_handle_delete(handle);
472 }
473 
fillRandom(const::aidl::android::hardware::drm::SharedBuffer & buf)474 uint8_t* DrmHalTest::fillRandom(const ::aidl::android::hardware::drm::SharedBuffer& buf) {
475     std::random_device rd;
476     std::mt19937 rand(rd());
477 
478     auto fd = buf.handle.fds[0].get();
479     size_t size = buf.size;
480     uint8_t* base = static_cast<uint8_t*>(
481             mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0));
482     EXPECT_NE(MAP_FAILED, base);
483     for (size_t i = 0; i < size / sizeof(uint32_t); i++) {
484         auto p = static_cast<uint32_t*>(static_cast<void*>(base));
485         p[i] = rand();
486     }
487     return base;
488 }
489 
decrypt(Mode mode,bool isSecure,const std::array<uint8_t,16> & keyId,uint8_t * iv,const vector<SubSample> & subSamples,const Pattern & pattern,const vector<uint8_t> & key,Status expectedStatus)490 uint32_t DrmHalTest::decrypt(Mode mode, bool isSecure, const std::array<uint8_t, 16>& keyId,
491                              uint8_t* iv, const vector<SubSample>& subSamples,
492                              const Pattern& pattern, const vector<uint8_t>& key,
493                              Status expectedStatus) {
494     const size_t kSegmentIndex = 0;
495 
496     uint8_t localIv[AES_BLOCK_SIZE];
497     memcpy(localIv, iv, AES_BLOCK_SIZE);
498     vector<uint8_t> ivVec(localIv, localIv + AES_BLOCK_SIZE);
499     vector<uint8_t> keyIdVec(keyId.begin(), keyId.end());
500 
501     int64_t totalSize = 0;
502     for (size_t i = 0; i < subSamples.size(); i++) {
503         totalSize += subSamples[i].numBytesOfClearData;
504         totalSize += subSamples[i].numBytesOfEncryptedData;
505     }
506 
507     // The first totalSize bytes of shared memory is the encrypted
508     // input, the second totalSize bytes (if exists) is the decrypted output.
509     size_t factor = expectedStatus == Status::ERROR_DRM_FRAME_TOO_LARGE ? 1 : 2;
510     SharedBuffer sourceBuffer;
511     getDecryptMemory(totalSize * factor, kSegmentIndex, sourceBuffer);
512     auto base = fillRandom(sourceBuffer);
513 
514     SharedBuffer sourceRange;
515     sourceRange.bufferId = kSegmentIndex;
516     sourceRange.offset = 0;
517     sourceRange.size = totalSize;
518 
519     SharedBuffer destRange;
520     destRange.bufferId = kSegmentIndex;
521     destRange.offset = totalSize;
522     destRange.size = totalSize;
523 
524     DecryptArgs args;
525     args.secure = isSecure;
526     args.keyId = keyIdVec;
527     args.iv = ivVec;
528     args.mode = mode;
529     args.pattern = pattern;
530     args.subSamples = subSamples;
531     args.source = std::move(sourceRange);
532     args.offset = 0;
533     args.destination = std::move(destRange);
534 
535     int32_t bytesWritten = 0;
536     auto ret = cryptoPlugin->decrypt(args, &bytesWritten);
537     EXPECT_TXN(ret);
538     EXPECT_EQ(expectedStatus, DrmErr(ret)) << "Unexpected decrypt status " << ret.getMessage();
539 
540     if (bytesWritten != totalSize) {
541         return bytesWritten;
542     }
543 
544     // generate reference vector
545     vector<uint8_t> reference(totalSize);
546 
547     memcpy(localIv, iv, AES_BLOCK_SIZE);
548     switch (mode) {
549         case Mode::UNENCRYPTED:
550             memcpy(&reference[0], base, totalSize);
551             break;
552         case Mode::AES_CTR:
553             aes_ctr_decrypt(&reference[0], base, localIv, subSamples, key);
554             break;
555         case Mode::AES_CBC:
556             aes_cbc_decrypt(&reference[0], base, localIv, subSamples, key);
557             break;
558         case Mode::AES_CBC_CTS:
559             ADD_FAILURE() << "AES_CBC_CTS mode not supported";
560             break;
561     }
562 
563     // compare reference to decrypted data which is at base + total size
564     EXPECT_EQ(0, memcmp(static_cast<void*>(&reference[0]), static_cast<void*>(base + totalSize),
565                         totalSize))
566             << "decrypt data mismatch";
567     munmap(base, totalSize * factor);
568     return totalSize;
569 }
570 
571 /**
572  * Decrypt a list of clear+encrypted subsamples using the specified key
573  * in AES-CTR mode
574  */
aes_ctr_decrypt(uint8_t * dest,uint8_t * src,uint8_t * iv,const vector<SubSample> & subSamples,const vector<uint8_t> & key)575 void DrmHalTest::aes_ctr_decrypt(uint8_t* dest, uint8_t* src, uint8_t* iv,
576                                  const vector<SubSample>& subSamples, const vector<uint8_t>& key) {
577     AES_KEY decryptionKey;
578     AES_set_encrypt_key(&key[0], 128, &decryptionKey);
579 
580     size_t offset = 0;
581     unsigned int blockOffset = 0;
582     uint8_t previousEncryptedCounter[AES_BLOCK_SIZE];
583     memset(previousEncryptedCounter, 0, AES_BLOCK_SIZE);
584 
585     for (size_t i = 0; i < subSamples.size(); i++) {
586         const SubSample& subSample = subSamples[i];
587 
588         if (subSample.numBytesOfClearData > 0) {
589             memcpy(dest + offset, src + offset, subSample.numBytesOfClearData);
590             offset += subSample.numBytesOfClearData;
591         }
592 
593         if (subSample.numBytesOfEncryptedData > 0) {
594             AES_ctr128_encrypt(src + offset, dest + offset, subSample.numBytesOfEncryptedData,
595                                &decryptionKey, iv, previousEncryptedCounter, &blockOffset);
596             offset += subSample.numBytesOfEncryptedData;
597         }
598     }
599 }
600 
601 /**
602  * Decrypt a list of clear+encrypted subsamples using the specified key
603  * in AES-CBC mode
604  */
aes_cbc_decrypt(uint8_t * dest,uint8_t * src,uint8_t * iv,const vector<SubSample> & subSamples,const vector<uint8_t> & key)605 void DrmHalTest::aes_cbc_decrypt(uint8_t* dest, uint8_t* src, uint8_t* iv,
606                                  const vector<SubSample>& subSamples, const vector<uint8_t>& key) {
607     AES_KEY decryptionKey;
608     AES_set_encrypt_key(&key[0], 128, &decryptionKey);
609 
610     size_t offset = 0;
611     for (size_t i = 0; i < subSamples.size(); i++) {
612         memcpy(dest + offset, src + offset, subSamples[i].numBytesOfClearData);
613         offset += subSamples[i].numBytesOfClearData;
614 
615         AES_cbc_encrypt(src + offset, dest + offset, subSamples[i].numBytesOfEncryptedData,
616                         &decryptionKey, iv, 0 /* decrypt */);
617         offset += subSamples[i].numBytesOfEncryptedData;
618     }
619 }
620 
621 /**
622  * Helper method to test decryption with invalid keys is returned
623  */
decryptWithInvalidKeys(vector<uint8_t> & invalidResponse,vector<uint8_t> & iv,const Pattern & noPattern,const vector<SubSample> & subSamples)624 void DrmHalClearkeyTest::decryptWithInvalidKeys(vector<uint8_t>& invalidResponse,
625                                                 vector<uint8_t>& iv, const Pattern& noPattern,
626                                                 const vector<SubSample>& subSamples) {
627     DrmHalVTSVendorModule_V1::ContentConfiguration content = getContent();
628     if (content.keys.empty()) {
629         FAIL() << "no keys";
630     }
631 
632     const auto& key = content.keys[0];
633     auto sessionId = openSession();
634     KeySetId result;
635     auto ret = drmPlugin->provideKeyResponse(sessionId, invalidResponse, &result);
636 
637     EXPECT_OK(ret);
638     EXPECT_EQ(0u, result.keySetId.size());
639 
640     EXPECT_OK(cryptoPlugin->setMediaDrmSession(sessionId));
641 
642     uint32_t byteCount =
643             decrypt(Mode::AES_CTR, key.isSecure, toStdArray(key.keyId), &iv[0], subSamples,
644                     noPattern, key.clearContentKey, Status::ERROR_DRM_NO_LICENSE);
645     EXPECT_EQ(0u, byteCount);
646 
647     closeSession(sessionId);
648 }
649 
650 }  // namespace vts
651 }  // namespace drm
652 }  // namespace hardware
653 }  // namespace android
654 }  // namespace aidl
655