1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#import "tensorflow/lite/delegates/gpu/metal_delegate.h" 17 18#import <Metal/Metal.h> 19 20#include <algorithm> 21#include <cstring> 22#include <map> 23#include <memory> 24#include <mutex> 25#include <string> 26#include <thread> 27#include <utility> 28#include <vector> 29 30#include "absl/container/flat_hash_set.h" 31#include "absl/types/span.h" 32#include "tensorflow/lite/builtin_ops.h" 33#include "tensorflow/lite/c/common.h" 34#include "tensorflow/lite/context_util.h" 35#include "tensorflow/lite/delegates/gpu/common/convert.h" 36#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" 37#include "tensorflow/lite/delegates/gpu/common/model.h" 38#include "tensorflow/lite/delegates/gpu/common/model_builder.h" 39#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" 40#include "tensorflow/lite/delegates/gpu/common/precision.h" 41#include "tensorflow/lite/delegates/gpu/common/quantization_util.h" 42#include "tensorflow/lite/delegates/gpu/common/shape.h" 43#include "tensorflow/lite/delegates/gpu/common/status.h" 44#include "tensorflow/lite/delegates/gpu/common/types.h" 45#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h" 46#include "tensorflow/lite/delegates/gpu/metal/common.h" 47#include "tensorflow/lite/delegates/gpu/metal/inference_context.h" 48#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h" 49#include "tensorflow/lite/kernels/kernel_util.h" 50#include "tensorflow/lite/minimal_logging.h" 51 52namespace tflite { 53namespace gpu { 54namespace metal { 55namespace { 56 57// Multi-thread safe alarm clock for preventing GPU sleeping. It spawns lightweight compute tasks 58// until no inference is performing on a device. It's reduces the CPU-to-CPU inference latency. 59// The class is used only for kAggressive wait type. 60class GpuAlarmClock { 61 public: 62 explicit GpuAlarmClock(id<MTLCommandQueue> command_queue) { 63 auto device = [command_queue device]; 64 std::lock_guard<std::mutex> lock(alarms_mutex_); 65 if (!alarms_) alarms_ = new std::map<id<MTLDevice>, GpuAlarmClockInternal*>(); 66 auto it = alarms_->find(device); 67 if (it == alarms_->end()) { 68 internal_ = new GpuAlarmClockInternal(command_queue); 69 (*alarms_)[device] = internal_; 70 } else { 71 internal_ = it->second; 72 internal_->total_alarms_++; 73 } 74 } 75 ~GpuAlarmClock() { 76 std::lock_guard<std::mutex> lock(alarms_mutex_); 77 if (--internal_->total_alarms_ > 0) return; 78 Stop(); 79 delete internal_; 80 // Remove the alarm from the container to free-up device handle. 81 for (auto it = alarms_->begin(); it != alarms_->end(); ++it) { 82 if (it->second == internal_) { 83 alarms_->erase(it); 84 break; 85 } 86 } 87 if (alarms_->empty()) { 88 delete alarms_; 89 alarms_ = nullptr; 90 } 91 } 92 void Start() { 93 if (started_) return; 94 started_ = true; 95 internal_->active_alarms_++; 96 } 97 void Stop() { 98 if (!started_) return; 99 started_ = false; 100 internal_->active_alarms_--; 101 } 102 103 private: 104 class GpuAlarmClockInternal { 105 public: 106 id<MTLComputePipelineState> stub_program_; 107 id<MTLBuffer> stub_buffer_; 108 explicit GpuAlarmClockInternal(id<MTLCommandQueue> command_queue) { 109 command_queue_ = command_queue; 110 device_ = [command_queue_ device]; 111 total_alarms_ = 1; 112 NSString* error; 113 id<MTLComputePipelineState> program; 114 // TODO(impjdi): Properly handle returned status. 115 CreateComputeProgram(device_, 116 "kernel void ComputeFunction(device int* output_buffer [[buffer(0)]]) { " 117 "output_buffer[0] = 0; }", 118 "ComputeFunction", {}, &program) 119 .IgnoreError(); 120 stub_program_ = program; 121 stub_buffer_ = [device_ newBufferWithLength:sizeof(int) * 4 122 options:MTLResourceHazardTrackingModeUntracked]; 123 alarm_thread_ = std::thread([this]() { 124 id<MTLCommandBuffer> prev_command_buffer; 125 while (!release_thread_) { 126 @autoreleasepool { 127 if (active_alarms_ == total_alarms_) { 128 id<MTLCommandBuffer> command_buffer = [command_queue_ commandBuffer]; 129 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder]; 130 [encoder setComputePipelineState:stub_program_]; 131 [encoder setBuffer:stub_buffer_ offset:0 atIndex:0]; 132 [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) 133 threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; 134 [encoder endEncoding]; 135 [command_buffer commit]; 136 if (prev_command_buffer != nil) [prev_command_buffer waitUntilScheduled]; 137 prev_command_buffer = command_buffer; 138 } else { 139 std::this_thread::sleep_for(std::chrono::milliseconds(1)); 140 } 141 } 142 } 143 }); 144 } 145 ~GpuAlarmClockInternal() { 146 release_thread_ = true; 147 alarm_thread_.join(); 148 } 149 150 private: 151 friend class GpuAlarmClock; 152 std::atomic<int> active_alarms_; 153 std::thread alarm_thread_; 154 id<MTLCommandQueue> command_queue_; 155 id<MTLDevice> device_; 156 volatile bool release_thread_ = false; 157 int total_alarms_ = 0; 158 }; 159 static std::map<id<MTLDevice>, GpuAlarmClockInternal*>* alarms_; 160 std::mutex alarms_mutex_; 161 GpuAlarmClockInternal* internal_; 162 bool started_ = false; 163}; 164std::map<id<MTLDevice>, GpuAlarmClock::GpuAlarmClockInternal*>* GpuAlarmClock::alarms_ = nullptr; 165 166// Forward declaration. 167TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate); 168 169class Delegate { 170 struct ValueRef { 171 BHWC shape; 172 int64_t tensor_id; 173 }; 174 175 public: 176 explicit Delegate(const TFLGpuDelegateOptions* options) { 177 if (options) { 178 options_ = *options; 179 } else { 180 options_ = TFLGpuDelegateOptionsDefault(); 181 } 182 metal_device_ = MTLCreateSystemDefaultDevice(); 183 command_queue_ = [metal_device_ newCommandQueue]; 184 if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) { 185 gpu_alarm_clock_ = std::unique_ptr<GpuAlarmClock>(new GpuAlarmClock(command_queue_)); 186 const std::string code = R"( 187 kernel void ComputeFunction(device int* output_buffer [[buffer(0)]], 188 constant int& value [[buffer(1)]]) { 189 output_buffer[0] = value; 190 } 191 )"; 192 NSString* error; 193 id<MTLComputePipelineState> signal_program; 194 // TODO(impjdi): Properly handle returned status. 195 CreateComputeProgram(metal_device_, code, "ComputeFunction", {}, &signal_program) 196 .IgnoreError(); 197 signal_program_ = signal_program; 198 signal_buffer_ = [metal_device_ newBufferWithLength:sizeof(int) * 4 199 options:MTLResourceStorageModeShared | 200 MTLResourceHazardTrackingModeUntracked]; 201 } 202 } 203 204 absl::Status BindBufferToTensor(id<MTLBuffer> buffer, int tensor_index) { 205 // The tensor index is expected to be an input or output tensor of the interpreter. 206 // For quantized model, the buffer should be linked with their dequantized counterpart. 207 if (quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) { 208 tensor_index = quant_conversion_map_[tensor_index]; 209 // remove [dequantized tensor ID] -> [quantized tensor ID] mapping, to prevent extra 210 // dequant/quant on in/outputs. 211 quant_conversion_map_.erase(tensor_index); 212 } 213 for (auto& input : graph_inputs_) { 214 if (input.tensor_id == tensor_index) { 215 if (in_out_tensors_[input.id]->GetBufferHandle() != buffer) { 216 RETURN_IF_ERROR(in_out_tensors_[input.id]->SetBufferHandle(buffer)); 217 RETURN_IF_ERROR(inference_context_.SetTensor(input.id, in_out_tensors_[input.id].get())); 218 } 219 input.set_externally = true; 220 return absl::OkStatus(); 221 } 222 } 223 for (auto& output : graph_outputs_) { 224 if (output.tensor_id == tensor_index) { 225 if (in_out_tensors_[output.id]->GetBufferHandle() != buffer) { 226 RETURN_IF_ERROR(in_out_tensors_[output.id]->SetBufferHandle(buffer)); 227 RETURN_IF_ERROR( 228 inference_context_.SetTensor(output.id, in_out_tensors_[output.id].get())); 229 } 230 output.set_externally = true; 231 return absl::OkStatus(); 232 } 233 } 234 return absl::NotFoundError("Couldn't find tensor: " + std::to_string(tensor_index)); 235 } 236 237 void SetCommandBuffer(id<MTLCommandBuffer> command_buffer) { 238 external_command_buffer_ = command_buffer; 239 } 240 241 // This directs the runtime to allocate memory for input/output temporary 242 // tensors that require dequantization/quantization. 243 absl::Status GetRequiredTemporaries(TfLiteContext* context, TfLiteNode* node, 244 TfLiteIntArray** temporaries_array_ptr) { 245 if (quant_conversion_map_.empty()) return absl::OkStatus(); 246 247 std::vector<int> temporary_tensor_ids; 248 for (auto index : input_tensor_ids_) { 249 if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) { 250 temporary_tensor_ids.push_back(index); 251 } 252 } 253 for (auto index : output_tensor_ids_) { 254 if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) { 255 temporary_tensor_ids.push_back(index); 256 } 257 } 258 *temporaries_array_ptr = TfLiteIntArrayCreate(temporary_tensor_ids.size()); 259 for (int i = 0; i < temporary_tensor_ids.size(); ++i) { 260 (*temporaries_array_ptr)->data[i] = temporary_tensor_ids[i]; 261 } 262 return absl::OkStatus(); 263 } 264 265 absl::Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) { 266 // Extract TFLite delegate execution plan from the context and convert it into GraphFloat32. 267 GraphFloat32 graph; 268 quant_conversion_map_.clear(); 269 if (options_.enable_quantization) { 270 RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, &graph, &quant_conversion_map_)); 271 } else { 272 RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, &graph)); 273 } 274 275 // TODO(impjdi): Remove code duplication. 276 auto values = graph.values(); 277 auto find_value = [&](int tensor_index) -> Value* { 278 for (auto value : values) { 279 if (value->tensor.ref == tensor_index) return value; 280 } 281 return nullptr; 282 }; 283 tensors_.reserve(values.back()->id + 1); 284 for (const auto* value : values) { 285 if (tensors_.size() <= value->id) tensors_.resize(value->id + 1); 286 tensors_[value->id] = { 287 value->tensor.shape, // .shape 288 value->tensor.ref, // .tensor_id 289 }; 290 } 291 292 // Prepare graph inputs. 293 // 294 // Note that graph.inputs() cannot be used directly, as the notion of graph input has a 295 // different meaning in public API and GPU-internal API. 296 for (int tensor_index : TfLiteIntArrayView(delegate_params->input_tensors)) { 297 auto* tensor = &context->tensors[tensor_index]; 298 if (IsConstantTensor(tensor)) continue; 299 // For quantized models, actual inputs of GPU graph are float tensors, so the 8-bit inputs 300 // to the delegate kernel need to be dequantized berfore feeding to the GPU graph. 301 if (options_.enable_quantization && 302 quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) { 303 tensor_index = quant_conversion_map_[tensor_index]; 304 tensor = &context->tensors[tensor_index]; 305 } 306 const auto* input = find_value(tensor_index); 307 if (!input || tensor->type != TfLiteType::kTfLiteFloat32) { 308 return absl::NotFoundError("Input tensor is not found in the graph."); 309 } 310 311 inputs_.push_back(input->id); 312 input_tensor_ids_.push_back(tensor_index); 313 tensor->buffer_handle = input->id; 314 tensor->delegate = &delegate_; 315 } 316 317 // Prepare graph outputs. 318 // 319 // Note that graph.outputs() cannot be used directly, as the notion of graph output has a 320 // different meaning in public API and GPU-internal API. 321 for (int tensor_index : TfLiteIntArrayView(delegate_params->output_tensors)) { 322 auto* tensor = &context->tensors[tensor_index]; 323 if (IsConstantTensor(tensor)) continue; 324 // For quantized models, actual outputs of GPU graph are float tensors, so they should be 325 // quantized to be the 8-bit outputs of delegate. 326 if (options_.enable_quantization && 327 quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) { 328 tensor_index = quant_conversion_map_[tensor_index]; 329 tensor = &context->tensors[tensor_index]; 330 } 331 const auto* output = find_value(tensor_index); 332 if (!output || tensor->type != TfLiteType::kTfLiteFloat32) { 333 return absl::NotFoundError("Output tensor is not found in the graph."); 334 } 335 336 outputs_.push_back(output->id); 337 output_tensor_ids_.push_back(tensor_index); 338 tensor->buffer_handle = output->id; 339 tensor->delegate = &delegate_; 340 } 341 342 std::string device_name = std::string([[metal_device_ name] UTF8String]); 343 GpuInfo gpu_info; 344 GetGpuInfoFromDeviceDescription(device_name, GpuApi::kMetal, &gpu_info); 345 size_t storage_type_size; 346 CalculationsPrecision precision; 347 if (options_.allow_precision_loss) { 348 storage_type_size = sizeof(HalfBits); 349 if (gpu_info.IsRoundToNearestSupported()) { 350 precision = CalculationsPrecision::F16; 351 } else { 352 precision = CalculationsPrecision::F32_F16; 353 } 354 } else { 355 storage_type_size = sizeof(float); 356 precision = CalculationsPrecision::F32; 357 } 358 359 CreateGpuModelInfo create_info; 360 create_info.precision = precision; 361 create_info.storage_type = GetFastestStorageType(gpu_info); 362 create_info.hints.Add(ModelHints::kAllowSpecialKernels); 363 const DataType external_data_type = DeduceDataTypeFromPrecision(create_info.precision); 364 const TensorStorageType external_storage_type = TensorStorageType::BUFFER; 365 for (auto& value : graph.inputs()) { 366 Layout layout = value->tensor.shape.b == 1 ? Layout::HWC : Layout::BHWC; 367 create_info.external_mutable_tensors[value->id] = 368 TensorDescriptor{external_data_type, external_storage_type, layout}; 369 } 370 for (auto& value : graph.outputs()) { 371 Layout layout = value->tensor.shape.b == 1 ? Layout::HWC : Layout::BHWC; 372 create_info.external_mutable_tensors[value->id] = 373 TensorDescriptor{external_data_type, external_storage_type, layout}; 374 } 375 376 // TODO(impjdi): Merge logic with above. 377 // Pre-allocate input and output metal buffers 378 std::vector<::tflite::gpu::ValueId> input_ids; 379 input_ids.reserve(inputs_.size()); 380 std::map<::tflite::gpu::ValueId, BHWC> input_dimensions; 381 graph_inputs_.reserve(inputs_.size()); 382 for (const ValueId input : inputs_) { 383 const auto& input_tensor = tensors_[input]; 384 const auto tensor_id = input_tensor.tensor_id; 385 input_ids.push_back(input); 386 input_dimensions[input] = input_tensor.shape; 387 graph_inputs_.push_back({ 388 input, // .id 389 tensor_id, // .tensor_id 390 input_tensor.shape, // .shape 391 false, // .set_externally 392 393 }); 394 395 // Create BHWC F32 buffer 396 int bhwc_f32_length = 397 static_cast<int>(sizeof(float) * input_tensor.shape.DimensionsProduct()); 398 in_out_bhwc_f32_buffers_[input] = 399 [metal_device_ newBufferWithLength:bhwc_f32_length options:MTLResourceStorageModeShared]; 400 401 // Create shared Metal spatial tensor with storage type BUFFER 402 int bphwc4_length = 403 static_cast<int>(storage_type_size * GetElementsSizeForPHWC4(input_tensor.shape)); 404 id<MTLBuffer> bphwc4_buffer = 405 [metal_device_ newBufferWithLength:bphwc4_length options:MTLResourceStorageModeShared]; 406 MetalSpatialTensor metal_tensor; 407 TensorDescriptor descriptor_with_shape = create_info.external_mutable_tensors[input]; 408 descriptor_with_shape.SetBHWCShape(input_tensor.shape); 409 RETURN_IF_ERROR( 410 CreateTensorSharedBuffer(bphwc4_buffer, descriptor_with_shape, &metal_tensor)); 411 in_out_tensors_[input] = std::make_unique<MetalSpatialTensor>(std::move(metal_tensor)); 412 } 413 414 std::vector<::tflite::gpu::ValueId> output_ids; 415 output_ids.reserve(outputs_.size()); 416 graph_outputs_.reserve(outputs_.size()); 417 for (const ValueId output : outputs_) { 418 const auto& output_tensor = tensors_[output]; 419 const auto tensor_id = output_tensor.tensor_id; 420 output_ids.push_back(output); 421 graph_outputs_.push_back({ 422 output, // .id 423 tensor_id, // .tensor_id 424 output_tensor.shape, // .shape 425 false, // .set_externally 426 }); 427 428 // Create BHWC F32 buffer 429 int bhwc_length = static_cast<int>(sizeof(float) * output_tensor.shape.DimensionsProduct()); 430 in_out_bhwc_f32_buffers_[output] = 431 [metal_device_ newBufferWithLength:bhwc_length options:MTLResourceStorageModeShared]; 432 433 // Create shared Metal spatial tensor with storage type BUFFER 434 int bphwc4_length = 435 static_cast<int>(storage_type_size * GetElementsSizeForPHWC4(output_tensor.shape)); 436 id<MTLBuffer> bphwc4_buffer = 437 [metal_device_ newBufferWithLength:bphwc4_length options:MTLResourceStorageModeShared]; 438 MetalSpatialTensor metal_tensor; 439 TensorDescriptor descriptor_with_shape = create_info.external_mutable_tensors[output]; 440 descriptor_with_shape.SetBHWCShape(output_tensor.shape); 441 RETURN_IF_ERROR( 442 CreateTensorSharedBuffer(bphwc4_buffer, descriptor_with_shape, &metal_tensor)); 443 in_out_tensors_[output] = std::make_unique<MetalSpatialTensor>(std::move(metal_tensor)); 444 } 445 446 // allocate converter bhwc->bphwc4 447 converter_to_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:metal_device_ 448 isFloat16:options_.allow_precision_loss 449 convertToPBHWC4:true]; 450 if (converter_to_BPHWC4_ == nil) { 451 return absl::InternalError("Error initialization of input buffer converter"); 452 } 453 454 // allocate converter bphwc4->bhwc 455 converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:metal_device_ 456 isFloat16:options_.allow_precision_loss 457 convertToPBHWC4:false]; 458 if (converter_from_BPHWC4_ == nil) { 459 return absl::InternalError("Error initialization of output buffer converter"); 460 } 461 462 RETURN_IF_ERROR( 463 inference_context_.InitFromGraphWithTransforms(create_info, &graph, metal_device_)); 464 for (auto& external_tensor : in_out_tensors_) { 465 RETURN_IF_ERROR( 466 inference_context_.SetTensor(external_tensor.first, external_tensor.second.get())); 467 } 468 return absl::OkStatus(); 469 } 470 471 absl::Status Invoke(TfLiteContext* context) { 472 if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) 473 gpu_alarm_clock_->Stop(); 474 // We need only synchronization so volatile works better than atomic which reads from global 475 // memory each time. 476 __block volatile bool buffer_completed = false; 477 id<MTLCommandBuffer> command_buffer = external_command_buffer_; 478 if (external_command_buffer_ == nil) { 479 command_buffer = [command_queue_ commandBuffer]; 480 } 481 const bool flush = external_command_buffer_ == nil && 482 (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive || 483 options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive); 484 const int flush_period = 8; 485 486 const bool is_quantized_model = !quant_conversion_map_.empty(); 487 if (is_quantized_model) { 488 RETURN_IF_ERROR(DequantizeInputs(context, input_tensor_ids_, quant_conversion_map_)); 489 } 490 491 // CPU HWC input data conversion to PHWC4 and fill the GPU buffer 492 for (const auto& input : graph_inputs_) { 493 if (input.set_externally) { 494 continue; 495 } 496 // A user provides data on CPU memory for this buffer - need to copy to MTLBuffer 497 498 TfLiteTensor* tensor = &context->tensors[input.tensor_id]; 499 void* gpu_ptr = [in_out_bhwc_f32_buffers_[input.id] contents]; 500 std::memcpy(gpu_ptr, tensor->data.f, input.shape.DimensionsProduct() * sizeof(float)); 501 id<MTLComputeCommandEncoder> input_encoder = [command_buffer computeCommandEncoder]; 502 [converter_to_BPHWC4_ convertWithEncoder:input_encoder 503 shape:input.shape 504 sourceBuffer:in_out_bhwc_f32_buffers_[input.id] 505 convertedBuffer:in_out_tensors_[input.id]->GetBufferHandle()]; 506 [input_encoder endEncoding]; 507 } 508 509 @autoreleasepool { 510 if (flush) { 511 [command_buffer commit]; 512 inference_context_.EncodeWithCommandQueue(command_queue_, flush_period); 513 command_buffer = [command_queue_ commandBuffer]; 514 } else { 515 inference_context_.EncodeWithCommandBuffer(command_buffer); 516 } 517 } 518 519 for (const auto& output : graph_outputs_) { 520 if (output.set_externally) { 521 continue; 522 } 523 id<MTLComputeCommandEncoder> output_encoder = [command_buffer computeCommandEncoder]; 524 [converter_from_BPHWC4_ convertWithEncoder:output_encoder 525 shape:output.shape 526 sourceBuffer:in_out_tensors_[output.id]->GetBufferHandle() 527 convertedBuffer:in_out_bhwc_f32_buffers_[output.id]]; 528 [output_encoder endEncoding]; 529 } 530 531 if (external_command_buffer_ == nil) { 532 if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) { 533 [command_buffer addCompletedHandler:^(id<MTLCommandBuffer>) { 534 buffer_completed = true; 535 }]; 536 } 537 [command_buffer commit]; 538 if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) { 539 while (!buffer_completed) { 540 // Busy wait. Use local variable. Volatile uses RAM access all the time. 541 for (volatile int i = 0; i < 100; i++) { 542 } 543 } 544 } else if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive) { 545 // passive wait: this thread sleeps until GPU finishes. 546 [command_buffer waitUntilCompleted]; 547 } else if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) { 548 id<MTLCommandBuffer> signal_cb = [command_queue_ commandBuffer]; 549 id<MTLComputeCommandEncoder> signal_encoder = [signal_cb computeCommandEncoder]; 550 [signal_encoder setComputePipelineState:signal_program_]; 551 [signal_encoder setBuffer:signal_buffer_ offset:0 atIndex:0]; 552 signal_value_++; 553 [signal_encoder setBytes:&signal_value_ length:sizeof(int) atIndex:1]; 554 [signal_encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) 555 threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; 556 [signal_encoder endEncoding]; 557 [signal_cb commit]; 558 gpu_alarm_clock_->Start(); 559 const int* signal_ptr = reinterpret_cast<const int*>([signal_buffer_ contents]); 560 while (signal_ptr[0] != signal_value_) { 561 // Busy wait. Spinning with local variable to avoid RAM pressure. 562 for (volatile int i = 0; i < 100; i++) { 563 } 564 } 565 } 566 } else { 567 // External command buffer must be set before every invoke call. 568 external_command_buffer_ = nil; 569 // External command buffer is assigned so all output buffers are controlled by a user. 570 for (const auto& output : graph_outputs_) { 571 if (!output.set_externally) { 572 return absl::InternalError( 573 "External command encoder is used, but not all output buffers are bound."); 574 } 575 } 576 return absl::OkStatus(); 577 } 578 579 // Retrieve data from GPU and convert from PHWC4 to HWC. 580 for (const auto& output : graph_outputs_) { 581 if (output.set_externally) continue; 582 // A user retrieves data on CPU memory for this buffer - need to copy from MTLBuffer. 583 TfLiteTensor* tensor = context->tensors + output.tensor_id; 584 const void* gpu_ptr = [in_out_bhwc_f32_buffers_[output.id] contents]; 585 std::memcpy(tensor->data.f, gpu_ptr, output.shape.DimensionsProduct() * sizeof(float)); 586 } 587 if (is_quantized_model) { 588 RETURN_IF_ERROR(QuantizeOutputs(context, output_tensor_ids_, quant_conversion_map_)); 589 } 590 return absl::OkStatus(); 591 } 592 593 const TFLGpuDelegateOptions options() const { return options_; } 594 595 TfLiteDelegate* tflite_delegate() { return &delegate_; } 596 597 private: 598 TfLiteDelegate delegate_ = { 599 reinterpret_cast<void*>(this), // .data_ 600 DelegatePrepare, // .Prepare 601 nullptr, // .CopyFromBufferHandle 602 nullptr, // .CopyToBufferHandle 603 nullptr, // .FreeBufferHandle 604 kTfLiteDelegateFlagsNone, // .flags 605 }; 606 607 TFLGpuDelegateOptions options_; 608 609 id<MTLDevice> metal_device_; 610 611 std::vector<ValueRef> tensors_; // indexed by ValueId 612 std::vector<ValueId> inputs_; 613 std::vector<ValueId> outputs_; 614 std::vector<int64_t> input_tensor_ids_; 615 std::vector<int64_t> output_tensor_ids_; 616 // Whenever quantized inference is enabled, this maps the tensor index of each 617 // originally quantized (8-bit) tensor to its float version added in 618 // model_builder - and vice versa. 619 absl::flat_hash_map<int, int> quant_conversion_map_; 620 621 InferenceContext inference_context_; 622 // Metal bhwc f32 input and output buffers for better conversion performance from cpu tensors 623 // We will memcpy cpu<->gpu and use metal for other conversions(layout changes, for example) 624 std::map<ValueId, id<MTLBuffer>> in_out_bhwc_f32_buffers_; 625 // input and output tensors can be set externally with help of 626 // TFLGpuDelegateBindMetalBufferToTensor 627 std::map<ValueId, std::unique_ptr<MetalSpatialTensor>> in_out_tensors_; 628 TFLBufferConvert* converter_to_BPHWC4_ = nil; 629 TFLBufferConvert* converter_from_BPHWC4_ = nil; 630 631 struct BufferDescriptor { 632 ValueId id; 633 int64_t tensor_id; 634 BHWC shape; 635 bool set_externally; // a user fills/retrieves data on this MTLBuffer buffer 636 }; 637 std::vector<BufferDescriptor> graph_inputs_; 638 std::vector<BufferDescriptor> graph_outputs_; 639 640 id<MTLCommandBuffer> external_command_buffer_ = nil; 641 id<MTLCommandQueue> command_queue_; 642 std::unique_ptr<GpuAlarmClock> gpu_alarm_clock_; 643 id<MTLComputePipelineState> signal_program_; 644 id<MTLBuffer> signal_buffer_; 645 int signal_value_ = 0; 646}; 647 648Delegate* GetMetalDelegate(TfLiteNode* node) { 649 return reinterpret_cast<Delegate*>(node->user_data); 650} 651 652Delegate* GetMetalDelegate(TfLiteDelegate* delegate) { 653 return reinterpret_cast<Delegate*>(delegate->data_); 654} 655 656TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { 657 const TfLiteRegistration kRegistration = { 658 // .init 659 [](TfLiteContext* context, const char* buffer, size_t) -> void* { 660 const auto* params = reinterpret_cast<const TfLiteDelegateParams*>(buffer); 661 auto* metal_delegate = GetMetalDelegate(params->delegate); 662 // Everything below should happen in prepare function call, but TFLite for whatever reason 663 // forbids that. 664 const auto status = metal_delegate->Prepare(context, params); 665 if (status.ok()) return metal_delegate; 666 TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Prepare: %s", 667 std::string(status.message()).c_str()); 668 return nullptr; 669 }, 670 // .free 671 [](TfLiteContext*, void* buffer) -> void {}, 672 // .prepare 673 [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { 674 if (!node->user_data) { 675 return kTfLiteError; 676 } 677 678 auto* gpu_delegate_kernel = GetMetalDelegate(node); 679 const auto status = 680 gpu_delegate_kernel->GetRequiredTemporaries(context, node, &node->temporaries); 681 if (!status.ok()) { 682 TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Prepare: %s", 683 std::string(status.message()).c_str()); 684 return kTfLiteError; 685 } 686 return node->user_data ? kTfLiteOk : kTfLiteError; 687 }, 688 // .invoke 689 [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { 690 const auto status = GetMetalDelegate(node)->Invoke(context); 691 if (status.ok()) return kTfLiteOk; 692 TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Invoke: %s", 693 std::string(status.message()).c_str()); 694 return kTfLiteError; 695 }, 696 nullptr, // .profiling_string 697 0, // .builtin_code 698 "TfLiteMetalDelegate", // .custom_name 699 1, // .version 700 }; 701 TfLiteIntArray* ops_to_replace = 702 GetOpsToReplace(context, GetMetalDelegate(delegate)->options().enable_quantization); 703 const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(context, kRegistration, 704 ops_to_replace, delegate); 705 TfLiteIntArrayFree(ops_to_replace); 706 return status; 707} 708 709} // namespace 710} // namespace metal 711} // namespace gpu 712} // namespace tflite 713 714TfLiteDelegate* TFLGpuDelegateCreate(const TFLGpuDelegateOptions* options) { 715 TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "Created TensorFlow Lite delegate for Metal."); 716 auto* metal_delegate = new ::tflite::gpu::metal::Delegate(options); 717 return metal_delegate ? metal_delegate->tflite_delegate() : nullptr; 718} 719 720void TFLGpuDelegateDelete(TfLiteDelegate* delegate) { 721 delete ::tflite::gpu::metal::GetMetalDelegate(delegate); 722} 723 724bool TFLGpuDelegateBindMetalBufferToTensor(TfLiteDelegate* delegate, int tensor_index, 725 id<MTLBuffer> buffer) { 726 auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate); 727 return metal_delegate && metal_delegate->BindBufferToTensor(buffer, tensor_index).ok(); 728} 729 730// Note: This function is not exposed in `metal_delegate.h`, but it's exposed in 731// `metal_delegate_internal.h`. 732bool TFLGpuDelegateSetCommandBuffer(TfLiteDelegate* delegate, 733 id<MTLCommandBuffer> command_buffer) { 734 auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate); 735 if (!metal_delegate) return false; 736 metal_delegate->SetCommandBuffer(command_buffer); 737 return true; 738} 739 740TFLGpuDelegateOptions TFLGpuDelegateOptionsDefault() { 741 TFLGpuDelegateOptions options = { 742 .allow_precision_loss = false, 743 .wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive, 744 .enable_quantization = true, 745 }; 746 return options; 747} 748