1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Conversions.h"
18
19 #include <android-base/logging.h>
20 #include <android/hardware/neuralnetworks/1.3/types.h>
21 #include <nnapi/OperandTypes.h>
22 #include <nnapi/OperationTypes.h>
23 #include <nnapi/Result.h>
24 #include <nnapi/SharedMemory.h>
25 #include <nnapi/TypeUtils.h>
26 #include <nnapi/Types.h>
27 #include <nnapi/Validation.h>
28 #include <nnapi/hal/1.0/Conversions.h>
29 #include <nnapi/hal/1.2/Conversions.h>
30 #include <nnapi/hal/CommonUtils.h>
31
32 #include <algorithm>
33 #include <chrono>
34 #include <functional>
35 #include <iterator>
36 #include <limits>
37 #include <type_traits>
38 #include <utility>
39
40 #include "Utils.h"
41
42 namespace {
43
makeNanosFromUint64(uint64_t nanoseconds)44 std::chrono::nanoseconds makeNanosFromUint64(uint64_t nanoseconds) {
45 constexpr auto kMaxCount = std::chrono::nanoseconds::max().count();
46 using CommonType = std::common_type_t<std::chrono::nanoseconds::rep, uint64_t>;
47 const auto count = std::min<CommonType>(kMaxCount, nanoseconds);
48 return std::chrono::nanoseconds{static_cast<std::chrono::nanoseconds::rep>(count)};
49 }
50
makeUint64FromNanos(std::chrono::nanoseconds nanoseconds)51 uint64_t makeUint64FromNanos(std::chrono::nanoseconds nanoseconds) {
52 if (nanoseconds < std::chrono::nanoseconds::zero()) {
53 return 0;
54 }
55 constexpr auto kMaxCount = std::numeric_limits<uint64_t>::max();
56 using CommonType = std::common_type_t<std::chrono::nanoseconds::rep, uint64_t>;
57 const auto count = std::min<CommonType>(kMaxCount, nanoseconds.count());
58 return static_cast<uint64_t>(count);
59 }
60
61 template <typename Type>
underlyingType(Type value)62 constexpr std::underlying_type_t<Type> underlyingType(Type value) {
63 return static_cast<std::underlying_type_t<Type>>(value);
64 }
65
66 } // namespace
67
68 namespace android::nn {
69 namespace {
70
71 using hardware::hidl_vec;
72
73 template <typename Input>
74 using UnvalidatedConvertOutput =
75 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
76
77 template <typename Type>
unvalidatedConvert(const hidl_vec<Type> & arguments)78 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
79 const hidl_vec<Type>& arguments) {
80 std::vector<UnvalidatedConvertOutput<Type>> canonical;
81 canonical.reserve(arguments.size());
82 for (const auto& argument : arguments) {
83 canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
84 }
85 return canonical;
86 }
87
88 template <typename Type>
validatedConvert(const Type & halObject)89 GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
90 auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
91 NN_TRY(hal::V1_3::utils::compliantVersion(canonical));
92 return canonical;
93 }
94
95 template <typename Type>
validatedConvert(const hidl_vec<Type> & arguments)96 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
97 const hidl_vec<Type>& arguments) {
98 std::vector<UnvalidatedConvertOutput<Type>> canonical;
99 canonical.reserve(arguments.size());
100 for (const auto& argument : arguments) {
101 canonical.push_back(NN_TRY(validatedConvert(argument)));
102 }
103 return canonical;
104 }
105
106 } // anonymous namespace
107
unvalidatedConvert(const hal::V1_3::OperandType & operandType)108 GeneralResult<OperandType> unvalidatedConvert(const hal::V1_3::OperandType& operandType) {
109 return static_cast<OperandType>(operandType);
110 }
111
unvalidatedConvert(const hal::V1_3::OperationType & operationType)112 GeneralResult<OperationType> unvalidatedConvert(const hal::V1_3::OperationType& operationType) {
113 return static_cast<OperationType>(operationType);
114 }
115
unvalidatedConvert(const hal::V1_3::Priority & priority)116 GeneralResult<Priority> unvalidatedConvert(const hal::V1_3::Priority& priority) {
117 return static_cast<Priority>(priority);
118 }
119
unvalidatedConvert(const hal::V1_3::Capabilities & capabilities)120 GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_3::Capabilities& capabilities) {
121 const bool validOperandTypes = std::all_of(
122 capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
123 [](const hal::V1_3::Capabilities::OperandPerformance& operandPerformance) {
124 return validatedConvert(operandPerformance.type).has_value();
125 });
126 if (!validOperandTypes) {
127 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
128 << "Invalid OperandType when unvalidatedConverting OperandPerformance in "
129 "Capabilities";
130 }
131
132 auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
133 auto table =
134 NN_TRY(Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)));
135
136 const auto relaxedFloat32toFloat16PerformanceScalar =
137 NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar));
138 const auto relaxedFloat32toFloat16PerformanceTensor =
139 NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor));
140 const auto ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance));
141 const auto whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance));
142 return Capabilities{
143 .relaxedFloat32toFloat16PerformanceScalar = relaxedFloat32toFloat16PerformanceScalar,
144 .relaxedFloat32toFloat16PerformanceTensor = relaxedFloat32toFloat16PerformanceTensor,
145 .operandPerformance = std::move(table),
146 .ifPerformance = ifPerformance,
147 .whilePerformance = whilePerformance,
148 };
149 }
150
unvalidatedConvert(const hal::V1_3::Capabilities::OperandPerformance & operandPerformance)151 GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
152 const hal::V1_3::Capabilities::OperandPerformance& operandPerformance) {
153 const auto type = NN_TRY(unvalidatedConvert(operandPerformance.type));
154 const auto info = NN_TRY(unvalidatedConvert(operandPerformance.info));
155 return Capabilities::OperandPerformance{
156 .type = type,
157 .info = info,
158 };
159 }
160
unvalidatedConvert(const hal::V1_3::Operation & operation)161 GeneralResult<Operation> unvalidatedConvert(const hal::V1_3::Operation& operation) {
162 const auto type = NN_TRY(unvalidatedConvert(operation.type));
163 return Operation{
164 .type = type,
165 .inputs = operation.inputs,
166 .outputs = operation.outputs,
167 };
168 }
169
unvalidatedConvert(const hal::V1_3::OperandLifeTime & operandLifeTime)170 GeneralResult<Operand::LifeTime> unvalidatedConvert(
171 const hal::V1_3::OperandLifeTime& operandLifeTime) {
172 return static_cast<Operand::LifeTime>(operandLifeTime);
173 }
174
unvalidatedConvert(const hal::V1_3::Operand & operand)175 GeneralResult<Operand> unvalidatedConvert(const hal::V1_3::Operand& operand) {
176 const auto type = NN_TRY(unvalidatedConvert(operand.type));
177 const auto lifetime = NN_TRY(unvalidatedConvert(operand.lifetime));
178 const auto location = NN_TRY(unvalidatedConvert(operand.location));
179 auto extraParams = NN_TRY(unvalidatedConvert(operand.extraParams));
180 return Operand{
181 .type = type,
182 .dimensions = operand.dimensions,
183 .scale = operand.scale,
184 .zeroPoint = operand.zeroPoint,
185 .lifetime = lifetime,
186 .location = location,
187 .extraParams = std::move(extraParams),
188 };
189 }
190
unvalidatedConvert(const hal::V1_3::Model & model)191 GeneralResult<Model> unvalidatedConvert(const hal::V1_3::Model& model) {
192 auto main = NN_TRY(unvalidatedConvert(model.main));
193 auto referenced = NN_TRY(unvalidatedConvert(model.referenced));
194 auto operandValues = NN_TRY(unvalidatedConvert(model.operandValues));
195 auto pools = NN_TRY(unvalidatedConvert(model.pools));
196 auto extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix));
197 return Model{
198 .main = std::move(main),
199 .referenced = std::move(referenced),
200 .operandValues = std::move(operandValues),
201 .pools = std::move(pools),
202 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
203 .extensionNameToPrefix = std::move(extensionNameToPrefix),
204 };
205 }
206
unvalidatedConvert(const hal::V1_3::Subgraph & subgraph)207 GeneralResult<Model::Subgraph> unvalidatedConvert(const hal::V1_3::Subgraph& subgraph) {
208 auto operations = NN_TRY(unvalidatedConvert(subgraph.operations));
209
210 // Verify number of consumers.
211 const auto numberOfConsumers =
212 NN_TRY(countNumberOfConsumers(subgraph.operands.size(), operations));
213 CHECK(subgraph.operands.size() == numberOfConsumers.size());
214 for (size_t i = 0; i < subgraph.operands.size(); ++i) {
215 if (subgraph.operands[i].numberOfConsumers != numberOfConsumers[i]) {
216 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
217 << "Invalid numberOfConsumers for operand " << i << ", expected "
218 << numberOfConsumers[i] << " but found "
219 << subgraph.operands[i].numberOfConsumers;
220 }
221 }
222
223 auto operands = NN_TRY(unvalidatedConvert(subgraph.operands));
224 return Model::Subgraph{
225 .operands = std::move(operands),
226 .operations = std::move(operations),
227 .inputIndexes = subgraph.inputIndexes,
228 .outputIndexes = subgraph.outputIndexes,
229 };
230 }
231
unvalidatedConvert(const hal::V1_3::BufferDesc & bufferDesc)232 GeneralResult<BufferDesc> unvalidatedConvert(const hal::V1_3::BufferDesc& bufferDesc) {
233 return BufferDesc{.dimensions = bufferDesc.dimensions};
234 }
235
unvalidatedConvert(const hal::V1_3::BufferRole & bufferRole)236 GeneralResult<BufferRole> unvalidatedConvert(const hal::V1_3::BufferRole& bufferRole) {
237 return BufferRole{
238 .modelIndex = bufferRole.modelIndex,
239 .ioIndex = bufferRole.ioIndex,
240 .probability = bufferRole.frequency,
241 };
242 }
243
unvalidatedConvert(const hal::V1_3::Request & request)244 GeneralResult<Request> unvalidatedConvert(const hal::V1_3::Request& request) {
245 auto inputs = NN_TRY(unvalidatedConvert(request.inputs));
246 auto outputs = NN_TRY(unvalidatedConvert(request.outputs));
247 auto pools = NN_TRY(unvalidatedConvert(request.pools));
248 return Request{
249 .inputs = std::move(inputs),
250 .outputs = std::move(outputs),
251 .pools = std::move(pools),
252 };
253 }
254
unvalidatedConvert(const hal::V1_3::Request::MemoryPool & memoryPool)255 GeneralResult<Request::MemoryPool> unvalidatedConvert(
256 const hal::V1_3::Request::MemoryPool& memoryPool) {
257 using Discriminator = hal::V1_3::Request::MemoryPool::hidl_discriminator;
258 switch (memoryPool.getDiscriminator()) {
259 case Discriminator::hidlMemory:
260 return unvalidatedConvert(memoryPool.hidlMemory());
261 case Discriminator::token:
262 return static_cast<Request::MemoryDomainToken>(memoryPool.token());
263 }
264 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
265 << "Invalid Request::MemoryPool discriminator "
266 << underlyingType(memoryPool.getDiscriminator());
267 }
268
unvalidatedConvert(const hal::V1_3::OptionalTimePoint & optionalTimePoint)269 GeneralResult<OptionalTimePoint> unvalidatedConvert(
270 const hal::V1_3::OptionalTimePoint& optionalTimePoint) {
271 using Discriminator = hal::V1_3::OptionalTimePoint::hidl_discriminator;
272 switch (optionalTimePoint.getDiscriminator()) {
273 case Discriminator::none:
274 return {};
275 case Discriminator::nanosecondsSinceEpoch: {
276 const auto currentSteadyTime = std::chrono::steady_clock::now();
277 const auto currentBootTime = Clock::now();
278
279 const auto timeSinceEpoch =
280 makeNanosFromUint64(optionalTimePoint.nanosecondsSinceEpoch());
281 const auto steadyTimePoint = std::chrono::steady_clock::time_point{timeSinceEpoch};
282
283 // Both steadyTimePoint and currentSteadyTime are guaranteed to be non-negative, so this
284 // subtraction will never overflow or underflow.
285 const auto timeRemaining = steadyTimePoint - currentSteadyTime;
286
287 // currentBootTime is guaranteed to be non-negative, so this code only protects against
288 // an overflow.
289 nn::TimePoint bootTimePoint;
290 constexpr auto kZeroNano = std::chrono::nanoseconds::zero();
291 constexpr auto kMaxTime = nn::TimePoint::max();
292 if (timeRemaining > kZeroNano && currentBootTime > kMaxTime - timeRemaining) {
293 bootTimePoint = kMaxTime;
294 } else {
295 bootTimePoint = currentBootTime + timeRemaining;
296 }
297
298 constexpr auto kZeroTime = nn::TimePoint{};
299 return std::max(bootTimePoint, kZeroTime);
300 }
301 }
302 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
303 << "Invalid OptionalTimePoint discriminator "
304 << underlyingType(optionalTimePoint.getDiscriminator());
305 }
306
unvalidatedConvert(const hal::V1_3::OptionalTimeoutDuration & optionalTimeoutDuration)307 GeneralResult<OptionalDuration> unvalidatedConvert(
308 const hal::V1_3::OptionalTimeoutDuration& optionalTimeoutDuration) {
309 using Discriminator = hal::V1_3::OptionalTimeoutDuration::hidl_discriminator;
310 switch (optionalTimeoutDuration.getDiscriminator()) {
311 case Discriminator::none:
312 return {};
313 case Discriminator::nanoseconds:
314 return Duration(optionalTimeoutDuration.nanoseconds());
315 }
316 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
317 << "Invalid OptionalTimeoutDuration discriminator "
318 << underlyingType(optionalTimeoutDuration.getDiscriminator());
319 }
320
unvalidatedConvert(const hal::V1_3::ErrorStatus & status)321 GeneralResult<ErrorStatus> unvalidatedConvert(const hal::V1_3::ErrorStatus& status) {
322 switch (status) {
323 case hal::V1_3::ErrorStatus::NONE:
324 case hal::V1_3::ErrorStatus::DEVICE_UNAVAILABLE:
325 case hal::V1_3::ErrorStatus::GENERAL_FAILURE:
326 case hal::V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
327 case hal::V1_3::ErrorStatus::INVALID_ARGUMENT:
328 case hal::V1_3::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
329 case hal::V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
330 case hal::V1_3::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
331 case hal::V1_3::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
332 return static_cast<ErrorStatus>(status);
333 }
334 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
335 << "Invalid ErrorStatus " << underlyingType(status);
336 }
337
convert(const hal::V1_3::Priority & priority)338 GeneralResult<Priority> convert(const hal::V1_3::Priority& priority) {
339 return validatedConvert(priority);
340 }
341
convert(const hal::V1_3::Capabilities & capabilities)342 GeneralResult<Capabilities> convert(const hal::V1_3::Capabilities& capabilities) {
343 return validatedConvert(capabilities);
344 }
345
convert(const hal::V1_3::Model & model)346 GeneralResult<Model> convert(const hal::V1_3::Model& model) {
347 return validatedConvert(model);
348 }
349
convert(const hal::V1_3::BufferDesc & bufferDesc)350 GeneralResult<BufferDesc> convert(const hal::V1_3::BufferDesc& bufferDesc) {
351 return validatedConvert(bufferDesc);
352 }
353
convert(const hal::V1_3::Request & request)354 GeneralResult<Request> convert(const hal::V1_3::Request& request) {
355 return validatedConvert(request);
356 }
357
convert(const hal::V1_3::OptionalTimePoint & optionalTimePoint)358 GeneralResult<OptionalTimePoint> convert(const hal::V1_3::OptionalTimePoint& optionalTimePoint) {
359 return validatedConvert(optionalTimePoint);
360 }
361
convert(const hal::V1_3::OptionalTimeoutDuration & optionalTimeoutDuration)362 GeneralResult<OptionalDuration> convert(
363 const hal::V1_3::OptionalTimeoutDuration& optionalTimeoutDuration) {
364 return validatedConvert(optionalTimeoutDuration);
365 }
366
convert(const hal::V1_3::ErrorStatus & errorStatus)367 GeneralResult<ErrorStatus> convert(const hal::V1_3::ErrorStatus& errorStatus) {
368 return validatedConvert(errorStatus);
369 }
370
convert(const hardware::hidl_handle & handle)371 GeneralResult<SharedHandle> convert(const hardware::hidl_handle& handle) {
372 return validatedConvert(handle);
373 }
374
convert(const hardware::hidl_vec<hal::V1_3::BufferRole> & bufferRoles)375 GeneralResult<std::vector<BufferRole>> convert(
376 const hardware::hidl_vec<hal::V1_3::BufferRole>& bufferRoles) {
377 return validatedConvert(bufferRoles);
378 }
379
380 } // namespace android::nn
381
382 namespace android::hardware::neuralnetworks::V1_3::utils {
383 namespace {
384
385 using utils::unvalidatedConvert;
386
unvalidatedConvert(const nn::Capabilities::PerformanceInfo & performanceInfo)387 nn::GeneralResult<V1_0::PerformanceInfo> unvalidatedConvert(
388 const nn::Capabilities::PerformanceInfo& performanceInfo) {
389 return V1_0::utils::unvalidatedConvert(performanceInfo);
390 }
391
unvalidatedConvert(const nn::DataLocation & dataLocation)392 nn::GeneralResult<V1_0::DataLocation> unvalidatedConvert(const nn::DataLocation& dataLocation) {
393 return V1_0::utils::unvalidatedConvert(dataLocation);
394 }
395
unvalidatedConvert(const nn::Model::OperandValues & operandValues)396 nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
397 const nn::Model::OperandValues& operandValues) {
398 return V1_0::utils::unvalidatedConvert(operandValues);
399 }
400
unvalidatedConvert(const nn::SharedHandle & handle)401 nn::GeneralResult<hidl_handle> unvalidatedConvert(const nn::SharedHandle& handle) {
402 return V1_0::utils::unvalidatedConvert(handle);
403 }
404
unvalidatedConvert(const nn::SharedMemory & memory)405 nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory) {
406 return V1_0::utils::unvalidatedConvert(memory);
407 }
408
unvalidatedConvert(const nn::Request::Argument & argument)409 nn::GeneralResult<V1_0::RequestArgument> unvalidatedConvert(const nn::Request::Argument& argument) {
410 return V1_0::utils::unvalidatedConvert(argument);
411 }
412
unvalidatedConvert(const nn::Operand::ExtraParams & extraParams)413 nn::GeneralResult<V1_2::Operand::ExtraParams> unvalidatedConvert(
414 const nn::Operand::ExtraParams& extraParams) {
415 return V1_2::utils::unvalidatedConvert(extraParams);
416 }
417
unvalidatedConvert(const nn::ExtensionNameAndPrefix & extensionNameAndPrefix)418 nn::GeneralResult<V1_2::Model::ExtensionNameAndPrefix> unvalidatedConvert(
419 const nn::ExtensionNameAndPrefix& extensionNameAndPrefix) {
420 return V1_2::utils::unvalidatedConvert(extensionNameAndPrefix);
421 }
422
423 template <typename Input>
424 using UnvalidatedConvertOutput =
425 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
426
427 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)428 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
429 const std::vector<Type>& arguments) {
430 hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
431 for (size_t i = 0; i < arguments.size(); ++i) {
432 halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
433 }
434 return halObject;
435 }
436
makeMemoryPool(const nn::SharedMemory & memory)437 nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::SharedMemory& memory) {
438 Request::MemoryPool ret;
439 ret.hidlMemory(NN_TRY(unvalidatedConvert(memory)));
440 return ret;
441 }
442
makeMemoryPool(const nn::Request::MemoryDomainToken & token)443 nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::Request::MemoryDomainToken& token) {
444 Request::MemoryPool ret;
445 ret.token(underlyingType(token));
446 return ret;
447 }
448
makeMemoryPool(const nn::SharedBuffer &)449 nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::SharedBuffer& /*buffer*/) {
450 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "Unable to make memory pool from IBuffer";
451 }
452
453 using utils::unvalidatedConvert;
454
455 template <typename Type>
validatedConvert(const Type & canonical)456 nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
457 NN_TRY(compliantVersion(canonical));
458 return unvalidatedConvert(canonical);
459 }
460
461 template <typename Type>
validatedConvert(const std::vector<Type> & arguments)462 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> validatedConvert(
463 const std::vector<Type>& arguments) {
464 hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
465 for (size_t i = 0; i < arguments.size(); ++i) {
466 halObject[i] = NN_TRY(validatedConvert(arguments[i]));
467 }
468 return halObject;
469 }
470
471 } // anonymous namespace
472
unvalidatedConvert(const nn::OperandType & operandType)473 nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
474 return static_cast<OperandType>(operandType);
475 }
476
unvalidatedConvert(const nn::OperationType & operationType)477 nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
478 return static_cast<OperationType>(operationType);
479 }
480
unvalidatedConvert(const nn::Priority & priority)481 nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority) {
482 return static_cast<Priority>(priority);
483 }
484
unvalidatedConvert(const nn::Capabilities & capabilities)485 nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities) {
486 std::vector<nn::Capabilities::OperandPerformance> filteredOperandPerformances;
487 filteredOperandPerformances.reserve(capabilities.operandPerformance.asVector().size());
488 std::copy_if(capabilities.operandPerformance.asVector().begin(),
489 capabilities.operandPerformance.asVector().end(),
490 std::back_inserter(filteredOperandPerformances),
491 [](const nn::Capabilities::OperandPerformance& operandPerformance) {
492 return compliantVersion(operandPerformance.type).has_value();
493 });
494
495 const auto relaxedFloat32toFloat16PerformanceScalar =
496 NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar));
497 const auto relaxedFloat32toFloat16PerformanceTensor =
498 NN_TRY(unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor));
499 auto operandPerformance = NN_TRY(unvalidatedConvert(filteredOperandPerformances));
500 const auto ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance));
501 const auto whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance));
502 return Capabilities{
503 .relaxedFloat32toFloat16PerformanceScalar = relaxedFloat32toFloat16PerformanceScalar,
504 .relaxedFloat32toFloat16PerformanceTensor = relaxedFloat32toFloat16PerformanceTensor,
505 .operandPerformance = std::move(operandPerformance),
506 .ifPerformance = ifPerformance,
507 .whilePerformance = whilePerformance,
508 };
509 }
510
unvalidatedConvert(const nn::Capabilities::OperandPerformance & operandPerformance)511 nn::GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
512 const nn::Capabilities::OperandPerformance& operandPerformance) {
513 const auto type = NN_TRY(unvalidatedConvert(operandPerformance.type));
514 const auto info = NN_TRY(unvalidatedConvert(operandPerformance.info));
515 return Capabilities::OperandPerformance{
516 .type = type,
517 .info = info,
518 };
519 }
520
unvalidatedConvert(const nn::Operation & operation)521 nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
522 const auto type = NN_TRY(unvalidatedConvert(operation.type));
523 return Operation{
524 .type = type,
525 .inputs = operation.inputs,
526 .outputs = operation.outputs,
527 };
528 }
529
unvalidatedConvert(const nn::Operand::LifeTime & operandLifeTime)530 nn::GeneralResult<OperandLifeTime> unvalidatedConvert(
531 const nn::Operand::LifeTime& operandLifeTime) {
532 if (operandLifeTime == nn::Operand::LifeTime::POINTER) {
533 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
534 << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
535 }
536 return static_cast<OperandLifeTime>(operandLifeTime);
537 }
538
unvalidatedConvert(const nn::Operand & operand)539 nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
540 const auto type = NN_TRY(unvalidatedConvert(operand.type));
541 const auto lifetime = NN_TRY(unvalidatedConvert(operand.lifetime));
542 const auto location = NN_TRY(unvalidatedConvert(operand.location));
543 auto extraParams = NN_TRY(unvalidatedConvert(operand.extraParams));
544 return Operand{
545 .type = type,
546 .dimensions = operand.dimensions,
547 .numberOfConsumers = 0,
548 .scale = operand.scale,
549 .zeroPoint = operand.zeroPoint,
550 .lifetime = lifetime,
551 .location = location,
552 .extraParams = std::move(extraParams),
553 };
554 }
555
unvalidatedConvert(const nn::Model & model)556 nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
557 if (!hal::utils::hasNoPointerData(model)) {
558 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
559 << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
560 }
561
562 auto main = NN_TRY(unvalidatedConvert(model.main));
563 auto referenced = NN_TRY(unvalidatedConvert(model.referenced));
564 auto operandValues = NN_TRY(unvalidatedConvert(model.operandValues));
565 auto pools = NN_TRY(unvalidatedConvert(model.pools));
566 auto extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix));
567 return Model{
568 .main = std::move(main),
569 .referenced = std::move(referenced),
570 .operandValues = std::move(operandValues),
571 .pools = std::move(pools),
572 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
573 .extensionNameToPrefix = std::move(extensionNameToPrefix),
574 };
575 }
576
unvalidatedConvert(const nn::Model::Subgraph & subgraph)577 nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgraph) {
578 auto operands = NN_TRY(unvalidatedConvert(subgraph.operands));
579
580 // Update number of consumers.
581 const auto numberOfConsumers =
582 NN_TRY(countNumberOfConsumers(operands.size(), subgraph.operations));
583 CHECK(operands.size() == numberOfConsumers.size());
584 for (size_t i = 0; i < operands.size(); ++i) {
585 operands[i].numberOfConsumers = numberOfConsumers[i];
586 }
587
588 auto operations = NN_TRY(unvalidatedConvert(subgraph.operations));
589 return Subgraph{
590 .operands = std::move(operands),
591 .operations = std::move(operations),
592 .inputIndexes = subgraph.inputIndexes,
593 .outputIndexes = subgraph.outputIndexes,
594 };
595 }
596
unvalidatedConvert(const nn::BufferDesc & bufferDesc)597 nn::GeneralResult<BufferDesc> unvalidatedConvert(const nn::BufferDesc& bufferDesc) {
598 return BufferDesc{.dimensions = bufferDesc.dimensions};
599 }
600
unvalidatedConvert(const nn::BufferRole & bufferRole)601 nn::GeneralResult<BufferRole> unvalidatedConvert(const nn::BufferRole& bufferRole) {
602 return BufferRole{
603 .modelIndex = bufferRole.modelIndex,
604 .ioIndex = bufferRole.ioIndex,
605 .frequency = bufferRole.probability,
606 };
607 }
608
unvalidatedConvert(const nn::Request & request)609 nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {
610 if (!hal::utils::hasNoPointerData(request)) {
611 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
612 << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
613 }
614
615 auto inputs = NN_TRY(unvalidatedConvert(request.inputs));
616 auto outputs = NN_TRY(unvalidatedConvert(request.outputs));
617 auto pools = NN_TRY(unvalidatedConvert(request.pools));
618 return Request{
619 .inputs = std::move(inputs),
620 .outputs = std::move(outputs),
621 .pools = std::move(pools),
622 };
623 }
624
unvalidatedConvert(const nn::Request::MemoryPool & memoryPool)625 nn::GeneralResult<Request::MemoryPool> unvalidatedConvert(
626 const nn::Request::MemoryPool& memoryPool) {
627 return std::visit([](const auto& o) { return makeMemoryPool(o); }, memoryPool);
628 }
629
unvalidatedConvert(const nn::OptionalTimePoint & optionalTimePoint)630 nn::GeneralResult<OptionalTimePoint> unvalidatedConvert(
631 const nn::OptionalTimePoint& optionalTimePoint) {
632 const auto currentSteadyTime = std::chrono::steady_clock::now();
633 const auto currentBootTime = nn::Clock::now();
634
635 OptionalTimePoint ret;
636 if (optionalTimePoint.has_value()) {
637 const auto bootTimePoint = optionalTimePoint.value();
638
639 if (bootTimePoint < nn::TimePoint{}) {
640 return NN_ERROR() << "Trying to cast invalid time point";
641 }
642
643 // Both bootTimePoint and currentBootTime are guaranteed to be non-negative, so this
644 // subtraction will never overflow or underflow.
645 const auto timeRemaining = bootTimePoint - currentBootTime;
646
647 // currentSteadyTime is guaranteed to be non-negative, so this code only protects against an
648 // overflow.
649 std::chrono::steady_clock::time_point steadyTimePoint;
650 constexpr auto kZeroNano = std::chrono::nanoseconds::zero();
651 constexpr auto kMaxTime = std::chrono::steady_clock::time_point::max();
652 if (timeRemaining > kZeroNano && currentSteadyTime > kMaxTime - timeRemaining) {
653 steadyTimePoint = kMaxTime;
654 } else {
655 steadyTimePoint = currentSteadyTime + timeRemaining;
656 }
657
658 const uint64_t count = makeUint64FromNanos(steadyTimePoint.time_since_epoch());
659 ret.nanosecondsSinceEpoch(count);
660 }
661 return ret;
662 }
663
unvalidatedConvert(const nn::OptionalDuration & optionalTimeoutDuration)664 nn::GeneralResult<OptionalTimeoutDuration> unvalidatedConvert(
665 const nn::OptionalDuration& optionalTimeoutDuration) {
666 OptionalTimeoutDuration ret;
667 if (optionalTimeoutDuration.has_value()) {
668 const auto count = optionalTimeoutDuration.value().count();
669 ret.nanoseconds(count);
670 }
671 return ret;
672 }
673
unvalidatedConvert(const nn::ErrorStatus & errorStatus)674 nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& errorStatus) {
675 switch (errorStatus) {
676 case nn::ErrorStatus::NONE:
677 case nn::ErrorStatus::DEVICE_UNAVAILABLE:
678 case nn::ErrorStatus::GENERAL_FAILURE:
679 case nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
680 case nn::ErrorStatus::INVALID_ARGUMENT:
681 case nn::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
682 case nn::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
683 case nn::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
684 case nn::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
685 return static_cast<ErrorStatus>(errorStatus);
686 default:
687 return ErrorStatus::GENERAL_FAILURE;
688 }
689 }
690
convert(const nn::Priority & priority)691 nn::GeneralResult<Priority> convert(const nn::Priority& priority) {
692 return validatedConvert(priority);
693 }
694
convert(const nn::Capabilities & capabilities)695 nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities) {
696 return validatedConvert(capabilities);
697 }
698
convert(const nn::Model & model)699 nn::GeneralResult<Model> convert(const nn::Model& model) {
700 return validatedConvert(model);
701 }
702
convert(const nn::BufferDesc & bufferDesc)703 nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) {
704 return validatedConvert(bufferDesc);
705 }
706
convert(const nn::Request & request)707 nn::GeneralResult<Request> convert(const nn::Request& request) {
708 return validatedConvert(request);
709 }
710
convert(const nn::OptionalTimePoint & optionalTimePoint)711 nn::GeneralResult<OptionalTimePoint> convert(const nn::OptionalTimePoint& optionalTimePoint) {
712 return validatedConvert(optionalTimePoint);
713 }
714
convert(const nn::OptionalDuration & optionalTimeoutDuration)715 nn::GeneralResult<OptionalTimeoutDuration> convert(
716 const nn::OptionalDuration& optionalTimeoutDuration) {
717 return validatedConvert(optionalTimeoutDuration);
718 }
719
convert(const nn::ErrorStatus & errorStatus)720 nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus) {
721 return validatedConvert(errorStatus);
722 }
723
convert(const nn::SharedHandle & handle)724 nn::GeneralResult<hidl_handle> convert(const nn::SharedHandle& handle) {
725 return validatedConvert(handle);
726 }
727
convert(const nn::SharedMemory & memory)728 nn::GeneralResult<hidl_memory> convert(const nn::SharedMemory& memory) {
729 return validatedConvert(memory);
730 }
731
convert(const std::vector<nn::BufferRole> & bufferRoles)732 nn::GeneralResult<hidl_vec<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles) {
733 return validatedConvert(bufferRoles);
734 }
735
convert(const nn::DeviceStatus & deviceStatus)736 nn::GeneralResult<V1_0::DeviceStatus> convert(const nn::DeviceStatus& deviceStatus) {
737 return V1_2::utils::convert(deviceStatus);
738 }
739
convert(const nn::ExecutionPreference & executionPreference)740 nn::GeneralResult<V1_1::ExecutionPreference> convert(
741 const nn::ExecutionPreference& executionPreference) {
742 return V1_2::utils::convert(executionPreference);
743 }
744
convert(const std::vector<nn::Extension> & extensions)745 nn::GeneralResult<hidl_vec<V1_2::Extension>> convert(const std::vector<nn::Extension>& extensions) {
746 return V1_2::utils::convert(extensions);
747 }
748
convert(const std::vector<nn::SharedHandle> & handles)749 nn::GeneralResult<hidl_vec<hidl_handle>> convert(const std::vector<nn::SharedHandle>& handles) {
750 return V1_2::utils::convert(handles);
751 }
752
convert(const std::vector<nn::OutputShape> & outputShapes)753 nn::GeneralResult<hidl_vec<V1_2::OutputShape>> convert(
754 const std::vector<nn::OutputShape>& outputShapes) {
755 return V1_2::utils::convert(outputShapes);
756 }
757
convert(const nn::DeviceType & deviceType)758 nn::GeneralResult<V1_2::DeviceType> convert(const nn::DeviceType& deviceType) {
759 return V1_2::utils::convert(deviceType);
760 }
761
convert(const nn::MeasureTiming & measureTiming)762 nn::GeneralResult<V1_2::MeasureTiming> convert(const nn::MeasureTiming& measureTiming) {
763 return V1_2::utils::convert(measureTiming);
764 }
765
convert(const nn::Timing & timing)766 nn::GeneralResult<V1_2::Timing> convert(const nn::Timing& timing) {
767 return V1_2::utils::convert(timing);
768 }
769
convertSyncFences(const std::vector<nn::SyncFence> & syncFences)770 nn::GeneralResult<hidl_vec<hidl_handle>> convertSyncFences(
771 const std::vector<nn::SyncFence>& syncFences) {
772 std::vector<nn::SharedHandle> handles;
773 handles.reserve(syncFences.size());
774 std::transform(syncFences.begin(), syncFences.end(), std::back_inserter(handles),
775 [](const nn::SyncFence& syncFence) { return syncFence.getSharedHandle(); });
776 return convert(handles);
777 }
778
779 } // namespace android::hardware::neuralnetworks::V1_3::utils
780