xref: /aosp_15_r20/external/pytorch/torch/_dynamo/backends/tvm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import functools
4import importlib
5import logging
6import os
7import sys
8import tempfile
9from types import MappingProxyType
10from typing import Optional
11
12import torch
13
14from .common import device_from_inputs, fake_tensor_unsupported
15from .registry import register_backend
16
17
18log = logging.getLogger(__name__)
19
20
21@register_backend
22@fake_tensor_unsupported
23def tvm(
24    gm,
25    example_inputs,
26    *,
27    options: Optional[MappingProxyType] = MappingProxyType(
28        {"scheduler": None, "trials": 20000, "opt_level": 3}
29    ),
30):
31    import tvm  # type: ignore[import]
32    from tvm import relay  # type: ignore[import]
33    from tvm.contrib import graph_executor  # type: ignore[import]
34
35    jit_mod = torch.jit.trace(gm, example_inputs)
36    device = device_from_inputs(example_inputs)
37    shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
38    example_outputs = gm(*example_inputs)
39    if len(example_outputs) == 0:
40        log.warning("Explicitly fall back to eager due to zero output")
41        return gm.forward
42    mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
43    if device.type == "cuda":
44        dev = tvm.cuda(device.index)
45        target = tvm.target.cuda()
46    else:
47        dev = tvm.cpu(0)
48        target = tvm.target.Target(llvm_target())
49
50    scheduler = options.get("scheduler", None)
51    if scheduler is None:
52        scheduler = os.environ.get("TVM_SCHEDULER", None)
53
54    trials = options.get("trials", 20000)
55    opt_level = options.get("opt_level", 3)
56
57    if scheduler == "auto_scheduler":
58        from tvm import auto_scheduler
59
60        log_file = tempfile.NamedTemporaryFile()
61
62        if not os.path.exists(log_file):
63            tasks, task_weights = auto_scheduler.extract_tasks(
64                mod["main"], params, target
65            )
66            for task in tasks:
67                print(task.compute_dag)
68            else:
69                print("No tasks")
70            if len(tasks) != 0:
71                tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
72                if not os.path.exists(log_file):
73                    assert trials > 0
74                    tune_option = auto_scheduler.TuningOptions(
75                        num_measure_trials=trials,
76                        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
77                        early_stopping=2000,
78                    )
79                    try:
80                        tuner.tune(tune_option)
81                    except Exception:
82                        if os.path.exists(log_file):
83                            os.unlink(log_file)
84                        raise
85
86        with auto_scheduler.ApplyHistoryBest(log_file):
87            with tvm.transform.PassContext(
88                opt_level=opt_level, config={"relay.backend.use_auto_scheduler": True}
89            ):
90                lib = relay.build(mod, target=target, params=params)
91    elif scheduler == "meta_schedule":
92        from tvm import meta_schedule as ms
93
94        with tempfile.TemporaryDirectory() as work_dir:
95            if device.type != "cuda":
96                # meta_schedule needs num-cores to be specified
97                # here we use the maximum core count
98                target = tvm.target.Target(
99                    f"{llvm_target()} --num-cores {ms.utils.cpu_count(logical=False)}"
100                )
101            # TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch
102            # once USE_PT_TVMDSOOP is updated and turned on by default in TVM.
103            assert trials > 0
104            database = ms.relay_integration.tune_relay(
105                mod=mod,
106                target=target,
107                work_dir=work_dir,
108                max_trials_global=trials,
109                num_trials_per_iter=64,
110                params=params,
111                strategy="evolutionary",
112                opt_level=opt_level,
113            )
114            lib = ms.relay_integration.compile_relay(
115                database=database,
116                mod=mod,
117                target=target,
118                params=params,
119                opt_level=opt_level,
120            )
121    elif scheduler == "default" or not scheduler:
122        # no autotuning
123        with tvm.transform.PassContext(opt_level=opt_level):
124            lib = relay.build(mod, target=target, params=params)
125    else:
126        raise NotImplementedError(
127            "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. "
128            "There are three available options: default, auto_scheduler and meta_schedule."
129        )
130    m = graph_executor.GraphModule(lib["default"](dev))
131
132    def to_torch_tensor(nd_tensor):
133        """A helper function to transfer a NDArray to torch.tensor."""
134        if nd_tensor.dtype == "bool":
135            # DLPack does not support boolean so it can't be handled by
136            # torch.utils.dlpack.from_pack. Workaround by going through
137            # numpy, although this brings additional data copy overhead.
138            return torch.from_numpy(nd_tensor.numpy())
139        return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
140
141    def to_tvm_tensor(torch_tensor):
142        """A helper function to transfer a torch.tensor to NDArray."""
143        if torch_tensor.dtype == torch.bool:
144            # same reason as above, fallback to numpy conversion which
145            # could introduce data copy overhead
146            return tvm.nd.array(torch_tensor.cpu().numpy())
147        return tvm.nd.from_dlpack(torch_tensor)
148
149    def exec_tvm(*i_args):
150        args = [a.contiguous() for a in i_args]
151        shape_info, _ = m.get_input_info()
152        active_inputs = {name for name, _ in shape_info.items()}
153        for idx, arg in enumerate(args, 0):
154            if arg.dim() != 0:
155                if arg.requires_grad:
156                    arg = arg.detach()
157                inp_name = f"inp_{idx}"
158                if inp_name not in active_inputs:
159                    log.warning(
160                        "input %s skipped as not found in tvm's runtime library",
161                        inp_name,
162                    )
163                    continue
164                m.set_input(
165                    inp_name,
166                    to_tvm_tensor(arg),
167                )
168        m.run()
169        return [to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())]
170
171    return exec_tvm
172
173
174tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule")
175tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler")
176
177
178def has_tvm():
179    try:
180        importlib.import_module("tvm")
181        return True
182    except ImportError:
183        return False
184
185
186@functools.lru_cache(None)
187def llvm_target():
188    if sys.platform == "linux":
189        cpuinfo = open("/proc/cpuinfo").read()
190        if "avx512" in cpuinfo:
191            return "llvm -mcpu=skylake-avx512"
192        elif "avx2" in cpuinfo:
193            return "llvm -mcpu=core-avx2"
194    return "llvm"
195