1# Owner(s): ["module: unknown"] 2 3import glob 4import io 5import os 6import unittest 7 8import torch 9from torch.testing._internal.common_utils import run_tests, TestCase 10 11 12try: 13 from third_party.build_bundled import create_bundled 14except ImportError: 15 create_bundled = None 16 17license_file = "third_party/LICENSES_BUNDLED.txt" 18starting_txt = "The PyTorch repository and source distributions bundle" 19site_packages = os.path.dirname(os.path.dirname(torch.__file__)) 20distinfo = glob.glob(os.path.join(site_packages, "torch-*dist-info")) 21 22 23class TestLicense(TestCase): 24 @unittest.skipIf(not create_bundled, "can only be run in a source tree") 25 def test_license_for_wheel(self): 26 current = io.StringIO() 27 create_bundled("third_party", current) 28 with open(license_file) as fid: 29 src_tree = fid.read() 30 if not src_tree == current.getvalue(): 31 raise AssertionError( 32 f'the contents of "{license_file}" do not ' 33 "match the current state of the third_party files. Use " 34 '"python third_party/build_bundled.py" to regenerate it' 35 ) 36 37 @unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test") 38 def test_distinfo_license(self): 39 """If run when pytorch is installed via a wheel, the license will be in 40 site-package/torch-*dist-info/LICENSE. Make sure it contains the third 41 party bundle of licenses""" 42 43 if len(distinfo) > 1: 44 raise AssertionError( 45 'Found too many "torch-*dist-info" directories ' 46 f'in "{site_packages}, expected only one' 47 ) 48 with open(os.path.join(os.path.join(distinfo[0], "LICENSE"))) as fid: 49 txt = fid.read() 50 self.assertTrue(starting_txt in txt) 51 52 53if __name__ == "__main__": 54 run_tests() 55