1import operator_benchmark as op_bench 2 3import torch 4 5 6embeddingbag_conversion_short_configs = op_bench.cross_product_configs( 7 num_embeddings=(80,), embedding_dim=(128, 256, 512), tags=("short",) 8) 9 10embeddingbag_conversion_long_configs = op_bench.cross_product_configs( 11 num_embeddings=(100, 120, 1000), 12 embedding_dim=(16, 64, 128, 256, 512, 1024, 2048), 13 tags=("long",), 14) 15 16embeddingbag_conversion_three_dim_configs = op_bench.cross_product_configs( 17 num_embeddings=(80,), 18 embedding_dim=(128, 256, 512), 19 batch_size=(10,), 20 tags=("short",), 21) 22 23conversion_ops = op_bench.op_list( 24 attrs=( 25 ("qembeddingbag_byte_prepack", torch.ops.quantized.embedding_bag_byte_prepack), 26 ("qembeddingbag_4bit_prepack", torch.ops.quantized.embedding_bag_4bit_prepack), 27 ("qembeddingbag_2bit_prepack", torch.ops.quantized.embedding_bag_2bit_prepack), 28 ), 29 attr_names=("op_name", "op_func"), 30) 31 32unpack_ops = op_bench.op_list( 33 attrs=( 34 ("qembeddingbag_byte_unpack", torch.ops.quantized.embedding_bag_byte_unpack), 35 ("qembeddingbag_4bit_unpack", torch.ops.quantized.embedding_bag_4bit_unpack), 36 ("qembeddingbag_2bit_unpack", torch.ops.quantized.embedding_bag_2bit_unpack), 37 ), 38 attr_names=("op_name", "op_func"), 39) 40 41 42class EmbeddingBagFloatToFusedBase(op_bench.TorchBenchmarkBase): 43 def init(self, num_embeddings, embedding_dim, op_func): 44 self.inputs = { 45 "weight": torch.rand(num_embeddings, embedding_dim, dtype=torch.float) + 1 46 } 47 self.op_func = op_func 48 49 def forward(self, weight): 50 return self.op_func(weight) 51 52 53class EmbeddingBagThreeDimFloatToFusedBase(op_bench.TorchBenchmarkBase): 54 def init(self, num_embeddings, embedding_dim, batch_size, op_func): 55 self.inputs = { 56 "weight": torch.rand( 57 batch_size, num_embeddings, embedding_dim, dtype=torch.float 58 ) 59 + 1 60 } 61 self.op_func = op_func 62 63 def forward(self, weight): 64 return self.op_func(weight) 65 66 67class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase): 68 def init(self, num_embeddings, embedding_dim, op_func): 69 weight = torch.randn(num_embeddings, embedding_dim + 8, dtype=torch.float) 70 self.inputs = {"packed_weight": weight.to(torch.uint8)} 71 self.op_func = op_func 72 73 def forward(self, packed_weight): 74 return self.op_func(packed_weight) 75 76 77class EmbeddingBagThreeDimFusedToFloatBase(op_bench.TorchBenchmarkBase): 78 def init(self, num_embeddings, embedding_dim, batch_size, op_func): 79 weight = torch.randn( 80 batch_size, num_embeddings, embedding_dim + 8, dtype=torch.float 81 ) 82 self.inputs = {"packed_weight": weight.to(torch.uint8)} 83 self.op_func = op_func 84 85 def forward(self, packed_weight): 86 return self.op_func(packed_weight) 87 88 89op_bench.generate_pt_tests_from_op_list( 90 conversion_ops, 91 embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs, 92 EmbeddingBagFloatToFusedBase, 93) 94op_bench.generate_pt_tests_from_op_list( 95 unpack_ops, 96 embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs, 97 EmbeddingBagFusedToFloatBase, 98) 99op_bench.generate_pt_tests_from_op_list( 100 conversion_ops, 101 embeddingbag_conversion_three_dim_configs, 102 EmbeddingBagThreeDimFloatToFusedBase, 103) 104op_bench.generate_pt_tests_from_op_list( 105 unpack_ops, 106 embeddingbag_conversion_three_dim_configs, 107 EmbeddingBagThreeDimFusedToFloatBase, 108) 109 110if __name__ == "__main__": 111 op_bench.benchmark_runner.main() 112