xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/swift/Sources/MetalDelegate.swift (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 // Copyright 2019 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 import TensorFlowLiteCMetal
16 
17 /// A delegate that uses the `Metal` framework for performing TensorFlow Lite graph operations with
18 /// GPU acceleration.
19 ///
20 /// - Important: This is an experimental interface that is subject to change.
21 public final class MetalDelegate: Delegate {
22   /// The configuration options for the `MetalDelegate`.
23   public let options: Options
24 
25   // Conformance to the `Delegate` protocol.
26   public private(set) var cDelegate: CDelegate
27 
28   /// Creates a new instance configured with the given `options`.
29   ///
30   /// - Parameters:
31   ///   - options: Configurations for the delegate. The default is a new instance of
32   ///       `MetalDelegate.Options` with the default configuration values.
33   public init(options: Options = Options()) {
34     self.options = options
35     var delegateOptions = TFLGpuDelegateOptions()
36     delegateOptions.allow_precision_loss = options.isPrecisionLossAllowed
37     delegateOptions.wait_type = options.waitType.cWaitType
38     delegateOptions.enable_quantization = options.isQuantizationEnabled
39     cDelegate = TFLGpuDelegateCreate(&delegateOptions)
40   }
41 
42   deinit {
43     TFLGpuDelegateDelete(cDelegate)
44   }
45 }
46 
47 extension MetalDelegate {
48   /// Options for configuring the `MetalDelegate`.
49   public struct Options: Equatable, Hashable {
50     /// Indicates whether the GPU delegate allows precision loss, such as allowing `Float16`
51     /// precision for a `Float32` computation. The default is `false`.
52     public var isPrecisionLossAllowed = false
53 
54     @available(
55       *, deprecated, message: "Deprecated since TensorFlow Lite 2.4",
56       renamed: "isPrecisionLossAllowed"
57     )
58     public var allowsPrecisionLoss: Bool {
59       get { return isPrecisionLossAllowed }
60       set(value) { isPrecisionLossAllowed = value }
61     }
62 
63     /// A type indicating how the current thread should wait for work on the GPU to complete. The
64     /// default is `passive`.
65     public var waitType: ThreadWaitType = .passive
66 
67     /// Indicates whether the GPU delegate allows execution of an 8-bit quantized model. The default
68     /// is `true`.
69     public var isQuantizationEnabled = true
70 
71     /// Creates a new instance with the default values.
72     public init() {}
73   }
74 }
75 
76 /// A type indicating how the current thread should wait for work scheduled on the GPU to complete.
77 public enum ThreadWaitType: Equatable, Hashable {
78   /// The thread does not wait for the work to complete. Useful when the output of the work is used
79   /// with the GPU pipeline.
80   case none
81   /// The thread waits until the work is complete.
82   case passive
83   /// The thread waits for the work to complete with minimal latency, which may require additional
84   /// CPU resources.
85   case active
86   /// The thread waits for the work while trying to prevent the GPU from going into sleep mode.
87   case aggressive
88 
89   /// The C `TFLGpuDelegateWaitType` for the current `ThreadWaitType`.
90   var cWaitType: TFLGpuDelegateWaitType {
91     switch self {
92     case .none:
93       return TFLGpuDelegateWaitTypeDoNotWait
94     case .passive:
95       return TFLGpuDelegateWaitTypePassive
96     case .active:
97       return TFLGpuDelegateWaitTypeActive
98     case .aggressive:
99       return TFLGpuDelegateWaitTypeAggressive
100     }
101   }
102 }
103