1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import functional as F 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests 7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 8*da0073e9SAndroid Build Coastguard Workerimport io 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerclass TestMetalRewritePass(TestCase): 11*da0073e9SAndroid Build Coastguard Worker @staticmethod 12*da0073e9SAndroid Build Coastguard Worker def validate_transformed_module( 13*da0073e9SAndroid Build Coastguard Worker # To please flake 14*da0073e9SAndroid Build Coastguard Worker self, 15*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 16*da0073e9SAndroid Build Coastguard Worker data_shape, 17*da0073e9SAndroid Build Coastguard Worker prepack_removal=False, 18*da0073e9SAndroid Build Coastguard Worker fuse_clamping_ops=False): 19*da0073e9SAndroid Build Coastguard Worker module_instance = self 20*da0073e9SAndroid Build Coastguard Worker scripted_model = torch.jit.script(module_instance) 21*da0073e9SAndroid Build Coastguard Worker scripted_model.eval() 22*da0073e9SAndroid Build Coastguard Worker input_data = torch.normal(1, 20, size=data_shape) 23*da0073e9SAndroid Build Coastguard Worker ref_result = scripted_model(input_data) 24*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_metal_insert_prepacked_ops(scripted_model._c) 25*da0073e9SAndroid Build Coastguard Worker if fuse_clamping_ops or prepack_removal: 26*da0073e9SAndroid Build Coastguard Worker scripted_model._c = torch._C._freeze_module(scripted_model._c) 27*da0073e9SAndroid Build Coastguard Worker if fuse_clamping_ops: 28*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv(scripted_model._c) 29*da0073e9SAndroid Build Coastguard Worker if prepack_removal: 30*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_metal_fold_prepacking_ops(scripted_model._c) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 33*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_model, buffer) 34*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 35*da0073e9SAndroid Build Coastguard Worker deserialized_scripted_model = torch.jit.load(buffer) 36*da0073e9SAndroid Build Coastguard Worker for pattern, v in pattern_count_map.items(): 37*da0073e9SAndroid Build Coastguard Worker if (v == 0): 38*da0073e9SAndroid Build Coastguard Worker FileCheck().check(pattern).run(deserialized_scripted_model.graph) 39*da0073e9SAndroid Build Coastguard Worker elif (v == -1): 40*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not(pattern).run(deserialized_scripted_model.graph) 41*da0073e9SAndroid Build Coastguard Worker else: 42*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker def test_conv(self): 45*da0073e9SAndroid Build Coastguard Worker # Conv params 46*da0073e9SAndroid Build Coastguard Worker batch_size = 2 47*da0073e9SAndroid Build Coastguard Worker input_channels_per_group = 6 48*da0073e9SAndroid Build Coastguard Worker height = 16 49*da0073e9SAndroid Build Coastguard Worker width = 16 50*da0073e9SAndroid Build Coastguard Worker output_channels_per_group = 6 51*da0073e9SAndroid Build Coastguard Worker groups = 4 52*da0073e9SAndroid Build Coastguard Worker kernel_h = kernel_w = 3 53*da0073e9SAndroid Build Coastguard Worker stride_h = stride_w = 1 54*da0073e9SAndroid Build Coastguard Worker pad_h = pad_w = 1 55*da0073e9SAndroid Build Coastguard Worker dilation = 1 56*da0073e9SAndroid Build Coastguard Worker input_channels = input_channels_per_group * groups 57*da0073e9SAndroid Build Coastguard Worker output_channels = output_channels_per_group * groups 58*da0073e9SAndroid Build Coastguard Worker kernels = (kernel_h, kernel_w) 59*da0073e9SAndroid Build Coastguard Worker strides = (stride_h, stride_w) 60*da0073e9SAndroid Build Coastguard Worker paddings = (pad_h, pad_w) 61*da0073e9SAndroid Build Coastguard Worker dilations = (dilation, dilation) 62*da0073e9SAndroid Build Coastguard Worker conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) 63*da0073e9SAndroid Build Coastguard Worker conv_bias_shape = (output_channels) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker class Conv2D(torch.nn.Module): 66*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 67*da0073e9SAndroid Build Coastguard Worker super().__init__() 68*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) 69*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) 70*da0073e9SAndroid Build Coastguard Worker self.strides = strides 71*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 72*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 73*da0073e9SAndroid Build Coastguard Worker self.groups = groups 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 76*da0073e9SAndroid Build Coastguard Worker return F.conv2d(x, self.weight, self.bias, 77*da0073e9SAndroid Build Coastguard Worker self.strides, self.paddings, self.dilations, self.groups) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker data_shape = (batch_size, input_channels, height, width) 80*da0073e9SAndroid Build Coastguard Worker pattern_count_map = {"Tensor = aten::conv2d": -1, 81*da0073e9SAndroid Build Coastguard Worker "metal_prepack::conv2d_prepack": 1, 82*da0073e9SAndroid Build Coastguard Worker "metal_prepack::conv2d_run": 1} 83*da0073e9SAndroid Build Coastguard Worker TestMetalRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker class Conv2DRelu(torch.nn.Module): 86*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 87*da0073e9SAndroid Build Coastguard Worker super().__init__() 88*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) 89*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) 90*da0073e9SAndroid Build Coastguard Worker self.strides = strides 91*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 92*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 93*da0073e9SAndroid Build Coastguard Worker self.groups = groups 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 96*da0073e9SAndroid Build Coastguard Worker o = F.conv2d(x, self.weight, self.bias, 97*da0073e9SAndroid Build Coastguard Worker self.strides, self.paddings, self.dilations, self.groups) 98*da0073e9SAndroid Build Coastguard Worker o = F.relu(o) 99*da0073e9SAndroid Build Coastguard Worker return o 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker data_shape = (batch_size, input_channels, height, width) 102*da0073e9SAndroid Build Coastguard Worker pattern_count_map = {"Tensor = aten::conv2d": -1, 103*da0073e9SAndroid Build Coastguard Worker "metal_prepack::conv2d_prepack": 1, 104*da0073e9SAndroid Build Coastguard Worker "metal_prepack::conv2d_run": 1} 105*da0073e9SAndroid Build Coastguard Worker TestMetalRewritePass.validate_transformed_module( 106*da0073e9SAndroid Build Coastguard Worker Conv2DRelu(), pattern_count_map, data_shape) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker pattern_count_map["aten::relu"] = 1 109*da0073e9SAndroid Build Coastguard Worker pattern_count_map["metal_prepack::conv2d_prepack"] = -1 110*da0073e9SAndroid Build Coastguard Worker TestMetalRewritePass.validate_transformed_module( 111*da0073e9SAndroid Build Coastguard Worker Conv2DRelu(), 112*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 113*da0073e9SAndroid Build Coastguard Worker data_shape, 114*da0073e9SAndroid Build Coastguard Worker prepack_removal=True) 115*da0073e9SAndroid Build Coastguard Worker pattern_count_map["aten::relu"] = -1 116*da0073e9SAndroid Build Coastguard Worker TestMetalRewritePass.validate_transformed_module( 117*da0073e9SAndroid Build Coastguard Worker Conv2DRelu(), 118*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 119*da0073e9SAndroid Build Coastguard Worker data_shape, 120*da0073e9SAndroid Build Coastguard Worker prepack_removal=True, 121*da0073e9SAndroid Build Coastguard Worker fuse_clamping_ops=True) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker class Conv2DHardtanh(torch.nn.Module): 125*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 126*da0073e9SAndroid Build Coastguard Worker super().__init__() 127*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) 128*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) 129*da0073e9SAndroid Build Coastguard Worker self.strides = strides 130*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 131*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 132*da0073e9SAndroid Build Coastguard Worker self.groups = groups 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 135*da0073e9SAndroid Build Coastguard Worker o = F.conv2d(x, self.weight, self.bias, 136*da0073e9SAndroid Build Coastguard Worker self.strides, self.paddings, self.dilations, self.groups) 137*da0073e9SAndroid Build Coastguard Worker o = F.hardtanh(o) 138*da0073e9SAndroid Build Coastguard Worker return o 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker data_shape = (batch_size, input_channels, height, width) 141*da0073e9SAndroid Build Coastguard Worker pattern_count_map = {"Tensor = aten::conv2d": -1, 142*da0073e9SAndroid Build Coastguard Worker "metal_prepack::conv2d_prepack": 1, 143*da0073e9SAndroid Build Coastguard Worker "metal_prepack::conv2d_run": 1} 144*da0073e9SAndroid Build Coastguard Worker TestMetalRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape) 145*da0073e9SAndroid Build Coastguard Worker pattern_count_map["aten::hardtanh"] = 1 146*da0073e9SAndroid Build Coastguard Worker pattern_count_map["metal_prepack::conv2d_prepack"] = -1 147*da0073e9SAndroid Build Coastguard Worker TestMetalRewritePass.validate_transformed_module( 148*da0073e9SAndroid Build Coastguard Worker Conv2DHardtanh(), 149*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 150*da0073e9SAndroid Build Coastguard Worker data_shape, 151*da0073e9SAndroid Build Coastguard Worker prepack_removal=True) 152*da0073e9SAndroid Build Coastguard Worker pattern_count_map["aten::hardtanh"] = -1 153*da0073e9SAndroid Build Coastguard Worker TestMetalRewritePass.validate_transformed_module( 154*da0073e9SAndroid Build Coastguard Worker Conv2DRelu(), 155*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 156*da0073e9SAndroid Build Coastguard Worker data_shape, 157*da0073e9SAndroid Build Coastguard Worker prepack_removal=True, 158*da0073e9SAndroid Build Coastguard Worker fuse_clamping_ops=True) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 161*da0073e9SAndroid Build Coastguard Worker run_tests() 162