xref: /aosp_15_r20/external/pytorch/test/package/test_analyze.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: package/deploy"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerfrom torch.package import analyze
5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workertry:
9*da0073e9SAndroid Build Coastguard Worker    from .common import PackageTestCase
10*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
11*da0073e9SAndroid Build Coastguard Worker    # Support the case where we run this file directly.
12*da0073e9SAndroid Build Coastguard Worker    from common import PackageTestCase
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerclass TestAnalyze(PackageTestCase):
16*da0073e9SAndroid Build Coastguard Worker    """Dependency analysis API tests."""
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker    def test_trace_dependencies(self):
19*da0073e9SAndroid Build Coastguard Worker        import test_trace_dep
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker        obj = test_trace_dep.SumMod()
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker        used_modules = analyze.trace_dependencies(obj, [(torch.randn(4),)])
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn("yaml", used_modules)
26*da0073e9SAndroid Build Coastguard Worker        self.assertIn("test_trace_dep", used_modules)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
30*da0073e9SAndroid Build Coastguard Worker    run_tests()
31