xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/strided_slice.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/strided_slice.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
22 
23 namespace tflite {
24 namespace gpu {
25 
26 namespace {
Is4Aligned(const SliceAttributes & attr)27 bool Is4Aligned(const SliceAttributes& attr) {
28   return attr.strides.c == 1 && attr.starts.c % 4 == 0;
29 }
30 
GetOffset(const SliceAttributes & attr,int src_width,int src_height,int src_channels,int src_batch)31 int4 GetOffset(const SliceAttributes& attr, int src_width, int src_height,
32                int src_channels, int src_batch) {
33   int4 offset;
34   if (attr.strides.w > 0) {
35     offset.x = attr.starts.w;
36   } else {
37     if (attr.ends.w > 0) {
38       offset.x = attr.ends.w;
39     } else {
40       offset.x = src_width + attr.ends.w;
41     }
42   }
43   if (attr.strides.h > 0) {
44     offset.y = attr.starts.h;
45   } else {
46     if (attr.ends.h > 0) {
47       offset.y = attr.ends.h;
48     } else {
49       offset.y = src_height + attr.ends.h;
50     }
51   }
52   if (attr.strides.c > 0) {
53     offset.z = attr.starts.c;
54   } else {
55     if (attr.ends.c > 0) {
56       offset.z = attr.ends.c;
57     } else {
58       offset.z = src_channels + attr.ends.c;
59     }
60   }
61   if (Is4Aligned(attr)) {
62     offset.z /= 4;
63   }
64   if (attr.strides.b > 0) {
65     offset.w = attr.starts.b;
66   } else {
67     if (attr.ends.b > 0) {
68       offset.w = attr.ends.b;
69     } else {
70       offset.w = src_batch + attr.ends.b;
71     }
72   }
73   return offset;
74 }
75 
76 }  // namespace
77 
StridedSlice(const OperationDef & definition,const SliceAttributes & attr)78 StridedSlice::StridedSlice(const OperationDef& definition,
79                            const SliceAttributes& attr)
80     : GPUOperation(definition), attributes_(attr) {
81   work_group_size_ = int3(8, 4, 1);
82   code_ = GetStridedSliceCode(definition_, Is4Aligned(attributes_));
83 }
84 
StridedSlice(StridedSlice && operation)85 StridedSlice::StridedSlice(StridedSlice&& operation)
86     : GPUOperation(std::move(operation)), attributes_(operation.attributes_) {}
87 
operator =(StridedSlice && operation)88 StridedSlice& StridedSlice::operator=(StridedSlice&& operation) {
89   if (this != &operation) {
90     attributes_ = operation.attributes_;
91     GPUOperation::operator=(std::move(operation));
92   }
93   return *this;
94 }
95 
GetStridedSliceCode(const OperationDef & op_def,bool alignedx4)96 std::string StridedSlice::GetStridedSliceCode(const OperationDef& op_def,
97                                               bool alignedx4) {
98   AddSrcTensor("src_tensor", op_def.src_tensors[0]);
99   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
100   args_.AddInt("offset_x");
101   args_.AddInt("offset_y");
102   args_.AddInt("offset_z");
103   args_.AddInt("offset_b");
104   args_.AddInt("stride_x");
105   args_.AddInt("stride_y");
106   args_.AddInt("stride_z");
107   args_.AddInt("stride_b");
108 
109   const std::string batch_id =
110       op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
111   std::string c;
112   c += "MAIN_FUNCTION($0) {\n";
113   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
114     c += "  int linear_id = GLOBAL_ID_0;\n";
115     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
116     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
117     c += "  args.dst_tensor.SetBatchRef(B);\n";
118   } else {
119     c += "  int X = GLOBAL_ID_0;\n";
120   }
121   c += "  int Y = GLOBAL_ID_1;\n";
122   c += "  int S = GLOBAL_ID_2;\n";
123   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
124        "S >= args.dst_tensor.Slices()) { \n";
125   c += "    return; \n";
126   c += "  } \n";
127   c += "  int s_x = X * args.stride_x + args.offset_x;\n";
128   c += "  int s_y = Y * args.stride_y + args.offset_y;\n";
129   if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
130     c += "  int s_b = " + batch_id + " * args.stride_b + args.offset_b;\n";
131     c += "  args.src_tensor.SetBatchRef(s_b);\n";
132   }
133   if (alignedx4) {
134     c += "  int s_z = S + args.offset_z;\n";
135     c += "  args.src_tensor::type result = args.src_tensor.Read(s_x, s_y, "
136          "s_z);\n";
137   } else {
138     c += "  args.src_tensor::type result;\n";
139     const std::string postfixes[] = {"x", "y", "z", "w"};
140     for (int i = 0; i < 4; ++i) {
141       c += "  {\n";
142       const std::string channel = "(S * 4 + " + std::to_string(i) + ")";
143       c += "    int s_ch = min(" + channel +
144            " * args.stride_z + args.offset_z, args.src_tensor.Channels() - "
145            "1);\n";
146       c += "    args.src_tensor.ReadPerChannel(result." + postfixes[i] +
147            ", s_x, s_y, s_ch);\n";
148       c += "  }\n";
149     }
150   }
151   c += "  args.dst_tensor.Write(result, X, Y, S);\n";
152   c += "}\n";
153   return c;
154 }
155 
BindArguments(ArgumentsBinder * args)156 absl::Status StridedSlice::BindArguments(ArgumentsBinder* args) {
157   int4 offset = GetOffset(attributes_, src_[0]->Width(), src_[0]->Height(),
158                           src_[0]->Channels(), src_[0]->Batch());
159   RETURN_IF_ERROR(args->SetInt("offset_x", offset.x));
160   RETURN_IF_ERROR(args->SetInt("offset_y", offset.y));
161   RETURN_IF_ERROR(args->SetInt("offset_z", offset.z));
162   RETURN_IF_ERROR(args->SetInt("offset_b", offset.w));
163   RETURN_IF_ERROR(args->SetInt("stride_x", attributes_.strides.w));
164   RETURN_IF_ERROR(args->SetInt("stride_y", attributes_.strides.h));
165   RETURN_IF_ERROR(args->SetInt("stride_z", attributes_.strides.c));
166   RETURN_IF_ERROR(args->SetInt("stride_b", attributes_.strides.b));
167   return absl::OkStatus();
168 }
169 
GetGridSize() const170 int3 StridedSlice::GetGridSize() const {
171   const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
172   const int grid_y = dst_[0]->Height();
173   const int grid_z = dst_[0]->Slices();
174   return int3(grid_x, grid_y, grid_z);
175 }
176 
CreateStridedSlice(const OperationDef & definition,const SliceAttributes & attr)177 StridedSlice CreateStridedSlice(const OperationDef& definition,
178                                 const SliceAttributes& attr) {
179   return StridedSlice(definition, attr);
180 }
181 
182 }  // namespace gpu
183 }  // namespace tflite
184