xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/metal_delegate.mm (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#import "tensorflow/lite/delegates/gpu/metal_delegate.h"
17
18#import <Metal/Metal.h>
19
20#include <algorithm>
21#include <cstring>
22#include <map>
23#include <memory>
24#include <mutex>
25#include <string>
26#include <thread>
27#include <utility>
28#include <vector>
29
30#include "absl/container/flat_hash_set.h"
31#include "absl/types/span.h"
32#include "tensorflow/lite/builtin_ops.h"
33#include "tensorflow/lite/c/common.h"
34#include "tensorflow/lite/context_util.h"
35#include "tensorflow/lite/delegates/gpu/common/convert.h"
36#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
37#include "tensorflow/lite/delegates/gpu/common/model.h"
38#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
39#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
40#include "tensorflow/lite/delegates/gpu/common/precision.h"
41#include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
42#include "tensorflow/lite/delegates/gpu/common/shape.h"
43#include "tensorflow/lite/delegates/gpu/common/status.h"
44#include "tensorflow/lite/delegates/gpu/common/types.h"
45#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
46#include "tensorflow/lite/delegates/gpu/metal/common.h"
47#include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
48#include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
49#include "tensorflow/lite/kernels/kernel_util.h"
50#include "tensorflow/lite/minimal_logging.h"
51
52namespace tflite {
53namespace gpu {
54namespace metal {
55namespace {
56
57// Multi-thread safe alarm clock for preventing GPU sleeping. It spawns lightweight compute tasks
58// until no inference is performing on a device. It's reduces the CPU-to-CPU inference latency.
59// The class is used only for kAggressive wait type.
60class GpuAlarmClock {
61 public:
62  explicit GpuAlarmClock(id<MTLCommandQueue> command_queue) {
63    auto device = [command_queue device];
64    std::lock_guard<std::mutex> lock(alarms_mutex_);
65    if (!alarms_) alarms_ = new std::map<id<MTLDevice>, GpuAlarmClockInternal*>();
66    auto it = alarms_->find(device);
67    if (it == alarms_->end()) {
68      internal_ = new GpuAlarmClockInternal(command_queue);
69      (*alarms_)[device] = internal_;
70    } else {
71      internal_ = it->second;
72      internal_->total_alarms_++;
73    }
74  }
75  ~GpuAlarmClock() {
76    std::lock_guard<std::mutex> lock(alarms_mutex_);
77    if (--internal_->total_alarms_ > 0) return;
78    Stop();
79    delete internal_;
80    // Remove the alarm from the container to free-up device handle.
81    for (auto it = alarms_->begin(); it != alarms_->end(); ++it) {
82      if (it->second == internal_) {
83        alarms_->erase(it);
84        break;
85      }
86    }
87    if (alarms_->empty()) {
88      delete alarms_;
89      alarms_ = nullptr;
90    }
91  }
92  void Start() {
93    if (started_) return;
94    started_ = true;
95    internal_->active_alarms_++;
96  }
97  void Stop() {
98    if (!started_) return;
99    started_ = false;
100    internal_->active_alarms_--;
101  }
102
103 private:
104  class GpuAlarmClockInternal {
105   public:
106    id<MTLComputePipelineState> stub_program_;
107    id<MTLBuffer> stub_buffer_;
108    explicit GpuAlarmClockInternal(id<MTLCommandQueue> command_queue) {
109      command_queue_ = command_queue;
110      device_ = [command_queue_ device];
111      total_alarms_ = 1;
112      NSString* error;
113      id<MTLComputePipelineState> program;
114      // TODO(impjdi): Properly handle returned status.
115      CreateComputeProgram(device_,
116                           "kernel void ComputeFunction(device int* output_buffer [[buffer(0)]]) { "
117                           "output_buffer[0] = 0; }",
118                           "ComputeFunction", {}, &program)
119          .IgnoreError();
120      stub_program_ = program;
121      stub_buffer_ = [device_ newBufferWithLength:sizeof(int) * 4
122                                          options:MTLResourceHazardTrackingModeUntracked];
123      alarm_thread_ = std::thread([this]() {
124        id<MTLCommandBuffer> prev_command_buffer;
125        while (!release_thread_) {
126          @autoreleasepool {
127            if (active_alarms_ == total_alarms_) {
128              id<MTLCommandBuffer> command_buffer = [command_queue_ commandBuffer];
129              id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
130              [encoder setComputePipelineState:stub_program_];
131              [encoder setBuffer:stub_buffer_ offset:0 atIndex:0];
132              [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
133                      threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
134              [encoder endEncoding];
135              [command_buffer commit];
136              if (prev_command_buffer != nil) [prev_command_buffer waitUntilScheduled];
137              prev_command_buffer = command_buffer;
138            } else {
139              std::this_thread::sleep_for(std::chrono::milliseconds(1));
140            }
141          }
142        }
143      });
144    }
145    ~GpuAlarmClockInternal() {
146      release_thread_ = true;
147      alarm_thread_.join();
148    }
149
150   private:
151    friend class GpuAlarmClock;
152    std::atomic<int> active_alarms_;
153    std::thread alarm_thread_;
154    id<MTLCommandQueue> command_queue_;
155    id<MTLDevice> device_;
156    volatile bool release_thread_ = false;
157    int total_alarms_ = 0;
158  };
159  static std::map<id<MTLDevice>, GpuAlarmClockInternal*>* alarms_;
160  std::mutex alarms_mutex_;
161  GpuAlarmClockInternal* internal_;
162  bool started_ = false;
163};
164std::map<id<MTLDevice>, GpuAlarmClock::GpuAlarmClockInternal*>* GpuAlarmClock::alarms_ = nullptr;
165
166// Forward declaration.
167TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate);
168
169class Delegate {
170  struct ValueRef {
171    BHWC shape;
172    int64_t tensor_id;
173  };
174
175 public:
176  explicit Delegate(const TFLGpuDelegateOptions* options) {
177    if (options) {
178      options_ = *options;
179    } else {
180      options_ = TFLGpuDelegateOptionsDefault();
181    }
182    metal_device_ = MTLCreateSystemDefaultDevice();
183    command_queue_ = [metal_device_ newCommandQueue];
184    if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) {
185      gpu_alarm_clock_ = std::unique_ptr<GpuAlarmClock>(new GpuAlarmClock(command_queue_));
186      const std::string code = R"(
187          kernel void ComputeFunction(device int* output_buffer [[buffer(0)]],
188                                      constant int& value [[buffer(1)]]) {
189            output_buffer[0] = value;
190          }
191        )";
192      NSString* error;
193      id<MTLComputePipelineState> signal_program;
194      // TODO(impjdi): Properly handle returned status.
195      CreateComputeProgram(metal_device_, code, "ComputeFunction", {}, &signal_program)
196          .IgnoreError();
197      signal_program_ = signal_program;
198      signal_buffer_ = [metal_device_ newBufferWithLength:sizeof(int) * 4
199                                                  options:MTLResourceStorageModeShared |
200                                                          MTLResourceHazardTrackingModeUntracked];
201    }
202  }
203
204  absl::Status BindBufferToTensor(id<MTLBuffer> buffer, int tensor_index) {
205    // The tensor index is expected to be an input or output tensor of the interpreter.
206    // For quantized model, the buffer should be linked with their dequantized counterpart.
207    if (quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) {
208      tensor_index = quant_conversion_map_[tensor_index];
209      // remove [dequantized tensor ID] -> [quantized tensor ID] mapping, to prevent extra
210      // dequant/quant on in/outputs.
211      quant_conversion_map_.erase(tensor_index);
212    }
213    for (auto& input : graph_inputs_) {
214      if (input.tensor_id == tensor_index) {
215        if (in_out_tensors_[input.id]->GetBufferHandle() != buffer) {
216          RETURN_IF_ERROR(in_out_tensors_[input.id]->SetBufferHandle(buffer));
217          RETURN_IF_ERROR(inference_context_.SetTensor(input.id, in_out_tensors_[input.id].get()));
218        }
219        input.set_externally = true;
220        return absl::OkStatus();
221      }
222    }
223    for (auto& output : graph_outputs_) {
224      if (output.tensor_id == tensor_index) {
225        if (in_out_tensors_[output.id]->GetBufferHandle() != buffer) {
226          RETURN_IF_ERROR(in_out_tensors_[output.id]->SetBufferHandle(buffer));
227          RETURN_IF_ERROR(
228              inference_context_.SetTensor(output.id, in_out_tensors_[output.id].get()));
229        }
230        output.set_externally = true;
231        return absl::OkStatus();
232      }
233    }
234    return absl::NotFoundError("Couldn't find tensor: " + std::to_string(tensor_index));
235  }
236
237  void SetCommandBuffer(id<MTLCommandBuffer> command_buffer) {
238    external_command_buffer_ = command_buffer;
239  }
240
241  // This directs the runtime to allocate memory for input/output temporary
242  // tensors that require dequantization/quantization.
243  absl::Status GetRequiredTemporaries(TfLiteContext* context, TfLiteNode* node,
244                                      TfLiteIntArray** temporaries_array_ptr) {
245    if (quant_conversion_map_.empty()) return absl::OkStatus();
246
247    std::vector<int> temporary_tensor_ids;
248    for (auto index : input_tensor_ids_) {
249      if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) {
250        temporary_tensor_ids.push_back(index);
251      }
252    }
253    for (auto index : output_tensor_ids_) {
254      if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) {
255        temporary_tensor_ids.push_back(index);
256      }
257    }
258    *temporaries_array_ptr = TfLiteIntArrayCreate(temporary_tensor_ids.size());
259    for (int i = 0; i < temporary_tensor_ids.size(); ++i) {
260      (*temporaries_array_ptr)->data[i] = temporary_tensor_ids[i];
261    }
262    return absl::OkStatus();
263  }
264
265  absl::Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) {
266    // Extract TFLite delegate execution plan from the context and convert it into GraphFloat32.
267    GraphFloat32 graph;
268    quant_conversion_map_.clear();
269    if (options_.enable_quantization) {
270      RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, &graph, &quant_conversion_map_));
271    } else {
272      RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, &graph));
273    }
274
275    // TODO(impjdi): Remove code duplication.
276    auto values = graph.values();
277    auto find_value = [&](int tensor_index) -> Value* {
278      for (auto value : values) {
279        if (value->tensor.ref == tensor_index) return value;
280      }
281      return nullptr;
282    };
283    tensors_.reserve(values.back()->id + 1);
284    for (const auto* value : values) {
285      if (tensors_.size() <= value->id) tensors_.resize(value->id + 1);
286      tensors_[value->id] = {
287          value->tensor.shape,  // .shape
288          value->tensor.ref,    // .tensor_id
289      };
290    }
291
292    // Prepare graph inputs.
293    //
294    // Note that graph.inputs() cannot be used directly, as the notion of graph input has a
295    // different meaning in public API and GPU-internal API.
296    for (int tensor_index : TfLiteIntArrayView(delegate_params->input_tensors)) {
297      auto* tensor = &context->tensors[tensor_index];
298      if (IsConstantTensor(tensor)) continue;
299      // For quantized models, actual inputs of GPU graph are float tensors, so the 8-bit inputs
300      // to the delegate kernel need to be dequantized berfore feeding to the GPU graph.
301      if (options_.enable_quantization &&
302          quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) {
303        tensor_index = quant_conversion_map_[tensor_index];
304        tensor = &context->tensors[tensor_index];
305      }
306      const auto* input = find_value(tensor_index);
307      if (!input || tensor->type != TfLiteType::kTfLiteFloat32) {
308        return absl::NotFoundError("Input tensor is not found in the graph.");
309      }
310
311      inputs_.push_back(input->id);
312      input_tensor_ids_.push_back(tensor_index);
313      tensor->buffer_handle = input->id;
314      tensor->delegate = &delegate_;
315    }
316
317    // Prepare graph outputs.
318    //
319    // Note that graph.outputs() cannot be used directly, as the notion of graph output has a
320    // different meaning in public API and GPU-internal API.
321    for (int tensor_index : TfLiteIntArrayView(delegate_params->output_tensors)) {
322      auto* tensor = &context->tensors[tensor_index];
323      if (IsConstantTensor(tensor)) continue;
324      // For quantized models, actual outputs of GPU graph are float tensors, so they should be
325      // quantized to be the 8-bit outputs of delegate.
326      if (options_.enable_quantization &&
327          quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) {
328        tensor_index = quant_conversion_map_[tensor_index];
329        tensor = &context->tensors[tensor_index];
330      }
331      const auto* output = find_value(tensor_index);
332      if (!output || tensor->type != TfLiteType::kTfLiteFloat32) {
333        return absl::NotFoundError("Output tensor is not found in the graph.");
334      }
335
336      outputs_.push_back(output->id);
337      output_tensor_ids_.push_back(tensor_index);
338      tensor->buffer_handle = output->id;
339      tensor->delegate = &delegate_;
340    }
341
342    std::string device_name = std::string([[metal_device_ name] UTF8String]);
343    GpuInfo gpu_info;
344    GetGpuInfoFromDeviceDescription(device_name, GpuApi::kMetal, &gpu_info);
345    size_t storage_type_size;
346    CalculationsPrecision precision;
347    if (options_.allow_precision_loss) {
348      storage_type_size = sizeof(HalfBits);
349      if (gpu_info.IsRoundToNearestSupported()) {
350        precision = CalculationsPrecision::F16;
351      } else {
352        precision = CalculationsPrecision::F32_F16;
353      }
354    } else {
355      storage_type_size = sizeof(float);
356      precision = CalculationsPrecision::F32;
357    }
358
359    CreateGpuModelInfo create_info;
360    create_info.precision = precision;
361    create_info.storage_type = GetFastestStorageType(gpu_info);
362    create_info.hints.Add(ModelHints::kAllowSpecialKernels);
363    const DataType external_data_type = DeduceDataTypeFromPrecision(create_info.precision);
364    const TensorStorageType external_storage_type = TensorStorageType::BUFFER;
365    for (auto& value : graph.inputs()) {
366      Layout layout = value->tensor.shape.b == 1 ? Layout::HWC : Layout::BHWC;
367      create_info.external_mutable_tensors[value->id] =
368          TensorDescriptor{external_data_type, external_storage_type, layout};
369    }
370    for (auto& value : graph.outputs()) {
371      Layout layout = value->tensor.shape.b == 1 ? Layout::HWC : Layout::BHWC;
372      create_info.external_mutable_tensors[value->id] =
373          TensorDescriptor{external_data_type, external_storage_type, layout};
374    }
375
376    // TODO(impjdi): Merge logic with above.
377    // Pre-allocate input and output metal buffers
378    std::vector<::tflite::gpu::ValueId> input_ids;
379    input_ids.reserve(inputs_.size());
380    std::map<::tflite::gpu::ValueId, BHWC> input_dimensions;
381    graph_inputs_.reserve(inputs_.size());
382    for (const ValueId input : inputs_) {
383      const auto& input_tensor = tensors_[input];
384      const auto tensor_id = input_tensor.tensor_id;
385      input_ids.push_back(input);
386      input_dimensions[input] = input_tensor.shape;
387      graph_inputs_.push_back({
388          input,               // .id
389          tensor_id,           // .tensor_id
390          input_tensor.shape,  // .shape
391          false,               // .set_externally
392
393      });
394
395      // Create BHWC F32 buffer
396      int bhwc_f32_length =
397          static_cast<int>(sizeof(float) * input_tensor.shape.DimensionsProduct());
398      in_out_bhwc_f32_buffers_[input] =
399          [metal_device_ newBufferWithLength:bhwc_f32_length options:MTLResourceStorageModeShared];
400
401      // Create shared Metal spatial tensor with storage type BUFFER
402      int bphwc4_length =
403          static_cast<int>(storage_type_size * GetElementsSizeForPHWC4(input_tensor.shape));
404      id<MTLBuffer> bphwc4_buffer =
405          [metal_device_ newBufferWithLength:bphwc4_length options:MTLResourceStorageModeShared];
406      MetalSpatialTensor metal_tensor;
407      TensorDescriptor descriptor_with_shape = create_info.external_mutable_tensors[input];
408      descriptor_with_shape.SetBHWCShape(input_tensor.shape);
409      RETURN_IF_ERROR(
410          CreateTensorSharedBuffer(bphwc4_buffer, descriptor_with_shape, &metal_tensor));
411      in_out_tensors_[input] = std::make_unique<MetalSpatialTensor>(std::move(metal_tensor));
412    }
413
414    std::vector<::tflite::gpu::ValueId> output_ids;
415    output_ids.reserve(outputs_.size());
416    graph_outputs_.reserve(outputs_.size());
417    for (const ValueId output : outputs_) {
418      const auto& output_tensor = tensors_[output];
419      const auto tensor_id = output_tensor.tensor_id;
420      output_ids.push_back(output);
421      graph_outputs_.push_back({
422          output,               // .id
423          tensor_id,            // .tensor_id
424          output_tensor.shape,  // .shape
425          false,                // .set_externally
426      });
427
428      // Create BHWC F32 buffer
429      int bhwc_length = static_cast<int>(sizeof(float) * output_tensor.shape.DimensionsProduct());
430      in_out_bhwc_f32_buffers_[output] =
431          [metal_device_ newBufferWithLength:bhwc_length options:MTLResourceStorageModeShared];
432
433      // Create shared Metal spatial tensor with storage type BUFFER
434      int bphwc4_length =
435          static_cast<int>(storage_type_size * GetElementsSizeForPHWC4(output_tensor.shape));
436      id<MTLBuffer> bphwc4_buffer =
437          [metal_device_ newBufferWithLength:bphwc4_length options:MTLResourceStorageModeShared];
438      MetalSpatialTensor metal_tensor;
439      TensorDescriptor descriptor_with_shape = create_info.external_mutable_tensors[output];
440      descriptor_with_shape.SetBHWCShape(output_tensor.shape);
441      RETURN_IF_ERROR(
442          CreateTensorSharedBuffer(bphwc4_buffer, descriptor_with_shape, &metal_tensor));
443      in_out_tensors_[output] = std::make_unique<MetalSpatialTensor>(std::move(metal_tensor));
444    }
445
446    // allocate converter bhwc->bphwc4
447    converter_to_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:metal_device_
448                                                          isFloat16:options_.allow_precision_loss
449                                                    convertToPBHWC4:true];
450    if (converter_to_BPHWC4_ == nil) {
451      return absl::InternalError("Error initialization of input buffer converter");
452    }
453
454    // allocate converter bphwc4->bhwc
455    converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:metal_device_
456                                                            isFloat16:options_.allow_precision_loss
457                                                      convertToPBHWC4:false];
458    if (converter_from_BPHWC4_ == nil) {
459      return absl::InternalError("Error initialization of output buffer converter");
460    }
461
462    RETURN_IF_ERROR(
463        inference_context_.InitFromGraphWithTransforms(create_info, &graph, metal_device_));
464    for (auto& external_tensor : in_out_tensors_) {
465      RETURN_IF_ERROR(
466          inference_context_.SetTensor(external_tensor.first, external_tensor.second.get()));
467    }
468    return absl::OkStatus();
469  }
470
471  absl::Status Invoke(TfLiteContext* context) {
472    if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive)
473      gpu_alarm_clock_->Stop();
474    // We need only synchronization so volatile works better than atomic which reads from global
475    // memory each time.
476    __block volatile bool buffer_completed = false;
477    id<MTLCommandBuffer> command_buffer = external_command_buffer_;
478    if (external_command_buffer_ == nil) {
479      command_buffer = [command_queue_ commandBuffer];
480    }
481    const bool flush = external_command_buffer_ == nil &&
482        (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive ||
483         options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive);
484    const int flush_period = 8;
485
486    const bool is_quantized_model = !quant_conversion_map_.empty();
487    if (is_quantized_model) {
488      RETURN_IF_ERROR(DequantizeInputs(context, input_tensor_ids_, quant_conversion_map_));
489    }
490
491    // CPU HWC input data conversion to PHWC4 and fill the GPU buffer
492    for (const auto& input : graph_inputs_) {
493      if (input.set_externally) {
494        continue;
495      }
496      // A user provides data on CPU memory for this buffer - need to copy to MTLBuffer
497
498      TfLiteTensor* tensor = &context->tensors[input.tensor_id];
499      void* gpu_ptr = [in_out_bhwc_f32_buffers_[input.id] contents];
500      std::memcpy(gpu_ptr, tensor->data.f, input.shape.DimensionsProduct() * sizeof(float));
501      id<MTLComputeCommandEncoder> input_encoder = [command_buffer computeCommandEncoder];
502      [converter_to_BPHWC4_ convertWithEncoder:input_encoder
503                                         shape:input.shape
504                                  sourceBuffer:in_out_bhwc_f32_buffers_[input.id]
505                               convertedBuffer:in_out_tensors_[input.id]->GetBufferHandle()];
506      [input_encoder endEncoding];
507    }
508
509    @autoreleasepool {
510      if (flush) {
511        [command_buffer commit];
512        inference_context_.EncodeWithCommandQueue(command_queue_, flush_period);
513        command_buffer = [command_queue_ commandBuffer];
514      } else {
515        inference_context_.EncodeWithCommandBuffer(command_buffer);
516      }
517    }
518
519    for (const auto& output : graph_outputs_) {
520      if (output.set_externally) {
521        continue;
522      }
523      id<MTLComputeCommandEncoder> output_encoder = [command_buffer computeCommandEncoder];
524      [converter_from_BPHWC4_ convertWithEncoder:output_encoder
525                                           shape:output.shape
526                                    sourceBuffer:in_out_tensors_[output.id]->GetBufferHandle()
527                                 convertedBuffer:in_out_bhwc_f32_buffers_[output.id]];
528      [output_encoder endEncoding];
529    }
530
531    if (external_command_buffer_ == nil) {
532      if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) {
533        [command_buffer addCompletedHandler:^(id<MTLCommandBuffer>) {
534          buffer_completed = true;
535        }];
536      }
537      [command_buffer commit];
538      if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) {
539        while (!buffer_completed) {
540          // Busy wait. Use local variable. Volatile uses RAM access all the time.
541          for (volatile int i = 0; i < 100; i++) {
542          }
543        }
544      } else if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive) {
545        // passive wait: this thread sleeps until GPU finishes.
546        [command_buffer waitUntilCompleted];
547      } else if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) {
548        id<MTLCommandBuffer> signal_cb = [command_queue_ commandBuffer];
549        id<MTLComputeCommandEncoder> signal_encoder = [signal_cb computeCommandEncoder];
550        [signal_encoder setComputePipelineState:signal_program_];
551        [signal_encoder setBuffer:signal_buffer_ offset:0 atIndex:0];
552        signal_value_++;
553        [signal_encoder setBytes:&signal_value_ length:sizeof(int) atIndex:1];
554        [signal_encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
555                threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
556        [signal_encoder endEncoding];
557        [signal_cb commit];
558        gpu_alarm_clock_->Start();
559        const int* signal_ptr = reinterpret_cast<const int*>([signal_buffer_ contents]);
560        while (signal_ptr[0] != signal_value_) {
561          // Busy wait. Spinning with local variable to avoid RAM pressure.
562          for (volatile int i = 0; i < 100; i++) {
563          }
564        }
565      }
566    } else {
567      // External command buffer must be set before every invoke call.
568      external_command_buffer_ = nil;
569      // External command buffer is assigned so all output buffers are controlled by a user.
570      for (const auto& output : graph_outputs_) {
571        if (!output.set_externally) {
572          return absl::InternalError(
573              "External command encoder is used, but not all output buffers are bound.");
574        }
575      }
576      return absl::OkStatus();
577    }
578
579    // Retrieve data from GPU and convert from PHWC4 to HWC.
580    for (const auto& output : graph_outputs_) {
581      if (output.set_externally) continue;
582      // A user retrieves data on CPU memory for this buffer - need to copy from MTLBuffer.
583      TfLiteTensor* tensor = context->tensors + output.tensor_id;
584      const void* gpu_ptr = [in_out_bhwc_f32_buffers_[output.id] contents];
585      std::memcpy(tensor->data.f, gpu_ptr, output.shape.DimensionsProduct() * sizeof(float));
586    }
587    if (is_quantized_model) {
588      RETURN_IF_ERROR(QuantizeOutputs(context, output_tensor_ids_, quant_conversion_map_));
589    }
590    return absl::OkStatus();
591  }
592
593  const TFLGpuDelegateOptions options() const { return options_; }
594
595  TfLiteDelegate* tflite_delegate() { return &delegate_; }
596
597 private:
598  TfLiteDelegate delegate_ = {
599      reinterpret_cast<void*>(this),  // .data_
600      DelegatePrepare,                // .Prepare
601      nullptr,                        // .CopyFromBufferHandle
602      nullptr,                        // .CopyToBufferHandle
603      nullptr,                        // .FreeBufferHandle
604      kTfLiteDelegateFlagsNone,       // .flags
605  };
606
607  TFLGpuDelegateOptions options_;
608
609  id<MTLDevice> metal_device_;
610
611  std::vector<ValueRef> tensors_;  // indexed by ValueId
612  std::vector<ValueId> inputs_;
613  std::vector<ValueId> outputs_;
614  std::vector<int64_t> input_tensor_ids_;
615  std::vector<int64_t> output_tensor_ids_;
616  // Whenever quantized inference is enabled, this maps the tensor index of each
617  // originally quantized (8-bit) tensor to its float version added in
618  // model_builder - and vice versa.
619  absl::flat_hash_map<int, int> quant_conversion_map_;
620
621  InferenceContext inference_context_;
622  // Metal bhwc f32 input and output buffers for better conversion performance from cpu tensors
623  // We will memcpy cpu<->gpu and use metal for other conversions(layout changes, for example)
624  std::map<ValueId, id<MTLBuffer>> in_out_bhwc_f32_buffers_;
625  // input and output tensors can be set externally with help of
626  // TFLGpuDelegateBindMetalBufferToTensor
627  std::map<ValueId, std::unique_ptr<MetalSpatialTensor>> in_out_tensors_;
628  TFLBufferConvert* converter_to_BPHWC4_ = nil;
629  TFLBufferConvert* converter_from_BPHWC4_ = nil;
630
631  struct BufferDescriptor {
632    ValueId id;
633    int64_t tensor_id;
634    BHWC shape;
635    bool set_externally;  // a user fills/retrieves data on this MTLBuffer buffer
636  };
637  std::vector<BufferDescriptor> graph_inputs_;
638  std::vector<BufferDescriptor> graph_outputs_;
639
640  id<MTLCommandBuffer> external_command_buffer_ = nil;
641  id<MTLCommandQueue> command_queue_;
642  std::unique_ptr<GpuAlarmClock> gpu_alarm_clock_;
643  id<MTLComputePipelineState> signal_program_;
644  id<MTLBuffer> signal_buffer_;
645  int signal_value_ = 0;
646};
647
648Delegate* GetMetalDelegate(TfLiteNode* node) {
649  return reinterpret_cast<Delegate*>(node->user_data);
650}
651
652Delegate* GetMetalDelegate(TfLiteDelegate* delegate) {
653  return reinterpret_cast<Delegate*>(delegate->data_);
654}
655
656TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
657  const TfLiteRegistration kRegistration = {
658      // .init
659      [](TfLiteContext* context, const char* buffer, size_t) -> void* {
660        const auto* params = reinterpret_cast<const TfLiteDelegateParams*>(buffer);
661        auto* metal_delegate = GetMetalDelegate(params->delegate);
662        // Everything below should happen in prepare function call, but TFLite for whatever reason
663        // forbids that.
664        const auto status = metal_delegate->Prepare(context, params);
665        if (status.ok()) return metal_delegate;
666        TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Prepare: %s",
667                           std::string(status.message()).c_str());
668        return nullptr;
669      },
670      // .free
671      [](TfLiteContext*, void* buffer) -> void {},
672      // .prepare
673      [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
674        if (!node->user_data) {
675          return kTfLiteError;
676        }
677
678        auto* gpu_delegate_kernel = GetMetalDelegate(node);
679        const auto status =
680            gpu_delegate_kernel->GetRequiredTemporaries(context, node, &node->temporaries);
681        if (!status.ok()) {
682          TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Prepare: %s",
683                             std::string(status.message()).c_str());
684          return kTfLiteError;
685        }
686        return node->user_data ? kTfLiteOk : kTfLiteError;
687      },
688      // .invoke
689      [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
690        const auto status = GetMetalDelegate(node)->Invoke(context);
691        if (status.ok()) return kTfLiteOk;
692        TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Invoke: %s",
693                           std::string(status.message()).c_str());
694        return kTfLiteError;
695      },
696      nullptr,                // .profiling_string
697      0,                      // .builtin_code
698      "TfLiteMetalDelegate",  // .custom_name
699      1,                      // .version
700  };
701  TfLiteIntArray* ops_to_replace =
702      GetOpsToReplace(context, GetMetalDelegate(delegate)->options().enable_quantization);
703  const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(context, kRegistration,
704                                                                     ops_to_replace, delegate);
705  TfLiteIntArrayFree(ops_to_replace);
706  return status;
707}
708
709}  // namespace
710}  // namespace metal
711}  // namespace gpu
712}  // namespace tflite
713
714TfLiteDelegate* TFLGpuDelegateCreate(const TFLGpuDelegateOptions* options) {
715  TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "Created TensorFlow Lite delegate for Metal.");
716  auto* metal_delegate = new ::tflite::gpu::metal::Delegate(options);
717  return metal_delegate ? metal_delegate->tflite_delegate() : nullptr;
718}
719
720void TFLGpuDelegateDelete(TfLiteDelegate* delegate) {
721  delete ::tflite::gpu::metal::GetMetalDelegate(delegate);
722}
723
724bool TFLGpuDelegateBindMetalBufferToTensor(TfLiteDelegate* delegate, int tensor_index,
725                                           id<MTLBuffer> buffer) {
726  auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate);
727  return metal_delegate && metal_delegate->BindBufferToTensor(buffer, tensor_index).ok();
728}
729
730// Note: This function is not exposed in `metal_delegate.h`, but it's exposed in
731// `metal_delegate_internal.h`.
732bool TFLGpuDelegateSetCommandBuffer(TfLiteDelegate* delegate,
733                                    id<MTLCommandBuffer> command_buffer) {
734  auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate);
735  if (!metal_delegate) return false;
736  metal_delegate->SetCommandBuffer(command_buffer);
737  return true;
738}
739
740TFLGpuDelegateOptions TFLGpuDelegateOptionsDefault() {
741  TFLGpuDelegateOptions options = {
742      .allow_precision_loss = false,
743      .wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive,
744      .enable_quantization = true,
745  };
746  return options;
747}
748