1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: package/deploy"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerfrom torch.package import analyze 5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workertry: 9*da0073e9SAndroid Build Coastguard Worker from .common import PackageTestCase 10*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 11*da0073e9SAndroid Build Coastguard Worker # Support the case where we run this file directly. 12*da0073e9SAndroid Build Coastguard Worker from common import PackageTestCase 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerclass TestAnalyze(PackageTestCase): 16*da0073e9SAndroid Build Coastguard Worker """Dependency analysis API tests.""" 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker def test_trace_dependencies(self): 19*da0073e9SAndroid Build Coastguard Worker import test_trace_dep 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker obj = test_trace_dep.SumMod() 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker used_modules = analyze.trace_dependencies(obj, [(torch.randn(4),)]) 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("yaml", used_modules) 26*da0073e9SAndroid Build Coastguard Worker self.assertIn("test_trace_dep", used_modules) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 30*da0073e9SAndroid Build Coastguard Worker run_tests() 31