# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import unittest from types import MappingProxyType import torch from executorch import exir from executorch.exir.backend.backend_details import CompileSpec, ExportedProgram from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( generate_pattern_op_partitions, ) from executorch.exir.backend.partitioner import ( DelegationSpec, Partitioner, PartitionResult, ) from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import ( AnyOperatorSupport, ) from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import ( ExecutorBackend, ) from executorch.exir.backend.test.op_partitioner_demo import ( AddAttributePartitionerDemo, AllNodesPartitionerDemo, ) from executorch.exir.backend.utils import get_delegates, tag_constant_data from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.tests.models import MLP from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib _load_for_executorch_from_buffer, ) from executorch.extension.pytree import tree_flatten from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param from torch.export import export, export_for_training from torch.fx.passes.operator_support import any_chain class TestPartitioner(unittest.TestCase): def test_partitioner_with_spec(self): # Create a custom partitioner with spec and check the spec can be accessed by not mutable. class PartitionerWithSpec(Partitioner): def __init__(self, spec) -> None: super().__init__(spec) self.op_support = any_chain(AnyOperatorSupport()) self.delegation_spec = DelegationSpec( ExecutorBackend.__name__, [CompileSpec(key, value) for key, value in self.spec.items()], ) def partition( self, edge_exported_program: ExportedProgram ) -> PartitionResult: partition_tags = {} partition_list = generate_pattern_op_partitions( edge_exported_program.graph_module, op_support=self.op_support ) for partition in partition_list: for node in partition.nodes: delegation_tag = f"tag{partition.id}" node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec return PartitionResult( tagged_exported_program=edge_exported_program, partition_tags=partition_tags, ) mlp = MLP() example_inputs = mlp.get_random_inputs() model = export_for_training(mlp, example_inputs).module() aten = export(model, example_inputs) spec_key = "path" spec_value = "/a/b/c/d" spec = MappingProxyType({spec_key: spec_value}) my_partitioner = PartitionerWithSpec(spec) edge = exir.to_edge(aten).to_backend(my_partitioner) lowered_module_nodes = get_delegates(edge.exported_program().graph) self.assertEqual(len(lowered_module_nodes), 1) # Check the lowered module has correct compile spec for lower_module_node in lowered_module_nodes: lower_module = getattr( edge.exported_program().graph_module, lower_module_node.name ) self.assertEqual(lower_module.compile_specs[0].key, spec_key) self.assertEqual(lower_module.compile_specs[0].value, spec_value) # Check the custom partitioner has the correct spec self.assertEqual(my_partitioner.spec[spec_key], spec_value) with self.assertRaisesRegex( TypeError, "'mappingproxy' object does not support item assignment", ): my_partitioner.spec[spec_key] = "new_value" with self.assertRaisesRegex( AttributeError, "can't set attribute 'spec'", ): my_partitioner.spec = {"new_key": "new_value"} def test_bad_partitioner_tagged_output(self): # Create a bad partitioner to tag output, which is not allowed. class PartitionerTagOutput(Partitioner): def __init__(self) -> None: super().__init__() self.delegation_spec = DelegationSpec( ExecutorBackend.__name__, [CompileSpec(key, value) for key, value in self.spec.items()], ) def partition( self, edge_exported_program: ExportedProgram ) -> PartitionResult: partition_tags = {} for node in edge_exported_program.graph.nodes: if node.op == "output": delegation_tag = "tag0" node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec return PartitionResult( tagged_exported_program=edge_exported_program, partition_tags=partition_tags, ) mlp = MLP() example_inputs = mlp.get_random_inputs() model = export_for_training(mlp, example_inputs).module() aten = export(model, example_inputs) edge = exir.to_edge(aten) with self.assertRaisesRegex( RuntimeError, "output node output should not be tagged", ): _ = edge.to_backend(PartitionerTagOutput()) def test_bad_partitioner_tagged_model_input(self): # Create a bad partitioner to tag an input that is neither params nor buffer, which is not allowed. class PartitionerTagInput(Partitioner): def __init__(self) -> None: super().__init__() self.delegation_spec = DelegationSpec( ExecutorBackend.__name__, [CompileSpec(key, value) for key, value in self.spec.items()], ) def partition( self, edge_exported_program: ExportedProgram ) -> PartitionResult: partition_tags = {} for node in edge_exported_program.graph.nodes: if node.op == "placeholder": if not is_param(edge_exported_program, node) and not is_buffer( edge_exported_program, node ): delegation_tag = "tag_" + str(node.meta["debug_handle"]) node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec return PartitionResult( tagged_exported_program=edge_exported_program, partition_tags=partition_tags, ) mlp = MLP() example_inputs = mlp.get_random_inputs() model = export_for_training(mlp, example_inputs).module() edge = exir.to_edge(export(model, example_inputs)) with self.assertRaisesRegex( RuntimeError, "placeholder node for non-params, non-buffer, and non-tensor constants should not be tagged", ): _ = edge.to_backend(PartitionerTagInput()) class AddConst(torch.nn.Module): def __init__(self): super().__init__() self.const1 = torch.ones(2, 2) self.register_buffer("const2", torch.ones(2, 2), persistent=False) self.register_parameter("const3", torch.nn.Parameter(torch.ones(2, 2))) def forward(self, x): return x + self.const1 + self.const2 + self.const3 def test_partitioner_not_tag_data(self): """ We test here that when partitioners do not explicitly tag constant data nodes, then the partitioned ExportedProgram will not own the data. Instead the owning program will still own the constant data and instead feed it as inputs to the partitioned program """ class PartitionerNoTagData(Partitioner): def __init__(self): super().__init__() self.delegation_spec = DelegationSpec( ExecutorBackend.__name__, [CompileSpec(key, value) for key, value in self.spec.items()], ) def partition( self, edge_exported_program: ExportedProgram ) -> PartitionResult: partition_tags = {} for node in edge_exported_program.graph.nodes: if node.op == "call_function" and node.target in [ exir_ops.edge.aten.add.Tensor ]: delegation_tag = "tag0" node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec return PartitionResult( tagged_exported_program=edge_exported_program, partition_tags=partition_tags, ) model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),))) delegated = edge.to_backend(PartitionerNoTagData()) # Check Owning Program still owns all constant data owning_program = delegated.exported_program() self.assertEqual( len(owning_program.state_dict) + len(owning_program.constants), 3 ) self.assertEqual( len(owning_program.graph_signature.buffers) + len(owning_program.graph_signature.lifted_tensor_constants), 2, ) self.assertEqual(len(owning_program.graph_signature.parameters), 1) # Check Lowered Module Exported Program does not have any constant data lowered_module_nodes = get_delegates(delegated.exported_program().graph) self.assertEqual(len(lowered_module_nodes), 1) lowered_module_node = lowered_module_nodes[0] # get call delegate node call_delegate_node = list(lowered_module_node.users.keys())[0] # 5 args to lowered module are: delegated_payload, x, const1, const2, const3 self.assertEqual(len(call_delegate_node.args), 5) lower_module = getattr( delegated.exported_program().graph_module, lowered_module_node.name ) delegated_ep = lower_module.original_module self.assertEqual(len(delegated_ep.state_dict), 0) self.assertEqual(len(delegated_ep.graph_signature.buffers), 0) self.assertEqual(len(delegated_ep.graph_signature.parameters), 0) # check exported program is still runnable output = delegated.exported_program().module()(torch.ones(2, 2)) reference_output = model(torch.ones(2, 2)) self.assertTrue(torch.allclose(reference_output, output)) def test_partitioner_tag_data(self): """ We test here that when partitioners explicitly tag constant data nodes, then the partitioned ExportedProgram will own the data, and the data will be removed from the owning program. """ class PartitionerTagData(Partitioner): def __init__(self): super().__init__() self.delegation_spec = DelegationSpec( ExecutorBackend.__name__, [CompileSpec(key, value) for key, value in self.spec.items()], ) def partition( self, edge_exported_program: ExportedProgram ) -> PartitionResult: partition_tags = {} for node in edge_exported_program.graph.nodes: if node.op == "call_function" and node.target in [ exir_ops.edge.aten.add.Tensor ]: delegation_tag = "tag0" node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec if node.op == "placeholder" and ( is_param(edge_exported_program, node) or is_buffer(edge_exported_program, node) or is_lifted_tensor_constant(edge_exported_program, node) ): delegation_tag = "tag0" node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec return PartitionResult( tagged_exported_program=edge_exported_program, partition_tags=partition_tags, ) model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),))) delegated = edge.to_backend(PartitionerTagData()) # Check Owning Program still owns all constant data owning_program = delegated.exported_program() self.assertEqual(len(owning_program.state_dict), 0) self.assertEqual(len(owning_program.graph_signature.buffers), 0) self.assertEqual(len(owning_program.graph_signature.parameters), 0) # Check Lowered Module Exported Program does not have any constant data lowered_module_nodes = get_delegates(delegated.exported_program().graph) self.assertEqual(len(lowered_module_nodes), 1) lowered_module_node = lowered_module_nodes[0] # get call delegate node call_delegate_node = list(lowered_module_node.users.keys())[0] # 5 args to lowered module are: delegated_payload, x self.assertEqual(len(call_delegate_node.args), 2) lower_module = getattr( delegated.exported_program().graph_module, lowered_module_node.name ) delegated_ep = lower_module.original_module self.assertEqual(len(delegated_ep.state_dict) + len(delegated_ep.constants), 3) self.assertEqual( len(delegated_ep.graph_signature.buffers) + len(delegated_ep.graph_signature.lifted_tensor_constants), 2, ) self.assertEqual(len(delegated_ep.graph_signature.parameters), 1) # check exported program is still runnable output = delegated.exported_program().module()(torch.ones(2, 2)) reference_output = model(torch.ones(2, 2)) self.assertTrue(torch.allclose(reference_output, output)) def test_partitioner_tag_only_params(self): """ We test here that when partitioners explicitly tag constant data nodes, then the partitioned ExportedProgram will own the data, and the data will be removed from the owning program. """ class PartitionerTagData(Partitioner): def __init__(self): super().__init__() self.delegation_spec = DelegationSpec( ExecutorBackend.__name__, [CompileSpec(key, value) for key, value in self.spec.items()], ) def partition( self, edge_exported_program: ExportedProgram ) -> PartitionResult: partition_tags = {} for node in edge_exported_program.graph.nodes: if node.op == "call_function" and node.target in [ exir_ops.edge.aten.add.Tensor ]: delegation_tag = "tag0" node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec if node.op == "placeholder" and ( is_param(edge_exported_program, node) ): delegation_tag = "tag0" node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec return PartitionResult( tagged_exported_program=edge_exported_program, partition_tags=partition_tags, ) model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),))) delegated = edge.to_backend(PartitionerTagData()) # Check Owning Program still owns only buffers owning_program = delegated.exported_program() self.assertEqual( len(owning_program.state_dict) + len(owning_program.constants), 2 ) self.assertEqual( len(owning_program.graph_signature.buffers) + len(owning_program.graph_signature.lifted_tensor_constants), 2, ) self.assertEqual(len(owning_program.graph_signature.parameters), 0) # Check Lowered Module Exported Program does not own any buffers lowered_module_nodes = get_delegates(delegated.exported_program().graph) self.assertEqual(len(lowered_module_nodes), 1) lowered_module_node = lowered_module_nodes[0] # get call delegate node call_delegate_node = list(lowered_module_node.users.keys())[0] # 5 args to lowered module are: delegated_payload, x, buffer1, buffer2 self.assertEqual(len(call_delegate_node.args), 4) lower_module = getattr( delegated.exported_program().graph_module, lowered_module_node.name ) delegated_ep = lower_module.original_module self.assertEqual(len(delegated_ep.state_dict), 1) self.assertEqual(len(delegated_ep.graph_signature.buffers), 0) self.assertEqual(len(delegated_ep.graph_signature.parameters), 1) # check exported program is still runnable output = delegated.exported_program().module()(torch.ones(2, 2)) reference_output = model(torch.ones(2, 2)) self.assertTrue(torch.allclose(reference_output, output)) def test_partitioner_splits_constant_data(self): """ We test that we throw an error when constant data users are split between different delegated payloads or owning program. """ class ReuseConstData(torch.nn.Module): def __init__(self): super().__init__() self.const = torch.ones(2, 2) def forward(self, x): y = x + self.const z = x - self.const return y, z class PartitionerTagData(Partitioner): def __init__(self): super().__init__() self.delegation_spec = DelegationSpec( ExecutorBackend.__name__, [CompileSpec(key, value) for key, value in self.spec.items()], ) def partition( self, edge_exported_program: ExportedProgram ) -> PartitionResult: partition_tags = {} for node in edge_exported_program.graph.nodes: if node.op == "call_function" and node.target in [ exir_ops.edge.aten.add.Tensor ]: delegation_tag = "tag0" node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec if node.op == "placeholder" and ( is_param(edge_exported_program, node) or is_buffer(edge_exported_program, node) ): delegation_tag = "tag0" node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec return PartitionResult( tagged_exported_program=edge_exported_program, partition_tags=partition_tags, ) inputs = (torch.ones(2, 2),) model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),))) exec_prog = edge.to_backend(PartitionerTagData()).to_executorch() executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer) inputs_flattened, _ = tree_flatten(inputs) # Send the input from server executor to client executor, and receive the result from client executor _ = executorch_module.run_method("forward", inputs) def test_partitioner_alert_split_constant_data(self): """ We test that we throw an error when constant data users are split between different delegated payloads or owning program. """ class ReuseConstData(torch.nn.Module): def __init__(self): super().__init__() self.const = torch.ones(2, 2) def forward(self, x): y = x + self.const z = x - self.const return y, z class PartitionerTagData(Partitioner): def __init__(self): super().__init__() self.delegation_spec = DelegationSpec( ExecutorBackend.__name__, [CompileSpec(key, value) for key, value in self.spec.items()], ) def partition( self, edge_exported_program: ExportedProgram ) -> PartitionResult: partition_tags = {} for node in edge_exported_program.graph.nodes: if node.op == "call_function" and node.target in [ exir_ops.edge.aten.add.Tensor ]: delegation_tag = "tag0" node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec if node.op == "placeholder" and ( is_param(edge_exported_program, node) or is_buffer(edge_exported_program, node) or is_lifted_tensor_constant(edge_exported_program, node) ): delegation_tag = "tag0" node.meta["delegation_tag"] = delegation_tag node.meta["no_copy"] = True partition_tags[delegation_tag] = self.delegation_spec return PartitionResult( tagged_exported_program=edge_exported_program, partition_tags=partition_tags, ) model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),))) with self.assertRaises(RuntimeError) as error: _ = edge.to_backend(PartitionerTagData()) self.assertTrue( "is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)" in str(error.exception), ) def test_not_delegate_mutable_buffers(self) -> None: """ A test case to check the mutated buffer is not delegated. We'll need to add a test case to consider when the delegate can consume the mutable buffer. """ class MutableStateModule(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("my_state", torch.zeros(1)) def forward(self, x): y = x + self.my_state self.my_state.add_(1) return y edge = exir.to_edge( torch.export.export( MutableStateModule(), (torch.zeros(1),), ) ) self.assertGreater( len(edge.exported_program().graph_signature.buffers_to_mutate), 0, "The test case should at leaset one mutable buffer", ) class PartitionerTagData(Partitioner): def __init__(self): super().__init__() self.delegation_spec = DelegationSpec( ExecutorBackend.__name__, [CompileSpec(key, value) for key, value in self.spec.items()], ) def partition( self, edge_exported_program: ExportedProgram ) -> PartitionResult: partition_tags = {} for node in edge_exported_program.graph.nodes: if node.op == "call_function" and node.target in [ exir_ops.edge.aten.add.Tensor ]: delegation_tag = "tag0" node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec tag_constant_data(edge_exported_program) return PartitionResult( tagged_exported_program=edge_exported_program, partition_tags=partition_tags, ) # Check the edge program inital buffers_to_mutate mutate_op = "aten_add_tensor_1" self.assertEqual( edge.exported_program().graph_signature.buffers_to_mutate[mutate_op], "my_state", ) edge = edge.to_backend(PartitionerTagData()) # After to_backend, add is delegated and is no longer in buffers_to_mutate. self.assertNotIn( mutate_op, edge.exported_program().graph_signature.buffers_to_mutate, ) mutate_op = "getitem_1" # Ensure the mutated buffer is not delegated, and the new mutate node is getitem (from call_delegate) self.assertEqual( edge.exported_program().graph_signature.buffers_to_mutate[mutate_op], "my_state", ) # Check the copy_ node is inserted edge = edge.to_executorch() copy_node = [ node for node in edge.exported_program().graph.nodes if node.op == "call_function" and node.target == torch.ops.aten.copy_.default ] self.assertEqual(len(copy_node), 1) def test_buffer_mutation1(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("b", torch.ones(3, 3)) def forward(self, x): self.b.add_(x) return x + self.b model_inputs = (torch.ones(3, 3),) orig_res = TestModule()(*model_inputs) edge_program = exir.to_edge(torch.export.export(TestModule(), model_inputs)) lowered = edge_program.to_backend(AddAttributePartitionerDemo()) self.assertTrue( torch.allclose(lowered.exported_program().module()(*model_inputs), orig_res) ) self.assertEqual( len(lowered.exported_program().graph_signature.buffers_to_mutate), 0, ) lowered_module_nodes = get_delegates(lowered.exported_program().graph) self.assertEqual(len(lowered_module_nodes), 1) lowered_module_node = lowered_module_nodes[0] # get call delegate node call_delegate_node = list(lowered_module_node.users.keys())[0] self.assertEqual(len(call_delegate_node.args), 2) lower_module = getattr( lowered.exported_program().graph_module, lowered_module_node.name ) delegated_ep = lower_module.original_module self.assertEqual(len(delegated_ep.state_dict), 1) self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1) self.assertEqual(len(delegated_ep.graph_signature.buffers), 1) def test_buffer_mutation_llama_repro(self): SHAPE = (2, 3) class Model(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("cache", torch.zeros(SHAPE, dtype=torch.float32)) def forward(self, q, k_val, input_pos): q_T = q.transpose(0, 1) k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val) attn = k.mm(q_T) return attn q = torch.rand(1, 3) k = torch.rand(1, 3) example_inputs = (q, k, torch.tensor([1, 1])) model = Model() model.eval() exir_program_aten = torch.export.export(model, example_inputs) exir_program_aten.module()(*example_inputs) edge_program_manager = exir.to_edge(exir_program_aten) lowered = edge_program_manager.to_backend(AllNodesPartitionerDemo()) self.assertEqual( len(lowered.exported_program().graph_signature.buffers_to_mutate), 0, ) lowered_module_nodes = get_delegates(lowered.exported_program().graph) self.assertEqual(len(lowered_module_nodes), 1) lowered_module_node = lowered_module_nodes[0] # get call delegate node call_delegate_node = list(lowered_module_node.users.keys())[0] self.assertEqual(len(call_delegate_node.args), 4) lower_module = getattr( lowered.exported_program().graph_module, lowered_module_node.name ) delegated_ep = lower_module.original_module self.assertEqual(len(delegated_ep.state_dict), 1) self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1) self.assertEqual(len(delegated_ep.graph_signature.buffers), 1) def test_buffer_mutation_unsupported(self): SHAPE = (2, 3) class Model(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("state_1", torch.zeros(SHAPE, dtype=torch.float32)) def forward(self, x): add = self.state_1.add_(x) return add model = Model() model.eval() example_inputs = (torch.randn(SHAPE),) exir_program_aten = torch.export.export(model, example_inputs) edge_program_manager = exir.to_edge(exir_program_aten) with self.assertRaises(AssertionError): edge_program_manager.to_backend(AddAttributePartitionerDemo())