1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestSub(unittest.TestCase): 14 class Sub(torch.nn.Module): 15 def __init__(self): 16 super().__init__() 17 18 def forward(self, x, y): 19 z = x - y 20 return z 21 22 class Sub2(torch.nn.Module): 23 def __init__(self): 24 super().__init__() 25 26 def forward(self, x): 27 z = x - x 28 return z 29 30 def _test_sub(self, inputs): 31 ( 32 Tester(self.Sub(), inputs) 33 .export() 34 .check_count({"torch.ops.aten.sub.Tensor": 1}) 35 .to_edge_transform_and_lower() 36 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 37 .check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"]) 38 .to_executorch() 39 .serialize() 40 .run_method_and_compare_outputs() 41 ) 42 43 def test_fp16_sub(self): 44 inputs = ( 45 torch.randn((1, 3)).to(torch.float16), 46 torch.randn((4, 3)).to(torch.float16), 47 ) 48 self._test_sub(inputs) 49 50 def test_fp32_sub(self): 51 inputs = (torch.randn((1, 3)), torch.randn((4, 3))) 52 self._test_sub(inputs) 53 54 @unittest.skip("T171957656 - Quantized sub not implemented.") 55 def _test_qs8_sub(self): 56 inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4)) 57 ( 58 Tester(self.Sub(), inputs) 59 .quantize() 60 .export() 61 .check_count({"torch.ops.aten.sub.Tensor": 1}) 62 .check(["torch.ops.quantized_decomposed"]) 63 .to_edge_transform_and_lower() 64 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 65 .check_not( 66 [ 67 "executorch_exir_dialects_edge__ops_aten_sub_Tensor", 68 "torch.ops.quantized_decomposed", 69 ] 70 ) 71 .to_executorch() 72 .serialize() 73 .run_method_and_compare_outputs() 74 ) 75 76 @unittest.skip("T171957656 - Quantized sub not implemented.") 77 def _test_qs8_sub2(self): 78 inputs = (torch.randn(1, 1, 4, 4),) 79 ( 80 Tester(self.Sub2(), inputs) 81 .quantize() 82 .export() 83 .check_count({"torch.ops.aten.sub.Tensor": 1}) 84 .check(["torch.ops.quantized_decomposed"]) 85 .to_edge_transform_and_lower() 86 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 87 .check_not( 88 [ 89 "executorch_exir_dialects_edge__ops_aten_sub_Tensor", 90 "torch.ops.quantized_decomposed", 91 ] 92 ) 93 .to_executorch() 94 .serialize() 95 .run_method_and_compare_outputs() 96 ) 97 98 @unittest.skip("T171957656 - Quantized sub not implemented.") 99 def _test_qs8_sub3(self): 100 inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)) 101 ( 102 Tester(self.Sub(), inputs) 103 .quantize() 104 .export() 105 .check_count({"torch.ops.aten.sub.Tensor": 1}) 106 .check(["torch.ops.quantized_decomposed"]) 107 .to_edge_transform_and_lower() 108 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 109 .check_not( 110 [ 111 "executorch_exir_dialects_edge__ops_aten_sub_Tensor", 112 "torch.ops.quantized_decomposed", 113 ] 114 ) 115 .to_executorch() 116 .serialize() 117 .run_method_and_compare_outputs() 118 ) 119 120 @unittest.skip("T171957656 - Quantized sub not implemented.") 121 def _test_qs8_sub_relu(self): 122 class Sub(torch.nn.Module): 123 def forward(self, x, y): 124 z = x - y 125 return torch.nn.functional.relu(z) 126 127 inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4)) 128 ( 129 Tester(self.Sub(), inputs) 130 .quantize() 131 .export() 132 .check_count( 133 { 134 "torch.ops.aten.sub.Tensor": 1, 135 "torch.ops.aten.relu.default": 1, 136 } 137 ) 138 .check(["torch.ops.quantized_decomposed"]) 139 .to_edge_transform_and_lower() 140 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 141 .check_not( 142 [ 143 "executorch_exir_dialects_edge__ops_aten_sub_Tensor", 144 "executorch_exir_dialects_edge__ops_aten_relu_default", 145 "torch.ops.quantized_decomposed", 146 ] 147 ) 148 .to_executorch() 149 .serialize() 150 .run_method_and_compare_outputs() 151 ) 152