1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import torch 7from executorch.exir.dialects._ops import ops as exir_ops 8from executorch.exir.pass_base import ExportPass, PassResult 9 10 11class RecomposePixelUnshuffle(ExportPass): 12 """ 13 Merge decomposed operators from mathematically equivalent implementation 14 back to one super node. 15 """ 16 17 def __init__(self, quantization_capture=False): 18 super(RecomposePixelUnshuffle, self).__init__() 19 self.reshape_target = exir_ops.edge.aten.view_copy.default 20 self.permute_target = exir_ops.edge.aten.permute_copy.default 21 self.view_target = exir_ops.edge.aten.view_copy.default 22 self.op = exir_ops.edge.aten.pixel_unshuffle.default 23 24 self.quantization_capture = quantization_capture 25 if quantization_capture: 26 self.reshape_target = torch.ops.aten._unsafe_view.default 27 self.permute_target = torch.ops.aten.permute.default 28 self.view_target = torch.ops.aten.view.default 29 self.op = torch.ops.aten.pixel_unshuffle.default 30 31 def call(self, graph_module: torch.fx.GraphModule): 32 graph = graph_module.graph 33 # math equivalent implementation 34 for node in graph.nodes: 35 if node.op == "call_function" and node.target == self.reshape_target: 36 with graph.inserting_after(node): 37 38 # Clone op still exists between permute and reshape_target during quantization, 39 # so we need to check for args[0].args[0] to get permute node 40 if self.quantization_capture: 41 premute_node = node.args[0].args[0] 42 else: 43 premute_node = node.args[0] 44 if any( 45 [ 46 len(node.args[1]) != 4, 47 premute_node.op != "call_function", 48 premute_node.target != self.permute_target, 49 ] 50 ): 51 continue 52 53 view_node = premute_node.args[0] 54 if any( 55 [ 56 view_node.op != "call_function", 57 view_node.target != self.view_target, 58 len(view_node.args[1]) != 6, 59 len(premute_node.args[1]) != 6, 60 ] 61 ): 62 continue 63 64 b_in, d_nominal, h_in, s_h, w_in, s_w = view_node.args[1] 65 b_out, d_out, w_out, h_out = node.args[1] 66 if any( 67 [ 68 b_out != b_in, 69 d_out != d_nominal * s_h * s_w, 70 w_in != w_out, 71 h_in != h_out, 72 ] 73 ): 74 continue 75 76 input_node = view_node.args[0] 77 args = (input_node, s_h) 78 pixel_unshuffle_node = graph.create_node( 79 "call_function", self.op, args 80 ) 81 users = node.users.copy() 82 for user in users: 83 user.replace_input_with(node, pixel_unshuffle_node) 84 # copy metadata 85 pixel_unshuffle_node.meta = node.meta 86 87 graph.eliminate_dead_code() 88 graph_module.recompile() 89 return PassResult(graph_module, True) 90