xref: /aosp_15_r20/external/ComputeLibrary/arm_compute/runtime/CPP/functions/CPPSplit.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2020-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_CPP_SPLIT_H
25 #define ARM_COMPUTE_CPP_SPLIT_H
26 
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/core/Helpers.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/Types.h"
31 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
32 
33 #include "support/ToolchainSupport.h"
34 
35 #include "arm_compute/runtime/IFunction.h"
36 
37 namespace arm_compute
38 {
39 /** Basic function to split a tensor along a given axis */
40 template <typename SliceType, typename TensorInterfaceType = ITensor>
41 class CPPSplit : public IFunction
42 {
43 public:
CPPSplit()44     CPPSplit()
45         : _outputs_vector(), _slice_functions(), _num_outputs(0)
46     {
47     }
48     /** Static function to check if given info will lead to a valid configuration of @ref CPPSplit
49      *
50      * @param[in] input   The input tensor info. Data types supported: All.
51      * @param[in] outputs A vector containing the output tensors' info. Data types supported: same as @p input.
52      *                    The output tensors should match the input tensor dimensions for all shape dimensions apart
53      *                    from the split dimension
54      * @param[in] axis    Axis on which to split the input.
55      *
56      * @return a status
57      */
validate(const ITensorInfo * input,const std::vector<ITensorInfo * > & outputs,unsigned int axis)58     static Status validate(const ITensorInfo *input, const std::vector<ITensorInfo *> &outputs, unsigned int axis)
59     {
60         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
61         ARM_COMPUTE_RETURN_ERROR_ON(axis >= input->num_dimensions());
62         ARM_COMPUTE_RETURN_ERROR_ON(outputs.size() < 2);
63 
64         // Get output shape
65         TensorShape  output_shape{};
66         unsigned int total_output_shape_size = 0;
67 
68         // Sum the output sizes and fall back to evenly-sized splits if any are zero
69         const bool using_split_shapes = std::none_of(outputs.begin(), outputs.end(), [&total_output_shape_size](ITensorInfo * info)
70         {
71             unsigned int output_shape_size = info->tensor_shape().total_size();
72             total_output_shape_size += output_shape_size;
73             return output_shape_size == 0;
74         });
75 
76         if(using_split_shapes)
77         {
78             ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape().total_size() != total_output_shape_size);
79         }
80         else
81         {
82             output_shape = arm_compute::misc::shape_calculator::compute_split_shape(input, axis, outputs.size());
83             ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() == 0);
84         }
85 
86         // Validate output tensors
87         unsigned int axis_offset = 0;
88         for(const auto &output : outputs)
89         {
90             ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
91             if(using_split_shapes)
92             {
93                 output_shape = output->tensor_shape();
94                 ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() == 0);
95             }
96 
97             const size_t axis_split_step = output_shape[axis];
98 
99             // Start/End coordinates
100             Coordinates start_coords;
101             Coordinates end_coords;
102             for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d)
103             {
104                 end_coords.set(d, -1);
105             }
106 
107             // Output auto inizialitation if not yet initialized
108             TensorInfo tmp_output_info = *output->clone();
109             if(tmp_output_info.tensor_shape().total_size() == 0)
110             {
111                 tmp_output_info = input->clone()->set_is_resizable(true).set_tensor_shape(output_shape);
112             }
113 
114             // Update coordinate on axis
115             start_coords.set(axis, axis_offset);
116             end_coords.set(axis, axis_offset + axis_split_step);
117 
118             ARM_COMPUTE_RETURN_ON_ERROR(SliceType::validate(input, output, start_coords, end_coords));
119             axis_offset += axis_split_step;
120         }
121 
122         return Status{};
123     }
124 
125     /** Initialise the kernel's input and outputs.
126      *
127      * @param[in]  input   The input tensor. Data types supported: All
128      * @param[out] outputs A vector containing the output tensors. Data types supported: Same as @p input.
129      *                     The output tensors should match the input tensor dimensions for all shape dimensions apart
130      *                     from the split dimension.
131      * @param[in]  axis    Axis on which to split the input.
132      */
configure(const TensorInterfaceType * input,const std::vector<TensorInterfaceType * > & outputs,unsigned int axis)133     void configure(const TensorInterfaceType *input, const std::vector<TensorInterfaceType *> &outputs, unsigned int axis)
134     {
135         // Create Slice functions
136         _num_outputs = outputs.size();
137         _slice_functions.resize(_num_outputs);
138 
139         // Extract output tensor info
140         std::vector<ITensorInfo *> outputs_info;
141         for(auto &output : outputs)
142         {
143             ARM_COMPUTE_ERROR_ON_NULLPTR(output);
144             outputs_info.emplace_back(output->info());
145         }
146 
147         // If any of the outputs have a zero size, fall-back to using evenly-sized output splits
148         const bool outputs_have_sizes = std::none_of(outputs_info.begin(), outputs_info.end(), [](ITensorInfo * info)
149         {
150             return info->tensor_shape().total_size() == 0;
151         });
152 
153         // Validate
154         ARM_COMPUTE_ERROR_THROW_ON(CPPSplit::validate(input->info(), outputs_info, axis));
155 
156         unsigned int axis_offset = 0;
157         unsigned int i           = 0;
158 
159         for(const auto &output_info : outputs_info)
160         {
161             // Get output shape
162             TensorShape output_shape = (outputs_have_sizes ?
163                                         output_info->tensor_shape() :
164                                         arm_compute::misc::shape_calculator::compute_split_shape(input->info(), axis, _num_outputs));
165 
166             const size_t axis_split_step = output_shape[axis];
167 
168             // Start/End coordinates
169             Coordinates start_coords;
170             Coordinates end_coords;
171 
172             for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d)
173             {
174                 end_coords.set(d, -1);
175             }
176 
177             // Update coordinate on axis
178             start_coords.set(axis, axis_offset);
179             end_coords.set(axis, axis_offset + axis_split_step);
180 
181             // Configure slice function
182             _slice_functions[i].configure(input, outputs[i], start_coords, end_coords);
183 
184             // Set valid region from shape
185             outputs[i]->info()->set_valid_region(ValidRegion(Coordinates(), output_shape));
186 
187             // Update axis offset
188             axis_offset += axis_split_step;
189             ++i;
190         }
191     }
192 
193 protected:
194     std::vector<TensorInterfaceType *> _outputs_vector;
195     std::vector<SliceType>             _slice_functions;
196     unsigned int                       _num_outputs;
197 };
198 
199 } // namespace arm_compute
200 #endif /* ARM_COMPUTE_CPP_SPLIT_H */
201