xref: /aosp_15_r20/external/executorch/backends/arm/arm_vela.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2023-2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8import os
9import struct
10import tempfile
11
12from typing import List
13
14import numpy as np
15from ethosu.vela import vela
16
17
18# Pack either input or output tensor block, compose the related arrays into
19# per-io structs to simplify runtime use.
20def vela_bin_pack_io(prefix, data):
21    ios = struct.pack("<i", len(data[prefix + "_shape"]))
22    for i in range(len(data[prefix + "_shape"])):
23        io_shape = data[prefix + "_shape"][i]
24        io_elem_size = data[prefix + "_elem_size"][i]
25        io_offset = data[prefix + "_offset"][i]
26        io_region = data[prefix + "_region"][i]
27        assert len(io_shape) <= 4
28        inp_pad = io_shape.tolist() + [0] * (4 - len(io_shape))
29        io_struct = struct.pack(
30            "<iiiiiii", *inp_pad, io_elem_size, io_offset, io_region
31        )
32        ios += io_struct
33    return ios
34
35
36# Output via Vela to binary stream for ArmBackendEthosU
37# WARNING: Do not change this without changing VelaBinStream.cpp as that
38#          function consumes this format and the two need to align.
39def vela_compile(tosa_graph, args: List[str]):
40    with tempfile.TemporaryDirectory() as tmpdir:
41        tosaname = "out.tosa"
42        flatbuffer = tosa_graph.serialize()
43        tosa_path = os.path.join(tmpdir, tosaname)
44        with open(tosa_path, "wb") as f:
45            f.write(flatbuffer)
46
47        # invoke vela
48        output_dir = os.path.join(tmpdir, "output")
49        args.append(f"--output-dir={output_dir}")
50        args.append(tosa_path)
51        vela.main(" ".join(args).split(" "))
52
53        if any("ethos-u85" in arg for arg in args) or any(
54            "debug-force-regor" in arg for arg in args
55        ):
56            np_path = os.path.join(tmpdir, "output", "out_vela.npz")
57        else:
58            np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz")
59        blocks = b""
60
61        with np.load(np_path, allow_pickle=False) as data:
62            # Construct our modified output_blocks with data in a form easily
63            # digested on the device side
64            bin_blocks = {"vela_bin_stream": b""}
65
66            # copy command data through unmodified
67            bin_blocks["cmd_data"] = data["cmd_data"].tobytes()
68
69            # copy weight data through unmodified
70            bin_blocks["weight_data"] = data["weight_data"].tobytes()
71
72            # Add a block for scratch, inputs and outputs;  scratch shape is a 1 element
73            # array giving us size in bytes so extract this and add a block of 0's.
74            # Currently we preallocated this on the host to provide SRAM for computation.
75            if not isinstance(data["scratch_shape"][0], np.int64):
76                raise RuntimeError("Expected scratch to be int64")
77            block_length = int(data["scratch_shape"][0])
78            bin_blocks["scratch_data"] = b"\x00" * block_length
79
80            # Capture inputs and outputs
81            bin_blocks["inputs"] = vela_bin_pack_io("input", data)
82            bin_blocks["outputs"] = vela_bin_pack_io("output", data)
83
84            bin_blocks["vela_end_stream"] = b""
85
86            # Emit the NPZ regions as:
87            #  - 16 byte block name null terminated string (padded to 16 if name shorter)
88            #  - 4 bytes of int32 block length and 12 bytes of 0's
89            #  - block data (padded to 16 byte alignment at end)
90            # Repeat for all blocks
91            for key in bin_blocks.keys():
92                block_name = bytes(key, "utf8")[:15]
93                block_name = block_name + b"\x00" * (16 - len(block_name))
94
95                # We need the acual unpadded block lengths for hw setup
96                block_length = struct.pack("<iiii", len(bin_blocks[key]), 0, 0, 0)
97
98                # Pad block data to multiple of 16 bytes
99                block_data = bin_blocks[key]
100                block_data = block_data + b"\x00" * (15 - (len(block_data) - 1) % 16)
101
102                block = block_name + block_length + block_data
103                blocks = blocks + block
104
105        return blocks
106