xref: /aosp_15_r20/external/pytorch/test/package/test_analyze.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3import torch
4from torch.package import analyze
5from torch.testing._internal.common_utils import run_tests
6
7
8try:
9    from .common import PackageTestCase
10except ImportError:
11    # Support the case where we run this file directly.
12    from common import PackageTestCase
13
14
15class TestAnalyze(PackageTestCase):
16    """Dependency analysis API tests."""
17
18    def test_trace_dependencies(self):
19        import test_trace_dep
20
21        obj = test_trace_dep.SumMod()
22
23        used_modules = analyze.trace_dependencies(obj, [(torch.randn(4),)])
24
25        self.assertNotIn("yaml", used_modules)
26        self.assertIn("test_trace_dep", used_modules)
27
28
29if __name__ == "__main__":
30    run_tests()
31