1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7 8import io 9import unittest 10 11from executorch.exir._serialize._cord import Cord 12 13 14class TestCord(unittest.TestCase): 15 def test_cord_init(self) -> None: 16 cord_empty = Cord() 17 self.assertEqual(0, len(cord_empty)) 18 19 cord = Cord(b"HelloWorld") 20 self.assertEqual(10, len(cord)) 21 self.assertEqual(b"HelloWorld", bytes(cord)) 22 23 cord2 = Cord(cord) 24 self.assertEqual(10, len(cord2)) 25 self.assertEqual(b"HelloWorld", bytes(cord)) 26 27 # Confirm no copies were made. 28 self.assertEqual(id(cord._buffers[0]), id(cord2._buffers[0])) 29 30 def test_cord_append(self) -> None: 31 cord = Cord() 32 cord.append(b"Hello") 33 self.assertEqual(5, len(cord)) 34 self.assertEqual(b"Hello", bytes(cord)) 35 36 cord.append(b"World") 37 self.assertEqual(10, len(cord)) 38 self.assertEqual(b"HelloWorld", bytes(cord)) 39 40 def test_cord_append_cord(self) -> None: 41 cord = Cord() 42 cord.append(b"Hello") 43 cord.append((b"World")) 44 45 cord2 = Cord() 46 cord2.append(b"Prefix") 47 cord2.append(cord) 48 49 self.assertEqual(16, len(cord2)) 50 self.assertEqual(b"PrefixHelloWorld", bytes(cord2)) 51 52 # Confirm that no copies were made when appending a Cord. 53 self.assertEqual(id(cord2._buffers[1]), id(cord._buffers[0])) 54 self.assertEqual(id(cord2._buffers[2]), id(cord._buffers[1])) 55 56 def test_cord_write_to_file(self) -> None: 57 cord = Cord() 58 cord.append(b"Hello") 59 cord.append(b"World") 60 61 outfile = io.BytesIO() 62 cord.write_to_file(outfile) 63 self.assertEqual(b"HelloWorld", outfile.getvalue()) 64