xref: /aosp_15_r20/external/pytorch/torch/utils/_exposed_in.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Allows one to expose an API in a private submodule publicly as per the definition
3# in PyTorch's public api policy.
4#
5# It is a temporary solution while we figure out if it should be the long-term solution
6# or if we should amend PyTorch's public api policy. The concern is that this approach
7# may not be very robust because it's not clear what __module__ is used for.
8# However, both numpy and jax overwrite the __module__ attribute of their APIs
9# without problem, so it seems fine.
10def exposed_in(module):
11    def wrapper(fn):
12        fn.__module__ = module
13        return fn
14
15    return wrapper
16