1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_ 18 19 #import <Metal/Metal.h> 20 21 #include <list> 22 #include <map> 23 #include <vector> 24 25 #include "absl/container/flat_hash_map.h" 26 #include "tensorflow/lite/delegates/gpu/common/gpu_model.h" 27 #include "tensorflow/lite/delegates/gpu/common/gpu_model_generated.h" 28 #include "tensorflow/lite/delegates/gpu/common/model.h" 29 #include "tensorflow/lite/delegates/gpu/common/model_hints.h" 30 #include "tensorflow/lite/delegates/gpu/common/precision.h" 31 #include "tensorflow/lite/delegates/gpu/common/shape.h" 32 #include "tensorflow/lite/delegates/gpu/common/status.h" 33 #include "tensorflow/lite/delegates/gpu/common/task/profiling_info.h" 34 #include "tensorflow/lite/delegates/gpu/common/task/tuning_type.h" 35 #include "tensorflow/lite/delegates/gpu/metal/compute_task.h" 36 #include "tensorflow/lite/delegates/gpu/metal/inference_context_generated.h" 37 #include "tensorflow/lite/delegates/gpu/metal/metal_device.h" 38 #include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h" 39 40 namespace tflite { 41 namespace gpu { 42 namespace metal { 43 44 struct MetalNode { 45 ComputeTask task; 46 std::vector<ValueId> inputs; 47 std::vector<ValueId> outputs; 48 49 // Mostly for debug purposes. 50 std::string name; 51 52 MetalNode() = default; 53 54 MetalNode(MetalNode&& node) = default; 55 MetalNode& operator=(MetalNode&& node) = default; 56 MetalNode(const MetalNode&) = delete; 57 MetalNode& operator=(const MetalNode&) = delete; 58 }; 59 60 class InferenceContext { 61 public: 62 InferenceContext() = default; 63 64 // IMPORTANT: If InitFromGraph used, RunGraphTransforms must be applied for 65 // this graph upfront, otherwise not guaranteed correct behavior 66 absl::Status InitFromGraph(const CreateGpuModelInfo& create_info, 67 const GraphFloat32& graph, id<MTLDevice> device_id, 68 std::vector<uint8_t>* serialized_model = nullptr); 69 70 // Applies specific transformations to the graph before the 71 // initialization. These transformations are either impossible or useless in 72 // other backends. 73 absl::Status InitFromGraphWithTransforms( 74 const CreateGpuModelInfo& create_info, GraphFloat32* graph, 75 id<MTLDevice> device_id, 76 std::vector<uint8_t>* serialized_model = nullptr); 77 78 absl::Status RestoreDeserialized( 79 const absl::Span<const uint8_t> serialized_model, id<MTLDevice> device_id, 80 CreateGpuModelInfo* create_info = nullptr); 81 82 /// Inserts all GPU compute tasks into the command encoder. 83 /// @param inputOutputBuffers Must be created and passed into the method 84 /// with pairs ID:buffer 85 /// @discussion No GPU synchronization functions are used inside. All GPU 86 /// resources must be created 87 /// with the same device which has been used in 88 /// compileModelWithDevice() method. 89 void EncodeWithEncoder(id<MTLComputeCommandEncoder> command_encoder); 90 91 /// Inserts all GPU compute tasks into the command buffer. For every task will 92 /// be used separate 93 /// encoder. 94 /// @param inputOutputBuffers Must be created and passed into the method with 95 /// pairs ID:buffer 96 /// @discussion No GPU synchronization functions are used inside. All GPU 97 /// resources must be created 98 /// with the same device which has been used in 99 /// compileModelWithDevice() method. 100 void EncodeWithCommandBuffer(id<MTLCommandBuffer> command_buffer); 101 102 /// Adds all GPU compute tasks to the command queue. For every task will be 103 /// used separate 104 /// encoder. Few encoders(flushPeriod) batched into compute buffer that sent 105 /// for execution. 106 /// @param inputOutputBuffers Must be created and passed into the method with 107 /// pairs ID:buffer 108 /// @discussion No GPU synchronization functions are used inside. All GPU 109 /// resources must be created 110 /// with the same device which has been used in 111 /// compileModelWithDevice() method. 112 void EncodeWithCommandQueue(id<MTLCommandQueue> command_queue, 113 int flush_period); 114 115 API_AVAILABLE(ios(13.0), macos(11.00), tvos(13.0)) 116 void AddResources(id<MTLComputeCommandEncoder> command_encoder); 117 API_AVAILABLE(ios(13.0), macos(11.00), tvos(13.0)) 118 void EncodeWithICB(id<MTLComputeCommandEncoder> command_encoder); 119 120 void Profile(id<MTLDevice> device, ProfilingInfo* result); 121 // Returns size in bytes for all intermediate(runtime) tensors that owned by 122 // this inference context. Do not include constant tensors. 123 uint64_t GetIntermediateTensorsSize() const; 124 uint64_t GetConstantTensorsSize() const; 125 126 // Can be used only with ids from external_mutable_tensors in create_info 127 // Must be called after initialization and before execution 128 absl::Status SetTensor(const ValueId& tensor_id, 129 MetalSpatialTensor* tensor_ptr); 130 131 MetalSpatialTensor* GetTensor(ValueId tensor_id); 132 absl::Status SetInputTensor(ValueId id, const TensorFloat32& tensor); 133 absl::Status GetOutputTensor(ValueId id, TensorFloat32* result); 134 135 private: 136 enum class TensorMemoryType { 137 kStrongShape, 138 kBuffer, 139 kVariable, 140 kConst, 141 kExternal 142 }; 143 144 flatbuffers::Offset<data::InferenceContext> Encode( 145 MetalDevice* device, 146 flatbuffers::Offset<tflite::gpu::data::GpuModel> gpu_model_fb, 147 flatbuffers::FlatBufferBuilder* builder); 148 149 absl::Status Decode(MetalDevice* device, 150 const data::InferenceContext* fb_inference); 151 152 void CopyFromGpuModel(GpuModel* gpu_model); 153 absl::Status CompileOperations(MetalDevice* device); 154 void PrepareExternal(); 155 156 absl::Status AllocateTensors(MetalDevice* device); 157 absl::Status AllocateMemoryForConstTensors(MetalDevice* device); 158 absl::Status AllocateMemoryForBuffers(MetalDevice* device); 159 absl::Status AllocateMemoryForStrongShapes(MetalDevice* device); 160 void BindTensorsToOperations(); 161 absl::Status UpdateParams(const GpuInfo& gpu_info); 162 void GetUsages(const std::function<bool(ValueId)>& functor, 163 std::map<ValueId, int2>* usages); 164 TensorMemoryType GetTensorMemoryType(ValueId id); 165 absl::Status Tune(TuningType tuning_type, MetalDevice* device); 166 167 absl::flat_hash_map<ValueId, TensorDescriptor> tensors_descs_; 168 169 std::vector<MetalNode> nodes_; 170 std::vector<ValueId> input_ids_; 171 std::vector<ValueId> output_ids_; 172 173 absl::flat_hash_map<ValueId, MetalSpatialTensor*> external_immutable_tensors_; 174 absl::flat_hash_map<ValueId, MetalSpatialTensor*> external_mutable_tensors_; 175 absl::flat_hash_map<ValueId, std::vector<int>> external_tensor_to_nodes_; 176 absl::flat_hash_map<ValueId, TensorDescriptor> const_tensors_descs_; 177 std::map<ValueId, MetalSpatialTensor> const_tensors_; 178 179 std::map<ValueId, int> graph_ids_to_shared_buffer_tensors_; 180 std::vector<id<MTLBuffer>> shared_buffers_; 181 std::vector<MetalSpatialTensor> 182 shared_buffer_tensors_; // use references to memory 183 // from _sharedBuffers 184 185 std::map<ValueId, MetalSpatialTensor> strong_shape_tensors_; 186 std::map<ValueId, ValueId> graph_ids_to_strong_shape_tensors_; 187 188 id<MTLIndirectCommandBuffer> icb_ = nullptr; 189 id<MTLDevice> device_ = nullptr; 190 }; 191 192 } // namespace metal 193 } // namespace gpu 194 } // namespace tflite 195 196 #endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_ 197