xref: /aosp_15_r20/hardware/interfaces/neuralnetworks/aidl/vts/functional/ValidateModel.cpp (revision 4d7e907c777eeecc4c5bd7cf640a754fac206ff7)
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_aidl_hal_test"
18 
19 #include <aidl/android/hardware/common/NativeHandle.h>
20 #include <android/binder_auto_utils.h>
21 #include <android/binder_enums.h>
22 #include <android/binder_interface_utils.h>
23 #include <nnapi/TypeUtils.h>
24 #include <nnapi/hal/aidl/Conversions.h>
25 #include <nnapi/hal/aidl/Utils.h>
26 
27 #include <optional>
28 #include <type_traits>
29 #include <utility>
30 
31 #include "Callbacks.h"
32 #include "GeneratedTestHarness.h"
33 #include "Utils.h"
34 #include "VtsHalNeuralnetworks.h"
35 
36 namespace aidl::android::hardware::neuralnetworks::vts::functional {
37 
38 using common::NativeHandle;
39 using implementation::PreparedModelCallback;
40 
41 using PrepareModelMutation = std::function<void(Model*, ExecutionPreference*, Priority*)>;
42 
43 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
44 
validateGetSupportedOperations(const std::shared_ptr<IDevice> & device,const std::string & message,const Model & model)45 static void validateGetSupportedOperations(const std::shared_ptr<IDevice>& device,
46                                            const std::string& message, const Model& model) {
47     SCOPED_TRACE(message + " [getSupportedOperations]");
48 
49     std::vector<bool> supported;
50     const auto retStatus = device->getSupportedOperations(model, &supported);
51 
52     ASSERT_FALSE(retStatus.isOk());
53     ASSERT_EQ(retStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
54     ASSERT_EQ(static_cast<ErrorStatus>(retStatus.getServiceSpecificError()),
55               ErrorStatus::INVALID_ARGUMENT);
56 }
57 
validatePrepareModel(const std::shared_ptr<IDevice> & device,const std::string & message,const Model & model,ExecutionPreference preference,Priority priority)58 static void validatePrepareModel(const std::shared_ptr<IDevice>& device, const std::string& message,
59                                  const Model& model, ExecutionPreference preference,
60                                  Priority priority) {
61     SCOPED_TRACE(message + " [prepareModel]");
62 
63     std::shared_ptr<PreparedModelCallback> preparedModelCallback =
64             ndk::SharedRefBase::make<PreparedModelCallback>();
65     const auto prepareLaunchStatus =
66             device->prepareModel(model, preference, priority, kNoDeadline, {}, {}, kEmptyCacheToken,
67                                  preparedModelCallback);
68     ASSERT_FALSE(prepareLaunchStatus.isOk());
69     ASSERT_EQ(prepareLaunchStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
70     ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus.getServiceSpecificError()),
71               ErrorStatus::INVALID_ARGUMENT);
72 
73     preparedModelCallback->wait();
74     ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
75     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
76     std::shared_ptr<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();
77     ASSERT_EQ(nullptr, preparedModel.get());
78 }
79 
validatePrepareModelWithConfig(const std::shared_ptr<IDevice> & device,const std::string & message,const Model & model,ExecutionPreference preference,Priority priority)80 static void validatePrepareModelWithConfig(const std::shared_ptr<IDevice>& device,
81                                            const std::string& message, const Model& model,
82                                            ExecutionPreference preference, Priority priority) {
83     SCOPED_TRACE(message + " [prepareModelWithConfig]");
84 
85     std::shared_ptr<PreparedModelCallback> preparedModelCallback =
86             ndk::SharedRefBase::make<PreparedModelCallback>();
87     const auto prepareLaunchStatus = device->prepareModelWithConfig(
88             model, {preference, priority, kNoDeadline, {}, {}, kEmptyCacheTokenArray, {}, {}},
89             preparedModelCallback);
90     ASSERT_FALSE(prepareLaunchStatus.isOk());
91     ASSERT_EQ(prepareLaunchStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
92     ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus.getServiceSpecificError()),
93               ErrorStatus::INVALID_ARGUMENT);
94 
95     preparedModelCallback->wait();
96     ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
97     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
98     std::shared_ptr<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();
99     ASSERT_EQ(nullptr, preparedModel.get());
100 }
101 
validExecutionPreference(ExecutionPreference preference)102 static bool validExecutionPreference(ExecutionPreference preference) {
103     return preference == ExecutionPreference::LOW_POWER ||
104            preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
105            preference == ExecutionPreference::SUSTAINED_SPEED;
106 }
107 
validExecutionPriority(Priority priority)108 static bool validExecutionPriority(Priority priority) {
109     return priority == Priority::LOW || priority == Priority::MEDIUM || priority == Priority::HIGH;
110 }
111 
112 // Primary validation function. This function will take a valid model, apply a
113 // mutation to invalidate the model, the execution preference, or the priority,
114 // then pass these to supportedOperations and/or prepareModel if that method is
115 // called with an invalid argument.
validate(const std::shared_ptr<IDevice> & device,const std::string & message,const Model & originalModel,const PrepareModelMutation & mutate)116 static void validate(const std::shared_ptr<IDevice>& device, const std::string& message,
117                      const Model& originalModel, const PrepareModelMutation& mutate) {
118     Model model = utils::clone(originalModel).value();
119     ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER;
120     Priority priority = kDefaultPriority;
121     mutate(&model, &preference, &priority);
122 
123     if (validExecutionPreference(preference) && validExecutionPriority(priority)) {
124         validateGetSupportedOperations(device, message, model);
125     }
126 
127     validatePrepareModel(device, message, model, preference, priority);
128 
129     int32_t aidlVersion;
130     ASSERT_TRUE(device->getInterfaceVersion(&aidlVersion).isOk());
131     if (aidlVersion >= kMinAidlLevelForFL8) {
132         // prepareModelWithConfig must satisfy all requirements enforced by prepareModel.
133         validatePrepareModelWithConfig(device, message, model, preference, priority);
134     }
135 }
136 
addOperand(Model * model)137 static uint32_t addOperand(Model* model) {
138     model->main.operands.push_back({
139             .type = OperandType::INT32,
140             .dimensions = {},
141             .scale = 0.0f,
142             .zeroPoint = 0,
143             .lifetime = OperandLifeTime::SUBGRAPH_INPUT,
144             .location = {.poolIndex = 0, .offset = 0, .length = 0},
145     });
146     return model->main.operands.size() - 1;
147 }
148 
addOperand(Model * model,OperandLifeTime lifetime)149 static uint32_t addOperand(Model* model, OperandLifeTime lifetime) {
150     uint32_t index = addOperand(model);
151     model->main.operands[index].lifetime = lifetime;
152     return index;
153 }
154 
155 // If we introduce a CONSTANT_COPY for an operand of size operandSize,
156 // how much will this increase the size of the model?  This assumes
157 // that we can (re)use all of model.operandValues for the operand
158 // value.
constantCopyExtraSize(const Model & model,size_t operandSize)159 static size_t constantCopyExtraSize(const Model& model, size_t operandSize) {
160     const size_t operandValuesSize = model.operandValues.size();
161     return (operandValuesSize < operandSize) ? (operandSize - operandValuesSize) : 0;
162 }
163 
164 // Highly specialized utility routine for converting an operand to
165 // CONSTANT_COPY lifetime.
166 //
167 // Expects that:
168 // - operand has a known size
169 // - operand->lifetime has already been set to CONSTANT_COPY
170 // - operand->location has been zeroed out
171 //
172 // Does the following:
173 // - initializes operand->location to point to the beginning of model->operandValues
174 // - resizes model->operandValues (if necessary) to be large enough for the operand
175 //   value, padding it with zeroes on the end
176 //
177 // Potential problem:
178 // By changing the operand to CONSTANT_COPY lifetime, this function is effectively initializing the
179 // operand with unspecified (but deterministic) data. This means that the model may be invalidated
180 // in two ways: not only is the lifetime of CONSTANT_COPY invalid, but the operand's value in the
181 // graph may also be invalid (e.g., if the operand is used as an activation code and has an invalid
182 // value). For now, this should be fine because it just means we're not testing what we think we're
183 // testing in certain cases; but we can handwave this and assume we're probabilistically likely to
184 // exercise the validation code over the span of the entire test set and operand space.
185 //
186 // Aborts if the specified operand type is an extension type or OEM type.
becomeConstantCopy(Model * model,Operand * operand)187 static void becomeConstantCopy(Model* model, Operand* operand) {
188     // sizeOfData will abort if the specified type is an extension type or OEM type.
189     const size_t sizeOfOperand = sizeOfData(*operand);
190     EXPECT_NE(sizeOfOperand, size_t(0));
191     operand->location.poolIndex = 0;
192     operand->location.offset = 0;
193     operand->location.length = sizeOfOperand;
194     if (model->operandValues.size() < sizeOfOperand) {
195         model->operandValues.resize(sizeOfOperand);
196     }
197 }
198 
199 // The sizeForBinder() functions estimate the size of the
200 // representation of a value when sent to binder.  It's probably a bit
201 // of an under-estimate, because we don't know the size of the
202 // metadata in the binder format (e.g., representation of the size of
203 // a vector); but at least it adds up "big" things like vector
204 // contents.  However, it doesn't treat inter-field or end-of-struct
205 // padding in a methodical way -- there's no attempt to be consistent
206 // in whether or not padding in the native (C++) representation
207 // contributes to the estimated size for the binder representation;
208 // and there's no attempt to understand what padding (if any) is
209 // needed in the binder representation.
210 //
211 // This assumes that non-metadata uses a fixed length encoding (e.g.,
212 // a uint32_t is always encoded in sizeof(uint32_t) bytes, rather than
213 // using an encoding whose length is related to the magnitude of the
214 // encoded value).
215 
216 template <typename Type>
sizeForBinder(const Type & val)217 static size_t sizeForBinder(const Type& val) {
218     static_assert(std::is_trivially_copyable_v<std::remove_reference_t<Type>>,
219                   "expected a trivially copyable type");
220     return sizeof(val);
221 }
222 
223 template <typename Type>
sizeForBinder(const std::vector<Type> & vec)224 static size_t sizeForBinder(const std::vector<Type>& vec) {
225     return std::accumulate(vec.begin(), vec.end(), 0,
226                            [](size_t acc, const Type& x) { return acc + sizeForBinder(x); });
227 }
228 
229 template <>
sizeForBinder(const SymmPerChannelQuantParams & symmPerChannelQuantParams)230 size_t sizeForBinder(const SymmPerChannelQuantParams& symmPerChannelQuantParams) {
231     size_t size = 0;
232 
233     size += sizeForBinder(symmPerChannelQuantParams.scales);
234     size += sizeForBinder(symmPerChannelQuantParams.channelDim);
235 
236     return size;
237 }
238 
239 template <>
sizeForBinder(const std::optional<OperandExtraParams> & optionalExtraParams)240 size_t sizeForBinder(const std::optional<OperandExtraParams>& optionalExtraParams) {
241     if (!optionalExtraParams.has_value()) {
242         return 0;
243     }
244     const auto& extraParams = optionalExtraParams.value();
245     using Tag = OperandExtraParams::Tag;
246     switch (extraParams.getTag()) {
247         case Tag::channelQuant:
248             return sizeForBinder(extraParams.get<Tag::channelQuant>());
249         case Tag::extension:
250             return sizeForBinder(extraParams.get<Tag::extension>());
251     }
252     LOG(FATAL) << "Unrecognized extraParams tag: " << static_cast<int>(extraParams.getTag());
253     return 0;
254 }
255 
256 template <>
sizeForBinder(const Operand & operand)257 size_t sizeForBinder(const Operand& operand) {
258     size_t size = 0;
259 
260     size += sizeForBinder(operand.type);
261     size += sizeForBinder(operand.dimensions);
262     size += sizeForBinder(operand.scale);
263     size += sizeForBinder(operand.zeroPoint);
264     size += sizeForBinder(operand.lifetime);
265     size += sizeForBinder(operand.location);
266     size += sizeForBinder(operand.extraParams);
267 
268     return size;
269 }
270 
271 template <>
sizeForBinder(const Operation & operation)272 size_t sizeForBinder(const Operation& operation) {
273     size_t size = 0;
274 
275     size += sizeForBinder(operation.type);
276     size += sizeForBinder(operation.inputs);
277     size += sizeForBinder(operation.outputs);
278 
279     return size;
280 }
281 
282 template <>
sizeForBinder(const std::string & name)283 size_t sizeForBinder(const std::string& name) {
284     return name.size();
285 }
286 
287 template <>
sizeForBinder(const Memory & memory)288 size_t sizeForBinder(const Memory& memory) {
289     // This is just a guess.
290 
291     size_t size = sizeof(Memory);
292 
293     // Only hardwareBuffer type memory has dynamic memory that needs to be accounted for (in the
294     // form of a NativeHandle type). The other other types of memory (MappableFile, Ashmem) use a
295     // single file descriptor (with metadata) instead.
296     if (memory.getTag() == Memory::Tag::hardwareBuffer) {
297         const NativeHandle& handle = memory.get<Memory::Tag::hardwareBuffer>().handle;
298         size += sizeof(decltype(handle.fds)::value_type) * handle.fds.size();
299         size += sizeof(decltype(handle.ints)::value_type) * handle.ints.size();
300     }
301 
302     return size;
303 }
304 
305 template <>
sizeForBinder(const Subgraph & subgraph)306 size_t sizeForBinder(const Subgraph& subgraph) {
307     size_t size = 0;
308 
309     size += sizeForBinder(subgraph.operands);
310     size += sizeForBinder(subgraph.operations);
311     size += sizeForBinder(subgraph.inputIndexes);
312     size += sizeForBinder(subgraph.outputIndexes);
313 
314     return size;
315 }
316 
317 template <>
sizeForBinder(const ExtensionNameAndPrefix & extensionNameToPrefix)318 size_t sizeForBinder(const ExtensionNameAndPrefix& extensionNameToPrefix) {
319     size_t size = 0;
320 
321     size += sizeForBinder(extensionNameToPrefix.name);
322     size += sizeForBinder(extensionNameToPrefix.prefix);
323 
324     return size;
325 }
326 
327 template <>
sizeForBinder(const Model & model)328 size_t sizeForBinder(const Model& model) {
329     size_t size = 0;
330 
331     size += sizeForBinder(model.main);
332     size += sizeForBinder(model.referenced);
333     size += sizeForBinder(model.operandValues);
334     size += sizeForBinder(model.pools);
335     size += sizeForBinder(model.relaxComputationFloat32toFloat16);
336     size += sizeForBinder(model.extensionNameToPrefix);
337 
338     return size;
339 }
340 
341 // https://developer.android.com/reference/android/os/TransactionTooLargeException.html
342 //
343 //     "The Binder transaction buffer has a limited fixed size,
344 //     currently 1Mb, which is shared by all transactions in progress
345 //     for the process."
346 //
347 // Will our representation fit under this limit?  There are three complications:
348 // - Our representation size is just approximate (see sizeForBinder()).
349 // - This object may not be the only occupant of the Binder transaction buffer
350 //   (although our VTS test suite should not be putting multiple objects in the
351 //   buffer at once).
352 // - IBinder.MAX_IPC_SIZE recommends limiting a transaction to 64 * 1024 bytes.
353 // So we'll be very conservative: We want the representation size to be no
354 // larger than half the recommended limit.
355 //
356 // If our representation grows large enough that it still fits within
357 // the transaction buffer but combined with other transactions may
358 // exceed the buffer size, then we may see intermittent HAL transport
359 // errors.
exceedsBinderSizeLimit(size_t representationSize)360 static bool exceedsBinderSizeLimit(size_t representationSize) {
361     // There is no C++ API to retrieve the value of the Java variable IBinder.MAX_IPC_SIZE.
362     static const size_t kHalfMaxIPCSize = 64 * 1024 / 2;
363 
364     return representationSize > kHalfMaxIPCSize;
365 }
366 
367 ///////////////////////// VALIDATE EXECUTION ORDER ////////////////////////////
368 
mutateExecutionOrderTest(const std::shared_ptr<IDevice> & device,const Model & model,const std::vector<uint32_t> & numberOfConsumers)369 static void mutateExecutionOrderTest(const std::shared_ptr<IDevice>& device, const Model& model,
370                                      const std::vector<uint32_t>& numberOfConsumers) {
371     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
372         const Operation& operationObj = model.main.operations[operation];
373         for (uint32_t input : operationObj.inputs) {
374             if (model.main.operands[input].lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
375                 model.main.operands[input].lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
376                 // This operation reads an operand written by some
377                 // other operation.  Move this operation to the
378                 // beginning of the sequence, ensuring that it reads
379                 // the operand before that operand is written, thereby
380                 // violating execution order rules.
381                 const std::string message = "mutateExecutionOrderTest: operation " +
382                                             std::to_string(operation) + " is a reader";
383                 validate(device, message, model,
384                          [operation](Model* model, ExecutionPreference*, Priority*) {
385                              auto& operations = model->main.operations;
386                              std::rotate(operations.begin(), operations.begin() + operation,
387                                          operations.begin() + operation + 1);
388                          });
389                 break;  // only need to do this once per operation
390             }
391         }
392         for (uint32_t output : operationObj.outputs) {
393             if (numberOfConsumers[output] > 0) {
394                 // This operation writes an operand read by some other
395                 // operation.  Move this operation to the end of the
396                 // sequence, ensuring that it writes the operand after
397                 // that operand is read, thereby violating execution
398                 // order rules.
399                 const std::string message = "mutateExecutionOrderTest: operation " +
400                                             std::to_string(operation) + " is a writer";
401                 validate(device, message, model,
402                          [operation](Model* model, ExecutionPreference*, Priority*) {
403                              auto& operations = model->main.operations;
404                              std::rotate(operations.begin() + operation,
405                                          operations.begin() + operation + 1, operations.end());
406                          });
407                 break;  // only need to do this once per operation
408             }
409         }
410     }
411 }
412 
413 ///////////////////////// VALIDATE MODEL OPERAND TYPE /////////////////////////
414 
415 static const int32_t invalidOperandTypes[] = {
416         -1,
417         static_cast<int32_t>(*(ndk::enum_range<OperandType>().end() - 1)) + 1,
418 };
419 
mutateOperandTypeTest(const std::shared_ptr<IDevice> & device,const Model & model)420 static void mutateOperandTypeTest(const std::shared_ptr<IDevice>& device, const Model& model) {
421     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
422         for (int32_t invalidOperandType : invalidOperandTypes) {
423             const std::string message = "mutateOperandTypeTest: operand " +
424                                         std::to_string(operand) + " set to value " +
425                                         std::to_string(invalidOperandType);
426             validate(device, message, model,
427                      [operand, invalidOperandType](Model* model, ExecutionPreference*, Priority*) {
428                          model->main.operands[operand].type =
429                                  static_cast<OperandType>(invalidOperandType);
430                      });
431         }
432     }
433 }
434 
435 ///////////////////////// VALIDATE OPERAND RANK /////////////////////////
436 
getInvalidRank(OperandType type)437 static uint32_t getInvalidRank(OperandType type) {
438     switch (type) {
439         case OperandType::FLOAT16:
440         case OperandType::FLOAT32:
441         case OperandType::INT32:
442         case OperandType::UINT32:
443         case OperandType::BOOL:
444             return 1;
445         case OperandType::TENSOR_BOOL8:
446         case OperandType::TENSOR_FLOAT16:
447         case OperandType::TENSOR_FLOAT32:
448         case OperandType::TENSOR_INT32:
449         case OperandType::TENSOR_QUANT8_ASYMM:
450         case OperandType::TENSOR_QUANT8_SYMM:
451         case OperandType::TENSOR_QUANT16_ASYMM:
452         case OperandType::TENSOR_QUANT16_SYMM:
453         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
454             return 0;
455         default:
456             return 0;
457     }
458 }
459 
mutateOperandRankTest(const std::shared_ptr<IDevice> & device,const Model & model)460 static void mutateOperandRankTest(const std::shared_ptr<IDevice>& device, const Model& model) {
461     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
462         const uint32_t invalidRank = getInvalidRank(model.main.operands[operand].type);
463         if (invalidRank == 0) {
464             continue;
465         }
466         const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
467                                     " has rank of " + std::to_string(invalidRank);
468         validate(device, message, model,
469                  [operand, invalidRank](Model* model, ExecutionPreference*, Priority*) {
470                      model->main.operands[operand].dimensions =
471                              std::vector<int32_t>(invalidRank, 0);
472                  });
473     }
474 }
475 
476 ///////////////////////// VALIDATE OPERAND SCALE /////////////////////////
477 
getInvalidScale(OperandType type)478 static float getInvalidScale(OperandType type) {
479     switch (type) {
480         case OperandType::FLOAT16:
481         case OperandType::FLOAT32:
482         case OperandType::INT32:
483         case OperandType::UINT32:
484         case OperandType::BOOL:
485         case OperandType::TENSOR_BOOL8:
486         case OperandType::TENSOR_FLOAT16:
487         case OperandType::TENSOR_FLOAT32:
488         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
489         case OperandType::SUBGRAPH:
490             return 1.0f;
491         case OperandType::TENSOR_INT32:
492             return -1.0f;
493         case OperandType::TENSOR_QUANT8_SYMM:
494         case OperandType::TENSOR_QUANT8_ASYMM:
495         case OperandType::TENSOR_QUANT16_ASYMM:
496         case OperandType::TENSOR_QUANT16_SYMM:
497             return 0.0f;
498         default:
499             return 0.0f;
500     }
501 }
502 
mutateOperandScaleTest(const std::shared_ptr<IDevice> & device,const Model & model)503 static void mutateOperandScaleTest(const std::shared_ptr<IDevice>& device, const Model& model) {
504     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
505         const float invalidScale = getInvalidScale(model.main.operands[operand].type);
506         const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
507                                     " has scale of " + std::to_string(invalidScale);
508         validate(device, message, model,
509                  [operand, invalidScale](Model* model, ExecutionPreference*, Priority*) {
510                      model->main.operands[operand].scale = invalidScale;
511                  });
512     }
513 }
514 
515 ///////////////////////// VALIDATE OPERAND ZERO POINT /////////////////////////
516 
getInvalidZeroPoints(OperandType type)517 static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
518     switch (type) {
519         case OperandType::FLOAT16:
520         case OperandType::FLOAT32:
521         case OperandType::INT32:
522         case OperandType::UINT32:
523         case OperandType::BOOL:
524         case OperandType::TENSOR_BOOL8:
525         case OperandType::TENSOR_FLOAT16:
526         case OperandType::TENSOR_FLOAT32:
527         case OperandType::TENSOR_INT32:
528         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
529         case OperandType::SUBGRAPH:
530             return {1};
531         case OperandType::TENSOR_QUANT8_ASYMM:
532             return {-1, 256};
533         case OperandType::TENSOR_QUANT8_SYMM:
534             return {-129, -1, 1, 128};
535         case OperandType::TENSOR_QUANT16_ASYMM:
536             return {-1, 65536};
537         case OperandType::TENSOR_QUANT16_SYMM:
538             return {-32769, -1, 1, 32768};
539         default:
540             return {};
541     }
542 }
543 
mutateOperandZeroPointTest(const std::shared_ptr<IDevice> & device,const Model & model)544 static void mutateOperandZeroPointTest(const std::shared_ptr<IDevice>& device, const Model& model) {
545     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
546         const std::vector<int32_t> invalidZeroPoints =
547                 getInvalidZeroPoints(model.main.operands[operand].type);
548         for (int32_t invalidZeroPoint : invalidZeroPoints) {
549             const std::string message = "mutateOperandZeroPointTest: operand " +
550                                         std::to_string(operand) + " has zero point of " +
551                                         std::to_string(invalidZeroPoint);
552             validate(device, message, model,
553                      [operand, invalidZeroPoint](Model* model, ExecutionPreference*, Priority*) {
554                          model->main.operands[operand].zeroPoint = invalidZeroPoint;
555                      });
556         }
557     }
558 }
559 
560 ///////////////////////// VALIDATE OPERAND LIFETIME /////////////////////////////////////////////
561 
getInvalidLifeTimes(const Model & model,size_t modelSize,const Operand & operand)562 static std::vector<OperandLifeTime> getInvalidLifeTimes(const Model& model, size_t modelSize,
563                                                         const Operand& operand) {
564     // TODO: Support OperandLifeTime::CONSTANT_REFERENCE as an invalid lifetime
565     // TODO: Support OperandLifeTime::NO_VALUE as an invalid lifetime
566 
567     // Ways to get an invalid lifetime:
568     // - change whether a lifetime means an operand should have a writer
569     std::vector<OperandLifeTime> ret;
570     switch (operand.lifetime) {
571         case OperandLifeTime::SUBGRAPH_OUTPUT:
572         case OperandLifeTime::TEMPORARY_VARIABLE:
573             ret = {
574                     OperandLifeTime::SUBGRAPH_INPUT,
575                     OperandLifeTime::CONSTANT_COPY,
576             };
577             break;
578         case OperandLifeTime::CONSTANT_COPY:
579         case OperandLifeTime::CONSTANT_POOL:
580         case OperandLifeTime::SUBGRAPH_INPUT:
581             ret = {
582                     OperandLifeTime::TEMPORARY_VARIABLE,
583                     OperandLifeTime::SUBGRAPH_OUTPUT,
584             };
585             break;
586         case OperandLifeTime::NO_VALUE:
587             // Not enough information to know whether
588             // TEMPORARY_VARIABLE or CONSTANT_COPY would be invalid --
589             // is this operand written (then CONSTANT_COPY would be
590             // invalid) or not (then TEMPORARY_VARIABLE would be
591             // invalid)?
592             break;
593         case OperandLifeTime::SUBGRAPH:
594             break;
595         default:
596             ADD_FAILURE();
597             break;
598     }
599 
600     const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
601     if (!operandSize ||
602         exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
603         // Unknown size or too-large size
604         ret.erase(std::remove(ret.begin(), ret.end(), OperandLifeTime::CONSTANT_COPY), ret.end());
605     }
606 
607     return ret;
608 }
609 
mutateOperandLifeTimeTest(const std::shared_ptr<IDevice> & device,const Model & model)610 static void mutateOperandLifeTimeTest(const std::shared_ptr<IDevice>& device, const Model& model) {
611     const size_t modelSize = sizeForBinder(model);
612     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
613         const std::vector<OperandLifeTime> invalidLifeTimes =
614                 getInvalidLifeTimes(model, modelSize, model.main.operands[operand]);
615         for (OperandLifeTime invalidLifeTime : invalidLifeTimes) {
616             const std::string message = "mutateOperandLifetimeTest: operand " +
617                                         std::to_string(operand) + " has lifetime " +
618                                         toString(invalidLifeTime) + " instead of lifetime " +
619                                         toString(model.main.operands[operand].lifetime);
620             validate(device, message, model,
621                      [operand, invalidLifeTime](Model* model, ExecutionPreference*, Priority*) {
622                          static const DataLocation kZeroDataLocation = {};
623                          Operand& operandObj = model->main.operands[operand];
624                          switch (operandObj.lifetime) {
625                              case OperandLifeTime::SUBGRAPH_INPUT: {
626                                  auto& inputs = model->main.inputIndexes;
627                                  inputs.erase(std::remove(inputs.begin(), inputs.end(), operand),
628                                               inputs.end());
629                                  break;
630                              }
631                              case OperandLifeTime::SUBGRAPH_OUTPUT: {
632                                  auto& outputs = model->main.outputIndexes;
633                                  outputs.erase(std::remove(outputs.begin(), outputs.end(), operand),
634                                                outputs.end());
635                                  break;
636                              }
637                              default:
638                                  break;
639                          }
640                          operandObj.lifetime = invalidLifeTime;
641                          operandObj.location = kZeroDataLocation;
642                          switch (invalidLifeTime) {
643                              case OperandLifeTime::CONSTANT_COPY: {
644                                  becomeConstantCopy(model, &operandObj);
645                                  break;
646                              }
647                              case OperandLifeTime::SUBGRAPH_INPUT:
648                                  model->main.inputIndexes.push_back(operand);
649                                  break;
650                              case OperandLifeTime::SUBGRAPH_OUTPUT:
651                                  model->main.outputIndexes.push_back(operand);
652                                  break;
653                              default:
654                                  break;
655                          }
656                      });
657         }
658     }
659 }
660 
661 ///////////////////////// VALIDATE OPERAND INPUT-or-OUTPUT //////////////////////////////////////
662 
getInputOutputLifeTime(const Model & model,size_t modelSize,const Operand & operand)663 static std::optional<OperandLifeTime> getInputOutputLifeTime(const Model& model, size_t modelSize,
664                                                              const Operand& operand) {
665     // Ways to get an invalid lifetime (with respect to model inputIndexes and outputIndexes):
666     // - change whether a lifetime means an operand is a model input, a model output, or neither
667     // - preserve whether or not a lifetime means an operand should have a writer
668     switch (operand.lifetime) {
669         case OperandLifeTime::CONSTANT_COPY:
670         case OperandLifeTime::CONSTANT_POOL:
671             return OperandLifeTime::SUBGRAPH_INPUT;
672         case OperandLifeTime::SUBGRAPH_INPUT: {
673             const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
674             if (!operandSize ||
675                 exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
676                 // Unknown size or too-large size
677                 break;
678             }
679             return OperandLifeTime::CONSTANT_COPY;
680         }
681         case OperandLifeTime::SUBGRAPH_OUTPUT:
682             return OperandLifeTime::TEMPORARY_VARIABLE;
683         case OperandLifeTime::TEMPORARY_VARIABLE:
684             return OperandLifeTime::SUBGRAPH_OUTPUT;
685         case OperandLifeTime::NO_VALUE:
686             // Not enough information to know whether
687             // TEMPORARY_VARIABLE or CONSTANT_COPY would be an
688             // appropriate choice -- is this operand written (then
689             // TEMPORARY_VARIABLE would be appropriate) or not (then
690             // CONSTANT_COPY would be appropriate)?
691             break;
692         case OperandLifeTime::SUBGRAPH:
693             break;
694         default:
695             ADD_FAILURE();
696             break;
697     }
698 
699     return std::nullopt;
700 }
701 
mutateOperandInputOutputTest(const std::shared_ptr<IDevice> & device,const Model & model)702 static void mutateOperandInputOutputTest(const std::shared_ptr<IDevice>& device,
703                                          const Model& model) {
704     const size_t modelSize = sizeForBinder(model);
705     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
706         const std::optional<OperandLifeTime> changedLifeTime =
707                 getInputOutputLifeTime(model, modelSize, model.main.operands[operand]);
708         if (changedLifeTime) {
709             const std::string message = "mutateOperandInputOutputTest: operand " +
710                                         std::to_string(operand) + " has lifetime " +
711                                         toString(*changedLifeTime) + " instead of lifetime " +
712                                         toString(model.main.operands[operand].lifetime);
713             validate(device, message, model,
714                      [operand, changedLifeTime](Model* model, ExecutionPreference*, Priority*) {
715                          static const DataLocation kZeroDataLocation = {};
716                          Operand& operandObj = model->main.operands[operand];
717                          operandObj.lifetime = *changedLifeTime;
718                          operandObj.location = kZeroDataLocation;
719                          if (*changedLifeTime == OperandLifeTime::CONSTANT_COPY) {
720                              becomeConstantCopy(model, &operandObj);
721                          }
722                      });
723         }
724     }
725 }
726 
727 ///////////////////////// VALIDATE OPERAND NUMBER OF WRITERS ////////////////////////////////////
728 
mutateOperandAddWriterTest(const std::shared_ptr<IDevice> & device,const Model & model)729 static void mutateOperandAddWriterTest(const std::shared_ptr<IDevice>& device, const Model& model) {
730     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
731         for (size_t badOutputNum = 0;
732              badOutputNum < model.main.operations[operation].outputs.size(); ++badOutputNum) {
733             const uint32_t outputOperandIndex =
734                     model.main.operations[operation].outputs[badOutputNum];
735             const std::string message = "mutateOperandAddWriterTest: operation " +
736                                         std::to_string(operation) + " writes to " +
737                                         std::to_string(outputOperandIndex);
738             // We'll insert a copy of the operation, all of whose
739             // OTHER output operands are newly-created -- i.e.,
740             // there'll only be a duplicate write of ONE of that
741             // operation's output operands.
742             validate(device, message, model,
743                      [operation, badOutputNum](Model* model, ExecutionPreference*, Priority*) {
744                          Operation newOperation = model->main.operations[operation];
745                          for (size_t outputNum = 0; outputNum < newOperation.outputs.size();
746                               ++outputNum) {
747                              if (outputNum == badOutputNum) continue;
748 
749                              Operand operandValue =
750                                      model->main.operands[newOperation.outputs[outputNum]];
751                              if (operandValue.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
752                                  operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
753                              } else {
754                                  ASSERT_EQ(operandValue.lifetime,
755                                            OperandLifeTime::TEMPORARY_VARIABLE);
756                              }
757                              newOperation.outputs[outputNum] = model->main.operands.size();
758                              model->main.operands.push_back(operandValue);
759                          }
760                          // Where do we insert the extra writer (a new
761                          // operation)?  It has to be later than all the
762                          // writers of its inputs.  The easiest thing to do
763                          // is to insert it at the end of the operation
764                          // sequence.
765                          model->main.operations.push_back(newOperation);
766                      });
767         }
768     }
769 }
770 
771 ///////////////////////// VALIDATE EXTRA ??? /////////////////////////
772 
773 // TODO: Operand::location
774 
775 ///////////////////////// VALIDATE OPERATION OPERAND TYPE /////////////////////////
776 
mutateOperand(Operand * operand,OperandType type)777 static void mutateOperand(Operand* operand, OperandType type) {
778     Operand newOperand = *operand;
779     newOperand.type = type;
780     switch (type) {
781         case OperandType::FLOAT16:
782         case OperandType::FLOAT32:
783         case OperandType::INT32:
784         case OperandType::UINT32:
785         case OperandType::BOOL:
786             newOperand.dimensions = {};
787             newOperand.scale = 0.0f;
788             newOperand.zeroPoint = 0;
789             break;
790         case OperandType::TENSOR_BOOL8:
791         case OperandType::TENSOR_FLOAT16:
792         case OperandType::TENSOR_FLOAT32:
793             newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
794                                                                    : std::vector<int32_t>({1});
795             newOperand.scale = 0.0f;
796             newOperand.zeroPoint = 0;
797             break;
798         case OperandType::TENSOR_INT32:
799             newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
800                                                                    : std::vector<int32_t>({1});
801             newOperand.zeroPoint = 0;
802             break;
803         case OperandType::TENSOR_QUANT8_ASYMM:
804         case OperandType::TENSOR_QUANT8_SYMM:
805         case OperandType::TENSOR_QUANT16_ASYMM:
806         case OperandType::TENSOR_QUANT16_SYMM:
807             newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
808                                                                    : std::vector<int32_t>({1});
809             newOperand.scale = operand->scale != 0.0f ? operand->scale : 1.0f;
810             break;
811         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
812             newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
813                                                                    : std::vector<int32_t>({1});
814             newOperand.scale = 0.0f;
815             newOperand.zeroPoint = 0;
816 
817             SymmPerChannelQuantParams channelQuant;
818             channelQuant.channelDim = 0;
819             channelQuant.scales = std::vector<float>(
820                     operand->dimensions.size() > 0 ? static_cast<size_t>(operand->dimensions[0])
821                                                    : 0);
822             for (size_t i = 0; i < channelQuant.scales.size(); ++i) {
823                 channelQuant.scales[i] = 1.0f;
824             }
825             newOperand.extraParams->set<OperandExtraParams::Tag::channelQuant>(
826                     std::move(channelQuant));
827         } break;
828         default:
829             break;
830     }
831     *operand = newOperand;
832 }
833 
mutateOperationOperandTypeSkip(size_t operand,OperandType type,const Model & model)834 static bool mutateOperationOperandTypeSkip(size_t operand, OperandType type, const Model& model) {
835     if (type == model.main.operands[operand].type) {
836         return true;
837     }
838     for (const Operation& operation : model.main.operations) {
839         // Skip mutateOperationOperandTypeTest for the following operations.
840         // - LSH_PROJECTION's second argument is allowed to have any type.
841         // - ARGMIN and ARGMAX's first argument can be any of
842         // TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
843         // - CAST's argument can be any of TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
844         // - RANDOM_MULTINOMIAL's argument can be either TENSOR_FLOAT16 or TENSOR_FLOAT32.
845         // - DEQUANTIZE input can be any of
846         // TENSOR_(QUANT8_ASYMM|QUANT8_ASYMM_SIGNED|QUANT8_SYMM|QUANT8_SYMM_PER_CHANNEL),
847         // output can be of either TENSOR_FLOAT16 or TENSOR_FLOAT32.
848         // - QUANTIZE input can be either TENSOR_FLOAT16 or TENSOR_FLOAT32
849         // - CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
850         // - DEPTHWISE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
851         // - GROUPED_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
852         // - TRANSPOSE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
853         // - AXIS_ALIGNED_BBOX_TRANSFORM bounding boxes (arg 1) can be of
854         //     TENSOR_QUANT8_ASYMM or TENSOR_QUANT8_ASYMM_SIGNED.
855         // - RANK's input can have any TENSOR_* type.
856         switch (operation.type) {
857             case OperationType::LSH_PROJECTION: {
858                 if (operand == operation.inputs[1]) {
859                     return true;
860                 }
861             } break;
862             case OperationType::CAST:
863             case OperationType::ARGMAX:
864             case OperationType::ARGMIN: {
865                 if (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32 ||
866                     type == OperandType::TENSOR_INT32 || type == OperandType::TENSOR_QUANT8_ASYMM ||
867                     type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
868                     return true;
869                 }
870             } break;
871             case OperationType::QUANTIZE: {
872                 if (operand == operation.inputs[0] &&
873                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
874                     return true;
875                 }
876                 if (operand == operation.outputs[0] &&
877                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
878                      type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
879                     return true;
880                 }
881             } break;
882             case OperationType::RANDOM_MULTINOMIAL: {
883                 if (operand == operation.inputs[0] &&
884                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
885                     return true;
886                 }
887             } break;
888             case OperationType::DEQUANTIZE: {
889                 if (operand == operation.inputs[0] &&
890                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
891                      type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
892                      type == OperandType::TENSOR_QUANT8_SYMM ||
893                      type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
894                     return true;
895                 }
896                 if (operand == operation.outputs[0] &&
897                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
898                     return true;
899                 }
900             } break;
901             case OperationType::TRANSPOSE_CONV_2D:
902             case OperationType::GROUPED_CONV_2D:
903             case OperationType::DEPTHWISE_CONV_2D:
904             case OperationType::CONV_2D: {
905                 if (operand == operation.inputs[1] &&
906                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
907                      type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
908                     return true;
909                 }
910             } break;
911             case OperationType::AXIS_ALIGNED_BBOX_TRANSFORM: {
912                 if (operand == operation.inputs[1] &&
913                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
914                      type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
915                     return true;
916                 }
917             } break;
918             case OperationType::RANK: {
919                 if (operand == operation.inputs[0] &&
920                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32 ||
921                      type == OperandType::TENSOR_INT32 ||
922                      type == OperandType::TENSOR_QUANT8_ASYMM ||
923                      type == OperandType::TENSOR_QUANT16_SYMM ||
924                      type == OperandType::TENSOR_BOOL8 ||
925                      type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL ||
926                      type == OperandType::TENSOR_QUANT16_ASYMM ||
927                      type == OperandType::TENSOR_QUANT8_SYMM ||
928                      type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
929                     return true;
930                 }
931             } break;
932             default:
933                 break;
934         }
935     }
936     return false;
937 }
938 
mutateOperationOperandTypeTest(const std::shared_ptr<IDevice> & device,const Model & model)939 static void mutateOperationOperandTypeTest(const std::shared_ptr<IDevice>& device,
940                                            const Model& model) {
941     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
942         for (OperandType invalidOperandType : ndk::enum_range<OperandType>()) {
943             if (mutateOperationOperandTypeSkip(operand, invalidOperandType, model)) {
944                 continue;
945             }
946             const std::string message = "mutateOperationOperandTypeTest: operand " +
947                                         std::to_string(operand) + " set to type " +
948                                         toString(invalidOperandType);
949             validate(device, message, model,
950                      [operand, invalidOperandType](Model* model, ExecutionPreference*, Priority*) {
951                          mutateOperand(&model->main.operands[operand], invalidOperandType);
952                      });
953         }
954     }
955 }
956 
957 ///////////////////////// VALIDATE MODEL OPERATION TYPE /////////////////////////
958 
959 static const int32_t invalidOperationTypes[] = {
960         -1,
961         static_cast<int32_t>(*(ndk::enum_range<OperationType>().end() - 1)) + 1,
962 };
963 
mutateOperationTypeTest(const std::shared_ptr<IDevice> & device,const Model & model)964 static void mutateOperationTypeTest(const std::shared_ptr<IDevice>& device, const Model& model) {
965     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
966         for (int32_t invalidOperationType : invalidOperationTypes) {
967             const std::string message = "mutateOperationTypeTest: operation " +
968                                         std::to_string(operation) + " set to value " +
969                                         std::to_string(invalidOperationType);
970             validate(device, message, model,
971                      [operation, invalidOperationType](Model* model, ExecutionPreference*,
972                                                        Priority*) {
973                          model->main.operations[operation].type =
974                                  static_cast<OperationType>(invalidOperationType);
975                      });
976         }
977     }
978 }
979 
980 ///////////////////////// VALIDATE MODEL OPERATION INPUT OPERAND INDEX /////////////////////////
981 
mutateOperationInputOperandIndexTest(const std::shared_ptr<IDevice> & device,const Model & model)982 static void mutateOperationInputOperandIndexTest(const std::shared_ptr<IDevice>& device,
983                                                  const Model& model) {
984     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
985         const uint32_t invalidOperand = model.main.operands.size();
986         for (size_t input = 0; input < model.main.operations[operation].inputs.size(); ++input) {
987             const std::string message = "mutateOperationInputOperandIndexTest: operation " +
988                                         std::to_string(operation) + " input " +
989                                         std::to_string(input);
990             validate(device, message, model,
991                      [operation, input, invalidOperand](Model* model, ExecutionPreference*,
992                                                         Priority*) {
993                          model->main.operations[operation].inputs[input] = invalidOperand;
994                      });
995         }
996     }
997 }
998 
999 ///////////////////////// VALIDATE MODEL OPERATION OUTPUT OPERAND INDEX /////////////////////////
1000 
mutateOperationOutputOperandIndexTest(const std::shared_ptr<IDevice> & device,const Model & model)1001 static void mutateOperationOutputOperandIndexTest(const std::shared_ptr<IDevice>& device,
1002                                                   const Model& model) {
1003     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1004         const uint32_t invalidOperand = model.main.operands.size();
1005         for (size_t output = 0; output < model.main.operations[operation].outputs.size();
1006              ++output) {
1007             const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
1008                                         std::to_string(operation) + " output " +
1009                                         std::to_string(output);
1010             validate(device, message, model,
1011                      [operation, output, invalidOperand](Model* model, ExecutionPreference*,
1012                                                          Priority*) {
1013                          model->main.operations[operation].outputs[output] = invalidOperand;
1014                      });
1015         }
1016     }
1017 }
1018 
1019 ///////////////////////// VALIDATE MODEL OPERANDS WRITTEN ///////////////////////////////////////
1020 
mutateOperationRemoveWriteTest(const std::shared_ptr<IDevice> & device,const Model & model,const std::vector<uint32_t> & numberOfConsumers)1021 static void mutateOperationRemoveWriteTest(const std::shared_ptr<IDevice>& device,
1022                                            const Model& model,
1023                                            const std::vector<uint32_t>& numberOfConsumers) {
1024     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1025         for (size_t outputNum = 0; outputNum < model.main.operations[operation].outputs.size();
1026              ++outputNum) {
1027             const uint32_t outputOperandIndex = model.main.operations[operation].outputs[outputNum];
1028             if (numberOfConsumers[outputOperandIndex] > 0) {
1029                 const std::string message = "mutateOperationRemoveWriteTest: operation " +
1030                                             std::to_string(operation) + " writes to " +
1031                                             std::to_string(outputOperandIndex);
1032                 validate(device, message, model,
1033                          [operation, outputNum](Model* model, ExecutionPreference*, Priority*) {
1034                              int32_t& outputOperandIndex =
1035                                      model->main.operations[operation].outputs[outputNum];
1036                              Operand operandValue = model->main.operands[outputOperandIndex];
1037                              if (operandValue.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
1038                                  operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
1039                              } else {
1040                                  ASSERT_EQ(operandValue.lifetime,
1041                                            OperandLifeTime::TEMPORARY_VARIABLE);
1042                              }
1043                              outputOperandIndex = model->main.operands.size();
1044                              model->main.operands.push_back(operandValue);
1045                          });
1046             }
1047         }
1048     }
1049 }
1050 
1051 ///////////////////////// REMOVE OPERAND FROM EVERYTHING /////////////////////////
1052 
removeValueAndDecrementGreaterValues(std::vector<int32_t> * vec,uint32_t value)1053 static void removeValueAndDecrementGreaterValues(std::vector<int32_t>* vec, uint32_t value) {
1054     if (vec) {
1055         // remove elements matching "value"
1056         vec->erase(std::remove(vec->begin(), vec->end(), value), vec->end());
1057 
1058         // decrement elements exceeding "value"
1059         std::transform(vec->begin(), vec->end(), vec->begin(),
1060                        [value](uint32_t v) { return v > value ? v-- : v; });
1061     }
1062 }
1063 
removeOperand(Model * model,uint32_t index)1064 static void removeOperand(Model* model, uint32_t index) {
1065     model->main.operands.erase(model->main.operands.begin() + index);
1066     for (Operation& operation : model->main.operations) {
1067         removeValueAndDecrementGreaterValues(&operation.inputs, index);
1068         removeValueAndDecrementGreaterValues(&operation.outputs, index);
1069     }
1070     removeValueAndDecrementGreaterValues(&model->main.inputIndexes, index);
1071     removeValueAndDecrementGreaterValues(&model->main.outputIndexes, index);
1072 }
1073 
removeOperandSkip(size_t operandIndex,const Model & model,const std::vector<uint32_t> & numberOfConsumers)1074 static bool removeOperandSkip(size_t operandIndex, const Model& model,
1075                               const std::vector<uint32_t>& numberOfConsumers) {
1076     if (numberOfConsumers[operandIndex] == 0) {
1077         // Removing an unused operand has no effect.
1078         return true;
1079     }
1080     for (const Operation& operation : model.main.operations) {
1081         // Skip removeOperandTest for the following operations.
1082         // - SPLIT's outputs are not checked during prepareModel.
1083         if (operation.type == OperationType::SPLIT) {
1084             for (const size_t index : operation.outputs) {
1085                 if (index == operandIndex) {
1086                     return true;
1087                 }
1088             }
1089         }
1090         // BIDIRECTIONAL_SEQUENCE_LSTM and BIDIRECTIONAL_SEQUENCE_RNN can have
1091         // either one, two, three or four outputs depending on their
1092         // mergeOutputs parameter and if state outputs are provided.
1093         // UNIDIRECTIONAL_SEQUENCE_LSTM and UNIDIRECTIONAL_SEQUENCE_RNN can have
1094         // either one or three outputs depending on whether state outputs are
1095         // provided.
1096         if (operation.type == OperationType::UNIDIRECTIONAL_SEQUENCE_LSTM ||
1097             operation.type == OperationType::UNIDIRECTIONAL_SEQUENCE_RNN ||
1098             operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_LSTM ||
1099             operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_RNN) {
1100             for (const size_t index : operation.outputs) {
1101                 if (index == operandIndex) {
1102                     return true;
1103                 }
1104             }
1105         }
1106     }
1107     return false;
1108 }
1109 
removeOperandTest(const std::shared_ptr<IDevice> & device,const Model & model,const std::vector<uint32_t> & numberOfConsumers)1110 static void removeOperandTest(const std::shared_ptr<IDevice>& device, const Model& model,
1111                               const std::vector<uint32_t>& numberOfConsumers) {
1112     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
1113         if (removeOperandSkip(operand, model, numberOfConsumers)) {
1114             continue;
1115         }
1116         const std::string message = "removeOperandTest: operand " + std::to_string(operand);
1117         validate(device, message, model, [operand](Model* model, ExecutionPreference*, Priority*) {
1118             removeOperand(model, operand);
1119         });
1120     }
1121 }
1122 
1123 ///////////////////////// REMOVE OPERATION /////////////////////////
1124 
removeOperation(Model * model,uint32_t index)1125 static void removeOperation(Model* model, uint32_t index) {
1126     auto& operations = model->main.operations;
1127     operations.erase(operations.begin() + index);
1128 }
1129 
removeOperationTest(const std::shared_ptr<IDevice> & device,const Model & model)1130 static void removeOperationTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1131     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1132         const std::string message = "removeOperationTest: operation " + std::to_string(operation);
1133         validate(device, message, model,
1134                  [operation](Model* model, ExecutionPreference*, Priority*) {
1135                      removeOperation(model, operation);
1136                  });
1137     }
1138 }
1139 
1140 ///////////////////////// REMOVE OPERATION INPUT /////////////////////////
1141 
removeOperationInputSkip(const Operation & op,size_t input)1142 static bool removeOperationInputSkip(const Operation& op, size_t input) {
1143     // Skip removeOperationInputTest for the following operations.
1144     // - CONCATENATION has at least 2 inputs, with the last element being INT32.
1145     // - CONV_2D, DEPTHWISE_CONV_2D, MAX_POOL_2D, AVERAGE_POOL_2D, L2_POOL_2D, RESIZE_BILINEAR,
1146     //   SPACE_TO_DEPTH, SPACE_TO_DEPTH, SPACE_TO_BATCH_ND, BATCH_TO_SPACE_ND can have an optional
1147     //   layout parameter.
1148     //   RESIZE_BILINEAR and RESIZE_NEAREST_NEIGHBOR can have optional
1149     //   align_corners and half_pixel_centers parameters.
1150     // - L2_NORMALIZATION, LOCAL_RESPONSE_NORMALIZATION, SOFTMAX can have an optional axis
1151     //   parameter.
1152     // - PACK has at least 2 inputs, with the first element being INT32.
1153     switch (op.type) {
1154         case OperationType::CONCATENATION: {
1155             if (op.inputs.size() > 2 && input != op.inputs.size() - 1) {
1156                 return true;
1157             }
1158         } break;
1159         case OperationType::DEPTHWISE_CONV_2D: {
1160             if ((op.inputs.size() == 12 && input == 11) || (op.inputs.size() == 9 && input == 8)) {
1161                 return true;
1162             }
1163         } break;
1164         case OperationType::CONV_2D:
1165         case OperationType::AVERAGE_POOL_2D:
1166         case OperationType::MAX_POOL_2D:
1167         case OperationType::L2_POOL_2D: {
1168             if ((op.inputs.size() == 11 && input == 10) || (op.inputs.size() == 8 && input == 7)) {
1169                 return true;
1170             }
1171         } break;
1172         case OperationType::RESIZE_BILINEAR: {
1173             if (op.inputs.size() >= 4 && input >= 3) {
1174                 return true;
1175             }
1176         } break;
1177         case OperationType::RESIZE_NEAREST_NEIGHBOR: {
1178             if (op.inputs.size() >= 5 && input >= 3) {
1179                 return true;
1180             }
1181         } break;
1182         case OperationType::SPACE_TO_DEPTH:
1183         case OperationType::DEPTH_TO_SPACE:
1184         case OperationType::BATCH_TO_SPACE_ND: {
1185             if (op.inputs.size() == 3 && input == 2) {
1186                 return true;
1187             }
1188         } break;
1189         case OperationType::SPACE_TO_BATCH_ND: {
1190             if (op.inputs.size() == 4 && input == 3) {
1191                 return true;
1192             }
1193         } break;
1194         case OperationType::L2_NORMALIZATION: {
1195             if (op.inputs.size() == 2 && input == 1) {
1196                 return true;
1197             }
1198         } break;
1199         case OperationType::LOCAL_RESPONSE_NORMALIZATION: {
1200             if (op.inputs.size() == 6 && input == 5) {
1201                 return true;
1202             }
1203         } break;
1204         case OperationType::SOFTMAX: {
1205             if (op.inputs.size() == 3 && input == 2) {
1206                 return true;
1207             }
1208         } break;
1209         case OperationType::PACK: {
1210             if (op.inputs.size() > 2 && input != 0) {
1211                 return true;
1212             }
1213         } break;
1214         default:
1215             break;
1216     }
1217     return false;
1218 }
1219 
removeOperationInputTest(const std::shared_ptr<IDevice> & device,const Model & model)1220 static void removeOperationInputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1221     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1222         for (size_t input = 0; input < model.main.operations[operation].inputs.size(); ++input) {
1223             const Operation& op = model.main.operations[operation];
1224             if (removeOperationInputSkip(op, input)) {
1225                 continue;
1226             }
1227             const std::string message = "removeOperationInputTest: operation " +
1228                                         std::to_string(operation) + ", input " +
1229                                         std::to_string(input);
1230             validate(device, message, model,
1231                      [operation, input](Model* model, ExecutionPreference*, Priority*) {
1232                          auto& inputs = model->main.operations[operation].inputs;
1233                          inputs.erase(inputs.begin() + input);
1234                      });
1235         }
1236     }
1237 }
1238 
1239 ///////////////////////// REMOVE OPERATION OUTPUT /////////////////////////
1240 
removeOperationOutputTest(const std::shared_ptr<IDevice> & device,const Model & model)1241 static void removeOperationOutputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1242     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1243         for (size_t output = 0; output < model.main.operations[operation].outputs.size();
1244              ++output) {
1245             const std::string message = "removeOperationOutputTest: operation " +
1246                                         std::to_string(operation) + ", output " +
1247                                         std::to_string(output);
1248             validate(device, message, model,
1249                      [operation, output](Model* model, ExecutionPreference*, Priority*) {
1250                          auto& outputs = model->main.operations[operation].outputs;
1251                          outputs.erase(outputs.begin() + output);
1252                      });
1253         }
1254     }
1255 }
1256 
1257 ///////////////////////// MODEL VALIDATION /////////////////////////
1258 
1259 // TODO: remove model input
1260 // TODO: remove model output
1261 // TODO: add unused operation
1262 
1263 ///////////////////////// ADD OPERATION INPUT /////////////////////////
1264 
addOperationInputSkip(const Operation & op)1265 static bool addOperationInputSkip(const Operation& op) {
1266     // Skip addOperationInputTest for the following operations.
1267     // - L2_NORMALIZATION, LOCAL_RESPONSE_NORMALIZATION, SOFTMAX can have an optional INT32 axis
1268     //   parameter.
1269     if ((op.type == OperationType::L2_NORMALIZATION && op.inputs.size() == 1) ||
1270         (op.type == OperationType::LOCAL_RESPONSE_NORMALIZATION && op.inputs.size() == 5) ||
1271         (op.type == OperationType::SOFTMAX && op.inputs.size() == 2) ||
1272         (op.type == OperationType::RESIZE_BILINEAR && op.inputs.size() < 6) ||
1273         (op.type == OperationType::RESIZE_NEAREST_NEIGHBOR && op.inputs.size() < 6)) {
1274         return true;
1275     }
1276     return false;
1277 }
1278 
addOperationInputTest(const std::shared_ptr<IDevice> & device,const Model & model)1279 static void addOperationInputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1280     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1281         if (addOperationInputSkip(model.main.operations[operation])) {
1282             continue;
1283         }
1284         const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
1285         validate(device, message, model,
1286                  [operation](Model* model, ExecutionPreference*, Priority*) {
1287                      uint32_t index = addOperand(model, OperandLifeTime::SUBGRAPH_INPUT);
1288                      model->main.operations[operation].inputs.push_back(index);
1289                      model->main.inputIndexes.push_back(index);
1290                  });
1291     }
1292 }
1293 
1294 ///////////////////////// ADD OPERATION OUTPUT /////////////////////////
1295 
addOperationOutputTest(const std::shared_ptr<IDevice> & device,const Model & model)1296 static void addOperationOutputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1297     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1298         const std::string message =
1299                 "addOperationOutputTest: operation " + std::to_string(operation);
1300         validate(device, message, model,
1301                  [operation](Model* model, ExecutionPreference*, Priority*) {
1302                      uint32_t index = addOperand(model, OperandLifeTime::SUBGRAPH_OUTPUT);
1303                      model->main.operations[operation].outputs.push_back(index);
1304                      model->main.outputIndexes.push_back(index);
1305                  });
1306     }
1307 }
1308 
1309 ///////////////////////// VALIDATE EXECUTION PREFERENCE /////////////////////////
1310 
1311 static const int32_t invalidExecutionPreferences[] = {
1312         static_cast<int32_t>(ExecutionPreference::LOW_POWER) - 1,        // lower bound
1313         static_cast<int32_t>(ExecutionPreference::SUSTAINED_SPEED) + 1,  // upper bound
1314 };
1315 
mutateExecutionPreferenceTest(const std::shared_ptr<IDevice> & device,const Model & model)1316 static void mutateExecutionPreferenceTest(const std::shared_ptr<IDevice>& device,
1317                                           const Model& model) {
1318     for (int32_t invalidPreference : invalidExecutionPreferences) {
1319         const std::string message =
1320                 "mutateExecutionPreferenceTest: preference " + std::to_string(invalidPreference);
1321         validate(device, message, model,
1322                  [invalidPreference](Model*, ExecutionPreference* preference, Priority*) {
1323                      *preference = static_cast<ExecutionPreference>(invalidPreference);
1324                  });
1325     }
1326 }
1327 
1328 ///////////////////////// VALIDATE PRIORITY /////////////////////////
1329 
1330 static const int32_t invalidPriorities[] = {
1331         static_cast<int32_t>(Priority::LOW) - 1,   // lower bound
1332         static_cast<int32_t>(Priority::HIGH) + 1,  // upper bound
1333 };
1334 
mutateExecutionPriorityTest(const std::shared_ptr<IDevice> & device,const Model & model)1335 static void mutateExecutionPriorityTest(const std::shared_ptr<IDevice>& device,
1336                                         const Model& model) {
1337     for (int32_t invalidPriority : invalidPriorities) {
1338         const std::string message =
1339                 "mutatePriorityTest: priority " + std::to_string(invalidPriority);
1340         validate(device, message, model,
1341                  [invalidPriority](Model*, ExecutionPreference*, Priority* priority) {
1342                      *priority = static_cast<Priority>(invalidPriority);
1343                  });
1344     }
1345 }
1346 
1347 ////////////////////////// ENTRY POINT //////////////////////////////
1348 
validateModel(const std::shared_ptr<IDevice> & device,const Model & model)1349 void validateModel(const std::shared_ptr<IDevice>& device, const Model& model) {
1350     const auto numberOfConsumers =
1351             countNumberOfConsumers(model.main.operands.size(),
1352                                    nn::unvalidatedConvert(model.main.operations).value())
1353                     .value();
1354     mutateExecutionOrderTest(device, model, numberOfConsumers);
1355     mutateOperandTypeTest(device, model);
1356     mutateOperandRankTest(device, model);
1357     mutateOperandScaleTest(device, model);
1358     mutateOperandZeroPointTest(device, model);
1359     mutateOperandLifeTimeTest(device, model);
1360     mutateOperandInputOutputTest(device, model);
1361     mutateOperandAddWriterTest(device, model);
1362     mutateOperationOperandTypeTest(device, model);
1363     mutateOperationTypeTest(device, model);
1364     mutateOperationInputOperandIndexTest(device, model);
1365     mutateOperationOutputOperandIndexTest(device, model);
1366     mutateOperationRemoveWriteTest(device, model, numberOfConsumers);
1367     removeOperandTest(device, model, numberOfConsumers);
1368     removeOperationTest(device, model);
1369     removeOperationInputTest(device, model);
1370     removeOperationOutputTest(device, model);
1371     addOperationInputTest(device, model);
1372     addOperationOutputTest(device, model);
1373     mutateExecutionPreferenceTest(device, model);
1374     mutateExecutionPriorityTest(device, model);
1375 }
1376 
1377 }  // namespace aidl::android::hardware::neuralnetworks::vts::functional
1378