xref: /aosp_15_r20/external/pytorch/aten/src/ATen/benchmarks/stateful_conv1d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <benchmark/benchmark.h>
2 #include <c10/util/irange.h>
3 #include <torch/csrc/jit/passes/xnnpack_rewrite.h>
4 #include <torch/csrc/autograd/generated/variable_factories.h>
5 #include <torch/csrc/jit/api/module.h>
6 
7 #include <vector>
8 
stateful_conv1d(benchmark::State & state)9 static void stateful_conv1d(benchmark::State& state) {
10   const size_t input_channels = static_cast<size_t>(state.range(0));
11   const size_t output_channels = static_cast<size_t>(state.range(1));
12   const size_t kernel = static_cast<size_t>(state.range(2));
13   const size_t batch_size = static_cast<size_t>(state.range(3));
14   const size_t width = static_cast<size_t>(state.range(4));
15   const bool optimized = static_cast<bool>(state.range(5));
16 
17   torch::jit::Module m("m");
18   m.register_parameter("weight_1", torch::rand({output_channels, input_channels, kernel}), false);
19   m.register_parameter("bias_1", torch::rand({output_channels}), false);
20   m.register_parameter("weight_2", torch::rand({output_channels, output_channels, kernel}), false);
21   m.register_parameter("bias_2", torch::rand({output_channels}), false);
22   m.register_parameter("weight_3", torch::rand({output_channels, output_channels, kernel}), false);
23   m.register_parameter("bias_3", torch::rand({output_channels}), false);
24   m.register_parameter("weight_4", torch::rand({output_channels, output_channels, kernel}), false);
25   m.register_parameter("bias_4", torch::rand({output_channels}), false);
26 
27   m.define(R"(
28     def forward(self, x):
29       x = torch.conv1d(x, self.weight_1, self.bias_1, 1, 0, 1, 1)
30       x = torch.conv1d(x, self.weight_2, self.bias_2, 1, 0, 1, 1)
31       x = torch.conv1d(x, self.weight_3, self.bias_3, 1, 0, 1, 1)
32       x = torch.conv1d(x, self.weight_4, self.bias_4, 1, 0, 1, 1)
33       return x
34   )");
35 
36   std::vector<std::vector<torch::jit::IValue>> inputs;
37   for (const auto i : c10::irange(10)) {
38     inputs.emplace_back(
39         {torch::jit::IValue(torch::rand({batch_size, input_channels, width}))});
40   }
41 
42   auto m_cloned = m.clone();
43   torch::jit::transformConv1dToConv2d(m_cloned);
44   auto m_optimized = torch::jit::optimizeForMobile(m_cloned);
45   torch::jit::IValue output;
46 
47   if (!optimized) {
48     for (auto _ : state) {
49       for (const auto& input : inputs) {
50         output = m.forward(input);
51       }
52     }
53   } else {
54     for (auto _ : state) {
55       for (const auto& input : inputs) {
56         output = m_optimized.forward(input);
57       }
58     }
59   }
60 }
61 
GenerateSizes(benchmark::internal::Benchmark * b)62 static void GenerateSizes(benchmark::internal::Benchmark* b) {
63   b->ArgNames({"Input Channels",
64                "Output Channels",
65                "Kernel",
66                "Batch Size",
67                "Width",
68                "Optimized"});
69 
70   for (size_t input_channels = 32; input_channels < 256; input_channels *= 2) {
71     for (size_t output_channels = 32; output_channels < 256; output_channels *= 2) {
72       for (const auto kernel : c10::irange(3, 8)) {
73         for (const auto batch_size : c10::irange(1, 5)) {
74           for (size_t width = 32; width < 256; width *= 2) {
75             b->Args({input_channels, output_channels, kernel, batch_size, width, true});
76             b->Args({input_channels, output_channels, kernel, batch_size, width, false});
77           }
78         }
79       }
80     }
81   }
82 }
83 
84 BENCHMARK(stateful_conv1d)->Apply(GenerateSizes);
85 BENCHMARK_MAIN();
86