1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
10
11 namespace vkcompute {
12
13 //
14 // Tensor output size calculation functions
15 //
16
calculate_broadcasted_output_size(const api::vTensor & t1,const api::vTensor & t2)17 std::vector<int64_t> calculate_broadcasted_output_size(
18 const api::vTensor& t1,
19 const api::vTensor& t2) {
20 std::vector<int64_t> out_sizes(
21 std::max(t1.sizes().size(), t2.sizes().size()));
22
23 // Match the sizes in reverse because sizes are in NCHW order
24 for (int i = -1; i >= -out_sizes.size(); --i) {
25 out_sizes.at(out_sizes.size() + i) =
26 std::max(utils::val_at(i, t1.sizes()), utils::val_at(i, t2.sizes()));
27 }
28
29 return out_sizes;
30 }
31
32 //
33 // Tensor property checking functions
34 //
35
check_ndim_is(const api::vTensor & t,size_t ndim)36 bool check_ndim_is(const api::vTensor& t, size_t ndim) {
37 return t.sizes().size() == ndim;
38 }
39
check_same_sizes_at(const api::vTensor & t1,const int64_t d1,const api::vTensor & t2,const int64_t d2)40 bool check_same_sizes_at(
41 const api::vTensor& t1,
42 const int64_t d1,
43 const api::vTensor& t2,
44 const int64_t d2) {
45 return utils::val_at(d1, t1.sizes()) == utils::val_at(d2, t2.sizes());
46 }
47
check_packed_dim_is(const api::vTensor & t,const int32_t packed_dim)48 bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim) {
49 return t.packed_dim() == packed_dim;
50 }
51
check_same_ndim(const api::vTensor & t1,const api::vTensor & t2)52 bool check_same_ndim(const api::vTensor& t1, const api::vTensor& t2) {
53 return t1.sizes().size() == t2.sizes().size();
54 }
55
check_same_packed_dim(const api::vTensor & t1,const api::vTensor & t2)56 bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2) {
57 return t1.packed_dim() == t2.packed_dim();
58 }
59
check_same_packed_dim(const api::vTensor & t1,const api::vTensor & t2,const api::vTensor & t3)60 bool check_same_packed_dim(
61 const api::vTensor& t1,
62 const api::vTensor& t2,
63 const api::vTensor& t3) {
64 if (t1.packed_dim() != t2.packed_dim()) {
65 return false;
66 }
67 return (t1.packed_dim() == t3.packed_dim());
68 }
69
70 //
71 // Broadcast flag functions
72 //
73
is_packed_dim_broadcasted(const api::vTensor & sndr,const api::vTensor & rcvr)74 bool is_packed_dim_broadcasted(
75 const api::vTensor& sndr,
76 const api::vTensor& rcvr) {
77 // We assume that the tensors are broadcastable. If values aren't equal at
78 // some index, then the value of rcvr is 1 and hence should be broadcasted.
79 switch (sndr.packed_dim()) {
80 case WHCN::kChannelsDim:
81 return utils::val_at(-3, sndr.sizes()) > utils::val_at(-3, rcvr.sizes());
82 case WHCN::kHeightDim:
83 return utils::val_at(-2, sndr.sizes()) > utils::val_at(-2, rcvr.sizes());
84 case WHCN::kWidthDim:
85 return utils::val_at(-1, sndr.sizes()) > utils::val_at(-1, rcvr.sizes());
86 default:
87 VK_THROW("Invalid packed dim");
88 }
89 }
90
create_broadcast_params(const api::vTensor & t1,const api::vTensor & t2)91 utils::ivec2 create_broadcast_params(
92 const api::vTensor& t1,
93 const api::vTensor& t2) {
94 return utils::make_ivec2(
95 {is_packed_dim_broadcasted(t2, t1), is_packed_dim_broadcasted(t1, t2)});
96 }
97
98 //
99 // Work group size calculation functions
100 //
101
adaptive_work_group_size(const utils::uvec3 & global_work_group)102 utils::uvec3 adaptive_work_group_size(const utils::uvec3& global_work_group) {
103 utils::uvec3 local_group_size = {4, 4, 4};
104 if (global_work_group[2u] == 1) {
105 if (global_work_group[1u] < 8) {
106 local_group_size[0u] = 16;
107 local_group_size[1u] = 4;
108 local_group_size[2u] = 1;
109 } else {
110 local_group_size[0u] = 8;
111 local_group_size[1u] = 8;
112 local_group_size[2u] = 1;
113 }
114 }
115 return local_group_size;
116 }
117
118 } // namespace vkcompute
119