xref: /aosp_15_r20/external/pytorch/test/test_ops_jit.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3from functools import partial
4from textwrap import dedent
5
6import torch
7from torch.testing import FileCheck
8from torch.testing._internal.common_device_type import (
9    instantiate_device_type_tests,
10    OpDTypes,
11    ops,
12)
13from torch.testing._internal.common_jit import (
14    check_against_reference,
15    JitCommonTestCase,
16)
17from torch.testing._internal.common_methods_invocations import op_db
18from torch.testing._internal.common_utils import (
19    clone_input_helper,
20    first_sample,
21    IS_SANDCASTLE,
22    run_tests,
23    TestCase,
24    unMarkDynamoStrictTest,
25)
26from torch.testing._internal.jit_metaprogramming_utils import (
27    check_alias_annotation,
28    create_script_fn,
29    create_traced_fn,
30)
31from torch.testing._internal.jit_utils import (
32    disable_autodiff_subgraph_inlining,
33    is_lambda,
34)
35
36
37# variant testing is only done with torch.float and torch.cfloat to avoid
38#   excessive test times and maximize signal to noise ratio
39_variant_ops = partial(
40    ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat)
41)
42
43
44# Tests operators for consistency between JIT and eager, also checks
45#   correctness of JIT specific alias schemas and intended
46#   autodifferentiation behavior.
47# Inherits from JitCommonTestCase instead of TestCase directly to share
48#   functionality with original test_jit.py method operator tests
49@unMarkDynamoStrictTest
50class TestJit(JitCommonTestCase):
51    exact_dtype = True
52
53    # Tests that the forward and backward passes of operations produce the
54    #   same values for the cross-product of op variants (function, method, inplace)
55    #   and runtimes (eager, traced, scripted).
56    # TODO WARNING: inplace x {traced, scripted} not currently tested
57    @_variant_ops(op_db)
58    def test_variant_consistency_jit(self, device, dtype, op):
59        _requires_grad = dtype in op.supported_backward_dtypes(
60            torch.device(device).type
61        )
62
63        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
64        samples = op.sample_inputs(
65            device,
66            dtype,
67            requires_grad=_requires_grad,
68            include_conjugated_inputs=include_conjugated_inputs,
69        )
70
71        # Acquires variants to test
72        func = op.get_op()
73        method = op.get_method()
74        variants = {
75            # TODO: inplace tests currently fail, fix and add inplace variant
76            "function": func,
77            "method": method,
78        }
79
80        # scripting strips the torch.ops prefix from these operators
81        # incorrectly; don't bother testing this case.  Count this
82        # as "testing"
83        if isinstance(func, torch._ops.OpOverload):
84            self.skipTest("variant consistency doesn't work on torch.ops")
85
86        # TODO: find better way to standardize on op registration itself..
87        has_fake_function = op.name in ["resize_", "resize_as_"]
88
89        if has_fake_function:
90            variants = {"method": getattr(torch.Tensor, op.name)}
91            samples = op.sample_inputs(device, dtype, requires_grad=False)
92
93        tested = False
94        for sample in samples:
95            # Test traced and scripted consistency
96            for func_type, variant in variants.items():
97                if variant is None:
98                    continue
99
100                # scripting and check_alias_analysis do not work with lambdas
101                # lambdas are typically used as a way to simulate methods without
102                # functional variants, so rely on the other variant for testing
103                # for now
104                if is_lambda(variant):
105                    continue
106
107                tested = True
108                try:
109                    self.indiv_variant_test_jit(
110                        device, dtype, op, sample, func_type, variant, has_fake_function
111                    )
112                except Exception as e:
113                    variant_error_info = dedent(
114                        f"""
115                        Error testing {op.name} {func_type} variant
116                        with dtype: {dtype}
117                        with inputs {sample}:
118                    """
119                    )
120                    raise Exception(variant_error_info) from e  # noqa: TRY002
121
122        assert tested, "JIT Test does not execute any logic"
123
124    def indiv_variant_test_jit(
125        self, device, dtype, op, sample, func_type, variant, has_fake_function
126    ):
127        _requires_grad = dtype in op.supported_backward_dtypes(
128            torch.device(device).type
129        )
130        support_script = op.supports_scripting
131        # Create accessor for script function variant
132        name = op.name + "_" if func_type == "inplace" else op.name
133
134        # run with disable_autodiff_subgraph_inlining(True) to test
135        #   autodiff support. Context manager forces the graph to contain
136        #   DifferentiableGraph nodes if they are present
137        with disable_autodiff_subgraph_inlining():
138            # Check scripted forward, grad, and grad grad
139            if support_script:
140                script_fn = create_script_fn(self, name, func_type)
141
142            def out_fn(output):
143                # Processes the output for autograd
144                if sample.output_process_fn_grad is not None:
145                    return sample.output_process_fn_grad(output)
146                return output
147
148            def get_sample():
149                return (
150                    clone_input_helper(sample.input)
151                    if op.name[-1] == "_"
152                    else sample.input
153                )
154
155            if support_script:
156                check_against_reference(
157                    self,
158                    script_fn,
159                    op.get_op(),
160                    out_fn,
161                    (get_sample(),) + sample.args,
162                    sample.kwargs,
163                    no_grad=not _requires_grad,
164                    no_gradgrad=not op.supports_gradgrad,
165                )
166
167            # Check traced forward, grad, and grad grad
168            # TODO: fix tracing here
169            supports_tracing = op.supports_tracing and not has_fake_function
170            if op.assert_jit_shape_analysis:
171                self.assertTrue(supports_tracing)
172
173            if supports_tracing:
174                traced_fn = create_traced_fn(self, variant)
175                check_against_reference(
176                    self,
177                    traced_fn,
178                    op.get_op(),
179                    out_fn,
180                    (get_sample(),) + sample.args,
181                    sample.kwargs,
182                    no_grad=not _requires_grad,
183                    no_gradgrad=not op.supports_gradgrad,
184                )
185
186            # Check alias annotation schema for correctness (make
187            #   sure inputs that aren't supposed to be modified aren't)
188            # Note: only runs in float32 because schema isn't affected by dtype,
189            #   so running it on all dtypes is would be excessive
190            if dtype == torch.float32:
191                # TODO: no reason why we cant run this with tracing graph
192                if support_script and op.name != "rsub":
193                    check_alias_annotation(
194                        name,
195                        (get_sample(),) + sample.args,
196                        sample.kwargs,
197                        func_type=func_type,
198                        aten_name=op.aten_name,
199                    )
200
201                # TODO: use script graph as well
202                checked_shape_analysis = False
203                if supports_tracing:
204                    out = variant(get_sample(), *sample.args, **sample.kwargs)
205
206                    # right now, tuple of outputs and tensor output supported
207                    # TODO: list of tensor outputs
208                    tuple_of_tensors = isinstance(out, tuple) and all(
209                        isinstance(elem, torch.Tensor) for elem in out
210                    )
211
212                    if isinstance(out, torch.Tensor) or tuple_of_tensors:
213                        if tuple_of_tensors:
214                            sizes = [elem.size() for elem in out]
215                        else:
216                            sizes = out.size()
217                        self.checkShapeAnalysis(
218                            sizes, traced_fn.graph, op.assert_jit_shape_analysis
219                        )
220                        checked_shape_analysis = True
221                if op.assert_jit_shape_analysis:
222                    self.assertTrue(checked_shape_analysis)
223
224            # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
225            if dtype is torch.float32:
226                # Sandcastle doesn't fuse nodes
227                if IS_SANDCASTLE:
228                    # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
229                    nonfusible_nodes = (
230                        op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
231                    )
232                    fusible_nodes = []
233                else:
234                    nonfusible_nodes = op.autodiff_nonfusible_nodes
235                    fusible_nodes = op.autodiff_fusible_nodes
236
237                if supports_tracing:
238                    self.assertAutodiffNode(
239                        traced_fn.last_graph,
240                        op.assert_autodiffed,
241                        nonfusible_nodes,
242                        fusible_nodes,
243                    )
244                if support_script:
245                    self.assertAutodiffNode(
246                        script_fn.last_graph,
247                        op.assert_autodiffed,
248                        nonfusible_nodes,
249                        fusible_nodes,
250                    )
251
252    # alias testing is only done with torch.float for the same reason
253    _alias_ops = partial(ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float,))
254
255    @_alias_ops(op for op in op_db if op.aliases)
256    def test_jit_alias_remapping(self, device, dtype, op):
257        # NOTE: only tests on first sample
258        samples = op.sample_inputs(device, dtype, requires_grad=True)
259        sample = first_sample(self, samples)
260
261        # [Scripting Data Preparation]
262        # Prepare data for test scripting
263        # Below we prepare strings of args/kwargs with and without type annotations.
264        # These strings are inserted into function template strings which is then torch scripted.
265        # - args string is ["t0"] corresponding to the "input" tensor required by the op
266        # - args_kw is the value of args and strings of kwargs used to call the op (without type annotations), for example,
267        # ["to", "1.0", "(1,)", "True", "tensor(1.0)"] -> def fn(t0): return variant(t0, 1.0, (1,), True, tensor(1.0))
268        args = ["t0"]
269
270        def quote_strs(v):
271            if isinstance(v, str):
272                return f"'{v}'"
273
274            return str(v)
275
276        args_kw = (
277            args
278            + [f"{v}" for v in sample.args]
279            + [f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]
280        )
281
282        # Prepare data for test tracing
283        sample_args_kwargs = ()
284        if len(sample.args) > 0:
285            sample_args_kwargs += (sample.args,)
286        if len(sample.kwargs) > 0:
287            sample_args_kwargs += (sample.kwargs,)
288
289        original_name = op.aten_name
290        original_name_inplace = original_name + "_"
291        expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype
292
293        for a_op in op.aliases:
294            inplace = a_op.inplace_variant
295            method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
296            variants = (
297                v
298                for v in (a_op.op, a_op.method_variant, a_op.inplace_variant)
299                if v is not None
300            )
301
302            # Test scripting:
303            for variant in variants:
304                variant_name = variant.__name__
305                op_name = original_name_inplace if variant is inplace else original_name
306
307                if variant in method_or_inplace:
308                    fn_template = """
309                        def _fn(t0{c}):
310                            return t0.{alias_name}({args_kw})
311                    """
312                    # remove the first input tensor
313                    script = fn_template.format(
314                        c=", " if len(args_kw[1:]) > 1 else "",
315                        args_kw=", ".join(args_kw[1:]),
316                        alias_name=variant_name,
317                    )
318                else:
319                    fn_template = """
320                        def _fn({args}):
321                            return variant({args_kw})
322                    """
323                    script = fn_template.format(
324                        args=", ".join(args),
325                        args_kw=", ".join(args_kw),
326                    )
327
328                # Required to avoid undefined value: tensor error in JIT
329                # compilation of the function template
330                script = script.replace("tensor(", "torch.tensor(")
331
332                scripted = torch.jit.CompilationUnit(script)._fn
333
334                if variant is inplace and not torch.can_cast(expected_dtype, dtype):
335                    try:
336                        inp = clone_input_helper(sample.input)
337                        scripted(inp)
338                    except Exception as e:
339                        continue
340                    self.fail(
341                        "Inplace operation on integer tensor that should be promoted to float didn't fail!"
342                    )
343
344                inp = clone_input_helper(sample.input)
345                scripted(inp)
346                inp = clone_input_helper(sample.input)
347                graph = scripted.graph_for(inp)
348                FileCheck().check(op.aten_name).check_not(variant_name).run(graph)
349
350            # Test tracing:
351            for variant in variants:
352                variant_name = variant.__name__
353                op_name = original_name_inplace if variant is inplace else original_name
354
355                def _fn(*sample_args, **sample_kwargs):
356                    return variant(*sample_args, **sample_kwargs)
357
358                inp = (clone_input_helper(sample.input),) + sample_args_kwargs
359                traced = torch.jit.trace(_fn, *inp)
360                inp = (clone_input_helper(sample.input),) + sample_args_kwargs
361                traced(*inp)
362                inp = (clone_input_helper(sample.input),) + sample_args_kwargs
363                graph = traced.graph_for(*inp)
364                FileCheck().check(op_name).check_not(variant_name).run(graph)
365
366
367instantiate_device_type_tests(TestJit, globals())
368
369if __name__ == "__main__":
370    TestCase._default_dtype_check_enabled = True
371    run_tests()
372