xref: /aosp_15_r20/hardware/interfaces/neuralnetworks/1.2/vts/functional/ValidateModel.cpp (revision 4d7e907c777eeecc4c5bd7cf640a754fac206ff7)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18 
19 #include <android/hardware/neuralnetworks/1.1/types.h>
20 #include "1.0/Utils.h"
21 #include "1.2/Callbacks.h"
22 #include "1.2/Utils.h"
23 #include "GeneratedTestHarness.h"
24 #include "VtsHalNeuralnetworks.h"
25 
26 #include <optional>
27 #include <type_traits>
28 #include <utility>
29 
30 namespace android::hardware::neuralnetworks::V1_2::vts::functional {
31 
32 using implementation::PreparedModelCallback;
33 using V1_0::DataLocation;
34 using V1_0::ErrorStatus;
35 using V1_0::OperandLifeTime;
36 using V1_1::ExecutionPreference;
37 using HidlToken = hidl_array<uint8_t, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
38 
39 using PrepareModelMutation = std::function<void(Model*, ExecutionPreference*)>;
40 
41 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
42 
validateGetSupportedOperations(const sp<IDevice> & device,const std::string & message,const Model & model)43 static void validateGetSupportedOperations(const sp<IDevice>& device, const std::string& message,
44                                            const Model& model) {
45     SCOPED_TRACE(message + " [getSupportedOperations_1_2]");
46 
47     Return<void> ret = device->getSupportedOperations_1_2(
48             model, [&](ErrorStatus status, const hidl_vec<bool>&) {
49                 EXPECT_EQ(ErrorStatus::INVALID_ARGUMENT, status);
50             });
51     EXPECT_TRUE(ret.isOk());
52 }
53 
validatePrepareModel(const sp<IDevice> & device,const std::string & message,const Model & model,ExecutionPreference preference)54 static void validatePrepareModel(const sp<IDevice>& device, const std::string& message,
55                                  const Model& model, ExecutionPreference preference) {
56     SCOPED_TRACE(message + " [prepareModel_1_2]");
57 
58     sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
59     Return<ErrorStatus> prepareLaunchStatus =
60             device->prepareModel_1_2(model, preference, hidl_vec<hidl_handle>(),
61                                      hidl_vec<hidl_handle>(), HidlToken(), preparedModelCallback);
62     ASSERT_TRUE(prepareLaunchStatus.isOk());
63     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, static_cast<ErrorStatus>(prepareLaunchStatus));
64 
65     preparedModelCallback->wait();
66     ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
67     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
68     sp<IPreparedModel> preparedModel = getPreparedModel_1_2(preparedModelCallback);
69     ASSERT_EQ(nullptr, preparedModel.get());
70 }
71 
validExecutionPreference(ExecutionPreference preference)72 static bool validExecutionPreference(ExecutionPreference preference) {
73     return preference == ExecutionPreference::LOW_POWER ||
74            preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
75            preference == ExecutionPreference::SUSTAINED_SPEED;
76 }
77 
78 // Primary validation function. This function will take a valid model, apply a
79 // mutation to invalidate either the model or the execution preference, then
80 // pass these to supportedOperations and/or prepareModel if that method is
81 // called with an invalid argument.
validate(const sp<IDevice> & device,const std::string & message,const Model & originalModel,const PrepareModelMutation & mutate)82 static void validate(const sp<IDevice>& device, const std::string& message,
83                      const Model& originalModel, const PrepareModelMutation& mutate) {
84     Model model = originalModel;
85     ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER;
86     mutate(&model, &preference);
87 
88     if (validExecutionPreference(preference)) {
89         validateGetSupportedOperations(device, message, model);
90     }
91 
92     validatePrepareModel(device, message, model, preference);
93 }
94 
addOperand(Model * model)95 static uint32_t addOperand(Model* model) {
96     return hidl_vec_push_back(&model->operands,
97                               {
98                                       .type = OperandType::INT32,
99                                       .dimensions = {},
100                                       .numberOfConsumers = 0,
101                                       .scale = 0.0f,
102                                       .zeroPoint = 0,
103                                       .lifetime = OperandLifeTime::MODEL_INPUT,
104                                       .location = {.poolIndex = 0, .offset = 0, .length = 0},
105                               });
106 }
107 
addOperand(Model * model,OperandLifeTime lifetime)108 static uint32_t addOperand(Model* model, OperandLifeTime lifetime) {
109     uint32_t index = addOperand(model);
110     model->operands[index].numberOfConsumers = 1;
111     model->operands[index].lifetime = lifetime;
112     return index;
113 }
114 
115 // If we introduce a CONSTANT_COPY for an operand of size operandSize,
116 // how much will this increase the size of the model?  This assumes
117 // that we can (re)use all of model.operandValues for the operand
118 // value.
constantCopyExtraSize(const Model & model,size_t operandSize)119 static size_t constantCopyExtraSize(const Model& model, size_t operandSize) {
120     const size_t operandValuesSize = model.operandValues.size();
121     return (operandValuesSize < operandSize) ? (operandSize - operandValuesSize) : 0;
122 }
123 
124 // Highly specialized utility routine for converting an operand to
125 // CONSTANT_COPY lifetime.
126 //
127 // Expects that:
128 // - operand has a known size
129 // - operand->lifetime has already been set to CONSTANT_COPY
130 // - operand->location has been zeroed out
131 //
132 // Does the following:
133 // - initializes operand->location to point to the beginning of model->operandValues
134 // - resizes model->operandValues (if necessary) to be large enough for the operand
135 //   value, padding it with zeroes on the end
136 //
137 // Potential problem:
138 // By changing the operand to CONSTANT_COPY lifetime, this function is effectively initializing the
139 // operand with unspecified (but deterministic) data. This means that the model may be invalidated
140 // in two ways: not only is the lifetime of CONSTANT_COPY invalid, but the operand's value in the
141 // graph may also be invalid (e.g., if the operand is used as an activation code and has an invalid
142 // value). For now, this should be fine because it just means we're not testing what we think we're
143 // testing in certain cases; but we can handwave this and assume we're probabilistically likely to
144 // exercise the validation code over the span of the entire test set and operand space.
145 //
146 // Aborts if the specified operand type is an extension type or OEM type.
becomeConstantCopy(Model * model,Operand * operand)147 static void becomeConstantCopy(Model* model, Operand* operand) {
148     // sizeOfData will abort if the specified type is an extension type or OEM type.
149     const size_t sizeOfOperand = sizeOfData(*operand);
150     EXPECT_NE(sizeOfOperand, size_t(0));
151     operand->location.poolIndex = 0;
152     operand->location.offset = 0;
153     operand->location.length = sizeOfOperand;
154     if (model->operandValues.size() < sizeOfOperand) {
155         model->operandValues.resize(sizeOfOperand);
156     }
157 }
158 
159 // The sizeForBinder() functions estimate the size of the
160 // representation of a value when sent to binder.  It's probably a bit
161 // of an under-estimate, because we don't know the size of the
162 // metadata in the binder format (e.g., representation of the size of
163 // a vector); but at least it adds up "big" things like vector
164 // contents.  However, it doesn't treat inter-field or end-of-struct
165 // padding in a methodical way -- there's no attempt to be consistent
166 // in whether or not padding in the native (C++) representation
167 // contributes to the estimated size for the binder representation;
168 // and there's no attempt to understand what padding (if any) is
169 // needed in the binder representation.
170 //
171 // This assumes that non-metadata uses a fixed length encoding (e.g.,
172 // a uint32_t is always encoded in sizeof(uint32_t) bytes, rather than
173 // using an encoding whose length is related to the magnitude of the
174 // encoded value).
175 
176 template <typename Type>
sizeForBinder(const Type & val)177 static size_t sizeForBinder(const Type& val) {
178     static_assert(std::is_trivially_copyable_v<std::remove_reference_t<Type>>,
179                   "expected a trivially copyable type");
180     return sizeof(val);
181 }
182 
183 template <typename Type>
sizeForBinder(const hidl_vec<Type> & vec)184 static size_t sizeForBinder(const hidl_vec<Type>& vec) {
185     return std::accumulate(vec.begin(), vec.end(), 0,
186                            [](size_t acc, const Type& x) { return acc + sizeForBinder(x); });
187 }
188 
189 template <>
sizeForBinder(const SymmPerChannelQuantParams & symmPerChannelQuantParams)190 size_t sizeForBinder(const SymmPerChannelQuantParams& symmPerChannelQuantParams) {
191     size_t size = 0;
192 
193     size += sizeForBinder(symmPerChannelQuantParams.scales);
194     size += sizeForBinder(symmPerChannelQuantParams.channelDim);
195 
196     return size;
197 }
198 
199 template <>
sizeForBinder(const Operand::ExtraParams & extraParams)200 size_t sizeForBinder(const Operand::ExtraParams& extraParams) {
201     using Discriminator = Operand::ExtraParams::hidl_discriminator;
202     switch (extraParams.getDiscriminator()) {
203         case Discriminator::none:
204             return 0;
205         case Discriminator::channelQuant:
206             return sizeForBinder(extraParams.channelQuant());
207         case Discriminator::extension:
208             return sizeForBinder(extraParams.extension());
209     }
210     LOG(FATAL) << "Unrecognized extraParams enum: "
211                << static_cast<int>(extraParams.getDiscriminator());
212     return 0;
213 }
214 
215 template <>
sizeForBinder(const Operand & operand)216 size_t sizeForBinder(const Operand& operand) {
217     size_t size = 0;
218 
219     size += sizeForBinder(operand.type);
220     size += sizeForBinder(operand.dimensions);
221     size += sizeForBinder(operand.numberOfConsumers);
222     size += sizeForBinder(operand.scale);
223     size += sizeForBinder(operand.zeroPoint);
224     size += sizeForBinder(operand.lifetime);
225     size += sizeForBinder(operand.location);
226     size += sizeForBinder(operand.extraParams);
227 
228     return size;
229 }
230 
231 template <>
sizeForBinder(const Operation & operation)232 size_t sizeForBinder(const Operation& operation) {
233     size_t size = 0;
234 
235     size += sizeForBinder(operation.type);
236     size += sizeForBinder(operation.inputs);
237     size += sizeForBinder(operation.outputs);
238 
239     return size;
240 }
241 
242 template <>
sizeForBinder(const hidl_string & name)243 size_t sizeForBinder(const hidl_string& name) {
244     return name.size();
245 }
246 
247 template <>
sizeForBinder(const hidl_memory & memory)248 size_t sizeForBinder(const hidl_memory& memory) {
249     // This is just a guess.
250 
251     size_t size = 0;
252 
253     if (const native_handle_t* handle = memory.handle()) {
254         size += sizeof(*handle);
255         size += sizeof(handle->data[0] * (handle->numFds + handle->numInts));
256     }
257     size += sizeForBinder(memory.name());
258 
259     return size;
260 }
261 
262 template <>
sizeForBinder(const Model::ExtensionNameAndPrefix & extensionNameToPrefix)263 size_t sizeForBinder(const Model::ExtensionNameAndPrefix& extensionNameToPrefix) {
264     size_t size = 0;
265 
266     size += sizeForBinder(extensionNameToPrefix.name);
267     size += sizeForBinder(extensionNameToPrefix.prefix);
268 
269     return size;
270 }
271 
272 template <>
sizeForBinder(const Model & model)273 size_t sizeForBinder(const Model& model) {
274     size_t size = 0;
275 
276     size += sizeForBinder(model.operands);
277     size += sizeForBinder(model.operations);
278     size += sizeForBinder(model.inputIndexes);
279     size += sizeForBinder(model.outputIndexes);
280     size += sizeForBinder(model.operandValues);
281     size += sizeForBinder(model.pools);
282     size += sizeForBinder(model.relaxComputationFloat32toFloat16);
283     size += sizeForBinder(model.extensionNameToPrefix);
284 
285     return size;
286 }
287 
288 // https://developer.android.com/reference/android/os/TransactionTooLargeException.html
289 //
290 //     "The Binder transaction buffer has a limited fixed size,
291 //     currently 1Mb, which is shared by all transactions in progress
292 //     for the process."
293 //
294 // Will our representation fit under this limit?  There are three complications:
295 // - Our representation size is just approximate (see sizeForBinder()).
296 // - This object may not be the only occupant of the Binder transaction buffer
297 //   (although our VTS test suite should not be putting multiple objects in the
298 //   buffer at once).
299 // - IBinder.MAX_IPC_SIZE recommends limiting a transaction to 64 * 1024 bytes.
300 // So we'll be very conservative: We want the representation size to be no
301 // larger than half the recommended limit.
302 //
303 // If our representation grows large enough that it still fits within
304 // the transaction buffer but combined with other transactions may
305 // exceed the buffer size, then we may see intermittent HAL transport
306 // errors.
exceedsBinderSizeLimit(size_t representationSize)307 static bool exceedsBinderSizeLimit(size_t representationSize) {
308     // There is no C++ API to retrieve the value of the Java variable IBinder.MAX_IPC_SIZE.
309     static const size_t kHalfMaxIPCSize = 64 * 1024 / 2;
310 
311     return representationSize > kHalfMaxIPCSize;
312 }
313 
314 ///////////////////////// VALIDATE EXECUTION ORDER ////////////////////////////
315 
mutateExecutionOrderTest(const sp<IDevice> & device,const Model & model)316 static void mutateExecutionOrderTest(const sp<IDevice>& device, const Model& model) {
317     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
318         const Operation& operationObj = model.operations[operation];
319         for (uint32_t input : operationObj.inputs) {
320             if (model.operands[input].lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
321                 model.operands[input].lifetime == OperandLifeTime::MODEL_OUTPUT) {
322                 // This operation reads an operand written by some
323                 // other operation.  Move this operation to the
324                 // beginning of the sequence, ensuring that it reads
325                 // the operand before that operand is written, thereby
326                 // violating execution order rules.
327                 const std::string message = "mutateExecutionOrderTest: operation " +
328                                             std::to_string(operation) + " is a reader";
329                 validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
330                     auto& operations = model->operations;
331                     std::rotate(operations.begin(), operations.begin() + operation,
332                                 operations.begin() + operation + 1);
333                 });
334                 break;  // only need to do this once per operation
335             }
336         }
337         for (uint32_t output : operationObj.outputs) {
338             if (model.operands[output].numberOfConsumers > 0) {
339                 // This operation writes an operand read by some other
340                 // operation.  Move this operation to the end of the
341                 // sequence, ensuring that it writes the operand after
342                 // that operand is read, thereby violating execution
343                 // order rules.
344                 const std::string message = "mutateExecutionOrderTest: operation " +
345                                             std::to_string(operation) + " is a writer";
346                 validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
347                     auto& operations = model->operations;
348                     std::rotate(operations.begin() + operation, operations.begin() + operation + 1,
349                                 operations.end());
350                 });
351                 break;  // only need to do this once per operation
352             }
353         }
354     }
355 }
356 
357 ///////////////////////// VALIDATE MODEL OPERAND TYPE /////////////////////////
358 
359 static const uint32_t invalidOperandTypes[] = {
360         static_cast<uint32_t>(OperandTypeRange::FUNDAMENTAL_MIN) - 1,
361         static_cast<uint32_t>(OperandTypeRange::FUNDAMENTAL_MAX) + 1,
362         static_cast<uint32_t>(OperandTypeRange::OEM_MIN) - 1,
363         static_cast<uint32_t>(OperandTypeRange::OEM_MAX) + 1,
364 };
365 
mutateOperandTypeTest(const sp<IDevice> & device,const Model & model)366 static void mutateOperandTypeTest(const sp<IDevice>& device, const Model& model) {
367     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
368         for (uint32_t invalidOperandType : invalidOperandTypes) {
369             const std::string message = "mutateOperandTypeTest: operand " +
370                                         std::to_string(operand) + " set to value " +
371                                         std::to_string(invalidOperandType);
372             validate(device, message, model,
373                      [operand, invalidOperandType](Model* model, ExecutionPreference*) {
374                          model->operands[operand].type =
375                                  static_cast<OperandType>(invalidOperandType);
376                      });
377         }
378     }
379 }
380 
381 ///////////////////////// VALIDATE OPERAND RANK /////////////////////////
382 
getInvalidRank(OperandType type)383 static uint32_t getInvalidRank(OperandType type) {
384     switch (type) {
385         case OperandType::FLOAT16:
386         case OperandType::FLOAT32:
387         case OperandType::INT32:
388         case OperandType::UINT32:
389         case OperandType::BOOL:
390             return 1;
391         case OperandType::TENSOR_BOOL8:
392         case OperandType::TENSOR_FLOAT16:
393         case OperandType::TENSOR_FLOAT32:
394         case OperandType::TENSOR_INT32:
395         case OperandType::TENSOR_QUANT8_ASYMM:
396         case OperandType::TENSOR_QUANT8_SYMM:
397         case OperandType::TENSOR_QUANT16_ASYMM:
398         case OperandType::TENSOR_QUANT16_SYMM:
399         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
400             return 0;
401         default:
402             return 0;
403     }
404 }
405 
mutateOperandRankTest(const sp<IDevice> & device,const Model & model)406 static void mutateOperandRankTest(const sp<IDevice>& device, const Model& model) {
407     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
408         const uint32_t invalidRank = getInvalidRank(model.operands[operand].type);
409         if (invalidRank == 0) {
410             continue;
411         }
412         const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
413                                     " has rank of " + std::to_string(invalidRank);
414         validate(device, message, model,
415                  [operand, invalidRank](Model* model, ExecutionPreference*) {
416                      model->operands[operand].dimensions = std::vector<uint32_t>(invalidRank, 0);
417                  });
418     }
419 }
420 
421 ///////////////////////// VALIDATE OPERAND SCALE /////////////////////////
422 
getInvalidScale(OperandType type)423 static float getInvalidScale(OperandType type) {
424     switch (type) {
425         case OperandType::FLOAT16:
426         case OperandType::FLOAT32:
427         case OperandType::INT32:
428         case OperandType::UINT32:
429         case OperandType::BOOL:
430         case OperandType::TENSOR_BOOL8:
431         case OperandType::TENSOR_FLOAT16:
432         case OperandType::TENSOR_FLOAT32:
433         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
434             return 1.0f;
435         case OperandType::TENSOR_INT32:
436             return -1.0f;
437         case OperandType::TENSOR_QUANT8_SYMM:
438         case OperandType::TENSOR_QUANT8_ASYMM:
439         case OperandType::TENSOR_QUANT16_ASYMM:
440         case OperandType::TENSOR_QUANT16_SYMM:
441             return 0.0f;
442         default:
443             return 0.0f;
444     }
445 }
446 
mutateOperandScaleTest(const sp<IDevice> & device,const Model & model)447 static void mutateOperandScaleTest(const sp<IDevice>& device, const Model& model) {
448     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
449         const float invalidScale = getInvalidScale(model.operands[operand].type);
450         const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
451                                     " has scale of " + std::to_string(invalidScale);
452         validate(device, message, model,
453                  [operand, invalidScale](Model* model, ExecutionPreference*) {
454                      model->operands[operand].scale = invalidScale;
455                  });
456     }
457 }
458 
459 ///////////////////////// VALIDATE OPERAND ZERO POINT /////////////////////////
460 
getInvalidZeroPoints(OperandType type)461 static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
462     switch (type) {
463         case OperandType::FLOAT16:
464         case OperandType::FLOAT32:
465         case OperandType::INT32:
466         case OperandType::UINT32:
467         case OperandType::BOOL:
468         case OperandType::TENSOR_BOOL8:
469         case OperandType::TENSOR_FLOAT16:
470         case OperandType::TENSOR_FLOAT32:
471         case OperandType::TENSOR_INT32:
472         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
473             return {1};
474         case OperandType::TENSOR_QUANT8_ASYMM:
475             return {-1, 256};
476         case OperandType::TENSOR_QUANT8_SYMM:
477             return {-129, -1, 1, 128};
478         case OperandType::TENSOR_QUANT16_ASYMM:
479             return {-1, 65536};
480         case OperandType::TENSOR_QUANT16_SYMM:
481             return {-32769, -1, 1, 32768};
482         default:
483             return {};
484     }
485 }
486 
mutateOperandZeroPointTest(const sp<IDevice> & device,const Model & model)487 static void mutateOperandZeroPointTest(const sp<IDevice>& device, const Model& model) {
488     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
489         const std::vector<int32_t> invalidZeroPoints =
490                 getInvalidZeroPoints(model.operands[operand].type);
491         for (int32_t invalidZeroPoint : invalidZeroPoints) {
492             const std::string message = "mutateOperandZeroPointTest: operand " +
493                                         std::to_string(operand) + " has zero point of " +
494                                         std::to_string(invalidZeroPoint);
495             validate(device, message, model,
496                      [operand, invalidZeroPoint](Model* model, ExecutionPreference*) {
497                          model->operands[operand].zeroPoint = invalidZeroPoint;
498                      });
499         }
500     }
501 }
502 
503 ///////////////////////// VALIDATE OPERAND LIFETIME /////////////////////////////////////////////
504 
getInvalidLifeTimes(const Model & model,size_t modelSize,const Operand & operand)505 static std::vector<OperandLifeTime> getInvalidLifeTimes(const Model& model, size_t modelSize,
506                                                         const Operand& operand) {
507     // TODO: Support OperandLifeTime::CONSTANT_REFERENCE as an invalid lifetime
508     // TODO: Support OperandLifeTime::NO_VALUE as an invalid lifetime
509 
510     // Ways to get an invalid lifetime:
511     // - change whether a lifetime means an operand should have a writer
512     std::vector<OperandLifeTime> ret;
513     switch (operand.lifetime) {
514         case OperandLifeTime::MODEL_OUTPUT:
515         case OperandLifeTime::TEMPORARY_VARIABLE:
516             ret = {
517                     OperandLifeTime::MODEL_INPUT,
518                     OperandLifeTime::CONSTANT_COPY,
519             };
520             break;
521         case OperandLifeTime::CONSTANT_COPY:
522         case OperandLifeTime::CONSTANT_REFERENCE:
523         case OperandLifeTime::MODEL_INPUT:
524             ret = {
525                     OperandLifeTime::TEMPORARY_VARIABLE,
526                     OperandLifeTime::MODEL_OUTPUT,
527             };
528             break;
529         case OperandLifeTime::NO_VALUE:
530             // Not enough information to know whether
531             // TEMPORARY_VARIABLE or CONSTANT_COPY would be invalid --
532             // is this operand written (then CONSTANT_COPY would be
533             // invalid) or not (then TEMPORARY_VARIABLE would be
534             // invalid)?
535             break;
536         default:
537             ADD_FAILURE();
538             break;
539     }
540 
541     const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
542     if (!operandSize ||
543         exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
544         // Unknown size or too-large size
545         ret.erase(std::remove(ret.begin(), ret.end(), OperandLifeTime::CONSTANT_COPY), ret.end());
546     }
547 
548     return ret;
549 }
550 
mutateOperandLifeTimeTest(const sp<IDevice> & device,const Model & model)551 static void mutateOperandLifeTimeTest(const sp<IDevice>& device, const Model& model) {
552     const size_t modelSize = sizeForBinder(model);
553     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
554         const std::vector<OperandLifeTime> invalidLifeTimes =
555                 getInvalidLifeTimes(model, modelSize, model.operands[operand]);
556         for (OperandLifeTime invalidLifeTime : invalidLifeTimes) {
557             const std::string message = "mutateOperandLifetimeTest: operand " +
558                                         std::to_string(operand) + " has lifetime " +
559                                         toString(invalidLifeTime) + " instead of lifetime " +
560                                         toString(model.operands[operand].lifetime);
561             validate(device, message, model,
562                      [operand, invalidLifeTime](Model* model, ExecutionPreference*) {
563                          static const DataLocation kZeroDataLocation = {};
564                          Operand& operandObj = model->operands[operand];
565                          switch (operandObj.lifetime) {
566                              case OperandLifeTime::MODEL_INPUT: {
567                                  hidl_vec_remove(&model->inputIndexes, uint32_t(operand));
568                                  break;
569                              }
570                              case OperandLifeTime::MODEL_OUTPUT: {
571                                  hidl_vec_remove(&model->outputIndexes, uint32_t(operand));
572                                  break;
573                              }
574                              default:
575                                  break;
576                          }
577                          operandObj.lifetime = invalidLifeTime;
578                          operandObj.location = kZeroDataLocation;
579                          switch (invalidLifeTime) {
580                              case OperandLifeTime::CONSTANT_COPY: {
581                                  becomeConstantCopy(model, &operandObj);
582                                  break;
583                              }
584                              case OperandLifeTime::MODEL_INPUT:
585                                  hidl_vec_push_back(&model->inputIndexes, uint32_t(operand));
586                                  break;
587                              case OperandLifeTime::MODEL_OUTPUT:
588                                  hidl_vec_push_back(&model->outputIndexes, uint32_t(operand));
589                                  break;
590                              default:
591                                  break;
592                          }
593                      });
594         }
595     }
596 }
597 
598 ///////////////////////// VALIDATE OPERAND INPUT-or-OUTPUT //////////////////////////////////////
599 
getInputOutputLifeTime(const Model & model,size_t modelSize,const Operand & operand)600 static std::optional<OperandLifeTime> getInputOutputLifeTime(const Model& model, size_t modelSize,
601                                                              const Operand& operand) {
602     // Ways to get an invalid lifetime (with respect to model inputIndexes and outputIndexes):
603     // - change whether a lifetime means an operand is a model input, a model output, or neither
604     // - preserve whether or not a lifetime means an operand should have a writer
605     switch (operand.lifetime) {
606         case OperandLifeTime::CONSTANT_COPY:
607         case OperandLifeTime::CONSTANT_REFERENCE:
608             return OperandLifeTime::MODEL_INPUT;
609         case OperandLifeTime::MODEL_INPUT: {
610             const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
611             if (!operandSize ||
612                 exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
613                 // Unknown size or too-large size
614                 break;
615             }
616             return OperandLifeTime::CONSTANT_COPY;
617         }
618         case OperandLifeTime::MODEL_OUTPUT:
619             return OperandLifeTime::TEMPORARY_VARIABLE;
620         case OperandLifeTime::TEMPORARY_VARIABLE:
621             return OperandLifeTime::MODEL_OUTPUT;
622         case OperandLifeTime::NO_VALUE:
623             // Not enough information to know whether
624             // TEMPORARY_VARIABLE or CONSTANT_COPY would be an
625             // appropriate choice -- is this operand written (then
626             // TEMPORARY_VARIABLE would be appropriate) or not (then
627             // CONSTANT_COPY would be appropriate)?
628             break;
629         default:
630             ADD_FAILURE();
631             break;
632     }
633 
634     return std::nullopt;
635 }
636 
mutateOperandInputOutputTest(const sp<IDevice> & device,const Model & model)637 static void mutateOperandInputOutputTest(const sp<IDevice>& device, const Model& model) {
638     const size_t modelSize = sizeForBinder(model);
639     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
640         const std::optional<OperandLifeTime> changedLifeTime =
641                 getInputOutputLifeTime(model, modelSize, model.operands[operand]);
642         if (changedLifeTime) {
643             const std::string message = "mutateOperandInputOutputTest: operand " +
644                                         std::to_string(operand) + " has lifetime " +
645                                         toString(*changedLifeTime) + " instead of lifetime " +
646                                         toString(model.operands[operand].lifetime);
647             validate(device, message, model,
648                      [operand, changedLifeTime](Model* model, ExecutionPreference*) {
649                          static const DataLocation kZeroDataLocation = {};
650                          Operand& operandObj = model->operands[operand];
651                          operandObj.lifetime = *changedLifeTime;
652                          operandObj.location = kZeroDataLocation;
653                          if (*changedLifeTime == OperandLifeTime::CONSTANT_COPY) {
654                              becomeConstantCopy(model, &operandObj);
655                          }
656                      });
657         }
658     }
659 }
660 
661 ///////////////////////// VALIDATE OPERAND NUMBER OF CONSUMERS //////////////////////////////////
662 
getInvalidNumberOfConsumers(uint32_t numberOfConsumers)663 static std::vector<uint32_t> getInvalidNumberOfConsumers(uint32_t numberOfConsumers) {
664     if (numberOfConsumers == 0) {
665         return {1};
666     } else {
667         return {numberOfConsumers - 1, numberOfConsumers + 1};
668     }
669 }
670 
mutateOperandNumberOfConsumersTest(const sp<IDevice> & device,const Model & model)671 static void mutateOperandNumberOfConsumersTest(const sp<IDevice>& device, const Model& model) {
672     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
673         const std::vector<uint32_t> invalidNumberOfConsumersVec =
674                 getInvalidNumberOfConsumers(model.operands[operand].numberOfConsumers);
675         for (uint32_t invalidNumberOfConsumers : invalidNumberOfConsumersVec) {
676             const std::string message =
677                     "mutateOperandNumberOfConsumersTest: operand " + std::to_string(operand) +
678                     " numberOfConsumers = " + std::to_string(invalidNumberOfConsumers);
679             validate(device, message, model,
680                      [operand, invalidNumberOfConsumers](Model* model, ExecutionPreference*) {
681                          model->operands[operand].numberOfConsumers = invalidNumberOfConsumers;
682                      });
683         }
684     }
685 }
686 
687 ///////////////////////// VALIDATE OPERAND NUMBER OF WRITERS ////////////////////////////////////
688 
mutateOperandAddWriterTest(const sp<IDevice> & device,const Model & model)689 static void mutateOperandAddWriterTest(const sp<IDevice>& device, const Model& model) {
690     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
691         for (size_t badOutputNum = 0; badOutputNum < model.operations[operation].outputs.size();
692              ++badOutputNum) {
693             const uint32_t outputOperandIndex = model.operations[operation].outputs[badOutputNum];
694             const std::string message = "mutateOperandAddWriterTest: operation " +
695                                         std::to_string(operation) + " writes to " +
696                                         std::to_string(outputOperandIndex);
697             // We'll insert a copy of the operation, all of whose
698             // OTHER output operands are newly-created -- i.e.,
699             // there'll only be a duplicate write of ONE of that
700             // operation's output operands.
701             validate(device, message, model,
702                      [operation, badOutputNum](Model* model, ExecutionPreference*) {
703                          Operation newOperation = model->operations[operation];
704                          for (uint32_t input : newOperation.inputs) {
705                              ++model->operands[input].numberOfConsumers;
706                          }
707                          for (size_t outputNum = 0; outputNum < newOperation.outputs.size();
708                               ++outputNum) {
709                              if (outputNum == badOutputNum) continue;
710 
711                              Operand operandValue =
712                                      model->operands[newOperation.outputs[outputNum]];
713                              operandValue.numberOfConsumers = 0;
714                              if (operandValue.lifetime == OperandLifeTime::MODEL_OUTPUT) {
715                                  operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
716                              } else {
717                                  ASSERT_EQ(operandValue.lifetime,
718                                            OperandLifeTime::TEMPORARY_VARIABLE);
719                              }
720                              newOperation.outputs[outputNum] =
721                                      hidl_vec_push_back(&model->operands, operandValue);
722                          }
723                          // Where do we insert the extra writer (a new
724                          // operation)?  It has to be later than all the
725                          // writers of its inputs.  The easiest thing to do
726                          // is to insert it at the end of the operation
727                          // sequence.
728                          hidl_vec_push_back(&model->operations, newOperation);
729                      });
730         }
731     }
732 }
733 
734 ///////////////////////// VALIDATE EXTRA ??? /////////////////////////
735 
736 // TODO: Operand::location
737 
738 ///////////////////////// VALIDATE OPERATION OPERAND TYPE /////////////////////////
739 
mutateOperand(Operand * operand,OperandType type)740 static void mutateOperand(Operand* operand, OperandType type) {
741     Operand newOperand = *operand;
742     newOperand.type = type;
743     switch (type) {
744         case OperandType::FLOAT16:
745         case OperandType::FLOAT32:
746         case OperandType::INT32:
747         case OperandType::UINT32:
748         case OperandType::BOOL:
749             newOperand.dimensions = hidl_vec<uint32_t>();
750             newOperand.scale = 0.0f;
751             newOperand.zeroPoint = 0;
752             break;
753         case OperandType::TENSOR_BOOL8:
754         case OperandType::TENSOR_FLOAT16:
755         case OperandType::TENSOR_FLOAT32:
756             newOperand.dimensions =
757                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
758             newOperand.scale = 0.0f;
759             newOperand.zeroPoint = 0;
760             break;
761         case OperandType::TENSOR_INT32:
762             newOperand.dimensions =
763                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
764             newOperand.zeroPoint = 0;
765             break;
766         case OperandType::TENSOR_QUANT8_ASYMM:
767         case OperandType::TENSOR_QUANT8_SYMM:
768         case OperandType::TENSOR_QUANT16_ASYMM:
769         case OperandType::TENSOR_QUANT16_SYMM:
770             newOperand.dimensions =
771                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
772             newOperand.scale = operand->scale != 0.0f ? operand->scale : 1.0f;
773             break;
774         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
775             newOperand.dimensions =
776                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
777             newOperand.scale = 0.0f;
778             newOperand.zeroPoint = 0;
779 
780             SymmPerChannelQuantParams channelQuant;
781             channelQuant.channelDim = 0;
782             channelQuant.scales = hidl_vec<float>(
783                     operand->dimensions.size() > 0 ? static_cast<size_t>(operand->dimensions[0])
784                                                    : 0);
785             for (size_t i = 0; i < channelQuant.scales.size(); ++i) {
786                 channelQuant.scales[i] = 1.0f;
787             }
788             newOperand.extraParams.channelQuant(std::move(channelQuant));
789         } break;
790         case OperandType::OEM:
791         case OperandType::TENSOR_OEM_BYTE:
792         default:
793             break;
794     }
795     *operand = newOperand;
796 }
797 
mutateOperationOperandTypeSkip(size_t operand,OperandType type,const Model & model)798 static bool mutateOperationOperandTypeSkip(size_t operand, OperandType type, const Model& model) {
799     // Do not test OEM types
800     if (type == model.operands[operand].type || type == OperandType::OEM ||
801         type == OperandType::TENSOR_OEM_BYTE) {
802         return true;
803     }
804     for (const Operation& operation : model.operations) {
805         // Skip mutateOperationOperandTypeTest for the following operations.
806         // - LSH_PROJECTION's second argument is allowed to have any type.
807         // - ARGMIN and ARGMAX's first argument can be any of
808         // TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
809         // - CAST's argument can be any of TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
810         // - RANDOM_MULTINOMIAL's argument can be either TENSOR_FLOAT16 or TENSOR_FLOAT32.
811         // - DEQUANTIZE input can be any of
812         // TENSOR_(QUANT8_ASYMM|QUANT8_SYMM|QUANT8_SYMM_PER_CHANNEL), output can
813         // be of either TENSOR_FLOAT16 or TENSOR_FLOAT32.
814         // - QUANTIZE input can be either TENSOR_FLOAT16 or TENSOR_FLOAT32
815         // - CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
816         // - DEPTHWISE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
817         // - GROUPED_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
818         // - TRANSPOSE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
819         switch (operation.type) {
820             case OperationType::LSH_PROJECTION: {
821                 if (operand == operation.inputs[1]) {
822                     return true;
823                 }
824             } break;
825             case OperationType::CAST:
826             case OperationType::ARGMAX:
827             case OperationType::ARGMIN: {
828                 if (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32 ||
829                     type == OperandType::TENSOR_INT32 || type == OperandType::TENSOR_QUANT8_ASYMM) {
830                     return true;
831                 }
832             } break;
833             case OperationType::QUANTIZE:
834             case OperationType::RANDOM_MULTINOMIAL: {
835                 if (operand == operation.inputs[0] &&
836                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
837                     return true;
838                 }
839             } break;
840             case OperationType::DEQUANTIZE: {
841                 if (operand == operation.inputs[0] &&
842                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
843                      type == OperandType::TENSOR_QUANT8_SYMM ||
844                      type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
845                     return true;
846                 }
847                 if (operand == operation.outputs[0] &&
848                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
849                     return true;
850                 }
851             } break;
852             case OperationType::TRANSPOSE_CONV_2D:
853             case OperationType::GROUPED_CONV_2D:
854             case OperationType::DEPTHWISE_CONV_2D:
855             case OperationType::CONV_2D: {
856                 if (operand == operation.inputs[1] &&
857                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
858                      type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
859                     return true;
860                 }
861             } break;
862             default:
863                 break;
864         }
865     }
866     return false;
867 }
868 
mutateOperationOperandTypeTest(const sp<IDevice> & device,const Model & model)869 static void mutateOperationOperandTypeTest(const sp<IDevice>& device, const Model& model) {
870     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
871         for (OperandType invalidOperandType : hidl_enum_range<OperandType>{}) {
872             if (mutateOperationOperandTypeSkip(operand, invalidOperandType, model)) {
873                 continue;
874             }
875             const std::string message = "mutateOperationOperandTypeTest: operand " +
876                                         std::to_string(operand) + " set to type " +
877                                         toString(invalidOperandType);
878             validate(device, message, model,
879                      [operand, invalidOperandType](Model* model, ExecutionPreference*) {
880                          mutateOperand(&model->operands[operand], invalidOperandType);
881                      });
882         }
883     }
884 }
885 
886 ///////////////////////// VALIDATE MODEL OPERATION TYPE /////////////////////////
887 
888 static const uint32_t invalidOperationTypes[] = {
889         static_cast<uint32_t>(OperationTypeRange::FUNDAMENTAL_MAX) + 1,
890         static_cast<uint32_t>(OperationTypeRange::OEM_MIN) - 1,
891         static_cast<uint32_t>(OperationTypeRange::OEM_MAX) + 1,
892 };
893 
mutateOperationTypeTest(const sp<IDevice> & device,const Model & model)894 static void mutateOperationTypeTest(const sp<IDevice>& device, const Model& model) {
895     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
896         for (uint32_t invalidOperationType : invalidOperationTypes) {
897             const std::string message = "mutateOperationTypeTest: operation " +
898                                         std::to_string(operation) + " set to value " +
899                                         std::to_string(invalidOperationType);
900             validate(device, message, model,
901                      [operation, invalidOperationType](Model* model, ExecutionPreference*) {
902                          model->operations[operation].type =
903                                  static_cast<OperationType>(invalidOperationType);
904                      });
905         }
906     }
907 }
908 
909 ///////////////////////// VALIDATE MODEL OPERATION INPUT OPERAND INDEX /////////////////////////
910 
mutateOperationInputOperandIndexTest(const sp<IDevice> & device,const Model & model)911 static void mutateOperationInputOperandIndexTest(const sp<IDevice>& device, const Model& model) {
912     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
913         const uint32_t invalidOperand = model.operands.size();
914         for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) {
915             const std::string message = "mutateOperationInputOperandIndexTest: operation " +
916                                         std::to_string(operation) + " input " +
917                                         std::to_string(input);
918             validate(device, message, model,
919                      [operation, input, invalidOperand](Model* model, ExecutionPreference*) {
920                          model->operations[operation].inputs[input] = invalidOperand;
921                      });
922         }
923     }
924 }
925 
926 ///////////////////////// VALIDATE MODEL OPERATION OUTPUT OPERAND INDEX /////////////////////////
927 
mutateOperationOutputOperandIndexTest(const sp<IDevice> & device,const Model & model)928 static void mutateOperationOutputOperandIndexTest(const sp<IDevice>& device, const Model& model) {
929     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
930         const uint32_t invalidOperand = model.operands.size();
931         for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) {
932             const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
933                                         std::to_string(operation) + " output " +
934                                         std::to_string(output);
935             validate(device, message, model,
936                      [operation, output, invalidOperand](Model* model, ExecutionPreference*) {
937                          model->operations[operation].outputs[output] = invalidOperand;
938                      });
939         }
940     }
941 }
942 
943 ///////////////////////// VALIDATE MODEL OPERANDS WRITTEN ///////////////////////////////////////
944 
mutateOperationRemoveWriteTest(const sp<IDevice> & device,const Model & model)945 static void mutateOperationRemoveWriteTest(const sp<IDevice>& device, const Model& model) {
946     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
947         for (size_t outputNum = 0; outputNum < model.operations[operation].outputs.size();
948              ++outputNum) {
949             const uint32_t outputOperandIndex = model.operations[operation].outputs[outputNum];
950             if (model.operands[outputOperandIndex].numberOfConsumers > 0) {
951                 const std::string message = "mutateOperationRemoveWriteTest: operation " +
952                                             std::to_string(operation) + " writes to " +
953                                             std::to_string(outputOperandIndex);
954                 validate(device, message, model,
955                          [operation, outputNum](Model* model, ExecutionPreference*) {
956                              uint32_t& outputOperandIndex =
957                                      model->operations[operation].outputs[outputNum];
958                              Operand operandValue = model->operands[outputOperandIndex];
959                              operandValue.numberOfConsumers = 0;
960                              if (operandValue.lifetime == OperandLifeTime::MODEL_OUTPUT) {
961                                  operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
962                              } else {
963                                  ASSERT_EQ(operandValue.lifetime,
964                                            OperandLifeTime::TEMPORARY_VARIABLE);
965                              }
966                              outputOperandIndex =
967                                      hidl_vec_push_back(&model->operands, operandValue);
968                          });
969             }
970         }
971     }
972 }
973 
974 ///////////////////////// REMOVE OPERAND FROM EVERYTHING /////////////////////////
975 
removeValueAndDecrementGreaterValues(hidl_vec<uint32_t> * vec,uint32_t value)976 static void removeValueAndDecrementGreaterValues(hidl_vec<uint32_t>* vec, uint32_t value) {
977     if (vec) {
978         // remove elements matching "value"
979         auto last = std::remove(vec->begin(), vec->end(), value);
980         vec->resize(std::distance(vec->begin(), last));
981 
982         // decrement elements exceeding "value"
983         std::transform(vec->begin(), vec->end(), vec->begin(),
984                        [value](uint32_t v) { return v > value ? v-- : v; });
985     }
986 }
987 
removeOperand(Model * model,uint32_t index)988 static void removeOperand(Model* model, uint32_t index) {
989     hidl_vec_removeAt(&model->operands, index);
990     for (Operation& operation : model->operations) {
991         removeValueAndDecrementGreaterValues(&operation.inputs, index);
992         removeValueAndDecrementGreaterValues(&operation.outputs, index);
993     }
994     removeValueAndDecrementGreaterValues(&model->inputIndexes, index);
995     removeValueAndDecrementGreaterValues(&model->outputIndexes, index);
996 }
997 
removeOperandSkip(size_t operand,const Model & model)998 static bool removeOperandSkip(size_t operand, const Model& model) {
999     for (const Operation& operation : model.operations) {
1000         // Skip removeOperandTest for the following operations.
1001         // - SPLIT's outputs are not checked during prepareModel.
1002         if (operation.type == OperationType::SPLIT) {
1003             for (const size_t outOprand : operation.outputs) {
1004                 if (operand == outOprand) {
1005                     return true;
1006                 }
1007             }
1008         }
1009         // BIDIRECTIONAL_SEQUENCE_LSTM and BIDIRECTIONAL_SEQUENCE_RNN can have either one or two
1010         // outputs depending on their mergeOutputs parameter.
1011         if (operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_LSTM ||
1012             operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_RNN) {
1013             for (const size_t outOprand : operation.outputs) {
1014                 if (operand == outOprand) {
1015                     return true;
1016                 }
1017             }
1018         }
1019     }
1020     return false;
1021 }
1022 
removeOperandTest(const sp<IDevice> & device,const Model & model)1023 static void removeOperandTest(const sp<IDevice>& device, const Model& model) {
1024     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
1025         if (removeOperandSkip(operand, model)) {
1026             continue;
1027         }
1028         const std::string message = "removeOperandTest: operand " + std::to_string(operand);
1029         validate(device, message, model,
1030                  [operand](Model* model, ExecutionPreference*) { removeOperand(model, operand); });
1031     }
1032 }
1033 
1034 ///////////////////////// REMOVE OPERATION /////////////////////////
1035 
removeOperation(Model * model,uint32_t index)1036 static void removeOperation(Model* model, uint32_t index) {
1037     for (uint32_t operand : model->operations[index].inputs) {
1038         model->operands[operand].numberOfConsumers--;
1039     }
1040     hidl_vec_removeAt(&model->operations, index);
1041 }
1042 
removeOperationTest(const sp<IDevice> & device,const Model & model)1043 static void removeOperationTest(const sp<IDevice>& device, const Model& model) {
1044     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
1045         const std::string message = "removeOperationTest: operation " + std::to_string(operation);
1046         validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
1047             removeOperation(model, operation);
1048         });
1049     }
1050 }
1051 
1052 ///////////////////////// REMOVE OPERATION INPUT /////////////////////////
1053 
removeOperationInputSkip(const Operation & op,size_t input)1054 static bool removeOperationInputSkip(const Operation& op, size_t input) {
1055     // Skip removeOperationInputTest for the following operations.
1056     // - CONCATENATION has at least 2 inputs, with the last element being INT32.
1057     // - CONV_2D, DEPTHWISE_CONV_2D, MAX_POOL_2D, AVERAGE_POOL_2D, L2_POOL_2D, RESIZE_BILINEAR,
1058     //   SPACE_TO_DEPTH, SPACE_TO_DEPTH, SPACE_TO_BATCH_ND, BATCH_TO_SPACE_ND can have an optional
1059     //   layout parameter.
1060     // - L2_NORMALIZATION, LOCAL_RESPONSE_NORMALIZATION, SOFTMAX can have an optional axis
1061     //   parameter.
1062     switch (op.type) {
1063         case OperationType::CONCATENATION: {
1064             if (op.inputs.size() > 2 && input != op.inputs.size() - 1) {
1065                 return true;
1066             }
1067         } break;
1068         case OperationType::DEPTHWISE_CONV_2D: {
1069             if ((op.inputs.size() == 12 && input == 11) || (op.inputs.size() == 9 && input == 8)) {
1070                 return true;
1071             }
1072         } break;
1073         case OperationType::CONV_2D:
1074         case OperationType::AVERAGE_POOL_2D:
1075         case OperationType::MAX_POOL_2D:
1076         case OperationType::L2_POOL_2D: {
1077             if ((op.inputs.size() == 11 && input == 10) || (op.inputs.size() == 8 && input == 7)) {
1078                 return true;
1079             }
1080         } break;
1081         case OperationType::RESIZE_BILINEAR: {
1082             if (op.inputs.size() == 4 && input == 3) {
1083                 return true;
1084             }
1085         } break;
1086         case OperationType::SPACE_TO_DEPTH:
1087         case OperationType::DEPTH_TO_SPACE:
1088         case OperationType::BATCH_TO_SPACE_ND: {
1089             if (op.inputs.size() == 3 && input == 2) {
1090                 return true;
1091             }
1092         } break;
1093         case OperationType::SPACE_TO_BATCH_ND: {
1094             if (op.inputs.size() == 4 && input == 3) {
1095                 return true;
1096             }
1097         } break;
1098         case OperationType::L2_NORMALIZATION: {
1099             if (op.inputs.size() == 2 && input == 1) {
1100                 return true;
1101             }
1102         } break;
1103         case OperationType::LOCAL_RESPONSE_NORMALIZATION: {
1104             if (op.inputs.size() == 6 && input == 5) {
1105                 return true;
1106             }
1107         } break;
1108         case OperationType::SOFTMAX: {
1109             if (op.inputs.size() == 3 && input == 2) {
1110                 return true;
1111             }
1112         } break;
1113         default:
1114             break;
1115     }
1116     return false;
1117 }
1118 
removeOperationInputTest(const sp<IDevice> & device,const Model & model)1119 static void removeOperationInputTest(const sp<IDevice>& device, const Model& model) {
1120     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
1121         for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) {
1122             const Operation& op = model.operations[operation];
1123             if (removeOperationInputSkip(op, input)) {
1124                 continue;
1125             }
1126             const std::string message = "removeOperationInputTest: operation " +
1127                                         std::to_string(operation) + ", input " +
1128                                         std::to_string(input);
1129             validate(device, message, model,
1130                      [operation, input](Model* model, ExecutionPreference*) {
1131                          uint32_t operand = model->operations[operation].inputs[input];
1132                          model->operands[operand].numberOfConsumers--;
1133                          hidl_vec_removeAt(&model->operations[operation].inputs, input);
1134                      });
1135         }
1136     }
1137 }
1138 
1139 ///////////////////////// REMOVE OPERATION OUTPUT /////////////////////////
1140 
removeOperationOutputTest(const sp<IDevice> & device,const Model & model)1141 static void removeOperationOutputTest(const sp<IDevice>& device, const Model& model) {
1142     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
1143         for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) {
1144             const std::string message = "removeOperationOutputTest: operation " +
1145                                         std::to_string(operation) + ", output " +
1146                                         std::to_string(output);
1147             validate(device, message, model,
1148                      [operation, output](Model* model, ExecutionPreference*) {
1149                          hidl_vec_removeAt(&model->operations[operation].outputs, output);
1150                      });
1151         }
1152     }
1153 }
1154 
1155 ///////////////////////// MODEL VALIDATION /////////////////////////
1156 
1157 // TODO: remove model input
1158 // TODO: remove model output
1159 // TODO: add unused operation
1160 
1161 ///////////////////////// ADD OPERATION INPUT /////////////////////////
1162 
addOperationInputSkip(const Operation & op)1163 static bool addOperationInputSkip(const Operation& op) {
1164     // Skip addOperationInputTest for the following operations.
1165     // - L2_NORMALIZATION, LOCAL_RESPONSE_NORMALIZATION, SOFTMAX can have an optional INT32 axis
1166     //   parameter.
1167     if ((op.type == OperationType::L2_NORMALIZATION && op.inputs.size() == 1) ||
1168         (op.type == OperationType::LOCAL_RESPONSE_NORMALIZATION && op.inputs.size() == 5) ||
1169         (op.type == OperationType::SOFTMAX && op.inputs.size() == 2)) {
1170         return true;
1171     }
1172     return false;
1173 }
1174 
addOperationInputTest(const sp<IDevice> & device,const Model & model)1175 static void addOperationInputTest(const sp<IDevice>& device, const Model& model) {
1176     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
1177         if (addOperationInputSkip(model.operations[operation])) {
1178             continue;
1179         }
1180         const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
1181         validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
1182             uint32_t index = addOperand(model, OperandLifeTime::MODEL_INPUT);
1183             hidl_vec_push_back(&model->operations[operation].inputs, index);
1184             hidl_vec_push_back(&model->inputIndexes, index);
1185         });
1186     }
1187 }
1188 
1189 ///////////////////////// ADD OPERATION OUTPUT /////////////////////////
1190 
addOperationOutputTest(const sp<IDevice> & device,const Model & model)1191 static void addOperationOutputTest(const sp<IDevice>& device, const Model& model) {
1192     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
1193         const std::string message =
1194                 "addOperationOutputTest: operation " + std::to_string(operation);
1195         validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
1196             uint32_t index = addOperand(model, OperandLifeTime::MODEL_OUTPUT);
1197             hidl_vec_push_back(&model->operations[operation].outputs, index);
1198             hidl_vec_push_back(&model->outputIndexes, index);
1199         });
1200     }
1201 }
1202 
1203 ///////////////////////// VALIDATE EXECUTION PREFERENCE /////////////////////////
1204 
1205 static const int32_t invalidExecutionPreferences[] = {
1206         static_cast<int32_t>(ExecutionPreference::LOW_POWER) - 1,        // lower bound
1207         static_cast<int32_t>(ExecutionPreference::SUSTAINED_SPEED) + 1,  // upper bound
1208 };
1209 
mutateExecutionPreferenceTest(const sp<IDevice> & device,const Model & model)1210 static void mutateExecutionPreferenceTest(const sp<IDevice>& device, const Model& model) {
1211     for (int32_t invalidPreference : invalidExecutionPreferences) {
1212         const std::string message =
1213                 "mutateExecutionPreferenceTest: preference " + std::to_string(invalidPreference);
1214         validate(device, message, model,
1215                  [invalidPreference](Model*, ExecutionPreference* preference) {
1216                      *preference = static_cast<ExecutionPreference>(invalidPreference);
1217                  });
1218     }
1219 }
1220 
1221 ////////////////////////// ENTRY POINT //////////////////////////////
1222 
validateModel(const sp<IDevice> & device,const Model & model)1223 void validateModel(const sp<IDevice>& device, const Model& model) {
1224     mutateExecutionOrderTest(device, model);
1225     mutateOperandTypeTest(device, model);
1226     mutateOperandRankTest(device, model);
1227     mutateOperandScaleTest(device, model);
1228     mutateOperandZeroPointTest(device, model);
1229     mutateOperandLifeTimeTest(device, model);
1230     mutateOperandInputOutputTest(device, model);
1231     mutateOperandNumberOfConsumersTest(device, model);
1232     mutateOperandAddWriterTest(device, model);
1233     mutateOperationOperandTypeTest(device, model);
1234     mutateOperationTypeTest(device, model);
1235     mutateOperationInputOperandIndexTest(device, model);
1236     mutateOperationOutputOperandIndexTest(device, model);
1237     mutateOperationRemoveWriteTest(device, model);
1238     removeOperandTest(device, model);
1239     removeOperationTest(device, model);
1240     removeOperationInputTest(device, model);
1241     removeOperationOutputTest(device, model);
1242     addOperationInputTest(device, model);
1243     addOperationOutputTest(device, model);
1244     mutateExecutionPreferenceTest(device, model);
1245 }
1246 
1247 }  // namespace android::hardware::neuralnetworks::V1_2::vts::functional
1248