xref: /aosp_15_r20/external/executorch/exir/tests/test_print_program.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.exir import to_edge
11from executorch.exir.print_program import inspect_node
12from torch.export import export
13
14
15class TestPrintProgram(unittest.TestCase):
16    def test_inspect_node(self) -> None:
17        class TestModel(torch.nn.Module):
18            def __init__(self):
19                super().__init__()
20                self.conv1 = torch.nn.Conv2d(32, 32, 1)
21                self.conv2 = torch.nn.Conv2d(32, 32, 1)
22                self.conv3 = torch.nn.Conv2d(32, 32, 1)
23                self.gelu = torch.nn.GELU()
24
25            def forward(self, x: torch.Tensor):
26                a = self.conv1(x)
27                b = self.conv2(a)
28                c = self.conv3(a + b)
29                return self.gelu(c)
30
31        class WrapModule(torch.nn.Module):
32            def __init__(self):
33                super().__init__()
34                self.test_model = TestModel()
35
36            def forward(self, x):
37                return self.test_model(x)
38
39        warp_model = WrapModule()
40        example_inputs = (torch.rand(1, 32, 16, 16),)
41
42        exir_exported_program = to_edge(export(warp_model, example_inputs))
43        number_of_stack_trace = 0
44        for node in exir_exported_program.exported_program().graph.nodes:
45            node_info = inspect_node(
46                exir_exported_program.exported_program().graph, node
47            )
48            self.assertRegex(node_info, r".*-->.*")
49            if "stack_trace" in node.meta:
50                self.assertRegex(
51                    node_info, r".*Traceback \(most recent call last\)\:.*"
52                )
53                number_of_stack_trace = number_of_stack_trace + 1
54        self.assertGreaterEqual(number_of_stack_trace, 1)
55