xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/swift/Sources/Interpreter.swift (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 // Copyright 2018 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 import Foundation
16 import TensorFlowLiteC
17 
18 #if os(Linux)
19   import SwiftGlibc
20 #else
21   import Darwin
22 #endif
23 
24 /// A TensorFlow Lite interpreter that performs inference from a given model.
25 ///
26 /// - Note: Interpreter instances are *not* thread-safe.
27 public final class Interpreter {
28   /// The configuration options for the `Interpreter`.
29   public let options: Options?
30 
31   /// An `Array` of `Delegate`s for the `Interpreter` to use to perform graph operations.
32   public let delegates: [Delegate]?
33 
34   /// The total number of input `Tensor`s associated with the model.
35   public var inputTensorCount: Int {
36     return Int(TfLiteInterpreterGetInputTensorCount(cInterpreter))
37   }
38 
39   /// The total number of output `Tensor`s associated with the model.
40   public var outputTensorCount: Int {
41     return Int(TfLiteInterpreterGetOutputTensorCount(cInterpreter))
42   }
43 
44   /// An ordered list of SignatureDef exported method names available in the model.
45   public var signatureKeys: [String] {
46     guard let signatureKeys = _signatureKeys else {
47       let signatureCount = Int(TfLiteInterpreterGetSignatureCount(self.cInterpreter))
48       let keys: [String] = (0..<signatureCount).map {
49         guard
50           let signatureNameCString = TfLiteInterpreterGetSignatureKey(
51             self.cInterpreter, Int32($0))
52         else {
53           return ""
54         }
55         return String(cString: signatureNameCString)
56       }
57       _signatureKeys = keys
58       return keys
59     }
60     return signatureKeys
61   }
62 
63   /// The `TfLiteInterpreter` C pointer type represented as an `UnsafePointer<TfLiteInterpreter>`.
64   internal typealias CInterpreter = OpaquePointer
65 
66   /// The underlying `TfLiteInterpreter` C pointer.
67   internal var cInterpreter: CInterpreter?
68 
69   /// The underlying `TfLiteDelegate` C pointer for XNNPACK delegate.
70   private var cXNNPackDelegate: Delegate.CDelegate?
71 
72   /// An ordered list of SignatureDef exported method names available in the model.
73   private var _signatureKeys: [String]? = nil
74 
75   /// Creates a new instance with the given values.
76   ///
77   /// - Parameters:
78   ///   - modelPath: The local file path to a TensorFlow Lite model.
79   ///   - options: Configurations for the `Interpreter`. The default is `nil` indicating that the
80   ///       `Interpreter` will determine the configuration options.
81   ///   - delegate: `Array` of `Delegate`s for the `Interpreter` to use to peform graph operations.
82   ///       The default is `nil`.
83   /// - Throws: An error if the model could not be loaded or the interpreter could not be created.
84   public init(modelPath: String, options: Options? = nil, delegates: [Delegate]? = nil) throws {
85     guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel }
86     guard let cInterpreterOptions = TfLiteInterpreterOptionsCreate() else {
87       throw InterpreterError.failedToCreateInterpreter
88     }
89     defer { TfLiteInterpreterOptionsDelete(cInterpreterOptions) }
90 
91     self.options = options
92     self.delegates = delegates
93     options.map {
94       if let threadCount = $0.threadCount, threadCount > 0 {
95         TfLiteInterpreterOptionsSetNumThreads(cInterpreterOptions, Int32(threadCount))
96       }
97       TfLiteInterpreterOptionsSetErrorReporter(
98         cInterpreterOptions,
99         { (_, format, args) -> Void in
100           // Workaround for optionality differences for x86_64 (non-optional) and arm64 (optional).
101           let optionalArgs: CVaListPointer? = args
102           guard let cFormat = format,
103             let arguments = optionalArgs,
104             let message = String(cFormat: cFormat, arguments: arguments)
105           else {
106             return
107           }
108           print(String(describing: InterpreterError.tensorFlowLiteError(message)))
109         },
110         nil
111       )
112     }
113     delegates?.forEach { TfLiteInterpreterOptionsAddDelegate(cInterpreterOptions, $0.cDelegate) }
114 
115     // Configure the XNNPack delegate after the other delegates explicitly added by the user.
116     options.map {
117       if $0.isXNNPackEnabled {
118         configureXNNPack(options: $0, cInterpreterOptions: cInterpreterOptions)
119       }
120     }
121 
122     guard let cInterpreter = TfLiteInterpreterCreate(model.cModel, cInterpreterOptions) else {
123       throw InterpreterError.failedToCreateInterpreter
124     }
125     self.cInterpreter = cInterpreter
126   }
127 
128   deinit {
129     TfLiteInterpreterDelete(cInterpreter)
130     TfLiteXNNPackDelegateDelete(cXNNPackDelegate)
131   }
132 
133   /// Invokes the interpreter to perform inference from the loaded graph.
134   ///
135   /// - Throws: An error if the model was not ready because the tensors were not allocated.
invokenull136   public func invoke() throws {
137     guard TfLiteInterpreterInvoke(cInterpreter) == kTfLiteOk else {
138       throw InterpreterError.allocateTensorsRequired
139     }
140   }
141 
142   /// Returns the input `Tensor` at the given index.
143   ///
144   /// - Parameters:
145   ///   - index: The index for the input `Tensor`.
146   /// - Throws: An error if the index is invalid or the tensors have not been allocated.
147   /// - Returns: The input `Tensor` at the given index.
inputnull148   public func input(at index: Int) throws -> Tensor {
149     let maxIndex = inputTensorCount - 1
150     guard case 0...maxIndex = index else {
151       throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
152     }
153     guard let cTensor = TfLiteInterpreterGetInputTensor(cInterpreter, Int32(index)),
154       let bytes = TfLiteTensorData(cTensor),
155       let nameCString = TfLiteTensorName(cTensor)
156     else {
157       throw InterpreterError.allocateTensorsRequired
158     }
159     guard let dataType = Tensor.DataType(type: TfLiteTensorType(cTensor)) else {
160       throw InterpreterError.invalidTensorDataType
161     }
162 
163     let name = String(cString: nameCString)
164     let rank = TfLiteTensorNumDims(cTensor)
165     let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) }
166     let shape = Tensor.Shape(dimensions)
167     let byteCount = TfLiteTensorByteSize(cTensor)
168     let data = Data(bytes: bytes, count: byteCount)
169     let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor)
170     let scale = cQuantizationParams.scale
171     let zeroPoint = Int(cQuantizationParams.zero_point)
172     var quantizationParameters: QuantizationParameters? = nil
173     if scale != 0.0 {
174       quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
175     }
176     let tensor = Tensor(
177       name: name,
178       dataType: dataType,
179       shape: shape,
180       data: data,
181       quantizationParameters: quantizationParameters
182     )
183     return tensor
184   }
185 
186   /// Returns the output `Tensor` at the given index.
187   ///
188   /// - Parameters:
189   ///   - index: The index for the output `Tensor`.
190   /// - Throws: An error if the index is invalid, tensors haven't been allocated, or interpreter
191   ///     has not been invoked for models that dynamically compute output tensors based on the
192   ///     values of its input tensors.
193   /// - Returns: The output `Tensor` at the given index.
outputnull194   public func output(at index: Int) throws -> Tensor {
195     let maxIndex = outputTensorCount - 1
196     guard case 0...maxIndex = index else {
197       throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
198     }
199     guard let cTensor = TfLiteInterpreterGetOutputTensor(cInterpreter, Int32(index)),
200       let bytes = TfLiteTensorData(cTensor),
201       let nameCString = TfLiteTensorName(cTensor)
202     else {
203       throw InterpreterError.invokeInterpreterRequired
204     }
205     guard let dataType = Tensor.DataType(type: TfLiteTensorType(cTensor)) else {
206       throw InterpreterError.invalidTensorDataType
207     }
208 
209     let name = String(cString: nameCString)
210     let rank = TfLiteTensorNumDims(cTensor)
211     let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) }
212     let shape = Tensor.Shape(dimensions)
213     let byteCount = TfLiteTensorByteSize(cTensor)
214     let data = Data(bytes: bytes, count: byteCount)
215     let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor)
216     let scale = cQuantizationParams.scale
217     let zeroPoint = Int(cQuantizationParams.zero_point)
218     var quantizationParameters: QuantizationParameters? = nil
219     if scale != 0.0 {
220       quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
221     }
222     let tensor = Tensor(
223       name: name,
224       dataType: dataType,
225       shape: shape,
226       data: data,
227       quantizationParameters: quantizationParameters
228     )
229     return tensor
230   }
231 
232   /// Resizes the input `Tensor` at the given index to the specified `Tensor.Shape`.
233   ///
234   /// - Note: After resizing an input tensor, the client **must** explicitly call
235   ///     `allocateTensors()` before attempting to access the resized tensor data or invoking the
236   ///     interpreter to perform inference.
237   /// - Parameters:
238   ///   - index: The index for the input `Tensor`.
239   ///   - shape: The shape to resize the input `Tensor` to.
240   /// - Throws: An error if the input tensor at the given index could not be resized.
resizeInputnull241   public func resizeInput(at index: Int, to shape: Tensor.Shape) throws {
242     let maxIndex = inputTensorCount - 1
243     guard case 0...maxIndex = index else {
244       throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
245     }
246     guard
247       TfLiteInterpreterResizeInputTensor(
248         cInterpreter,
249         Int32(index),
250         shape.int32Dimensions,
251         Int32(shape.rank)
252       ) == kTfLiteOk
253     else {
254       throw InterpreterError.failedToResizeInputTensor(index: index)
255     }
256   }
257 
258   /// Copies the given data to the input `Tensor` at the given index.
259   ///
260   /// - Parameters:
261   ///   - data: The data to be copied to the input `Tensor`'s data buffer.
262   ///   - index: The index for the input `Tensor`.
263   /// - Throws: An error if the `data.count` does not match the input tensor's `data.count` or if
264   ///     the given index is invalid.
265   /// - Returns: The input `Tensor` with the copied data.
266   @discardableResult
copynull267   public func copy(_ data: Data, toInputAt index: Int) throws -> Tensor {
268     let maxIndex = inputTensorCount - 1
269     guard case 0...maxIndex = index else {
270       throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
271     }
272     guard let cTensor = TfLiteInterpreterGetInputTensor(cInterpreter, Int32(index)) else {
273       throw InterpreterError.allocateTensorsRequired
274     }
275 
276     let byteCount = TfLiteTensorByteSize(cTensor)
277     guard data.count == byteCount else {
278       throw InterpreterError.invalidTensorDataCount(provided: data.count, required: byteCount)
279     }
280 
281     #if swift(>=5.0)
282       let status = data.withUnsafeBytes {
283         TfLiteTensorCopyFromBuffer(cTensor, $0.baseAddress, data.count)
284       }
285     #else
286       let status = data.withUnsafeBytes { TfLiteTensorCopyFromBuffer(cTensor, $0, data.count) }
287     #endif  // swift(>=5.0)
288     guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor }
289     return try input(at: index)
290   }
291 
292   /// Allocates memory for all input `Tensor`s based on their `Tensor.Shape`s.
293   ///
294   /// - Note: This is a relatively expensive operation and should only be called after creating the
295   ///     interpreter and resizing any input tensors.
296   /// - Throws: An error if memory could not be allocated for the input tensors.
allocateTensorsnull297   public func allocateTensors() throws {
298     guard TfLiteInterpreterAllocateTensors(cInterpreter) == kTfLiteOk else {
299       throw InterpreterError.failedToAllocateTensors
300     }
301   }
302 
303   /// Returns a new signature runner instance for the signature with the given key in the model.
304   ///
305   /// - Parameters:
306   ///   - key: The signature key.
307   /// - Throws: `SignatureRunnerError` if signature runner creation fails.
308   /// - Returns: A new signature runner instance for the signature with the given key.
signatureRunnernull309   public func signatureRunner(with key: String) throws -> SignatureRunner {
310     guard signatureKeys.contains(key) else {
311       throw SignatureRunnerError.failedToCreateSignatureRunner(signatureKey: key)
312     }
313     return try SignatureRunner.init(interpreter: self, signatureKey: key)
314   }
315 
316   // MARK: - Private
317 
configureXNNPacknull318   private func configureXNNPack(options: Options, cInterpreterOptions: OpaquePointer) {
319     var cXNNPackOptions = TfLiteXNNPackDelegateOptionsDefault()
320     if let threadCount = options.threadCount, threadCount > 0 {
321       cXNNPackOptions.num_threads = Int32(threadCount)
322     }
323 
324     cXNNPackDelegate = TfLiteXNNPackDelegateCreate(&cXNNPackOptions)
325     TfLiteInterpreterOptionsAddDelegate(cInterpreterOptions, cXNNPackDelegate)
326   }
327 }
328 
329 extension Interpreter {
330   /// Options for configuring the `Interpreter`.
331   public struct Options: Equatable, Hashable {
332     /// The maximum number of CPU threads that the interpreter should run on. The default is `nil`
333     /// indicating that the `Interpreter` will decide the number of threads to use.
334     public var threadCount: Int? = nil
335 
336     /// Indicates whether an optimized set of floating point CPU kernels, provided by XNNPACK, is
337     /// enabled.
338     ///
339     /// - Experiment:
340     /// Enabling this flag will enable use of a new, highly optimized set of CPU kernels provided
341     /// via the XNNPACK delegate. Currently, this is restricted to a subset of floating point
342     /// operations. Eventually, we plan to enable this by default, as it can provide significant
343     /// performance benefits for many classes of floating point models. See
344     /// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md
345     /// for more details.
346     ///
347     /// - Important:
348     /// Things to keep in mind when enabling this flag:
349     ///
350     ///     * Startup time and resize time may increase.
351     ///     * Baseline memory consumption may increase.
352     ///     * Compatibility with other delegates (e.g., GPU) has not been fully validated.
353     ///     * Quantized models will not see any benefit.
354     ///
355     /// - Warning: This is an experimental interface that is subject to change.
356     public var isXNNPackEnabled: Bool = false
357 
358     /// Creates a new instance with the default values.
359     public init() {}
360   }
361 }
362 
363 /// A type alias for `Interpreter.Options` to support backwards compatibility with the deprecated
364 /// `InterpreterOptions` struct.
365 @available(*, deprecated, renamed: "Interpreter.Options")
366 public typealias InterpreterOptions = Interpreter.Options
367 
368 extension String {
369   /// Returns a new `String` initialized by using the given format C array as a template into which
370   /// the remaining argument values are substituted according to the user’s default locale.
371   ///
372   /// - Note: Returns `nil` if a new `String` could not be constructed from the given values.
373   /// - Parameters:
374   ///   - cFormat: The format C array as a template for substituting values.
375   ///   - arguments: A C pointer to a `va_list` of arguments to substitute into `cFormat`.
376   init?(cFormat: UnsafePointer<CChar>, arguments: CVaListPointer) {
377     #if os(Linux)
378       let length = Int(vsnprintf(nil, 0, cFormat, arguments) + 1)  // null terminator
379       guard length > 0 else { return nil }
380       let buffer = UnsafeMutablePointer<CChar>.allocate(capacity: length)
381       defer {
382         buffer.deallocate()
383       }
384       guard vsnprintf(buffer, length, cFormat, arguments) == length - 1 else { return nil }
385       self.init(validatingUTF8: buffer)
386     #else
387       var buffer: UnsafeMutablePointer<CChar>?
388       guard vasprintf(&buffer, cFormat, arguments) != 0, let cString = buffer else { return nil }
389       self.init(validatingUTF8: cString)
390     #endif
391   }
392 }
393