1 #include <torch/csrc/utils/throughput_benchmark.h>
2
3 #include <pybind11/pybind11.h>
4 #include <torch/csrc/jit/python/pybind_utils.h>
5 #include <torch/csrc/utils/pybind.h>
6
7 namespace torch::throughput_benchmark {
8
operator <<(std::ostream & os,const BenchmarkExecutionStats & value)9 std::ostream& operator<<(
10 std::ostream& os,
11 const BenchmarkExecutionStats& value) {
12 return os << "Average latency / iter (ms): " << value.latency_avg_ms
13 << "\n Total number of iters: " << value.num_iters;
14 }
15
addInput(py::args args,py::kwargs kwargs)16 void ThroughputBenchmark::addInput(py::args args, py::kwargs kwargs) {
17 CHECK(script_module_.initialized() ^ module_.initialized());
18 if (script_module_.initialized()) {
19 script_module_.addInput(std::move(args), std::move(kwargs));
20 } else {
21 CHECK(module_.initialized());
22 module_.addInput(std::move(args), std::move(kwargs));
23 }
24 }
25
runOnce(const py::args & args,const py::kwargs & kwargs)26 py::object ThroughputBenchmark::runOnce(
27 const py::args& args,
28 const py::kwargs& kwargs) {
29 CHECK(script_module_.initialized() ^ module_.initialized());
30 if (script_module_.initialized()) {
31 c10::IValue result;
32 {
33 pybind11::gil_scoped_release no_gil_guard;
34 result = script_module_.runOnce(args, kwargs);
35 }
36 return jit::toPyObject(std::move(result));
37 } else {
38 CHECK(module_.initialized());
39 return module_.runOnce(args, kwargs);
40 }
41 }
42
ThroughputBenchmark(const jit::Module & script_module)43 ThroughputBenchmark::ThroughputBenchmark(const jit::Module& script_module)
44 : script_module_(script_module) {}
45
ThroughputBenchmark(py::object module)46 ThroughputBenchmark::ThroughputBenchmark(py::object module)
47 : module_(std::move(module)) {}
48
benchmark(const BenchmarkConfig & config) const49 BenchmarkExecutionStats ThroughputBenchmark::benchmark(
50 const BenchmarkConfig& config) const {
51 CHECK(script_module_.initialized() ^ module_.initialized());
52 // Main benchmark thread doesn't hold the GIL after scheduling worker threads
53 // But for now we don't release it as we will be implicitly manipulating with
54 // py::object ref. counts in the case of nn.Module benchmarking.
55 if (script_module_.initialized()) {
56 return script_module_.benchmark(config);
57 } else {
58 CHECK(module_.initialized());
59 TORCH_WARN(
60 "Starting benchmark on an nn.Module. This can be slow due "
61 "to Python GIL.For proper inference simulation you might want to switch to "
62 "a ScriptModule instead");
63 return module_.benchmark(config);
64 }
65 }
66
67 namespace detail {
68
69 template <>
runOnce(ScriptModuleInput && input) const70 void ScriptModuleBenchmark::runOnce(ScriptModuleInput&& input) const {
71 CHECK(initialized_);
72 // TODO: provide guarantees that compiler won't optimize this out
73 model_.get_method("forward").function()(std::move(input));
74 }
75
76 template <>
runOnce(const py::args & args,const py::kwargs & kwargs) const77 ScriptModuleOutput ScriptModuleBenchmark::runOnce(
78 const py::args& args,
79 const py::kwargs& kwargs) const {
80 CHECK(initialized_);
81 auto& function = model_.get_method("forward").function();
82 ScriptModuleInput stack = jit::createStackForSchema(
83 function.getSchema(), args, kwargs, model_._ivalue());
84 return function(std::move(stack));
85 }
86
87 template <>
runOnce(ModuleInput && input) const88 void ModuleBenchmark::runOnce(ModuleInput&& input) const {
89 CHECK(initialized_);
90 pybind11::gil_scoped_acquire gil_guard;
91 model_(*input.args, **input.kwargs);
92 }
93
94 template <>
runOnce(const py::args & args,const py::kwargs & kwargs) const95 ModuleOutput ModuleBenchmark::runOnce(
96 const py::args& args,
97 const py::kwargs& kwargs) const {
98 CHECK(initialized_);
99 pybind11::gil_scoped_acquire gil_guard;
100 return model_(*args, **kwargs);
101 }
102
103 template <>
addInput(py::args && args,py::kwargs && kwargs)104 void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
105 jit::Stack stack = jit::createStackForSchema(
106 model_.get_method("forward").function().getSchema(),
107 args,
108 kwargs,
109 model_._ivalue());
110 inputs_.emplace_back(std::move(stack));
111 }
112
113 template <>
addInput(ScriptModuleInput && input)114 void ScriptModuleBenchmark::addInput(ScriptModuleInput&& input) {
115 input.insert(input.begin(), model_._ivalue());
116 inputs_.emplace_back(std::move(input));
117 }
118
119 template <>
addInput(py::args && args,py::kwargs && kwargs)120 void ModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
121 inputs_.emplace_back(std::move(args), std::move(kwargs));
122 }
123
124 template <>
cloneInput(const ModuleInput & input)125 ModuleInput cloneInput<ModuleInput>(const ModuleInput& input) {
126 pybind11::gil_scoped_acquire gil_guard;
127 py::args args = input.args;
128 py::kwargs kwargs = input.kwargs;
129 return {std::move(args), std::move(kwargs)};
130 }
131
132 template <>
cloneInput(const ScriptModuleInput & input)133 ScriptModuleInput cloneInput<ScriptModuleInput>(
134 const ScriptModuleInput& input) {
135 return input;
136 }
137
138 } // namespace detail
139
140 } // namespace torch::throughput_benchmark
141