xref: /aosp_15_r20/hardware/interfaces/neuralnetworks/1.3/utils/src/Conversions.cpp (revision 4d7e907c777eeecc4c5bd7cf640a754fac206ff7)
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Conversions.h"
18 
19 #include <android-base/logging.h>
20 #include <android/hardware/neuralnetworks/1.3/types.h>
21 #include <nnapi/OperandTypes.h>
22 #include <nnapi/OperationTypes.h>
23 #include <nnapi/Result.h>
24 #include <nnapi/SharedMemory.h>
25 #include <nnapi/TypeUtils.h>
26 #include <nnapi/Types.h>
27 #include <nnapi/Validation.h>
28 #include <nnapi/hal/1.0/Conversions.h>
29 #include <nnapi/hal/1.2/Conversions.h>
30 #include <nnapi/hal/CommonUtils.h>
31 
32 #include <algorithm>
33 #include <chrono>
34 #include <functional>
35 #include <iterator>
36 #include <limits>
37 #include <type_traits>
38 #include <utility>
39 
40 #include "Utils.h"
41 
42 namespace {
43 
makeNanosFromUint64(uint64_t nanoseconds)44 std::chrono::nanoseconds makeNanosFromUint64(uint64_t nanoseconds) {
45     constexpr auto kMaxCount = std::chrono::nanoseconds::max().count();
46     using CommonType = std::common_type_t<std::chrono::nanoseconds::rep, uint64_t>;
47     const auto count = std::min<CommonType>(kMaxCount, nanoseconds);
48     return std::chrono::nanoseconds{static_cast<std::chrono::nanoseconds::rep>(count)};
49 }
50 
makeUint64FromNanos(std::chrono::nanoseconds nanoseconds)51 uint64_t makeUint64FromNanos(std::chrono::nanoseconds nanoseconds) {
52     if (nanoseconds < std::chrono::nanoseconds::zero()) {
53         return 0;
54     }
55     constexpr auto kMaxCount = std::numeric_limits<uint64_t>::max();
56     using CommonType = std::common_type_t<std::chrono::nanoseconds::rep, uint64_t>;
57     const auto count = std::min<CommonType>(kMaxCount, nanoseconds.count());
58     return static_cast<uint64_t>(count);
59 }
60 
61 template <typename Type>
underlyingType(Type value)62 constexpr std::underlying_type_t<Type> underlyingType(Type value) {
63     return static_cast<std::underlying_type_t<Type>>(value);
64 }
65 
66 }  // namespace
67 
68 namespace android::nn {
69 namespace {
70 
71 using hardware::hidl_vec;
72 
73 template <typename Input>
74 using UnvalidatedConvertOutput =
75         std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
76 
77 template <typename Type>
unvalidatedConvert(const hidl_vec<Type> & arguments)78 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
79         const hidl_vec<Type>& arguments) {
80     std::vector<UnvalidatedConvertOutput<Type>> canonical;
81     canonical.reserve(arguments.size());
82     for (const auto& argument : arguments) {
83         canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
84     }
85     return canonical;
86 }
87 
88 template <typename Type>
validatedConvert(const Type & halObject)89 GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
90     auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
91     NN_TRY(hal::V1_3::utils::compliantVersion(canonical));
92     return canonical;
93 }
94 
95 template <typename Type>
validatedConvert(const hidl_vec<Type> & arguments)96 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
97         const hidl_vec<Type>& arguments) {
98     std::vector<UnvalidatedConvertOutput<Type>> canonical;
99     canonical.reserve(arguments.size());
100     for (const auto& argument : arguments) {
101         canonical.push_back(NN_TRY(validatedConvert(argument)));
102     }
103     return canonical;
104 }
105 
106 }  // anonymous namespace
107 
unvalidatedConvert(const hal::V1_3::OperandType & operandType)108 GeneralResult<OperandType> unvalidatedConvert(const hal::V1_3::OperandType& operandType) {
109     return static_cast<OperandType>(operandType);
110 }
111 
unvalidatedConvert(const hal::V1_3::OperationType & operationType)112 GeneralResult<OperationType> unvalidatedConvert(const hal::V1_3::OperationType& operationType) {
113     return static_cast<OperationType>(operationType);
114 }
115 
unvalidatedConvert(const hal::V1_3::Priority & priority)116 GeneralResult<Priority> unvalidatedConvert(const hal::V1_3::Priority& priority) {
117     return static_cast<Priority>(priority);
118 }
119 
unvalidatedConvert(const hal::V1_3::Capabilities & capabilities)120 GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_3::Capabilities& capabilities) {
121     const bool validOperandTypes = std::all_of(
122             capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
123             [](const hal::V1_3::Capabilities::OperandPerformance& operandPerformance) {
124                 return validatedConvert(operandPerformance.type).has_value();
125             });
126     if (!validOperandTypes) {
127         return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
128                << "Invalid OperandType when unvalidatedConverting OperandPerformance in "
129                   "Capabilities";
130     }
131 
132     auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
133     auto table =
134             NN_TRY(Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)));
135 
136     const auto relaxedFloat32toFloat16PerformanceScalar =
137             NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar));
138     const auto relaxedFloat32toFloat16PerformanceTensor =
139             NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor));
140     const auto ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance));
141     const auto whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance));
142     return Capabilities{
143             .relaxedFloat32toFloat16PerformanceScalar = relaxedFloat32toFloat16PerformanceScalar,
144             .relaxedFloat32toFloat16PerformanceTensor = relaxedFloat32toFloat16PerformanceTensor,
145             .operandPerformance = std::move(table),
146             .ifPerformance = ifPerformance,
147             .whilePerformance = whilePerformance,
148     };
149 }
150 
unvalidatedConvert(const hal::V1_3::Capabilities::OperandPerformance & operandPerformance)151 GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
152         const hal::V1_3::Capabilities::OperandPerformance& operandPerformance) {
153     const auto type = NN_TRY(unvalidatedConvert(operandPerformance.type));
154     const auto info = NN_TRY(unvalidatedConvert(operandPerformance.info));
155     return Capabilities::OperandPerformance{
156             .type = type,
157             .info = info,
158     };
159 }
160 
unvalidatedConvert(const hal::V1_3::Operation & operation)161 GeneralResult<Operation> unvalidatedConvert(const hal::V1_3::Operation& operation) {
162     const auto type = NN_TRY(unvalidatedConvert(operation.type));
163     return Operation{
164             .type = type,
165             .inputs = operation.inputs,
166             .outputs = operation.outputs,
167     };
168 }
169 
unvalidatedConvert(const hal::V1_3::OperandLifeTime & operandLifeTime)170 GeneralResult<Operand::LifeTime> unvalidatedConvert(
171         const hal::V1_3::OperandLifeTime& operandLifeTime) {
172     return static_cast<Operand::LifeTime>(operandLifeTime);
173 }
174 
unvalidatedConvert(const hal::V1_3::Operand & operand)175 GeneralResult<Operand> unvalidatedConvert(const hal::V1_3::Operand& operand) {
176     const auto type = NN_TRY(unvalidatedConvert(operand.type));
177     const auto lifetime = NN_TRY(unvalidatedConvert(operand.lifetime));
178     const auto location = NN_TRY(unvalidatedConvert(operand.location));
179     auto extraParams = NN_TRY(unvalidatedConvert(operand.extraParams));
180     return Operand{
181             .type = type,
182             .dimensions = operand.dimensions,
183             .scale = operand.scale,
184             .zeroPoint = operand.zeroPoint,
185             .lifetime = lifetime,
186             .location = location,
187             .extraParams = std::move(extraParams),
188     };
189 }
190 
unvalidatedConvert(const hal::V1_3::Model & model)191 GeneralResult<Model> unvalidatedConvert(const hal::V1_3::Model& model) {
192     auto main = NN_TRY(unvalidatedConvert(model.main));
193     auto referenced = NN_TRY(unvalidatedConvert(model.referenced));
194     auto operandValues = NN_TRY(unvalidatedConvert(model.operandValues));
195     auto pools = NN_TRY(unvalidatedConvert(model.pools));
196     auto extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix));
197     return Model{
198             .main = std::move(main),
199             .referenced = std::move(referenced),
200             .operandValues = std::move(operandValues),
201             .pools = std::move(pools),
202             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
203             .extensionNameToPrefix = std::move(extensionNameToPrefix),
204     };
205 }
206 
unvalidatedConvert(const hal::V1_3::Subgraph & subgraph)207 GeneralResult<Model::Subgraph> unvalidatedConvert(const hal::V1_3::Subgraph& subgraph) {
208     auto operations = NN_TRY(unvalidatedConvert(subgraph.operations));
209 
210     // Verify number of consumers.
211     const auto numberOfConsumers =
212             NN_TRY(countNumberOfConsumers(subgraph.operands.size(), operations));
213     CHECK(subgraph.operands.size() == numberOfConsumers.size());
214     for (size_t i = 0; i < subgraph.operands.size(); ++i) {
215         if (subgraph.operands[i].numberOfConsumers != numberOfConsumers[i]) {
216             return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
217                    << "Invalid numberOfConsumers for operand " << i << ", expected "
218                    << numberOfConsumers[i] << " but found "
219                    << subgraph.operands[i].numberOfConsumers;
220         }
221     }
222 
223     auto operands = NN_TRY(unvalidatedConvert(subgraph.operands));
224     return Model::Subgraph{
225             .operands = std::move(operands),
226             .operations = std::move(operations),
227             .inputIndexes = subgraph.inputIndexes,
228             .outputIndexes = subgraph.outputIndexes,
229     };
230 }
231 
unvalidatedConvert(const hal::V1_3::BufferDesc & bufferDesc)232 GeneralResult<BufferDesc> unvalidatedConvert(const hal::V1_3::BufferDesc& bufferDesc) {
233     return BufferDesc{.dimensions = bufferDesc.dimensions};
234 }
235 
unvalidatedConvert(const hal::V1_3::BufferRole & bufferRole)236 GeneralResult<BufferRole> unvalidatedConvert(const hal::V1_3::BufferRole& bufferRole) {
237     return BufferRole{
238             .modelIndex = bufferRole.modelIndex,
239             .ioIndex = bufferRole.ioIndex,
240             .probability = bufferRole.frequency,
241     };
242 }
243 
unvalidatedConvert(const hal::V1_3::Request & request)244 GeneralResult<Request> unvalidatedConvert(const hal::V1_3::Request& request) {
245     auto inputs = NN_TRY(unvalidatedConvert(request.inputs));
246     auto outputs = NN_TRY(unvalidatedConvert(request.outputs));
247     auto pools = NN_TRY(unvalidatedConvert(request.pools));
248     return Request{
249             .inputs = std::move(inputs),
250             .outputs = std::move(outputs),
251             .pools = std::move(pools),
252     };
253 }
254 
unvalidatedConvert(const hal::V1_3::Request::MemoryPool & memoryPool)255 GeneralResult<Request::MemoryPool> unvalidatedConvert(
256         const hal::V1_3::Request::MemoryPool& memoryPool) {
257     using Discriminator = hal::V1_3::Request::MemoryPool::hidl_discriminator;
258     switch (memoryPool.getDiscriminator()) {
259         case Discriminator::hidlMemory:
260             return unvalidatedConvert(memoryPool.hidlMemory());
261         case Discriminator::token:
262             return static_cast<Request::MemoryDomainToken>(memoryPool.token());
263     }
264     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
265            << "Invalid Request::MemoryPool discriminator "
266            << underlyingType(memoryPool.getDiscriminator());
267 }
268 
unvalidatedConvert(const hal::V1_3::OptionalTimePoint & optionalTimePoint)269 GeneralResult<OptionalTimePoint> unvalidatedConvert(
270         const hal::V1_3::OptionalTimePoint& optionalTimePoint) {
271     using Discriminator = hal::V1_3::OptionalTimePoint::hidl_discriminator;
272     switch (optionalTimePoint.getDiscriminator()) {
273         case Discriminator::none:
274             return {};
275         case Discriminator::nanosecondsSinceEpoch: {
276             const auto currentSteadyTime = std::chrono::steady_clock::now();
277             const auto currentBootTime = Clock::now();
278 
279             const auto timeSinceEpoch =
280                     makeNanosFromUint64(optionalTimePoint.nanosecondsSinceEpoch());
281             const auto steadyTimePoint = std::chrono::steady_clock::time_point{timeSinceEpoch};
282 
283             // Both steadyTimePoint and currentSteadyTime are guaranteed to be non-negative, so this
284             // subtraction will never overflow or underflow.
285             const auto timeRemaining = steadyTimePoint - currentSteadyTime;
286 
287             // currentBootTime is guaranteed to be non-negative, so this code only protects against
288             // an overflow.
289             nn::TimePoint bootTimePoint;
290             constexpr auto kZeroNano = std::chrono::nanoseconds::zero();
291             constexpr auto kMaxTime = nn::TimePoint::max();
292             if (timeRemaining > kZeroNano && currentBootTime > kMaxTime - timeRemaining) {
293                 bootTimePoint = kMaxTime;
294             } else {
295                 bootTimePoint = currentBootTime + timeRemaining;
296             }
297 
298             constexpr auto kZeroTime = nn::TimePoint{};
299             return std::max(bootTimePoint, kZeroTime);
300         }
301     }
302     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
303            << "Invalid OptionalTimePoint discriminator "
304            << underlyingType(optionalTimePoint.getDiscriminator());
305 }
306 
unvalidatedConvert(const hal::V1_3::OptionalTimeoutDuration & optionalTimeoutDuration)307 GeneralResult<OptionalDuration> unvalidatedConvert(
308         const hal::V1_3::OptionalTimeoutDuration& optionalTimeoutDuration) {
309     using Discriminator = hal::V1_3::OptionalTimeoutDuration::hidl_discriminator;
310     switch (optionalTimeoutDuration.getDiscriminator()) {
311         case Discriminator::none:
312             return {};
313         case Discriminator::nanoseconds:
314             return Duration(optionalTimeoutDuration.nanoseconds());
315     }
316     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
317            << "Invalid OptionalTimeoutDuration discriminator "
318            << underlyingType(optionalTimeoutDuration.getDiscriminator());
319 }
320 
unvalidatedConvert(const hal::V1_3::ErrorStatus & status)321 GeneralResult<ErrorStatus> unvalidatedConvert(const hal::V1_3::ErrorStatus& status) {
322     switch (status) {
323         case hal::V1_3::ErrorStatus::NONE:
324         case hal::V1_3::ErrorStatus::DEVICE_UNAVAILABLE:
325         case hal::V1_3::ErrorStatus::GENERAL_FAILURE:
326         case hal::V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
327         case hal::V1_3::ErrorStatus::INVALID_ARGUMENT:
328         case hal::V1_3::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
329         case hal::V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
330         case hal::V1_3::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
331         case hal::V1_3::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
332             return static_cast<ErrorStatus>(status);
333     }
334     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
335            << "Invalid ErrorStatus " << underlyingType(status);
336 }
337 
convert(const hal::V1_3::Priority & priority)338 GeneralResult<Priority> convert(const hal::V1_3::Priority& priority) {
339     return validatedConvert(priority);
340 }
341 
convert(const hal::V1_3::Capabilities & capabilities)342 GeneralResult<Capabilities> convert(const hal::V1_3::Capabilities& capabilities) {
343     return validatedConvert(capabilities);
344 }
345 
convert(const hal::V1_3::Model & model)346 GeneralResult<Model> convert(const hal::V1_3::Model& model) {
347     return validatedConvert(model);
348 }
349 
convert(const hal::V1_3::BufferDesc & bufferDesc)350 GeneralResult<BufferDesc> convert(const hal::V1_3::BufferDesc& bufferDesc) {
351     return validatedConvert(bufferDesc);
352 }
353 
convert(const hal::V1_3::Request & request)354 GeneralResult<Request> convert(const hal::V1_3::Request& request) {
355     return validatedConvert(request);
356 }
357 
convert(const hal::V1_3::OptionalTimePoint & optionalTimePoint)358 GeneralResult<OptionalTimePoint> convert(const hal::V1_3::OptionalTimePoint& optionalTimePoint) {
359     return validatedConvert(optionalTimePoint);
360 }
361 
convert(const hal::V1_3::OptionalTimeoutDuration & optionalTimeoutDuration)362 GeneralResult<OptionalDuration> convert(
363         const hal::V1_3::OptionalTimeoutDuration& optionalTimeoutDuration) {
364     return validatedConvert(optionalTimeoutDuration);
365 }
366 
convert(const hal::V1_3::ErrorStatus & errorStatus)367 GeneralResult<ErrorStatus> convert(const hal::V1_3::ErrorStatus& errorStatus) {
368     return validatedConvert(errorStatus);
369 }
370 
convert(const hardware::hidl_handle & handle)371 GeneralResult<SharedHandle> convert(const hardware::hidl_handle& handle) {
372     return validatedConvert(handle);
373 }
374 
convert(const hardware::hidl_vec<hal::V1_3::BufferRole> & bufferRoles)375 GeneralResult<std::vector<BufferRole>> convert(
376         const hardware::hidl_vec<hal::V1_3::BufferRole>& bufferRoles) {
377     return validatedConvert(bufferRoles);
378 }
379 
380 }  // namespace android::nn
381 
382 namespace android::hardware::neuralnetworks::V1_3::utils {
383 namespace {
384 
385 using utils::unvalidatedConvert;
386 
unvalidatedConvert(const nn::Capabilities::PerformanceInfo & performanceInfo)387 nn::GeneralResult<V1_0::PerformanceInfo> unvalidatedConvert(
388         const nn::Capabilities::PerformanceInfo& performanceInfo) {
389     return V1_0::utils::unvalidatedConvert(performanceInfo);
390 }
391 
unvalidatedConvert(const nn::DataLocation & dataLocation)392 nn::GeneralResult<V1_0::DataLocation> unvalidatedConvert(const nn::DataLocation& dataLocation) {
393     return V1_0::utils::unvalidatedConvert(dataLocation);
394 }
395 
unvalidatedConvert(const nn::Model::OperandValues & operandValues)396 nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
397         const nn::Model::OperandValues& operandValues) {
398     return V1_0::utils::unvalidatedConvert(operandValues);
399 }
400 
unvalidatedConvert(const nn::SharedHandle & handle)401 nn::GeneralResult<hidl_handle> unvalidatedConvert(const nn::SharedHandle& handle) {
402     return V1_0::utils::unvalidatedConvert(handle);
403 }
404 
unvalidatedConvert(const nn::SharedMemory & memory)405 nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory) {
406     return V1_0::utils::unvalidatedConvert(memory);
407 }
408 
unvalidatedConvert(const nn::Request::Argument & argument)409 nn::GeneralResult<V1_0::RequestArgument> unvalidatedConvert(const nn::Request::Argument& argument) {
410     return V1_0::utils::unvalidatedConvert(argument);
411 }
412 
unvalidatedConvert(const nn::Operand::ExtraParams & extraParams)413 nn::GeneralResult<V1_2::Operand::ExtraParams> unvalidatedConvert(
414         const nn::Operand::ExtraParams& extraParams) {
415     return V1_2::utils::unvalidatedConvert(extraParams);
416 }
417 
unvalidatedConvert(const nn::ExtensionNameAndPrefix & extensionNameAndPrefix)418 nn::GeneralResult<V1_2::Model::ExtensionNameAndPrefix> unvalidatedConvert(
419         const nn::ExtensionNameAndPrefix& extensionNameAndPrefix) {
420     return V1_2::utils::unvalidatedConvert(extensionNameAndPrefix);
421 }
422 
423 template <typename Input>
424 using UnvalidatedConvertOutput =
425         std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
426 
427 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)428 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
429         const std::vector<Type>& arguments) {
430     hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
431     for (size_t i = 0; i < arguments.size(); ++i) {
432         halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
433     }
434     return halObject;
435 }
436 
makeMemoryPool(const nn::SharedMemory & memory)437 nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::SharedMemory& memory) {
438     Request::MemoryPool ret;
439     ret.hidlMemory(NN_TRY(unvalidatedConvert(memory)));
440     return ret;
441 }
442 
makeMemoryPool(const nn::Request::MemoryDomainToken & token)443 nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::Request::MemoryDomainToken& token) {
444     Request::MemoryPool ret;
445     ret.token(underlyingType(token));
446     return ret;
447 }
448 
makeMemoryPool(const nn::SharedBuffer &)449 nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::SharedBuffer& /*buffer*/) {
450     return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "Unable to make memory pool from IBuffer";
451 }
452 
453 using utils::unvalidatedConvert;
454 
455 template <typename Type>
validatedConvert(const Type & canonical)456 nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
457     NN_TRY(compliantVersion(canonical));
458     return unvalidatedConvert(canonical);
459 }
460 
461 template <typename Type>
validatedConvert(const std::vector<Type> & arguments)462 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> validatedConvert(
463         const std::vector<Type>& arguments) {
464     hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
465     for (size_t i = 0; i < arguments.size(); ++i) {
466         halObject[i] = NN_TRY(validatedConvert(arguments[i]));
467     }
468     return halObject;
469 }
470 
471 }  // anonymous namespace
472 
unvalidatedConvert(const nn::OperandType & operandType)473 nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
474     return static_cast<OperandType>(operandType);
475 }
476 
unvalidatedConvert(const nn::OperationType & operationType)477 nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
478     return static_cast<OperationType>(operationType);
479 }
480 
unvalidatedConvert(const nn::Priority & priority)481 nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority) {
482     return static_cast<Priority>(priority);
483 }
484 
unvalidatedConvert(const nn::Capabilities & capabilities)485 nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities) {
486     std::vector<nn::Capabilities::OperandPerformance> filteredOperandPerformances;
487     filteredOperandPerformances.reserve(capabilities.operandPerformance.asVector().size());
488     std::copy_if(capabilities.operandPerformance.asVector().begin(),
489                  capabilities.operandPerformance.asVector().end(),
490                  std::back_inserter(filteredOperandPerformances),
491                  [](const nn::Capabilities::OperandPerformance& operandPerformance) {
492                      return compliantVersion(operandPerformance.type).has_value();
493                  });
494 
495     const auto relaxedFloat32toFloat16PerformanceScalar =
496             NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar));
497     const auto relaxedFloat32toFloat16PerformanceTensor =
498             NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor));
499     auto operandPerformance = NN_TRY(unvalidatedConvert(filteredOperandPerformances));
500     const auto ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance));
501     const auto whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance));
502     return Capabilities{
503             .relaxedFloat32toFloat16PerformanceScalar = relaxedFloat32toFloat16PerformanceScalar,
504             .relaxedFloat32toFloat16PerformanceTensor = relaxedFloat32toFloat16PerformanceTensor,
505             .operandPerformance = std::move(operandPerformance),
506             .ifPerformance = ifPerformance,
507             .whilePerformance = whilePerformance,
508     };
509 }
510 
unvalidatedConvert(const nn::Capabilities::OperandPerformance & operandPerformance)511 nn::GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
512         const nn::Capabilities::OperandPerformance& operandPerformance) {
513     const auto type = NN_TRY(unvalidatedConvert(operandPerformance.type));
514     const auto info = NN_TRY(unvalidatedConvert(operandPerformance.info));
515     return Capabilities::OperandPerformance{
516             .type = type,
517             .info = info,
518     };
519 }
520 
unvalidatedConvert(const nn::Operation & operation)521 nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
522     const auto type = NN_TRY(unvalidatedConvert(operation.type));
523     return Operation{
524             .type = type,
525             .inputs = operation.inputs,
526             .outputs = operation.outputs,
527     };
528 }
529 
unvalidatedConvert(const nn::Operand::LifeTime & operandLifeTime)530 nn::GeneralResult<OperandLifeTime> unvalidatedConvert(
531         const nn::Operand::LifeTime& operandLifeTime) {
532     if (operandLifeTime == nn::Operand::LifeTime::POINTER) {
533         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
534                << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
535     }
536     return static_cast<OperandLifeTime>(operandLifeTime);
537 }
538 
unvalidatedConvert(const nn::Operand & operand)539 nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
540     const auto type = NN_TRY(unvalidatedConvert(operand.type));
541     const auto lifetime = NN_TRY(unvalidatedConvert(operand.lifetime));
542     const auto location = NN_TRY(unvalidatedConvert(operand.location));
543     auto extraParams = NN_TRY(unvalidatedConvert(operand.extraParams));
544     return Operand{
545             .type = type,
546             .dimensions = operand.dimensions,
547             .numberOfConsumers = 0,
548             .scale = operand.scale,
549             .zeroPoint = operand.zeroPoint,
550             .lifetime = lifetime,
551             .location = location,
552             .extraParams = std::move(extraParams),
553     };
554 }
555 
unvalidatedConvert(const nn::Model & model)556 nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
557     if (!hal::utils::hasNoPointerData(model)) {
558         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
559                << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
560     }
561 
562     auto main = NN_TRY(unvalidatedConvert(model.main));
563     auto referenced = NN_TRY(unvalidatedConvert(model.referenced));
564     auto operandValues = NN_TRY(unvalidatedConvert(model.operandValues));
565     auto pools = NN_TRY(unvalidatedConvert(model.pools));
566     auto extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix));
567     return Model{
568             .main = std::move(main),
569             .referenced = std::move(referenced),
570             .operandValues = std::move(operandValues),
571             .pools = std::move(pools),
572             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
573             .extensionNameToPrefix = std::move(extensionNameToPrefix),
574     };
575 }
576 
unvalidatedConvert(const nn::Model::Subgraph & subgraph)577 nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgraph) {
578     auto operands = NN_TRY(unvalidatedConvert(subgraph.operands));
579 
580     // Update number of consumers.
581     const auto numberOfConsumers =
582             NN_TRY(countNumberOfConsumers(operands.size(), subgraph.operations));
583     CHECK(operands.size() == numberOfConsumers.size());
584     for (size_t i = 0; i < operands.size(); ++i) {
585         operands[i].numberOfConsumers = numberOfConsumers[i];
586     }
587 
588     auto operations = NN_TRY(unvalidatedConvert(subgraph.operations));
589     return Subgraph{
590             .operands = std::move(operands),
591             .operations = std::move(operations),
592             .inputIndexes = subgraph.inputIndexes,
593             .outputIndexes = subgraph.outputIndexes,
594     };
595 }
596 
unvalidatedConvert(const nn::BufferDesc & bufferDesc)597 nn::GeneralResult<BufferDesc> unvalidatedConvert(const nn::BufferDesc& bufferDesc) {
598     return BufferDesc{.dimensions = bufferDesc.dimensions};
599 }
600 
unvalidatedConvert(const nn::BufferRole & bufferRole)601 nn::GeneralResult<BufferRole> unvalidatedConvert(const nn::BufferRole& bufferRole) {
602     return BufferRole{
603             .modelIndex = bufferRole.modelIndex,
604             .ioIndex = bufferRole.ioIndex,
605             .frequency = bufferRole.probability,
606     };
607 }
608 
unvalidatedConvert(const nn::Request & request)609 nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {
610     if (!hal::utils::hasNoPointerData(request)) {
611         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
612                << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
613     }
614 
615     auto inputs = NN_TRY(unvalidatedConvert(request.inputs));
616     auto outputs = NN_TRY(unvalidatedConvert(request.outputs));
617     auto pools = NN_TRY(unvalidatedConvert(request.pools));
618     return Request{
619             .inputs = std::move(inputs),
620             .outputs = std::move(outputs),
621             .pools = std::move(pools),
622     };
623 }
624 
unvalidatedConvert(const nn::Request::MemoryPool & memoryPool)625 nn::GeneralResult<Request::MemoryPool> unvalidatedConvert(
626         const nn::Request::MemoryPool& memoryPool) {
627     return std::visit([](const auto& o) { return makeMemoryPool(o); }, memoryPool);
628 }
629 
unvalidatedConvert(const nn::OptionalTimePoint & optionalTimePoint)630 nn::GeneralResult<OptionalTimePoint> unvalidatedConvert(
631         const nn::OptionalTimePoint& optionalTimePoint) {
632     const auto currentSteadyTime = std::chrono::steady_clock::now();
633     const auto currentBootTime = nn::Clock::now();
634 
635     OptionalTimePoint ret;
636     if (optionalTimePoint.has_value()) {
637         const auto bootTimePoint = optionalTimePoint.value();
638 
639         if (bootTimePoint < nn::TimePoint{}) {
640             return NN_ERROR() << "Trying to cast invalid time point";
641         }
642 
643         // Both bootTimePoint and currentBootTime are guaranteed to be non-negative, so this
644         // subtraction will never overflow or underflow.
645         const auto timeRemaining = bootTimePoint - currentBootTime;
646 
647         // currentSteadyTime is guaranteed to be non-negative, so this code only protects against an
648         // overflow.
649         std::chrono::steady_clock::time_point steadyTimePoint;
650         constexpr auto kZeroNano = std::chrono::nanoseconds::zero();
651         constexpr auto kMaxTime = std::chrono::steady_clock::time_point::max();
652         if (timeRemaining > kZeroNano && currentSteadyTime > kMaxTime - timeRemaining) {
653             steadyTimePoint = kMaxTime;
654         } else {
655             steadyTimePoint = currentSteadyTime + timeRemaining;
656         }
657 
658         const uint64_t count = makeUint64FromNanos(steadyTimePoint.time_since_epoch());
659         ret.nanosecondsSinceEpoch(count);
660     }
661     return ret;
662 }
663 
unvalidatedConvert(const nn::OptionalDuration & optionalTimeoutDuration)664 nn::GeneralResult<OptionalTimeoutDuration> unvalidatedConvert(
665         const nn::OptionalDuration& optionalTimeoutDuration) {
666     OptionalTimeoutDuration ret;
667     if (optionalTimeoutDuration.has_value()) {
668         const auto count = optionalTimeoutDuration.value().count();
669         ret.nanoseconds(count);
670     }
671     return ret;
672 }
673 
unvalidatedConvert(const nn::ErrorStatus & errorStatus)674 nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& errorStatus) {
675     switch (errorStatus) {
676         case nn::ErrorStatus::NONE:
677         case nn::ErrorStatus::DEVICE_UNAVAILABLE:
678         case nn::ErrorStatus::GENERAL_FAILURE:
679         case nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
680         case nn::ErrorStatus::INVALID_ARGUMENT:
681         case nn::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
682         case nn::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
683         case nn::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
684         case nn::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
685             return static_cast<ErrorStatus>(errorStatus);
686         default:
687             return ErrorStatus::GENERAL_FAILURE;
688     }
689 }
690 
convert(const nn::Priority & priority)691 nn::GeneralResult<Priority> convert(const nn::Priority& priority) {
692     return validatedConvert(priority);
693 }
694 
convert(const nn::Capabilities & capabilities)695 nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities) {
696     return validatedConvert(capabilities);
697 }
698 
convert(const nn::Model & model)699 nn::GeneralResult<Model> convert(const nn::Model& model) {
700     return validatedConvert(model);
701 }
702 
convert(const nn::BufferDesc & bufferDesc)703 nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) {
704     return validatedConvert(bufferDesc);
705 }
706 
convert(const nn::Request & request)707 nn::GeneralResult<Request> convert(const nn::Request& request) {
708     return validatedConvert(request);
709 }
710 
convert(const nn::OptionalTimePoint & optionalTimePoint)711 nn::GeneralResult<OptionalTimePoint> convert(const nn::OptionalTimePoint& optionalTimePoint) {
712     return validatedConvert(optionalTimePoint);
713 }
714 
convert(const nn::OptionalDuration & optionalTimeoutDuration)715 nn::GeneralResult<OptionalTimeoutDuration> convert(
716         const nn::OptionalDuration& optionalTimeoutDuration) {
717     return validatedConvert(optionalTimeoutDuration);
718 }
719 
convert(const nn::ErrorStatus & errorStatus)720 nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus) {
721     return validatedConvert(errorStatus);
722 }
723 
convert(const nn::SharedHandle & handle)724 nn::GeneralResult<hidl_handle> convert(const nn::SharedHandle& handle) {
725     return validatedConvert(handle);
726 }
727 
convert(const nn::SharedMemory & memory)728 nn::GeneralResult<hidl_memory> convert(const nn::SharedMemory& memory) {
729     return validatedConvert(memory);
730 }
731 
convert(const std::vector<nn::BufferRole> & bufferRoles)732 nn::GeneralResult<hidl_vec<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles) {
733     return validatedConvert(bufferRoles);
734 }
735 
convert(const nn::DeviceStatus & deviceStatus)736 nn::GeneralResult<V1_0::DeviceStatus> convert(const nn::DeviceStatus& deviceStatus) {
737     return V1_2::utils::convert(deviceStatus);
738 }
739 
convert(const nn::ExecutionPreference & executionPreference)740 nn::GeneralResult<V1_1::ExecutionPreference> convert(
741         const nn::ExecutionPreference& executionPreference) {
742     return V1_2::utils::convert(executionPreference);
743 }
744 
convert(const std::vector<nn::Extension> & extensions)745 nn::GeneralResult<hidl_vec<V1_2::Extension>> convert(const std::vector<nn::Extension>& extensions) {
746     return V1_2::utils::convert(extensions);
747 }
748 
convert(const std::vector<nn::SharedHandle> & handles)749 nn::GeneralResult<hidl_vec<hidl_handle>> convert(const std::vector<nn::SharedHandle>& handles) {
750     return V1_2::utils::convert(handles);
751 }
752 
convert(const std::vector<nn::OutputShape> & outputShapes)753 nn::GeneralResult<hidl_vec<V1_2::OutputShape>> convert(
754         const std::vector<nn::OutputShape>& outputShapes) {
755     return V1_2::utils::convert(outputShapes);
756 }
757 
convert(const nn::DeviceType & deviceType)758 nn::GeneralResult<V1_2::DeviceType> convert(const nn::DeviceType& deviceType) {
759     return V1_2::utils::convert(deviceType);
760 }
761 
convert(const nn::MeasureTiming & measureTiming)762 nn::GeneralResult<V1_2::MeasureTiming> convert(const nn::MeasureTiming& measureTiming) {
763     return V1_2::utils::convert(measureTiming);
764 }
765 
convert(const nn::Timing & timing)766 nn::GeneralResult<V1_2::Timing> convert(const nn::Timing& timing) {
767     return V1_2::utils::convert(timing);
768 }
769 
convertSyncFences(const std::vector<nn::SyncFence> & syncFences)770 nn::GeneralResult<hidl_vec<hidl_handle>> convertSyncFences(
771         const std::vector<nn::SyncFence>& syncFences) {
772     std::vector<nn::SharedHandle> handles;
773     handles.reserve(syncFences.size());
774     std::transform(syncFences.begin(), syncFences.end(), std::back_inserter(handles),
775                    [](const nn::SyncFence& syncFence) { return syncFence.getSharedHandle(); });
776     return convert(handles);
777 }
778 
779 }  // namespace android::hardware::neuralnetworks::V1_3::utils
780