xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/swift/Sources/Tensor.swift (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 // Copyright 2018 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 import Foundation
16 import TensorFlowLiteC
17 
18 /// An input or output tensor in a TensorFlow Lite graph.
19 public struct Tensor: Equatable, Hashable {
20   /// The name of the `Tensor`.
21   public let name: String
22 
23   /// The data type of the `Tensor`.
24   public let dataType: DataType
25 
26   /// The shape of the `Tensor`.
27   public let shape: Shape
28 
29   /// The data of the `Tensor`. The data is created with copied memory content. See creating data
30   /// from raw memory at https://developer.apple.com/documentation/foundation/data.
31   public let data: Data
32 
33   /// The quantization parameters for the `Tensor` if using a quantized model.
34   public let quantizationParameters: QuantizationParameters?
35 
36   /// Creates a new input or output `Tensor` instance.
37   ///
38   /// - Parameters:
39   ///   - name: The name of the `Tensor`.
40   ///   - dataType: The data type of the `Tensor`.
41   ///   - shape: The shape of the `Tensor`.
42   ///   - data: The data in the input `Tensor`.
43   ///   - quantizationParameters Parameters for the `Tensor` if using a quantized model. The default
44   ///       is `nil`.
45   init(
46     name: String,
47     dataType: DataType,
48     shape: Shape,
49     data: Data,
50     quantizationParameters: QuantizationParameters? = nil
51   ) {
52     self.name = name
53     self.dataType = dataType
54     self.shape = shape
55     self.data = data
56     self.quantizationParameters = quantizationParameters
57   }
58 }
59 
60 extension Tensor {
61   /// The supported `Tensor` data types.
62   public enum DataType: Equatable, Hashable {
63     /// A boolean.
64     case bool
65     /// An 8-bit unsigned integer.
66     case uInt8
67     /// A 16-bit signed integer.
68     case int16
69     /// A 32-bit signed integer.
70     case int32
71     /// A 64-bit signed integer.
72     case int64
73     /// A 16-bit half precision floating point.
74     case float16
75     /// A 32-bit single precision floating point.
76     case float32
77     /// A 64-bit double precision floating point.
78     case float64
79 
80     /// Creates a new instance from the given `TfLiteType` or `nil` if the data type is unsupported
81     /// or could not be determined because there was an error.
82     ///
83     /// - Parameter type: A data type for a tensor.
84     init?(type: TfLiteType) {
85       switch type {
86       case kTfLiteBool:
87         self = .bool
88       case kTfLiteUInt8:
89         self = .uInt8
90       case kTfLiteInt16:
91         self = .int16
92       case kTfLiteInt32:
93         self = .int32
94       case kTfLiteInt64:
95         self = .int64
96       case kTfLiteFloat16:
97         self = .float16
98       case kTfLiteFloat32:
99         self = .float32
100       case kTfLiteFloat64:
101         self = .float64
102       case kTfLiteNoType:
103         fallthrough
104       default:
105         return nil
106       }
107     }
108   }
109 }
110 
111 extension Tensor {
112   /// The shape of a `Tensor`.
113   public struct Shape: Equatable, Hashable {
114     /// The number of dimensions of the `Tensor`.
115     public let rank: Int
116 
117     /// An array of dimensions for the `Tensor`.
118     public let dimensions: [Int]
119 
120     /// An array of `Int32` dimensions for the `Tensor`.
121     var int32Dimensions: [Int32] { return dimensions.map(Int32.init) }
122 
123     /// Creates a new instance with the given array of dimensions.
124     ///
125     /// - Parameters:
126     ///   - dimensions: Dimensions for the `Tensor`.
127     public init(_ dimensions: [Int]) {
128       self.rank = dimensions.count
129       self.dimensions = dimensions
130     }
131 
132     /// Creates a new instance with the given elements representing the dimensions.
133     ///
134     /// - Parameters:
135     ///   - elements: Dimensions for the `Tensor`.
136     public init(_ elements: Int...) {
137       self.init(elements)
138     }
139   }
140 }
141 
142 extension Tensor.Shape: ExpressibleByArrayLiteral {
143   /// Creates a new instance with the given array literal representing the dimensions.
144   ///
145   /// - Parameters:
146   ///   - arrayLiteral: Dimensions for the `Tensor`.
147   public init(arrayLiteral: Int...) {
148     self.init(arrayLiteral)
149   }
150 }
151 
152 /// A tensor's function level purpose: input or output.
153 internal enum TensorType: String {
154   case input
155   case output
156 }
157