xref: /aosp_15_r20/external/pytorch/test/test_public_bindings.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: autograd"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport importlib
4*da0073e9SAndroid Build Coastguard Workerimport inspect
5*da0073e9SAndroid Build Coastguard Workerimport json
6*da0073e9SAndroid Build Coastguard Workerimport logging
7*da0073e9SAndroid Build Coastguard Workerimport os
8*da0073e9SAndroid Build Coastguard Workerimport pkgutil
9*da0073e9SAndroid Build Coastguard Workerimport unittest
10*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerimport torch
13*da0073e9SAndroid Build Coastguard Workerfrom torch._utils_internal import get_file_path_2
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
15*da0073e9SAndroid Build Coastguard Worker    IS_JETSON,
16*da0073e9SAndroid Build Coastguard Worker    IS_MACOS,
17*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
18*da0073e9SAndroid Build Coastguard Worker    run_tests,
19*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
20*da0073e9SAndroid Build Coastguard Worker    TestCase,
21*da0073e9SAndroid Build Coastguard Worker)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__)
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerclass TestPublicBindings(TestCase):
28*da0073e9SAndroid Build Coastguard Worker    def test_no_new_reexport_callables(self):
29*da0073e9SAndroid Build Coastguard Worker        """
30*da0073e9SAndroid Build Coastguard Worker        This test aims to stop the introduction of new re-exported callables into
31*da0073e9SAndroid Build Coastguard Worker        torch whose names do not start with _. Such callables are made available as
32*da0073e9SAndroid Build Coastguard Worker        torch.XXX, which may not be desirable.
33*da0073e9SAndroid Build Coastguard Worker        """
34*da0073e9SAndroid Build Coastguard Worker        reexported_callables = sorted(
35*da0073e9SAndroid Build Coastguard Worker            k
36*da0073e9SAndroid Build Coastguard Worker            for k, v in vars(torch).items()
37*da0073e9SAndroid Build Coastguard Worker            if callable(v) and not v.__module__.startswith("torch")
38*da0073e9SAndroid Build Coastguard Worker        )
39*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
40*da0073e9SAndroid Build Coastguard Worker            all(k.startswith("_") for k in reexported_callables), reexported_callables
41*da0073e9SAndroid Build Coastguard Worker        )
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    def test_no_new_bindings(self):
44*da0073e9SAndroid Build Coastguard Worker        """
45*da0073e9SAndroid Build Coastguard Worker        This test aims to stop the introduction of new JIT bindings into torch._C
46*da0073e9SAndroid Build Coastguard Worker        whose names do not start with _. Such bindings are made available as
47*da0073e9SAndroid Build Coastguard Worker        torch.XXX, which may not be desirable.
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        If your change causes this test to fail, add your new binding to a relevant
50*da0073e9SAndroid Build Coastguard Worker        submodule of torch._C, such as torch._C._jit (or other relevant submodule of
51*da0073e9SAndroid Build Coastguard Worker        torch._C). If your binding really needs to be available as torch.XXX, add it
52*da0073e9SAndroid Build Coastguard Worker        to torch._C and add it to the allowlist below.
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker        If you have removed a binding, remove it from the allowlist as well.
55*da0073e9SAndroid Build Coastguard Worker        """
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        # This allowlist contains every binding in torch._C that is copied into torch at
58*da0073e9SAndroid Build Coastguard Worker        # the time of writing. It was generated with
59*da0073e9SAndroid Build Coastguard Worker        #
60*da0073e9SAndroid Build Coastguard Worker        #   {elem for elem in dir(torch._C) if not elem.startswith("_")}
61*da0073e9SAndroid Build Coastguard Worker        torch_C_allowlist_superset = {
62*da0073e9SAndroid Build Coastguard Worker            "AggregationType",
63*da0073e9SAndroid Build Coastguard Worker            "AliasDb",
64*da0073e9SAndroid Build Coastguard Worker            "AnyType",
65*da0073e9SAndroid Build Coastguard Worker            "Argument",
66*da0073e9SAndroid Build Coastguard Worker            "ArgumentSpec",
67*da0073e9SAndroid Build Coastguard Worker            "AwaitType",
68*da0073e9SAndroid Build Coastguard Worker            "autocast_decrement_nesting",
69*da0073e9SAndroid Build Coastguard Worker            "autocast_increment_nesting",
70*da0073e9SAndroid Build Coastguard Worker            "AVG",
71*da0073e9SAndroid Build Coastguard Worker            "BenchmarkConfig",
72*da0073e9SAndroid Build Coastguard Worker            "BenchmarkExecutionStats",
73*da0073e9SAndroid Build Coastguard Worker            "Block",
74*da0073e9SAndroid Build Coastguard Worker            "BoolType",
75*da0073e9SAndroid Build Coastguard Worker            "BufferDict",
76*da0073e9SAndroid Build Coastguard Worker            "StorageBase",
77*da0073e9SAndroid Build Coastguard Worker            "CallStack",
78*da0073e9SAndroid Build Coastguard Worker            "Capsule",
79*da0073e9SAndroid Build Coastguard Worker            "ClassType",
80*da0073e9SAndroid Build Coastguard Worker            "clear_autocast_cache",
81*da0073e9SAndroid Build Coastguard Worker            "Code",
82*da0073e9SAndroid Build Coastguard Worker            "CompilationUnit",
83*da0073e9SAndroid Build Coastguard Worker            "CompleteArgumentSpec",
84*da0073e9SAndroid Build Coastguard Worker            "ComplexType",
85*da0073e9SAndroid Build Coastguard Worker            "ConcreteModuleType",
86*da0073e9SAndroid Build Coastguard Worker            "ConcreteModuleTypeBuilder",
87*da0073e9SAndroid Build Coastguard Worker            "cpp",
88*da0073e9SAndroid Build Coastguard Worker            "CudaBFloat16TensorBase",
89*da0073e9SAndroid Build Coastguard Worker            "CudaBoolTensorBase",
90*da0073e9SAndroid Build Coastguard Worker            "CudaByteTensorBase",
91*da0073e9SAndroid Build Coastguard Worker            "CudaCharTensorBase",
92*da0073e9SAndroid Build Coastguard Worker            "CudaComplexDoubleTensorBase",
93*da0073e9SAndroid Build Coastguard Worker            "CudaComplexFloatTensorBase",
94*da0073e9SAndroid Build Coastguard Worker            "CudaDoubleTensorBase",
95*da0073e9SAndroid Build Coastguard Worker            "CudaFloatTensorBase",
96*da0073e9SAndroid Build Coastguard Worker            "CudaHalfTensorBase",
97*da0073e9SAndroid Build Coastguard Worker            "CudaIntTensorBase",
98*da0073e9SAndroid Build Coastguard Worker            "CudaLongTensorBase",
99*da0073e9SAndroid Build Coastguard Worker            "CudaShortTensorBase",
100*da0073e9SAndroid Build Coastguard Worker            "DeepCopyMemoTable",
101*da0073e9SAndroid Build Coastguard Worker            "default_generator",
102*da0073e9SAndroid Build Coastguard Worker            "DeserializationStorageContext",
103*da0073e9SAndroid Build Coastguard Worker            "device",
104*da0073e9SAndroid Build Coastguard Worker            "DeviceObjType",
105*da0073e9SAndroid Build Coastguard Worker            "DictType",
106*da0073e9SAndroid Build Coastguard Worker            "DisableTorchFunction",
107*da0073e9SAndroid Build Coastguard Worker            "DisableTorchFunctionSubclass",
108*da0073e9SAndroid Build Coastguard Worker            "DispatchKey",
109*da0073e9SAndroid Build Coastguard Worker            "DispatchKeySet",
110*da0073e9SAndroid Build Coastguard Worker            "dtype",
111*da0073e9SAndroid Build Coastguard Worker            "EnumType",
112*da0073e9SAndroid Build Coastguard Worker            "ErrorReport",
113*da0073e9SAndroid Build Coastguard Worker            "ExcludeDispatchKeyGuard",
114*da0073e9SAndroid Build Coastguard Worker            "ExecutionPlan",
115*da0073e9SAndroid Build Coastguard Worker            "FatalError",
116*da0073e9SAndroid Build Coastguard Worker            "FileCheck",
117*da0073e9SAndroid Build Coastguard Worker            "finfo",
118*da0073e9SAndroid Build Coastguard Worker            "FloatType",
119*da0073e9SAndroid Build Coastguard Worker            "fork",
120*da0073e9SAndroid Build Coastguard Worker            "FunctionSchema",
121*da0073e9SAndroid Build Coastguard Worker            "Future",
122*da0073e9SAndroid Build Coastguard Worker            "FutureType",
123*da0073e9SAndroid Build Coastguard Worker            "Generator",
124*da0073e9SAndroid Build Coastguard Worker            "GeneratorType",
125*da0073e9SAndroid Build Coastguard Worker            "get_autocast_cpu_dtype",
126*da0073e9SAndroid Build Coastguard Worker            "get_autocast_dtype",
127*da0073e9SAndroid Build Coastguard Worker            "get_autocast_ipu_dtype",
128*da0073e9SAndroid Build Coastguard Worker            "get_default_dtype",
129*da0073e9SAndroid Build Coastguard Worker            "get_num_interop_threads",
130*da0073e9SAndroid Build Coastguard Worker            "get_num_threads",
131*da0073e9SAndroid Build Coastguard Worker            "Gradient",
132*da0073e9SAndroid Build Coastguard Worker            "Graph",
133*da0073e9SAndroid Build Coastguard Worker            "GraphExecutorState",
134*da0073e9SAndroid Build Coastguard Worker            "has_cuda",
135*da0073e9SAndroid Build Coastguard Worker            "has_cudnn",
136*da0073e9SAndroid Build Coastguard Worker            "has_lapack",
137*da0073e9SAndroid Build Coastguard Worker            "has_mkl",
138*da0073e9SAndroid Build Coastguard Worker            "has_mkldnn",
139*da0073e9SAndroid Build Coastguard Worker            "has_mps",
140*da0073e9SAndroid Build Coastguard Worker            "has_openmp",
141*da0073e9SAndroid Build Coastguard Worker            "has_spectral",
142*da0073e9SAndroid Build Coastguard Worker            "iinfo",
143*da0073e9SAndroid Build Coastguard Worker            "import_ir_module_from_buffer",
144*da0073e9SAndroid Build Coastguard Worker            "import_ir_module",
145*da0073e9SAndroid Build Coastguard Worker            "InferredType",
146*da0073e9SAndroid Build Coastguard Worker            "init_num_threads",
147*da0073e9SAndroid Build Coastguard Worker            "InterfaceType",
148*da0073e9SAndroid Build Coastguard Worker            "IntType",
149*da0073e9SAndroid Build Coastguard Worker            "SymFloatType",
150*da0073e9SAndroid Build Coastguard Worker            "SymBoolType",
151*da0073e9SAndroid Build Coastguard Worker            "SymIntType",
152*da0073e9SAndroid Build Coastguard Worker            "IODescriptor",
153*da0073e9SAndroid Build Coastguard Worker            "is_anomaly_enabled",
154*da0073e9SAndroid Build Coastguard Worker            "is_anomaly_check_nan_enabled",
155*da0073e9SAndroid Build Coastguard Worker            "is_autocast_cache_enabled",
156*da0073e9SAndroid Build Coastguard Worker            "is_autocast_cpu_enabled",
157*da0073e9SAndroid Build Coastguard Worker            "is_autocast_ipu_enabled",
158*da0073e9SAndroid Build Coastguard Worker            "is_autocast_enabled",
159*da0073e9SAndroid Build Coastguard Worker            "is_grad_enabled",
160*da0073e9SAndroid Build Coastguard Worker            "is_inference_mode_enabled",
161*da0073e9SAndroid Build Coastguard Worker            "JITException",
162*da0073e9SAndroid Build Coastguard Worker            "layout",
163*da0073e9SAndroid Build Coastguard Worker            "ListType",
164*da0073e9SAndroid Build Coastguard Worker            "LiteScriptModule",
165*da0073e9SAndroid Build Coastguard Worker            "LockingLogger",
166*da0073e9SAndroid Build Coastguard Worker            "LoggerBase",
167*da0073e9SAndroid Build Coastguard Worker            "memory_format",
168*da0073e9SAndroid Build Coastguard Worker            "merge_type_from_type_comment",
169*da0073e9SAndroid Build Coastguard Worker            "ModuleDict",
170*da0073e9SAndroid Build Coastguard Worker            "Node",
171*da0073e9SAndroid Build Coastguard Worker            "NoneType",
172*da0073e9SAndroid Build Coastguard Worker            "NoopLogger",
173*da0073e9SAndroid Build Coastguard Worker            "NumberType",
174*da0073e9SAndroid Build Coastguard Worker            "OperatorInfo",
175*da0073e9SAndroid Build Coastguard Worker            "OptionalType",
176*da0073e9SAndroid Build Coastguard Worker            "OutOfMemoryError",
177*da0073e9SAndroid Build Coastguard Worker            "ParameterDict",
178*da0073e9SAndroid Build Coastguard Worker            "parse_ir",
179*da0073e9SAndroid Build Coastguard Worker            "parse_schema",
180*da0073e9SAndroid Build Coastguard Worker            "parse_type_comment",
181*da0073e9SAndroid Build Coastguard Worker            "PyObjectType",
182*da0073e9SAndroid Build Coastguard Worker            "PyTorchFileReader",
183*da0073e9SAndroid Build Coastguard Worker            "PyTorchFileWriter",
184*da0073e9SAndroid Build Coastguard Worker            "qscheme",
185*da0073e9SAndroid Build Coastguard Worker            "read_vitals",
186*da0073e9SAndroid Build Coastguard Worker            "RRefType",
187*da0073e9SAndroid Build Coastguard Worker            "ScriptClass",
188*da0073e9SAndroid Build Coastguard Worker            "ScriptClassFunction",
189*da0073e9SAndroid Build Coastguard Worker            "ScriptDict",
190*da0073e9SAndroid Build Coastguard Worker            "ScriptDictIterator",
191*da0073e9SAndroid Build Coastguard Worker            "ScriptDictKeyIterator",
192*da0073e9SAndroid Build Coastguard Worker            "ScriptList",
193*da0073e9SAndroid Build Coastguard Worker            "ScriptListIterator",
194*da0073e9SAndroid Build Coastguard Worker            "ScriptFunction",
195*da0073e9SAndroid Build Coastguard Worker            "ScriptMethod",
196*da0073e9SAndroid Build Coastguard Worker            "ScriptModule",
197*da0073e9SAndroid Build Coastguard Worker            "ScriptModuleSerializer",
198*da0073e9SAndroid Build Coastguard Worker            "ScriptObject",
199*da0073e9SAndroid Build Coastguard Worker            "ScriptObjectProperty",
200*da0073e9SAndroid Build Coastguard Worker            "SerializationStorageContext",
201*da0073e9SAndroid Build Coastguard Worker            "set_anomaly_enabled",
202*da0073e9SAndroid Build Coastguard Worker            "set_autocast_cache_enabled",
203*da0073e9SAndroid Build Coastguard Worker            "set_autocast_cpu_dtype",
204*da0073e9SAndroid Build Coastguard Worker            "set_autocast_dtype",
205*da0073e9SAndroid Build Coastguard Worker            "set_autocast_ipu_dtype",
206*da0073e9SAndroid Build Coastguard Worker            "set_autocast_cpu_enabled",
207*da0073e9SAndroid Build Coastguard Worker            "set_autocast_ipu_enabled",
208*da0073e9SAndroid Build Coastguard Worker            "set_autocast_enabled",
209*da0073e9SAndroid Build Coastguard Worker            "set_flush_denormal",
210*da0073e9SAndroid Build Coastguard Worker            "set_num_interop_threads",
211*da0073e9SAndroid Build Coastguard Worker            "set_num_threads",
212*da0073e9SAndroid Build Coastguard Worker            "set_vital",
213*da0073e9SAndroid Build Coastguard Worker            "Size",
214*da0073e9SAndroid Build Coastguard Worker            "StaticModule",
215*da0073e9SAndroid Build Coastguard Worker            "Stream",
216*da0073e9SAndroid Build Coastguard Worker            "StreamObjType",
217*da0073e9SAndroid Build Coastguard Worker            "Event",
218*da0073e9SAndroid Build Coastguard Worker            "StringType",
219*da0073e9SAndroid Build Coastguard Worker            "SUM",
220*da0073e9SAndroid Build Coastguard Worker            "SymFloat",
221*da0073e9SAndroid Build Coastguard Worker            "SymInt",
222*da0073e9SAndroid Build Coastguard Worker            "TensorType",
223*da0073e9SAndroid Build Coastguard Worker            "ThroughputBenchmark",
224*da0073e9SAndroid Build Coastguard Worker            "TracingState",
225*da0073e9SAndroid Build Coastguard Worker            "TupleType",
226*da0073e9SAndroid Build Coastguard Worker            "Type",
227*da0073e9SAndroid Build Coastguard Worker            "unify_type_list",
228*da0073e9SAndroid Build Coastguard Worker            "UnionType",
229*da0073e9SAndroid Build Coastguard Worker            "Use",
230*da0073e9SAndroid Build Coastguard Worker            "Value",
231*da0073e9SAndroid Build Coastguard Worker            "set_autocast_gpu_dtype",
232*da0073e9SAndroid Build Coastguard Worker            "get_autocast_gpu_dtype",
233*da0073e9SAndroid Build Coastguard Worker            "vitals_enabled",
234*da0073e9SAndroid Build Coastguard Worker            "wait",
235*da0073e9SAndroid Build Coastguard Worker            "Tag",
236*da0073e9SAndroid Build Coastguard Worker            "set_autocast_xla_enabled",
237*da0073e9SAndroid Build Coastguard Worker            "set_autocast_xla_dtype",
238*da0073e9SAndroid Build Coastguard Worker            "get_autocast_xla_dtype",
239*da0073e9SAndroid Build Coastguard Worker            "is_autocast_xla_enabled",
240*da0073e9SAndroid Build Coastguard Worker        }
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker        torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker        # torch.TensorBase is explicitly removed in torch/__init__.py, so included here (#109940)
245*da0073e9SAndroid Build Coastguard Worker        explicitly_removed_torch_C_bindings = {"TensorBase"}
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker        torch_C_bindings = torch_C_bindings - explicitly_removed_torch_C_bindings
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker        # Check that the torch._C bindings are all in the allowlist. Since
250*da0073e9SAndroid Build Coastguard Worker        # bindings can change based on how PyTorch was compiled (e.g. with/without
251*da0073e9SAndroid Build Coastguard Worker        # CUDA), the two may not be an exact match but the bindings should be
252*da0073e9SAndroid Build Coastguard Worker        # a subset of the allowlist.
253*da0073e9SAndroid Build Coastguard Worker        difference = torch_C_bindings.difference(torch_C_allowlist_superset)
254*da0073e9SAndroid Build Coastguard Worker        msg = f"torch._C had bindings that are not present in the allowlist:\n{difference}"
255*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch_C_bindings.issubset(torch_C_allowlist_superset), msg)
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker    @staticmethod
258*da0073e9SAndroid Build Coastguard Worker    def _is_mod_public(modname):
259*da0073e9SAndroid Build Coastguard Worker        split_strs = modname.split(".")
260*da0073e9SAndroid Build Coastguard Worker        for elem in split_strs:
261*da0073e9SAndroid Build Coastguard Worker            if elem.startswith("_"):
262*da0073e9SAndroid Build Coastguard Worker                return False
263*da0073e9SAndroid Build Coastguard Worker        return True
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
266*da0073e9SAndroid Build Coastguard Worker        IS_WINDOWS or IS_MACOS,
267*da0073e9SAndroid Build Coastguard Worker        "Inductor/Distributed modules hard fail on windows and macos",
268*da0073e9SAndroid Build Coastguard Worker    )
269*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Broken and not relevant for now")
270*da0073e9SAndroid Build Coastguard Worker    def test_modules_can_be_imported(self):
271*da0073e9SAndroid Build Coastguard Worker        failures = []
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker        def onerror(modname):
274*da0073e9SAndroid Build Coastguard Worker            failures.append(
275*da0073e9SAndroid Build Coastguard Worker                (modname, ImportError("exception occurred importing package"))
276*da0073e9SAndroid Build Coastguard Worker            )
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        for mod in pkgutil.walk_packages(torch.__path__, "torch.", onerror=onerror):
279*da0073e9SAndroid Build Coastguard Worker            modname = mod.name
280*da0073e9SAndroid Build Coastguard Worker            try:
281*da0073e9SAndroid Build Coastguard Worker                # TODO: fix "torch/utils/model_dump/__main__.py"
282*da0073e9SAndroid Build Coastguard Worker                # which calls sys.exit() when we try to import it
283*da0073e9SAndroid Build Coastguard Worker                if "__main__" in modname:
284*da0073e9SAndroid Build Coastguard Worker                    continue
285*da0073e9SAndroid Build Coastguard Worker                importlib.import_module(modname)
286*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
287*da0073e9SAndroid Build Coastguard Worker                # Some current failures are not ImportError
288*da0073e9SAndroid Build Coastguard Worker                log.exception("import_module failed")
289*da0073e9SAndroid Build Coastguard Worker                failures.append((modname, e))
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker        # It is ok to add new entries here but please be careful that these modules
292*da0073e9SAndroid Build Coastguard Worker        # do not get imported by public code.
293*da0073e9SAndroid Build Coastguard Worker        private_allowlist = {
294*da0073e9SAndroid Build Coastguard Worker            "torch._inductor.codegen.cuda.cuda_kernel",
295*da0073e9SAndroid Build Coastguard Worker            # TODO(#133647): Remove the onnx._internal entries after
296*da0073e9SAndroid Build Coastguard Worker            # onnx and onnxscript are installed in CI.
297*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter",
298*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._analysis",
299*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._building",
300*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._capture_strategies",
301*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._compat",
302*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._core",
303*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._decomp",
304*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._dispatching",
305*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._fx_passes",
306*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._ir_passes",
307*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._isolated",
308*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._onnx_program",
309*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._registration",
310*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._reporting",
311*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._schemas",
312*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._tensors",
313*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.exporter._verification",
314*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx._pass",
315*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.analysis",
316*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.analysis.unsupported_nodes",
317*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.decomposition_skip",
318*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.diagnostics",
319*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.fx_onnx_interpreter",
320*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.fx_symbolic_graph_extractor",
321*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.onnxfunction_dispatcher",
322*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.op_validation",
323*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.passes",
324*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.passes._utils",
325*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.passes.decomp",
326*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.passes.functionalization",
327*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.passes.modularization",
328*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.passes.readability",
329*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.passes.type_promotion",
330*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.passes.virtualization",
331*da0073e9SAndroid Build Coastguard Worker            "torch.onnx._internal.fx.type_utils",
332*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.common_distributed",
333*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.common_fsdp",
334*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.dist_utils",
335*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.common_state_dict",
336*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed._shard.sharded_tensor",
337*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed._shard.test_common",
338*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed._tensor.common_dtensor",
339*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.ddp_under_dist_autograd_test",
340*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.distributed_test",
341*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.distributed_utils",
342*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.fake_pg",
343*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.multi_threaded_pg",
344*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.nn.api.remote_module_test",
345*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc.dist_autograd_test",
346*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc.dist_optimizer_test",
347*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc.examples.parameter_server_test",
348*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test",
349*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc.faulty_agent_rpc_test",
350*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc.faulty_rpc_agent_test_fixture",
351*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc.jit.dist_autograd_test",
352*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc.jit.rpc_test",
353*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc.jit.rpc_test_faulty",
354*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc.rpc_agent_test_fixture",
355*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc.rpc_test",
356*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture",
357*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc_utils",
358*da0073e9SAndroid Build Coastguard Worker            "torch._inductor.codegen.cuda.cuda_template",
359*da0073e9SAndroid Build Coastguard Worker            "torch._inductor.codegen.cuda.gemm_template",
360*da0073e9SAndroid Build Coastguard Worker            "torch._inductor.codegen.cpp_template",
361*da0073e9SAndroid Build Coastguard Worker            "torch._inductor.codegen.cpp_gemm_template",
362*da0073e9SAndroid Build Coastguard Worker            "torch._inductor.codegen.cpp_micro_gemm",
363*da0073e9SAndroid Build Coastguard Worker            "torch._inductor.codegen.cpp_template_kernel",
364*da0073e9SAndroid Build Coastguard Worker            "torch._inductor.runtime.triton_helpers",
365*da0073e9SAndroid Build Coastguard Worker            "torch.ao.pruning._experimental.data_sparsifier.lightning.callbacks.data_sparsity",
366*da0073e9SAndroid Build Coastguard Worker            "torch.backends._coreml.preprocess",
367*da0073e9SAndroid Build Coastguard Worker            "torch.contrib._tensorboard_vis",
368*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._composable",
369*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._functional_collectives",
370*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._functional_collectives_impl",
371*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._shard",
372*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._sharded_tensor",
373*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._sharding_spec",
374*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._spmd.api",
375*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._spmd.batch_dim_utils",
376*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._spmd.comm_tensor",
377*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._spmd.data_parallel",
378*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._spmd.distribute",
379*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._spmd.experimental_ops",
380*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._spmd.parallel_mode",
381*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._tensor",
382*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
383*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.algorithms._optimizer_overlap",
384*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.rpc._testing.faulty_agent_backend_registry",
385*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.rpc._utils",
386*da0073e9SAndroid Build Coastguard Worker            "torch.ao.pruning._experimental.data_sparsifier.benchmarks.dlrm_utils",
387*da0073e9SAndroid Build Coastguard Worker            "torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_disk_savings",
388*da0073e9SAndroid Build Coastguard Worker            "torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_forward_time",
389*da0073e9SAndroid Build Coastguard Worker            "torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_model_metrics",
390*da0073e9SAndroid Build Coastguard Worker            "torch.ao.pruning._experimental.data_sparsifier.lightning.tests.test_callbacks",
391*da0073e9SAndroid Build Coastguard Worker            "torch.csrc.jit.tensorexpr.scripts.bisect",
392*da0073e9SAndroid Build Coastguard Worker            "torch.csrc.lazy.test_mnist",
393*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._shard.checkpoint._fsspec_filesystem",
394*da0073e9SAndroid Build Coastguard Worker            "torch.distributed._tensor.examples.visualize_sharding_example",
395*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.checkpoint._fsspec_filesystem",
396*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.examples.memory_tracker_example",
397*da0073e9SAndroid Build Coastguard Worker            "torch.testing._internal.distributed.rpc.fb.thrift_rpc_agent_test_fixture",
398*da0073e9SAndroid Build Coastguard Worker            "torch.utils._cxx_pytree",
399*da0073e9SAndroid Build Coastguard Worker            "torch.utils.tensorboard._convert_np",
400*da0073e9SAndroid Build Coastguard Worker            "torch.utils.tensorboard._embedding",
401*da0073e9SAndroid Build Coastguard Worker            "torch.utils.tensorboard._onnx_graph",
402*da0073e9SAndroid Build Coastguard Worker            "torch.utils.tensorboard._proto_graph",
403*da0073e9SAndroid Build Coastguard Worker            "torch.utils.tensorboard._pytorch_graph",
404*da0073e9SAndroid Build Coastguard Worker            "torch.utils.tensorboard._utils",
405*da0073e9SAndroid Build Coastguard Worker        }
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker        # No new entries should be added to this list.
408*da0073e9SAndroid Build Coastguard Worker        # All public modules should be importable on all platforms.
409*da0073e9SAndroid Build Coastguard Worker        public_allowlist = {
410*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.algorithms.ddp_comm_hooks",
411*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.algorithms.model_averaging.averagers",
412*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.algorithms.model_averaging.hierarchical_model_averager",
413*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.algorithms.model_averaging.utils",
414*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.checkpoint",
415*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.constants",
416*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.distributed_c10d",
417*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.elastic.agent.server",
418*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.elastic.rendezvous",
419*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.fsdp",
420*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.launch",
421*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.launcher",
422*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.nn",
423*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.nn.api.remote_module",
424*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.optim",
425*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.optim.optimizer",
426*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.rendezvous",
427*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.rpc.api",
428*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.rpc.backend_registry",
429*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.rpc.constants",
430*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.rpc.internal",
431*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.rpc.options",
432*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.rpc.rref_proxy",
433*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.elastic.rendezvous.etcd_rendezvous",
434*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.elastic.rendezvous.etcd_rendezvous_backend",
435*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.elastic.rendezvous.etcd_store",
436*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.rpc.server_process_global_profiler",
437*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.run",
438*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.tensor.parallel",
439*da0073e9SAndroid Build Coastguard Worker            "torch.distributed.utils",
440*da0073e9SAndroid Build Coastguard Worker            "torch.utils.tensorboard",
441*da0073e9SAndroid Build Coastguard Worker            "torch.utils.tensorboard.summary",
442*da0073e9SAndroid Build Coastguard Worker            "torch.utils.tensorboard.writer",
443*da0073e9SAndroid Build Coastguard Worker            "torch.ao.quantization.experimental.fake_quantize",
444*da0073e9SAndroid Build Coastguard Worker            "torch.ao.quantization.experimental.linear",
445*da0073e9SAndroid Build Coastguard Worker            "torch.ao.quantization.experimental.observer",
446*da0073e9SAndroid Build Coastguard Worker            "torch.ao.quantization.experimental.qconfig",
447*da0073e9SAndroid Build Coastguard Worker        }
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker        errors = []
450*da0073e9SAndroid Build Coastguard Worker        for mod, exc in failures:
451*da0073e9SAndroid Build Coastguard Worker            if mod in public_allowlist:
452*da0073e9SAndroid Build Coastguard Worker                # TODO: Ensure this is the right error type
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker                continue
455*da0073e9SAndroid Build Coastguard Worker            if mod in private_allowlist:
456*da0073e9SAndroid Build Coastguard Worker                continue
457*da0073e9SAndroid Build Coastguard Worker            errors.append(
458*da0073e9SAndroid Build Coastguard Worker                f"{mod} failed to import with error {type(exc).__qualname__}: {str(exc)}"
459*da0073e9SAndroid Build Coastguard Worker            )
460*da0073e9SAndroid Build Coastguard Worker        self.assertEqual("", "\n".join(errors))
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker    # AttributeError: module 'torch.distributed' has no attribute '_shard'
463*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS or IS_JETSON or IS_MACOS, "Distributed Attribute Error")
464*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Broken and not relevant for now")
465*da0073e9SAndroid Build Coastguard Worker    def test_correct_module_names(self):
466*da0073e9SAndroid Build Coastguard Worker        """
467*da0073e9SAndroid Build Coastguard Worker        An API is considered public, if  its  `__module__` starts with `torch.`
468*da0073e9SAndroid Build Coastguard Worker        and there is no name in `__module__` or the object itself that starts with "_".
469*da0073e9SAndroid Build Coastguard Worker        Each public package should either:
470*da0073e9SAndroid Build Coastguard Worker        - (preferred) Define `__all__` and all callables and classes in there must have their
471*da0073e9SAndroid Build Coastguard Worker         `__module__` start with the current submodule's path. Things not in `__all__` should
472*da0073e9SAndroid Build Coastguard Worker          NOT have their `__module__` start with the current submodule.
473*da0073e9SAndroid Build Coastguard Worker        - (for simple python-only modules) Not define `__all__` and all the elements in `dir(submod)` must have their
474*da0073e9SAndroid Build Coastguard Worker          `__module__` that start with the current submodule.
475*da0073e9SAndroid Build Coastguard Worker        """
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker        failure_list = []
478*da0073e9SAndroid Build Coastguard Worker        with open(
479*da0073e9SAndroid Build Coastguard Worker            get_file_path_2(os.path.dirname(__file__), "allowlist_for_publicAPI.json")
480*da0073e9SAndroid Build Coastguard Worker        ) as json_file:
481*da0073e9SAndroid Build Coastguard Worker            # no new entries should be added to this allow_dict.
482*da0073e9SAndroid Build Coastguard Worker            # New APIs must follow the public API guidelines.
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker            allow_dict = json.load(json_file)
485*da0073e9SAndroid Build Coastguard Worker            # Because we want minimal modifications to the `allowlist_for_publicAPI.json`,
486*da0073e9SAndroid Build Coastguard Worker            # we are adding the entries for the migrated modules here from the original
487*da0073e9SAndroid Build Coastguard Worker            # locations.
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker            for modname in allow_dict["being_migrated"]:
490*da0073e9SAndroid Build Coastguard Worker                if modname in allow_dict:
491*da0073e9SAndroid Build Coastguard Worker                    allow_dict[allow_dict["being_migrated"][modname]] = allow_dict[
492*da0073e9SAndroid Build Coastguard Worker                        modname
493*da0073e9SAndroid Build Coastguard Worker                    ]
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker        def test_module(modname):
496*da0073e9SAndroid Build Coastguard Worker            try:
497*da0073e9SAndroid Build Coastguard Worker                if "__main__" in modname:
498*da0073e9SAndroid Build Coastguard Worker                    return
499*da0073e9SAndroid Build Coastguard Worker                mod = importlib.import_module(modname)
500*da0073e9SAndroid Build Coastguard Worker            except Exception:
501*da0073e9SAndroid Build Coastguard Worker                # It is ok to ignore here as we have a test above that ensures
502*da0073e9SAndroid Build Coastguard Worker                # this should never happen
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker                return
505*da0073e9SAndroid Build Coastguard Worker            if not self._is_mod_public(modname):
506*da0073e9SAndroid Build Coastguard Worker                return
507*da0073e9SAndroid Build Coastguard Worker            # verifies that each public API has the correct module name and naming semantics
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker            def check_one_element(elem, modname, mod, *, is_public, is_all):
510*da0073e9SAndroid Build Coastguard Worker                obj = getattr(mod, elem)
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker                # torch.dtype is not a class nor callable, so we need to check for it separately
513*da0073e9SAndroid Build Coastguard Worker                if not (
514*da0073e9SAndroid Build Coastguard Worker                    isinstance(obj, (Callable, torch.dtype)) or inspect.isclass(obj)
515*da0073e9SAndroid Build Coastguard Worker                ):
516*da0073e9SAndroid Build Coastguard Worker                    return
517*da0073e9SAndroid Build Coastguard Worker                elem_module = getattr(obj, "__module__", None)
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker                # Only used for nice error message below
520*da0073e9SAndroid Build Coastguard Worker                why_not_looks_public = ""
521*da0073e9SAndroid Build Coastguard Worker                if elem_module is None:
522*da0073e9SAndroid Build Coastguard Worker                    why_not_looks_public = (
523*da0073e9SAndroid Build Coastguard Worker                        "because it does not have a `__module__` attribute"
524*da0073e9SAndroid Build Coastguard Worker                    )
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker                # If a module is being migrated from foo.a to bar.a (that is entry {"foo": "bar"}),
527*da0073e9SAndroid Build Coastguard Worker                # the module's starting package would be referred to as the new location even
528*da0073e9SAndroid Build Coastguard Worker                # if there is a "from foo import a" inside the "bar.py".
529*da0073e9SAndroid Build Coastguard Worker                modname = allow_dict["being_migrated"].get(modname, modname)
530*da0073e9SAndroid Build Coastguard Worker                elem_modname_starts_with_mod = (
531*da0073e9SAndroid Build Coastguard Worker                    elem_module is not None
532*da0073e9SAndroid Build Coastguard Worker                    and elem_module.startswith(modname)
533*da0073e9SAndroid Build Coastguard Worker                    and "._" not in elem_module
534*da0073e9SAndroid Build Coastguard Worker                )
535*da0073e9SAndroid Build Coastguard Worker                if not why_not_looks_public and not elem_modname_starts_with_mod:
536*da0073e9SAndroid Build Coastguard Worker                    why_not_looks_public = (
537*da0073e9SAndroid Build Coastguard Worker                        f"because its `__module__` attribute (`{elem_module}`) is not within the "
538*da0073e9SAndroid Build Coastguard Worker                        f"torch library or does not start with the submodule where it is defined (`{modname}`)"
539*da0073e9SAndroid Build Coastguard Worker                    )
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker                # elem's name must NOT begin with an `_` and it's module name
542*da0073e9SAndroid Build Coastguard Worker                # SHOULD start with it's current module since it's a public API
543*da0073e9SAndroid Build Coastguard Worker                looks_public = not elem.startswith("_") and elem_modname_starts_with_mod
544*da0073e9SAndroid Build Coastguard Worker                if not why_not_looks_public and not looks_public:
545*da0073e9SAndroid Build Coastguard Worker                    why_not_looks_public = f"because it starts with `_` (`{elem}`)"
546*da0073e9SAndroid Build Coastguard Worker                if is_public != looks_public:
547*da0073e9SAndroid Build Coastguard Worker                    if modname in allow_dict and elem in allow_dict[modname]:
548*da0073e9SAndroid Build Coastguard Worker                        return
549*da0073e9SAndroid Build Coastguard Worker                    if is_public:
550*da0073e9SAndroid Build Coastguard Worker                        why_is_public = (
551*da0073e9SAndroid Build Coastguard Worker                            f"it is inside the module's (`{modname}`) `__all__`"
552*da0073e9SAndroid Build Coastguard Worker                            if is_all
553*da0073e9SAndroid Build Coastguard Worker                            else "it is an attribute that does not start with `_` on a module that "
554*da0073e9SAndroid Build Coastguard Worker                            "does not have `__all__` defined"
555*da0073e9SAndroid Build Coastguard Worker                        )
556*da0073e9SAndroid Build Coastguard Worker                        fix_is_public = (
557*da0073e9SAndroid Build Coastguard Worker                            f"remove it from the modules's (`{modname}`) `__all__`"
558*da0073e9SAndroid Build Coastguard Worker                            if is_all
559*da0073e9SAndroid Build Coastguard Worker                            else f"either define a `__all__` for `{modname}` or add a `_` at the beginning of the name"
560*da0073e9SAndroid Build Coastguard Worker                        )
561*da0073e9SAndroid Build Coastguard Worker                    else:
562*da0073e9SAndroid Build Coastguard Worker                        assert is_all
563*da0073e9SAndroid Build Coastguard Worker                        why_is_public = (
564*da0073e9SAndroid Build Coastguard Worker                            f"it is not inside the module's (`{modname}`) `__all__`"
565*da0073e9SAndroid Build Coastguard Worker                        )
566*da0073e9SAndroid Build Coastguard Worker                        fix_is_public = (
567*da0073e9SAndroid Build Coastguard Worker                            f"add it from the modules's (`{modname}`) `__all__`"
568*da0073e9SAndroid Build Coastguard Worker                        )
569*da0073e9SAndroid Build Coastguard Worker                    if looks_public:
570*da0073e9SAndroid Build Coastguard Worker                        why_looks_public = (
571*da0073e9SAndroid Build Coastguard Worker                            "it does look public because it follows the rules from the doc above "
572*da0073e9SAndroid Build Coastguard Worker                            "(does not start with `_` and has a proper `__module__`)."
573*da0073e9SAndroid Build Coastguard Worker                        )
574*da0073e9SAndroid Build Coastguard Worker                        fix_looks_public = "make its name start with `_`"
575*da0073e9SAndroid Build Coastguard Worker                    else:
576*da0073e9SAndroid Build Coastguard Worker                        why_looks_public = why_not_looks_public
577*da0073e9SAndroid Build Coastguard Worker                        if not elem_modname_starts_with_mod:
578*da0073e9SAndroid Build Coastguard Worker                            fix_looks_public = (
579*da0073e9SAndroid Build Coastguard Worker                                "make sure the `__module__` is properly set and points to a submodule "
580*da0073e9SAndroid Build Coastguard Worker                                f"of `{modname}`"
581*da0073e9SAndroid Build Coastguard Worker                            )
582*da0073e9SAndroid Build Coastguard Worker                        else:
583*da0073e9SAndroid Build Coastguard Worker                            fix_looks_public = (
584*da0073e9SAndroid Build Coastguard Worker                                "remove the `_` at the beginning of the name"
585*da0073e9SAndroid Build Coastguard Worker                            )
586*da0073e9SAndroid Build Coastguard Worker                    failure_list.append(f"# {modname}.{elem}:")
587*da0073e9SAndroid Build Coastguard Worker                    is_public_str = "" if is_public else " NOT"
588*da0073e9SAndroid Build Coastguard Worker                    failure_list.append(
589*da0073e9SAndroid Build Coastguard Worker                        f"  - Is{is_public_str} public: {why_is_public}"
590*da0073e9SAndroid Build Coastguard Worker                    )
591*da0073e9SAndroid Build Coastguard Worker                    looks_public_str = "" if looks_public else " NOT"
592*da0073e9SAndroid Build Coastguard Worker                    failure_list.append(
593*da0073e9SAndroid Build Coastguard Worker                        f"  - Does{looks_public_str} look public: {why_looks_public}"
594*da0073e9SAndroid Build Coastguard Worker                    )
595*da0073e9SAndroid Build Coastguard Worker                    # Swap the str below to avoid having to create the NOT again
596*da0073e9SAndroid Build Coastguard Worker                    failure_list.append(
597*da0073e9SAndroid Build Coastguard Worker                        "  - You can do either of these two things to fix this problem:"
598*da0073e9SAndroid Build Coastguard Worker                    )
599*da0073e9SAndroid Build Coastguard Worker                    failure_list.append(
600*da0073e9SAndroid Build Coastguard Worker                        f"    - To make it{looks_public_str} public: {fix_is_public}"
601*da0073e9SAndroid Build Coastguard Worker                    )
602*da0073e9SAndroid Build Coastguard Worker                    failure_list.append(
603*da0073e9SAndroid Build Coastguard Worker                        f"    - To make it{is_public_str} look public: {fix_looks_public}"
604*da0073e9SAndroid Build Coastguard Worker                    )
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker            if hasattr(mod, "__all__"):
607*da0073e9SAndroid Build Coastguard Worker                public_api = mod.__all__
608*da0073e9SAndroid Build Coastguard Worker                all_api = dir(mod)
609*da0073e9SAndroid Build Coastguard Worker                for elem in all_api:
610*da0073e9SAndroid Build Coastguard Worker                    check_one_element(
611*da0073e9SAndroid Build Coastguard Worker                        elem, modname, mod, is_public=elem in public_api, is_all=True
612*da0073e9SAndroid Build Coastguard Worker                    )
613*da0073e9SAndroid Build Coastguard Worker            else:
614*da0073e9SAndroid Build Coastguard Worker                all_api = dir(mod)
615*da0073e9SAndroid Build Coastguard Worker                for elem in all_api:
616*da0073e9SAndroid Build Coastguard Worker                    if not elem.startswith("_"):
617*da0073e9SAndroid Build Coastguard Worker                        check_one_element(
618*da0073e9SAndroid Build Coastguard Worker                            elem, modname, mod, is_public=True, is_all=False
619*da0073e9SAndroid Build Coastguard Worker                        )
620*da0073e9SAndroid Build Coastguard Worker
621*da0073e9SAndroid Build Coastguard Worker        for mod in pkgutil.walk_packages(torch.__path__, "torch."):
622*da0073e9SAndroid Build Coastguard Worker            modname = mod.name
623*da0073e9SAndroid Build Coastguard Worker            test_module(modname)
624*da0073e9SAndroid Build Coastguard Worker        test_module("torch")
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker        msg = (
627*da0073e9SAndroid Build Coastguard Worker            "All the APIs below do not meet our guidelines for public API from "
628*da0073e9SAndroid Build Coastguard Worker            "https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.\n"
629*da0073e9SAndroid Build Coastguard Worker        )
630*da0073e9SAndroid Build Coastguard Worker        msg += (
631*da0073e9SAndroid Build Coastguard Worker            "Make sure that everything that is public is expected (in particular that the module "
632*da0073e9SAndroid Build Coastguard Worker            "has a properly populated `__all__` attribute) and that everything that is supposed to be public "
633*da0073e9SAndroid Build Coastguard Worker            "does look public (it does not start with `_` and has a `__module__` that is properly populated)."
634*da0073e9SAndroid Build Coastguard Worker        )
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker        msg += "\n\nFull list:\n"
637*da0073e9SAndroid Build Coastguard Worker        msg += "\n".join(map(str, failure_list))
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker        # empty lists are considered false in python
640*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not failure_list, msg)
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
644*da0073e9SAndroid Build Coastguard Worker    run_tests()
645