xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/throughput_benchmark.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/utils/throughput_benchmark.h>
2 
3 #include <pybind11/pybind11.h>
4 #include <torch/csrc/jit/python/pybind_utils.h>
5 #include <torch/csrc/utils/pybind.h>
6 
7 namespace torch::throughput_benchmark {
8 
operator <<(std::ostream & os,const BenchmarkExecutionStats & value)9 std::ostream& operator<<(
10     std::ostream& os,
11     const BenchmarkExecutionStats& value) {
12   return os << "Average latency / iter (ms): " << value.latency_avg_ms
13             << "\n Total number of iters: " << value.num_iters;
14 }
15 
addInput(py::args args,py::kwargs kwargs)16 void ThroughputBenchmark::addInput(py::args args, py::kwargs kwargs) {
17   CHECK(script_module_.initialized() ^ module_.initialized());
18   if (script_module_.initialized()) {
19     script_module_.addInput(std::move(args), std::move(kwargs));
20   } else {
21     CHECK(module_.initialized());
22     module_.addInput(std::move(args), std::move(kwargs));
23   }
24 }
25 
runOnce(const py::args & args,const py::kwargs & kwargs)26 py::object ThroughputBenchmark::runOnce(
27     const py::args& args,
28     const py::kwargs& kwargs) {
29   CHECK(script_module_.initialized() ^ module_.initialized());
30   if (script_module_.initialized()) {
31     c10::IValue result;
32     {
33       pybind11::gil_scoped_release no_gil_guard;
34       result = script_module_.runOnce(args, kwargs);
35     }
36     return jit::toPyObject(std::move(result));
37   } else {
38     CHECK(module_.initialized());
39     return module_.runOnce(args, kwargs);
40   }
41 }
42 
ThroughputBenchmark(const jit::Module & script_module)43 ThroughputBenchmark::ThroughputBenchmark(const jit::Module& script_module)
44     : script_module_(script_module) {}
45 
ThroughputBenchmark(py::object module)46 ThroughputBenchmark::ThroughputBenchmark(py::object module)
47     : module_(std::move(module)) {}
48 
benchmark(const BenchmarkConfig & config) const49 BenchmarkExecutionStats ThroughputBenchmark::benchmark(
50     const BenchmarkConfig& config) const {
51   CHECK(script_module_.initialized() ^ module_.initialized());
52   // Main benchmark thread doesn't hold the GIL after scheduling worker threads
53   // But for now we don't release it as we will be implicitly manipulating with
54   // py::object ref. counts in the case of nn.Module benchmarking.
55   if (script_module_.initialized()) {
56     return script_module_.benchmark(config);
57   } else {
58     CHECK(module_.initialized());
59     TORCH_WARN(
60         "Starting benchmark on an nn.Module. This can be slow due "
61         "to Python GIL.For proper inference simulation you might want to switch to "
62         "a ScriptModule instead");
63     return module_.benchmark(config);
64   }
65 }
66 
67 namespace detail {
68 
69 template <>
runOnce(ScriptModuleInput && input) const70 void ScriptModuleBenchmark::runOnce(ScriptModuleInput&& input) const {
71   CHECK(initialized_);
72   // TODO: provide guarantees that compiler won't optimize this out
73   model_.get_method("forward").function()(std::move(input));
74 }
75 
76 template <>
runOnce(const py::args & args,const py::kwargs & kwargs) const77 ScriptModuleOutput ScriptModuleBenchmark::runOnce(
78     const py::args& args,
79     const py::kwargs& kwargs) const {
80   CHECK(initialized_);
81   auto& function = model_.get_method("forward").function();
82   ScriptModuleInput stack = jit::createStackForSchema(
83       function.getSchema(), args, kwargs, model_._ivalue());
84   return function(std::move(stack));
85 }
86 
87 template <>
runOnce(ModuleInput && input) const88 void ModuleBenchmark::runOnce(ModuleInput&& input) const {
89   CHECK(initialized_);
90   pybind11::gil_scoped_acquire gil_guard;
91   model_(*input.args, **input.kwargs);
92 }
93 
94 template <>
runOnce(const py::args & args,const py::kwargs & kwargs) const95 ModuleOutput ModuleBenchmark::runOnce(
96     const py::args& args,
97     const py::kwargs& kwargs) const {
98   CHECK(initialized_);
99   pybind11::gil_scoped_acquire gil_guard;
100   return model_(*args, **kwargs);
101 }
102 
103 template <>
addInput(py::args && args,py::kwargs && kwargs)104 void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
105   jit::Stack stack = jit::createStackForSchema(
106       model_.get_method("forward").function().getSchema(),
107       args,
108       kwargs,
109       model_._ivalue());
110   inputs_.emplace_back(std::move(stack));
111 }
112 
113 template <>
addInput(ScriptModuleInput && input)114 void ScriptModuleBenchmark::addInput(ScriptModuleInput&& input) {
115   input.insert(input.begin(), model_._ivalue());
116   inputs_.emplace_back(std::move(input));
117 }
118 
119 template <>
addInput(py::args && args,py::kwargs && kwargs)120 void ModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
121   inputs_.emplace_back(std::move(args), std::move(kwargs));
122 }
123 
124 template <>
cloneInput(const ModuleInput & input)125 ModuleInput cloneInput<ModuleInput>(const ModuleInput& input) {
126   pybind11::gil_scoped_acquire gil_guard;
127   py::args args = input.args;
128   py::kwargs kwargs = input.kwargs;
129   return {std::move(args), std::move(kwargs)};
130 }
131 
132 template <>
cloneInput(const ScriptModuleInput & input)133 ScriptModuleInput cloneInput<ScriptModuleInput>(
134     const ScriptModuleInput& input) {
135   return input;
136 }
137 
138 } // namespace detail
139 
140 } // namespace torch::throughput_benchmark
141