xref: /aosp_15_r20/external/pytorch/test/test_license.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import glob
4import io
5import os
6import unittest
7
8import torch
9from torch.testing._internal.common_utils import run_tests, TestCase
10
11
12try:
13    from third_party.build_bundled import create_bundled
14except ImportError:
15    create_bundled = None
16
17license_file = "third_party/LICENSES_BUNDLED.txt"
18starting_txt = "The PyTorch repository and source distributions bundle"
19site_packages = os.path.dirname(os.path.dirname(torch.__file__))
20distinfo = glob.glob(os.path.join(site_packages, "torch-*dist-info"))
21
22
23class TestLicense(TestCase):
24    @unittest.skipIf(not create_bundled, "can only be run in a source tree")
25    def test_license_for_wheel(self):
26        current = io.StringIO()
27        create_bundled("third_party", current)
28        with open(license_file) as fid:
29            src_tree = fid.read()
30        if not src_tree == current.getvalue():
31            raise AssertionError(
32                f'the contents of "{license_file}" do not '
33                "match the current state of the third_party files. Use "
34                '"python third_party/build_bundled.py" to regenerate it'
35            )
36
37    @unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test")
38    def test_distinfo_license(self):
39        """If run when pytorch is installed via a wheel, the license will be in
40        site-package/torch-*dist-info/LICENSE. Make sure it contains the third
41        party bundle of licenses"""
42
43        if len(distinfo) > 1:
44            raise AssertionError(
45                'Found too many "torch-*dist-info" directories '
46                f'in "{site_packages}, expected only one'
47            )
48        with open(os.path.join(os.path.join(distinfo[0], "LICENSE"))) as fid:
49            txt = fid.read()
50            self.assertTrue(starting_txt in txt)
51
52
53if __name__ == "__main__":
54    run_tests()
55