xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
10 
11 namespace vkcompute {
12 
13 //
14 // Tensor output size calculation functions
15 //
16 
calculate_broadcasted_output_size(const api::vTensor & t1,const api::vTensor & t2)17 std::vector<int64_t> calculate_broadcasted_output_size(
18     const api::vTensor& t1,
19     const api::vTensor& t2) {
20   std::vector<int64_t> out_sizes(
21       std::max(t1.sizes().size(), t2.sizes().size()));
22 
23   // Match the sizes in reverse because sizes are in NCHW order
24   for (int i = -1; i >= -out_sizes.size(); --i) {
25     out_sizes.at(out_sizes.size() + i) =
26         std::max(utils::val_at(i, t1.sizes()), utils::val_at(i, t2.sizes()));
27   }
28 
29   return out_sizes;
30 }
31 
32 //
33 // Tensor property checking functions
34 //
35 
check_ndim_is(const api::vTensor & t,size_t ndim)36 bool check_ndim_is(const api::vTensor& t, size_t ndim) {
37   return t.sizes().size() == ndim;
38 }
39 
check_same_sizes_at(const api::vTensor & t1,const int64_t d1,const api::vTensor & t2,const int64_t d2)40 bool check_same_sizes_at(
41     const api::vTensor& t1,
42     const int64_t d1,
43     const api::vTensor& t2,
44     const int64_t d2) {
45   return utils::val_at(d1, t1.sizes()) == utils::val_at(d2, t2.sizes());
46 }
47 
check_packed_dim_is(const api::vTensor & t,const int32_t packed_dim)48 bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim) {
49   return t.packed_dim() == packed_dim;
50 }
51 
check_same_ndim(const api::vTensor & t1,const api::vTensor & t2)52 bool check_same_ndim(const api::vTensor& t1, const api::vTensor& t2) {
53   return t1.sizes().size() == t2.sizes().size();
54 }
55 
check_same_packed_dim(const api::vTensor & t1,const api::vTensor & t2)56 bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2) {
57   return t1.packed_dim() == t2.packed_dim();
58 }
59 
check_same_packed_dim(const api::vTensor & t1,const api::vTensor & t2,const api::vTensor & t3)60 bool check_same_packed_dim(
61     const api::vTensor& t1,
62     const api::vTensor& t2,
63     const api::vTensor& t3) {
64   if (t1.packed_dim() != t2.packed_dim()) {
65     return false;
66   }
67   return (t1.packed_dim() == t3.packed_dim());
68 }
69 
70 //
71 // Broadcast flag functions
72 //
73 
is_packed_dim_broadcasted(const api::vTensor & sndr,const api::vTensor & rcvr)74 bool is_packed_dim_broadcasted(
75     const api::vTensor& sndr,
76     const api::vTensor& rcvr) {
77   // We assume that the tensors are broadcastable. If values aren't equal at
78   // some index, then the value of rcvr is 1 and hence should be broadcasted.
79   switch (sndr.packed_dim()) {
80     case WHCN::kChannelsDim:
81       return utils::val_at(-3, sndr.sizes()) > utils::val_at(-3, rcvr.sizes());
82     case WHCN::kHeightDim:
83       return utils::val_at(-2, sndr.sizes()) > utils::val_at(-2, rcvr.sizes());
84     case WHCN::kWidthDim:
85       return utils::val_at(-1, sndr.sizes()) > utils::val_at(-1, rcvr.sizes());
86     default:
87       VK_THROW("Invalid packed dim");
88   }
89 }
90 
create_broadcast_params(const api::vTensor & t1,const api::vTensor & t2)91 utils::ivec2 create_broadcast_params(
92     const api::vTensor& t1,
93     const api::vTensor& t2) {
94   return utils::make_ivec2(
95       {is_packed_dim_broadcasted(t2, t1), is_packed_dim_broadcasted(t1, t2)});
96 }
97 
98 //
99 // Work group size calculation functions
100 //
101 
adaptive_work_group_size(const utils::uvec3 & global_work_group)102 utils::uvec3 adaptive_work_group_size(const utils::uvec3& global_work_group) {
103   utils::uvec3 local_group_size = {4, 4, 4};
104   if (global_work_group[2u] == 1) {
105     if (global_work_group[1u] < 8) {
106       local_group_size[0u] = 16;
107       local_group_size[1u] = 4;
108       local_group_size[2u] = 1;
109     } else {
110       local_group_size[0u] = 8;
111       local_group_size[1u] = 8;
112       local_group_size[2u] = 1;
113     }
114   }
115   return local_group_size;
116 }
117 
118 } // namespace vkcompute
119