xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <mutex>
4 #include <sstream>
5 
6 #include <torch/csrc/autograd/grad_mode.h>
7 #include <torch/csrc/jit/mobile/import.h>
8 #include <torch/csrc/jit/mobile/module.h>
9 #include <torch/csrc/jit/serialization/export.h>
10 #include <torch/script.h>
11 
12 namespace torch::jit::mobile {
13 
14 class MobileModelRunner {
15   std::shared_ptr<torch::jit::mobile::Module> module_;
16 
17  public:
MobileModelRunner(std::string const & file_path)18   explicit MobileModelRunner(std::string const& file_path) {
19     module_ = std::make_shared<torch::jit::mobile::Module>(
20         torch::jit::_load_for_mobile(file_path));
21   }
22 
MobileModelRunner(std::string const & file_path,uint64_t module_load_options)23   MobileModelRunner(
24       std::string const& file_path,
25       uint64_t module_load_options) {
26     std::unordered_map<std::string, std::string> extra_files;
27     module_ = std::make_shared<torch::jit::mobile::Module>(
28         torch::jit::_load_for_mobile(
29             file_path,
30             at::Device(at::DeviceType::CPU, 0),
31             extra_files,
32             module_load_options));
33   }
34 
MobileModelRunner(std::stringstream oss)35   MobileModelRunner(std::stringstream oss) {
36     module_ = std::make_shared<torch::jit::mobile::Module>(
37         torch::jit::_load_for_mobile(oss, at::Device(at::DeviceType::CPU, 0)));
38   }
39 
40   /**
41    * Returns true if the list of operators passed in has a Metal GPU operator,
42    * and false otherwise.
43    *
44    */
45   static bool set_has_metal_gpu_operators(std::set<std::string> const& op_list);
46 
47   /**
48    * Fetches the set of root operators in the file "extra/mobile_info.json"
49    * within the .ptl archive at location file_path.
50    *
51    * An exception is thrown if:
52    *
53    * 1. The file at file_path does not exist, or
54    * 2. The contents of extra/mobile_info.json is not a JSON, or
55    * 3. The file extra/mobile_info.json does not exist, or
56    * 4. The JSON is malformed in some way and the operator list can not be
57    * extracted correctly.
58    *
59    */
60   static std::set<std::string> get_operators_from_mobile_info_json(
61       std::string const& file_path);
62 
63   static std::vector<std::vector<at::IValue>> ivalue_to_bundled_inputs(
64       const c10::IValue& bundled_inputs);
65 
66   static std::unordered_map<std::string, std::string>
67   ivalue_to_bundled_inputs_map(const c10::IValue& bundled_inputs);
68 
69   /**
70    * Fetches all the bundled inputs of the loaded mobile model.
71    *
72    * A bundled input itself is of type std::vector<at::IValue> and the
73    * elements of this vector<> are the arguments that the "forward"
74    * method of the model accepts. i.e. each of the at::IValue is a
75    * single argument to the model's "forward" method.
76    *
77    * The outer vector holds a bundled input. For models with bundled
78    * inputs, the outer most vector will have size > 0.
79    */
80   std::vector<std::vector<at::IValue>> get_all_bundled_inputs();
81 
82   /**
83    * Fetches all the bundled inputs for all functions of the loaded mobile
84    * model.
85    *
86    * The mapping is from 'function_names' eg 'forward' to bundled inputs for
87    * that function
88    *
89    * A bundled input itself is of type std::vector<at::IValue> and the
90    * elements of this vector<> are the arguments that the corresponding
91    * method of the model accepts. i.e. each of the at::IValue in the entry
92    * for forward is a single argument to the model's "forward" method.
93    *
94    * The outer vector of each value holds a bundled input. For models with
95    * bundled inputs, the outer most vector will have size > 0.
96    */
97   std::unordered_map<std::string, std::vector<std::vector<at::IValue>>>
98   get_many_functions_bundled_inputs();
99 
100   /**
101    * Returns true if a model possesses get_bundled_inputs_functions_and_info()
102    */
has_new_style_bundled_inputs()103   bool has_new_style_bundled_inputs() const {
104     return module_->find_method("get_bundled_inputs_functions_and_info") !=
105         std::nullopt;
106   }
107 
108   /**
109    * For each tensor in bundled inputs, call the user-provided function 'func'.
110    */
111   void for_each_tensor_in_bundled_inputs(
112       std::function<void(const ::at::Tensor&)> const& func);
113 
114   /**
115    * Get the root operators directly called by this model's Bytecode.
116    */
get_root_operators()117   std::set<std::string> get_root_operators() {
118     return torch::jit::mobile::_export_operator_list(*module_);
119   }
120 
121   /**
122    * Runs the model against all of the provided inputs using the model's
123    * "forward" method. Returns an std::vector<at::IValue>, where each element
124    * of the returned vector is one of the return values from calling forward().
125    */
126   std::vector<at::IValue> run_with_inputs(
127       std::vector<std::vector<at::IValue>> const& bundled_inputs);
128 
129   /**
130    * Runs the model against all of the provided inputs for all the specified
131    * function. Returns an std::vector<at::IValue>, where each element
132    * of the returned vector is one of the return values from calling the
133    * method named "function_name" on this model.
134    */
135   std::vector<at::IValue> run_with_inputs(
136       const std::string& function_name,
137       std::vector<std::vector<at::IValue>> const& bundled_inputs) const;
138 
139   /**
140    * Attempts to run all functions in the passed in list if they exist. All
141    * funcs should require no args
142    */
143   void run_argless_functions(const std::vector<std::string>& functions);
144 };
145 
146 } // namespace torch::jit::mobile
147