xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/metal/compute_task.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
17 
18 #include <Availability.h>
19 
20 #include <map>
21 #include <string>
22 #include <tuple>
23 #include <utility>
24 
25 #include "absl/strings/match.h"
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/lite/delegates/gpu/common/kernel_info.h"
28 #include "tensorflow/lite/delegates/gpu/common/shape.h"
29 #include "tensorflow/lite/delegates/gpu/common/status.h"
30 #include "tensorflow/lite/delegates/gpu/common/types.h"
31 #include "tensorflow/lite/delegates/gpu/common/util.h"
32 #include "tensorflow/lite/delegates/gpu/metal/common.h"
33 
34 namespace tflite {
35 namespace gpu {
36 namespace metal {
37 namespace {
IsWordSymbol(char symbol)38 bool IsWordSymbol(char symbol) {
39   return absl::ascii_isalnum(symbol) || symbol == '_';
40 }
41 
ReplaceAllWords(const std::string & old_word,const std::string & new_word,std::string * str)42 void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
43                      std::string* str) {
44   size_t position = str->find(old_word);
45   while (position != std::string::npos) {
46     const char prev = position == 0 ? ' ' : (*str)[position - 1];
47     const char next = position + old_word.size() < str->size()
48                           ? (*str)[position + old_word.size()]
49                           : ' ';
50     if (IsWordSymbol(prev) || IsWordSymbol(next)) {
51       position = str->find(old_word, position + 1);
52       continue;
53     }
54     str->replace(position, old_word.size(), new_word);
55     position = str->find(old_word, position + new_word.size());
56   }
57 }
58 
GetMetalDefines(MetalDevice * device,CalculationsPrecision precision)59 std::map<std::string, std::string> GetMetalDefines(
60     MetalDevice* device, CalculationsPrecision precision) {
61   std::string simdgroup_barrier;
62   // simdgroup_barrier is supported since Metal shading language version 2.0
63   if (device->IsLanguageVersion2orHigher()) {
64     simdgroup_barrier = "simdgroup_barrier";
65   } else {
66     simdgroup_barrier = "threadgroup_barrier";
67   }
68   std::string storage_type;
69   std::string accumulator_type;
70   std::string to_accumulator_type4;
71   if (precision == CalculationsPrecision::F32) {
72     storage_type = "float";
73     accumulator_type = "float";
74   } else {
75     // FP16
76     storage_type = "half";
77     if (precision == CalculationsPrecision::F32_F16) {
78       accumulator_type = "float";
79       to_accumulator_type4 = "float4";
80     } else {
81       accumulator_type = "half";
82     }
83   }
84   return {
85       {"FLT16_0123(V)", "V[0]"},
86       {"FLT16_4567(V)", "V[1]"},
87       {"FLT16_89ab(V)", "V[2]"},
88       {"FLT16_cdef(V)", "V[3]"},
89       {"FLT", storage_type},
90       {"FLT2", storage_type + "2"},
91       {"FLT3", storage_type + "3"},
92       {"FLT4", storage_type + "4"},
93       {"ACCUM_FLT", accumulator_type},
94       {"ACCUM_FLT2", accumulator_type + "2"},
95       {"ACCUM_FLT3", accumulator_type + "3"},
96       {"ACCUM_FLT4", accumulator_type + "4"},
97       {"INIT_ACCUM_FLT4(value)", accumulator_type + "4(value)"},
98       {"TO_ACCUM_TYPE", to_accumulator_type4},
99       {"TO_ACCUM_FLT", accumulator_type},
100       {"TO_ACCUM_FLT2", accumulator_type + "2"},
101       {"TO_ACCUM_FLT3", accumulator_type + "3"},
102       {"TO_ACCUM_FLT4", accumulator_type + "4"},
103       {"TO_FLT4", storage_type + "4"},
104       {"SIMDGROUP_BARRIER", simdgroup_barrier},
105       {"SIMD_LOCAL_MEM_BARRIER", simdgroup_barrier},
106       {"MAIN_FUNCTION", "kernel void ComputeFunction"},
107       {"GLOBAL_ID_0", "static_cast<int>(reserved_gid.x)"},
108       {"GLOBAL_ID_1", "static_cast<int>(reserved_gid.y)"},
109       {"GLOBAL_ID_2", "static_cast<int>(reserved_gid.z)"},
110       {"LOCAL_ID_0", "static_cast<int>(reserved_lid.x)"},
111       {"LOCAL_ID_1", "static_cast<int>(reserved_lid.y)"},
112       {"LOCAL_ID_2", "static_cast<int>(reserved_lid.z)"},
113       {"GROUP_ID_0", "static_cast<int>(reserved_group_id.x)"},
114       {"GROUP_ID_1", "static_cast<int>(reserved_group_id.y)"},
115       {"GROUP_ID_2", "static_cast<int>(reserved_group_id.z)"},
116       {"GROUP_SIZE_0", "static_cast<int>(reserved_group_size.x)"},
117       {"GROUP_SIZE_1", "static_cast<int>(reserved_group_size.y)"},
118       {"GROUP_SIZE_2", "static_cast<int>(reserved_group_size.z)"},
119       {"SUB_GROUP_LOCAL_ID", "static_cast<int>(reserved_simd_id)"},
120       {"SUB_GROUP_BROADCAST(V, ID)", "simd_broadcast(V, ID)"},
121       {"__local", "threadgroup"},
122       {"__global", "device"},
123       {"__constant", "constant"},
124       {"LOCAL_MEM_BARRIER", "threadgroup_barrier(mem_flags::mem_threadgroup)"},
125       {"INIT_FLT(value)", storage_type + "(value)"},
126       {"INIT_FLT4(value)", storage_type + "4(value)"},
127       {"INIT_FLT4v4(v0, v1, v2, v3)", storage_type + "4(v0, v1, v2, v3)"},
128       {"INIT_FLOAT(value)", "float(value)"},
129       {"INIT_FLOAT2(value)", "float2(value)"},
130       {"INIT_FLOAT2v2(v0, v1)", "float2(v0, v1)"},
131       {"INIT_FLOAT3(value)", "float3(value)"},
132       {"INIT_FLOAT3v3(v0, v1, v2)", "float3(v0, v1, v2)"},
133       {"INIT_FLOAT4(value)", "float4(value)"},
134       {"INIT_FLOAT4v4(v0, v1, v2, v3)", "float4(v0, v1, v2, v3)"},
135       {"INIT_INT(value)", "int(value)"},
136       {"INIT_INT2v2(v0, v1)", "int2(v0, v1)"},
137       {"INIT_INT4v4(v0, v1, v2, v3)", "int4(v0, v1, v2, v3)"},
138       {"CONVERT_TO_INT4(value)", "int4(value)"},
139   };
140 }
141 }  // namespace
142 
ComputeTask(ComputeTask && task)143 ComputeTask::ComputeTask(ComputeTask&& task)
144     : operation_(std::move(task.operation_)),
145       program_(task.program_),
146       metal_args_(std::move(task.metal_args_)),
147       use_arguments_buffer_(task.use_arguments_buffer_),
148       need_icb_support_(task.need_icb_support_),
149       arguments_encoder_(task.arguments_encoder_),
150       arg_buffer_(task.arg_buffer_) {
151   task.program_ = nullptr;
152   task.arguments_encoder_ = nullptr;
153   task.arg_buffer_ = nullptr;
154 }
155 
operator =(ComputeTask && task)156 ComputeTask& ComputeTask::operator=(ComputeTask&& task) {
157   if (this != &task) {
158     Release();
159     operation_ = std::move(task.operation_);
160     std::swap(program_, task.program_);
161     metal_args_ = std::move(task.metal_args_);
162     std::swap(use_arguments_buffer_, task.use_arguments_buffer_);
163     std::swap(need_icb_support_, task.need_icb_support_);
164     std::swap(arguments_encoder_, task.arguments_encoder_);
165     std::swap(arg_buffer_, task.arg_buffer_);
166   }
167   return *this;
168 }
169 
~ComputeTask()170 ComputeTask::~ComputeTask() { Release(); }
171 
Release()172 void ComputeTask::Release() {
173   if (program_) {
174     program_ = nullptr;
175   }
176   if (arguments_encoder_) {
177     arguments_encoder_ = nullptr;
178   }
179   if (arg_buffer_) {
180     arg_buffer_ = nullptr;
181   }
182 }
183 
Init(std::unique_ptr<GPUOperation> && operation)184 void ComputeTask::Init(std::unique_ptr<GPUOperation>&& operation) {
185   operation_ = std::move(operation);
186 }
187 
GetDefinition() const188 const OperationDef& ComputeTask::GetDefinition() const {
189   return operation_->GetDefinition();
190 }
191 
Compile(MetalDevice * device)192 absl::Status ComputeTask::Compile(MetalDevice* device) {
193   RETURN_IF_ERROR(metal_args_.Init(use_arguments_buffer_, device,
194                                    &operation_->args_, &operation_->code_));
195 
196   operation_->args_.ReleaseCPURepresentation();
197 
198   // manually resolving this defines, so as Metal has reserved words for them
199   ReplaceAllWords("float16", "float4x4", &operation_->code_);
200   ReplaceAllWords("half16", "half4x4", &operation_->code_);
201   ReplaceAllWords("float8", "float2x4", &operation_->code_);
202   ReplaceAllWords("half8", "half2x4", &operation_->code_);
203   defines_ = GetMetalDefines(device, operation_->GetDefinition().precision);
204   return CompileProgram(device, operation_->code_, defines_);
205 }
206 
CompileProgram(MetalDevice * device,const std::string & code,const std::map<std::string,std::string> & defines)207 absl::Status ComputeTask::CompileProgram(
208     MetalDevice* device, const std::string& code,
209     const std::map<std::string, std::string>& defines) {
210   id<MTLComputePipelineState> program;
211   if (use_arguments_buffer_) {
212     id<MTLArgumentEncoder> arguments_encoder;
213     if (need_icb_support_) {
214       RETURN_IF_ERROR(CreateComputeProgramWithICBSupport(
215           device->device(), code, "ComputeFunction", defines, &program,
216           &arguments_encoder));
217     } else {
218       RETURN_IF_ERROR(CreateComputeProgramWithArgumentBuffer(
219           device->device(), code, "ComputeFunction", defines, &program,
220           &arguments_encoder));
221     }
222     arguments_encoder_ = arguments_encoder;
223     arg_buffer_ =
224         [device->device() newBufferWithLength:arguments_encoder_.encodedLength
225                                       options:0];
226     if (!arg_buffer_) {
227       return absl::InternalError("Failed to create MTLBuffer.");
228     }
229   } else {
230     RETURN_IF_ERROR(CreateComputeProgram(device->device(), code,
231                                          "ComputeFunction", defines, &program));
232   }
233   program_ = program;
234   return absl::OkStatus();
235 }
236 
Init(MetalDevice * device,const std::string & code,const std::map<std::string,std::string> & defines)237 absl::Status ComputeTask::Init(
238     MetalDevice* device, const std::string& code,
239     const std::map<std::string, std::string>& defines) {
240   return CompileProgram(device, code, defines);
241 }
242 
RestoreDeserialized(MetalDevice * device)243 absl::Status ComputeTask::RestoreDeserialized(MetalDevice* device) {
244   RETURN_IF_ERROR(
245       metal_args_.Init(use_arguments_buffer_, device, &operation_->args_));
246 
247   operation_->args_.ReleaseCPURepresentation();
248   return absl::OkStatus();
249 }
250 
UpdateParams()251 absl::Status ComputeTask::UpdateParams() {
252   for (int i = 0; i < operation_->GetSrcTensorsNames().size(); ++i) {
253     const auto* metal_spatial_tensor =
254         dynamic_cast<const MetalSpatialTensor*>(operation_->GetSrcTensors()[i]);
255     if (!metal_spatial_tensor) {
256       return absl::InvalidArgumentError("Expected MetalSpatialTensor.");
257     }
258     RETURN_IF_ERROR(metal_args_.SetObjectRef(
259         operation_->GetSrcTensorsNames()[i], *metal_spatial_tensor));
260   }
261   for (int i = 0; i < operation_->GetDstTensorsNames().size(); ++i) {
262     const auto* metal_spatial_tensor =
263         dynamic_cast<const MetalSpatialTensor*>(operation_->GetDstTensors()[i]);
264     if (!metal_spatial_tensor) {
265       return absl::InvalidArgumentError("Expected MetalSpatialTensor.");
266     }
267     RETURN_IF_ERROR(metal_args_.SetObjectRef(
268         operation_->GetDstTensorsNames()[i], *metal_spatial_tensor));
269   }
270   RETURN_IF_ERROR(operation_->BindArguments(&metal_args_));
271   operation_->RecalculateGridSize();
272   operation_->RecalculateWorkGroupsCount();
273   Update();
274   return absl::OkStatus();
275 }
276 
277 API_AVAILABLE(ios(13.0), macos(11.00), tvos(13.0))
EncodeToICB(id<MTLIndirectComputeCommand> icb_command)278 void ComputeTask::EncodeToICB(id<MTLIndirectComputeCommand> icb_command) {
279   MTLSize groupsCount, groupsSize;
280   groupsCount.width = operation_->GetWorkGroupsCount().x;
281   groupsCount.height = operation_->GetWorkGroupsCount().y;
282   groupsCount.depth = operation_->GetWorkGroupsCount().z;
283   groupsSize.width = operation_->work_group_size_.x;
284   groupsSize.height = operation_->work_group_size_.y;
285   groupsSize.depth = operation_->work_group_size_.z;
286   [icb_command setComputePipelineState:program_];
287   [icb_command setKernelBuffer:arg_buffer_ offset:0 atIndex:0];
288   [icb_command concurrentDispatchThreadgroups:groupsCount
289                         threadsPerThreadgroup:groupsSize];
290   [icb_command setBarrier];
291 }
292 
293 API_AVAILABLE(ios(11.0), macos(10.13), tvos(11.0))
AddResourcesToEncoder(id<MTLComputeCommandEncoder> encoder) const294 void ComputeTask::AddResourcesToEncoder(
295     id<MTLComputeCommandEncoder> encoder) const {
296   metal_args_.AddResourcesToEncoder(encoder);
297 }
298 
Update()299 void ComputeTask::Update() {
300   if (use_arguments_buffer_) {
301     if (@available(macOS 10.13, iOS 11.0, tvOS 11.0, *)) {
302       [arguments_encoder_ setArgumentBuffer:arg_buffer_ offset:0];
303       metal_args_.EncodeArguments(arguments_encoder_);
304     }
305   }
306 }
307 
Encode(id<MTLComputeCommandEncoder> encoder)308 void ComputeTask::Encode(id<MTLComputeCommandEncoder> encoder) {
309   [encoder setComputePipelineState:program_];
310   if (use_arguments_buffer_) {
311     if (@available(macOS 10.13, iOS 11.0, tvOS 11.0, *)) {
312       metal_args_.AddResourcesToEncoder(encoder);
313       [encoder setBuffer:arg_buffer_ offset:0 atIndex:0];
314     }
315   } else {
316     metal_args_.Encode(encoder, 0);
317   }
318   MTLSize groupsCount, groupsSize;
319   groupsCount.width = operation_->GetWorkGroupsCount().x;
320   groupsCount.height = operation_->GetWorkGroupsCount().y;
321   groupsCount.depth = operation_->GetWorkGroupsCount().z;
322   groupsSize.width = operation_->work_group_size_.x;
323   groupsSize.height = operation_->work_group_size_.y;
324   groupsSize.depth = operation_->work_group_size_.z;
325   [encoder dispatchThreadgroups:groupsCount threadsPerThreadgroup:groupsSize];
326 }
327 
SetSrcTensor(MetalSpatialTensor * tensor,int index)328 void ComputeTask::SetSrcTensor(MetalSpatialTensor* tensor, int index) {
329   operation_->SetSrc(tensor, index);
330   auto status = metal_args_.SetObjectRef(
331       operation_->GetSrcTensorsNames()[index], *tensor);
332 }
333 
SetDstTensor(MetalSpatialTensor * tensor,int index)334 void ComputeTask::SetDstTensor(MetalSpatialTensor* tensor, int index) {
335   operation_->SetDst(tensor, index);
336   auto status = metal_args_.SetObjectRef(
337       operation_->GetDstTensorsNames()[index], *tensor);
338 }
339 
Tune(TuningType tuning_type,MetalDevice * device)340 absl::Status ComputeTask::Tune(TuningType tuning_type, MetalDevice* device) {
341   KernelInfo kernel_info;
342   kernel_info.max_work_group_size = [program_ maxTotalThreadsPerThreadgroup];
343   kernel_info.private_memory_size = 0;
344   std::vector<GPUOperation::DispatchInfo> possible_dispatches;
345   operation_->GetPossibleDispatches(tuning_type, device->GetInfo(), kernel_info,
346                                     &possible_dispatches);
347   if (possible_dispatches.empty()) {
348     return absl::NotFoundError("No dispatch parameters to launch kernel");
349   }
350   operation_->work_group_size_ = possible_dispatches[0].work_group_size;
351   operation_->RecalculateWorkGroupsCount();
352   return absl::OkStatus();
353 }
354 
SetWorkGroupSize(const int3 & work_group_size)355 void ComputeTask::SetWorkGroupSize(const int3& work_group_size) {
356   operation_->work_group_size_ = work_group_size;
357   operation_->RecalculateWorkGroupsCount();
358 }
359 
360 }  // namespace metal
361 }  // namespace gpu
362 }  // namespace tflite
363