1 #pragma once 2 3 #include <mutex> 4 #include <sstream> 5 6 #include <torch/csrc/autograd/grad_mode.h> 7 #include <torch/csrc/jit/mobile/import.h> 8 #include <torch/csrc/jit/mobile/module.h> 9 #include <torch/csrc/jit/serialization/export.h> 10 #include <torch/script.h> 11 12 namespace torch::jit::mobile { 13 14 class MobileModelRunner { 15 std::shared_ptr<torch::jit::mobile::Module> module_; 16 17 public: MobileModelRunner(std::string const & file_path)18 explicit MobileModelRunner(std::string const& file_path) { 19 module_ = std::make_shared<torch::jit::mobile::Module>( 20 torch::jit::_load_for_mobile(file_path)); 21 } 22 MobileModelRunner(std::string const & file_path,uint64_t module_load_options)23 MobileModelRunner( 24 std::string const& file_path, 25 uint64_t module_load_options) { 26 std::unordered_map<std::string, std::string> extra_files; 27 module_ = std::make_shared<torch::jit::mobile::Module>( 28 torch::jit::_load_for_mobile( 29 file_path, 30 at::Device(at::DeviceType::CPU, 0), 31 extra_files, 32 module_load_options)); 33 } 34 MobileModelRunner(std::stringstream oss)35 MobileModelRunner(std::stringstream oss) { 36 module_ = std::make_shared<torch::jit::mobile::Module>( 37 torch::jit::_load_for_mobile(oss, at::Device(at::DeviceType::CPU, 0))); 38 } 39 40 /** 41 * Returns true if the list of operators passed in has a Metal GPU operator, 42 * and false otherwise. 43 * 44 */ 45 static bool set_has_metal_gpu_operators(std::set<std::string> const& op_list); 46 47 /** 48 * Fetches the set of root operators in the file "extra/mobile_info.json" 49 * within the .ptl archive at location file_path. 50 * 51 * An exception is thrown if: 52 * 53 * 1. The file at file_path does not exist, or 54 * 2. The contents of extra/mobile_info.json is not a JSON, or 55 * 3. The file extra/mobile_info.json does not exist, or 56 * 4. The JSON is malformed in some way and the operator list can not be 57 * extracted correctly. 58 * 59 */ 60 static std::set<std::string> get_operators_from_mobile_info_json( 61 std::string const& file_path); 62 63 static std::vector<std::vector<at::IValue>> ivalue_to_bundled_inputs( 64 const c10::IValue& bundled_inputs); 65 66 static std::unordered_map<std::string, std::string> 67 ivalue_to_bundled_inputs_map(const c10::IValue& bundled_inputs); 68 69 /** 70 * Fetches all the bundled inputs of the loaded mobile model. 71 * 72 * A bundled input itself is of type std::vector<at::IValue> and the 73 * elements of this vector<> are the arguments that the "forward" 74 * method of the model accepts. i.e. each of the at::IValue is a 75 * single argument to the model's "forward" method. 76 * 77 * The outer vector holds a bundled input. For models with bundled 78 * inputs, the outer most vector will have size > 0. 79 */ 80 std::vector<std::vector<at::IValue>> get_all_bundled_inputs(); 81 82 /** 83 * Fetches all the bundled inputs for all functions of the loaded mobile 84 * model. 85 * 86 * The mapping is from 'function_names' eg 'forward' to bundled inputs for 87 * that function 88 * 89 * A bundled input itself is of type std::vector<at::IValue> and the 90 * elements of this vector<> are the arguments that the corresponding 91 * method of the model accepts. i.e. each of the at::IValue in the entry 92 * for forward is a single argument to the model's "forward" method. 93 * 94 * The outer vector of each value holds a bundled input. For models with 95 * bundled inputs, the outer most vector will have size > 0. 96 */ 97 std::unordered_map<std::string, std::vector<std::vector<at::IValue>>> 98 get_many_functions_bundled_inputs(); 99 100 /** 101 * Returns true if a model possesses get_bundled_inputs_functions_and_info() 102 */ has_new_style_bundled_inputs()103 bool has_new_style_bundled_inputs() const { 104 return module_->find_method("get_bundled_inputs_functions_and_info") != 105 std::nullopt; 106 } 107 108 /** 109 * For each tensor in bundled inputs, call the user-provided function 'func'. 110 */ 111 void for_each_tensor_in_bundled_inputs( 112 std::function<void(const ::at::Tensor&)> const& func); 113 114 /** 115 * Get the root operators directly called by this model's Bytecode. 116 */ get_root_operators()117 std::set<std::string> get_root_operators() { 118 return torch::jit::mobile::_export_operator_list(*module_); 119 } 120 121 /** 122 * Runs the model against all of the provided inputs using the model's 123 * "forward" method. Returns an std::vector<at::IValue>, where each element 124 * of the returned vector is one of the return values from calling forward(). 125 */ 126 std::vector<at::IValue> run_with_inputs( 127 std::vector<std::vector<at::IValue>> const& bundled_inputs); 128 129 /** 130 * Runs the model against all of the provided inputs for all the specified 131 * function. Returns an std::vector<at::IValue>, where each element 132 * of the returned vector is one of the return values from calling the 133 * method named "function_name" on this model. 134 */ 135 std::vector<at::IValue> run_with_inputs( 136 const std::string& function_name, 137 std::vector<std::vector<at::IValue>> const& bundled_inputs) const; 138 139 /** 140 * Attempts to run all functions in the passed in list if they exist. All 141 * funcs should require no args 142 */ 143 void run_argless_functions(const std::vector<std::string>& functions); 144 }; 145 146 } // namespace torch::jit::mobile 147