1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
17
18 #include <Availability.h>
19
20 #include <map>
21 #include <string>
22 #include <tuple>
23 #include <utility>
24
25 #include "absl/strings/match.h"
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/lite/delegates/gpu/common/kernel_info.h"
28 #include "tensorflow/lite/delegates/gpu/common/shape.h"
29 #include "tensorflow/lite/delegates/gpu/common/status.h"
30 #include "tensorflow/lite/delegates/gpu/common/types.h"
31 #include "tensorflow/lite/delegates/gpu/common/util.h"
32 #include "tensorflow/lite/delegates/gpu/metal/common.h"
33
34 namespace tflite {
35 namespace gpu {
36 namespace metal {
37 namespace {
IsWordSymbol(char symbol)38 bool IsWordSymbol(char symbol) {
39 return absl::ascii_isalnum(symbol) || symbol == '_';
40 }
41
ReplaceAllWords(const std::string & old_word,const std::string & new_word,std::string * str)42 void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
43 std::string* str) {
44 size_t position = str->find(old_word);
45 while (position != std::string::npos) {
46 const char prev = position == 0 ? ' ' : (*str)[position - 1];
47 const char next = position + old_word.size() < str->size()
48 ? (*str)[position + old_word.size()]
49 : ' ';
50 if (IsWordSymbol(prev) || IsWordSymbol(next)) {
51 position = str->find(old_word, position + 1);
52 continue;
53 }
54 str->replace(position, old_word.size(), new_word);
55 position = str->find(old_word, position + new_word.size());
56 }
57 }
58
GetMetalDefines(MetalDevice * device,CalculationsPrecision precision)59 std::map<std::string, std::string> GetMetalDefines(
60 MetalDevice* device, CalculationsPrecision precision) {
61 std::string simdgroup_barrier;
62 // simdgroup_barrier is supported since Metal shading language version 2.0
63 if (device->IsLanguageVersion2orHigher()) {
64 simdgroup_barrier = "simdgroup_barrier";
65 } else {
66 simdgroup_barrier = "threadgroup_barrier";
67 }
68 std::string storage_type;
69 std::string accumulator_type;
70 std::string to_accumulator_type4;
71 if (precision == CalculationsPrecision::F32) {
72 storage_type = "float";
73 accumulator_type = "float";
74 } else {
75 // FP16
76 storage_type = "half";
77 if (precision == CalculationsPrecision::F32_F16) {
78 accumulator_type = "float";
79 to_accumulator_type4 = "float4";
80 } else {
81 accumulator_type = "half";
82 }
83 }
84 return {
85 {"FLT16_0123(V)", "V[0]"},
86 {"FLT16_4567(V)", "V[1]"},
87 {"FLT16_89ab(V)", "V[2]"},
88 {"FLT16_cdef(V)", "V[3]"},
89 {"FLT", storage_type},
90 {"FLT2", storage_type + "2"},
91 {"FLT3", storage_type + "3"},
92 {"FLT4", storage_type + "4"},
93 {"ACCUM_FLT", accumulator_type},
94 {"ACCUM_FLT2", accumulator_type + "2"},
95 {"ACCUM_FLT3", accumulator_type + "3"},
96 {"ACCUM_FLT4", accumulator_type + "4"},
97 {"INIT_ACCUM_FLT4(value)", accumulator_type + "4(value)"},
98 {"TO_ACCUM_TYPE", to_accumulator_type4},
99 {"TO_ACCUM_FLT", accumulator_type},
100 {"TO_ACCUM_FLT2", accumulator_type + "2"},
101 {"TO_ACCUM_FLT3", accumulator_type + "3"},
102 {"TO_ACCUM_FLT4", accumulator_type + "4"},
103 {"TO_FLT4", storage_type + "4"},
104 {"SIMDGROUP_BARRIER", simdgroup_barrier},
105 {"SIMD_LOCAL_MEM_BARRIER", simdgroup_barrier},
106 {"MAIN_FUNCTION", "kernel void ComputeFunction"},
107 {"GLOBAL_ID_0", "static_cast<int>(reserved_gid.x)"},
108 {"GLOBAL_ID_1", "static_cast<int>(reserved_gid.y)"},
109 {"GLOBAL_ID_2", "static_cast<int>(reserved_gid.z)"},
110 {"LOCAL_ID_0", "static_cast<int>(reserved_lid.x)"},
111 {"LOCAL_ID_1", "static_cast<int>(reserved_lid.y)"},
112 {"LOCAL_ID_2", "static_cast<int>(reserved_lid.z)"},
113 {"GROUP_ID_0", "static_cast<int>(reserved_group_id.x)"},
114 {"GROUP_ID_1", "static_cast<int>(reserved_group_id.y)"},
115 {"GROUP_ID_2", "static_cast<int>(reserved_group_id.z)"},
116 {"GROUP_SIZE_0", "static_cast<int>(reserved_group_size.x)"},
117 {"GROUP_SIZE_1", "static_cast<int>(reserved_group_size.y)"},
118 {"GROUP_SIZE_2", "static_cast<int>(reserved_group_size.z)"},
119 {"SUB_GROUP_LOCAL_ID", "static_cast<int>(reserved_simd_id)"},
120 {"SUB_GROUP_BROADCAST(V, ID)", "simd_broadcast(V, ID)"},
121 {"__local", "threadgroup"},
122 {"__global", "device"},
123 {"__constant", "constant"},
124 {"LOCAL_MEM_BARRIER", "threadgroup_barrier(mem_flags::mem_threadgroup)"},
125 {"INIT_FLT(value)", storage_type + "(value)"},
126 {"INIT_FLT4(value)", storage_type + "4(value)"},
127 {"INIT_FLT4v4(v0, v1, v2, v3)", storage_type + "4(v0, v1, v2, v3)"},
128 {"INIT_FLOAT(value)", "float(value)"},
129 {"INIT_FLOAT2(value)", "float2(value)"},
130 {"INIT_FLOAT2v2(v0, v1)", "float2(v0, v1)"},
131 {"INIT_FLOAT3(value)", "float3(value)"},
132 {"INIT_FLOAT3v3(v0, v1, v2)", "float3(v0, v1, v2)"},
133 {"INIT_FLOAT4(value)", "float4(value)"},
134 {"INIT_FLOAT4v4(v0, v1, v2, v3)", "float4(v0, v1, v2, v3)"},
135 {"INIT_INT(value)", "int(value)"},
136 {"INIT_INT2v2(v0, v1)", "int2(v0, v1)"},
137 {"INIT_INT4v4(v0, v1, v2, v3)", "int4(v0, v1, v2, v3)"},
138 {"CONVERT_TO_INT4(value)", "int4(value)"},
139 };
140 }
141 } // namespace
142
ComputeTask(ComputeTask && task)143 ComputeTask::ComputeTask(ComputeTask&& task)
144 : operation_(std::move(task.operation_)),
145 program_(task.program_),
146 metal_args_(std::move(task.metal_args_)),
147 use_arguments_buffer_(task.use_arguments_buffer_),
148 need_icb_support_(task.need_icb_support_),
149 arguments_encoder_(task.arguments_encoder_),
150 arg_buffer_(task.arg_buffer_) {
151 task.program_ = nullptr;
152 task.arguments_encoder_ = nullptr;
153 task.arg_buffer_ = nullptr;
154 }
155
operator =(ComputeTask && task)156 ComputeTask& ComputeTask::operator=(ComputeTask&& task) {
157 if (this != &task) {
158 Release();
159 operation_ = std::move(task.operation_);
160 std::swap(program_, task.program_);
161 metal_args_ = std::move(task.metal_args_);
162 std::swap(use_arguments_buffer_, task.use_arguments_buffer_);
163 std::swap(need_icb_support_, task.need_icb_support_);
164 std::swap(arguments_encoder_, task.arguments_encoder_);
165 std::swap(arg_buffer_, task.arg_buffer_);
166 }
167 return *this;
168 }
169
~ComputeTask()170 ComputeTask::~ComputeTask() { Release(); }
171
Release()172 void ComputeTask::Release() {
173 if (program_) {
174 program_ = nullptr;
175 }
176 if (arguments_encoder_) {
177 arguments_encoder_ = nullptr;
178 }
179 if (arg_buffer_) {
180 arg_buffer_ = nullptr;
181 }
182 }
183
Init(std::unique_ptr<GPUOperation> && operation)184 void ComputeTask::Init(std::unique_ptr<GPUOperation>&& operation) {
185 operation_ = std::move(operation);
186 }
187
GetDefinition() const188 const OperationDef& ComputeTask::GetDefinition() const {
189 return operation_->GetDefinition();
190 }
191
Compile(MetalDevice * device)192 absl::Status ComputeTask::Compile(MetalDevice* device) {
193 RETURN_IF_ERROR(metal_args_.Init(use_arguments_buffer_, device,
194 &operation_->args_, &operation_->code_));
195
196 operation_->args_.ReleaseCPURepresentation();
197
198 // manually resolving this defines, so as Metal has reserved words for them
199 ReplaceAllWords("float16", "float4x4", &operation_->code_);
200 ReplaceAllWords("half16", "half4x4", &operation_->code_);
201 ReplaceAllWords("float8", "float2x4", &operation_->code_);
202 ReplaceAllWords("half8", "half2x4", &operation_->code_);
203 defines_ = GetMetalDefines(device, operation_->GetDefinition().precision);
204 return CompileProgram(device, operation_->code_, defines_);
205 }
206
CompileProgram(MetalDevice * device,const std::string & code,const std::map<std::string,std::string> & defines)207 absl::Status ComputeTask::CompileProgram(
208 MetalDevice* device, const std::string& code,
209 const std::map<std::string, std::string>& defines) {
210 id<MTLComputePipelineState> program;
211 if (use_arguments_buffer_) {
212 id<MTLArgumentEncoder> arguments_encoder;
213 if (need_icb_support_) {
214 RETURN_IF_ERROR(CreateComputeProgramWithICBSupport(
215 device->device(), code, "ComputeFunction", defines, &program,
216 &arguments_encoder));
217 } else {
218 RETURN_IF_ERROR(CreateComputeProgramWithArgumentBuffer(
219 device->device(), code, "ComputeFunction", defines, &program,
220 &arguments_encoder));
221 }
222 arguments_encoder_ = arguments_encoder;
223 arg_buffer_ =
224 [device->device() newBufferWithLength:arguments_encoder_.encodedLength
225 options:0];
226 if (!arg_buffer_) {
227 return absl::InternalError("Failed to create MTLBuffer.");
228 }
229 } else {
230 RETURN_IF_ERROR(CreateComputeProgram(device->device(), code,
231 "ComputeFunction", defines, &program));
232 }
233 program_ = program;
234 return absl::OkStatus();
235 }
236
Init(MetalDevice * device,const std::string & code,const std::map<std::string,std::string> & defines)237 absl::Status ComputeTask::Init(
238 MetalDevice* device, const std::string& code,
239 const std::map<std::string, std::string>& defines) {
240 return CompileProgram(device, code, defines);
241 }
242
RestoreDeserialized(MetalDevice * device)243 absl::Status ComputeTask::RestoreDeserialized(MetalDevice* device) {
244 RETURN_IF_ERROR(
245 metal_args_.Init(use_arguments_buffer_, device, &operation_->args_));
246
247 operation_->args_.ReleaseCPURepresentation();
248 return absl::OkStatus();
249 }
250
UpdateParams()251 absl::Status ComputeTask::UpdateParams() {
252 for (int i = 0; i < operation_->GetSrcTensorsNames().size(); ++i) {
253 const auto* metal_spatial_tensor =
254 dynamic_cast<const MetalSpatialTensor*>(operation_->GetSrcTensors()[i]);
255 if (!metal_spatial_tensor) {
256 return absl::InvalidArgumentError("Expected MetalSpatialTensor.");
257 }
258 RETURN_IF_ERROR(metal_args_.SetObjectRef(
259 operation_->GetSrcTensorsNames()[i], *metal_spatial_tensor));
260 }
261 for (int i = 0; i < operation_->GetDstTensorsNames().size(); ++i) {
262 const auto* metal_spatial_tensor =
263 dynamic_cast<const MetalSpatialTensor*>(operation_->GetDstTensors()[i]);
264 if (!metal_spatial_tensor) {
265 return absl::InvalidArgumentError("Expected MetalSpatialTensor.");
266 }
267 RETURN_IF_ERROR(metal_args_.SetObjectRef(
268 operation_->GetDstTensorsNames()[i], *metal_spatial_tensor));
269 }
270 RETURN_IF_ERROR(operation_->BindArguments(&metal_args_));
271 operation_->RecalculateGridSize();
272 operation_->RecalculateWorkGroupsCount();
273 Update();
274 return absl::OkStatus();
275 }
276
277 API_AVAILABLE(ios(13.0), macos(11.00), tvos(13.0))
EncodeToICB(id<MTLIndirectComputeCommand> icb_command)278 void ComputeTask::EncodeToICB(id<MTLIndirectComputeCommand> icb_command) {
279 MTLSize groupsCount, groupsSize;
280 groupsCount.width = operation_->GetWorkGroupsCount().x;
281 groupsCount.height = operation_->GetWorkGroupsCount().y;
282 groupsCount.depth = operation_->GetWorkGroupsCount().z;
283 groupsSize.width = operation_->work_group_size_.x;
284 groupsSize.height = operation_->work_group_size_.y;
285 groupsSize.depth = operation_->work_group_size_.z;
286 [icb_command setComputePipelineState:program_];
287 [icb_command setKernelBuffer:arg_buffer_ offset:0 atIndex:0];
288 [icb_command concurrentDispatchThreadgroups:groupsCount
289 threadsPerThreadgroup:groupsSize];
290 [icb_command setBarrier];
291 }
292
293 API_AVAILABLE(ios(11.0), macos(10.13), tvos(11.0))
AddResourcesToEncoder(id<MTLComputeCommandEncoder> encoder) const294 void ComputeTask::AddResourcesToEncoder(
295 id<MTLComputeCommandEncoder> encoder) const {
296 metal_args_.AddResourcesToEncoder(encoder);
297 }
298
Update()299 void ComputeTask::Update() {
300 if (use_arguments_buffer_) {
301 if (@available(macOS 10.13, iOS 11.0, tvOS 11.0, *)) {
302 [arguments_encoder_ setArgumentBuffer:arg_buffer_ offset:0];
303 metal_args_.EncodeArguments(arguments_encoder_);
304 }
305 }
306 }
307
Encode(id<MTLComputeCommandEncoder> encoder)308 void ComputeTask::Encode(id<MTLComputeCommandEncoder> encoder) {
309 [encoder setComputePipelineState:program_];
310 if (use_arguments_buffer_) {
311 if (@available(macOS 10.13, iOS 11.0, tvOS 11.0, *)) {
312 metal_args_.AddResourcesToEncoder(encoder);
313 [encoder setBuffer:arg_buffer_ offset:0 atIndex:0];
314 }
315 } else {
316 metal_args_.Encode(encoder, 0);
317 }
318 MTLSize groupsCount, groupsSize;
319 groupsCount.width = operation_->GetWorkGroupsCount().x;
320 groupsCount.height = operation_->GetWorkGroupsCount().y;
321 groupsCount.depth = operation_->GetWorkGroupsCount().z;
322 groupsSize.width = operation_->work_group_size_.x;
323 groupsSize.height = operation_->work_group_size_.y;
324 groupsSize.depth = operation_->work_group_size_.z;
325 [encoder dispatchThreadgroups:groupsCount threadsPerThreadgroup:groupsSize];
326 }
327
SetSrcTensor(MetalSpatialTensor * tensor,int index)328 void ComputeTask::SetSrcTensor(MetalSpatialTensor* tensor, int index) {
329 operation_->SetSrc(tensor, index);
330 auto status = metal_args_.SetObjectRef(
331 operation_->GetSrcTensorsNames()[index], *tensor);
332 }
333
SetDstTensor(MetalSpatialTensor * tensor,int index)334 void ComputeTask::SetDstTensor(MetalSpatialTensor* tensor, int index) {
335 operation_->SetDst(tensor, index);
336 auto status = metal_args_.SetObjectRef(
337 operation_->GetDstTensorsNames()[index], *tensor);
338 }
339
Tune(TuningType tuning_type,MetalDevice * device)340 absl::Status ComputeTask::Tune(TuningType tuning_type, MetalDevice* device) {
341 KernelInfo kernel_info;
342 kernel_info.max_work_group_size = [program_ maxTotalThreadsPerThreadgroup];
343 kernel_info.private_memory_size = 0;
344 std::vector<GPUOperation::DispatchInfo> possible_dispatches;
345 operation_->GetPossibleDispatches(tuning_type, device->GetInfo(), kernel_info,
346 &possible_dispatches);
347 if (possible_dispatches.empty()) {
348 return absl::NotFoundError("No dispatch parameters to launch kernel");
349 }
350 operation_->work_group_size_ = possible_dispatches[0].work_group_size;
351 operation_->RecalculateWorkGroupsCount();
352 return absl::OkStatus();
353 }
354
SetWorkGroupSize(const int3 & work_group_size)355 void ComputeTask::SetWorkGroupSize(const int3& work_group_size) {
356 operation_->work_group_size_ = work_group_size;
357 operation_->RecalculateWorkGroupsCount();
358 }
359
360 } // namespace metal
361 } // namespace gpu
362 } // namespace tflite
363