1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6import logging 7 8from collections import Counter 9from pprint import pformat 10from typing import Any, Iterable, List, Literal, Optional, Tuple, Union 11 12import executorch.backends.xnnpack.test.tester.tester as tester 13 14import numpy as np 15import serializer.tosa_serializer as ts 16 17import torch.fx 18 19from executorch.backends.arm.arm_backend import get_intermediate_path, is_permute_memory 20from executorch.backends.arm.arm_partitioner import ArmPartitioner 21from executorch.backends.arm.quantizer.arm_quantizer import ( 22 ArmQuantizer, 23 get_symmetric_quantization_config, 24) 25from executorch.backends.arm.test.common import ( 26 arm_test_options, 27 current_time_formated, 28 get_option, 29) 30 31from executorch.backends.arm.test.runner_utils import ( 32 _get_input_quantization_params, 33 _get_output_node, 34 _get_output_quantization_params, 35 dbg_tosa_fb_to_json, 36 RunnerUtil, 37) 38from executorch.backends.arm.tosa_mapping import extract_tensor_meta 39 40from executorch.backends.xnnpack.test.tester import Tester 41from executorch.devtools.backend_debug import get_delegation_info 42from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager 43from executorch.exir.backend.compile_spec_schema import CompileSpec 44from executorch.exir.backend.partitioner import Partitioner 45from executorch.exir.lowered_backend_module import LoweredBackendModule 46 47from tabulate import tabulate 48from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec 49from torch.fx import Graph 50 51logger = logging.getLogger(__name__) 52 53 54def _dump_lowered_modules_artifact( 55 path_to_dump: Optional[str], 56 artifact: ExecutorchProgramManager, 57 graph_module: torch.fx.GraphModule, 58): 59 output = "Formated Graph Signature:\n" 60 output += _format_export_graph_signature( 61 artifact.exported_program().graph_signature 62 ) 63 64 def get_output_format(lowered_module) -> str | None: 65 for spec in lowered_module.compile_specs: 66 if spec.key == "output_format": 67 return spec.value.decode() 68 return None 69 70 for node in graph_module.graph.nodes: 71 if node.op == "get_attr" and node.name.startswith("lowered_module_"): 72 lowered_module = getattr(graph_module, node.name) 73 assert isinstance( 74 lowered_module, LoweredBackendModule 75 ), f"Attribute {node.name} must be of type LoweredBackendModule." 76 77 output_format = get_output_format(lowered_module) 78 if output_format == "tosa": 79 tosa_fb = lowered_module.processed_bytes 80 to_print = dbg_tosa_fb_to_json(tosa_fb) 81 to_print = pformat(to_print, compact=True, indent=1) 82 output += f"\nTOSA deserialized {node.name}: \n{to_print}\n" 83 elif output_format == "vela": 84 vela_cmd_stream = lowered_module.processed_bytes 85 output += f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n" 86 else: 87 logger.warning( 88 f"No TOSA nor Vela compile spec found in compile specs of {node.name}." 89 ) 90 continue 91 92 if not output: 93 logger.warning("No output to print generated from artifact.") 94 return 95 96 _dump_str(output, path_to_dump) 97 98 99class Partition(tester.Partition): 100 def dump_artifact(self, path_to_dump: Optional[str]): 101 super().dump_artifact(path_to_dump) 102 _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) 103 104 105class ToEdgeTransformAndLower(tester.ToEdgeTransformAndLower): 106 def dump_artifact(self, path_to_dump: Optional[str]): 107 super().dump_artifact(path_to_dump) 108 _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) 109 110 111class Serialize(tester.Serialize): 112 def __init__(self, runner_util: RunnerUtil, timeout: int = 1): 113 super().__init__() 114 self.runner = runner_util 115 self.runner.set_timeout(timeout) 116 117 def run_artifact(self, inputs): 118 return self.runner.run_corstone(inputs) 119 120 def dump_artifact(self, path_to_dump: Optional[str]): 121 if not path_to_dump: 122 path_to_dump = self.path + "/program.pte" 123 super().dump_artifact(path_to_dump) 124 125 126class ToExecutorch(tester.ToExecutorch): 127 def __init__( 128 self, 129 tosa_test_util: RunnerUtil, 130 dynamic_shapes: Optional[Tuple[Any]] = None, 131 ): 132 super().__init__(dynamic_shapes) 133 self.tosa_test_util = tosa_test_util 134 135 def run_artifact(self, inputs): 136 tosa_output = self.tosa_test_util.run_tosa_ref_model( 137 inputs=inputs, 138 ) 139 return tosa_output 140 141 142class InitialModel(tester.Stage): 143 def __init__(self, model: torch.nn.Module): 144 self.model = model 145 146 def run(self, artifact, inputs=None) -> None: 147 pass 148 149 @property 150 def artifact(self) -> torch.nn.Module: 151 return self.model 152 153 @property 154 def graph_module(self) -> None: 155 return None 156 157 def artifact_str(self) -> str: 158 return str(self.model) 159 160 def run_artifact(self, inputs): 161 return self.model.forward(*inputs) 162 163 164class ArmTester(Tester): 165 def __init__( 166 self, 167 model: torch.nn.Module, 168 example_inputs: Tuple[torch.Tensor], 169 compile_spec: List[CompileSpec] = None, 170 tosa_ref_model_path: str | None = None, 171 ): 172 """ 173 Args: 174 model (torch.nn.Module): The model to test 175 example_inputs (Tuple[torch.Tensor]): Example inputs to the model 176 compile_spec (List[CompileSpec]): The compile spec to use 177 """ 178 179 # Initiate runner_util 180 intermediate_path = get_intermediate_path(compile_spec) 181 self.runner_util = RunnerUtil( 182 intermediate_path=intermediate_path, 183 tosa_ref_model_path=tosa_ref_model_path, 184 ) 185 186 self.compile_spec = compile_spec 187 super().__init__(model, example_inputs) 188 self.pipeline[self.stage_name(InitialModel)] = [ 189 self.stage_name(tester.Quantize), 190 self.stage_name(tester.Export), 191 ] 192 193 # Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry. 194 self.stages[self.stage_name(InitialModel)] = None 195 self._run_stage(InitialModel(self.original_module)) 196 197 def quantize(self, quantize_stage: Optional[tester.Quantize] = None): 198 if quantize_stage is None: 199 quantize_stage = tester.Quantize( 200 ArmQuantizer(), 201 get_symmetric_quantization_config(is_per_channel=False), 202 ) 203 return super().quantize(quantize_stage) 204 205 def to_edge( 206 self, 207 to_edge_stage: Optional[tester.ToEdge] = None, 208 config: Optional[EdgeCompileConfig] = None, 209 ): 210 if to_edge_stage is None: 211 to_edge_stage = tester.ToEdge(config) 212 else: 213 if config is not None: 214 to_edge_stage.edge_compile_conf = config 215 216 # TODO(T182928844): Delegate dim order op to backend. 217 to_edge_stage.edge_compile_conf._skip_dim_order = True 218 return super().to_edge(to_edge_stage) 219 220 def partition(self, partition_stage: Optional[Partition] = None): 221 if partition_stage is None: 222 arm_partitioner = ArmPartitioner(compile_spec=self.compile_spec) 223 partition_stage = Partition(arm_partitioner) 224 return super().partition(partition_stage) 225 226 def to_edge_transform_and_lower( 227 self, 228 to_edge_and_lower_stage: Optional[ToEdgeTransformAndLower] = None, 229 partitioners: Optional[List[Partitioner]] = None, 230 edge_compile_config: Optional[EdgeCompileConfig] = None, 231 ): 232 if to_edge_and_lower_stage is None: 233 if partitioners is None: 234 partitioners = [ArmPartitioner(compile_spec=self.compile_spec)] 235 to_edge_and_lower_stage = ToEdgeTransformAndLower( 236 partitioners, edge_compile_config 237 ) 238 else: 239 if partitioners is not None: 240 to_edge_and_lower_stage.partitioners = partitioners 241 if edge_compile_config is not None: 242 to_edge_and_lower_stage.edge_compile_conf = edge_compile_config 243 to_edge_and_lower_stage.edge_compile_conf._skip_dim_order = True 244 return super().to_edge_transform_and_lower(to_edge_and_lower_stage) 245 246 def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] | None = None): 247 if to_executorch_stage is None: 248 to_executorch_stage = ToExecutorch(self.runner_util) 249 return super().to_executorch(to_executorch_stage) 250 251 def serialize( 252 self, serialize_stage: Optional[Serialize] = None, timeout: int = 120 253 ): 254 if serialize_stage is None: 255 serialize_stage = Serialize(self.runner_util, timeout=timeout) 256 assert ( 257 get_intermediate_path(self.compile_spec) is not None 258 ), "Can't dump serialized file when compile specs do not contain an artifact path." 259 260 return ( 261 super() 262 .serialize(serialize_stage) 263 .dump_artifact(get_intermediate_path(self.compile_spec) + "/program.pte") 264 ) 265 266 def run_method_and_compare_outputs( 267 self, 268 inputs: Optional[Tuple[torch.Tensor]] = None, 269 stage: Optional[str] = None, 270 target_board: Optional[str] = "corstone-300", 271 num_runs=1, 272 atol=1e-03, 273 rtol=1e-03, 274 qtol=0, 275 ): 276 """ 277 Compares the run_artifact output of 'stage' with the output of a reference stage. 278 If the model is quantized, the reference stage is the Quantize stage output. 279 Otherwise, the reference stage is the initial pytorch module. 280 281 Asserts that the outputs are equal (within tolerances). 282 Returns self to allow the function to be run in a test chain. 283 284 Args: 285 stage: (Optional[str]): The name of the stage to compare. 286 The default is the latest run stage. 287 inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data. 288 The default is random data. 289 """ 290 edge_stage = self.stages[self.stage_name(tester.ToEdge)] 291 if edge_stage is None: 292 edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)] 293 assert ( 294 self.runner_util is not None 295 ), "self.tosa_test_util is not initialized, cannot use run_method()" 296 assert ( 297 edge_stage is not None 298 ), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run." 299 300 stage = stage or self.cur 301 test_stage = self.stages[stage] 302 is_quantized = self.stages[self.stage_name(tester.Quantize)] is not None 303 304 exported_program = self.stages[self.stage_name(tester.Export)].artifact 305 edge_program = edge_stage.artifact.exported_program() 306 self.runner_util.init_run( 307 exported_program, 308 edge_program, 309 is_quantized, 310 target_board, 311 ) 312 313 if is_quantized: 314 reference_stage = self.stages[self.stage_name(tester.Quantize)] 315 quantization_scale = self.runner_util.qp_output.scale 316 else: 317 reference_stage = self.stages[self.stage_name(InitialModel)] 318 quantization_scale = None 319 320 logger.info( 321 f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'" 322 ) 323 is_nhwc = is_permute_memory(self.compile_spec) 324 325 # Loop inputs and compare reference stage with the compared stage. 326 for run_iteration in range(num_runs): 327 reference_input = inputs if inputs else next(self.generate_random_inputs()) 328 329 # Test parameters can include constants that are used in eager mode but are already set as attributes 330 # in TOSA. Therefore, only accept torch.Tensor inputs. 331 test_input: list[torch.Tensor] = [] 332 for arg in reference_input: 333 if isinstance(arg, torch.Tensor): 334 test_input.append(arg.clone()) 335 if isinstance(arg, tuple) and isinstance(arg[0], torch.Tensor): 336 test_input.extend([tensor.clone() for tensor in arg]) 337 338 if ( 339 is_nhwc 340 and test_stage == self.stages[self.stage_name(tester.ToExecutorch)] 341 ): 342 test_input = self.transpose_data_format(test_input, "NHWC") 343 344 input_shapes = [ 345 generated_input.shape if hasattr(generated_input, "shape") else (1,) 346 for generated_input in reference_input 347 ] 348 input_shape_str = ", ".join([str(list(i)) for i in input_shapes]) 349 logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}") 350 351 reference_output = reference_stage.run_artifact(reference_input) 352 test_output = tuple(test_stage.run_artifact(test_input)) 353 if ( 354 is_nhwc 355 and test_stage == self.stages[self.stage_name(tester.ToExecutorch)] 356 ): 357 test_output = self.transpose_data_format(test_output, "NCHW") 358 359 self._compare_outputs( 360 reference_output, test_output, quantization_scale, atol, rtol, qtol 361 ) 362 363 return self 364 365 def get_graph(self, stage: str | None = None) -> Graph: 366 if stage is None: 367 stage = self.cur 368 artifact = self.get_artifact(stage) 369 if ( 370 self.cur == self.stage_name(tester.ToEdge) 371 or self.cur == self.stage_name(Partition) 372 or self.cur == self.stage_name(ToEdgeTransformAndLower) 373 ): 374 graph = artifact.exported_program().graph 375 elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name( 376 tester.Quantize 377 ): 378 graph = artifact.graph 379 else: 380 raise RuntimeError( 381 "Can only get a graph from Quantize, ToEdge, Export, and Partition stages." 382 ) 383 384 return graph 385 386 def dump_operator_distribution( 387 self, path_to_dump: Optional[str] = None, print_table: bool = True 388 ): 389 """Dump the distribution of operators in the current stage. 390 In the partition stage, additional information is included such as the number of 391 delegates and the distribution of TOSA operators. 392 Set parameter print_table to False to dump in a parseable format. 393 394 395 Returns self for daisy-chaining. 396 """ 397 line = "#" * 10 398 to_print = f"{line} {self.cur.capitalize()} Operator Distribution {line}\n" 399 400 if ( 401 self.cur 402 in ( 403 self.stage_name(tester.Partition), 404 self.stage_name(ToEdgeTransformAndLower), 405 ) 406 and print_table 407 ): 408 graph_module = self.get_artifact().exported_program().graph_module 409 if print_table: 410 delegation_info = get_delegation_info(graph_module) 411 op_dist = delegation_info.get_operator_delegation_dataframe() 412 else: 413 op_dist = dict(_get_operator_distribution(graph_module.graph)) 414 to_print += _format_dict(op_dist, print_table) 415 to_print += "\n" + _get_tosa_operator_distribution( 416 graph_module, print_table 417 ) 418 to_print += "\n" 419 to_print += delegation_info.get_summary() 420 else: 421 graph = self.get_graph(self.cur) 422 op_dist = dict(_get_operator_distribution(graph)) 423 if print_table: 424 op_dist = { 425 "Operator": list(op_dist), 426 "Count": [op_dist[key] for key in op_dist], 427 } 428 to_print += _format_dict(op_dist, print_table) + "\n" 429 430 _dump_str(to_print, path_to_dump) 431 432 return self 433 434 def dump_dtype_distribution( 435 self, path_to_dump: Optional[str] = None, print_table: bool = True 436 ): 437 """Dump a the distributions of dtypes of nodes and placeholders in the current stage. 438 Set parameter print_table to False to dump in a parseable format. 439 440 Returns self for daisy-chaining. 441 """ 442 443 line = "#" * 10 444 to_print = ( 445 f"{line} {self.cur.capitalize()} Placeholder Dtype Distribution {line}\n" 446 ) 447 448 graph = self.get_graph(self.cur) 449 dtype_dist_placeholders, dtype_dirst_tensors = _get_dtype_distribution(graph) 450 all_dtypes = set(dtype_dist_placeholders.keys()) | set( 451 dtype_dirst_tensors.keys() 452 ) 453 if print_table: 454 dtype_dist = { 455 "Dtype": all_dtypes, 456 "Placeholder Count": [ 457 ( 458 dtype_dist_placeholders[key] 459 if key in dtype_dist_placeholders 460 else 0 461 ) 462 for key in all_dtypes 463 ], 464 "Tensor Count": [ 465 (dtype_dirst_tensors[key] if key in dtype_dirst_tensors else 0) 466 for key in all_dtypes 467 ], 468 } 469 else: 470 dtype_dist = dict(dtype_dist_placeholders + dtype_dirst_tensors) 471 to_print += _format_dict(dtype_dist, print_table) + "\n" 472 _dump_str(to_print, path_to_dump) 473 return self 474 475 @staticmethod 476 def _calculate_reference_output( 477 module: Union[torch.fx.GraphModule, torch.nn.Module], inputs 478 ) -> torch.Tensor: 479 """ 480 Note: I'd prefer to use the base class method here, but since it use the 481 exported program, I can't. The partitioner stage clears the state_dict 482 of the exported program, which causes an issue when evaluating the 483 module. 484 """ 485 486 return module.forward(*inputs) 487 488 def transpose_data_format( 489 self, data: Tuple[torch.Tensor], to: Literal["NHWC", "NCHW"] 490 ): 491 if to == "NCHW": 492 dim_order = (0, 3, 1, 2) 493 if to == "NHWC": 494 dim_order = (0, 2, 3, 1) 495 inputs_transposed = list(data) 496 for i in range(len(data)): 497 if hasattr(data[i], "shape") and len(data[i].shape) == 4: 498 inputs_transposed[i] = np.transpose(data[i], dim_order) 499 return tuple(inputs_transposed) 500 501 def _compare_outputs( 502 self, 503 reference_output, 504 stage_output, 505 quantization_scale=None, 506 atol=1e-03, 507 rtol=1e-03, 508 qtol=0, 509 ): 510 try: 511 super()._compare_outputs( 512 reference_output, stage_output, quantization_scale, atol, rtol, qtol 513 ) 514 except AssertionError as e: 515 # Capture assertion error and print more info 516 banner = "=" * 40 + "TOSA debug info" + "=" * 40 517 logger.error(banner) 518 path_to_tosa_files = self.runner_util.intermediate_path 519 520 export_stage = self.stages.get(self.stage_name(tester.Export), None) 521 quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None) 522 if export_stage is not None and quantize_stage is not None: 523 output_node = _get_output_node(export_stage.artifact) 524 qp_input = _get_input_quantization_params(export_stage.artifact) 525 qp_output = _get_output_quantization_params( 526 export_stage.artifact, output_node 527 ) 528 logger.error(f"{qp_input=}") 529 logger.error(f"{qp_output=}") 530 531 logger.error(f"{path_to_tosa_files=}") 532 import os 533 534 torch.save( 535 stage_output, 536 os.path.join(path_to_tosa_files, "torch_tosa_output.pt"), 537 ) 538 torch.save( 539 reference_output, 540 os.path.join(path_to_tosa_files, "torch_ref_output.pt"), 541 ) 542 logger.error(f"{atol=}, {rtol=}, {qtol=}") 543 raise e 544 545 546def _get_dtype_distribution(graph: Graph) -> tuple[dict, dict]: 547 """Counts the occurences of placeholder and call_function dtypes in a graph. 548 The result is a tuple of Counters (placeholder_distribution, call_function_distribution) 549 """ 550 placeholder_dtypes = [] 551 call_function_dtypes = [] 552 for node in graph.nodes: 553 if node.op == "placeholder": 554 placeholder_dtypes.append(str(node.meta["val"].dtype)) 555 if node.op == "call_function": 556 if "val" in node.meta: 557 dtype, _, _ = extract_tensor_meta(node.meta) 558 call_function_dtypes.append(ts.DTypeNames[dtype]) 559 return Counter(placeholder_dtypes), Counter(call_function_dtypes) 560 561 562def _get_operator_distribution(graph: Graph) -> dict[str, int]: 563 """Counts the occurences of operator names in a graph. 564 The result is a dict {'operator name':'number of nodes'} 565 """ 566 return Counter( 567 [str(node.target) for node in list(graph.nodes) if node.op == "call_function"] 568 ) 569 570 571def _format_export_graph_signature(signature: ExportGraphSignature) -> str: 572 def specs_dict(specs: list[InputSpec | OutputSpec], title: str): 573 _dict: dict[str, list] = {title: [], "arg": [], "kind": [], "target": []} 574 for i, spec in enumerate(specs): 575 _dict[title].append(i) 576 _dict["arg"].append(spec.arg) 577 _dict["kind"].append(spec.kind) 578 _dict["target"].append(spec.target if spec.target else "-") 579 return _dict 580 581 input_dict = specs_dict(signature.input_specs, "Inputs") 582 output_dict = specs_dict(signature.output_specs, "Outputs") 583 584 return f"{_format_dict(input_dict)}\n{_format_dict(output_dict)}" 585 586 587def _get_tosa_operator_distribution( 588 graph_module: torch.fx.GraphModule, print_table=False 589) -> str: 590 """Counts the occurences of operator names of all lowered modules containing 591 a TOSA flatbuffer. 592 The result is a string with the operator distribution or an error message. 593 """ 594 op_list = [] 595 id = 0 596 while lowered_module := getattr(graph_module, f"lowered_module_{id}", None): 597 for spec in lowered_module.compile_specs: 598 if spec.key != "output_format": 599 continue 600 if spec.value == b"tosa": 601 tosa_fb = lowered_module.processed_bytes 602 tosa_json = dbg_tosa_fb_to_json(tosa_fb) 603 for region in tosa_json["regions"]: 604 for block in region["blocks"]: 605 op_list.extend( 606 [operator["op"] for operator in block["operators"]] 607 ) 608 break 609 elif spec.value == b"vela": 610 return "Can not get operator distribution for Vela command stream." 611 else: 612 return f"Unknown output format '{spec.value}'." 613 id += 1 614 if id == 0: 615 return "No delegate with name 'lowered_module_0 found in graph module." 616 op_dist = dict(Counter(op_list)) 617 op_dist = { 618 "Operator": list(op_dist.keys()), 619 "Count": [item[1] for item in op_dist.items()], 620 } 621 return "TOSA operators:\n" + _format_dict(dict(op_dist), print_table) 622 623 624def _dump_str(to_print: str, path_to_dump: Optional[str] = None): 625 default_dump_path = get_option(arm_test_options.dump_path) 626 if not path_to_dump and default_dump_path: 627 path_to_dump = default_dump_path / f"ArmTester_{current_time_formated()}.log" 628 if path_to_dump: 629 with open(path_to_dump, "a") as fp: 630 fp.write(to_print) 631 else: 632 logger.info(to_print) 633 634 635def _format_dict(to_print: dict, print_table: bool = True) -> str: 636 if isinstance(list(to_print.items())[0], Iterable) and print_table: 637 return tabulate( 638 to_print, headers="keys", tablefmt="fancy_grid", maxcolwidths=35 639 ) 640 else: 641 return pformat(to_print, compact=True, indent=1) 642