1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.exir import to_edge 11from executorch.exir.print_program import inspect_node 12from torch.export import export 13 14 15class TestPrintProgram(unittest.TestCase): 16 def test_inspect_node(self) -> None: 17 class TestModel(torch.nn.Module): 18 def __init__(self): 19 super().__init__() 20 self.conv1 = torch.nn.Conv2d(32, 32, 1) 21 self.conv2 = torch.nn.Conv2d(32, 32, 1) 22 self.conv3 = torch.nn.Conv2d(32, 32, 1) 23 self.gelu = torch.nn.GELU() 24 25 def forward(self, x: torch.Tensor): 26 a = self.conv1(x) 27 b = self.conv2(a) 28 c = self.conv3(a + b) 29 return self.gelu(c) 30 31 class WrapModule(torch.nn.Module): 32 def __init__(self): 33 super().__init__() 34 self.test_model = TestModel() 35 36 def forward(self, x): 37 return self.test_model(x) 38 39 warp_model = WrapModule() 40 example_inputs = (torch.rand(1, 32, 16, 16),) 41 42 exir_exported_program = to_edge(export(warp_model, example_inputs)) 43 number_of_stack_trace = 0 44 for node in exir_exported_program.exported_program().graph.nodes: 45 node_info = inspect_node( 46 exir_exported_program.exported_program().graph, node 47 ) 48 self.assertRegex(node_info, r".*-->.*") 49 if "stack_trace" in node.meta: 50 self.assertRegex( 51 node_info, r".*Traceback \(most recent call last\)\:.*" 52 ) 53 number_of_stack_trace = number_of_stack_trace + 1 54 self.assertGreaterEqual(number_of_stack_trace, 1) 55