1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8# pyre-ignore-all-errors[6] 9# pyre-ignore-all-errors[16] 10import unittest 11 12from typing import List, Optional 13 14import executorch.exir.schema as schema 15 16import torch 17from executorch.exir.tensor import ( 18 contiguous_stride_from_shape, 19 dim_order_from_stride, 20 make_allocation_info, 21 make_tensor_value, 22 num_bytes_from_shape_and_dtype, 23 scalar_type_enum, 24 stride_from_dim_order, 25 TensorSpec, 26) 27 28 29class TestTensor(unittest.TestCase): 30 def compare_tensors( 31 self, 32 torch_tensor: torch.Tensor, 33 flatbuffer_tensor: schema.Tensor, 34 dim_order: Optional[List[int]] = None, 35 ) -> None: 36 """Checks if the given normal torch tensor is equivalent to the 37 flatbuffer tensor. 38 """ 39 self.assertEqual( 40 flatbuffer_tensor.scalar_type, scalar_type_enum(torch_tensor.dtype) 41 ) 42 # The runtime currently only supports tensors with offset 0. 43 self.assertEqual(flatbuffer_tensor.storage_offset, 0) 44 self.assertEqual(flatbuffer_tensor.sizes, list(torch_tensor.size())) 45 self.assertEqual(flatbuffer_tensor.requires_grad, torch_tensor.requires_grad) 46 if dim_order is not None: 47 self.assertEqual(flatbuffer_tensor.dim_order, dim_order) 48 49 def test_normal_tensor_conversion(self) -> None: 50 """Testing a normal tensor""" 51 52 normal_tensor = torch.randn(2, 2, 3) 53 flatbuffer_tensor = make_tensor_value( 54 1, 0, TensorSpec.from_tensor(normal_tensor) 55 ) 56 self.compare_tensors(normal_tensor, flatbuffer_tensor) 57 58 # Test zero size tensor 59 normal_tensor = torch.randn(2, 2, 0) 60 flatbuffer_tensor = make_tensor_value( 61 1, 0, TensorSpec.from_tensor(normal_tensor) 62 ) 63 self.compare_tensors(normal_tensor, flatbuffer_tensor) 64 65 # Test zero size tensor 66 normal_tensor = torch.randn(2, 0, 3) 67 flatbuffer_tensor = make_tensor_value( 68 1, 0, TensorSpec.from_tensor(normal_tensor) 69 ) 70 self.compare_tensors(normal_tensor, flatbuffer_tensor) 71 72 # Test zero size tensor 73 normal_tensor = torch.randn(0, 2, 3) 74 flatbuffer_tensor = make_tensor_value( 75 1, 0, TensorSpec.from_tensor(normal_tensor) 76 ) 77 self.compare_tensors(normal_tensor, flatbuffer_tensor) 78 79 # Compare dim order 80 normal_tensor = torch.rand((2, 2, 3, 4)) 81 flatbuffer_tensor = make_tensor_value( 82 1, 0, TensorSpec.from_tensor(normal_tensor) 83 ) 84 self.compare_tensors(normal_tensor, flatbuffer_tensor, dim_order=[0, 1, 2, 3]) 85 # cannot compare torch.memory_format = torch.channels_last because make_tensor_value 86 # infers strides from sizes assuming tensor dimensions are laid out in memory 87 # in the same order as indicated by dimension order of sizes array. 88 # e.g. for sizes = (2, 3, 4, 5), it assumes dimension order is (0, 1, 2, 3) and 89 # thus strides = (3*4*5, 4*5, 5, 1) 90 # whereas strides for torch.memory_format = torch.channels_last is 91 # (3*4*5, 1, 5*3, 3)) 92 93 def test_allocation_info_succeeds(self) -> None: 94 test_cases = ( 95 ( 96 {"mem_id": 0, "mem_offset": 0}, 97 schema.AllocationDetails( 98 memory_id=0, memory_offset_low=0, memory_offset_high=0 99 ), 100 ), 101 ( 102 # Easily fits in 32 bits 103 {"mem_id": 1, "mem_offset": 55555}, 104 schema.AllocationDetails( 105 memory_id=1, memory_offset_low=55555, memory_offset_high=0 106 ), 107 ), 108 ( 109 # Just fits in 32 bits 110 {"mem_id": 1, "mem_offset": (1 << 32) - 1}, 111 schema.AllocationDetails( 112 memory_id=1, memory_offset_low=0xFFFFFFFF, memory_offset_high=0 113 ), 114 ), 115 ( 116 # Smallest 32-bit overflow. 117 {"mem_id": 1, "mem_offset": 1 << 32}, 118 schema.AllocationDetails( 119 memory_id=1, memory_offset_low=0, memory_offset_high=1 120 ), 121 ), 122 ( 123 # Easily fits in 64 bits. 124 {"mem_id": 1, "mem_offset": (1 << 64) - 55555555}, 125 schema.AllocationDetails( 126 memory_id=1, 127 memory_offset_low=4239411741, 128 memory_offset_high=4294967295, 129 ), 130 ), 131 ( 132 # Just fits in 64 bits 133 {"mem_id": 1, "mem_offset": (1 << 64) - 1}, 134 schema.AllocationDetails( 135 memory_id=1, 136 memory_offset_low=0xFFFFFFFF, 137 memory_offset_high=0xFFFFFFFF, 138 ), 139 ), 140 ) 141 for test_case in test_cases: 142 allocation_info = make_allocation_info(**(test_case[0])) 143 self.assertEqual(allocation_info, test_case[1]) 144 145 def test_allocation_info_fails(self) -> None: 146 test_cases = ( 147 ( 148 # Significant negative underflow. 149 {"mem_id": 0, "mem_offset": -55555}, 150 # Error message should complain about the negative value. 151 "negative", 152 ), 153 ( 154 # Smallest negative underflow. 155 {"mem_id": 0, "mem_offset": -1}, 156 # Error message should complain about the negative value. 157 "negative", 158 ), 159 ( 160 # Smallest 64-bit overflow. 161 {"mem_id": 1, "mem_offset": 1 << 64}, 162 # Error message should complain that the value is too large. 163 "64 bits", 164 ), 165 ( 166 # Significant 64-bit overflow. 167 {"mem_id": 1, "mem_offset": (1 << 64) + 55555}, 168 # Error message should complain that the value is too large. 169 "64 bits", 170 ), 171 ) 172 for test_case in test_cases: 173 kwargs = test_case[0] 174 with self.assertRaisesRegex(Exception, test_case[1], msg=f"{kwargs}"): 175 make_allocation_info(**kwargs) 176 177 def test_contiguous_stride_from_shape(self) -> None: 178 shape = (2, 3, 4) 179 stride = contiguous_stride_from_shape(torch.Size(shape)) 180 self.assertEqual((12, 4, 1), stride) 181 182 def test_dim_order_from_stride(self) -> None: 183 # shape = (4) 184 strides = (1,) 185 dim_order = dim_order_from_stride(strides) 186 print(dim_order) 187 self.assertEqual((0,), dim_order) 188 189 # Test contiguous, a.k.a NCHW format 190 # shape = (2, 3, 4) 191 strides = (3 * 4, 4, 1) 192 dim_order = dim_order_from_stride(strides) 193 self.assertEqual((0, 1, 2), dim_order) 194 195 # shape = (2, 3, 4, 5) 196 strides = (3 * 4 * 5, 4 * 5, 5, 1) 197 dim_order = dim_order_from_stride(strides) 198 self.assertEqual((0, 1, 2, 3), dim_order) 199 200 # shape = (2, 3, 4, 5, 6) 201 strides = (3 * 4 * 5 * 6, 4 * 5 * 6, 5 * 6, 6, 1) 202 dim_order = dim_order_from_stride(strides) 203 self.assertEqual((0, 1, 2, 3, 4), dim_order) 204 205 # Test channels last format 206 # shape = (2, 3, 4) 207 strides = (3 * 4, 1, 3) 208 dim_order = dim_order_from_stride(strides) 209 self.assertEqual((0, 2, 1), dim_order) 210 211 # shape = (2, 3, 4, 5) 212 strides = (3 * 4 * 5, 1, 5 * 3, 3) 213 dim_order = dim_order_from_stride(strides) 214 self.assertEqual((0, 2, 3, 1), dim_order) 215 216 # shape = (2, 3, 4, 5, 6) 217 strides = (3 * 4 * 5 * 6, 1, 5 * 6 * 3, 6 * 3, 3) 218 dim_order = dim_order_from_stride(strides) 219 self.assertEqual((0, 2, 3, 4, 1), dim_order) 220 221 # test ambiguous strides 222 # shape = (1, 3, 3, 1) 223 strides = (9, 3, 1, 1) 224 dim_order = dim_order_from_stride(strides) 225 self.assertEqual((0, 1, 2, 3), dim_order) 226 227 # test ambiguous strides 228 # shape = (1, 3, 1, 1) 229 strides = (3, 1, 3, 3) 230 dim_order = dim_order_from_stride(strides) 231 self.assertEqual((0, 2, 3, 1), dim_order) 232 233 # test ambiguous strides 234 # shape = (1, 3, 1, 1) 235 strides = (3, 1, 1, 1) 236 dim_order = dim_order_from_stride(strides) 237 self.assertEqual((0, 1, 2, 3), dim_order) 238 239 # test ambiguous strides 240 # shape = (1, 1, 1, 1) 241 strides = (1, 1, 1, 1) 242 dim_order = dim_order_from_stride(strides) 243 self.assertEqual((0, 1, 2, 3), dim_order) 244 245 # test 0 in strides 246 # dim[2] is broadcasting dim 247 # shape = (5, 1, 15, 10) 248 strides = (10, 10, 0, 1) 249 with self.assertRaises(ValueError): 250 dim_order = dim_order_from_stride(strides) 251 252 def test_strides_from_dim_order(self) -> None: 253 sizes = [] 254 dim_order = [] 255 strides = stride_from_dim_order(sizes, dim_order) 256 self.assertEqual([], strides) 257 258 sizes = [ 259 4, 260 ] 261 dim_order = [ 262 0, 263 ] 264 expected_strides = [ 265 1, 266 ] 267 strides = stride_from_dim_order(sizes, dim_order) 268 self.assertEqual(expected_strides, strides) 269 270 # Test contiguous, a.k.a NCHW format 271 sizes = [2, 3, 4] 272 dim_order = [0, 1, 2] 273 expected_strides = [3 * 4, 4, 1] 274 strides = stride_from_dim_order(sizes, dim_order) 275 self.assertEqual(expected_strides, strides) 276 277 sizes = [2, 3, 4, 5] 278 dim_order = [0, 1, 2, 3] 279 expected_strides = [3 * 4 * 5, 4 * 5, 5, 1] 280 strides = stride_from_dim_order(sizes, dim_order) 281 self.assertEqual(expected_strides, strides) 282 283 sizes = [2, 3, 4, 5, 6] 284 dim_order = [0, 1, 2, 3, 4] 285 expected_strides = [3 * 4 * 5 * 6, 4 * 5 * 6, 5 * 6, 6, 1] 286 strides = stride_from_dim_order(sizes, dim_order) 287 self.assertEqual(expected_strides, strides) 288 289 # Test channels last format 290 sizes = [2, 3, 4] 291 dim_order = [0, 2, 1] 292 expected_strides = [3 * 4, 1, 3] 293 strides = stride_from_dim_order(sizes, dim_order) 294 self.assertEqual(expected_strides, strides) 295 296 sizes = [2, 3, 4, 5] 297 dim_order = [0, 2, 3, 1] 298 expected_strides = [3 * 4 * 5, 1, 5 * 3, 3] 299 strides = stride_from_dim_order(sizes, dim_order) 300 self.assertEqual(expected_strides, strides) 301 302 sizes = [2, 3, 4, 5, 6] 303 dim_order = [0, 2, 3, 4, 1] 304 expected_strides = [3 * 4 * 5 * 6, 1, 5 * 6 * 3, 6 * 3, 3] 305 strides = stride_from_dim_order(sizes, dim_order) 306 self.assertEqual(expected_strides, strides) 307 308 # test ambiguous strides 309 sizes = [1, 3, 3, 1] 310 dim_order = [0, 1, 2, 3] 311 expected_strides = [9, 3, 1, 1] 312 strides = stride_from_dim_order(sizes, dim_order) 313 self.assertEqual(expected_strides, strides) 314 315 # test ambiguous strides 316 sizes = [1, 3, 1, 1] 317 dim_order = [0, 2, 3, 1] 318 expected_strides = [3, 1, 3, 3] 319 strides = stride_from_dim_order(sizes, dim_order) 320 self.assertEqual(expected_strides, strides) 321 322 # test ambiguous strides 323 sizes = [1, 3, 1, 1] 324 dim_order = [0, 1, 2, 3] 325 expected_strides = [3, 1, 1, 1] 326 strides = stride_from_dim_order(sizes, dim_order) 327 self.assertEqual(expected_strides, strides) 328 329 # test ambiguous strides 330 sizes = [1, 1, 1, 1] 331 dim_order = [0, 1, 2, 3] 332 expected_strides = [1, 1, 1, 1] 333 strides = stride_from_dim_order(sizes, dim_order) 334 self.assertEqual(expected_strides, strides) 335 336 def test_num_bytes_from_shape_and_dtype(self) -> None: 337 shape = (2, 3, 4) 338 self.assertEqual(24, num_bytes_from_shape_and_dtype(shape, torch.int8)) 339 self.assertEqual(48, num_bytes_from_shape_and_dtype(shape, torch.half)) 340 self.assertEqual(96, num_bytes_from_shape_and_dtype(shape, torch.float)) 341 self.assertEqual(192, num_bytes_from_shape_and_dtype(shape, torch.float64)) 342