1 // Copyright 2018 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at: 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 import Foundation 16 import TensorFlowLiteC 17 18 /// An input or output tensor in a TensorFlow Lite graph. 19 public struct Tensor: Equatable, Hashable { 20 /// The name of the `Tensor`. 21 public let name: String 22 23 /// The data type of the `Tensor`. 24 public let dataType: DataType 25 26 /// The shape of the `Tensor`. 27 public let shape: Shape 28 29 /// The data of the `Tensor`. The data is created with copied memory content. See creating data 30 /// from raw memory at https://developer.apple.com/documentation/foundation/data. 31 public let data: Data 32 33 /// The quantization parameters for the `Tensor` if using a quantized model. 34 public let quantizationParameters: QuantizationParameters? 35 36 /// Creates a new input or output `Tensor` instance. 37 /// 38 /// - Parameters: 39 /// - name: The name of the `Tensor`. 40 /// - dataType: The data type of the `Tensor`. 41 /// - shape: The shape of the `Tensor`. 42 /// - data: The data in the input `Tensor`. 43 /// - quantizationParameters Parameters for the `Tensor` if using a quantized model. The default 44 /// is `nil`. 45 init( 46 name: String, 47 dataType: DataType, 48 shape: Shape, 49 data: Data, 50 quantizationParameters: QuantizationParameters? = nil 51 ) { 52 self.name = name 53 self.dataType = dataType 54 self.shape = shape 55 self.data = data 56 self.quantizationParameters = quantizationParameters 57 } 58 } 59 60 extension Tensor { 61 /// The supported `Tensor` data types. 62 public enum DataType: Equatable, Hashable { 63 /// A boolean. 64 case bool 65 /// An 8-bit unsigned integer. 66 case uInt8 67 /// A 16-bit signed integer. 68 case int16 69 /// A 32-bit signed integer. 70 case int32 71 /// A 64-bit signed integer. 72 case int64 73 /// A 16-bit half precision floating point. 74 case float16 75 /// A 32-bit single precision floating point. 76 case float32 77 /// A 64-bit double precision floating point. 78 case float64 79 80 /// Creates a new instance from the given `TfLiteType` or `nil` if the data type is unsupported 81 /// or could not be determined because there was an error. 82 /// 83 /// - Parameter type: A data type for a tensor. 84 init?(type: TfLiteType) { 85 switch type { 86 case kTfLiteBool: 87 self = .bool 88 case kTfLiteUInt8: 89 self = .uInt8 90 case kTfLiteInt16: 91 self = .int16 92 case kTfLiteInt32: 93 self = .int32 94 case kTfLiteInt64: 95 self = .int64 96 case kTfLiteFloat16: 97 self = .float16 98 case kTfLiteFloat32: 99 self = .float32 100 case kTfLiteFloat64: 101 self = .float64 102 case kTfLiteNoType: 103 fallthrough 104 default: 105 return nil 106 } 107 } 108 } 109 } 110 111 extension Tensor { 112 /// The shape of a `Tensor`. 113 public struct Shape: Equatable, Hashable { 114 /// The number of dimensions of the `Tensor`. 115 public let rank: Int 116 117 /// An array of dimensions for the `Tensor`. 118 public let dimensions: [Int] 119 120 /// An array of `Int32` dimensions for the `Tensor`. 121 var int32Dimensions: [Int32] { return dimensions.map(Int32.init) } 122 123 /// Creates a new instance with the given array of dimensions. 124 /// 125 /// - Parameters: 126 /// - dimensions: Dimensions for the `Tensor`. 127 public init(_ dimensions: [Int]) { 128 self.rank = dimensions.count 129 self.dimensions = dimensions 130 } 131 132 /// Creates a new instance with the given elements representing the dimensions. 133 /// 134 /// - Parameters: 135 /// - elements: Dimensions for the `Tensor`. 136 public init(_ elements: Int...) { 137 self.init(elements) 138 } 139 } 140 } 141 142 extension Tensor.Shape: ExpressibleByArrayLiteral { 143 /// Creates a new instance with the given array literal representing the dimensions. 144 /// 145 /// - Parameters: 146 /// - arrayLiteral: Dimensions for the `Tensor`. 147 public init(arrayLiteral: Int...) { 148 self.init(arrayLiteral) 149 } 150 } 151 152 /// A tensor's function level purpose: input or output. 153 internal enum TensorType: String { 154 case input 155 case output 156 } 157