xref: /aosp_15_r20/external/pytorch/benchmarks/operator_benchmark/pt/qembedding_pack_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import operator_benchmark as op_bench
2
3import torch
4
5
6embeddingbag_conversion_short_configs = op_bench.cross_product_configs(
7    num_embeddings=(80,), embedding_dim=(128, 256, 512), tags=("short",)
8)
9
10embeddingbag_conversion_long_configs = op_bench.cross_product_configs(
11    num_embeddings=(100, 120, 1000),
12    embedding_dim=(16, 64, 128, 256, 512, 1024, 2048),
13    tags=("long",),
14)
15
16embeddingbag_conversion_three_dim_configs = op_bench.cross_product_configs(
17    num_embeddings=(80,),
18    embedding_dim=(128, 256, 512),
19    batch_size=(10,),
20    tags=("short",),
21)
22
23conversion_ops = op_bench.op_list(
24    attrs=(
25        ("qembeddingbag_byte_prepack", torch.ops.quantized.embedding_bag_byte_prepack),
26        ("qembeddingbag_4bit_prepack", torch.ops.quantized.embedding_bag_4bit_prepack),
27        ("qembeddingbag_2bit_prepack", torch.ops.quantized.embedding_bag_2bit_prepack),
28    ),
29    attr_names=("op_name", "op_func"),
30)
31
32unpack_ops = op_bench.op_list(
33    attrs=(
34        ("qembeddingbag_byte_unpack", torch.ops.quantized.embedding_bag_byte_unpack),
35        ("qembeddingbag_4bit_unpack", torch.ops.quantized.embedding_bag_4bit_unpack),
36        ("qembeddingbag_2bit_unpack", torch.ops.quantized.embedding_bag_2bit_unpack),
37    ),
38    attr_names=("op_name", "op_func"),
39)
40
41
42class EmbeddingBagFloatToFusedBase(op_bench.TorchBenchmarkBase):
43    def init(self, num_embeddings, embedding_dim, op_func):
44        self.inputs = {
45            "weight": torch.rand(num_embeddings, embedding_dim, dtype=torch.float) + 1
46        }
47        self.op_func = op_func
48
49    def forward(self, weight):
50        return self.op_func(weight)
51
52
53class EmbeddingBagThreeDimFloatToFusedBase(op_bench.TorchBenchmarkBase):
54    def init(self, num_embeddings, embedding_dim, batch_size, op_func):
55        self.inputs = {
56            "weight": torch.rand(
57                batch_size, num_embeddings, embedding_dim, dtype=torch.float
58            )
59            + 1
60        }
61        self.op_func = op_func
62
63    def forward(self, weight):
64        return self.op_func(weight)
65
66
67class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase):
68    def init(self, num_embeddings, embedding_dim, op_func):
69        weight = torch.randn(num_embeddings, embedding_dim + 8, dtype=torch.float)
70        self.inputs = {"packed_weight": weight.to(torch.uint8)}
71        self.op_func = op_func
72
73    def forward(self, packed_weight):
74        return self.op_func(packed_weight)
75
76
77class EmbeddingBagThreeDimFusedToFloatBase(op_bench.TorchBenchmarkBase):
78    def init(self, num_embeddings, embedding_dim, batch_size, op_func):
79        weight = torch.randn(
80            batch_size, num_embeddings, embedding_dim + 8, dtype=torch.float
81        )
82        self.inputs = {"packed_weight": weight.to(torch.uint8)}
83        self.op_func = op_func
84
85    def forward(self, packed_weight):
86        return self.op_func(packed_weight)
87
88
89op_bench.generate_pt_tests_from_op_list(
90    conversion_ops,
91    embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
92    EmbeddingBagFloatToFusedBase,
93)
94op_bench.generate_pt_tests_from_op_list(
95    unpack_ops,
96    embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
97    EmbeddingBagFusedToFloatBase,
98)
99op_bench.generate_pt_tests_from_op_list(
100    conversion_ops,
101    embeddingbag_conversion_three_dim_configs,
102    EmbeddingBagThreeDimFloatToFusedBase,
103)
104op_bench.generate_pt_tests_from_op_list(
105    unpack_ops,
106    embeddingbag_conversion_three_dim_configs,
107    EmbeddingBagThreeDimFusedToFloatBase,
108)
109
110if __name__ == "__main__":
111    op_bench.benchmark_runner.main()
112