1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import copy 10import io 11import logging 12from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union 13 14import torch 15import torch._export 16from executorch.exir._serialize import _serialize_pte_binary 17from executorch.exir._serialize._cord import Cord 18from executorch.exir._warnings import experimental 19from executorch.exir.backend.backend_api import to_backend 20from executorch.exir.backend.partitioner import Partitioner 21from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig 22from executorch.exir.emit import emit_program, EmitterOutput 23from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap 24from executorch.exir.error import ExportError 25from executorch.exir.graph_module import get_control_flow_submodules 26from executorch.exir.pass_base import PassBase 27from executorch.exir.pass_manager import PassType 28from executorch.exir.passes import ( 29 base_post_op_replace_passes, 30 base_pre_op_replace_passes, 31 dead_code_elimination_pass, 32 EdgeToBackendOpsPass, 33 MemoryFormatOpsPass, 34 OpReplacePass, 35) 36from executorch.exir.passes.insert_write_back_for_buffers_pass import ( 37 insert_write_back_for_buffers_pass, 38) 39from executorch.exir.passes.normalize_view_copy_base_pass import ( 40 NormalizeViewCopyBasePass, 41) 42from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass 43from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators 44from executorch.exir.passes.replace_aten_with_edge_pass import aten_to_edge 45from executorch.exir.passes.replace_view_copy_with_view_pass import ( 46 ReplaceViewCopyWithViewPass, 47) 48from executorch.exir.passes.spec_prop_pass import SpecPropPass 49from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass 50from executorch.exir.print_program import pretty_print, print_program 51from executorch.exir.schema import Program 52from executorch.exir.tracer import _default_decomposition_table 53from executorch.exir.verification.verifier import ( 54 EXIRATenDialectVerifier, 55 EXIREdgeDialectVerifier, 56 get_aten_verifier, 57) 58from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass 59from torch.export import ExportedProgram 60from torch.export._remove_auto_functionalized_pass import ( 61 unsafe_remove_auto_functionalized_pass, 62) 63from torch.export.exported_program import ( 64 ConstantArgument, 65 ExportGraphSignature, 66 InputKind, 67 InputSpec, 68 OutputSpec, 69 TensorArgument, 70) 71from torch.fx import _pytree as fx_pytree 72from torch.fx._compatibility import compatibility 73from torch.fx.passes.infra.pass_manager import PassManager 74from torch.utils import _pytree as pytree 75 76Val = Any 77 78from torch.library import Library 79 80# This is the reserved namespace that is used to register ops to that will 81# be prevented from being decomposed during to_edge_transform_and_lower. 82edge_no_decomp_namespace = "EDGE_DO_NOT_DECOMP" 83lib = Library(edge_no_decomp_namespace, "DEF") 84# Map from aten ops to the transformed ops registered in the edge_no_decomp_namespace. 85aten_op_to_transform_op = {} 86# Map from the transformed ops registered in the edge_no_decomp_namespace to aten ops. 87transform_op_to_aten_op = {} 88 89 90def _get_updated_range_constraints(gm): 91 def get_shape_env(gm): 92 vals = [ 93 node.meta["val"] 94 for node in gm.graph.nodes 95 if node.meta.get("val", None) is not None 96 ] 97 from torch._guards import detect_fake_mode # type: ignore[21] 98 99 fake_mode = detect_fake_mode(vals) 100 if fake_mode is not None: 101 return fake_mode.shape_env 102 for v in vals: 103 if isinstance(v, torch.SymInt): 104 return v.node.shape_env 105 106 shape_env = get_shape_env(gm) 107 if shape_env is None: 108 return {} 109 range_constraints = { 110 k: v 111 for k, v in shape_env.var_to_range.items() 112 if k not in shape_env.replacements 113 } 114 # Only when we have an unbacked symint, and it's used as constructor inputs, 115 # runtime_var_to_range will make a difference compated to var_to_range. 116 # e.g. [2, oo) -> [0, oo) 117 for k, v in shape_env.var_to_range.items(): 118 if k not in shape_env.replacements: 119 range_constraints[k] = v 120 return range_constraints 121 122 123def _get_updated_graph_signature( 124 old_signature: ExportGraphSignature, 125 new_gm: torch.fx.GraphModule, 126) -> ExportGraphSignature: 127 """ 128 Update the graph signature's user_input/user_outputs. 129 """ 130 new_input_specs = [] 131 i = 0 132 for node in new_gm.graph.nodes: 133 if node.op != "placeholder": 134 continue 135 136 assert i < len( 137 old_signature.input_specs 138 ), "Number of inputs changed after transformation" 139 old_input_spec = old_signature.input_specs[i] 140 arg = ( 141 old_input_spec.arg 142 if isinstance(old_input_spec.arg, ConstantArgument) 143 # pyre-fixme[20]: Argument `class_fqn` expected. 144 else type(old_input_spec.arg)(node.name) 145 ) 146 new_input_specs.append( 147 InputSpec( 148 old_input_spec.kind, 149 arg, 150 old_input_spec.target, 151 persistent=old_input_spec.persistent, 152 ) 153 ) 154 i += 1 155 156 output_node = list(new_gm.graph.nodes)[-1] 157 assert output_node.op == "output" 158 159 new_output_specs = [] 160 for i, node in enumerate(output_node.args[0]): 161 assert i < len( 162 old_signature.output_specs 163 ), "Number of outputs changed after transformation" 164 old_output_spec = old_signature.output_specs[i] 165 arg = ( 166 old_output_spec.arg 167 if isinstance(old_output_spec.arg, ConstantArgument) 168 # pyre-fixme[20]: Argument `class_fqn` expected. 169 else type(old_output_spec.arg)(node.name) 170 ) 171 new_output_specs.append( 172 OutputSpec(old_output_spec.kind, arg, old_output_spec.target) 173 ) 174 175 new_signature = ExportGraphSignature( 176 input_specs=new_input_specs, output_specs=new_output_specs 177 ) 178 return new_signature 179 180 181def _transform(self, *passes: PassType) -> "ExportedProgram": 182 pm = PassManager(list(passes)) 183 res = pm(self.graph_module) 184 transformed_gm = res.graph_module if res is not None else self.graph_module 185 assert transformed_gm is not None 186 187 if transformed_gm is self.graph_module and not res.modified: 188 return self 189 190 transformed_ep = ExportedProgram( 191 root=transformed_gm, 192 graph=transformed_gm.graph, 193 graph_signature=_get_updated_graph_signature( 194 self.graph_signature, transformed_gm 195 ), 196 state_dict=self.state_dict, 197 range_constraints=_get_updated_range_constraints(transformed_gm), 198 module_call_graph=copy.deepcopy(self._module_call_graph), 199 example_inputs=self.example_inputs, 200 constants=self.constants, 201 verifiers=[self.verifier], 202 ) 203 transformed_ep.graph_module.meta.update(self.graph_module.meta) 204 transformed_ep.graph_module.meta.update(res.graph_module.meta) 205 return transformed_ep 206 207 208def _copy_module(new_prog, new_gm): 209 new_prog.meta.update(new_gm.meta) 210 new_prog.graph = new_gm.graph 211 submodules = [name for name, _ in new_prog.named_children()] 212 for name in submodules: 213 delattr(new_prog, name) 214 for name, mod in new_gm.named_children(): 215 setattr(new_prog, name, mod) 216 for node in new_gm.graph.nodes: 217 if node.op == "get_attr": 218 t = getattr(new_gm, node.target, None) 219 if isinstance(t, torch.Tensor): 220 setattr(new_prog, node.target, t) 221 222 223def lift_constant_tensor_pass(ep): 224 """ 225 Takes an ExportedProgram and returns the ExportedProgram modified in-place, 226 with the constant tensors as buffers. 227 """ 228 if len([node for node in ep.graph.nodes if node.op == "placeholder"]) == 0: 229 return ep 230 231 graph_signature = ep.graph_signature 232 buffers = list(graph_signature.buffers) 233 234 fake_mode = list(ep.graph.nodes)[0].meta["val"].fake_mode 235 first_user_input = None 236 lifted_constants = [] 237 for node in ep.graph.nodes: 238 if node.op == "placeholder" and node.name in graph_signature.user_inputs: 239 first_user_input = node 240 break 241 242 for node in ep.graph.nodes: 243 if node.op == "get_attr": 244 constant_tensor = getattr(ep.graph_module, node.target) 245 if not isinstance(constant_tensor, torch.Tensor): 246 continue 247 248 constant_tensor_fqn = f"_lifted_tensor_constant{len(buffers)}" 249 250 with ep.graph.inserting_before(first_user_input): 251 # Insert the constant node before the first user input 252 const_placeholder_node = ep.graph.placeholder(constant_tensor_fqn) 253 for k, v in node.meta.items(): 254 const_placeholder_node.meta[k] = v 255 if fake_mode is not None: 256 const_placeholder_node.meta["val"] = fake_mode.from_tensor( 257 constant_tensor, static_shapes=True 258 ) 259 else: 260 const_placeholder_node.meta["val"] = constant_tensor 261 const_placeholder_node.meta["val"].constant = constant_tensor 262 node.replace_all_uses_with(const_placeholder_node) 263 ep.graph.erase_node(node) 264 265 # Add the constant as a buffer to the graph signature 266 lifted_constants.append( 267 InputSpec( 268 kind=InputKind.BUFFER, 269 arg=TensorArgument(name=const_placeholder_node.name), 270 target=constant_tensor_fqn, 271 persistent=True, 272 ) 273 ) 274 buffers.append(constant_tensor_fqn) 275 ep.state_dict[constant_tensor_fqn] = constant_tensor 276 277 new_input_specs = [] 278 for s in graph_signature.input_specs: 279 if s.kind == InputKind.USER_INPUT and len(lifted_constants) > 0: 280 new_input_specs.extend(lifted_constants) 281 lifted_constants.clear() 282 new_input_specs.append(s) 283 ep.graph_signature.input_specs = new_input_specs 284 ep.graph_module.recompile() 285 return ep 286 287 288# Stub to ease migration from `transform` to private `_transform` 289def transform_exported_program(ep, *passes: PassType) -> ExportedProgram: 290 if hasattr(ep, "_transform"): 291 return ep._transform(*passes) 292 else: 293 return ep.transform(*passes) 294 295 296class HackedUpExportedProgramDONOTUSE(ExportedProgram): 297 def __init__( 298 self, 299 root, 300 graph, 301 graph_signature, 302 call_spec, 303 state_dict, 304 range_constraints, 305 module_call_graph, 306 example_inputs, 307 verifier, 308 ): 309 super().__init__( 310 root=root, 311 graph=graph, 312 graph_signature=graph_signature, 313 state_dict=state_dict, 314 range_constraints=range_constraints, 315 module_call_graph=module_call_graph, 316 example_inputs=example_inputs, 317 verifier=verifier, 318 ) 319 320 def __call__(self, *args: Any, **kwargs: Any) -> Any: 321 import torch._export.error as error 322 323 if self.call_spec.in_spec is not None: 324 user_args = args 325 try: 326 args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec) # type: ignore[assignment] 327 except Exception: 328 _, received_spec = pytree.tree_flatten(user_args) 329 raise error.InternalError( 330 "Trying to flatten user inputs with exported input tree spec: \n" 331 f"{self.call_spec.in_spec}\n" 332 "but actually got inputs with tree spec of: \n" 333 f"{received_spec}" 334 ) 335 336 ordered_params = tuple( 337 self.state_dict[name] for name in self.graph_signature.parameters 338 ) 339 ordered_buffers = tuple( 340 self.state_dict[name] for name in self.graph_signature.buffers 341 ) 342 343 with torch.no_grad(): 344 # NOTE: calling convention is first params, then buffers, then args as user supplied them. 345 # See: torch/_functorch/aot_autograd.py#L1034 346 res = torch.fx.Interpreter(self.graph_module).run( 347 *ordered_params, *ordered_buffers, *args, enable_io_processing=False 348 ) 349 350 if self.call_spec.out_spec is not None: 351 mutation = self.graph_signature.buffers_to_mutate 352 num_mutated = len(mutation) 353 mutated_buffers = res[:num_mutated] 354 355 # Exclude dependency token from final result. 356 assertion_dep_token = self.graph_signature.assertion_dep_token 357 if assertion_dep_token is not None: 358 assertion_dep_token_index = list(assertion_dep_token.keys())[0] 359 res = res[:assertion_dep_token_index] 360 361 res = res[num_mutated:] 362 try: 363 res = pytree.tree_unflatten(res, self.call_spec.out_spec) 364 except Exception: 365 _, received_spec = pytree.tree_flatten(res) 366 raise error.InternalError( 367 "Trying to flatten user outputs with exported output tree spec: \n" 368 f"{self.call_spec.out_spec}\n" 369 "but actually got outputs with tree spec of: \n" 370 f"{received_spec}" 371 ) 372 finally: 373 ix = 0 374 for buffer in self.graph_signature.buffers_to_mutate.values(): 375 self.state_dict[buffer] = mutated_buffers[ix] 376 ix += 1 377 return res 378 379 380@compatibility(is_backward_compatible=False) 381class ExirExportedProgram: 382 def __init__( 383 self, 384 exported_program: ExportedProgram, 385 after_to_edge_passes: bool, 386 ): 387 self.exported_program = exported_program 388 389 # Add a flag to denote whehter to_edge is called on this program 390 # to detect misusage of directly calling to_executorch without to_edge 391 self.after_to_edge_passes = after_to_edge_passes 392 393 def transform(self, *passes: PassType) -> "ExirExportedProgram": 394 self.exported_program = _transform(self.exported_program, *passes) 395 return self 396 397 def __call__(self, *args: Any) -> Any: 398 return self.exported_program.module()(*args) 399 400 # TODO(ycao): Change this to a composable function. 401 def to_edge( 402 self, config: Optional[EdgeCompileConfig] = None 403 ) -> "ExirExportedProgram": 404 config = config or EdgeCompileConfig() 405 assert isinstance( 406 self.exported_program.graph_module, torch.fx.GraphModule 407 ), f"type is instead: {type(self.exported_program.graph_module).__name__}" 408 409 return _to_edge(self, config) 410 411 def dump(self) -> None: 412 print(self.exported_program.graph_module.graph) 413 414 def to_executorch( 415 self, 416 config: Optional[ExecutorchBackendConfig] = None, 417 ) -> "ExecutorchProgram": 418 if not self.after_to_edge_passes: 419 raise RuntimeError("Must run to_edge before to_executorch.") 420 config = config or ExecutorchBackendConfig() 421 new_gm = self.exported_program.graph_module 422 for p in edge_to_executorch_passes(config): 423 new_gm_res = p(new_gm) 424 assert new_gm_res is not None 425 new_gm = new_gm_res.graph_module 426 427 # This is tech debt on tech debt. memory planning pass inherits from some pass infra for GMs. 428 # This isnt enough info now so i cant use call I have to use some new function 'run'. 429 # Existing user passes dont use run so Im just cheating here because they dont need to work on mutable buffers yet. 430 # After exir.capture is gone I will clean up the memory planning infra to be consistent. 431 # Frankly all of exir has big code quality issues because of the migrations that need to be addressed. 432 new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[29] 433 assert new_gm_res is not None 434 new_gm = new_gm_res.graph_module 435 new_prog = ExirExportedProgram( 436 copy.deepcopy(self.exported_program), self.after_to_edge_passes 437 ) 438 _copy_module(new_prog.exported_program.graph_module, new_gm) 439 executorch_prog = ExecutorchProgram( 440 new_prog, 441 emit_stacktrace=config.emit_stacktrace, 442 extract_delegate_segments=config.extract_delegate_segments, 443 segment_alignment=config.segment_alignment, 444 constant_tensor_alignment=config.constant_tensor_alignment, 445 delegate_alignment=config.delegate_alignment, 446 ) 447 executorch_prog.graph_module.meta.update(new_gm.meta) 448 executorch_prog.graph_module.meta.update( 449 self.exported_program.graph_module.meta 450 ) 451 return executorch_prog 452 453 def __deepcopy__( 454 self, memo: Optional[Dict[int, Any]] = None 455 ) -> "ExirExportedProgram": 456 new_eep = ExirExportedProgram( 457 copy.deepcopy(self.exported_program, memo), 458 self.after_to_edge_passes, 459 ) 460 return new_eep 461 462 463@compatibility(is_backward_compatible=False) 464class ExecutorchProgram: 465 def __init__( 466 self, 467 exir_exported_program: ExirExportedProgram, 468 emit_stacktrace: bool, 469 extract_delegate_segments: bool, 470 segment_alignment: int, 471 constant_tensor_alignment: Optional[int] = None, 472 delegate_alignment: Optional[int] = None, 473 ) -> None: 474 if not exir_exported_program.after_to_edge_passes: 475 raise RuntimeError( 476 "Need to call prog.to_edge prior to constructing ExecutorchProgram." 477 ) 478 self.exported_program = exir_exported_program.exported_program 479 self._pte_data: Optional[Cord] = None 480 self._buffer: Optional[bytes] = None 481 self._emitter_output: Optional[EmitterOutput] = None 482 self._emit_stacktrace: bool = emit_stacktrace 483 self._extract_delegate_segments: bool = extract_delegate_segments 484 self._segment_alignment: int = segment_alignment 485 self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment 486 self._delegate_alignment: Optional[int] = delegate_alignment 487 488 def _get_pte_data(self) -> Cord: 489 if self._pte_data is None: 490 self._pte_data = _serialize_pte_binary( 491 program=self.program, 492 extract_delegate_segments=self._extract_delegate_segments, 493 segment_alignment=self._segment_alignment, 494 constant_tensor_alignment=self._constant_tensor_alignment, 495 delegate_alignment=self._delegate_alignment, 496 ) 497 return self._pte_data 498 499 @property 500 def buffer(self) -> bytes: 501 """Returns the serialized ExecuTorch binary as a byte string. 502 503 Note that the call to `buffer` may allocate a very large amount of 504 contiguous memory, depending on the model size. If writing to a file, 505 use `write_to_file` which won't incur additional copies. 506 """ 507 # TODO(T181494963): update pybinding to remove buffer cache, which can consume large 508 # amounts of memory longer than necessary. 509 if self._buffer is None: 510 self._buffer = bytes(self._get_pte_data()) 511 return self._buffer 512 513 @property 514 def program(self) -> Program: 515 if self._emitter_output is None: 516 self._emitter_output = emit_program( 517 self.exported_program, self._emit_stacktrace 518 ) 519 return self._emitter_output.program 520 521 @property 522 def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]: 523 if self._emitter_output: 524 return self._emitter_output.debug_handle_map 525 return {} 526 527 @property 528 def delegate_map( 529 self, 530 ) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]: 531 if self._emitter_output: 532 return self._emitter_output.method_to_delegate_debug_id_map 533 return {} 534 535 @property 536 def graph_module(self) -> torch.fx.GraphModule: 537 return self.exported_program.graph_module 538 539 # TODO (zhxchen17) Change this to property. 540 def dump_graph_module(self) -> torch.fx.GraphModule: 541 return self.exported_program.graph_module 542 543 def dump_exported_program(self) -> ExportedProgram: 544 return self.exported_program 545 546 def write_to_file(self, open_file: io.BufferedIOBase) -> None: 547 """ 548 Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over 549 `buffer`, as it writes to file without copying into a contiguous block of memory first, 550 reducing the peak memory usage. 551 """ 552 self._get_pte_data().write_to_file(open_file) 553 554 555def _get_aten_to_edge_passes(config: EdgeCompileConfig): 556 # TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable 557 # use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play 558 # well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta. 559 # It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost. 560 561 pre_op_replace_passes = base_pre_op_replace_passes + ( 562 [] if config._skip_type_promotion else [RemoveMixedTypeOperators()] 563 ) 564 565 post_op_replace_passes = base_post_op_replace_passes 566 567 return pre_op_replace_passes, post_op_replace_passes 568 569 570def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram": 571 if config._check_ir_validity: 572 try: 573 EXIRATenDialectVerifier()(ep.exported_program.graph_module) 574 except ExportError: 575 logging.info( 576 "If a particular operator failed core ATen IR check, please consider adding it to the exception list. " 577 "Add the operator to _core_aten_ops_exception_list in EdgeCompileConfig. This is the recommended way " 578 "to resolve this type of failure, so that the rest of the IR validation check can still be performed.\n" 579 "If you'd like to disable IR validation checking, please set _check_ir_validity in EdgeCompileConfig, " 580 "like *.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))." 581 ) 582 raise 583 584 dialect = ep.exported_program.dialect 585 if dialect == "ATEN": 586 ep = ExirExportedProgram( 587 ExportedProgram( 588 root=ep.exported_program.graph_module, 589 graph=ep.exported_program.graph_module.graph, 590 graph_signature=ep.exported_program.graph_signature, 591 state_dict=ep.exported_program.state_dict, 592 range_constraints=ep.exported_program.range_constraints, 593 module_call_graph=ep.exported_program.module_call_graph, 594 example_inputs=ep.exported_program.example_inputs, 595 constants=ep.exported_program.constants, 596 verifiers=[ 597 get_aten_verifier( 598 config=config, 599 ) 600 ], 601 ), 602 False, 603 ) 604 pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config) 605 606 new_ep = copy.deepcopy(ep).transform(*pre_op_replace_passes) 607 if dialect == "ATEN": 608 new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program) 609 610 new_gm = new_ep.exported_program.graph_module 611 if config._use_edge_ops: 612 new_gm_res = OpReplacePass()(new_gm) 613 assert new_gm_res is not None 614 new_gm = new_gm_res.graph_module 615 if not config._skip_dim_order: 616 new_gm_res = MemoryFormatOpsPass()(new_gm) 617 assert new_gm_res is not None 618 new_gm = new_gm_res.graph_module 619 620 for p in post_op_replace_passes: 621 new_gm_res = p(new_gm) 622 assert new_gm_res is not None 623 new_gm = new_gm_res.graph_module 624 625 new_ep.exported_program = ExportedProgram( 626 root=new_gm, 627 graph=new_gm.graph, 628 graph_signature=_get_updated_graph_signature( 629 new_ep.exported_program.graph_signature, new_gm 630 ), 631 state_dict=new_ep.exported_program.state_dict, 632 range_constraints=new_ep.exported_program.range_constraints, 633 module_call_graph=new_ep.exported_program.module_call_graph, 634 example_inputs=new_ep.exported_program.example_inputs, 635 constants=new_ep.exported_program.constants, 636 verifiers=[ 637 EXIREdgeDialectVerifier( 638 edge_compile_config=config, 639 class_only=True, 640 ) 641 ], 642 ) 643 new_ep.after_to_edge_passes = True 644 return new_ep 645 646 647def pre_memory_planning_passes( 648 config: ExecutorchBackendConfig, name: Optional[str] = None 649) -> List[PassType]: 650 """ 651 Returns a list of passes to run before memory planning. 652 Get the sym shape eval pass based on the method name, if the pass is not in the dict, use the default pass. 653 """ 654 # Handle symbolic shape eval pass 655 if isinstance(config.sym_shape_eval_pass, dict): 656 default_pass = ExecutorchBackendConfig().sym_shape_eval_pass 657 if not name: 658 sym_shape_eval_pass = default_pass 659 # pyre-ignore: Undefined attribute [16] 660 sym_shape_eval_pass = config.sym_shape_eval_pass.get(name, default_pass) 661 elif isinstance(config.sym_shape_eval_pass, PassBase): 662 sym_shape_eval_pass = config.sym_shape_eval_pass 663 else: 664 raise RuntimeError( 665 f"sym_shape_eval_pass must be a dict or a PassBase, got {config.sym_shape_eval_pass}" 666 ) 667 if config.remove_view_copy: 668 return [ 669 NormalizeViewCopyBasePass(), 670 dead_code_elimination_pass, 671 ReplaceViewCopyWithViewPass(), 672 sym_shape_eval_pass, 673 config.to_out_var_pass, 674 ] 675 else: 676 return [ 677 sym_shape_eval_pass, 678 config.to_out_var_pass, 679 ] 680 681 682def edge_to_executorch_passes( 683 config: ExecutorchBackendConfig, name: Optional[str] = None 684) -> List[PassType]: 685 """ 686 Returns a list of passes to lower from edge to executorch. 687 Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass. 688 """ 689 passes: List[PassType] = [ 690 *config.passes, 691 SpecPropPass(), 692 # ExecuTorch backend ops are unable to handle unbacked symints. So after 693 # this pass, passes cannot be Interpreter-based, because it will fail if 694 # there exists an unbacked symint operation. 695 EdgeToBackendOpsPass(), 696 RemoveGraphAssertsPass(), 697 ] + pre_memory_planning_passes(config, name) 698 699 return passes 700 701 702def _generate_edge_program( 703 name: str, 704 config: EdgeCompileConfig, 705 program: ExportedProgram, 706 ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None, 707) -> ExportedProgram: 708 if config._check_ir_validity: 709 try: 710 EXIRATenDialectVerifier( 711 edge_compile_config=config, 712 class_only=False, 713 exception_list=ops_set_to_not_decompose, 714 )(program.graph_module) 715 except ExportError as e: 716 logging.info(f"Input program {name} is not in ATen dialect.") 717 raise e 718 719 pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config) 720 721 passes = [] 722 passes.append( 723 ReplaceViewOpsWithViewCopyOpsPass() 724 ) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture 725 passes.extend(pre_op_replace_passes) 726 if config._use_edge_ops: 727 passes.append(OpReplacePass()) 728 if not config._skip_dim_order: 729 passes.append(MemoryFormatOpsPass()) 730 731 gm = program.graph_module 732 for p in passes: 733 gm_res = p(gm) 734 assert gm_res is not None 735 gm = gm_res.graph_module 736 737 edge_program = ExportedProgram( 738 root=gm, 739 graph=gm.graph, 740 graph_signature=_get_updated_graph_signature(program.graph_signature, gm), 741 state_dict=program.state_dict, 742 range_constraints=program.range_constraints, 743 module_call_graph=program.module_call_graph, 744 example_inputs=program.example_inputs, 745 constants=program.constants, 746 verifiers=[ 747 EXIREdgeDialectVerifier( 748 edge_compile_config=config, 749 class_only=True, 750 exception_list=ops_set_to_not_decompose, 751 ) 752 ], 753 ) 754 # Lift the tensor constants created in ScalarToTensorPass 755 edge_program = lift_constant_tensor_pass(edge_program) 756 edge_program = _transform(edge_program, *post_op_replace_passes) 757 758 return edge_program 759 760 761def _replace_aten_ops_with_transformed_ops( 762 name: str, 763 program: ExportedProgram, 764 partitioner, 765): 766 ops_to_not_decompose = set() 767 partitioners = partitioner.get(name) 768 if partitioners is None: 769 return 770 771 # Iterate through the graph and replace the aten ops with the corresponding 772 # transformed ops. 773 for partitioner in partitioners: 774 ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose( 775 program 776 ) 777 778 for op_aten in ops_set_to_not_decompose: 779 _register_no_decomp_op(op_aten) 780 781 for node in program.graph.nodes: 782 is_op_supported = check_op_support(node) if check_op_support else True 783 if ( 784 node.op == "call_function" 785 and node.target in ops_set_to_not_decompose 786 and is_op_supported 787 ): 788 ops_to_not_decompose.add(node.target) 789 node.target = aten_op_to_transform_op[node.target] 790 791 for _, submod, _ in get_control_flow_submodules(program.graph_module): 792 for node in submod.graph.nodes: 793 is_op_supported = check_op_support(node) if check_op_support else True 794 if ( 795 node.op == "call_function" 796 and node.target in ops_set_to_not_decompose 797 and is_op_supported 798 ): 799 ops_to_not_decompose.add(node.target) 800 node.target = aten_op_to_transform_op[node.target] 801 802 return ops_to_not_decompose 803 804 805def _restore_transformed_ops_to_aten_ops(program: ExportedProgram): 806 # Iterate through the graph and replace back the transformed ops with their 807 # corresponding aten ops. 808 for node in program.graph.nodes: 809 if node.op == "call_function" and str(node.target) in transform_op_to_aten_op: 810 node.target = transform_op_to_aten_op[str(node.target)] 811 for _, submod, _ in get_control_flow_submodules(program.graph_module): 812 for node in submod.graph.nodes: 813 if ( 814 node.op == "call_function" 815 and str(node.target) in transform_op_to_aten_op 816 ): 817 node.target = transform_op_to_aten_op[str(node.target)] 818 819 820# Returns the op in edge_no_decomp_namespace namespace for the aten 821# op that is passed in. 822def _get_transformed_op(op_aten): 823 op_name = op_aten._schema.name.split("::")[1] 824 overload_name = op_aten._schema.overload_name 825 assert hasattr( 826 torch.ops, edge_no_decomp_namespace 827 ), f"Couldn't find {edge_no_decomp_namespace} in torch.ops. Please make sure the Library has been registered." 828 op_namespace = getattr(torch.ops, edge_no_decomp_namespace) 829 op = getattr(op_namespace, op_name) 830 return getattr(op, overload_name) 831 832 833# Registers the op in edge_no_decomp_namespace namespace for the aten 834# op that is passed in if it is not already cached in the table. 835def _register_no_decomp_op(op_aten): 836 # Check if the op is already cached in the table. If not, then we need to 837 # create a new op in the edge_no_decomp_namespace namespace. 838 if aten_op_to_transform_op.get(op_aten) is None and isinstance( 839 op_aten, torch._ops.OpOverload 840 ): 841 # Extract the schema from the aten op. 842 op_schema = str(op_aten._schema).split("::")[1] 843 op_name = op_aten._schema.name.split("::")[1] 844 # Define an op in the edge_no_decomp_namespace namespace with the aten schema. 845 lib.define(op_schema) 846 # Define the implementation of the op in the edge_no_decomp_namespace namespace. 847 # Important to note that the implementation of the op is the same as the aten op. 848 849 overload_name = op_aten._schema.overload_name 850 if overload_name != "": 851 op_name += "." + overload_name 852 lib.impl(op_name, op_aten, "CompositeExplicitAutograd") 853 854 # Cache the aten op and transformed op in their corresponding tables for future use. 855 aten_op_to_transform_op[op_aten] = _get_transformed_op(op_aten) 856 transform_op_to_aten_op[str(aten_op_to_transform_op[op_aten])] = op_aten 857 858 859def _sanity_check_graph_for_non_decomp_ops( 860 name: str, 861 program: ExportedProgram, 862 ops_set_to_not_decompose, 863 check_op_support, 864 generate_error=False, 865 partitioner_name=None, 866): 867 warning_str = f"Found {ops_set_to_not_decompose} in edge dialect program {name}." 868 if partitioner_name is not None: 869 warning_str += f" This op was registered by the partitioner {partitioner_name} to not be decomposed." 870 871 # Check that the ops that were registered to not be decomposed are not present in the 872 # graph anymore as the transform passes and backends should have consumed them by now. 873 ops_set_to_not_decompose = { 874 aten_to_edge(op) for op in ops_set_to_not_decompose 875 }.union(ops_set_to_not_decompose) 876 for node in program.graph_module.graph.nodes: 877 is_op_supported = check_op_support(node) if check_op_support else True 878 if ( 879 node.op == "call_function" and node.target in ops_set_to_not_decompose 880 ) and is_op_supported: 881 if generate_error: 882 raise RuntimeError(warning_str) 883 else: 884 logging.warning(warning_str) 885 for _, submod, _ in get_control_flow_submodules(program.graph_module): 886 for node in submod.graph.nodes: 887 is_op_supported = check_op_support(node) if check_op_support else True 888 if ( 889 node.op == "call_function" and node.target in ops_set_to_not_decompose 890 ) and is_op_supported: 891 if generate_error: 892 raise RuntimeError(warning_str) 893 else: 894 logging.warning(warning_str) 895 896 897def _gen_edge_manager_for_partitioners( 898 partitioner: Dict[str, List[Partitioner]], 899 aten_programs: Dict[str, ExportedProgram], 900 config: EdgeCompileConfig, 901 constant_methods: Optional[Dict[str, Any]], 902) -> "EdgeProgramManager": 903 """ 904 Generates EdgeProgramManager for subsequent lowering to the 905 partitioners specified by partitioner. The EdgeProgramManager is generated from 906 aten_programs. 907 908 Partitioners specify what nodes should not be decomposed from the original aten programs. 909 This is done through two passes of run_decompositions. 910 - First pass preserves all aten_targets specified by partitioners to preserve 911 them from nested decompositions 912 - Second pass uses check_op fn provided by partitioners to perform additional checks 913 on nodes with preserved aten targets. They are then replaces with transformed ops to 914 keep them through the second pass of decompositions 915 """ 916 ops_set_to_not_decompose_by_program = {} 917 edge_programs: Dict[str, ExportedProgram] = {} 918 for name, program in aten_programs.items(): 919 if partitioner is not None: 920 # preserve all ops listed by all partitioners first 921 all_ops_no_decomp = set() 922 for curr_partitioner in partitioner.get(name, []): 923 curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program) 924 all_ops_no_decomp |= set(curr_ops_no_decomp) 925 926 table = _default_decomposition_table() 927 928 for op in all_ops_no_decomp: 929 table.pop(op, None) 930 931 program = program.run_decompositions(table) 932 # Among all the preserved aten ops, use the check_op_fn to do an additional 933 # check on which ops need to be preserved and which ops need to be decomposed 934 # Those which are truly preserved will be replaced with transformed ops 935 ops_set_to_not_decompose_by_program[name] = ( 936 _replace_aten_ops_with_transformed_ops(name, program, partitioner) or [] 937 ) 938 program = program.run_decompositions(_default_decomposition_table()) 939 940 _restore_transformed_ops_to_aten_ops(program) 941 942 edge_programs[name] = program 943 944 edge_programs[name] = _generate_edge_program( 945 name, 946 config, 947 program, 948 list(ops_set_to_not_decompose_by_program.get(name, [])), 949 ) 950 951 edge_manager = EdgeProgramManager( 952 edge_programs, 953 constant_methods, 954 config, 955 list(set().union(*ops_set_to_not_decompose_by_program.values())), 956 ) 957 return edge_manager 958 959 960def to_edge_transform_and_lower( 961 programs: Union[ExportedProgram, Dict[str, ExportedProgram]], 962 transform_passes: Optional[ 963 Union[Sequence[PassType], Dict[str, Sequence[PassType]]] 964 ] = None, 965 partitioner: Optional[ 966 Union[List[Partitioner], Dict[str, List[Partitioner]]] 967 ] = None, 968 constant_methods: Optional[Dict[str, Any]] = None, 969 compile_config: Optional[EdgeCompileConfig] = None, 970) -> "EdgeProgramManager": 971 """ 972 :func:`to_edge_transform_and_lower` constructs an EdgeProgramManager from a set of 973 exported programs in ATen dialect. It differs fundamentally from to_edge in that it 974 combines the conversion of the ATen dialect to the edge dialect program, then running 975 the transformation passes and then subsequently lowering the programs to their 976 corresponding backends all into a single API. 977 978 This is fundamentally useful for lowering to backends that have ops registered that they 979 do not want to be decomposed and thus rely on matching with these non-decomposed ops. For 980 these sorts of backends this is the *only* API that should be used to lower to the edge 981 dialect. Using a combination of to_edge(...) and to_backend(...) will result in inconsistent 982 or wrong behavior. 983 984 This API is the primary recommended way to lower to the CPU based XNNPack backend. 985 986 Args: 987 programs: Can be a single ExportedProgram or a dictionary mapping function names 988 to their corresponding ExportedPrograms. If only a single ExportedProgram is 989 provided it will be assigned the name "forward". 990 991 transform_passes: The passes can either be a list of passes, or a dictionary 992 mapping method names to lists of passes. If it is just a list of passes, all methods 993 in the given EdgeProgramManager will be transformed with the provided passes. If it 994 is a dictionary, only method names specified in the dictionary will be transformed 995 with their corresponding passes. 996 997 partitioner: The partitioner can either be a Partitioner subclass instance, or a 998 dictionary mapping method names to Partitioner subclass instance. If it is a 999 Partitioner subclass, all programs in the given EdgeProgramManager will be lowered 1000 using the given partitioner. If it is a dictionary, only method names specified in 1001 the dictionary will be lowered with the given partitioner. 1002 1003 constant_methods: An optional dictionary of method name to the constant value 1004 returned by that method in eager mode. Often used to store config information on 1005 Edge models. 1006 1007 compile_config: An optional argument used to provide greater control over the 1008 transformation to edge dialect process. 1009 1010 Returns: 1011 EdgeProgramManager 1012 """ 1013 assert not isinstance(constant_methods, EdgeCompileConfig) 1014 config = compile_config or EdgeCompileConfig() 1015 if not isinstance(programs, dict): 1016 aten_programs = {"forward": programs} 1017 else: 1018 aten_programs = programs 1019 1020 if not isinstance(partitioner, dict) and partitioner is not None: 1021 partitioner = {name: partitioner for name in aten_programs.keys()} 1022 elif partitioner is None: 1023 partitioner = {name: [] for name in aten_programs.keys()} 1024 1025 edge_manager = _gen_edge_manager_for_partitioners( 1026 partitioner, aten_programs, config, constant_methods 1027 ) 1028 1029 if transform_passes is not None: 1030 edge_manager = edge_manager.transform(transform_passes) 1031 1032 if partitioner is not None: 1033 for name, partitioner_list in partitioner.items(): 1034 for curr_partitioner in partitioner_list: 1035 edge_manager = edge_manager.to_backend({name: curr_partitioner}) 1036 1037 for name, program in edge_manager._edge_programs.items(): 1038 ops_set_to_not_decompose: Set[torch._ops.OpOverload] = set() 1039 partitioners = partitioner.get(name, []) 1040 for curr_partitioner in partitioners: 1041 curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose( 1042 program 1043 ) 1044 ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set) 1045 _sanity_check_graph_for_non_decomp_ops( 1046 name, 1047 program, 1048 ops_set_to_not_decompose, 1049 check_op_support, 1050 partitioner_name=curr_partitioner.__class__.__name__, 1051 generate_error=True, 1052 ) 1053 1054 if config._check_ir_validity: 1055 EXIREdgeDialectVerifier( 1056 edge_compile_config=config, 1057 class_only=True, 1058 exception_list=list(ops_set_to_not_decompose), 1059 )()(program.graph_module) 1060 1061 return edge_manager 1062 1063 1064@experimental( 1065 """ 1066 This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed. 1067 This function will be combined with to_edge in the future. 1068 """ 1069) 1070def to_edge_with_preserved_ops( 1071 programs: Union[ExportedProgram, Dict[str, ExportedProgram]], 1072 constant_methods: Optional[Dict[str, Any]] = None, 1073 compile_config: Optional[EdgeCompileConfig] = None, 1074 preserve_ops: Tuple[torch._ops.OpOverload, ...] = (), 1075) -> "EdgeProgramManager": 1076 """ 1077 :func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in 1078 ATen dialect. Upon construction those programs are transformed into edge dialect. 1079 1080 Args: 1081 programs: Can be a single ExportedProgram or a dictionary mapping function names to their corresponding ExportedPrograms. If only a single ExportedProgram is provided it will be assigned the name "forward". 1082 constant_methods: An optional dictionary of method name to the constant value returned by that method in eager mode. Often used to store config information on Edge models. 1083 compile_config: An optional argument used to provide greater control over the transformation to edge dialect process. 1084 preserve_ops: An argument used to specify ops that should not be decomposed. 1085 1086 Returns: 1087 EdgeProgramManager 1088 """ 1089 assert not isinstance(constant_methods, EdgeCompileConfig) 1090 config = compile_config or EdgeCompileConfig() 1091 if not isinstance(programs, dict): 1092 aten_programs = {"forward": programs} 1093 else: 1094 aten_programs = programs 1095 1096 edge_programs: Dict[str, ExportedProgram] = {} 1097 1098 for name, program in aten_programs.items(): 1099 # Decompose to Core ATen 1100 table = _default_decomposition_table() 1101 for op in preserve_ops: 1102 table.pop(op, None) 1103 program = program.run_decompositions(table) 1104 edge_programs[name] = _generate_edge_program( 1105 name, config, program, list(preserve_ops) 1106 ) 1107 1108 return EdgeProgramManager( 1109 edge_programs, constant_methods, config, list(preserve_ops) 1110 ) 1111 1112 1113def to_edge( 1114 programs: Union[ExportedProgram, Dict[str, ExportedProgram]], 1115 constant_methods: Optional[Dict[str, Any]] = None, 1116 compile_config: Optional[EdgeCompileConfig] = None, 1117) -> "EdgeProgramManager": 1118 """ 1119 :func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in 1120 ATen dialect. Upon construction those programs are transformed into edge dialect. 1121 1122 Args: 1123 programs: Can be a single ExportedProgram or a dictionary mapping function names to their corresponding ExportedPrograms. If only a single ExportedProgram is provided it will be assigned the name "forward". 1124 1125 constant_methods: An optional dictionary of method name to the constant value returned by that method in eager mode. Often used to store config information on Edge models. 1126 1127 compile_config: An optional argument used to provide greater control over the transformation to edge dialect process. 1128 1129 Returns: 1130 EdgeProgramManager 1131 """ 1132 assert not isinstance(constant_methods, EdgeCompileConfig) 1133 config = compile_config or EdgeCompileConfig() 1134 if not isinstance(programs, dict): 1135 aten_programs = {"forward": programs} 1136 else: 1137 aten_programs = programs 1138 1139 edge_programs: Dict[str, ExportedProgram] = {} 1140 1141 for name, program in aten_programs.items(): 1142 # Decompose to Core ATen 1143 program = program.run_decompositions(_default_decomposition_table()) 1144 edge_programs[name] = _generate_edge_program(name, config, program) 1145 1146 return EdgeProgramManager(edge_programs, constant_methods, config) 1147 1148 1149class EdgeProgramManager: 1150 """ 1151 Package of one or more `ExportedPrograms` in Edge dialect. Designed to simplify 1152 lowering to ExecuTorch. See: https://pytorch.org/executorch/stable/ir-exir.html 1153 1154 Allows easy applications of transforms across a collection of exported programs 1155 including the delegation of subgraphs. 1156 1157 Manages the second link in the lowering chain of ATen -> Edge -> ExecuTorch. 1158 """ 1159 1160 def __init__( 1161 self, 1162 edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]], 1163 constant_methods: Optional[Dict[str, Any]] = None, 1164 compile_config: Optional[EdgeCompileConfig] = None, 1165 ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None, 1166 ): 1167 """ 1168 Should not be called directly by users. User should use :func:'to_edge' instead. 1169 1170 Constructs an EdgeProgramManager from an existing set of exported programs in edge dialect. 1171 """ 1172 self.compile_config = compile_config or EdgeCompileConfig() 1173 if not isinstance(edge_programs, dict): 1174 edge_programs = {"forward": edge_programs} 1175 1176 for name, program in edge_programs.items(): 1177 try: 1178 EXIREdgeDialectVerifier( 1179 edge_compile_config=self.compile_config, 1180 exception_list=ops_set_to_not_decompose, 1181 )(program.graph_module) 1182 except ExportError as e: 1183 logging.info(f"Input program {name} is not in aten dialect.") 1184 raise e 1185 1186 self._edge_programs: Dict[str, ExportedProgram] = edge_programs 1187 self._config_methods = constant_methods 1188 1189 @property 1190 def methods(self) -> Set[str]: 1191 """ 1192 Returns the set of methods in this EdgeProgramManager. 1193 """ 1194 return set(self._edge_programs.keys()) 1195 1196 @property 1197 def config_methods(self) -> Set[str]: 1198 """ 1199 Returns the set of config methods in this EdgeProgramManager. 1200 """ 1201 return set(self._config_methods.keys()) if self._config_methods else set() 1202 1203 def exported_program(self, method_name: str = "forward") -> ExportedProgram: 1204 """ 1205 Returns the ExportedProgram specified by 'method_name'. 1206 """ 1207 return self._edge_programs[method_name] 1208 1209 def transform( 1210 self, 1211 passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]], 1212 compile_config: Optional[EdgeCompileConfig] = None, 1213 ) -> "EdgeProgramManager": 1214 """ 1215 Transforms the program according to the provided passes. 1216 1217 Args: 1218 passes: The passes can either be a list of passes, or a 1219 dictionary mapping method names to lists of passes. If it is 1220 just a list of passes, all methods in the given EdgeProgramManager 1221 will be transformed with the provided passes. If it is a 1222 dictionary, only method names specified in the dictionary will be 1223 transformed with their corresponding passes. 1224 compile_config: Compile config to use for veriy the correctness of model 1225 graph after each pass. If not specified, the compile config of the 1226 calling EdgeProgramManager will be used. It will be used in as compile 1227 config of returned EdgeProgramManager. 1228 1229 Returns: 1230 EdgeProgramManager: A copy of the calling EdgeProgramManager with the 1231 transformations applied. 1232 """ 1233 compile_config = compile_config or self.compile_config 1234 new_programs: Dict[str, ExportedProgram] = {} 1235 if isinstance(passes, dict): 1236 for name, program in self._edge_programs.items(): 1237 if name in passes.keys(): 1238 new_programs[name] = _transform(program, *passes[name]) 1239 EXIREdgeDialectVerifier(edge_compile_config=compile_config)( 1240 new_programs[name].graph_module 1241 ) 1242 else: 1243 new_programs[name] = copy.deepcopy(program) 1244 1245 else: # apply passes to every method 1246 for name, program in self._edge_programs.items(): 1247 new_programs[name] = _transform(program, *passes) 1248 EXIREdgeDialectVerifier(edge_compile_config=compile_config)( 1249 new_programs[name].graph_module 1250 ) 1251 1252 return EdgeProgramManager( 1253 new_programs, copy.deepcopy(self._config_methods), compile_config 1254 ) 1255 1256 def to_backend( 1257 self, partitioner: Union[Partitioner, Dict[str, Partitioner]] 1258 ) -> "EdgeProgramManager": 1259 """ 1260 Returns a semantically-equivalent program to the one given as input, 1261 but with portions of each program in the EdgeProgramManager targeted 1262 for delegation as determined by the partitioner. 1263 1264 Args: 1265 partitioner: The partitioner can either be a Partitioner subclass instance, or a 1266 dictionary mapping method names to Partitioner subclass instance. If it is a 1267 Partitioner subclass, all programs in the given EdgeProgramManager 1268 will be lowered using the given partitioner. If it is a 1269 dictionary, only method names specified in the dictionary will be 1270 lowered with the given partitioner. 1271 1272 The Partitioner subclass instance is in charge with tagging portions of the 1273 input program for delegation. A valid partitioner must return PartitionerResult including valid 1274 partition_tags: Dict[str, DelegationSpec], where each key is a tag 1275 name and the nodes with same tag will be fused a one subgraph and 1276 delegated to backend specififed in delegation spec. 1277 1278 Returns: 1279 EdgeProgramManager: A copy of the calling EdgeProgramManager with the 1280 specified subgraphs lowered. 1281 """ 1282 new_edge_programs: Dict[str, ExportedProgram] = {} 1283 if isinstance(partitioner, dict): 1284 for name, program in self._edge_programs.items(): 1285 if name in partitioner.keys(): 1286 new_edge_programs[name] = to_backend(program, partitioner[name]) 1287 else: 1288 new_edge_programs[name] = program 1289 1290 else: # apply partitioner to every method 1291 for name, program in self._edge_programs.items(): 1292 new_edge_programs[name] = to_backend(program, partitioner) 1293 1294 config = EdgeCompileConfig(_check_ir_validity=False) 1295 return EdgeProgramManager( 1296 new_edge_programs, copy.deepcopy(self._config_methods), config 1297 ) 1298 1299 def to_executorch( 1300 self, 1301 config: Optional[ExecutorchBackendConfig] = None, 1302 ) -> "ExecutorchProgramManager": 1303 """ 1304 Transforms the program to the ExecuTorch backend. 1305 1306 Args: 1307 config: An optional argument used to provide greater control over 1308 the transformation to the ExecuTorch backend. 1309 1310 Returns: 1311 ExecutorchProgramManager: A manager representing the state of the EdgeProgramManager 1312 after it has been transformed to the ExecuTorch backend. 1313 """ 1314 config = config if config else ExecutorchBackendConfig() 1315 1316 execution_programs: Dict[str, ExportedProgram] = {} 1317 for name, program in self._edge_programs.items(): 1318 program = weights_to_outputs_pass(program) 1319 program = unsafe_remove_auto_functionalized_pass(program) 1320 gm, new_signature = insert_write_back_for_buffers_pass(program) 1321 new_gm = program.graph_module 1322 for p in edge_to_executorch_passes(config, name): 1323 new_gm_res = p(new_gm) 1324 assert new_gm_res is not None 1325 new_gm = new_gm_res.graph_module 1326 if isinstance(p, SpecPropPass): 1327 # Note that this is a hacky way to get around the fact that 1328 # placeholder nodes corresponding to the parameters of the graph module 1329 # shall not participate in memory planning. It increases runtime memory 1330 # footprint. 1331 # Proper way would be to have ExportPass work with ExportedProgram 1332 # instead of GraphModule. This is because ExportPass should work 1333 # on top of the export artifact of torch.export whichi s ExportedProgram. 1334 # Working with GraphModule does not provide all the information contained 1335 # in the ExportedProgram 1336 # TODO(who?) 1337 p.update_placeholder_tensor_specs(program, new_gm) 1338 1339 if isinstance(config.memory_planning_pass, dict): 1340 memory_planning_pass = config.memory_planning_pass.get( 1341 name, ExecutorchBackendConfig().memory_planning_pass 1342 ) 1343 else: 1344 memory_planning_pass = config.memory_planning_pass 1345 # TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work 1346 if hasattr(memory_planning_pass, "run"): 1347 new_gm_res = memory_planning_pass.run( # pyre-ignore[16] 1348 new_gm, new_signature 1349 ) 1350 else: 1351 new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29] 1352 assert new_gm_res is not None 1353 new_gm = new_gm_res.graph_module 1354 1355 _copy_module(program.graph_module, new_gm) 1356 execution_programs[name] = program 1357 1358 return ExecutorchProgramManager( 1359 execution_programs, self._config_methods, config 1360 ) 1361 1362 1363class ExecutorchProgramManager: 1364 """ 1365 Package of one or more `ExportedPrograms` in Execution dialect. Designed to simplify 1366 lowering to ExecuTorch. See: https://pytorch.org/executorch/stable/ir-exir.html 1367 1368 When the ExecutorchProgramManager is constructed the ExportedPrograms in execution dialect 1369 are used to form the executorch binary (in a process called emission) and then serialized 1370 to a buffer. 1371 1372 Manages the final link in the lowering chain of ATen -> Edge -> ExecuTorch. 1373 """ 1374 1375 def __init__( 1376 self, 1377 execution_programs: Dict[str, ExportedProgram], 1378 config_methods: Optional[Dict[str, Any]] = None, 1379 backend_config: Optional[ExecutorchBackendConfig] = None, 1380 ): 1381 """ 1382 End users should not call this constructor directly. Instead, they should use 1383 :func:'to_executorch' to construct an ExecutorchProgramManager. 1384 1385 Constructs an ExecutorchProgramManager from a set of exported programs in 1386 execution dialect. 1387 1388 Args: 1389 execution_programs: A dictionary of method name to the corresponding 1390 ExportedProgram. 1391 1392 config_methods: A dictionary of method name to the config value returned 1393 by that method in eager mode. 1394 1395 backend_config: An optional argument used to provide greater control over 1396 the emission and serialization. 1397 """ 1398 # Set up methods 1399 self._execution_programs: Dict[str, ExportedProgram] = execution_programs 1400 self._config_methods: Optional[Dict[str, Any]] = config_methods 1401 1402 backend_config = backend_config or ExecutorchBackendConfig() 1403 1404 # Emit methods 1405 self._emitter_output: EmitterOutput = emit_program( 1406 self._execution_programs, 1407 backend_config.emit_stacktrace, 1408 self._config_methods, 1409 ) 1410 1411 # Serialize emitter output, ready to be written to a file. 1412 self._pte_data: Cord = _serialize_pte_binary( 1413 program=self._emitter_output.program, 1414 mutable_data=self._emitter_output.mutable_data, 1415 extract_delegate_segments=backend_config.extract_delegate_segments, 1416 segment_alignment=backend_config.segment_alignment, 1417 constant_tensor_alignment=backend_config.constant_tensor_alignment, 1418 delegate_alignment=backend_config.delegate_alignment, 1419 ) 1420 self._buffer: Optional[bytes] = None 1421 1422 @property 1423 def methods(self) -> Set[str]: 1424 """ 1425 Returns the set of methods in this ExecutorchProgramManager. 1426 """ 1427 return set(self._execution_programs.keys()) 1428 1429 @property 1430 def config_methods(self) -> Set[str]: 1431 """ 1432 Returns the set of config methods in this ExecutorchProgramManager. 1433 """ 1434 return set(self._config_methods.keys()) if self._config_methods else set() 1435 1436 def exported_program(self, method_name: str = "forward") -> ExportedProgram: 1437 """ 1438 Returns the ExportedProgram specified by 'method_name'. 1439 """ 1440 return self._execution_programs[method_name] 1441 1442 def dump_executorch_program( 1443 self, verbose: bool = False, out: Optional[TextIO] = None 1444 ) -> None: 1445 """ 1446 Prints the ExecuTorch binary in a human readable format. 1447 1448 Args: 1449 verbose (bool): 1450 If False prints the binary in a condensed format. 1451 If True prints the binary 1-1 with the specification in the schema. 1452 out: 1453 If None, prints to stdout. 1454 If non-None, writes the string to that stream object. It can be 1455 a file object, a StringIO object, or any other TextIO subclass. 1456 """ 1457 if verbose: 1458 pretty_print(self._emitter_output.program, out=out) 1459 else: 1460 print_program(self._emitter_output.program, out=out) 1461 1462 @property 1463 def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]: 1464 return self._emitter_output.debug_handle_map 1465 1466 @property 1467 def delegate_map( 1468 self, 1469 ) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]: 1470 return self._emitter_output.method_to_delegate_debug_id_map 1471 1472 @property 1473 def executorch_program(self) -> Program: 1474 """ 1475 Returns the object that represents the ExecuTorch binary before serialization. 1476 """ 1477 return self._emitter_output.program 1478 1479 @property 1480 def buffer(self) -> bytes: 1481 """Returns the serialized ExecuTorch binary as a byte string. 1482 1483 Note that the call to `buffer` may allocate a very large amount of 1484 contiguous memory, depending on the model size. If writing to a file, 1485 use `write_to_file` which won't incur additional copies. 1486 """ 1487 # TODO(T181494963): update pybinding to remove buffer cache, which can consume large 1488 # amounts of memory longer than necessary. 1489 if self._buffer is None: 1490 self._buffer = bytes(self._pte_data) 1491 return self._buffer 1492 1493 def write_to_file(self, open_file: io.BufferedIOBase) -> None: 1494 """ 1495 Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over 1496 `buffer`, as it writes to file without copying into a contiguous block of memory first, 1497 reducing the peak memory usage. 1498 """ 1499 self._pte_data.write_to_file(open_file) 1500