1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import torch 7from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO 8 9from executorch.exir.pass_base import ExportPass, PassResult 10from executorch.exir.tensor import TensorSpec 11 12 13class BuildQuantIo(ExportPass): 14 """ 15 To make lowering process correct, the pass assign the correct quantized dtype to spec of call_delegate. 16 """ 17 18 def __init__(self): 19 super(BuildQuantIo, self).__init__() 20 21 def _make_spec(self, x): 22 if isinstance(x, torch.Tensor): 23 return TensorSpec.from_tensor(x) 24 elif isinstance(x, (int, bool, float)): 25 return x 26 else: 27 return None 28 29 def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: 30 # forcely update delegate node's meta['spec'] to get correct output 31 # tensor size in runtime 32 call_delegate = [ 33 node 34 for node in graph_module.graph.nodes 35 if node.op == "call_function" and node.name == "executorch_call_delegate" 36 ] 37 assert len(call_delegate) == 1 38 spec = [] 39 for n in graph_module.graph.nodes: 40 if QCOM_QUANTIZED_IO in n.meta: 41 n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO]) 42 if n.op == "call_function" and "getitem" in n.name: 43 fake_tensor = n.meta["val"] 44 if QCOM_QUANTIZED_IO in n.meta: 45 fake_tensor = fake_tensor.to(dtype=n.meta[QCOM_QUANTIZED_IO]) 46 spec.append(self._make_spec(fake_tensor)) 47 48 call_delegate[0].meta["spec"] = tuple(spec) 49 50 def call(self, graph_module: torch.fx.GraphModule): 51 self._build(graph_module) 52 graph_module.graph.eliminate_dead_code() 53 graph_module.recompile() 54 return PassResult(graph_module, True) 55