1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the Licensgoe is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_SPATIAL_TENSOR_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_SPATIAL_TENSOR_H_ 18 19 #import <Metal/Metal.h> 20 21 #include "tensorflow/lite/delegates/gpu/common/status.h" 22 #include "tensorflow/lite/delegates/gpu/common/task/gpu_tensor.h" 23 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" 24 #include "tensorflow/lite/delegates/gpu/common/util.h" 25 #include "tensorflow/lite/delegates/gpu/metal/common.h" 26 #include "tensorflow/lite/delegates/gpu/metal/gpu_object.h" 27 28 namespace tflite { 29 namespace gpu { 30 namespace metal { 31 32 class MetalSpatialTensor : public GPUObject, public GpuSpatialTensor { 33 public: MetalSpatialTensor()34 MetalSpatialTensor() 35 : memory_(nullptr), 36 texture_mem_(nullptr), 37 memory_owner_(true), 38 texture_mem_owner_(true) {} 39 MetalSpatialTensor(id<MTLBuffer> buffer, id<MTLTexture> texture, 40 bool memory_owner, bool texture_mem_owner, 41 const TensorDescriptor& descriptor); 42 43 // Move only 44 MetalSpatialTensor(MetalSpatialTensor&& tensor); 45 MetalSpatialTensor& operator=(MetalSpatialTensor&& tensor); 46 MetalSpatialTensor(const MetalSpatialTensor&) = delete; 47 MetalSpatialTensor& operator=(const MetalSpatialTensor&) = delete; 48 ~MetalSpatialTensor()49 ~MetalSpatialTensor() override { Release(); } 50 51 absl::Status GetGPUResources(const GPUObjectDescriptor* obj_ptr, 52 GPUResourcesWithValue* resources) const override; 53 Width()54 int Width() const override { return descriptor_.GetBHWDCShape().w; } Height()55 int Height() const override { return descriptor_.GetBHWDCShape().h; } Depth()56 int Depth() const override { return descriptor_.GetBHWDCShape().d; } Channels()57 int Channels() const override { return descriptor_.GetBHWDCShape().c; } Slices()58 int Slices() const override { 59 return DivideRoundUp(descriptor_.GetBHWDCShape().c, 4); 60 } Batch()61 int Batch() const override { return descriptor_.GetBHWDCShape().b; } 62 GetDescriptor()63 TensorDescriptor GetDescriptor() const override { return descriptor_; } GetDataType()64 DataType GetDataType() const { return descriptor_.GetDataType(); } GetStorageType()65 TensorStorageType GetStorageType() const { 66 return descriptor_.GetStorageType(); 67 } GetMemorySizeInBytes()68 uint64_t GetMemorySizeInBytes() const { 69 return descriptor_.GetMemorySizeInBytes(); 70 } 71 72 absl::Status CreateFromDescriptor(const TensorDescriptor& desc, 73 id<MTLDevice> device); 74 absl::Status UploadDescriptorData(const TensorDescriptor& desc, 75 id<MTLDevice> device); 76 absl::Status ToDescriptor(TensorDescriptor* desc, id<MTLDevice> device) const; 77 78 absl::Status SetBufferHandle(id<MTLBuffer> buffer); 79 id<MTLBuffer> GetBufferHandle() const; 80 81 private: 82 friend absl::Status CreateTensorSharedBuffer( 83 id<MTLBuffer> buffer, const TensorDescriptor& descriptor, 84 MetalSpatialTensor* result, uint64_t buffer_offset); 85 86 friend absl::Status CreateTensorSharedImage2DBuffer( 87 id<MTLBuffer> buffer, const TensorDescriptor& descriptor, 88 int row_bytes_alignment, MetalSpatialTensor* result, 89 uint64_t buffer_offset); 90 91 absl::Status WriteData(id<MTLDevice> device, const void* ptr); 92 absl::Status ReadData(id<MTLDevice> device, void* ptr) const; 93 94 void Release(); 95 96 id<MTLBuffer> memory_; 97 id<MTLTexture> texture_mem_; 98 bool memory_owner_; 99 bool texture_mem_owner_; 100 TensorDescriptor descriptor_; 101 // for use with TEXTURE_2D and when texture created from buffer. 102 int aligned_texture_width_; 103 // used when created from shared buffer 104 uint64_t buffer_offset_ = 0; 105 }; 106 107 absl::Status CreateTensor(id<MTLDevice> device, 108 const TensorDescriptor& descriptor, 109 MetalSpatialTensor* result); 110 111 absl::Status CreateTensorSharedBuffer(id<MTLBuffer> buffer, 112 const TensorDescriptor& descriptor, 113 MetalSpatialTensor* result, 114 uint64_t buffer_offset = 0); 115 116 absl::Status CreateTensorSharedImage2DBuffer(id<MTLBuffer> buffer, 117 const TensorDescriptor& descriptor, 118 int row_bytes_alignment, 119 MetalSpatialTensor* result, 120 uint64_t buffer_offset = 0); 121 122 TensorStorageType GetFastestStorageType(const GpuInfo& gpu_info); 123 124 } // namespace metal 125 } // namespace gpu 126 } // namespace tflite 127 128 #endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_SPATIAL_TENSOR_H_ 129