xref: /aosp_15_r20/external/pytorch/test/test_license.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport glob
4*da0073e9SAndroid Build Coastguard Workerimport io
5*da0073e9SAndroid Build Coastguard Workerimport os
6*da0073e9SAndroid Build Coastguard Workerimport unittest
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workertry:
13*da0073e9SAndroid Build Coastguard Worker    from third_party.build_bundled import create_bundled
14*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
15*da0073e9SAndroid Build Coastguard Worker    create_bundled = None
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerlicense_file = "third_party/LICENSES_BUNDLED.txt"
18*da0073e9SAndroid Build Coastguard Workerstarting_txt = "The PyTorch repository and source distributions bundle"
19*da0073e9SAndroid Build Coastguard Workersite_packages = os.path.dirname(os.path.dirname(torch.__file__))
20*da0073e9SAndroid Build Coastguard Workerdistinfo = glob.glob(os.path.join(site_packages, "torch-*dist-info"))
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Workerclass TestLicense(TestCase):
24*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not create_bundled, "can only be run in a source tree")
25*da0073e9SAndroid Build Coastguard Worker    def test_license_for_wheel(self):
26*da0073e9SAndroid Build Coastguard Worker        current = io.StringIO()
27*da0073e9SAndroid Build Coastguard Worker        create_bundled("third_party", current)
28*da0073e9SAndroid Build Coastguard Worker        with open(license_file) as fid:
29*da0073e9SAndroid Build Coastguard Worker            src_tree = fid.read()
30*da0073e9SAndroid Build Coastguard Worker        if not src_tree == current.getvalue():
31*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(
32*da0073e9SAndroid Build Coastguard Worker                f'the contents of "{license_file}" do not '
33*da0073e9SAndroid Build Coastguard Worker                "match the current state of the third_party files. Use "
34*da0073e9SAndroid Build Coastguard Worker                '"python third_party/build_bundled.py" to regenerate it'
35*da0073e9SAndroid Build Coastguard Worker            )
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test")
38*da0073e9SAndroid Build Coastguard Worker    def test_distinfo_license(self):
39*da0073e9SAndroid Build Coastguard Worker        """If run when pytorch is installed via a wheel, the license will be in
40*da0073e9SAndroid Build Coastguard Worker        site-package/torch-*dist-info/LICENSE. Make sure it contains the third
41*da0073e9SAndroid Build Coastguard Worker        party bundle of licenses"""
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker        if len(distinfo) > 1:
44*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(
45*da0073e9SAndroid Build Coastguard Worker                'Found too many "torch-*dist-info" directories '
46*da0073e9SAndroid Build Coastguard Worker                f'in "{site_packages}, expected only one'
47*da0073e9SAndroid Build Coastguard Worker            )
48*da0073e9SAndroid Build Coastguard Worker        with open(os.path.join(os.path.join(distinfo[0], "LICENSE"))) as fid:
49*da0073e9SAndroid Build Coastguard Worker            txt = fid.read()
50*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(starting_txt in txt)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
54*da0073e9SAndroid Build Coastguard Worker    run_tests()
55