xref: /aosp_15_r20/external/executorch/backends/arm/arm_vela.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright 2023-2024 Arm Limited and/or its affiliates.
2*523fa7a6SAndroid Build Coastguard Worker#
3*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
4*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
5*523fa7a6SAndroid Build Coastguard Worker
6*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe
7*523fa7a6SAndroid Build Coastguard Worker
8*523fa7a6SAndroid Build Coastguard Workerimport os
9*523fa7a6SAndroid Build Coastguard Workerimport struct
10*523fa7a6SAndroid Build Coastguard Workerimport tempfile
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerfrom typing import List
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport numpy as np
15*523fa7a6SAndroid Build Coastguard Workerfrom ethosu.vela import vela
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker# Pack either input or output tensor block, compose the related arrays into
19*523fa7a6SAndroid Build Coastguard Worker# per-io structs to simplify runtime use.
20*523fa7a6SAndroid Build Coastguard Workerdef vela_bin_pack_io(prefix, data):
21*523fa7a6SAndroid Build Coastguard Worker    ios = struct.pack("<i", len(data[prefix + "_shape"]))
22*523fa7a6SAndroid Build Coastguard Worker    for i in range(len(data[prefix + "_shape"])):
23*523fa7a6SAndroid Build Coastguard Worker        io_shape = data[prefix + "_shape"][i]
24*523fa7a6SAndroid Build Coastguard Worker        io_elem_size = data[prefix + "_elem_size"][i]
25*523fa7a6SAndroid Build Coastguard Worker        io_offset = data[prefix + "_offset"][i]
26*523fa7a6SAndroid Build Coastguard Worker        io_region = data[prefix + "_region"][i]
27*523fa7a6SAndroid Build Coastguard Worker        assert len(io_shape) <= 4
28*523fa7a6SAndroid Build Coastguard Worker        inp_pad = io_shape.tolist() + [0] * (4 - len(io_shape))
29*523fa7a6SAndroid Build Coastguard Worker        io_struct = struct.pack(
30*523fa7a6SAndroid Build Coastguard Worker            "<iiiiiii", *inp_pad, io_elem_size, io_offset, io_region
31*523fa7a6SAndroid Build Coastguard Worker        )
32*523fa7a6SAndroid Build Coastguard Worker        ios += io_struct
33*523fa7a6SAndroid Build Coastguard Worker    return ios
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker# Output via Vela to binary stream for ArmBackendEthosU
37*523fa7a6SAndroid Build Coastguard Worker# WARNING: Do not change this without changing VelaBinStream.cpp as that
38*523fa7a6SAndroid Build Coastguard Worker#          function consumes this format and the two need to align.
39*523fa7a6SAndroid Build Coastguard Workerdef vela_compile(tosa_graph, args: List[str]):
40*523fa7a6SAndroid Build Coastguard Worker    with tempfile.TemporaryDirectory() as tmpdir:
41*523fa7a6SAndroid Build Coastguard Worker        tosaname = "out.tosa"
42*523fa7a6SAndroid Build Coastguard Worker        flatbuffer = tosa_graph.serialize()
43*523fa7a6SAndroid Build Coastguard Worker        tosa_path = os.path.join(tmpdir, tosaname)
44*523fa7a6SAndroid Build Coastguard Worker        with open(tosa_path, "wb") as f:
45*523fa7a6SAndroid Build Coastguard Worker            f.write(flatbuffer)
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Worker        # invoke vela
48*523fa7a6SAndroid Build Coastguard Worker        output_dir = os.path.join(tmpdir, "output")
49*523fa7a6SAndroid Build Coastguard Worker        args.append(f"--output-dir={output_dir}")
50*523fa7a6SAndroid Build Coastguard Worker        args.append(tosa_path)
51*523fa7a6SAndroid Build Coastguard Worker        vela.main(" ".join(args).split(" "))
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker        if any("ethos-u85" in arg for arg in args) or any(
54*523fa7a6SAndroid Build Coastguard Worker            "debug-force-regor" in arg for arg in args
55*523fa7a6SAndroid Build Coastguard Worker        ):
56*523fa7a6SAndroid Build Coastguard Worker            np_path = os.path.join(tmpdir, "output", "out_vela.npz")
57*523fa7a6SAndroid Build Coastguard Worker        else:
58*523fa7a6SAndroid Build Coastguard Worker            np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz")
59*523fa7a6SAndroid Build Coastguard Worker        blocks = b""
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard Worker        with np.load(np_path, allow_pickle=False) as data:
62*523fa7a6SAndroid Build Coastguard Worker            # Construct our modified output_blocks with data in a form easily
63*523fa7a6SAndroid Build Coastguard Worker            # digested on the device side
64*523fa7a6SAndroid Build Coastguard Worker            bin_blocks = {"vela_bin_stream": b""}
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Worker            # copy command data through unmodified
67*523fa7a6SAndroid Build Coastguard Worker            bin_blocks["cmd_data"] = data["cmd_data"].tobytes()
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Worker            # copy weight data through unmodified
70*523fa7a6SAndroid Build Coastguard Worker            bin_blocks["weight_data"] = data["weight_data"].tobytes()
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker            # Add a block for scratch, inputs and outputs;  scratch shape is a 1 element
73*523fa7a6SAndroid Build Coastguard Worker            # array giving us size in bytes so extract this and add a block of 0's.
74*523fa7a6SAndroid Build Coastguard Worker            # Currently we preallocated this on the host to provide SRAM for computation.
75*523fa7a6SAndroid Build Coastguard Worker            if not isinstance(data["scratch_shape"][0], np.int64):
76*523fa7a6SAndroid Build Coastguard Worker                raise RuntimeError("Expected scratch to be int64")
77*523fa7a6SAndroid Build Coastguard Worker            block_length = int(data["scratch_shape"][0])
78*523fa7a6SAndroid Build Coastguard Worker            bin_blocks["scratch_data"] = b"\x00" * block_length
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker            # Capture inputs and outputs
81*523fa7a6SAndroid Build Coastguard Worker            bin_blocks["inputs"] = vela_bin_pack_io("input", data)
82*523fa7a6SAndroid Build Coastguard Worker            bin_blocks["outputs"] = vela_bin_pack_io("output", data)
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker            bin_blocks["vela_end_stream"] = b""
85*523fa7a6SAndroid Build Coastguard Worker
86*523fa7a6SAndroid Build Coastguard Worker            # Emit the NPZ regions as:
87*523fa7a6SAndroid Build Coastguard Worker            #  - 16 byte block name null terminated string (padded to 16 if name shorter)
88*523fa7a6SAndroid Build Coastguard Worker            #  - 4 bytes of int32 block length and 12 bytes of 0's
89*523fa7a6SAndroid Build Coastguard Worker            #  - block data (padded to 16 byte alignment at end)
90*523fa7a6SAndroid Build Coastguard Worker            # Repeat for all blocks
91*523fa7a6SAndroid Build Coastguard Worker            for key in bin_blocks.keys():
92*523fa7a6SAndroid Build Coastguard Worker                block_name = bytes(key, "utf8")[:15]
93*523fa7a6SAndroid Build Coastguard Worker                block_name = block_name + b"\x00" * (16 - len(block_name))
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Worker                # We need the acual unpadded block lengths for hw setup
96*523fa7a6SAndroid Build Coastguard Worker                block_length = struct.pack("<iiii", len(bin_blocks[key]), 0, 0, 0)
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Worker                # Pad block data to multiple of 16 bytes
99*523fa7a6SAndroid Build Coastguard Worker                block_data = bin_blocks[key]
100*523fa7a6SAndroid Build Coastguard Worker                block_data = block_data + b"\x00" * (15 - (len(block_data) - 1) % 16)
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker                block = block_name + block_length + block_data
103*523fa7a6SAndroid Build Coastguard Worker                blocks = blocks + block
104*523fa7a6SAndroid Build Coastguard Worker
105*523fa7a6SAndroid Build Coastguard Worker        return blocks
106