1# Owner(s): ["oncall: package/deploy"] 2 3from io import BytesIO 4from sys import version_info 5from textwrap import dedent 6from unittest import skipIf 7 8from torch.package import PackageExporter, PackageImporter 9from torch.testing._internal.common_utils import run_tests 10 11 12try: 13 from .common import PackageTestCase 14except ImportError: 15 # Support the case where we run this file directly. 16 from common import PackageTestCase 17 18 19@skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7") 20class TestResources(PackageTestCase): 21 """Tests for access APIs for packaged resources.""" 22 23 def test_resource_reader(self): 24 """Test compliance with the get_resource_reader importlib API.""" 25 buffer = BytesIO() 26 with PackageExporter(buffer) as pe: 27 # Layout looks like: 28 # package 29 # |-- one/ 30 # | |-- a.txt 31 # | |-- b.txt 32 # | |-- c.txt 33 # | +-- three/ 34 # | |-- d.txt 35 # | +-- e.txt 36 # +-- two/ 37 # |-- f.txt 38 # +-- g.txt 39 pe.save_text("one", "a.txt", "hello, a!") 40 pe.save_text("one", "b.txt", "hello, b!") 41 pe.save_text("one", "c.txt", "hello, c!") 42 43 pe.save_text("one.three", "d.txt", "hello, d!") 44 pe.save_text("one.three", "e.txt", "hello, e!") 45 46 pe.save_text("two", "f.txt", "hello, f!") 47 pe.save_text("two", "g.txt", "hello, g!") 48 49 buffer.seek(0) 50 importer = PackageImporter(buffer) 51 52 reader_one = importer.get_resource_reader("one") 53 with self.assertRaises(FileNotFoundError): 54 reader_one.resource_path("a.txt") 55 56 self.assertTrue(reader_one.is_resource("a.txt")) 57 self.assertEqual(reader_one.open_resource("a.txt").getbuffer(), b"hello, a!") 58 self.assertFalse(reader_one.is_resource("three")) 59 reader_one_contents = list(reader_one.contents()) 60 self.assertSequenceEqual( 61 reader_one_contents, ["a.txt", "b.txt", "c.txt", "three"] 62 ) 63 64 reader_two = importer.get_resource_reader("two") 65 self.assertTrue(reader_two.is_resource("f.txt")) 66 self.assertEqual(reader_two.open_resource("f.txt").getbuffer(), b"hello, f!") 67 reader_two_contents = list(reader_two.contents()) 68 self.assertSequenceEqual(reader_two_contents, ["f.txt", "g.txt"]) 69 70 reader_one_three = importer.get_resource_reader("one.three") 71 self.assertTrue(reader_one_three.is_resource("d.txt")) 72 self.assertEqual( 73 reader_one_three.open_resource("d.txt").getbuffer(), b"hello, d!" 74 ) 75 reader_one_three_contenst = list(reader_one_three.contents()) 76 self.assertSequenceEqual(reader_one_three_contenst, ["d.txt", "e.txt"]) 77 78 self.assertIsNone(importer.get_resource_reader("nonexistent_package")) 79 80 def test_package_resource_access(self): 81 """Packaged modules should be able to use the importlib.resources API to access 82 resources saved in the package. 83 """ 84 mod_src = dedent( 85 """\ 86 import importlib.resources 87 import my_cool_resources 88 89 def secret_message(): 90 return importlib.resources.read_text(my_cool_resources, 'sekrit.txt') 91 """ 92 ) 93 buffer = BytesIO() 94 with PackageExporter(buffer) as pe: 95 pe.save_source_string("foo.bar", mod_src) 96 pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays") 97 98 buffer.seek(0) 99 importer = PackageImporter(buffer) 100 self.assertEqual( 101 importer.import_module("foo.bar").secret_message(), "my sekrit plays" 102 ) 103 104 def test_importer_access(self): 105 buffer = BytesIO() 106 with PackageExporter(buffer) as he: 107 he.save_text("main", "main", "my string") 108 he.save_binary("main", "main_binary", b"my string") 109 src = dedent( 110 """\ 111 import importlib 112 import torch_package_importer as resources 113 114 t = resources.load_text('main', 'main') 115 b = resources.load_binary('main', 'main_binary') 116 """ 117 ) 118 he.save_source_string("main", src, is_package=True) 119 buffer.seek(0) 120 hi = PackageImporter(buffer) 121 m = hi.import_module("main") 122 self.assertEqual(m.t, "my string") 123 self.assertEqual(m.b, b"my string") 124 125 def test_resource_access_by_path(self): 126 """ 127 Tests that packaged code can used importlib.resources.path. 128 """ 129 buffer = BytesIO() 130 with PackageExporter(buffer) as he: 131 he.save_binary("string_module", "my_string", b"my string") 132 src = dedent( 133 """\ 134 import importlib.resources 135 import string_module 136 137 with importlib.resources.path(string_module, 'my_string') as path: 138 with open(path, mode='r', encoding='utf-8') as f: 139 s = f.read() 140 """ 141 ) 142 he.save_source_string("main", src, is_package=True) 143 buffer.seek(0) 144 hi = PackageImporter(buffer) 145 m = hi.import_module("main") 146 self.assertEqual(m.s, "my string") 147 148 149if __name__ == "__main__": 150 run_tests() 151