1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport unittest 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import functional as F 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests 8*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 9*da0073e9SAndroid Build Coastguard Workerimport io 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker@unittest.skipUnless(torch.is_vulkan_available(), 12*da0073e9SAndroid Build Coastguard Worker "Vulkan backend must be available for these tests.") 13*da0073e9SAndroid Build Coastguard Workerclass TestVulkanRewritePass(TestCase): 14*da0073e9SAndroid Build Coastguard Worker @staticmethod 15*da0073e9SAndroid Build Coastguard Worker def validate_transformed_module( 16*da0073e9SAndroid Build Coastguard Worker # To please flake 17*da0073e9SAndroid Build Coastguard Worker self, 18*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 19*da0073e9SAndroid Build Coastguard Worker data_shape, 20*da0073e9SAndroid Build Coastguard Worker prepack_removal=False, 21*da0073e9SAndroid Build Coastguard Worker fuse_clamping_ops=False): 22*da0073e9SAndroid Build Coastguard Worker module_instance = self 23*da0073e9SAndroid Build Coastguard Worker scripted_model = torch.jit.script(module_instance) 24*da0073e9SAndroid Build Coastguard Worker scripted_model.eval() 25*da0073e9SAndroid Build Coastguard Worker input_data = torch.normal(1, 20, size=data_shape) 26*da0073e9SAndroid Build Coastguard Worker ref_result = scripted_model(input_data) 27*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_vulkan_insert_prepacked_ops(scripted_model._c) 28*da0073e9SAndroid Build Coastguard Worker if fuse_clamping_ops or prepack_removal: 29*da0073e9SAndroid Build Coastguard Worker scripted_model._c = torch._C._freeze_module(scripted_model._c) 30*da0073e9SAndroid Build Coastguard Worker if fuse_clamping_ops: 31*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv(scripted_model._c) 32*da0073e9SAndroid Build Coastguard Worker if prepack_removal: 33*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_vulkan_fold_prepacking_ops(scripted_model._c) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 36*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_model, buffer) 37*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 38*da0073e9SAndroid Build Coastguard Worker deserialized_scripted_model = torch.jit.load(buffer) 39*da0073e9SAndroid Build Coastguard Worker for pattern, v in pattern_count_map.items(): 40*da0073e9SAndroid Build Coastguard Worker if (v == 0): 41*da0073e9SAndroid Build Coastguard Worker FileCheck().check(pattern).run(deserialized_scripted_model.graph) 42*da0073e9SAndroid Build Coastguard Worker elif (v == -1): 43*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not(pattern).run(deserialized_scripted_model.graph) 44*da0073e9SAndroid Build Coastguard Worker else: 45*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker def test_conv(self): 48*da0073e9SAndroid Build Coastguard Worker # Conv params 49*da0073e9SAndroid Build Coastguard Worker batch_size = 2 50*da0073e9SAndroid Build Coastguard Worker input_channels_per_group = 6 51*da0073e9SAndroid Build Coastguard Worker height = 16 52*da0073e9SAndroid Build Coastguard Worker width = 16 53*da0073e9SAndroid Build Coastguard Worker output_channels_per_group = 6 54*da0073e9SAndroid Build Coastguard Worker groups = 4 55*da0073e9SAndroid Build Coastguard Worker kernel_h = kernel_w = 3 56*da0073e9SAndroid Build Coastguard Worker stride_h = stride_w = 1 57*da0073e9SAndroid Build Coastguard Worker pad_h = pad_w = 1 58*da0073e9SAndroid Build Coastguard Worker dilation = 1 59*da0073e9SAndroid Build Coastguard Worker input_channels = input_channels_per_group * groups 60*da0073e9SAndroid Build Coastguard Worker output_channels = output_channels_per_group * groups 61*da0073e9SAndroid Build Coastguard Worker kernels = (kernel_h, kernel_w) 62*da0073e9SAndroid Build Coastguard Worker strides = (stride_h, stride_w) 63*da0073e9SAndroid Build Coastguard Worker paddings = (pad_h, pad_w) 64*da0073e9SAndroid Build Coastguard Worker dilations = (dilation, dilation) 65*da0073e9SAndroid Build Coastguard Worker conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) 66*da0073e9SAndroid Build Coastguard Worker conv_bias_shape = (output_channels) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker class Conv2D(torch.nn.Module): 69*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 70*da0073e9SAndroid Build Coastguard Worker super().__init__() 71*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) 72*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) 73*da0073e9SAndroid Build Coastguard Worker self.strides = strides 74*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 75*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 76*da0073e9SAndroid Build Coastguard Worker self.groups = groups 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 79*da0073e9SAndroid Build Coastguard Worker return F.conv2d(x, self.weight, self.bias, 80*da0073e9SAndroid Build Coastguard Worker self.strides, self.paddings, self.dilations, self.groups) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker data_shape = (batch_size, input_channels, height, width) 83*da0073e9SAndroid Build Coastguard Worker pattern_count_map = {"Tensor = aten::conv2d": -1, 84*da0073e9SAndroid Build Coastguard Worker "vulkan_prepack::conv2d_clamp_prepack": 1, 85*da0073e9SAndroid Build Coastguard Worker "vulkan_prepack::conv2d_clamp_run": 1} 86*da0073e9SAndroid Build Coastguard Worker TestVulkanRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker class Conv2DRelu(torch.nn.Module): 89*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 90*da0073e9SAndroid Build Coastguard Worker super().__init__() 91*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) 92*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) 93*da0073e9SAndroid Build Coastguard Worker self.strides = strides 94*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 95*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 96*da0073e9SAndroid Build Coastguard Worker self.groups = groups 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 99*da0073e9SAndroid Build Coastguard Worker o = F.conv2d(x, self.weight, self.bias, 100*da0073e9SAndroid Build Coastguard Worker self.strides, self.paddings, self.dilations, self.groups) 101*da0073e9SAndroid Build Coastguard Worker o = F.relu(o) 102*da0073e9SAndroid Build Coastguard Worker return o 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker data_shape = (batch_size, input_channels, height, width) 105*da0073e9SAndroid Build Coastguard Worker pattern_count_map = {"Tensor = aten::conv2d": -1, 106*da0073e9SAndroid Build Coastguard Worker "vulkan_prepack::conv2d_clamp_prepack": 1, 107*da0073e9SAndroid Build Coastguard Worker "vulkan_prepack::conv2d_clamp_run": 1} 108*da0073e9SAndroid Build Coastguard Worker TestVulkanRewritePass.validate_transformed_module( 109*da0073e9SAndroid Build Coastguard Worker Conv2DRelu(), pattern_count_map, data_shape) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker pattern_count_map["aten::relu"] = 1 112*da0073e9SAndroid Build Coastguard Worker pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1 113*da0073e9SAndroid Build Coastguard Worker TestVulkanRewritePass.validate_transformed_module( 114*da0073e9SAndroid Build Coastguard Worker Conv2DRelu(), 115*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 116*da0073e9SAndroid Build Coastguard Worker data_shape, 117*da0073e9SAndroid Build Coastguard Worker prepack_removal=True) 118*da0073e9SAndroid Build Coastguard Worker pattern_count_map["aten::relu"] = -1 119*da0073e9SAndroid Build Coastguard Worker TestVulkanRewritePass.validate_transformed_module( 120*da0073e9SAndroid Build Coastguard Worker Conv2DRelu(), 121*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 122*da0073e9SAndroid Build Coastguard Worker data_shape, 123*da0073e9SAndroid Build Coastguard Worker prepack_removal=True, 124*da0073e9SAndroid Build Coastguard Worker fuse_clamping_ops=True) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker class Conv2DHardtanh(torch.nn.Module): 128*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 129*da0073e9SAndroid Build Coastguard Worker super().__init__() 130*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) 131*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) 132*da0073e9SAndroid Build Coastguard Worker self.strides = strides 133*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 134*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 135*da0073e9SAndroid Build Coastguard Worker self.groups = groups 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 138*da0073e9SAndroid Build Coastguard Worker o = F.conv2d(x, self.weight, self.bias, 139*da0073e9SAndroid Build Coastguard Worker self.strides, self.paddings, self.dilations, self.groups) 140*da0073e9SAndroid Build Coastguard Worker o = F.hardtanh(o) 141*da0073e9SAndroid Build Coastguard Worker return o 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker data_shape = (batch_size, input_channels, height, width) 144*da0073e9SAndroid Build Coastguard Worker pattern_count_map = {"Tensor = aten::conv2d": -1, 145*da0073e9SAndroid Build Coastguard Worker "vulkan_prepack::conv2d_clamp_prepack": 1, 146*da0073e9SAndroid Build Coastguard Worker "vulkan_prepack::conv2d_clamp_run": 1} 147*da0073e9SAndroid Build Coastguard Worker TestVulkanRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape) 148*da0073e9SAndroid Build Coastguard Worker pattern_count_map["aten::hardtanh"] = 1 149*da0073e9SAndroid Build Coastguard Worker pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1 150*da0073e9SAndroid Build Coastguard Worker TestVulkanRewritePass.validate_transformed_module( 151*da0073e9SAndroid Build Coastguard Worker Conv2DHardtanh(), 152*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 153*da0073e9SAndroid Build Coastguard Worker data_shape, 154*da0073e9SAndroid Build Coastguard Worker prepack_removal=True) 155*da0073e9SAndroid Build Coastguard Worker pattern_count_map["aten::hardtanh"] = -1 156*da0073e9SAndroid Build Coastguard Worker TestVulkanRewritePass.validate_transformed_module( 157*da0073e9SAndroid Build Coastguard Worker Conv2DRelu(), 158*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 159*da0073e9SAndroid Build Coastguard Worker data_shape, 160*da0073e9SAndroid Build Coastguard Worker prepack_removal=True, 161*da0073e9SAndroid Build Coastguard Worker fuse_clamping_ops=True) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 164*da0073e9SAndroid Build Coastguard Worker run_tests() 165