xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/flex/test_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_FLEX_TEST_UTIL_H_
17 #define TENSORFLOW_LITE_DELEGATES_FLEX_TEST_UTIL_H_
18 
19 #include "tensorflow/c/c_api_internal.h"
20 #include "tensorflow/lite/kernels/test_util.h"
21 
22 namespace tflite {
23 namespace flex {
24 namespace testing {
25 
26 enum TfOpType {
27   kUnpack,
28   kIdentity,
29   kAdd,
30   kMul,
31   kRfft,
32   kImag,
33   kLoopCond,
34   // Represents an op that does not exist in TensorFlow.
35   kNonExistent,
36   // Represents an valid TensorFlow op where the NodeDef is incompatible.
37   kIncompatibleNodeDef,
38 };
39 
40 // This class creates models with TF and TFLite ops. In order to use this class
41 // to test the Flex delegate, implement a function that calls
42 // interpreter->ModifyGraphWithDelegate.
43 class FlexModelTest : public ::testing::Test {
44  public:
FlexModelTest()45   FlexModelTest() {}
~FlexModelTest()46   ~FlexModelTest() override {}
47 
48   bool Invoke();
49 
50   // Sets the (typed) tensor's values at the given index.
51   template <typename T>
SetTypedValues(int tensor_index,const std::vector<T> & values)52   void SetTypedValues(int tensor_index, const std::vector<T>& values) {
53     memcpy(interpreter_->typed_tensor<T>(tensor_index), values.data(),
54            values.size() * sizeof(T));
55   }
56 
57   // Returns the (typed) tensor's values at the given index.
58   template <typename T>
GetTypedValues(int tensor_index)59   std::vector<T> GetTypedValues(int tensor_index) {
60     const TfLiteTensor* t = interpreter_->tensor(tensor_index);
61     const T* tdata = interpreter_->typed_tensor<T>(tensor_index);
62     return std::vector<T>(tdata, tdata + t->bytes / sizeof(T));
63   }
64 
65   // Sets the tensor's values at the given index.
SetValues(int tensor_index,const std::vector<float> & values)66   void SetValues(int tensor_index, const std::vector<float>& values) {
67     SetTypedValues<float>(tensor_index, values);
68   }
69   void SetStringValues(int tensor_index, const std::vector<string>& values);
70 
71   // Returns the tensor's values at the given index.
GetValues(int tensor_index)72   std::vector<float> GetValues(int tensor_index) {
73     return GetTypedValues<float>(tensor_index);
74   }
75   std::vector<string> GetStringValues(int tensor_index) const;
76 
77   // Sets the tensor's shape at the given index.
78   void SetShape(int tensor_index, const std::vector<int>& values);
79 
80   // Returns the tensor's shape at the given index.
81   std::vector<int> GetShape(int tensor_index);
82 
83   // Returns the tensor's type at the given index.
84   TfLiteType GetType(int tensor_index);
85 
86   // Returns if the tensor at the given index is dynamic.
87   bool IsDynamicTensor(int tensor_index);
88 
error_reporter()89   const TestErrorReporter& error_reporter() const { return error_reporter_; }
90 
91   // Adds `num_tensor` tensors to the model. `inputs` contains the indices of
92   // the input tensors and `outputs` contains the indices of the output
93   // tensors. All tensors are set to have `type` and `dims`.
94   void AddTensors(int num_tensors, const std::vector<int>& inputs,
95                   const std::vector<int>& outputs, TfLiteType type,
96                   const std::vector<int>& dims);
97 
98   // Set a constant tensor of the given shape, type and buffer at the given
99   // index.
100   void SetConstTensor(int tensor_index, const std::vector<int>& values,
101                       TfLiteType type, const char* buffer, size_t bytes);
102 
103   // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors
104   // and `outputs` contains the indices of the output tensors.
105   void AddTfLiteMulOp(const std::vector<int>& inputs,
106                       const std::vector<int>& outputs);
107 
108   // Adds a TensorFlow op. `inputs` contains the indices of the
109   // input tensors and `outputs` contains the indices of the output tensors.
110   // This function is limited to the set of ops defined in TfOpType.
111   void AddTfOp(TfOpType op, const std::vector<int>& inputs,
112                const std::vector<int>& outputs);
113 
114  protected:
115   std::unique_ptr<Interpreter> interpreter_;
116   TestErrorReporter error_reporter_;
117   std::vector<int> tf_ops_;
118 
119  private:
120   // Helper method to add a TensorFlow op. tflite_names needs to start with
121   // "Flex" in order to work with the Flex delegate.
122   void AddTfOp(const char* tflite_name, const string& tf_name,
123                const string& nodedef_str, const std::vector<int>& inputs,
124                const std::vector<int>& outputs);
125 
126   std::vector<std::vector<uint8_t>> flexbuffers_;
127 
128   int next_op_index_ = 0;
129 };
130 
131 }  // namespace testing
132 }  // namespace flex
133 }  // namespace tflite
134 
135 #endif  // TENSORFLOW_LITE_DELEGATES_FLEX_TEST_UTIL_H_
136