xref: /aosp_15_r20/external/executorch/exir/tests/test_common.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport re
10*523fa7a6SAndroid Build Coastguard Workerimport unittest
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport torch
13*523fa7a6SAndroid Build Coastguard Workerimport torch.fx
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.common import extract_out_arguments, get_schema_for_operators
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.print_program import add_cursor_to_graph
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Workerclass TestExirCommon(unittest.TestCase):
20*523fa7a6SAndroid Build Coastguard Worker    def test_get_schema_for_operators(self) -> None:
21*523fa7a6SAndroid Build Coastguard Worker        op_list = [
22*523fa7a6SAndroid Build Coastguard Worker            "torch.ops._caffe2.RoIAlign.default",
23*523fa7a6SAndroid Build Coastguard Worker            "torch.ops.aten.add.Tensor",
24*523fa7a6SAndroid Build Coastguard Worker            "torch.ops.aten.batch_norm.default",
25*523fa7a6SAndroid Build Coastguard Worker            "torch.ops.aten.cat.default",
26*523fa7a6SAndroid Build Coastguard Worker            "torch.ops.aten.clamp.default",
27*523fa7a6SAndroid Build Coastguard Worker        ]
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker        schemas = get_schema_for_operators(op_list)
30*523fa7a6SAndroid Build Coastguard Worker        pat = re.compile(r"[^\(]+\([^\)]+\) -> ")
31*523fa7a6SAndroid Build Coastguard Worker        for _op_name, schema in schemas.items():
32*523fa7a6SAndroid Build Coastguard Worker            self.assertIsNotNone(re.match(pat, schema))
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker    def test_get_out_args(self) -> None:
35*523fa7a6SAndroid Build Coastguard Worker        schema1 = torch._C.parse_schema(
36*523fa7a6SAndroid Build Coastguard Worker            "aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"
37*523fa7a6SAndroid Build Coastguard Worker        )
38*523fa7a6SAndroid Build Coastguard Worker        schema2 = torch._C.parse_schema(
39*523fa7a6SAndroid Build Coastguard Worker            "split_copy.Tensor_out(Tensor self, int split_size, int dim=0, *, Tensor(a!)[] out) -> ()"
40*523fa7a6SAndroid Build Coastguard Worker        )
41*523fa7a6SAndroid Build Coastguard Worker
42*523fa7a6SAndroid Build Coastguard Worker        out_args_1 = extract_out_arguments(schema1, {"out": torch.ones(5)})
43*523fa7a6SAndroid Build Coastguard Worker        out_args_2 = extract_out_arguments(
44*523fa7a6SAndroid Build Coastguard Worker            schema2, {"out": [torch.ones(5), torch.ones(5)]}
45*523fa7a6SAndroid Build Coastguard Worker        )
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Worker        out_arg_name_1, _ = out_args_1
48*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(out_arg_name_1, "out")
49*523fa7a6SAndroid Build Coastguard Worker
50*523fa7a6SAndroid Build Coastguard Worker        out_arg_name_2, _ = out_args_2
51*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(out_arg_name_2, "out")
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker    def test_add_cursor(self) -> None:
54*523fa7a6SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
55*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
56*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
57*523fa7a6SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 4))
58*523fa7a6SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 5)
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
61*523fa7a6SAndroid Build Coastguard Worker                return self.linear(x + self.param).clamp(min=0.0, max=1.0)
62*523fa7a6SAndroid Build Coastguard Worker
63*523fa7a6SAndroid Build Coastguard Worker        module = MyModule()
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker        from torch.fx import symbolic_trace
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Worker        symbolic_traced = symbolic_trace(module)
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Worker        # Graph we are testing:
70*523fa7a6SAndroid Build Coastguard Worker        # graph():
71*523fa7a6SAndroid Build Coastguard Worker        #   %x : [#users=1] = placeholder[target=x]
72*523fa7a6SAndroid Build Coastguard Worker        #   %param : [#users=1] = get_attr[target=param]
73*523fa7a6SAndroid Build Coastguard Worker        #   %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
74*523fa7a6SAndroid Build Coastguard Worker        # --> %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
75*523fa7a6SAndroid Build Coastguard Worker        #   %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
76*523fa7a6SAndroid Build Coastguard Worker        #   return clamp
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Worker        actual_str = add_cursor_to_graph(
79*523fa7a6SAndroid Build Coastguard Worker            symbolic_traced.graph, list(symbolic_traced.graph.nodes)[3]
80*523fa7a6SAndroid Build Coastguard Worker        )
81*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(actual_str.split("\n")[4].startswith("-->"))
82