1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import copy 10import json 11import tempfile 12import unittest 13 14import executorch.exir.tests.models as models 15import torch 16from executorch import exir 17from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite 18from executorch.devtools.bundled_program.core import BundledProgram 19from executorch.devtools.etrecord import generate_etrecord, parse_etrecord 20from executorch.devtools.etrecord._etrecord import ( 21 _get_reference_outputs, 22 ETRecordReservedFileNames, 23) 24from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge 25from torch.export import export 26 27 28# TODO : T154728484 Add test cases to cover multiple entry points 29class TestETRecord(unittest.TestCase): 30 def get_test_model(self): 31 f = models.BasicSinMax() 32 captured_output = exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) 33 captured_output_copy = copy.deepcopy(captured_output) 34 edge_output = captured_output.to_edge( 35 # TODO(gasoon): Remove _use_edge_ops=False once serde is fully migrated to Edge ops 36 exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False) 37 ) 38 edge_output_copy = copy.deepcopy(edge_output) 39 et_output = edge_output.to_executorch() 40 return (captured_output_copy, edge_output_copy, et_output) 41 42 def get_test_model_with_bundled_program(self): 43 f = models.BasicSinMax() 44 inputs = [f.get_random_inputs() for _ in range(2)] 45 m_name = "forward" 46 47 method_test_suites = [ 48 MethodTestSuite( 49 method_name=m_name, 50 test_cases=[ 51 MethodTestCase( 52 inputs=inp, expected_outputs=getattr(f, m_name)(*inp) 53 ) 54 for inp in inputs 55 ], 56 ) 57 ] 58 captured_output = exir.capture(f, inputs[0], exir.CaptureConfig()) 59 captured_output_copy = copy.deepcopy(captured_output) 60 edge_output = captured_output.to_edge( 61 # TODO(gasoon): Remove _use_edge_ops=False once serde is fully migrated to Edge ops 62 exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False) 63 ) 64 edge_output_copy = copy.deepcopy(edge_output) 65 et_output = edge_output.to_executorch() 66 67 bundled_program = BundledProgram(et_output, method_test_suites) 68 return (captured_output_copy, edge_output_copy, bundled_program) 69 70 def get_test_model_with_manager(self): 71 f = models.BasicSinMax() 72 aten_dialect = export(f, f.get_random_inputs()) 73 edge_program: EdgeProgramManager = to_edge( 74 aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False) 75 ) 76 edge_program_copy = copy.deepcopy(edge_program) 77 return (aten_dialect, edge_program_copy, edge_program.to_executorch()) 78 79 # Serialized and deserialized graph modules are not completely the same, so we check 80 # that they are close enough and match especially on the parameters we care about in the Developer Tools. 81 def check_graph_closeness(self, graph_a, graph_b): 82 self.assertEqual(len(graph_a.graph.nodes), len(graph_b.graph.nodes)) 83 for node_a, node_b in zip(graph_a.graph.nodes, graph_b.graph.nodes): 84 self.assertEqual(node_a.target, node_b.target) 85 self.assertEqual(len(node_a.args), len(node_b.args)) 86 self.assertEqual(len(node_a.kwargs), len(node_b.kwargs)) 87 self.assertEqual(node_a.name, node_b.name) 88 self.assertEqual(node_a.type, node_b.type) 89 self.assertEqual(node_a.op, node_b.op) 90 if node_a.op not in {"placeholder", "output"}: 91 self.assertEqual( 92 node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle") 93 ) 94 95 def test_etrecord_generation(self): 96 captured_output, edge_output, et_output = self.get_test_model() 97 with tempfile.TemporaryDirectory() as tmpdirname: 98 generate_etrecord( 99 tmpdirname + "/etrecord.bin", 100 edge_output, 101 et_output, 102 { 103 "aten_dialect_output": captured_output, 104 }, 105 ) 106 107 etrecord = parse_etrecord(tmpdirname + "/etrecord.bin") 108 self.check_graph_closeness( 109 etrecord.graph_map["aten_dialect_output/forward"], 110 captured_output.exported_program.graph_module, 111 ) 112 self.check_graph_closeness( 113 etrecord.edge_dialect_program, 114 edge_output.exported_program.graph_module, 115 ) 116 self.assertEqual( 117 etrecord._debug_handle_map, 118 json.loads(json.dumps(et_output.debug_handle_map)), 119 ) 120 121 def test_etrecord_generation_with_bundled_program(self): 122 ( 123 captured_output, 124 edge_output, 125 bundled_program, 126 ) = self.get_test_model_with_bundled_program() 127 with tempfile.TemporaryDirectory() as tmpdirname: 128 generate_etrecord( 129 tmpdirname + "/etrecord.bin", 130 edge_output, 131 bundled_program, 132 { 133 "aten_dialect_output": captured_output, 134 }, 135 ) 136 etrecord = parse_etrecord(tmpdirname + "/etrecord.bin") 137 138 expected = etrecord._reference_outputs 139 actual = _get_reference_outputs(bundled_program) 140 # assertEqual() gives "RuntimeError: Boolean value of Tensor with more than one value is ambiguous" when comparing tensors, 141 # so we use torch.equal() to compare the tensors one by one. 142 self.assertTrue( 143 torch.equal(expected["forward"][0][0], actual["forward"][0][0]) 144 ) 145 self.assertTrue( 146 torch.equal(expected["forward"][1][0], actual["forward"][1][0]) 147 ) 148 149 def test_etrecord_generation_with_manager(self): 150 captured_output, edge_output, et_output = self.get_test_model_with_manager() 151 with tempfile.TemporaryDirectory() as tmpdirname: 152 generate_etrecord( 153 tmpdirname + "/etrecord.bin", 154 edge_output, 155 et_output, 156 ) 157 158 etrecord = parse_etrecord(tmpdirname + "/etrecord.bin") 159 self.check_graph_closeness( 160 etrecord.edge_dialect_program, 161 edge_output.exported_program().graph_module, 162 ) 163 self.assertEqual( 164 etrecord._debug_handle_map, 165 json.loads(json.dumps(et_output.debug_handle_map)), 166 ) 167 168 def test_etrecord_invalid_input(self): 169 captured_output, edge_output, et_output = self.get_test_model() 170 with tempfile.TemporaryDirectory() as tmpdirname: 171 with self.assertRaises(RuntimeError): 172 generate_etrecord( 173 tmpdirname + "/etrecord.bin", 174 edge_output, 175 et_output, 176 {"fail_test_case": et_output}, 177 ) 178 179 def test_etrecord_reserved_name(self): 180 captured_output, edge_output, et_output = self.get_test_model() 181 with tempfile.TemporaryDirectory() as tmpdirname: 182 for reserved_name in ETRecordReservedFileNames: 183 with self.assertRaises(RuntimeError): 184 generate_etrecord( 185 tmpdirname + "/etrecord.bin", 186 edge_output, 187 et_output, 188 {reserved_name: captured_output.exported_program.graph_module}, 189 ) 190