1 // Copyright 2018 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at: 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 import Foundation 16 import TensorFlowLiteC 17 18 #if os(Linux) 19 import SwiftGlibc 20 #else 21 import Darwin 22 #endif 23 24 /// A TensorFlow Lite interpreter that performs inference from a given model. 25 /// 26 /// - Note: Interpreter instances are *not* thread-safe. 27 public final class Interpreter { 28 /// The configuration options for the `Interpreter`. 29 public let options: Options? 30 31 /// An `Array` of `Delegate`s for the `Interpreter` to use to perform graph operations. 32 public let delegates: [Delegate]? 33 34 /// The total number of input `Tensor`s associated with the model. 35 public var inputTensorCount: Int { 36 return Int(TfLiteInterpreterGetInputTensorCount(cInterpreter)) 37 } 38 39 /// The total number of output `Tensor`s associated with the model. 40 public var outputTensorCount: Int { 41 return Int(TfLiteInterpreterGetOutputTensorCount(cInterpreter)) 42 } 43 44 /// An ordered list of SignatureDef exported method names available in the model. 45 public var signatureKeys: [String] { 46 guard let signatureKeys = _signatureKeys else { 47 let signatureCount = Int(TfLiteInterpreterGetSignatureCount(self.cInterpreter)) 48 let keys: [String] = (0..<signatureCount).map { 49 guard 50 let signatureNameCString = TfLiteInterpreterGetSignatureKey( 51 self.cInterpreter, Int32($0)) 52 else { 53 return "" 54 } 55 return String(cString: signatureNameCString) 56 } 57 _signatureKeys = keys 58 return keys 59 } 60 return signatureKeys 61 } 62 63 /// The `TfLiteInterpreter` C pointer type represented as an `UnsafePointer<TfLiteInterpreter>`. 64 internal typealias CInterpreter = OpaquePointer 65 66 /// The underlying `TfLiteInterpreter` C pointer. 67 internal var cInterpreter: CInterpreter? 68 69 /// The underlying `TfLiteDelegate` C pointer for XNNPACK delegate. 70 private var cXNNPackDelegate: Delegate.CDelegate? 71 72 /// An ordered list of SignatureDef exported method names available in the model. 73 private var _signatureKeys: [String]? = nil 74 75 /// Creates a new instance with the given values. 76 /// 77 /// - Parameters: 78 /// - modelPath: The local file path to a TensorFlow Lite model. 79 /// - options: Configurations for the `Interpreter`. The default is `nil` indicating that the 80 /// `Interpreter` will determine the configuration options. 81 /// - delegate: `Array` of `Delegate`s for the `Interpreter` to use to peform graph operations. 82 /// The default is `nil`. 83 /// - Throws: An error if the model could not be loaded or the interpreter could not be created. 84 public init(modelPath: String, options: Options? = nil, delegates: [Delegate]? = nil) throws { 85 guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel } 86 guard let cInterpreterOptions = TfLiteInterpreterOptionsCreate() else { 87 throw InterpreterError.failedToCreateInterpreter 88 } 89 defer { TfLiteInterpreterOptionsDelete(cInterpreterOptions) } 90 91 self.options = options 92 self.delegates = delegates 93 options.map { 94 if let threadCount = $0.threadCount, threadCount > 0 { 95 TfLiteInterpreterOptionsSetNumThreads(cInterpreterOptions, Int32(threadCount)) 96 } 97 TfLiteInterpreterOptionsSetErrorReporter( 98 cInterpreterOptions, 99 { (_, format, args) -> Void in 100 // Workaround for optionality differences for x86_64 (non-optional) and arm64 (optional). 101 let optionalArgs: CVaListPointer? = args 102 guard let cFormat = format, 103 let arguments = optionalArgs, 104 let message = String(cFormat: cFormat, arguments: arguments) 105 else { 106 return 107 } 108 print(String(describing: InterpreterError.tensorFlowLiteError(message))) 109 }, 110 nil 111 ) 112 } 113 delegates?.forEach { TfLiteInterpreterOptionsAddDelegate(cInterpreterOptions, $0.cDelegate) } 114 115 // Configure the XNNPack delegate after the other delegates explicitly added by the user. 116 options.map { 117 if $0.isXNNPackEnabled { 118 configureXNNPack(options: $0, cInterpreterOptions: cInterpreterOptions) 119 } 120 } 121 122 guard let cInterpreter = TfLiteInterpreterCreate(model.cModel, cInterpreterOptions) else { 123 throw InterpreterError.failedToCreateInterpreter 124 } 125 self.cInterpreter = cInterpreter 126 } 127 128 deinit { 129 TfLiteInterpreterDelete(cInterpreter) 130 TfLiteXNNPackDelegateDelete(cXNNPackDelegate) 131 } 132 133 /// Invokes the interpreter to perform inference from the loaded graph. 134 /// 135 /// - Throws: An error if the model was not ready because the tensors were not allocated. invokenull136 public func invoke() throws { 137 guard TfLiteInterpreterInvoke(cInterpreter) == kTfLiteOk else { 138 throw InterpreterError.allocateTensorsRequired 139 } 140 } 141 142 /// Returns the input `Tensor` at the given index. 143 /// 144 /// - Parameters: 145 /// - index: The index for the input `Tensor`. 146 /// - Throws: An error if the index is invalid or the tensors have not been allocated. 147 /// - Returns: The input `Tensor` at the given index. inputnull148 public func input(at index: Int) throws -> Tensor { 149 let maxIndex = inputTensorCount - 1 150 guard case 0...maxIndex = index else { 151 throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) 152 } 153 guard let cTensor = TfLiteInterpreterGetInputTensor(cInterpreter, Int32(index)), 154 let bytes = TfLiteTensorData(cTensor), 155 let nameCString = TfLiteTensorName(cTensor) 156 else { 157 throw InterpreterError.allocateTensorsRequired 158 } 159 guard let dataType = Tensor.DataType(type: TfLiteTensorType(cTensor)) else { 160 throw InterpreterError.invalidTensorDataType 161 } 162 163 let name = String(cString: nameCString) 164 let rank = TfLiteTensorNumDims(cTensor) 165 let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) } 166 let shape = Tensor.Shape(dimensions) 167 let byteCount = TfLiteTensorByteSize(cTensor) 168 let data = Data(bytes: bytes, count: byteCount) 169 let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor) 170 let scale = cQuantizationParams.scale 171 let zeroPoint = Int(cQuantizationParams.zero_point) 172 var quantizationParameters: QuantizationParameters? = nil 173 if scale != 0.0 { 174 quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint) 175 } 176 let tensor = Tensor( 177 name: name, 178 dataType: dataType, 179 shape: shape, 180 data: data, 181 quantizationParameters: quantizationParameters 182 ) 183 return tensor 184 } 185 186 /// Returns the output `Tensor` at the given index. 187 /// 188 /// - Parameters: 189 /// - index: The index for the output `Tensor`. 190 /// - Throws: An error if the index is invalid, tensors haven't been allocated, or interpreter 191 /// has not been invoked for models that dynamically compute output tensors based on the 192 /// values of its input tensors. 193 /// - Returns: The output `Tensor` at the given index. outputnull194 public func output(at index: Int) throws -> Tensor { 195 let maxIndex = outputTensorCount - 1 196 guard case 0...maxIndex = index else { 197 throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) 198 } 199 guard let cTensor = TfLiteInterpreterGetOutputTensor(cInterpreter, Int32(index)), 200 let bytes = TfLiteTensorData(cTensor), 201 let nameCString = TfLiteTensorName(cTensor) 202 else { 203 throw InterpreterError.invokeInterpreterRequired 204 } 205 guard let dataType = Tensor.DataType(type: TfLiteTensorType(cTensor)) else { 206 throw InterpreterError.invalidTensorDataType 207 } 208 209 let name = String(cString: nameCString) 210 let rank = TfLiteTensorNumDims(cTensor) 211 let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) } 212 let shape = Tensor.Shape(dimensions) 213 let byteCount = TfLiteTensorByteSize(cTensor) 214 let data = Data(bytes: bytes, count: byteCount) 215 let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor) 216 let scale = cQuantizationParams.scale 217 let zeroPoint = Int(cQuantizationParams.zero_point) 218 var quantizationParameters: QuantizationParameters? = nil 219 if scale != 0.0 { 220 quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint) 221 } 222 let tensor = Tensor( 223 name: name, 224 dataType: dataType, 225 shape: shape, 226 data: data, 227 quantizationParameters: quantizationParameters 228 ) 229 return tensor 230 } 231 232 /// Resizes the input `Tensor` at the given index to the specified `Tensor.Shape`. 233 /// 234 /// - Note: After resizing an input tensor, the client **must** explicitly call 235 /// `allocateTensors()` before attempting to access the resized tensor data or invoking the 236 /// interpreter to perform inference. 237 /// - Parameters: 238 /// - index: The index for the input `Tensor`. 239 /// - shape: The shape to resize the input `Tensor` to. 240 /// - Throws: An error if the input tensor at the given index could not be resized. resizeInputnull241 public func resizeInput(at index: Int, to shape: Tensor.Shape) throws { 242 let maxIndex = inputTensorCount - 1 243 guard case 0...maxIndex = index else { 244 throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) 245 } 246 guard 247 TfLiteInterpreterResizeInputTensor( 248 cInterpreter, 249 Int32(index), 250 shape.int32Dimensions, 251 Int32(shape.rank) 252 ) == kTfLiteOk 253 else { 254 throw InterpreterError.failedToResizeInputTensor(index: index) 255 } 256 } 257 258 /// Copies the given data to the input `Tensor` at the given index. 259 /// 260 /// - Parameters: 261 /// - data: The data to be copied to the input `Tensor`'s data buffer. 262 /// - index: The index for the input `Tensor`. 263 /// - Throws: An error if the `data.count` does not match the input tensor's `data.count` or if 264 /// the given index is invalid. 265 /// - Returns: The input `Tensor` with the copied data. 266 @discardableResult copynull267 public func copy(_ data: Data, toInputAt index: Int) throws -> Tensor { 268 let maxIndex = inputTensorCount - 1 269 guard case 0...maxIndex = index else { 270 throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) 271 } 272 guard let cTensor = TfLiteInterpreterGetInputTensor(cInterpreter, Int32(index)) else { 273 throw InterpreterError.allocateTensorsRequired 274 } 275 276 let byteCount = TfLiteTensorByteSize(cTensor) 277 guard data.count == byteCount else { 278 throw InterpreterError.invalidTensorDataCount(provided: data.count, required: byteCount) 279 } 280 281 #if swift(>=5.0) 282 let status = data.withUnsafeBytes { 283 TfLiteTensorCopyFromBuffer(cTensor, $0.baseAddress, data.count) 284 } 285 #else 286 let status = data.withUnsafeBytes { TfLiteTensorCopyFromBuffer(cTensor, $0, data.count) } 287 #endif // swift(>=5.0) 288 guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor } 289 return try input(at: index) 290 } 291 292 /// Allocates memory for all input `Tensor`s based on their `Tensor.Shape`s. 293 /// 294 /// - Note: This is a relatively expensive operation and should only be called after creating the 295 /// interpreter and resizing any input tensors. 296 /// - Throws: An error if memory could not be allocated for the input tensors. allocateTensorsnull297 public func allocateTensors() throws { 298 guard TfLiteInterpreterAllocateTensors(cInterpreter) == kTfLiteOk else { 299 throw InterpreterError.failedToAllocateTensors 300 } 301 } 302 303 /// Returns a new signature runner instance for the signature with the given key in the model. 304 /// 305 /// - Parameters: 306 /// - key: The signature key. 307 /// - Throws: `SignatureRunnerError` if signature runner creation fails. 308 /// - Returns: A new signature runner instance for the signature with the given key. signatureRunnernull309 public func signatureRunner(with key: String) throws -> SignatureRunner { 310 guard signatureKeys.contains(key) else { 311 throw SignatureRunnerError.failedToCreateSignatureRunner(signatureKey: key) 312 } 313 return try SignatureRunner.init(interpreter: self, signatureKey: key) 314 } 315 316 // MARK: - Private 317 configureXNNPacknull318 private func configureXNNPack(options: Options, cInterpreterOptions: OpaquePointer) { 319 var cXNNPackOptions = TfLiteXNNPackDelegateOptionsDefault() 320 if let threadCount = options.threadCount, threadCount > 0 { 321 cXNNPackOptions.num_threads = Int32(threadCount) 322 } 323 324 cXNNPackDelegate = TfLiteXNNPackDelegateCreate(&cXNNPackOptions) 325 TfLiteInterpreterOptionsAddDelegate(cInterpreterOptions, cXNNPackDelegate) 326 } 327 } 328 329 extension Interpreter { 330 /// Options for configuring the `Interpreter`. 331 public struct Options: Equatable, Hashable { 332 /// The maximum number of CPU threads that the interpreter should run on. The default is `nil` 333 /// indicating that the `Interpreter` will decide the number of threads to use. 334 public var threadCount: Int? = nil 335 336 /// Indicates whether an optimized set of floating point CPU kernels, provided by XNNPACK, is 337 /// enabled. 338 /// 339 /// - Experiment: 340 /// Enabling this flag will enable use of a new, highly optimized set of CPU kernels provided 341 /// via the XNNPACK delegate. Currently, this is restricted to a subset of floating point 342 /// operations. Eventually, we plan to enable this by default, as it can provide significant 343 /// performance benefits for many classes of floating point models. See 344 /// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md 345 /// for more details. 346 /// 347 /// - Important: 348 /// Things to keep in mind when enabling this flag: 349 /// 350 /// * Startup time and resize time may increase. 351 /// * Baseline memory consumption may increase. 352 /// * Compatibility with other delegates (e.g., GPU) has not been fully validated. 353 /// * Quantized models will not see any benefit. 354 /// 355 /// - Warning: This is an experimental interface that is subject to change. 356 public var isXNNPackEnabled: Bool = false 357 358 /// Creates a new instance with the default values. 359 public init() {} 360 } 361 } 362 363 /// A type alias for `Interpreter.Options` to support backwards compatibility with the deprecated 364 /// `InterpreterOptions` struct. 365 @available(*, deprecated, renamed: "Interpreter.Options") 366 public typealias InterpreterOptions = Interpreter.Options 367 368 extension String { 369 /// Returns a new `String` initialized by using the given format C array as a template into which 370 /// the remaining argument values are substituted according to the user’s default locale. 371 /// 372 /// - Note: Returns `nil` if a new `String` could not be constructed from the given values. 373 /// - Parameters: 374 /// - cFormat: The format C array as a template for substituting values. 375 /// - arguments: A C pointer to a `va_list` of arguments to substitute into `cFormat`. 376 init?(cFormat: UnsafePointer<CChar>, arguments: CVaListPointer) { 377 #if os(Linux) 378 let length = Int(vsnprintf(nil, 0, cFormat, arguments) + 1) // null terminator 379 guard length > 0 else { return nil } 380 let buffer = UnsafeMutablePointer<CChar>.allocate(capacity: length) 381 defer { 382 buffer.deallocate() 383 } 384 guard vsnprintf(buffer, length, cFormat, arguments) == length - 1 else { return nil } 385 self.init(validatingUTF8: buffer) 386 #else 387 var buffer: UnsafeMutablePointer<CChar>? 388 guard vasprintf(&buffer, cFormat, arguments) != 0, let cString = buffer else { return nil } 389 self.init(validatingUTF8: cString) 390 #endif 391 } 392 } 393