xref: /aosp_15_r20/external/pytorch/test/jit/test_dtype_analysis.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3from itertools import product
4from typing import Tuple
5from unittest.case import expectedFailure
6
7import torch
8from torch import complex32, float32, float64, int32, int64
9from torch.jit._passes import _property_propagation
10from torch.testing._internal.common_device_type import (
11    instantiate_device_type_tests,
12    ops,
13)
14from torch.testing._internal.common_methods_invocations import (
15    op_db,
16    sample_inputs_adaptive_avg_pool2d,
17    sample_inputs_conv2d,
18    SampleInput,
19)
20from torch.testing._internal.common_utils import first_sample, set_default_dtype
21from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
22from torch.testing._internal.jit_utils import JitTestCase
23
24
25"""
26Dtype Analysis relies on symbolic shape analysis, which is still in beta
27"""
28
29
30if __name__ == "__main__":
31    raise RuntimeError(
32        "This test file is not meant to be run directly, use:\n\n"
33        "\tpython test/test_jit.py TESTNAME\n\n"
34        "instead."
35    )
36
37
38custom_rules_works_list = {
39    "nn.functional.adaptive_avg_pool1d",
40    "nn.functional.adaptive_avg_pool2d",
41    "nn.functional.adaptive_avg_pool3d",
42    "nn.functional.adaptive_max_pool1d",
43    "nn.functional.adaptive_max_pool2d",
44    "avg_pool1d",
45    "avg_pool3d",
46    "conv_transpose2d",
47    "conv1d",
48    "conv2d",
49    "hardswish",
50    "avg_pool2d",
51    "max_pool1d",
52    "max_pool2d",
53    "max_pool3d",
54    "nn.functional.prelu",
55    "batch_norm",
56}
57
58custom_rules_expected_failure_list = {
59    # create_traced_fn generates prim::NumToTensor nodes in graph (not supported yet)
60    "nn.functional.adaptive_max_pool3d",
61}
62
63# These ops seem to not be in opinfos
64custom_rules_not_tested_list = [
65    "conv3d",
66    "conv_tbc",
67    "conv_transpose1d",
68    "conv_transpose3d",
69    "convolution",
70    "_convolution",
71    "max_unpool2d",
72    "max_unpool3d",
73    "reflection_pad1d",
74    "reflection_pad2d",
75    "reflection_pad3d",
76    "replication_pad1d",
77    "replication_pad2d",
78    "replication_pad3d",
79    "upsample_bilinear2d",
80    "upsample_linear1d",
81    "upsample_nearest1d",
82    "upsample_nearest2d",
83    "upsample_nearest3d",
84    "upsample_trilinear3d",
85    "flatten",
86]
87
88
89class TestDtypeBase(JitTestCase):
90    SCALAR = "SCALAR"  # To mark unary vs 0 dim tensor
91
92    def setUp(self):
93        self.prev_symbolic_shapes_test_enabled = (
94            torch._C._jit_symbolic_shapes_test_mode_enabled()
95        )
96        torch._C._jit_set_symbolic_shapes_test_mode(True)
97
98    def tearDown(self):
99        torch._C._jit_set_symbolic_shapes_test_mode(
100            self.prev_symbolic_shapes_test_enabled
101        )
102
103    @staticmethod
104    def node_output_dtypes(graph):
105        dtypes = []
106        for out in graph.outputs():
107            if isinstance(out.type(), torch._C.TensorType):
108                dtypes.append(out.type().dtype())
109            else:
110                dtypes.append(None)
111        return dtypes
112
113    @staticmethod
114    def node_output_dtype_single(graph):
115        dtypes = TestDtypeBase.node_output_dtypes(graph)
116        assert len(dtypes) == 1
117        return dtypes[0]
118
119    def prop_dtype_on_graph(self, graph, example_inputs):
120        # We need to clear shape information because torch.jit.script
121        # will return a cached graph if the function is scripted twice.
122        torch._C._jit_pass_erase_shape_information(graph)
123        _property_propagation.apply_input_props_using_example(graph, example_inputs)
124        torch._C._jit_pass_propagate_shapes_on_graph(graph)
125        torch._C._jit_pass_propagate_dtype(graph)
126
127    def assert_dtype_equal(self, fn, in_shapes, in_dtypes):
128        inputs = [self.get_rand_tensor(s, d) for s, d in zip(in_shapes, in_dtypes)]
129        try:
130            self.assert_dtype_equal_custom_args(fn, inputs)
131        except Exception as e:
132            fail_text = f"Failed for shapes {in_shapes}, and dtypes {in_dtypes}"
133            raise AssertionError(fail_text) from e
134
135    def assert_dtype_equal_custom_args(self, fn, args):
136        try:
137            # Eager execution
138            expected_res = fn(*args)
139        except RuntimeError as e:
140            return
141
142        expected_dtype = expected_res.dtype
143
144        # Run the Dtype Analysis
145        graph = torch.jit.script(fn).graph  # Note this is a cached graph
146        self.prop_dtype_on_graph(graph, args)
147        actual_dtype = self.node_output_dtype_single(graph)
148
149        self.assertEqual(actual_dtype, expected_dtype, "Failed Verification")
150
151    def get_rand_tensor(self, shape, dtype):
152        if shape is self.SCALAR:
153            if dtype is float32:
154                return 1.1
155            elif dtype is int64:
156                return 2
157            else:
158                raise RuntimeError(
159                    "Testing of scalars only supported for fp32 and int64"
160                )
161
162        if dtype in (int32, int64):
163            rand_tensor = torch.randint(0, 10, shape, dtype=dtype)
164        else:
165            rand_tensor = torch.rand(shape, dtype=dtype)
166
167        # Sanity check!
168
169        self.assertEqual(rand_tensor.dtype, dtype)
170        return rand_tensor
171
172
173class TestDtypeAnalysis(TestDtypeBase):
174    def test_unary(self):
175        # Testing the Unary Implementation that uses metatensors
176
177        def relu_inplace(x):
178            return x.relu_()
179
180        def log(x):
181            return torch.log(x)
182
183        functions = [relu_inplace, log]
184
185        input_shapes = [
186            ((2, 2),),  # Simple Case
187            ((0, 2),),  # Size 0 Tensor
188            ((),),  # zerodim
189        ]
190
191        input_dtypes = [
192            (float32,),  # Simple Case
193            (int64,),  # Test how some unary ops implicitly convert to float
194            (complex32,),  # Show we can handle complex vals as well
195        ]
196
197        for fn, in_shapes, in_dtypes in product(functions, input_shapes, input_dtypes):
198            self.assert_dtype_equal(fn, in_shapes, in_dtypes)
199
200    def test_binary_tensors(self):
201        # Testing using Metatensors
202        def add(x, y):
203            return x + y
204
205        def div(x, y):
206            return x / y
207
208        functions = [add, div]
209
210        input_shapes = [
211            ((1, 1, 2), (1, 2)),  # Different Dim, non-zerodim
212            ((), (1, 2)),  # One zerodim
213            ((1, 2), ()),  # Other zerodim
214            ((2, 0, 3), (1, 3)),  # Test a tensor with a dim of 0
215            ((), ()),  # both zerodim
216        ]
217
218        input_dtypes = [
219            (float32, float32),  # Simple Case
220            (int32, int64),  # Size Promotion (compliated case for 0dim tensors)
221            (float32, int32),  # type Promotion
222            (int64, float32),  # Type promotion with size change
223            (float64, complex32),  # Show we can handle complex vals as well
224        ]
225
226        for fn, in_shapes, in_dtypes in product(functions, input_shapes, input_dtypes):
227            self.assert_dtype_equal(fn, in_shapes, in_dtypes)
228
229    def test_binary_scalar(self):
230        # Test the mixing of scalar and non-scalar args
231
232        input_shapes = [
233            ((2, 2), self.SCALAR),  # Non-Zerodim vs scalar
234            ((), self.SCALAR),  # Zerodim vs scalar
235            # Scalar vs Scalar is automatically inferred.
236        ]
237
238        input_dtypes = [
239            (float32, float32),  # Simple Case
240            (int32, int64),  # Size Promotion (compliated case for 0dim tensors)
241            (int32, float32),  # type Promotion
242        ]
243
244        with set_default_dtype(float32):
245            for in_shapes, in_dtypes in product(input_shapes, input_dtypes):
246                scalar_type = in_dtypes[1]
247
248                if scalar_type == float32:
249
250                    def add(x, y: float):
251                        return x + y
252
253                else:
254
255                    def add(x, y: int):
256                        return x + y
257
258                self.assert_dtype_equal(add, in_shapes, in_dtypes)
259
260    def test_custom_rules(self):
261        # Test some of the ops that are not covered by Metatensors
262
263        # Note that unlike the Conv2d module, the function conv2d
264        # does not take dtype/device arguments.
265
266        def conv2d_fn(input, weight, bias):
267            return torch.nn.functional.conv2d(input, weight, bias)
268
269        def adaptive_avg_pool2d_fn(input, output_size: Tuple[int]):
270            return torch._C._nn.adaptive_avg_pool2d(input, output_size)
271
272        for fn, inputs_fn in (
273            (conv2d_fn, sample_inputs_conv2d),
274            (adaptive_avg_pool2d_fn, sample_inputs_adaptive_avg_pool2d),
275        ):
276            for dtype in (torch.int8, torch.float64):
277                # Gets default version for conv2d
278                sample_input: SampleInput = list(inputs_fn(None, "cpu", dtype, False))[
279                    -1
280                ]
281                input_args = [sample_input.input, *sample_input.args]
282                self.assert_dtype_equal_custom_args(fn, input_args)
283
284    def test_conv_no_mixed_args(self):
285        def conv2d_fn(input, weight, bias):
286            return torch.nn.functional.conv2d(input, weight, bias)
287
288        # Now make sure that conv2d doesn't support mixed args
289        conv_ins = sample_inputs_conv2d(None, "cpu", torch.float, False)
290        conv_in = list(conv_ins)[-1]
291        weight, bias = conv_in.args
292        weight = weight.type(torch.long)
293
294        with self.assertRaises(RuntimeError):
295            conv2d_fn(conv_in.input, weight, bias)
296
297        # Check that we also don't propagate
298        graph = torch.jit.script(conv2d_fn).graph  # Note this is a cached graph
299        self.prop_dtype_on_graph(graph, [conv_in.input, weight, bias])
300        actual_dtype = self.node_output_dtype_single(graph)
301        self.assertEqual(actual_dtype, None)
302
303    def test_combined(self):
304        # Test a case with both custom rules and metatensors
305
306        def func(input, weight, bias, y):
307            conv_out = torch.nn.functional.conv2d(input, weight, bias)
308            conv_2 = conv_out + y
309            flattened = torch.flatten(conv_2, start_dim=2)
310            add_res = flattened + y
311            return add_res
312
313        conv_ins = sample_inputs_conv2d(None, "cpu", torch.int8, False)
314        conv_in = list(conv_ins)[-1]
315        y_val = torch.rand((1,), dtype=torch.float32)
316        input_args = [conv_in.input, *conv_in.args, y_val]
317        self.assert_dtype_equal_custom_args(func, input_args)
318
319
320class TestDtypeCustomRules(TestDtypeBase):
321    def assert_output_dtype_equal(self, expected_res, prop_graph):
322        actual_dtype = self.node_output_dtypes(prop_graph)
323        if len(actual_dtype) == 1:
324            # For len=1, there is no tuple packing for expected_res.
325            self.assert_tensor_dtype_equal(expected_res, actual_dtype[0])
326        else:
327            self.assertEqual(len(expected_res), len(actual_dtype))
328            for expected, actual in zip(expected_res, actual_dtype):
329                self.assert_tensor_dtype_equal(expected, actual)
330
331    def assert_tensor_dtype_equal(self, tensor_output, graph_dtype):
332        if not isinstance(tensor_output, torch.Tensor):
333            return
334        self.assertEqual(tensor_output.dtype, graph_dtype)
335
336    def custom_rules_test_base(self, device, dtype, op, allow_eager_fail=False):
337        try:
338            samples = op.sample_inputs(device, dtype, requires_grad=False)
339            sample_input = first_sample(self, samples)
340            input_args = [sample_input.input, *sample_input.args]
341            expected_res = op(*input_args, **sample_input.kwargs)
342
343        except Exception as e:
344            if allow_eager_fail:
345                return
346            else:
347                raise e
348
349        func = op.get_op()
350        traced_fn = create_traced_fn(self, func)
351
352        # Have to run the traced function to actually generate the trace
353        traced_fn(sample_input.input, *sample_input.args, **sample_input.kwargs)
354
355        # Run the Dtype Analysis
356        graph = traced_fn.graph  # Note this is a cached graph
357        input_tensors = [t for t in input_args if isinstance(t, torch.Tensor)]
358        input_tensors += [
359            v for v in sample_input.kwargs.values() if isinstance(v, torch.Tensor)
360        ]
361        self.prop_dtype_on_graph(graph, input_tensors)
362        self.assert_output_dtype_equal(expected_res, graph)
363
364    @ops([op for op in op_db if op.aten_name in custom_rules_works_list])
365    def test_custom_rules(self, device, dtype, op):
366        self.custom_rules_test_base(device, dtype, op)
367
368    @ops([op for op in op_db if op.aten_name in custom_rules_works_list])
369    def test_custom_rules_ints(self, device, dtype, op):
370        # This is done because opinfos currently only runs on floats.
371        # Return fn, inputs_fn for all
372        if dtype == torch.float32:
373            dtype = torch.int32
374        else:
375            dtype = torch.int64
376
377        # Because ints are not always implemented, we need to allow for eager to fail
378        self.custom_rules_test_base(device, dtype, op, allow_eager_fail=True)
379
380    @expectedFailure
381    @ops([op for op in op_db if op.aten_name in custom_rules_expected_failure_list])
382    def test_custom_rules_expected_failure(self, device, dtype, op):
383        self.custom_rules_test_base(device, dtype, op)
384
385
386TestDtypeCustomRulesCPU = None
387# This creates TestDtypeCustomRulesCPU
388instantiate_device_type_tests(TestDtypeCustomRules, globals(), only_for=("cpu",))
389