1# mypy: allow-untyped-defs 2# Allows one to expose an API in a private submodule publicly as per the definition 3# in PyTorch's public api policy. 4# 5# It is a temporary solution while we figure out if it should be the long-term solution 6# or if we should amend PyTorch's public api policy. The concern is that this approach 7# may not be very robust because it's not clear what __module__ is used for. 8# However, both numpy and jax overwrite the __module__ attribute of their APIs 9# without problem, so it seems fine. 10def exposed_in(module): 11 def wrapper(fn): 12 fn.__module__ = module 13 return fn 14 15 return wrapper 16