xref: /aosp_15_r20/external/executorch/exir/_serialize/test/test_cord.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8import io
9import unittest
10
11from executorch.exir._serialize._cord import Cord
12
13
14class TestCord(unittest.TestCase):
15    def test_cord_init(self) -> None:
16        cord_empty = Cord()
17        self.assertEqual(0, len(cord_empty))
18
19        cord = Cord(b"HelloWorld")
20        self.assertEqual(10, len(cord))
21        self.assertEqual(b"HelloWorld", bytes(cord))
22
23        cord2 = Cord(cord)
24        self.assertEqual(10, len(cord2))
25        self.assertEqual(b"HelloWorld", bytes(cord))
26
27        # Confirm no copies were made.
28        self.assertEqual(id(cord._buffers[0]), id(cord2._buffers[0]))
29
30    def test_cord_append(self) -> None:
31        cord = Cord()
32        cord.append(b"Hello")
33        self.assertEqual(5, len(cord))
34        self.assertEqual(b"Hello", bytes(cord))
35
36        cord.append(b"World")
37        self.assertEqual(10, len(cord))
38        self.assertEqual(b"HelloWorld", bytes(cord))
39
40    def test_cord_append_cord(self) -> None:
41        cord = Cord()
42        cord.append(b"Hello")
43        cord.append((b"World"))
44
45        cord2 = Cord()
46        cord2.append(b"Prefix")
47        cord2.append(cord)
48
49        self.assertEqual(16, len(cord2))
50        self.assertEqual(b"PrefixHelloWorld", bytes(cord2))
51
52        # Confirm that no copies were made when appending a Cord.
53        self.assertEqual(id(cord2._buffers[1]), id(cord._buffers[0]))
54        self.assertEqual(id(cord2._buffers[2]), id(cord._buffers[1]))
55
56    def test_cord_write_to_file(self) -> None:
57        cord = Cord()
58        cord.append(b"Hello")
59        cord.append(b"World")
60
61        outfile = io.BytesIO()
62        cord.write_to_file(outfile)
63        self.assertEqual(b"HelloWorld", outfile.getvalue())
64