xref: /aosp_15_r20/external/executorch/backends/apple/coreml/runtime/delegate/multiarray.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2//  multiarray.mm
3//  coremlexecutorch
4//
5//  Copyright © 2023 Apple Inc. All rights reserved.
6//
7// Please refer to the license found in the LICENSE file in the root directory of the source tree.
8
9#import <multiarray.h>
10
11#import <Accelerate/Accelerate.h>
12#import <CoreML/CoreML.h>
13#import <functional>
14#import <numeric>
15#import <objc_array_util.h>
16#import <optional>
17#import <vector>
18
19namespace  {
20using namespace executorchcoreml;
21
22// Returns BNNSDataLayout and sets strides from the multi-array strides.
23///
24/// BNNS requires strides to be non-decreasing order;
25/// `bnns_strides[i] <= bnns_strides[i + 1]`. BNNSDataLayout defines
26/// how each dimension is mapped to the stride.
27///
28/// @param multi_array_strides  The multiarray strides.
29/// @param bnns_strides   The bnns strides.
30/// @retval The `BNNSDataLayout`.
31std::optional<BNNSDataLayout> get_bnns_data_layout(const std::vector<ssize_t>& multi_array_strides,
32                                                   size_t *bnns_strides) {
33    bool first_major = false;
34    uint32_t rank = static_cast<uint32_t>(multi_array_strides.size());
35    if (rank > BNNS_MAX_TENSOR_DIMENSION) {
36        return std::nullopt;
37    }
38
39    if (std::is_sorted(multi_array_strides.begin(), multi_array_strides.end(), std::less())) {
40        first_major = false;
41        std::copy(multi_array_strides.begin(), multi_array_strides.end(), bnns_strides);
42    } else if (std::is_sorted(multi_array_strides.begin(), multi_array_strides.end(), std::greater()) ) {
43        first_major = true;
44        std::copy(multi_array_strides.rbegin(), multi_array_strides.rend(), bnns_strides);
45    } else {
46        return std::nullopt;
47    }
48
49    // See BNNSDataLayout's raw value how this bitwise-or makes sense.
50    return (BNNSDataLayout) (0x08000 +                    // flags as canonical first/last major type
51                             0x10000 * rank +             // set dimensionality
52                             (first_major ? 1 : 0));      // set first/last major bit
53}
54
55/// Returns `BNNSDataType` from `MultiArray::DataType`.
56///
57/// @param datatype  The multiarray datatype.
58/// @retval The `BNNSDataType`.
59std::optional<BNNSDataType> get_bnns_data_type(MultiArray::DataType datatype) {
60    switch (datatype) {
61        case MultiArray::DataType::Bool: {
62            return BNNSDataTypeBoolean;
63        }
64        case MultiArray::DataType::Byte: {
65            return BNNSDataTypeUInt8;
66        }
67        case MultiArray::DataType::Char: {
68            return BNNSDataTypeInt8;
69        }
70        case MultiArray::DataType::Short: {
71            return BNNSDataTypeInt16;
72        }
73        case MultiArray::DataType::Int32: {
74            return BNNSDataTypeInt32;
75        }
76        case MultiArray::DataType::Int64: {
77            return BNNSDataTypeInt64;
78        }
79        case MultiArray::DataType::Float16: {
80            return BNNSDataTypeFloat16;
81        }
82        case MultiArray::DataType::Float32: {
83            return BNNSDataTypeFloat32;
84        }
85        default: {
86            return std::nullopt;
87        }
88    }
89}
90
91/// Initializes BNNS array descriptor from multi array.
92///
93/// @param bnns_descriptor   The descriptor to be initialized.
94/// @param multi_array  The multiarray.
95/// @retval `true` if the initialization succeeded otherwise `false`.
96bool init_bnns_descriptor(BNNSNDArrayDescriptor& bnns_descriptor, const MultiArray& multi_array) {
97    const auto& layout = multi_array.layout();
98    if (layout.num_elements() == 1) {
99        return false;
100    }
101
102    auto bnns_datatype = get_bnns_data_type(layout.dataType());
103    if (!bnns_datatype) {
104        return false;
105    }
106
107    std::memset(&bnns_descriptor, 0, sizeof(bnns_descriptor));
108    auto bnns_layout = get_bnns_data_layout(layout.strides(), bnns_descriptor.stride);
109    if (!bnns_layout) {
110        return false;
111    }
112
113    const auto& shape = layout.shape();
114    std::copy(shape.begin(), shape.end(), bnns_descriptor.size);
115    bnns_descriptor.layout = bnns_layout.value();
116    bnns_descriptor.data_scale = 1.0f;
117    bnns_descriptor.data_bias = 0.0f;
118    bnns_descriptor.data_type = bnns_datatype.value();
119    bnns_descriptor.data = multi_array.data();
120
121    return true;
122}
123
124bool copy_using_bnns(const MultiArray& src, MultiArray& dst) {
125    if (dst.layout().num_bytes() < src.layout().num_bytes()) {
126        return false;
127    }
128    BNNSNDArrayDescriptor src_descriptor;
129    if (!init_bnns_descriptor(src_descriptor, src)) {
130        return false;
131    }
132
133    BNNSNDArrayDescriptor dst_descriptor;
134    if (!init_bnns_descriptor(dst_descriptor, dst)) {
135        return false;
136    }
137
138    return BNNSCopy(&dst_descriptor, &src_descriptor, NULL) == 0;
139}
140
141std::vector<MultiArray::MemoryLayout> get_layouts(const std::vector<MultiArray>& arrays) {
142    std::vector<MultiArray::MemoryLayout> result;
143    result.reserve(arrays.size());
144
145    std::transform(arrays.begin(), arrays.end(), std::back_inserter(result), [](const auto& array) {
146        return array.layout();
147    });
148
149    return result;
150}
151
152std::vector<void *> get_datas(const std::vector<MultiArray>& arrays) {
153    std::vector<void *> result;
154    result.reserve(arrays.size());
155
156    std::transform(arrays.begin(), arrays.end(), std::back_inserter(result), [](const auto& array) {
157        return array.data();
158    });
159
160    return result;
161}
162
163// We can coalesce two adjacent dimensions if either dim has size 1 or if `shape[n] * stride[n] == stride[n + 1]`.
164bool can_coalesce_dimensions(const std::vector<size_t>& shape,
165                             const std::vector<ssize_t>& strides,
166                             size_t dim1,
167                             size_t dim2) {
168    auto shape1 = shape[dim1];
169    auto shape2 = shape[dim2];
170    if (shape1 == 1 || shape2 == 1) {
171        return true;
172    }
173
174    auto stride1 = strides[dim1];
175    auto stride2 = strides[dim2];
176    return shape1 * stride1 == stride2;
177}
178
179bool can_coalesce_dimensions(const std::vector<size_t>& shape,
180                             const std::vector<std::vector<ssize_t>>& all_strides,
181                             size_t dim1,
182                             size_t dim2) {
183    for (const auto& strides : all_strides) {
184        if (!::can_coalesce_dimensions(shape, strides, dim1, dim2)) {
185            return false;
186        }
187    }
188
189    return true;
190}
191
192void update_strides(std::vector<std::vector<ssize_t>>& all_strides,
193                    size_t dim1,
194                    size_t dim2) {
195    for (auto& strides : all_strides) {
196        strides[dim1] = strides[dim2];
197    }
198}
199
200std::vector<MultiArray::MemoryLayout> coalesce_dimensions(std::vector<MultiArray::MemoryLayout> layouts) {
201    if (layouts.size() == 0) {
202        return {};
203    }
204
205    std::vector<size_t> shape = layouts.back().shape();
206    // reverse shape.
207    std::reverse(shape.begin(), shape.end());
208    std::vector<std::vector<ssize_t>> all_strides;
209    // reverse strides.
210    all_strides.reserve(layouts.size());
211    std::transform(layouts.begin(), layouts.end(), std::back_inserter(all_strides), [](const MultiArray::MemoryLayout& layout) {
212        auto strides = layout.strides();
213        std::reverse(strides.begin(), strides.end());
214        return strides;
215    });
216    size_t rank = layouts[0].rank();
217    size_t prev_dim = 0;
218    for (size_t dim = 1; dim < rank; ++dim) {
219        if (::can_coalesce_dimensions(shape, all_strides, prev_dim, dim)) {
220            if (shape[prev_dim] == 1) {
221                ::update_strides(all_strides, prev_dim, dim);
222            }
223            shape[prev_dim] *= shape[dim];
224        } else {
225            ++prev_dim;
226            if (prev_dim != dim) {
227                ::update_strides(all_strides, prev_dim, dim);
228                shape[prev_dim] = shape[dim];
229            }
230        }
231    }
232
233    if (rank == prev_dim + 1) {
234        return layouts;
235    }
236
237    shape.resize(prev_dim + 1);
238    for (auto& strides : all_strides) {
239        strides.resize(prev_dim + 1);
240    }
241
242    std::vector<MultiArray::MemoryLayout> result;
243    result.reserve(layouts.size());
244    std::reverse(shape.begin(), shape.end());
245    for (size_t i = 0; i < layouts.size(); ++i) {
246        std::reverse(all_strides[i].begin(), all_strides[i].end());
247        result.emplace_back(layouts[i].dataType(), shape, std::move(all_strides[i]));
248    }
249
250    return result;
251}
252
253enum class Direction : uint8_t {
254    Forward = 0,
255    Backward
256};
257
258void set_data_pointers(std::vector<void *>& data_pointers,
259                       ssize_t index,
260                       size_t dim,
261                       Direction direction,
262                       const std::vector<MultiArray::MemoryLayout>& layouts) {
263    for (size_t i = 0; i < layouts.size(); ++i) {
264        const auto& layout = layouts[i];
265        const ssize_t stride = layout.strides()[dim];
266        const size_t num_bytes = layout.num_bytes();
267        ssize_t offset = 0;
268        switch (direction) {
269            case Direction::Forward: {
270                offset = stride * index * num_bytes;
271                break;
272            }
273            case Direction::Backward: {
274                offset = - stride * index * num_bytes;
275                break;
276            }
277        }
278        data_pointers[i] = (void *)(static_cast<uint8_t *>(data_pointers[i]) + offset);
279    }
280}
281
282void increment_data_pointers(std::vector<void *>& data_pointers,
283                             size_t index,
284                             size_t dim,
285                             const std::vector<MultiArray::MemoryLayout>& layouts) {
286    set_data_pointers(data_pointers, index, dim, Direction::Forward, layouts);
287}
288
289void decrement_data_pointers(std::vector<void *>& data_pointers,
290                             size_t index,
291                             size_t dim,
292                             const std::vector<MultiArray::MemoryLayout>& layouts) {
293    set_data_pointers(data_pointers, index, dim, Direction::Backward, layouts);
294}
295
296class MultiArrayIterator final {
297public:
298    explicit MultiArrayIterator(const std::vector<MultiArray>& arrays)
299    :datas_(get_datas(arrays)),
300    layouts_(coalesce_dimensions(get_layouts(arrays)))
301    {}
302
303private:
304    template<typename FN>
305    void exec(FN&& fn, const std::vector<MultiArray::MemoryLayout>& layouts, std::vector<void *> datas, size_t n) {
306        const auto& layout = layouts.back();
307        // Avoid function call for rank <= 2.
308        switch (n) {
309            case 0: {
310                break;
311            }
312            case 1: {
313                for (size_t i = 0; i < layout.shape()[0]; ++i) {
314                    ::increment_data_pointers(datas, i, 0, layouts);
315                    fn(datas);
316                    ::decrement_data_pointers(datas, i, 0, layouts);
317                }
318                break;
319            }
320            case 2: {
321                for (size_t i = 0; i < layout.shape()[1]; ++i) {
322                    ::increment_data_pointers(datas, i, 1, layouts);
323                    for (size_t j = 0; j < layout.shape()[0]; ++j) {
324                        ::increment_data_pointers(datas, j, 0, layouts);
325                        fn(datas);
326                        ::decrement_data_pointers(datas, j, 0, layouts);
327                    }
328                    ::decrement_data_pointers(datas, i, 1, layouts);
329                }
330
331                break;
332            }
333
334            default: {
335                const size_t bound = layouts.back().shape()[n - 1];
336                for (size_t index = 0; index < bound; ++index) {
337                    ::increment_data_pointers(datas, index, n - 1, layouts);
338                    exec(std::forward<FN>(fn), layouts, datas, n - 1);
339                    ::decrement_data_pointers(datas, index, n - 1, layouts);
340                }
341            }
342        }
343    }
344
345public:
346    template<typename FN>
347    void exec(FN&& fn) {
348        std::vector<void *> datas = datas_;
349        exec(fn, layouts_, datas, layouts_[0].rank());
350    }
351
352private:
353    std::vector<void *> datas_;
354    std::vector<MultiArray::MemoryLayout> layouts_;
355};
356
357/// BNNS has no double type, so we handle the conversions here.
358template<typename T1, typename T2>
359inline void copy_value(void *dst, const void *src) {
360    const T2 *src_ptr = static_cast<const T2 *>(src);
361    T1 *dst_ptr = static_cast<T1 *>(dst);
362    *dst_ptr = static_cast<T1>(*src_ptr);
363}
364
365template<typename T>
366void copy(void *dst,
367          MultiArray::DataType dst_data_type,
368          const void *src) {
369    switch (dst_data_type) {
370        case MultiArray::DataType::Bool: {
371            ::copy_value<bool, T>(dst, src);
372            break;
373        }
374
375        case MultiArray::DataType::Byte: {
376            ::copy_value<uint8_t, T>(dst, src);
377            break;
378        }
379
380        case MultiArray::DataType::Char: {
381            ::copy_value<int8_t, T>(dst, src);
382            break;
383        }
384
385        case MultiArray::DataType::Short: {
386            ::copy_value<int16_t, T>(dst, src);
387            break;
388        }
389
390        case MultiArray::DataType::Int32: {
391            ::copy_value<int32_t, T>(dst, src);
392            break;
393        }
394
395        case MultiArray::DataType::Int64: {
396            ::copy_value<int64_t, T>(dst, src);
397            break;
398        }
399
400        case MultiArray::DataType::Float16: {
401            ::copy_value<_Float16, T>(dst, src);
402            break;
403        }
404
405        case MultiArray::DataType::Float32: {
406            ::copy_value<float, T>(dst, src);
407            break;
408        }
409
410        case MultiArray::DataType::Float64: {
411            ::copy_value<double, T>(dst, src);
412            break;
413        }
414    }
415}
416
417void copy(void *dst,
418          MultiArray::DataType dst_data_type,
419          const void *src,
420          MultiArray::DataType src_data_type) {
421    switch (src_data_type) {
422        case MultiArray::DataType::Bool: {
423            ::copy<uint8_t>(dst, dst_data_type, src);
424            break;
425        }
426
427        case MultiArray::DataType::Byte: {
428            ::copy<uint8_t>(dst, dst_data_type, src);
429            break;
430        }
431
432        case MultiArray::DataType::Char: {
433            ::copy<int8_t>(dst, dst_data_type, src);
434            break;
435        }
436
437        case MultiArray::DataType::Short: {
438            ::copy<int16_t>(dst, dst_data_type, src);
439            break;
440        }
441
442        case MultiArray::DataType::Int32: {
443            ::copy<int32_t>(dst, dst_data_type, src);
444            break;
445        }
446
447        case MultiArray::DataType::Int64: {
448            ::copy<int64_t>(dst, dst_data_type, src);
449            break;
450        }
451
452        case MultiArray::DataType::Float16: {
453            ::copy<_Float16>(dst, dst_data_type, src);
454            break;
455        }
456
457        case MultiArray::DataType::Float32: {
458            ::copy<float>(dst, dst_data_type, src);
459            break;
460        }
461
462        case MultiArray::DataType::Float64: {
463            ::copy<double>(dst, dst_data_type, src);
464            break;
465        }
466    }
467}
468
469void copy(const MultiArray& src, MultiArray& dst, MultiArray::CopyOptions options) {
470    if (options.use_bnns && copy_using_bnns(src, dst)) {
471        return;
472    }
473
474    if (options.use_memcpy &&
475        src.layout().dataType() == dst.layout().dataType() &&
476        src.layout().is_packed() &&
477        dst.layout().is_packed()) {
478        std::memcpy(dst.data(), src.data(), src.layout().num_elements() * src.layout().num_bytes());
479        return;
480    }
481
482    auto iterator = MultiArrayIterator({src, dst});
483    iterator.exec([&](const std::vector<void *>& datas){
484        void *src_data = datas[0];
485        void *dst_data = datas[1];
486        ::copy(dst_data, dst.layout().dataType(), src_data, src.layout().dataType());
487    });
488}
489
490ssize_t get_data_offset(const std::vector<size_t>& indices, const std::vector<ssize_t>& strides) {
491    ssize_t offset = 0;
492    for (size_t i = 0; i < indices.size(); ++i) {
493        offset += static_cast<ssize_t>(indices[i]) * strides[i];
494    }
495
496    return offset;
497}
498
499ssize_t get_data_offset(size_t index, const std::vector<size_t>& shape, const std::vector<ssize_t>& strides) {
500    size_t div = std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());;
501    size_t offset = 0;
502    for (size_t i = 0; i < shape.size(); ++i) {
503        div /= shape[i];
504        size_t dim_index = index / div;
505        offset += dim_index * strides[i];
506        index %= div;
507    }
508
509    return offset;
510}
511}
512
513namespace executorchcoreml {
514
515size_t MultiArray::MemoryLayout::num_elements() const noexcept {
516    if (shape_.size() == 0) {
517        return 0;
518    }
519
520    return std::accumulate(shape_.begin(), shape_.end(), size_t(1), std::multiplies<size_t>());
521}
522
523bool MultiArray::MemoryLayout::is_packed() const noexcept {
524    if (strides_.size() < 2) {
525        return true;
526    }
527
528    ssize_t expectedStride = 1;
529    auto stridesIt = strides_.crbegin();
530    for (auto shapeIt = shape_.crbegin(); shapeIt!= shape_.crend(); shapeIt++) {
531        if (*stridesIt != expectedStride) {
532            return false;
533        }
534        expectedStride = expectedStride * (*shapeIt);
535        stridesIt++;
536    }
537
538    return true;
539}
540
541size_t MultiArray::MemoryLayout::num_bytes() const noexcept {
542    switch (dataType()) {
543        case MultiArray::DataType::Bool: {
544            return 1;
545        }
546        case MultiArray::DataType::Byte: {
547            return 1;
548        }
549        case MultiArray::DataType::Char: {
550            return 1;
551        }
552        case MultiArray::DataType::Short: {
553            return 2;
554        }
555        case MultiArray::DataType::Int32: {
556            return 4;
557        }
558        case MultiArray::DataType::Int64: {
559            return 8;
560        }
561        case MultiArray::DataType::Float16: {
562            return 2;
563        }
564        case MultiArray::DataType::Float32: {
565            return 4;
566        }
567        case MultiArray::DataType::Float64: {
568            return 8;
569        }
570    }
571}
572
573void MultiArray::copy(MultiArray& dst, CopyOptions options) const noexcept {
574    assert(layout().shape() == dst.layout().shape());
575    ::copy(*this, dst, options);
576}
577
578std::optional<MLMultiArrayDataType> to_ml_multiarray_data_type(MultiArray::DataType data_type) {
579    switch (data_type) {
580        case MultiArray::DataType::Float16: {
581            return MLMultiArrayDataTypeFloat16;
582        }
583        case MultiArray::DataType::Float32: {
584            return MLMultiArrayDataTypeFloat32;
585        }
586        case MultiArray::DataType::Float64: {
587            return MLMultiArrayDataTypeDouble;
588        }
589        case MultiArray::DataType::Int32: {
590            return MLMultiArrayDataTypeInt32;
591        }
592        default: {
593            return std::nullopt;
594        }
595    }
596}
597
598std::optional<MultiArray::DataType> to_multiarray_data_type(MLMultiArrayDataType data_type) {
599    switch (data_type) {
600        case MLMultiArrayDataTypeFloat16: {
601            return MultiArray::DataType::Float16;
602        }
603        case MLMultiArrayDataTypeFloat32: {
604            return MultiArray::DataType::Float32;
605        }
606        case MLMultiArrayDataTypeFloat64: {
607            return MultiArray::DataType::Float64;
608        }
609        case MLMultiArrayDataTypeInt32: {
610            return MultiArray::DataType::Int32;
611        }
612        default: {
613            return std::nullopt;
614        }
615    }
616}
617
618void *MultiArray::data(const std::vector<size_t>& indices) const noexcept {
619    assert(indices.size() == layout().shape().size());
620    uint8_t *ptr = static_cast<uint8_t *>(data());
621    ssize_t offset = ::get_data_offset(indices, layout().strides());
622    return ptr + offset * layout().num_bytes();
623}
624
625void *MultiArray::data(size_t index) const noexcept {
626    assert(index < layout().num_elements());
627    uint8_t *ptr = static_cast<uint8_t *>(data());
628    ssize_t offset = ::get_data_offset(index, layout().shape(), layout().strides());
629    return ptr + offset * layout().num_bytes();
630}
631
632} // namespace executorchcoreml
633