xref: /aosp_15_r20/hardware/interfaces/neuralnetworks/1.3/vts/functional/ValidateModel.cpp (revision 4d7e907c777eeecc4c5bd7cf640a754fac206ff7)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18 
19 #include <android/hardware/neuralnetworks/1.1/types.h>
20 #include <android/hardware/neuralnetworks/1.3/types.h>
21 #include "1.0/Utils.h"
22 #include "1.3/Callbacks.h"
23 #include "1.3/Utils.h"
24 #include "GeneratedTestHarness.h"
25 #include "VtsHalNeuralnetworks.h"
26 
27 #include <optional>
28 #include <type_traits>
29 #include <utility>
30 
31 namespace android::hardware::neuralnetworks::V1_3::vts::functional {
32 
33 using implementation::PreparedModelCallback;
34 using V1_0::DataLocation;
35 using V1_1::ExecutionPreference;
36 using V1_2::SymmPerChannelQuantParams;
37 using HidlToken =
38         hidl_array<uint8_t, static_cast<uint32_t>(V1_2::Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
39 
40 using PrepareModelMutation = std::function<void(Model*, ExecutionPreference*, Priority*)>;
41 
42 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
43 
validateGetSupportedOperations(const sp<IDevice> & device,const std::string & message,const Model & model)44 static void validateGetSupportedOperations(const sp<IDevice>& device, const std::string& message,
45                                            const Model& model) {
46     SCOPED_TRACE(message + " [getSupportedOperations_1_3]");
47 
48     Return<void> ret = device->getSupportedOperations_1_3(
49             model, [&](ErrorStatus status, const hidl_vec<bool>&) {
50                 EXPECT_EQ(ErrorStatus::INVALID_ARGUMENT, status);
51             });
52     EXPECT_TRUE(ret.isOk());
53 }
54 
validatePrepareModel(const sp<IDevice> & device,const std::string & message,const Model & model,ExecutionPreference preference,Priority priority)55 static void validatePrepareModel(const sp<IDevice>& device, const std::string& message,
56                                  const Model& model, ExecutionPreference preference,
57                                  Priority priority) {
58     SCOPED_TRACE(message + " [prepareModel_1_3]");
59 
60     sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
61     Return<ErrorStatus> prepareLaunchStatus =
62             device->prepareModel_1_3(model, preference, priority, {}, hidl_vec<hidl_handle>(),
63                                      hidl_vec<hidl_handle>(), HidlToken(), preparedModelCallback);
64     ASSERT_TRUE(prepareLaunchStatus.isOk());
65     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, static_cast<ErrorStatus>(prepareLaunchStatus));
66 
67     preparedModelCallback->wait();
68     ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
69     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
70     sp<IPreparedModel> preparedModel = getPreparedModel_1_3(preparedModelCallback);
71     ASSERT_EQ(nullptr, preparedModel.get());
72 }
73 
validExecutionPreference(ExecutionPreference preference)74 static bool validExecutionPreference(ExecutionPreference preference) {
75     return preference == ExecutionPreference::LOW_POWER ||
76            preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
77            preference == ExecutionPreference::SUSTAINED_SPEED;
78 }
79 
validExecutionPriority(Priority priority)80 static bool validExecutionPriority(Priority priority) {
81     return priority == Priority::LOW || priority == Priority::MEDIUM || priority == Priority::HIGH;
82 }
83 
84 // Primary validation function. This function will take a valid model, apply a
85 // mutation to invalidate the model, the execution preference, or the priority,
86 // then pass these to supportedOperations and/or prepareModel if that method is
87 // called with an invalid argument.
validate(const sp<IDevice> & device,const std::string & message,const Model & originalModel,const PrepareModelMutation & mutate)88 static void validate(const sp<IDevice>& device, const std::string& message,
89                      const Model& originalModel, const PrepareModelMutation& mutate) {
90     Model model = originalModel;
91     ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER;
92     Priority priority = kDefaultPriority;
93     mutate(&model, &preference, &priority);
94 
95     if (validExecutionPreference(preference) && validExecutionPriority(priority)) {
96         validateGetSupportedOperations(device, message, model);
97     }
98 
99     validatePrepareModel(device, message, model, preference, priority);
100 }
101 
addOperand(Model * model)102 static uint32_t addOperand(Model* model) {
103     return hidl_vec_push_back(&model->main.operands,
104                               {
105                                       .type = OperandType::INT32,
106                                       .dimensions = {},
107                                       .numberOfConsumers = 0,
108                                       .scale = 0.0f,
109                                       .zeroPoint = 0,
110                                       .lifetime = OperandLifeTime::SUBGRAPH_INPUT,
111                                       .location = {.poolIndex = 0, .offset = 0, .length = 0},
112                               });
113 }
114 
addOperand(Model * model,OperandLifeTime lifetime)115 static uint32_t addOperand(Model* model, OperandLifeTime lifetime) {
116     uint32_t index = addOperand(model);
117     model->main.operands[index].numberOfConsumers = 1;
118     model->main.operands[index].lifetime = lifetime;
119     return index;
120 }
121 
122 // If we introduce a CONSTANT_COPY for an operand of size operandSize,
123 // how much will this increase the size of the model?  This assumes
124 // that we can (re)use all of model.operandValues for the operand
125 // value.
constantCopyExtraSize(const Model & model,size_t operandSize)126 static size_t constantCopyExtraSize(const Model& model, size_t operandSize) {
127     const size_t operandValuesSize = model.operandValues.size();
128     return (operandValuesSize < operandSize) ? (operandSize - operandValuesSize) : 0;
129 }
130 
131 // Highly specialized utility routine for converting an operand to
132 // CONSTANT_COPY lifetime.
133 //
134 // Expects that:
135 // - operand has a known size
136 // - operand->lifetime has already been set to CONSTANT_COPY
137 // - operand->location has been zeroed out
138 //
139 // Does the following:
140 // - initializes operand->location to point to the beginning of model->operandValues
141 // - resizes model->operandValues (if necessary) to be large enough for the operand
142 //   value, padding it with zeroes on the end
143 //
144 // Potential problem:
145 // By changing the operand to CONSTANT_COPY lifetime, this function is effectively initializing the
146 // operand with unspecified (but deterministic) data. This means that the model may be invalidated
147 // in two ways: not only is the lifetime of CONSTANT_COPY invalid, but the operand's value in the
148 // graph may also be invalid (e.g., if the operand is used as an activation code and has an invalid
149 // value). For now, this should be fine because it just means we're not testing what we think we're
150 // testing in certain cases; but we can handwave this and assume we're probabilistically likely to
151 // exercise the validation code over the span of the entire test set and operand space.
152 //
153 // Aborts if the specified operand type is an extension type or OEM type.
becomeConstantCopy(Model * model,Operand * operand)154 static void becomeConstantCopy(Model* model, Operand* operand) {
155     // sizeOfData will abort if the specified type is an extension type or OEM type.
156     const size_t sizeOfOperand = sizeOfData(*operand);
157     EXPECT_NE(sizeOfOperand, size_t(0));
158     operand->location.poolIndex = 0;
159     operand->location.offset = 0;
160     operand->location.length = sizeOfOperand;
161     if (model->operandValues.size() < sizeOfOperand) {
162         model->operandValues.resize(sizeOfOperand);
163     }
164 }
165 
166 // The sizeForBinder() functions estimate the size of the
167 // representation of a value when sent to binder.  It's probably a bit
168 // of an under-estimate, because we don't know the size of the
169 // metadata in the binder format (e.g., representation of the size of
170 // a vector); but at least it adds up "big" things like vector
171 // contents.  However, it doesn't treat inter-field or end-of-struct
172 // padding in a methodical way -- there's no attempt to be consistent
173 // in whether or not padding in the native (C++) representation
174 // contributes to the estimated size for the binder representation;
175 // and there's no attempt to understand what padding (if any) is
176 // needed in the binder representation.
177 //
178 // This assumes that non-metadata uses a fixed length encoding (e.g.,
179 // a uint32_t is always encoded in sizeof(uint32_t) bytes, rather than
180 // using an encoding whose length is related to the magnitude of the
181 // encoded value).
182 
183 template <typename Type>
sizeForBinder(const Type & val)184 static size_t sizeForBinder(const Type& val) {
185     static_assert(std::is_trivially_copyable_v<std::remove_reference_t<Type>>,
186                   "expected a trivially copyable type");
187     return sizeof(val);
188 }
189 
190 template <typename Type>
sizeForBinder(const hidl_vec<Type> & vec)191 static size_t sizeForBinder(const hidl_vec<Type>& vec) {
192     return std::accumulate(vec.begin(), vec.end(), 0,
193                            [](size_t acc, const Type& x) { return acc + sizeForBinder(x); });
194 }
195 
196 template <>
sizeForBinder(const SymmPerChannelQuantParams & symmPerChannelQuantParams)197 size_t sizeForBinder(const SymmPerChannelQuantParams& symmPerChannelQuantParams) {
198     size_t size = 0;
199 
200     size += sizeForBinder(symmPerChannelQuantParams.scales);
201     size += sizeForBinder(symmPerChannelQuantParams.channelDim);
202 
203     return size;
204 }
205 
206 template <>
sizeForBinder(const V1_2::Operand::ExtraParams & extraParams)207 size_t sizeForBinder(const V1_2::Operand::ExtraParams& extraParams) {
208     using Discriminator = V1_2::Operand::ExtraParams::hidl_discriminator;
209     switch (extraParams.getDiscriminator()) {
210         case Discriminator::none:
211             return 0;
212         case Discriminator::channelQuant:
213             return sizeForBinder(extraParams.channelQuant());
214         case Discriminator::extension:
215             return sizeForBinder(extraParams.extension());
216     }
217     LOG(FATAL) << "Unrecognized extraParams enum: "
218                << static_cast<int>(extraParams.getDiscriminator());
219     return 0;
220 }
221 
222 template <>
sizeForBinder(const Operand & operand)223 size_t sizeForBinder(const Operand& operand) {
224     size_t size = 0;
225 
226     size += sizeForBinder(operand.type);
227     size += sizeForBinder(operand.dimensions);
228     size += sizeForBinder(operand.numberOfConsumers);
229     size += sizeForBinder(operand.scale);
230     size += sizeForBinder(operand.zeroPoint);
231     size += sizeForBinder(operand.lifetime);
232     size += sizeForBinder(operand.location);
233     size += sizeForBinder(operand.extraParams);
234 
235     return size;
236 }
237 
238 template <>
sizeForBinder(const Operation & operation)239 size_t sizeForBinder(const Operation& operation) {
240     size_t size = 0;
241 
242     size += sizeForBinder(operation.type);
243     size += sizeForBinder(operation.inputs);
244     size += sizeForBinder(operation.outputs);
245 
246     return size;
247 }
248 
249 template <>
sizeForBinder(const hidl_string & name)250 size_t sizeForBinder(const hidl_string& name) {
251     return name.size();
252 }
253 
254 template <>
sizeForBinder(const hidl_memory & memory)255 size_t sizeForBinder(const hidl_memory& memory) {
256     // This is just a guess.
257 
258     size_t size = 0;
259 
260     if (const native_handle_t* handle = memory.handle()) {
261         size += sizeof(*handle);
262         size += sizeof(handle->data[0] * (handle->numFds + handle->numInts));
263     }
264     size += sizeForBinder(memory.name());
265 
266     return size;
267 }
268 
269 template <>
sizeForBinder(const Subgraph & subgraph)270 size_t sizeForBinder(const Subgraph& subgraph) {
271     size_t size = 0;
272 
273     size += sizeForBinder(subgraph.operands);
274     size += sizeForBinder(subgraph.operations);
275     size += sizeForBinder(subgraph.inputIndexes);
276     size += sizeForBinder(subgraph.outputIndexes);
277 
278     return size;
279 }
280 
281 template <>
sizeForBinder(const V1_2::Model::ExtensionNameAndPrefix & extensionNameToPrefix)282 size_t sizeForBinder(const V1_2::Model::ExtensionNameAndPrefix& extensionNameToPrefix) {
283     size_t size = 0;
284 
285     size += sizeForBinder(extensionNameToPrefix.name);
286     size += sizeForBinder(extensionNameToPrefix.prefix);
287 
288     return size;
289 }
290 
291 template <>
sizeForBinder(const Model & model)292 size_t sizeForBinder(const Model& model) {
293     size_t size = 0;
294 
295     size += sizeForBinder(model.main);
296     size += sizeForBinder(model.referenced);
297     size += sizeForBinder(model.operandValues);
298     size += sizeForBinder(model.pools);
299     size += sizeForBinder(model.relaxComputationFloat32toFloat16);
300     size += sizeForBinder(model.extensionNameToPrefix);
301 
302     return size;
303 }
304 
305 // https://developer.android.com/reference/android/os/TransactionTooLargeException.html
306 //
307 //     "The Binder transaction buffer has a limited fixed size,
308 //     currently 1Mb, which is shared by all transactions in progress
309 //     for the process."
310 //
311 // Will our representation fit under this limit?  There are three complications:
312 // - Our representation size is just approximate (see sizeForBinder()).
313 // - This object may not be the only occupant of the Binder transaction buffer
314 //   (although our VTS test suite should not be putting multiple objects in the
315 //   buffer at once).
316 // - IBinder.MAX_IPC_SIZE recommends limiting a transaction to 64 * 1024 bytes.
317 // So we'll be very conservative: We want the representation size to be no
318 // larger than half the recommended limit.
319 //
320 // If our representation grows large enough that it still fits within
321 // the transaction buffer but combined with other transactions may
322 // exceed the buffer size, then we may see intermittent HAL transport
323 // errors.
exceedsBinderSizeLimit(size_t representationSize)324 static bool exceedsBinderSizeLimit(size_t representationSize) {
325     // There is no C++ API to retrieve the value of the Java variable IBinder.MAX_IPC_SIZE.
326     static const size_t kHalfMaxIPCSize = 64 * 1024 / 2;
327 
328     return representationSize > kHalfMaxIPCSize;
329 }
330 
331 ///////////////////////// VALIDATE EXECUTION ORDER ////////////////////////////
332 
mutateExecutionOrderTest(const sp<IDevice> & device,const Model & model)333 static void mutateExecutionOrderTest(const sp<IDevice>& device, const Model& model) {
334     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
335         const Operation& operationObj = model.main.operations[operation];
336         for (uint32_t input : operationObj.inputs) {
337             if (model.main.operands[input].lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
338                 model.main.operands[input].lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
339                 // This operation reads an operand written by some
340                 // other operation.  Move this operation to the
341                 // beginning of the sequence, ensuring that it reads
342                 // the operand before that operand is written, thereby
343                 // violating execution order rules.
344                 const std::string message = "mutateExecutionOrderTest: operation " +
345                                             std::to_string(operation) + " is a reader";
346                 validate(device, message, model,
347                          [operation](Model* model, ExecutionPreference*, Priority*) {
348                              auto& operations = model->main.operations;
349                              std::rotate(operations.begin(), operations.begin() + operation,
350                                          operations.begin() + operation + 1);
351                          });
352                 break;  // only need to do this once per operation
353             }
354         }
355         for (uint32_t output : operationObj.outputs) {
356             if (model.main.operands[output].numberOfConsumers > 0) {
357                 // This operation writes an operand read by some other
358                 // operation.  Move this operation to the end of the
359                 // sequence, ensuring that it writes the operand after
360                 // that operand is read, thereby violating execution
361                 // order rules.
362                 const std::string message = "mutateExecutionOrderTest: operation " +
363                                             std::to_string(operation) + " is a writer";
364                 validate(device, message, model,
365                          [operation](Model* model, ExecutionPreference*, Priority*) {
366                              auto& operations = model->main.operations;
367                              std::rotate(operations.begin() + operation,
368                                          operations.begin() + operation + 1, operations.end());
369                          });
370                 break;  // only need to do this once per operation
371             }
372         }
373     }
374 }
375 
376 ///////////////////////// VALIDATE MODEL OPERAND TYPE /////////////////////////
377 
378 static const uint32_t invalidOperandTypes[] = {
379         static_cast<uint32_t>(OperandTypeRange::FUNDAMENTAL_MIN) - 1,
380         static_cast<uint32_t>(OperandTypeRange::FUNDAMENTAL_MAX) + 1,
381         static_cast<uint32_t>(OperandTypeRange::OEM_MIN) - 1,
382         static_cast<uint32_t>(OperandTypeRange::OEM_MAX) + 1,
383 };
384 
mutateOperandTypeTest(const sp<IDevice> & device,const Model & model)385 static void mutateOperandTypeTest(const sp<IDevice>& device, const Model& model) {
386     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
387         for (uint32_t invalidOperandType : invalidOperandTypes) {
388             const std::string message = "mutateOperandTypeTest: operand " +
389                                         std::to_string(operand) + " set to value " +
390                                         std::to_string(invalidOperandType);
391             validate(device, message, model,
392                      [operand, invalidOperandType](Model* model, ExecutionPreference*, Priority*) {
393                          model->main.operands[operand].type =
394                                  static_cast<OperandType>(invalidOperandType);
395                      });
396         }
397     }
398 }
399 
400 ///////////////////////// VALIDATE OPERAND RANK /////////////////////////
401 
getInvalidRank(OperandType type)402 static uint32_t getInvalidRank(OperandType type) {
403     switch (type) {
404         case OperandType::FLOAT16:
405         case OperandType::FLOAT32:
406         case OperandType::INT32:
407         case OperandType::UINT32:
408         case OperandType::BOOL:
409             return 1;
410         case OperandType::TENSOR_BOOL8:
411         case OperandType::TENSOR_FLOAT16:
412         case OperandType::TENSOR_FLOAT32:
413         case OperandType::TENSOR_INT32:
414         case OperandType::TENSOR_QUANT8_ASYMM:
415         case OperandType::TENSOR_QUANT8_SYMM:
416         case OperandType::TENSOR_QUANT16_ASYMM:
417         case OperandType::TENSOR_QUANT16_SYMM:
418         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
419             return 0;
420         default:
421             return 0;
422     }
423 }
424 
mutateOperandRankTest(const sp<IDevice> & device,const Model & model)425 static void mutateOperandRankTest(const sp<IDevice>& device, const Model& model) {
426     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
427         const uint32_t invalidRank = getInvalidRank(model.main.operands[operand].type);
428         if (invalidRank == 0) {
429             continue;
430         }
431         const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
432                                     " has rank of " + std::to_string(invalidRank);
433         validate(device, message, model,
434                  [operand, invalidRank](Model* model, ExecutionPreference*, Priority*) {
435                      model->main.operands[operand].dimensions =
436                              std::vector<uint32_t>(invalidRank, 0);
437                  });
438     }
439 }
440 
441 ///////////////////////// VALIDATE OPERAND SCALE /////////////////////////
442 
getInvalidScale(OperandType type)443 static float getInvalidScale(OperandType type) {
444     switch (type) {
445         case OperandType::FLOAT16:
446         case OperandType::FLOAT32:
447         case OperandType::INT32:
448         case OperandType::UINT32:
449         case OperandType::BOOL:
450         case OperandType::TENSOR_BOOL8:
451         case OperandType::TENSOR_FLOAT16:
452         case OperandType::TENSOR_FLOAT32:
453         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
454         case OperandType::SUBGRAPH:
455             return 1.0f;
456         case OperandType::TENSOR_INT32:
457             return -1.0f;
458         case OperandType::TENSOR_QUANT8_SYMM:
459         case OperandType::TENSOR_QUANT8_ASYMM:
460         case OperandType::TENSOR_QUANT16_ASYMM:
461         case OperandType::TENSOR_QUANT16_SYMM:
462             return 0.0f;
463         default:
464             return 0.0f;
465     }
466 }
467 
mutateOperandScaleTest(const sp<IDevice> & device,const Model & model)468 static void mutateOperandScaleTest(const sp<IDevice>& device, const Model& model) {
469     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
470         const float invalidScale = getInvalidScale(model.main.operands[operand].type);
471         const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
472                                     " has scale of " + std::to_string(invalidScale);
473         validate(device, message, model,
474                  [operand, invalidScale](Model* model, ExecutionPreference*, Priority*) {
475                      model->main.operands[operand].scale = invalidScale;
476                  });
477     }
478 }
479 
480 ///////////////////////// VALIDATE OPERAND ZERO POINT /////////////////////////
481 
getInvalidZeroPoints(OperandType type)482 static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
483     switch (type) {
484         case OperandType::FLOAT16:
485         case OperandType::FLOAT32:
486         case OperandType::INT32:
487         case OperandType::UINT32:
488         case OperandType::BOOL:
489         case OperandType::TENSOR_BOOL8:
490         case OperandType::TENSOR_FLOAT16:
491         case OperandType::TENSOR_FLOAT32:
492         case OperandType::TENSOR_INT32:
493         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
494         case OperandType::SUBGRAPH:
495             return {1};
496         case OperandType::TENSOR_QUANT8_ASYMM:
497             return {-1, 256};
498         case OperandType::TENSOR_QUANT8_SYMM:
499             return {-129, -1, 1, 128};
500         case OperandType::TENSOR_QUANT16_ASYMM:
501             return {-1, 65536};
502         case OperandType::TENSOR_QUANT16_SYMM:
503             return {-32769, -1, 1, 32768};
504         default:
505             return {};
506     }
507 }
508 
mutateOperandZeroPointTest(const sp<IDevice> & device,const Model & model)509 static void mutateOperandZeroPointTest(const sp<IDevice>& device, const Model& model) {
510     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
511         const std::vector<int32_t> invalidZeroPoints =
512                 getInvalidZeroPoints(model.main.operands[operand].type);
513         for (int32_t invalidZeroPoint : invalidZeroPoints) {
514             const std::string message = "mutateOperandZeroPointTest: operand " +
515                                         std::to_string(operand) + " has zero point of " +
516                                         std::to_string(invalidZeroPoint);
517             validate(device, message, model,
518                      [operand, invalidZeroPoint](Model* model, ExecutionPreference*, Priority*) {
519                          model->main.operands[operand].zeroPoint = invalidZeroPoint;
520                      });
521         }
522     }
523 }
524 
525 ///////////////////////// VALIDATE OPERAND LIFETIME /////////////////////////////////////////////
526 
getInvalidLifeTimes(const Model & model,size_t modelSize,const Operand & operand)527 static std::vector<OperandLifeTime> getInvalidLifeTimes(const Model& model, size_t modelSize,
528                                                         const Operand& operand) {
529     // TODO: Support OperandLifeTime::CONSTANT_REFERENCE as an invalid lifetime
530     // TODO: Support OperandLifeTime::NO_VALUE as an invalid lifetime
531 
532     // Ways to get an invalid lifetime:
533     // - change whether a lifetime means an operand should have a writer
534     std::vector<OperandLifeTime> ret;
535     switch (operand.lifetime) {
536         case OperandLifeTime::SUBGRAPH_OUTPUT:
537         case OperandLifeTime::TEMPORARY_VARIABLE:
538             ret = {
539                     OperandLifeTime::SUBGRAPH_INPUT,
540                     OperandLifeTime::CONSTANT_COPY,
541             };
542             break;
543         case OperandLifeTime::CONSTANT_COPY:
544         case OperandLifeTime::CONSTANT_REFERENCE:
545         case OperandLifeTime::SUBGRAPH_INPUT:
546             ret = {
547                     OperandLifeTime::TEMPORARY_VARIABLE,
548                     OperandLifeTime::SUBGRAPH_OUTPUT,
549             };
550             break;
551         case OperandLifeTime::NO_VALUE:
552             // Not enough information to know whether
553             // TEMPORARY_VARIABLE or CONSTANT_COPY would be invalid --
554             // is this operand written (then CONSTANT_COPY would be
555             // invalid) or not (then TEMPORARY_VARIABLE would be
556             // invalid)?
557             break;
558         case OperandLifeTime::SUBGRAPH:
559             break;
560         default:
561             ADD_FAILURE();
562             break;
563     }
564 
565     const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
566     if (!operandSize ||
567         exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
568         // Unknown size or too-large size
569         ret.erase(std::remove(ret.begin(), ret.end(), OperandLifeTime::CONSTANT_COPY), ret.end());
570     }
571 
572     return ret;
573 }
574 
mutateOperandLifeTimeTest(const sp<IDevice> & device,const Model & model)575 static void mutateOperandLifeTimeTest(const sp<IDevice>& device, const Model& model) {
576     const size_t modelSize = sizeForBinder(model);
577     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
578         const std::vector<OperandLifeTime> invalidLifeTimes =
579                 getInvalidLifeTimes(model, modelSize, model.main.operands[operand]);
580         for (OperandLifeTime invalidLifeTime : invalidLifeTimes) {
581             const std::string message = "mutateOperandLifetimeTest: operand " +
582                                         std::to_string(operand) + " has lifetime " +
583                                         toString(invalidLifeTime) + " instead of lifetime " +
584                                         toString(model.main.operands[operand].lifetime);
585             validate(device, message, model,
586                      [operand, invalidLifeTime](Model* model, ExecutionPreference*, Priority*) {
587                          static const DataLocation kZeroDataLocation = {};
588                          Operand& operandObj = model->main.operands[operand];
589                          switch (operandObj.lifetime) {
590                              case OperandLifeTime::SUBGRAPH_INPUT: {
591                                  hidl_vec_remove(&model->main.inputIndexes, uint32_t(operand));
592                                  break;
593                              }
594                              case OperandLifeTime::SUBGRAPH_OUTPUT: {
595                                  hidl_vec_remove(&model->main.outputIndexes, uint32_t(operand));
596                                  break;
597                              }
598                              default:
599                                  break;
600                          }
601                          operandObj.lifetime = invalidLifeTime;
602                          operandObj.location = kZeroDataLocation;
603                          switch (invalidLifeTime) {
604                              case OperandLifeTime::CONSTANT_COPY: {
605                                  becomeConstantCopy(model, &operandObj);
606                                  break;
607                              }
608                              case OperandLifeTime::SUBGRAPH_INPUT:
609                                  hidl_vec_push_back(&model->main.inputIndexes, uint32_t(operand));
610                                  break;
611                              case OperandLifeTime::SUBGRAPH_OUTPUT:
612                                  hidl_vec_push_back(&model->main.outputIndexes, uint32_t(operand));
613                                  break;
614                              default:
615                                  break;
616                          }
617                      });
618         }
619     }
620 }
621 
622 ///////////////////////// VALIDATE OPERAND INPUT-or-OUTPUT //////////////////////////////////////
623 
getInputOutputLifeTime(const Model & model,size_t modelSize,const Operand & operand)624 static std::optional<OperandLifeTime> getInputOutputLifeTime(const Model& model, size_t modelSize,
625                                                              const Operand& operand) {
626     // Ways to get an invalid lifetime (with respect to model inputIndexes and outputIndexes):
627     // - change whether a lifetime means an operand is a model input, a model output, or neither
628     // - preserve whether or not a lifetime means an operand should have a writer
629     switch (operand.lifetime) {
630         case OperandLifeTime::CONSTANT_COPY:
631         case OperandLifeTime::CONSTANT_REFERENCE:
632             return OperandLifeTime::SUBGRAPH_INPUT;
633         case OperandLifeTime::SUBGRAPH_INPUT: {
634             const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
635             if (!operandSize ||
636                 exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
637                 // Unknown size or too-large size
638                 break;
639             }
640             return OperandLifeTime::CONSTANT_COPY;
641         }
642         case OperandLifeTime::SUBGRAPH_OUTPUT:
643             return OperandLifeTime::TEMPORARY_VARIABLE;
644         case OperandLifeTime::TEMPORARY_VARIABLE:
645             return OperandLifeTime::SUBGRAPH_OUTPUT;
646         case OperandLifeTime::NO_VALUE:
647             // Not enough information to know whether
648             // TEMPORARY_VARIABLE or CONSTANT_COPY would be an
649             // appropriate choice -- is this operand written (then
650             // TEMPORARY_VARIABLE would be appropriate) or not (then
651             // CONSTANT_COPY would be appropriate)?
652             break;
653         case OperandLifeTime::SUBGRAPH:
654             break;
655         default:
656             ADD_FAILURE();
657             break;
658     }
659 
660     return std::nullopt;
661 }
662 
mutateOperandInputOutputTest(const sp<IDevice> & device,const Model & model)663 static void mutateOperandInputOutputTest(const sp<IDevice>& device, const Model& model) {
664     const size_t modelSize = sizeForBinder(model);
665     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
666         const std::optional<OperandLifeTime> changedLifeTime =
667                 getInputOutputLifeTime(model, modelSize, model.main.operands[operand]);
668         if (changedLifeTime) {
669             const std::string message = "mutateOperandInputOutputTest: operand " +
670                                         std::to_string(operand) + " has lifetime " +
671                                         toString(*changedLifeTime) + " instead of lifetime " +
672                                         toString(model.main.operands[operand].lifetime);
673             validate(device, message, model,
674                      [operand, changedLifeTime](Model* model, ExecutionPreference*, Priority*) {
675                          static const DataLocation kZeroDataLocation = {};
676                          Operand& operandObj = model->main.operands[operand];
677                          operandObj.lifetime = *changedLifeTime;
678                          operandObj.location = kZeroDataLocation;
679                          if (*changedLifeTime == OperandLifeTime::CONSTANT_COPY) {
680                              becomeConstantCopy(model, &operandObj);
681                          }
682                      });
683         }
684     }
685 }
686 
687 ///////////////////////// VALIDATE OPERAND NUMBER OF CONSUMERS //////////////////////////////////
688 
getInvalidNumberOfConsumers(uint32_t numberOfConsumers)689 static std::vector<uint32_t> getInvalidNumberOfConsumers(uint32_t numberOfConsumers) {
690     if (numberOfConsumers == 0) {
691         return {1};
692     } else {
693         return {numberOfConsumers - 1, numberOfConsumers + 1};
694     }
695 }
696 
mutateOperandNumberOfConsumersTest(const sp<IDevice> & device,const Model & model)697 static void mutateOperandNumberOfConsumersTest(const sp<IDevice>& device, const Model& model) {
698     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
699         const std::vector<uint32_t> invalidNumberOfConsumersVec =
700                 getInvalidNumberOfConsumers(model.main.operands[operand].numberOfConsumers);
701         for (uint32_t invalidNumberOfConsumers : invalidNumberOfConsumersVec) {
702             const std::string message =
703                     "mutateOperandNumberOfConsumersTest: operand " + std::to_string(operand) +
704                     " numberOfConsumers = " + std::to_string(invalidNumberOfConsumers);
705             validate(device, message, model,
706                      [operand, invalidNumberOfConsumers](Model* model, ExecutionPreference*,
707                                                          Priority*) {
708                          model->main.operands[operand].numberOfConsumers = invalidNumberOfConsumers;
709                      });
710         }
711     }
712 }
713 
714 ///////////////////////// VALIDATE OPERAND NUMBER OF WRITERS ////////////////////////////////////
715 
mutateOperandAddWriterTest(const sp<IDevice> & device,const Model & model)716 static void mutateOperandAddWriterTest(const sp<IDevice>& device, const Model& model) {
717     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
718         for (size_t badOutputNum = 0;
719              badOutputNum < model.main.operations[operation].outputs.size(); ++badOutputNum) {
720             const uint32_t outputOperandIndex =
721                     model.main.operations[operation].outputs[badOutputNum];
722             const std::string message = "mutateOperandAddWriterTest: operation " +
723                                         std::to_string(operation) + " writes to " +
724                                         std::to_string(outputOperandIndex);
725             // We'll insert a copy of the operation, all of whose
726             // OTHER output operands are newly-created -- i.e.,
727             // there'll only be a duplicate write of ONE of that
728             // operation's output operands.
729             validate(device, message, model,
730                      [operation, badOutputNum](Model* model, ExecutionPreference*, Priority*) {
731                          Operation newOperation = model->main.operations[operation];
732                          for (uint32_t input : newOperation.inputs) {
733                              ++model->main.operands[input].numberOfConsumers;
734                          }
735                          for (size_t outputNum = 0; outputNum < newOperation.outputs.size();
736                               ++outputNum) {
737                              if (outputNum == badOutputNum) continue;
738 
739                              Operand operandValue =
740                                      model->main.operands[newOperation.outputs[outputNum]];
741                              operandValue.numberOfConsumers = 0;
742                              if (operandValue.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
743                                  operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
744                              } else {
745                                  ASSERT_EQ(operandValue.lifetime,
746                                            OperandLifeTime::TEMPORARY_VARIABLE);
747                              }
748                              newOperation.outputs[outputNum] =
749                                      hidl_vec_push_back(&model->main.operands, operandValue);
750                          }
751                          // Where do we insert the extra writer (a new
752                          // operation)?  It has to be later than all the
753                          // writers of its inputs.  The easiest thing to do
754                          // is to insert it at the end of the operation
755                          // sequence.
756                          hidl_vec_push_back(&model->main.operations, newOperation);
757                      });
758         }
759     }
760 }
761 
762 ///////////////////////// VALIDATE EXTRA ??? /////////////////////////
763 
764 // TODO: Operand::location
765 
766 ///////////////////////// VALIDATE OPERATION OPERAND TYPE /////////////////////////
767 
mutateOperand(Operand * operand,OperandType type)768 static void mutateOperand(Operand* operand, OperandType type) {
769     Operand newOperand = *operand;
770     newOperand.type = type;
771     switch (type) {
772         case OperandType::FLOAT16:
773         case OperandType::FLOAT32:
774         case OperandType::INT32:
775         case OperandType::UINT32:
776         case OperandType::BOOL:
777             newOperand.dimensions = hidl_vec<uint32_t>();
778             newOperand.scale = 0.0f;
779             newOperand.zeroPoint = 0;
780             break;
781         case OperandType::TENSOR_BOOL8:
782         case OperandType::TENSOR_FLOAT16:
783         case OperandType::TENSOR_FLOAT32:
784             newOperand.dimensions =
785                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
786             newOperand.scale = 0.0f;
787             newOperand.zeroPoint = 0;
788             break;
789         case OperandType::TENSOR_INT32:
790             newOperand.dimensions =
791                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
792             newOperand.zeroPoint = 0;
793             break;
794         case OperandType::TENSOR_QUANT8_ASYMM:
795         case OperandType::TENSOR_QUANT8_SYMM:
796         case OperandType::TENSOR_QUANT16_ASYMM:
797         case OperandType::TENSOR_QUANT16_SYMM:
798             newOperand.dimensions =
799                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
800             newOperand.scale = operand->scale != 0.0f ? operand->scale : 1.0f;
801             break;
802         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
803             newOperand.dimensions =
804                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
805             newOperand.scale = 0.0f;
806             newOperand.zeroPoint = 0;
807 
808             SymmPerChannelQuantParams channelQuant;
809             channelQuant.channelDim = 0;
810             channelQuant.scales = hidl_vec<float>(
811                     operand->dimensions.size() > 0 ? static_cast<size_t>(operand->dimensions[0])
812                                                    : 0);
813             for (size_t i = 0; i < channelQuant.scales.size(); ++i) {
814                 channelQuant.scales[i] = 1.0f;
815             }
816             newOperand.extraParams.channelQuant(std::move(channelQuant));
817         } break;
818         case OperandType::OEM:
819         case OperandType::TENSOR_OEM_BYTE:
820         default:
821             break;
822     }
823     *operand = newOperand;
824 }
825 
mutateOperationOperandTypeSkip(size_t operand,OperandType type,const Model & model)826 static bool mutateOperationOperandTypeSkip(size_t operand, OperandType type, const Model& model) {
827     // Do not test OEM types
828     if (type == model.main.operands[operand].type || type == OperandType::OEM ||
829         type == OperandType::TENSOR_OEM_BYTE) {
830         return true;
831     }
832     for (const Operation& operation : model.main.operations) {
833         // Skip mutateOperationOperandTypeTest for the following operations.
834         // - LSH_PROJECTION's second argument is allowed to have any type.
835         // - ARGMIN and ARGMAX's first argument can be any of
836         // TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
837         // - CAST's argument can be any of TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
838         // - RANDOM_MULTINOMIAL's argument can be either TENSOR_FLOAT16 or TENSOR_FLOAT32.
839         // - DEQUANTIZE input can be any of
840         // TENSOR_(QUANT8_ASYMM|QUANT8_ASYMM_SIGNED|QUANT8_SYMM|QUANT8_SYMM_PER_CHANNEL),
841         // output can be of either TENSOR_FLOAT16 or TENSOR_FLOAT32.
842         // - QUANTIZE input can be either TENSOR_FLOAT16 or TENSOR_FLOAT32
843         // - CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
844         // - DEPTHWISE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
845         // - GROUPED_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
846         // - TRANSPOSE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
847         // - AXIS_ALIGNED_BBOX_TRANSFORM bounding boxes (arg 1) can be of
848         //     TENSOR_QUANT8_ASYMM or TENSOR_QUANT8_ASYMM_SIGNED.
849         // - RANK's input can have any TENSOR_* type.
850         switch (operation.type) {
851             case OperationType::LSH_PROJECTION: {
852                 if (operand == operation.inputs[1]) {
853                     return true;
854                 }
855             } break;
856             case OperationType::CAST:
857             case OperationType::ARGMAX:
858             case OperationType::ARGMIN: {
859                 if (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32 ||
860                     type == OperandType::TENSOR_INT32 || type == OperandType::TENSOR_QUANT8_ASYMM ||
861                     type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
862                     return true;
863                 }
864             } break;
865             case OperationType::QUANTIZE: {
866                 if (operand == operation.inputs[0] &&
867                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
868                     return true;
869                 }
870                 if (operand == operation.outputs[0] &&
871                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
872                      type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
873                     return true;
874                 }
875             } break;
876             case OperationType::RANDOM_MULTINOMIAL: {
877                 if (operand == operation.inputs[0] &&
878                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
879                     return true;
880                 }
881             } break;
882             case OperationType::DEQUANTIZE: {
883                 if (operand == operation.inputs[0] &&
884                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
885                      type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
886                      type == OperandType::TENSOR_QUANT8_SYMM ||
887                      type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
888                     return true;
889                 }
890                 if (operand == operation.outputs[0] &&
891                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
892                     return true;
893                 }
894             } break;
895             case OperationType::TRANSPOSE_CONV_2D:
896             case OperationType::GROUPED_CONV_2D:
897             case OperationType::DEPTHWISE_CONV_2D:
898             case OperationType::CONV_2D: {
899                 if (operand == operation.inputs[1] &&
900                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
901                      type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
902                     return true;
903                 }
904             } break;
905             case OperationType::AXIS_ALIGNED_BBOX_TRANSFORM: {
906                 if (operand == operation.inputs[1] &&
907                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
908                      type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
909                     return true;
910                 }
911             } break;
912             case OperationType::RANK: {
913                 if (operand == operation.inputs[0] &&
914                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32 ||
915                      type == OperandType::TENSOR_INT32 ||
916                      type == OperandType::TENSOR_QUANT8_ASYMM ||
917                      type == OperandType::TENSOR_QUANT16_SYMM ||
918                      type == OperandType::TENSOR_BOOL8 ||
919                      type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL ||
920                      type == OperandType::TENSOR_QUANT16_ASYMM ||
921                      type == OperandType::TENSOR_QUANT8_SYMM ||
922                      type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
923                     return true;
924                 }
925             } break;
926             default:
927                 break;
928         }
929     }
930     return false;
931 }
932 
mutateOperationOperandTypeTest(const sp<IDevice> & device,const Model & model)933 static void mutateOperationOperandTypeTest(const sp<IDevice>& device, const Model& model) {
934     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
935         for (OperandType invalidOperandType : hidl_enum_range<OperandType>{}) {
936             if (mutateOperationOperandTypeSkip(operand, invalidOperandType, model)) {
937                 continue;
938             }
939             const std::string message = "mutateOperationOperandTypeTest: operand " +
940                                         std::to_string(operand) + " set to type " +
941                                         toString(invalidOperandType);
942             validate(device, message, model,
943                      [operand, invalidOperandType](Model* model, ExecutionPreference*, Priority*) {
944                          mutateOperand(&model->main.operands[operand], invalidOperandType);
945                      });
946         }
947     }
948 }
949 
950 ///////////////////////// VALIDATE MODEL OPERATION TYPE /////////////////////////
951 
952 static const uint32_t invalidOperationTypes[] = {
953         static_cast<uint32_t>(OperationTypeRange::FUNDAMENTAL_MAX) + 1,
954         static_cast<uint32_t>(OperationTypeRange::OEM_MIN) - 1,
955         static_cast<uint32_t>(OperationTypeRange::OEM_MAX) + 1,
956 };
957 
mutateOperationTypeTest(const sp<IDevice> & device,const Model & model)958 static void mutateOperationTypeTest(const sp<IDevice>& device, const Model& model) {
959     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
960         for (uint32_t invalidOperationType : invalidOperationTypes) {
961             const std::string message = "mutateOperationTypeTest: operation " +
962                                         std::to_string(operation) + " set to value " +
963                                         std::to_string(invalidOperationType);
964             validate(device, message, model,
965                      [operation, invalidOperationType](Model* model, ExecutionPreference*,
966                                                        Priority*) {
967                          model->main.operations[operation].type =
968                                  static_cast<OperationType>(invalidOperationType);
969                      });
970         }
971     }
972 }
973 
974 ///////////////////////// VALIDATE MODEL OPERATION INPUT OPERAND INDEX /////////////////////////
975 
mutateOperationInputOperandIndexTest(const sp<IDevice> & device,const Model & model)976 static void mutateOperationInputOperandIndexTest(const sp<IDevice>& device, const Model& model) {
977     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
978         const uint32_t invalidOperand = model.main.operands.size();
979         for (size_t input = 0; input < model.main.operations[operation].inputs.size(); ++input) {
980             const std::string message = "mutateOperationInputOperandIndexTest: operation " +
981                                         std::to_string(operation) + " input " +
982                                         std::to_string(input);
983             validate(device, message, model,
984                      [operation, input, invalidOperand](Model* model, ExecutionPreference*,
985                                                         Priority*) {
986                          model->main.operations[operation].inputs[input] = invalidOperand;
987                      });
988         }
989     }
990 }
991 
992 ///////////////////////// VALIDATE MODEL OPERATION OUTPUT OPERAND INDEX /////////////////////////
993 
mutateOperationOutputOperandIndexTest(const sp<IDevice> & device,const Model & model)994 static void mutateOperationOutputOperandIndexTest(const sp<IDevice>& device, const Model& model) {
995     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
996         const uint32_t invalidOperand = model.main.operands.size();
997         for (size_t output = 0; output < model.main.operations[operation].outputs.size();
998              ++output) {
999             const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
1000                                         std::to_string(operation) + " output " +
1001                                         std::to_string(output);
1002             validate(device, message, model,
1003                      [operation, output, invalidOperand](Model* model, ExecutionPreference*,
1004                                                          Priority*) {
1005                          model->main.operations[operation].outputs[output] = invalidOperand;
1006                      });
1007         }
1008     }
1009 }
1010 
1011 ///////////////////////// VALIDATE MODEL OPERANDS WRITTEN ///////////////////////////////////////
1012 
mutateOperationRemoveWriteTest(const sp<IDevice> & device,const Model & model)1013 static void mutateOperationRemoveWriteTest(const sp<IDevice>& device, const Model& model) {
1014     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1015         for (size_t outputNum = 0; outputNum < model.main.operations[operation].outputs.size();
1016              ++outputNum) {
1017             const uint32_t outputOperandIndex = model.main.operations[operation].outputs[outputNum];
1018             if (model.main.operands[outputOperandIndex].numberOfConsumers > 0) {
1019                 const std::string message = "mutateOperationRemoveWriteTest: operation " +
1020                                             std::to_string(operation) + " writes to " +
1021                                             std::to_string(outputOperandIndex);
1022                 validate(device, message, model,
1023                          [operation, outputNum](Model* model, ExecutionPreference*, Priority*) {
1024                              uint32_t& outputOperandIndex =
1025                                      model->main.operations[operation].outputs[outputNum];
1026                              Operand operandValue = model->main.operands[outputOperandIndex];
1027                              operandValue.numberOfConsumers = 0;
1028                              if (operandValue.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
1029                                  operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
1030                              } else {
1031                                  ASSERT_EQ(operandValue.lifetime,
1032                                            OperandLifeTime::TEMPORARY_VARIABLE);
1033                              }
1034                              outputOperandIndex =
1035                                      hidl_vec_push_back(&model->main.operands, operandValue);
1036                          });
1037             }
1038         }
1039     }
1040 }
1041 
1042 ///////////////////////// REMOVE OPERAND FROM EVERYTHING /////////////////////////
1043 
removeValueAndDecrementGreaterValues(hidl_vec<uint32_t> * vec,uint32_t value)1044 static void removeValueAndDecrementGreaterValues(hidl_vec<uint32_t>* vec, uint32_t value) {
1045     if (vec) {
1046         // remove elements matching "value"
1047         auto last = std::remove(vec->begin(), vec->end(), value);
1048         vec->resize(std::distance(vec->begin(), last));
1049 
1050         // decrement elements exceeding "value"
1051         std::transform(vec->begin(), vec->end(), vec->begin(),
1052                        [value](uint32_t v) { return v > value ? v-- : v; });
1053     }
1054 }
1055 
removeOperand(Model * model,uint32_t index)1056 static void removeOperand(Model* model, uint32_t index) {
1057     hidl_vec_removeAt(&model->main.operands, index);
1058     for (Operation& operation : model->main.operations) {
1059         removeValueAndDecrementGreaterValues(&operation.inputs, index);
1060         removeValueAndDecrementGreaterValues(&operation.outputs, index);
1061     }
1062     removeValueAndDecrementGreaterValues(&model->main.inputIndexes, index);
1063     removeValueAndDecrementGreaterValues(&model->main.outputIndexes, index);
1064 }
1065 
removeOperandSkip(size_t operandIndex,const Model & model)1066 static bool removeOperandSkip(size_t operandIndex, const Model& model) {
1067     const Operand& operand = model.main.operands[operandIndex];
1068     if (operand.numberOfConsumers == 0) {
1069         // Removing an unused operand has no effect.
1070         return true;
1071     }
1072     for (const Operation& operation : model.main.operations) {
1073         // Skip removeOperandTest for the following operations.
1074         // - SPLIT's outputs are not checked during prepareModel.
1075         if (operation.type == OperationType::SPLIT) {
1076             for (const size_t index : operation.outputs) {
1077                 if (index == operandIndex) {
1078                     return true;
1079                 }
1080             }
1081         }
1082         // BIDIRECTIONAL_SEQUENCE_LSTM and BIDIRECTIONAL_SEQUENCE_RNN can have
1083         // either one, two, three or four outputs depending on their
1084         // mergeOutputs parameter and if state outputs are provided.
1085         // UNIDIRECTIONAL_SEQUENCE_LSTM and UNIDIRECTIONAL_SEQUENCE_RNN can have
1086         // either one or three outputs depending on whether state outputs are
1087         // provided.
1088         if (operation.type == OperationType::UNIDIRECTIONAL_SEQUENCE_LSTM ||
1089             operation.type == OperationType::UNIDIRECTIONAL_SEQUENCE_RNN ||
1090             operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_LSTM ||
1091             operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_RNN) {
1092             for (const size_t index : operation.outputs) {
1093                 if (index == operandIndex) {
1094                     return true;
1095                 }
1096             }
1097         }
1098     }
1099     return false;
1100 }
1101 
removeOperandTest(const sp<IDevice> & device,const Model & model)1102 static void removeOperandTest(const sp<IDevice>& device, const Model& model) {
1103     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
1104         if (removeOperandSkip(operand, model)) {
1105             continue;
1106         }
1107         const std::string message = "removeOperandTest: operand " + std::to_string(operand);
1108         validate(device, message, model, [operand](Model* model, ExecutionPreference*, Priority*) {
1109             removeOperand(model, operand);
1110         });
1111     }
1112 }
1113 
1114 ///////////////////////// REMOVE OPERATION /////////////////////////
1115 
removeOperation(Model * model,uint32_t index)1116 static void removeOperation(Model* model, uint32_t index) {
1117     for (uint32_t operand : model->main.operations[index].inputs) {
1118         model->main.operands[operand].numberOfConsumers--;
1119     }
1120     hidl_vec_removeAt(&model->main.operations, index);
1121 }
1122 
removeOperationTest(const sp<IDevice> & device,const Model & model)1123 static void removeOperationTest(const sp<IDevice>& device, const Model& model) {
1124     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1125         const std::string message = "removeOperationTest: operation " + std::to_string(operation);
1126         validate(device, message, model,
1127                  [operation](Model* model, ExecutionPreference*, Priority*) {
1128                      removeOperation(model, operation);
1129                  });
1130     }
1131 }
1132 
1133 ///////////////////////// REMOVE OPERATION INPUT /////////////////////////
1134 
removeOperationInputSkip(const Operation & op,size_t input)1135 static bool removeOperationInputSkip(const Operation& op, size_t input) {
1136     // Skip removeOperationInputTest for the following operations.
1137     // - CONCATENATION has at least 2 inputs, with the last element being INT32.
1138     // - CONV_2D, DEPTHWISE_CONV_2D, MAX_POOL_2D, AVERAGE_POOL_2D, L2_POOL_2D, RESIZE_BILINEAR,
1139     //   SPACE_TO_DEPTH, SPACE_TO_DEPTH, SPACE_TO_BATCH_ND, BATCH_TO_SPACE_ND can have an optional
1140     //   layout parameter.
1141     //   RESIZE_BILINEAR and RESIZE_NEAREST_NEIGHBOR can have optional
1142     //   align_corners and half_pixel_centers parameters.
1143     // - L2_NORMALIZATION, LOCAL_RESPONSE_NORMALIZATION, SOFTMAX can have an optional axis
1144     //   parameter.
1145     switch (op.type) {
1146         case OperationType::CONCATENATION: {
1147             if (op.inputs.size() > 2 && input != op.inputs.size() - 1) {
1148                 return true;
1149             }
1150         } break;
1151         case OperationType::DEPTHWISE_CONV_2D: {
1152             if ((op.inputs.size() == 12 && input == 11) || (op.inputs.size() == 9 && input == 8)) {
1153                 return true;
1154             }
1155         } break;
1156         case OperationType::CONV_2D:
1157         case OperationType::AVERAGE_POOL_2D:
1158         case OperationType::MAX_POOL_2D:
1159         case OperationType::L2_POOL_2D: {
1160             if ((op.inputs.size() == 11 && input == 10) || (op.inputs.size() == 8 && input == 7)) {
1161                 return true;
1162             }
1163         } break;
1164         case OperationType::RESIZE_BILINEAR: {
1165             if (op.inputs.size() >= 4 && input >= 3) {
1166                 return true;
1167             }
1168         } break;
1169         case OperationType::RESIZE_NEAREST_NEIGHBOR: {
1170             if (op.inputs.size() >= 5 && input >= 3) {
1171                 return true;
1172             }
1173         } break;
1174         case OperationType::SPACE_TO_DEPTH:
1175         case OperationType::DEPTH_TO_SPACE:
1176         case OperationType::BATCH_TO_SPACE_ND: {
1177             if (op.inputs.size() == 3 && input == 2) {
1178                 return true;
1179             }
1180         } break;
1181         case OperationType::SPACE_TO_BATCH_ND: {
1182             if (op.inputs.size() == 4 && input == 3) {
1183                 return true;
1184             }
1185         } break;
1186         case OperationType::L2_NORMALIZATION: {
1187             if (op.inputs.size() == 2 && input == 1) {
1188                 return true;
1189             }
1190         } break;
1191         case OperationType::LOCAL_RESPONSE_NORMALIZATION: {
1192             if (op.inputs.size() == 6 && input == 5) {
1193                 return true;
1194             }
1195         } break;
1196         case OperationType::SOFTMAX: {
1197             if (op.inputs.size() == 3 && input == 2) {
1198                 return true;
1199             }
1200         } break;
1201         default:
1202             break;
1203     }
1204     return false;
1205 }
1206 
removeOperationInputTest(const sp<IDevice> & device,const Model & model)1207 static void removeOperationInputTest(const sp<IDevice>& device, const Model& model) {
1208     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1209         for (size_t input = 0; input < model.main.operations[operation].inputs.size(); ++input) {
1210             const Operation& op = model.main.operations[operation];
1211             if (removeOperationInputSkip(op, input)) {
1212                 continue;
1213             }
1214             const std::string message = "removeOperationInputTest: operation " +
1215                                         std::to_string(operation) + ", input " +
1216                                         std::to_string(input);
1217             validate(device, message, model,
1218                      [operation, input](Model* model, ExecutionPreference*, Priority*) {
1219                          uint32_t operand = model->main.operations[operation].inputs[input];
1220                          model->main.operands[operand].numberOfConsumers--;
1221                          hidl_vec_removeAt(&model->main.operations[operation].inputs, input);
1222                      });
1223         }
1224     }
1225 }
1226 
1227 ///////////////////////// REMOVE OPERATION OUTPUT /////////////////////////
1228 
removeOperationOutputTest(const sp<IDevice> & device,const Model & model)1229 static void removeOperationOutputTest(const sp<IDevice>& device, const Model& model) {
1230     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1231         for (size_t output = 0; output < model.main.operations[operation].outputs.size();
1232              ++output) {
1233             const std::string message = "removeOperationOutputTest: operation " +
1234                                         std::to_string(operation) + ", output " +
1235                                         std::to_string(output);
1236             validate(device, message, model,
1237                      [operation, output](Model* model, ExecutionPreference*, Priority*) {
1238                          hidl_vec_removeAt(&model->main.operations[operation].outputs, output);
1239                      });
1240         }
1241     }
1242 }
1243 
1244 ///////////////////////// MODEL VALIDATION /////////////////////////
1245 
1246 // TODO: remove model input
1247 // TODO: remove model output
1248 // TODO: add unused operation
1249 
1250 ///////////////////////// ADD OPERATION INPUT /////////////////////////
1251 
addOperationInputSkip(const Operation & op)1252 static bool addOperationInputSkip(const Operation& op) {
1253     // Skip addOperationInputTest for the following operations.
1254     // - L2_NORMALIZATION, LOCAL_RESPONSE_NORMALIZATION, SOFTMAX can have an optional INT32 axis
1255     //   parameter.
1256     if ((op.type == OperationType::L2_NORMALIZATION && op.inputs.size() == 1) ||
1257         (op.type == OperationType::LOCAL_RESPONSE_NORMALIZATION && op.inputs.size() == 5) ||
1258         (op.type == OperationType::SOFTMAX && op.inputs.size() == 2) ||
1259         (op.type == OperationType::RESIZE_BILINEAR && op.inputs.size() < 6) ||
1260         (op.type == OperationType::RESIZE_NEAREST_NEIGHBOR && op.inputs.size() < 6)) {
1261         return true;
1262     }
1263     return false;
1264 }
1265 
addOperationInputTest(const sp<IDevice> & device,const Model & model)1266 static void addOperationInputTest(const sp<IDevice>& device, const Model& model) {
1267     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1268         if (addOperationInputSkip(model.main.operations[operation])) {
1269             continue;
1270         }
1271         const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
1272         validate(device, message, model,
1273                  [operation](Model* model, ExecutionPreference*, Priority*) {
1274                      uint32_t index = addOperand(model, OperandLifeTime::SUBGRAPH_INPUT);
1275                      hidl_vec_push_back(&model->main.operations[operation].inputs, index);
1276                      hidl_vec_push_back(&model->main.inputIndexes, index);
1277                  });
1278     }
1279 }
1280 
1281 ///////////////////////// ADD OPERATION OUTPUT /////////////////////////
1282 
addOperationOutputTest(const sp<IDevice> & device,const Model & model)1283 static void addOperationOutputTest(const sp<IDevice>& device, const Model& model) {
1284     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1285         const std::string message =
1286                 "addOperationOutputTest: operation " + std::to_string(operation);
1287         validate(device, message, model,
1288                  [operation](Model* model, ExecutionPreference*, Priority*) {
1289                      uint32_t index = addOperand(model, OperandLifeTime::SUBGRAPH_OUTPUT);
1290                      hidl_vec_push_back(&model->main.operations[operation].outputs, index);
1291                      hidl_vec_push_back(&model->main.outputIndexes, index);
1292                  });
1293     }
1294 }
1295 
1296 ///////////////////////// VALIDATE EXECUTION PREFERENCE /////////////////////////
1297 
1298 static const int32_t invalidExecutionPreferences[] = {
1299         static_cast<int32_t>(ExecutionPreference::LOW_POWER) - 1,        // lower bound
1300         static_cast<int32_t>(ExecutionPreference::SUSTAINED_SPEED) + 1,  // upper bound
1301 };
1302 
mutateExecutionPreferenceTest(const sp<IDevice> & device,const Model & model)1303 static void mutateExecutionPreferenceTest(const sp<IDevice>& device, const Model& model) {
1304     for (int32_t invalidPreference : invalidExecutionPreferences) {
1305         const std::string message =
1306                 "mutateExecutionPreferenceTest: preference " + std::to_string(invalidPreference);
1307         validate(device, message, model,
1308                  [invalidPreference](Model*, ExecutionPreference* preference, Priority*) {
1309                      *preference = static_cast<ExecutionPreference>(invalidPreference);
1310                  });
1311     }
1312 }
1313 
1314 ///////////////////////// VALIDATE PRIORITY /////////////////////////
1315 
1316 static const int32_t invalidPriorities[] = {
1317         static_cast<int32_t>(Priority::LOW) - 1,   // lower bound
1318         static_cast<int32_t>(Priority::HIGH) + 1,  // upper bound
1319 };
1320 
mutateExecutionPriorityTest(const sp<IDevice> & device,const Model & model)1321 static void mutateExecutionPriorityTest(const sp<IDevice>& device, const Model& model) {
1322     for (int32_t invalidPriority : invalidPriorities) {
1323         const std::string message =
1324                 "mutatePriorityTest: priority " + std::to_string(invalidPriority);
1325         validate(device, message, model,
1326                  [invalidPriority](Model*, ExecutionPreference*, Priority* priority) {
1327                      *priority = static_cast<Priority>(invalidPriority);
1328                  });
1329     }
1330 }
1331 
1332 ////////////////////////// ENTRY POINT //////////////////////////////
1333 
validateModel(const sp<IDevice> & device,const Model & model)1334 void validateModel(const sp<IDevice>& device, const Model& model) {
1335     mutateExecutionOrderTest(device, model);
1336     mutateOperandTypeTest(device, model);
1337     mutateOperandRankTest(device, model);
1338     mutateOperandScaleTest(device, model);
1339     mutateOperandZeroPointTest(device, model);
1340     mutateOperandLifeTimeTest(device, model);
1341     mutateOperandInputOutputTest(device, model);
1342     mutateOperandNumberOfConsumersTest(device, model);
1343     mutateOperandAddWriterTest(device, model);
1344     mutateOperationOperandTypeTest(device, model);
1345     mutateOperationTypeTest(device, model);
1346     mutateOperationInputOperandIndexTest(device, model);
1347     mutateOperationOutputOperandIndexTest(device, model);
1348     mutateOperationRemoveWriteTest(device, model);
1349     removeOperandTest(device, model);
1350     removeOperationTest(device, model);
1351     removeOperationInputTest(device, model);
1352     removeOperationOutputTest(device, model);
1353     addOperationInputTest(device, model);
1354     addOperationOutputTest(device, model);
1355     mutateExecutionPreferenceTest(device, model);
1356     mutateExecutionPriorityTest(device, model);
1357 }
1358 
1359 }  // namespace android::hardware::neuralnetworks::V1_3::vts::functional
1360