xref: /aosp_15_r20/external/executorch/exir/tests/test_tensor.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8# pyre-ignore-all-errors[6]
9# pyre-ignore-all-errors[16]
10import unittest
11
12from typing import List, Optional
13
14import executorch.exir.schema as schema
15
16import torch
17from executorch.exir.tensor import (
18    contiguous_stride_from_shape,
19    dim_order_from_stride,
20    make_allocation_info,
21    make_tensor_value,
22    num_bytes_from_shape_and_dtype,
23    scalar_type_enum,
24    stride_from_dim_order,
25    TensorSpec,
26)
27
28
29class TestTensor(unittest.TestCase):
30    def compare_tensors(
31        self,
32        torch_tensor: torch.Tensor,
33        flatbuffer_tensor: schema.Tensor,
34        dim_order: Optional[List[int]] = None,
35    ) -> None:
36        """Checks if the given normal torch tensor is equivalent to the
37        flatbuffer tensor.
38        """
39        self.assertEqual(
40            flatbuffer_tensor.scalar_type, scalar_type_enum(torch_tensor.dtype)
41        )
42        # The runtime currently only supports tensors with offset 0.
43        self.assertEqual(flatbuffer_tensor.storage_offset, 0)
44        self.assertEqual(flatbuffer_tensor.sizes, list(torch_tensor.size()))
45        self.assertEqual(flatbuffer_tensor.requires_grad, torch_tensor.requires_grad)
46        if dim_order is not None:
47            self.assertEqual(flatbuffer_tensor.dim_order, dim_order)
48
49    def test_normal_tensor_conversion(self) -> None:
50        """Testing a normal tensor"""
51
52        normal_tensor = torch.randn(2, 2, 3)
53        flatbuffer_tensor = make_tensor_value(
54            1, 0, TensorSpec.from_tensor(normal_tensor)
55        )
56        self.compare_tensors(normal_tensor, flatbuffer_tensor)
57
58        # Test zero size tensor
59        normal_tensor = torch.randn(2, 2, 0)
60        flatbuffer_tensor = make_tensor_value(
61            1, 0, TensorSpec.from_tensor(normal_tensor)
62        )
63        self.compare_tensors(normal_tensor, flatbuffer_tensor)
64
65        # Test zero size tensor
66        normal_tensor = torch.randn(2, 0, 3)
67        flatbuffer_tensor = make_tensor_value(
68            1, 0, TensorSpec.from_tensor(normal_tensor)
69        )
70        self.compare_tensors(normal_tensor, flatbuffer_tensor)
71
72        # Test zero size tensor
73        normal_tensor = torch.randn(0, 2, 3)
74        flatbuffer_tensor = make_tensor_value(
75            1, 0, TensorSpec.from_tensor(normal_tensor)
76        )
77        self.compare_tensors(normal_tensor, flatbuffer_tensor)
78
79        # Compare dim order
80        normal_tensor = torch.rand((2, 2, 3, 4))
81        flatbuffer_tensor = make_tensor_value(
82            1, 0, TensorSpec.from_tensor(normal_tensor)
83        )
84        self.compare_tensors(normal_tensor, flatbuffer_tensor, dim_order=[0, 1, 2, 3])
85        # cannot compare torch.memory_format = torch.channels_last because make_tensor_value
86        # infers strides from sizes assuming tensor dimensions are laid out in memory
87        # in the same order as indicated by dimension order of sizes array.
88        # e.g. for sizes = (2, 3, 4, 5), it assumes dimension order is (0, 1, 2, 3) and
89        # thus strides = (3*4*5, 4*5, 5, 1)
90        # whereas strides for torch.memory_format = torch.channels_last is
91        # (3*4*5, 1, 5*3, 3))
92
93    def test_allocation_info_succeeds(self) -> None:
94        test_cases = (
95            (
96                {"mem_id": 0, "mem_offset": 0},
97                schema.AllocationDetails(
98                    memory_id=0, memory_offset_low=0, memory_offset_high=0
99                ),
100            ),
101            (
102                # Easily fits in 32 bits
103                {"mem_id": 1, "mem_offset": 55555},
104                schema.AllocationDetails(
105                    memory_id=1, memory_offset_low=55555, memory_offset_high=0
106                ),
107            ),
108            (
109                # Just fits in 32 bits
110                {"mem_id": 1, "mem_offset": (1 << 32) - 1},
111                schema.AllocationDetails(
112                    memory_id=1, memory_offset_low=0xFFFFFFFF, memory_offset_high=0
113                ),
114            ),
115            (
116                # Smallest 32-bit overflow.
117                {"mem_id": 1, "mem_offset": 1 << 32},
118                schema.AllocationDetails(
119                    memory_id=1, memory_offset_low=0, memory_offset_high=1
120                ),
121            ),
122            (
123                # Easily fits in 64 bits.
124                {"mem_id": 1, "mem_offset": (1 << 64) - 55555555},
125                schema.AllocationDetails(
126                    memory_id=1,
127                    memory_offset_low=4239411741,
128                    memory_offset_high=4294967295,
129                ),
130            ),
131            (
132                # Just fits in 64 bits
133                {"mem_id": 1, "mem_offset": (1 << 64) - 1},
134                schema.AllocationDetails(
135                    memory_id=1,
136                    memory_offset_low=0xFFFFFFFF,
137                    memory_offset_high=0xFFFFFFFF,
138                ),
139            ),
140        )
141        for test_case in test_cases:
142            allocation_info = make_allocation_info(**(test_case[0]))
143            self.assertEqual(allocation_info, test_case[1])
144
145    def test_allocation_info_fails(self) -> None:
146        test_cases = (
147            (
148                # Significant negative underflow.
149                {"mem_id": 0, "mem_offset": -55555},
150                # Error message should complain about the negative value.
151                "negative",
152            ),
153            (
154                # Smallest negative underflow.
155                {"mem_id": 0, "mem_offset": -1},
156                # Error message should complain about the negative value.
157                "negative",
158            ),
159            (
160                # Smallest 64-bit overflow.
161                {"mem_id": 1, "mem_offset": 1 << 64},
162                # Error message should complain that the value is too large.
163                "64 bits",
164            ),
165            (
166                # Significant 64-bit overflow.
167                {"mem_id": 1, "mem_offset": (1 << 64) + 55555},
168                # Error message should complain that the value is too large.
169                "64 bits",
170            ),
171        )
172        for test_case in test_cases:
173            kwargs = test_case[0]
174            with self.assertRaisesRegex(Exception, test_case[1], msg=f"{kwargs}"):
175                make_allocation_info(**kwargs)
176
177    def test_contiguous_stride_from_shape(self) -> None:
178        shape = (2, 3, 4)
179        stride = contiguous_stride_from_shape(torch.Size(shape))
180        self.assertEqual((12, 4, 1), stride)
181
182    def test_dim_order_from_stride(self) -> None:
183        # shape = (4)
184        strides = (1,)
185        dim_order = dim_order_from_stride(strides)
186        print(dim_order)
187        self.assertEqual((0,), dim_order)
188
189        # Test contiguous, a.k.a NCHW format
190        # shape = (2, 3, 4)
191        strides = (3 * 4, 4, 1)
192        dim_order = dim_order_from_stride(strides)
193        self.assertEqual((0, 1, 2), dim_order)
194
195        # shape = (2, 3, 4, 5)
196        strides = (3 * 4 * 5, 4 * 5, 5, 1)
197        dim_order = dim_order_from_stride(strides)
198        self.assertEqual((0, 1, 2, 3), dim_order)
199
200        # shape = (2, 3, 4, 5, 6)
201        strides = (3 * 4 * 5 * 6, 4 * 5 * 6, 5 * 6, 6, 1)
202        dim_order = dim_order_from_stride(strides)
203        self.assertEqual((0, 1, 2, 3, 4), dim_order)
204
205        # Test channels last format
206        # shape = (2, 3, 4)
207        strides = (3 * 4, 1, 3)
208        dim_order = dim_order_from_stride(strides)
209        self.assertEqual((0, 2, 1), dim_order)
210
211        # shape = (2, 3, 4, 5)
212        strides = (3 * 4 * 5, 1, 5 * 3, 3)
213        dim_order = dim_order_from_stride(strides)
214        self.assertEqual((0, 2, 3, 1), dim_order)
215
216        # shape = (2, 3, 4, 5, 6)
217        strides = (3 * 4 * 5 * 6, 1, 5 * 6 * 3, 6 * 3, 3)
218        dim_order = dim_order_from_stride(strides)
219        self.assertEqual((0, 2, 3, 4, 1), dim_order)
220
221        # test ambiguous strides
222        # shape = (1, 3, 3, 1)
223        strides = (9, 3, 1, 1)
224        dim_order = dim_order_from_stride(strides)
225        self.assertEqual((0, 1, 2, 3), dim_order)
226
227        # test ambiguous strides
228        # shape = (1, 3, 1, 1)
229        strides = (3, 1, 3, 3)
230        dim_order = dim_order_from_stride(strides)
231        self.assertEqual((0, 2, 3, 1), dim_order)
232
233        # test ambiguous strides
234        # shape = (1, 3, 1, 1)
235        strides = (3, 1, 1, 1)
236        dim_order = dim_order_from_stride(strides)
237        self.assertEqual((0, 1, 2, 3), dim_order)
238
239        # test ambiguous strides
240        # shape = (1, 1, 1, 1)
241        strides = (1, 1, 1, 1)
242        dim_order = dim_order_from_stride(strides)
243        self.assertEqual((0, 1, 2, 3), dim_order)
244
245        # test 0 in strides
246        # dim[2] is broadcasting dim
247        # shape = (5, 1, 15, 10)
248        strides = (10, 10, 0, 1)
249        with self.assertRaises(ValueError):
250            dim_order = dim_order_from_stride(strides)
251
252    def test_strides_from_dim_order(self) -> None:
253        sizes = []
254        dim_order = []
255        strides = stride_from_dim_order(sizes, dim_order)
256        self.assertEqual([], strides)
257
258        sizes = [
259            4,
260        ]
261        dim_order = [
262            0,
263        ]
264        expected_strides = [
265            1,
266        ]
267        strides = stride_from_dim_order(sizes, dim_order)
268        self.assertEqual(expected_strides, strides)
269
270        # Test contiguous, a.k.a NCHW format
271        sizes = [2, 3, 4]
272        dim_order = [0, 1, 2]
273        expected_strides = [3 * 4, 4, 1]
274        strides = stride_from_dim_order(sizes, dim_order)
275        self.assertEqual(expected_strides, strides)
276
277        sizes = [2, 3, 4, 5]
278        dim_order = [0, 1, 2, 3]
279        expected_strides = [3 * 4 * 5, 4 * 5, 5, 1]
280        strides = stride_from_dim_order(sizes, dim_order)
281        self.assertEqual(expected_strides, strides)
282
283        sizes = [2, 3, 4, 5, 6]
284        dim_order = [0, 1, 2, 3, 4]
285        expected_strides = [3 * 4 * 5 * 6, 4 * 5 * 6, 5 * 6, 6, 1]
286        strides = stride_from_dim_order(sizes, dim_order)
287        self.assertEqual(expected_strides, strides)
288
289        # Test channels last format
290        sizes = [2, 3, 4]
291        dim_order = [0, 2, 1]
292        expected_strides = [3 * 4, 1, 3]
293        strides = stride_from_dim_order(sizes, dim_order)
294        self.assertEqual(expected_strides, strides)
295
296        sizes = [2, 3, 4, 5]
297        dim_order = [0, 2, 3, 1]
298        expected_strides = [3 * 4 * 5, 1, 5 * 3, 3]
299        strides = stride_from_dim_order(sizes, dim_order)
300        self.assertEqual(expected_strides, strides)
301
302        sizes = [2, 3, 4, 5, 6]
303        dim_order = [0, 2, 3, 4, 1]
304        expected_strides = [3 * 4 * 5 * 6, 1, 5 * 6 * 3, 6 * 3, 3]
305        strides = stride_from_dim_order(sizes, dim_order)
306        self.assertEqual(expected_strides, strides)
307
308        # test ambiguous strides
309        sizes = [1, 3, 3, 1]
310        dim_order = [0, 1, 2, 3]
311        expected_strides = [9, 3, 1, 1]
312        strides = stride_from_dim_order(sizes, dim_order)
313        self.assertEqual(expected_strides, strides)
314
315        # test ambiguous strides
316        sizes = [1, 3, 1, 1]
317        dim_order = [0, 2, 3, 1]
318        expected_strides = [3, 1, 3, 3]
319        strides = stride_from_dim_order(sizes, dim_order)
320        self.assertEqual(expected_strides, strides)
321
322        # test ambiguous strides
323        sizes = [1, 3, 1, 1]
324        dim_order = [0, 1, 2, 3]
325        expected_strides = [3, 1, 1, 1]
326        strides = stride_from_dim_order(sizes, dim_order)
327        self.assertEqual(expected_strides, strides)
328
329        # test ambiguous strides
330        sizes = [1, 1, 1, 1]
331        dim_order = [0, 1, 2, 3]
332        expected_strides = [1, 1, 1, 1]
333        strides = stride_from_dim_order(sizes, dim_order)
334        self.assertEqual(expected_strides, strides)
335
336    def test_num_bytes_from_shape_and_dtype(self) -> None:
337        shape = (2, 3, 4)
338        self.assertEqual(24, num_bytes_from_shape_and_dtype(shape, torch.int8))
339        self.assertEqual(48, num_bytes_from_shape_and_dtype(shape, torch.half))
340        self.assertEqual(96, num_bytes_from_shape_and_dtype(shape, torch.float))
341        self.assertEqual(192, num_bytes_from_shape_and_dtype(shape, torch.float64))
342