1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/strided_slice.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
22
23 namespace tflite {
24 namespace gpu {
25
26 namespace {
Is4Aligned(const SliceAttributes & attr)27 bool Is4Aligned(const SliceAttributes& attr) {
28 return attr.strides.c == 1 && attr.starts.c % 4 == 0;
29 }
30
GetOffset(const SliceAttributes & attr,int src_width,int src_height,int src_channels,int src_batch)31 int4 GetOffset(const SliceAttributes& attr, int src_width, int src_height,
32 int src_channels, int src_batch) {
33 int4 offset;
34 if (attr.strides.w > 0) {
35 offset.x = attr.starts.w;
36 } else {
37 if (attr.ends.w > 0) {
38 offset.x = attr.ends.w;
39 } else {
40 offset.x = src_width + attr.ends.w;
41 }
42 }
43 if (attr.strides.h > 0) {
44 offset.y = attr.starts.h;
45 } else {
46 if (attr.ends.h > 0) {
47 offset.y = attr.ends.h;
48 } else {
49 offset.y = src_height + attr.ends.h;
50 }
51 }
52 if (attr.strides.c > 0) {
53 offset.z = attr.starts.c;
54 } else {
55 if (attr.ends.c > 0) {
56 offset.z = attr.ends.c;
57 } else {
58 offset.z = src_channels + attr.ends.c;
59 }
60 }
61 if (Is4Aligned(attr)) {
62 offset.z /= 4;
63 }
64 if (attr.strides.b > 0) {
65 offset.w = attr.starts.b;
66 } else {
67 if (attr.ends.b > 0) {
68 offset.w = attr.ends.b;
69 } else {
70 offset.w = src_batch + attr.ends.b;
71 }
72 }
73 return offset;
74 }
75
76 } // namespace
77
StridedSlice(const OperationDef & definition,const SliceAttributes & attr)78 StridedSlice::StridedSlice(const OperationDef& definition,
79 const SliceAttributes& attr)
80 : GPUOperation(definition), attributes_(attr) {
81 work_group_size_ = int3(8, 4, 1);
82 code_ = GetStridedSliceCode(definition_, Is4Aligned(attributes_));
83 }
84
StridedSlice(StridedSlice && operation)85 StridedSlice::StridedSlice(StridedSlice&& operation)
86 : GPUOperation(std::move(operation)), attributes_(operation.attributes_) {}
87
operator =(StridedSlice && operation)88 StridedSlice& StridedSlice::operator=(StridedSlice&& operation) {
89 if (this != &operation) {
90 attributes_ = operation.attributes_;
91 GPUOperation::operator=(std::move(operation));
92 }
93 return *this;
94 }
95
GetStridedSliceCode(const OperationDef & op_def,bool alignedx4)96 std::string StridedSlice::GetStridedSliceCode(const OperationDef& op_def,
97 bool alignedx4) {
98 AddSrcTensor("src_tensor", op_def.src_tensors[0]);
99 AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
100 args_.AddInt("offset_x");
101 args_.AddInt("offset_y");
102 args_.AddInt("offset_z");
103 args_.AddInt("offset_b");
104 args_.AddInt("stride_x");
105 args_.AddInt("stride_y");
106 args_.AddInt("stride_z");
107 args_.AddInt("stride_b");
108
109 const std::string batch_id =
110 op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
111 std::string c;
112 c += "MAIN_FUNCTION($0) {\n";
113 if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
114 c += " int linear_id = GLOBAL_ID_0;\n";
115 c += " int X = linear_id / args.dst_tensor.Batch();\n";
116 c += " int B = linear_id % args.dst_tensor.Batch();\n";
117 c += " args.dst_tensor.SetBatchRef(B);\n";
118 } else {
119 c += " int X = GLOBAL_ID_0;\n";
120 }
121 c += " int Y = GLOBAL_ID_1;\n";
122 c += " int S = GLOBAL_ID_2;\n";
123 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
124 "S >= args.dst_tensor.Slices()) { \n";
125 c += " return; \n";
126 c += " } \n";
127 c += " int s_x = X * args.stride_x + args.offset_x;\n";
128 c += " int s_y = Y * args.stride_y + args.offset_y;\n";
129 if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
130 c += " int s_b = " + batch_id + " * args.stride_b + args.offset_b;\n";
131 c += " args.src_tensor.SetBatchRef(s_b);\n";
132 }
133 if (alignedx4) {
134 c += " int s_z = S + args.offset_z;\n";
135 c += " args.src_tensor::type result = args.src_tensor.Read(s_x, s_y, "
136 "s_z);\n";
137 } else {
138 c += " args.src_tensor::type result;\n";
139 const std::string postfixes[] = {"x", "y", "z", "w"};
140 for (int i = 0; i < 4; ++i) {
141 c += " {\n";
142 const std::string channel = "(S * 4 + " + std::to_string(i) + ")";
143 c += " int s_ch = min(" + channel +
144 " * args.stride_z + args.offset_z, args.src_tensor.Channels() - "
145 "1);\n";
146 c += " args.src_tensor.ReadPerChannel(result." + postfixes[i] +
147 ", s_x, s_y, s_ch);\n";
148 c += " }\n";
149 }
150 }
151 c += " args.dst_tensor.Write(result, X, Y, S);\n";
152 c += "}\n";
153 return c;
154 }
155
BindArguments(ArgumentsBinder * args)156 absl::Status StridedSlice::BindArguments(ArgumentsBinder* args) {
157 int4 offset = GetOffset(attributes_, src_[0]->Width(), src_[0]->Height(),
158 src_[0]->Channels(), src_[0]->Batch());
159 RETURN_IF_ERROR(args->SetInt("offset_x", offset.x));
160 RETURN_IF_ERROR(args->SetInt("offset_y", offset.y));
161 RETURN_IF_ERROR(args->SetInt("offset_z", offset.z));
162 RETURN_IF_ERROR(args->SetInt("offset_b", offset.w));
163 RETURN_IF_ERROR(args->SetInt("stride_x", attributes_.strides.w));
164 RETURN_IF_ERROR(args->SetInt("stride_y", attributes_.strides.h));
165 RETURN_IF_ERROR(args->SetInt("stride_z", attributes_.strides.c));
166 RETURN_IF_ERROR(args->SetInt("stride_b", attributes_.strides.b));
167 return absl::OkStatus();
168 }
169
GetGridSize() const170 int3 StridedSlice::GetGridSize() const {
171 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
172 const int grid_y = dst_[0]->Height();
173 const int grid_z = dst_[0]->Slices();
174 return int3(grid_x, grid_y, grid_z);
175 }
176
CreateStridedSlice(const OperationDef & definition,const SliceAttributes & attr)177 StridedSlice CreateStridedSlice(const OperationDef& definition,
178 const SliceAttributes& attr) {
179 return StridedSlice(definition, attr);
180 }
181
182 } // namespace gpu
183 } // namespace tflite
184