xref: /aosp_15_r20/external/executorch/exir/backend/test/hta_partitioner_demo.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from typing import final, List
9
10import torch
11from executorch import exir
12from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
13    generate_pattern_op_partitions,
14)
15
16from executorch.exir.backend.partitioner import (
17    DelegationSpec,
18    Partitioner,
19    PartitionResult,
20)
21from executorch.exir.backend.test.qnn_backend_demo import QnnBackend
22from executorch.exir.backend.utils import tag_constant_data
23from torch.export import ExportedProgram
24from torch.fx.passes.infra.partitioner import Partition
25
26
27@final
28class HTAPartitionerMultiplePatternsDemo(Partitioner):
29    """
30    An example implementation to partition graph for HTA, in this example, the backend
31    associate with this partitioner is QnnBackend. With QnnBackend, the two lowerable
32    patterns are: (lstm + conv) and (sub). backend is a class member instead of instance
33    members, as it is a properties of HTAPartitionerMultiplePatternsDemo, and won't be different for
34    different HTAPartitionerMultiplePatternsDemo instances.
35
36    The partition algorithm is:
37    1. Find out a list of partitions given a graph: generate_partition_list(GraphModule) -> List[Partition]:
38    2. Check if all partitions from generate_partition_list() are exclusive. If they are, it will error out
39    3. Fuse the partition list as submodules.
40    """
41
42    def __init__(self) -> None:
43        """
44        Initialize a list of pattern partitioners: (lstm + conv) and (sub)
45        """
46
47        class LSTMConvPattern(torch.nn.Module):
48            def __init__(self):
49                super().__init__()
50                self.lstm = torch.nn.LSTM(
51                    input_size=32,
52                    hidden_size=32,
53                    num_layers=1,
54                )
55                self.conv = torch.nn.Conv1d(1, 1, 1, stride=2)
56
57            def forward(self, x_raw, h, c):
58                output, (hn, cn) = self.lstm(x_raw, (h, c))
59                k = self.conv(output)
60                return output, hn, cn, k
61
62        input_x = torch.ones([1, 32])
63        input_h = torch.ones([1, 32])
64        input_c = torch.ones([1, 32])
65
66        pattern_lstm_conv_lifted = (
67            exir.capture(
68                LSTMConvPattern(),
69                (input_x, input_h, input_c),
70                exir.CaptureConfig(enable_aot=True),
71            )
72            .to_edge(
73                # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
74                exir.EdgeCompileConfig(_check_ir_validity=False)
75            )
76            .exported_program.graph_module
77        )
78        pattern_lstm_conv = (
79            exir.capture(
80                LSTMConvPattern(),
81                (input_x, input_h, input_c),
82                exir.CaptureConfig(),
83            )
84            .to_edge(
85                # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
86                exir.EdgeCompileConfig(_check_ir_validity=False)
87            )
88            .exported_program.graph_module
89        )
90
91        def sub(x, y):
92            return torch.sub(x, y)
93
94        pattern_sub_lifted = (
95            exir.capture(
96                sub,
97                (input_x, input_h),
98                exir.CaptureConfig(enable_aot=True, _unlift=False),
99            )
100            .to_edge(exir.EdgeCompileConfig(_use_edge_ops=True))
101            .exported_program.graph_module
102        )
103        pattern_sub = (
104            exir.capture(
105                sub,
106                (input_x, input_h),
107                exir.CaptureConfig(),
108            )
109            .to_edge()
110            .exported_program.graph_module
111        )
112        self.patterns = [
113            pattern_lstm_conv_lifted.graph,
114            pattern_lstm_conv.graph,
115            pattern_sub_lifted.graph,
116            pattern_sub.graph,
117        ]
118
119        backend_id = QnnBackend.__name__
120        self.delegation_spec = DelegationSpec(backend_id, [])
121
122    def is_exclusive(self, partition_list_list: List[List[Partition]]) -> bool:
123        """
124        List[Partition] is generate from one pattern partitioner, and this partitioner
125        only supports merging exclusive partitions. It will check if all partitions are
126        exclusive by comparing len(all_nodes) and len(set(all_nodes))
127
128        Args:
129            partition_list_list: all partitions from all pattern partitioners
130
131        Returns:
132            bool: True if all nodes from all partitions are exclusive.
133
134        For example, 0/1 are the partition id, A/B/../L are nodes:
135        [
136            [(0: A, B, C), (1: D, E, F)], # from pattern lstm + conv
137            [(0: B, J, L)], # from sub
138        ]
139        node B shows up in both partition. Usually some special tricks (either merge two list,
140        or only keep one pattern [A, B, C]) needs to done here, depending on user's need.
141        """
142        all_partition = [
143            partition
144            for partition_list in partition_list_list
145            for partition in partition_list
146        ]
147
148        # All nodes from all partitions from all pattern match results
149        all_nodes = []
150        for partition in all_partition:
151            all_nodes.extend(partition.nodes)
152        all_nodes_set = set(all_nodes)
153
154        # Calculate the number of duplciate nodes
155        duplicated_node_number = len(all_nodes) - len(all_nodes_set)
156        logging.info(f"duplicate node number is {duplicated_node_number}.")
157        return duplicated_node_number == 0
158
159    def generate_partition_list(self, graph_module) -> List[Partition]:
160        """
161        Generate a list of partitions from all matched patterns
162
163        Args:
164            graph_module: the input graph module
165
166        Returns:
167            bool: True if all partitions are exclusive.
168
169        For example, 0/1 are the partition id, A/B/../L are nodes:
170        [
171            [(0: A, B, C), (1: D, E, F)], # from pattern lstm + conv
172            [(0: G, H, I)], # from sub
173        ]
174        the output will be
175        [
176            [(0: A, B, C), (1: D, E, F), (3: G, H, I)]
177        ]
178
179        """
180        partitions_from_all_pattern = generate_pattern_op_partitions(
181            graph_module, self.patterns
182        )
183
184        # Assign a unique id for each partition
185        partition_id = 0
186
187        flat_proposed_partitions_with_unique_id = []
188        for partition in partitions_from_all_pattern:
189            partition.id = partition_id
190            flat_proposed_partitions_with_unique_id.append(partition)
191            partition_id += 1
192
193        return flat_proposed_partitions_with_unique_id
194
195    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
196        partition_tags = {}
197        partition_list = self.generate_partition_list(exported_program.graph_module)
198        for partition in partition_list:
199            for node in partition.nodes:
200                delegation_tag = f"tag{partition.id}"
201                node.meta["delegation_tag"] = delegation_tag
202                partition_tags[delegation_tag] = self.delegation_spec
203        tag_constant_data(exported_program)
204        return PartitionResult(
205            tagged_exported_program=exported_program, partition_tags=partition_tags
206        )
207
208
209@final
210class HTAPartitionerOnePatternDemo(Partitioner):
211    """
212    Similar to HTAPartitionerMultiplePatternDemo, the only difference is only one pattern (lstm + conv)
213    is lowerable. We can subclass PatternPartitioner and use the PatternPartitioner.generate_submodules()
214    function to get the graph with submodules and tag accordingly.
215    """
216
217    def __init__(self) -> None:
218        """
219        Initialize the parent class PatternPartitioner with the pattern (lstm + conv)
220        """
221
222        # Only lowering lstm + conv pattern
223        class LSTMConvPattern(torch.nn.Module):
224            def __init__(self):
225                super().__init__()
226                self.lstm = torch.nn.LSTM(
227                    input_size=32,
228                    hidden_size=32,
229                    num_layers=1,
230                )
231                self.conv = torch.nn.Conv1d(1, 1, 1, stride=2)
232
233            def forward(self, x_raw, h, c):
234                output, (hn, cn) = self.lstm(x_raw, (h, c))
235                k = self.conv(output)
236                return output, hn, cn, k
237
238        input_x = torch.ones([1, 32])
239        input_h = torch.ones([1, 32])
240        input_c = torch.ones([1, 32])
241
242        pattern_lstm_conv_lifted = (
243            exir.capture(
244                LSTMConvPattern(),
245                (input_x, input_h, input_c),
246                exir.CaptureConfig(enable_aot=True),
247            )
248            .to_edge(
249                # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
250                exir.EdgeCompileConfig(_check_ir_validity=False)
251            )
252            .exported_program.graph_module
253        )
254        pattern_lstm_conv_unlifted = (
255            exir.capture(
256                LSTMConvPattern(),
257                (input_x, input_h, input_c),
258                exir.CaptureConfig(),
259            )
260            .to_edge(
261                # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
262                exir.EdgeCompileConfig(_check_ir_validity=False)
263            )
264            .exported_program.graph_module
265        )
266        self.patterns = [
267            pattern_lstm_conv_lifted.graph,
268            pattern_lstm_conv_unlifted.graph,
269        ]
270        # Only (lstm + conv) pattern is lowerable
271
272        backend_id = QnnBackend.__name__
273        self.delegation_spec = DelegationSpec(backend_id, [])
274
275    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
276        partition_tags = {}
277        partition_list = generate_pattern_op_partitions(
278            exported_program.graph_module, patterns=self.patterns
279        )
280        for partition in partition_list:
281            for node in partition.nodes:
282                delegation_tag = f"tag{partition.id}"
283                node.meta["delegation_tag"] = delegation_tag
284                partition_tags[delegation_tag] = self.delegation_spec
285        tag_constant_data(exported_program)
286        return PartitionResult(
287            tagged_exported_program=exported_program, partition_tags=partition_tags
288        )
289