1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/TensorHandleFactoryRegistry.hpp> 7*89c4ff92SAndroid Build Coastguard Worker #include <cl/ClBackend.hpp> 8*89c4ff92SAndroid Build Coastguard Worker #include <cl/ClTensorHandleFactory.hpp> 9*89c4ff92SAndroid Build Coastguard Worker #include <cl/ClImportTensorHandleFactory.hpp> 10*89c4ff92SAndroid Build Coastguard Worker #include <cl/test/ClContextControlFixture.hpp> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 15*89c4ff92SAndroid Build Coastguard Worker 16*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("ClBackendTests") 17*89c4ff92SAndroid Build Coastguard Worker { 18*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ClRegisterTensorHandleFactoriesMatchingImportFactoryId") 19*89c4ff92SAndroid Build Coastguard Worker { 20*89c4ff92SAndroid Build Coastguard Worker auto clBackend = std::make_unique<ClBackend>(); 21*89c4ff92SAndroid Build Coastguard Worker TensorHandleFactoryRegistry registry; 22*89c4ff92SAndroid Build Coastguard Worker clBackend->RegisterTensorHandleFactories(registry); 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker // When calling RegisterTensorHandleFactories, CopyAndImportFactoryPair is registered 25*89c4ff92SAndroid Build Coastguard Worker // Get ClImportTensorHandleFactory id as the matching import factory id 26*89c4ff92SAndroid Build Coastguard Worker CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) == 27*89c4ff92SAndroid Build Coastguard Worker ClImportTensorHandleFactory::GetIdStatic())); 28*89c4ff92SAndroid Build Coastguard Worker } 29*89c4ff92SAndroid Build Coastguard Worker 30*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ClRegisterTensorHandleFactoriesWithMemorySourceFlagsMatchingImportFactoryId") 31*89c4ff92SAndroid Build Coastguard Worker { 32*89c4ff92SAndroid Build Coastguard Worker auto clBackend = std::make_unique<ClBackend>(); 33*89c4ff92SAndroid Build Coastguard Worker TensorHandleFactoryRegistry registry; 34*89c4ff92SAndroid Build Coastguard Worker clBackend->RegisterTensorHandleFactories(registry, 35*89c4ff92SAndroid Build Coastguard Worker static_cast<MemorySourceFlags>(MemorySource::Malloc), 36*89c4ff92SAndroid Build Coastguard Worker static_cast<MemorySourceFlags>(MemorySource::Malloc)); 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker // When calling RegisterTensorHandleFactories with MemorySourceFlags, CopyAndImportFactoryPair is registered 39*89c4ff92SAndroid Build Coastguard Worker // Get ClImportTensorHandleFactory id as the matching import factory id 40*89c4ff92SAndroid Build Coastguard Worker CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) == 41*89c4ff92SAndroid Build Coastguard Worker ClImportTensorHandleFactory::GetIdStatic())); 42*89c4ff92SAndroid Build Coastguard Worker } 43*89c4ff92SAndroid Build Coastguard Worker 44*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryMatchingImportFactoryId") 45*89c4ff92SAndroid Build Coastguard Worker { 46*89c4ff92SAndroid Build Coastguard Worker auto clBackend = std::make_unique<ClBackend>(); 47*89c4ff92SAndroid Build Coastguard Worker TensorHandleFactoryRegistry registry; 48*89c4ff92SAndroid Build Coastguard Worker clBackend->CreateWorkloadFactory(registry); 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker // When calling CreateWorkloadFactory, CopyAndImportFactoryPair is registered 51*89c4ff92SAndroid Build Coastguard Worker // Get ClImportTensorHandleFactory id as the matching import factory id 52*89c4ff92SAndroid Build Coastguard Worker CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) == 53*89c4ff92SAndroid Build Coastguard Worker ClImportTensorHandleFactory::GetIdStatic())); 54*89c4ff92SAndroid Build Coastguard Worker } 55*89c4ff92SAndroid Build Coastguard Worker 56*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryWithOptionsMatchingImportFactoryId") 57*89c4ff92SAndroid Build Coastguard Worker { 58*89c4ff92SAndroid Build Coastguard Worker auto clBackend = std::make_unique<ClBackend>(); 59*89c4ff92SAndroid Build Coastguard Worker TensorHandleFactoryRegistry registry; 60*89c4ff92SAndroid Build Coastguard Worker ModelOptions modelOptions; 61*89c4ff92SAndroid Build Coastguard Worker clBackend->CreateWorkloadFactory(registry, modelOptions); 62*89c4ff92SAndroid Build Coastguard Worker 63*89c4ff92SAndroid Build Coastguard Worker // When calling CreateWorkloadFactory with ModelOptions, CopyAndImportFactoryPair is registered 64*89c4ff92SAndroid Build Coastguard Worker // Get ClImportTensorHandleFactory id as the matching import factory id 65*89c4ff92SAndroid Build Coastguard Worker CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) == 66*89c4ff92SAndroid Build Coastguard Worker ClImportTensorHandleFactory::GetIdStatic())); 67*89c4ff92SAndroid Build Coastguard Worker } 68*89c4ff92SAndroid Build Coastguard Worker 69*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryWitMemoryFlagsMatchingImportFactoryId") 70*89c4ff92SAndroid Build Coastguard Worker { 71*89c4ff92SAndroid Build Coastguard Worker auto clBackend = std::make_unique<ClBackend>(); 72*89c4ff92SAndroid Build Coastguard Worker TensorHandleFactoryRegistry registry; 73*89c4ff92SAndroid Build Coastguard Worker ModelOptions modelOptions; 74*89c4ff92SAndroid Build Coastguard Worker clBackend->CreateWorkloadFactory(registry, modelOptions, 75*89c4ff92SAndroid Build Coastguard Worker static_cast<MemorySourceFlags>(MemorySource::Malloc), 76*89c4ff92SAndroid Build Coastguard Worker static_cast<MemorySourceFlags>(MemorySource::Malloc)); 77*89c4ff92SAndroid Build Coastguard Worker 78*89c4ff92SAndroid Build Coastguard Worker // When calling CreateWorkloadFactory with ModelOptions and MemorySourceFlags, 79*89c4ff92SAndroid Build Coastguard Worker // CopyAndImportFactoryPair is registered 80*89c4ff92SAndroid Build Coastguard Worker // Get ClImportTensorHandleFactory id as the matching import factory id 81*89c4ff92SAndroid Build Coastguard Worker CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) == 82*89c4ff92SAndroid Build Coastguard Worker ClImportTensorHandleFactory::GetIdStatic())); 83*89c4ff92SAndroid Build Coastguard Worker } 84*89c4ff92SAndroid Build Coastguard Worker } 85