1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_TESTING_TF_DRIVER_H_ 16 #define TENSORFLOW_LITE_TESTING_TF_DRIVER_H_ 17 18 #include <cstdint> 19 #include <string> 20 #include <unordered_map> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/framework/types.h" 27 #include "tensorflow/core/platform/logging.h" 28 #include "tensorflow/core/public/session.h" 29 #include "tensorflow/lite/testing/split.h" 30 #include "tensorflow/lite/testing/test_runner.h" 31 32 namespace tflite { 33 namespace testing { 34 35 // A test runner that feeds inputs into Tensorflow and generates outputs. 36 class TfDriver : public TestRunner { 37 public: 38 explicit TfDriver(const std::vector<string>& input_layer, 39 const std::vector<string>& input_layer_type, 40 const std::vector<string>& input_layer_shape, 41 const std::vector<string>& output_layer); ~TfDriver()42 ~TfDriver() override {} 43 44 void LoadModel(const string& bin_file_path) override; LoadModel(const string & bin_file_path,const string &)45 void LoadModel(const string& bin_file_path, const string&) override { 46 // Input output specifications are now provided by constructor. 47 // TODO(b/205171855): Support TfDriver to load from SavedModel instead of 48 // GraphDef. 49 LoadModel(bin_file_path); 50 } 51 52 void ReshapeTensor(const string& name, const string& csv_values) override; 53 void ResetTensor(const std::string& name) override; 54 string ReadOutput(const string& name) override; 55 void Invoke(const std::vector<std::pair<string, string>>& inputs) override; CheckResults(const std::vector<std::pair<string,string>> & expected_outputs,const std::vector<std::pair<string,string>> & expected_output_shapes)56 bool CheckResults( 57 const std::vector<std::pair<string, string>>& expected_outputs, 58 const std::vector<std::pair<string, string>>& expected_output_shapes) 59 override { 60 return true; 61 } GetOutputNames()62 std::vector<string> GetOutputNames() override { return output_names_; } 63 64 // no-op. SetInput will overwrite existing data . AllocateTensors()65 void AllocateTensors() override {} 66 67 protected: 68 void SetInput(const string& values_as_string, tensorflow::Tensor*); 69 string ReadOutput(const tensorflow::Tensor& tensor); 70 71 private: 72 std::unique_ptr<tensorflow::Session> session_; 73 std::vector<int> input_ids_; 74 std::vector<string> input_names_; 75 absl::flat_hash_map<string, int> input_name_to_id_; 76 std::vector<std::vector<int64_t>> input_shapes_; 77 std::vector<tensorflow::DataType> input_types_; 78 std::unordered_map<string, tensorflow::Tensor> input_tensors_; 79 80 std::vector<int> output_ids_; 81 std::vector<string> output_names_; 82 absl::flat_hash_map<string, int> output_name_to_id_; 83 std::vector<::tensorflow::Tensor> output_tensors_; 84 }; 85 86 } // namespace testing 87 } // namespace tflite 88 89 #endif // TENSORFLOW_LITE_TESTING_TF_DRIVER_H_ 90