1 // Copyright 2019 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at: 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 import TensorFlowLiteCMetal 16 17 /// A delegate that uses the `Metal` framework for performing TensorFlow Lite graph operations with 18 /// GPU acceleration. 19 /// 20 /// - Important: This is an experimental interface that is subject to change. 21 public final class MetalDelegate: Delegate { 22 /// The configuration options for the `MetalDelegate`. 23 public let options: Options 24 25 // Conformance to the `Delegate` protocol. 26 public private(set) var cDelegate: CDelegate 27 28 /// Creates a new instance configured with the given `options`. 29 /// 30 /// - Parameters: 31 /// - options: Configurations for the delegate. The default is a new instance of 32 /// `MetalDelegate.Options` with the default configuration values. 33 public init(options: Options = Options()) { 34 self.options = options 35 var delegateOptions = TFLGpuDelegateOptions() 36 delegateOptions.allow_precision_loss = options.isPrecisionLossAllowed 37 delegateOptions.wait_type = options.waitType.cWaitType 38 delegateOptions.enable_quantization = options.isQuantizationEnabled 39 cDelegate = TFLGpuDelegateCreate(&delegateOptions) 40 } 41 42 deinit { 43 TFLGpuDelegateDelete(cDelegate) 44 } 45 } 46 47 extension MetalDelegate { 48 /// Options for configuring the `MetalDelegate`. 49 public struct Options: Equatable, Hashable { 50 /// Indicates whether the GPU delegate allows precision loss, such as allowing `Float16` 51 /// precision for a `Float32` computation. The default is `false`. 52 public var isPrecisionLossAllowed = false 53 54 @available( 55 *, deprecated, message: "Deprecated since TensorFlow Lite 2.4", 56 renamed: "isPrecisionLossAllowed" 57 ) 58 public var allowsPrecisionLoss: Bool { 59 get { return isPrecisionLossAllowed } 60 set(value) { isPrecisionLossAllowed = value } 61 } 62 63 /// A type indicating how the current thread should wait for work on the GPU to complete. The 64 /// default is `passive`. 65 public var waitType: ThreadWaitType = .passive 66 67 /// Indicates whether the GPU delegate allows execution of an 8-bit quantized model. The default 68 /// is `true`. 69 public var isQuantizationEnabled = true 70 71 /// Creates a new instance with the default values. 72 public init() {} 73 } 74 } 75 76 /// A type indicating how the current thread should wait for work scheduled on the GPU to complete. 77 public enum ThreadWaitType: Equatable, Hashable { 78 /// The thread does not wait for the work to complete. Useful when the output of the work is used 79 /// with the GPU pipeline. 80 case none 81 /// The thread waits until the work is complete. 82 case passive 83 /// The thread waits for the work to complete with minimal latency, which may require additional 84 /// CPU resources. 85 case active 86 /// The thread waits for the work while trying to prevent the GPU from going into sleep mode. 87 case aggressive 88 89 /// The C `TFLGpuDelegateWaitType` for the current `ThreadWaitType`. 90 var cWaitType: TFLGpuDelegateWaitType { 91 switch self { 92 case .none: 93 return TFLGpuDelegateWaitTypeDoNotWait 94 case .passive: 95 return TFLGpuDelegateWaitTypePassive 96 case .active: 97 return TFLGpuDelegateWaitTypeActive 98 case .aggressive: 99 return TFLGpuDelegateWaitTypeAggressive 100 } 101 } 102 } 103