1 #include <benchmark/benchmark.h>
2 #include <c10/util/irange.h>
3 #include <torch/csrc/jit/passes/xnnpack_rewrite.h>
4 #include <torch/csrc/autograd/generated/variable_factories.h>
5 #include <torch/csrc/jit/api/module.h>
6
7 #include <vector>
8
stateful_conv1d(benchmark::State & state)9 static void stateful_conv1d(benchmark::State& state) {
10 const size_t input_channels = static_cast<size_t>(state.range(0));
11 const size_t output_channels = static_cast<size_t>(state.range(1));
12 const size_t kernel = static_cast<size_t>(state.range(2));
13 const size_t batch_size = static_cast<size_t>(state.range(3));
14 const size_t width = static_cast<size_t>(state.range(4));
15 const bool optimized = static_cast<bool>(state.range(5));
16
17 torch::jit::Module m("m");
18 m.register_parameter("weight_1", torch::rand({output_channels, input_channels, kernel}), false);
19 m.register_parameter("bias_1", torch::rand({output_channels}), false);
20 m.register_parameter("weight_2", torch::rand({output_channels, output_channels, kernel}), false);
21 m.register_parameter("bias_2", torch::rand({output_channels}), false);
22 m.register_parameter("weight_3", torch::rand({output_channels, output_channels, kernel}), false);
23 m.register_parameter("bias_3", torch::rand({output_channels}), false);
24 m.register_parameter("weight_4", torch::rand({output_channels, output_channels, kernel}), false);
25 m.register_parameter("bias_4", torch::rand({output_channels}), false);
26
27 m.define(R"(
28 def forward(self, x):
29 x = torch.conv1d(x, self.weight_1, self.bias_1, 1, 0, 1, 1)
30 x = torch.conv1d(x, self.weight_2, self.bias_2, 1, 0, 1, 1)
31 x = torch.conv1d(x, self.weight_3, self.bias_3, 1, 0, 1, 1)
32 x = torch.conv1d(x, self.weight_4, self.bias_4, 1, 0, 1, 1)
33 return x
34 )");
35
36 std::vector<std::vector<torch::jit::IValue>> inputs;
37 for (const auto i : c10::irange(10)) {
38 inputs.emplace_back(
39 {torch::jit::IValue(torch::rand({batch_size, input_channels, width}))});
40 }
41
42 auto m_cloned = m.clone();
43 torch::jit::transformConv1dToConv2d(m_cloned);
44 auto m_optimized = torch::jit::optimizeForMobile(m_cloned);
45 torch::jit::IValue output;
46
47 if (!optimized) {
48 for (auto _ : state) {
49 for (const auto& input : inputs) {
50 output = m.forward(input);
51 }
52 }
53 } else {
54 for (auto _ : state) {
55 for (const auto& input : inputs) {
56 output = m_optimized.forward(input);
57 }
58 }
59 }
60 }
61
GenerateSizes(benchmark::internal::Benchmark * b)62 static void GenerateSizes(benchmark::internal::Benchmark* b) {
63 b->ArgNames({"Input Channels",
64 "Output Channels",
65 "Kernel",
66 "Batch Size",
67 "Width",
68 "Optimized"});
69
70 for (size_t input_channels = 32; input_channels < 256; input_channels *= 2) {
71 for (size_t output_channels = 32; output_channels < 256; output_channels *= 2) {
72 for (const auto kernel : c10::irange(3, 8)) {
73 for (const auto batch_size : c10::irange(1, 5)) {
74 for (size_t width = 32; width < 256; width *= 2) {
75 b->Args({input_channels, output_channels, kernel, batch_size, width, true});
76 b->Args({input_channels, output_channels, kernel, batch_size, width, false});
77 }
78 }
79 }
80 }
81 }
82 }
83
84 BENCHMARK(stateful_conv1d)->Apply(GenerateSizes);
85 BENCHMARK_MAIN();
86