xref: /aosp_15_r20/external/pytorch/torch/_compile.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Worker"""
3*da0073e9SAndroid Build Coastguard WorkerAPIs related to torch.compile which lazily import torch._dynamo to avoid
4*da0073e9SAndroid Build Coastguard Workercircular dependencies.
5*da0073e9SAndroid Build Coastguard Worker"""
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport functools
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerdef _disable_dynamo(fn=None, recursive=True):
11*da0073e9SAndroid Build Coastguard Worker    """
12*da0073e9SAndroid Build Coastguard Worker    This API should be only used inside torch, external users should still use
13*da0073e9SAndroid Build Coastguard Worker    torch._dynamo.disable. The main goal of this API is to avoid circular
14*da0073e9SAndroid Build Coastguard Worker    imports issues that is common while using _dynamo.disable inside torch
15*da0073e9SAndroid Build Coastguard Worker    itself.
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker    This API avoids it by lazily importing torch._dynamo from the import time to
18*da0073e9SAndroid Build Coastguard Worker    the invocation of the decorated function.
19*da0073e9SAndroid Build Coastguard Worker    """
20*da0073e9SAndroid Build Coastguard Worker    if fn is not None:
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker        @functools.wraps(fn)
23*da0073e9SAndroid Build Coastguard Worker        def inner(*args, **kwargs):
24*da0073e9SAndroid Build Coastguard Worker            # cache this on the first invocation to avoid adding too much overhead.
25*da0073e9SAndroid Build Coastguard Worker            disable_fn = getattr(fn, "__dynamo_disable", None)
26*da0073e9SAndroid Build Coastguard Worker            if disable_fn is None:
27*da0073e9SAndroid Build Coastguard Worker                import torch._dynamo
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker                disable_fn = torch._dynamo.disable(fn, recursive)
30*da0073e9SAndroid Build Coastguard Worker                fn.__dynamo_disable = disable_fn
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker            return disable_fn(*args, **kwargs)
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker        return inner
35*da0073e9SAndroid Build Coastguard Worker    else:
36*da0073e9SAndroid Build Coastguard Worker        # decorator usage like @_disable_dynamo(recursive=False). The resulting
37*da0073e9SAndroid Build Coastguard Worker        # object expects the original decorated function as the arg.
38*da0073e9SAndroid Build Coastguard Worker        return functools.partial(_disable_dynamo, recursive=recursive)
39