1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5 6import torch 7 8 9# Make the helper files in test/ importable 10pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 11sys.path.append(pytorch_test_dir) 12from torch.testing._internal.jit_utils import JitTestCase 13 14 15if __name__ == "__main__": 16 raise RuntimeError( 17 "This test file is not meant to be run directly, use:\n\n" 18 "\tpython test/test_jit.py TESTNAME\n\n" 19 "instead." 20 ) 21 22 23class TestTensorCreationOps(JitTestCase): 24 """ 25 A suite of tests for ops that create tensors. 26 """ 27 28 def test_randperm_default_dtype(self): 29 def randperm(x: int): 30 perm = torch.randperm(x) 31 # Have to perform assertion here because TorchScript returns dtypes 32 # as integers, which are not comparable against eager torch.dtype. 33 assert perm.dtype == torch.int64 34 35 self.checkScript(randperm, (3,)) 36 37 def test_randperm_specifed_dtype(self): 38 def randperm(x: int): 39 perm = torch.randperm(x, dtype=torch.float) 40 # Have to perform assertion here because TorchScript returns dtypes 41 # as integers, which are not comparable against eager torch.dtype. 42 assert perm.dtype == torch.float 43 44 self.checkScript(randperm, (3,)) 45 46 def test_triu_indices_default_dtype(self): 47 def triu_indices(rows: int, cols: int): 48 indices = torch.triu_indices(rows, cols) 49 # Have to perform assertion here because TorchScript returns dtypes 50 # as integers, which are not comparable against eager torch.dtype. 51 assert indices.dtype == torch.int64 52 53 self.checkScript(triu_indices, (3, 3)) 54 55 def test_triu_indices_specified_dtype(self): 56 def triu_indices(rows: int, cols: int): 57 indices = torch.triu_indices(rows, cols, dtype=torch.int32) 58 # Have to perform assertion here because TorchScript returns dtypes 59 # as integers, which are not comparable against eager torch.dtype. 60 assert indices.dtype == torch.int32 61 62 self.checkScript(triu_indices, (3, 3)) 63 64 def test_tril_indices_default_dtype(self): 65 def tril_indices(rows: int, cols: int): 66 indices = torch.tril_indices(rows, cols) 67 # Have to perform assertion here because TorchScript returns dtypes 68 # as integers, which are not comparable against eager torch.dtype. 69 assert indices.dtype == torch.int64 70 71 self.checkScript(tril_indices, (3, 3)) 72 73 def test_tril_indices_specified_dtype(self): 74 def tril_indices(rows: int, cols: int): 75 indices = torch.tril_indices(rows, cols, dtype=torch.int32) 76 # Have to perform assertion here because TorchScript returns dtypes 77 # as integers, which are not comparable against eager torch.dtype. 78 assert indices.dtype == torch.int32 79 80 self.checkScript(tril_indices, (3, 3)) 81