xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/api/containers/Tensor.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/api/containers/Tensor.h>
10 
11 namespace vkcompute {
12 namespace api {
13 
calculate_sizes(const vkapi::VulkanImage & image,const utils::GPUMemoryLayout memory_layout)14 std::vector<int64_t> calculate_sizes(
15     const vkapi::VulkanImage& image,
16     const utils::GPUMemoryLayout memory_layout) {
17   auto sizes = std::vector<int64_t>{
18       image.extents().width, image.extents().height, image.extents().depth};
19   const auto packed_dim = utils::to_packed_dim<int32_t>(memory_layout);
20   sizes.at(packed_dim) *= 4;
21   return sizes;
22 }
23 
calculate_dim_order(const size_t ndim,const int32_t packed_dim)24 std::vector<int64_t> calculate_dim_order(
25     const size_t ndim,
26     const int32_t packed_dim) {
27   // Special case for zero dim tensors
28   if (ndim == 0) {
29     return {0};
30   }
31   std::vector<int64_t> dim_order(ndim);
32   // Explicitly convert ndim to signed to prevent underflow
33   int64_t last_dim = int64_t(ndim) - 1 - packed_dim;
34 
35   int64_t cur_dim = 0;
36   for (int d = 0; d < ndim; ++d) {
37     if (d == last_dim) {
38       cur_dim++;
39     }
40     dim_order[d] = cur_dim;
41     cur_dim++;
42   }
43   if (last_dim >= 0) {
44     dim_order[ndim - 1] = last_dim;
45   }
46 
47   return dim_order;
48 }
49 
calculate_strides(const std::vector<int64_t> & sizes,const std::vector<int64_t> & dim_order)50 std::vector<int64_t> calculate_strides(
51     const std::vector<int64_t>& sizes,
52     const std::vector<int64_t>& dim_order) {
53   // For zero dim tensors
54   if (sizes.size() == 0) {
55     return {1};
56   }
57 
58   size_t ndim = sizes.size();
59   std::vector<int64_t> strides(ndim);
60 
61   strides[dim_order[ndim - 1]] = 1;
62   for (int32_t i = ndim - 2; i >= 0; --i) {
63     if (sizes[dim_order[i + 1]] == 0) {
64       strides[dim_order[i]] = strides[dim_order[i + 1]];
65     } else {
66       strides[dim_order[i]] =
67           strides[dim_order[i + 1]] * sizes[dim_order[i + 1]];
68     }
69   }
70 
71   return strides;
72 }
73 
74 /*
75  * Axis mapping is somewhat analogous to strides for texture backed tensors.
76  *
77  * The axis mapping is normalized to 4 dimensions, similar to the padded sizes.
78  * The first 3 values of the axis mapping indicate the (X,Y,Z) image texture
79  * axis that corresponds to the width, height, and channels dimension of the
80  * tensor. Thus the axis mapping can be considered to be in WHCN dimension
81  * order.
82  *
83  * The last value `axis_map.at(3)` indicates the WHCN index of the tensor
84  * dimension along which batches will be concatenated. This dimension can be
85  * referred to as the "inner dimension" To determine which image texture axis is
86  * used for the concatenation, a double lookup will need to be performed
87  * (axis_map.at(axis_map.at(3))).
88  *
89  * The reason for strucuring axis mapping this way is because for the batch dim,
90  * two things need to be easily derived:
91  *
92  * 1. The dim idx of the inner dimension, so that the size of the inner
93  *    dimension can be easily determined.
94  * 2. The texture axis used to concatenate batches
95  *
96  * By storing the dim index of the inner dimension instead of the texture axis
97  * it maps to, both pieces of information are readily available.
98  *
99  * The axis mapping allows for permuted views of texture-backed tensors.
100  */
default_axis_map()101 std::vector<int64_t> default_axis_map() {
102   // Currently, all compute shaders have an assumption that the channels dim is
103   // used to combine with the batch dim of a tensor. However, once dim mapping
104   // is integrated into the tensor indexing logic for each compute shader, we
105   // can be more flexible with mapping the batch dim to different texture axes
106   // in order to improve performance or memory footprint.
107   return {0, 1, 2, 2};
108 }
109 
dim_order_is_valid(const std::vector<int64_t> & dim_order)110 bool dim_order_is_valid(const std::vector<int64_t>& dim_order) {
111   int64_t sum = 0;
112   for (size_t i = 0; i < dim_order.size(); ++i) {
113     if (dim_order[i] < 0 || dim_order[i] >= dim_order.size()) {
114       return false;
115     }
116     sum += dim_order[i];
117   }
118   int64_t n = static_cast<int64_t>(dim_order.size() - 1);
119   // Sanity check that the sum of the indices in the vector is equal to the sum
120   // of 0 + 1 + 2 + ... + (ndim - 1)
121   return sum == n * (n + 1) / 2;
122 }
123 
unsqueeze_strides(const std::vector<int64_t> & strides,const int64_t numel)124 std::vector<int64_t> unsqueeze_strides(
125     const std::vector<int64_t>& strides,
126     const int64_t numel) {
127   const size_t ndim = strides.size();
128   const size_t ndim_up4 = utils::align_up_4(strides.size());
129   std::vector<int64_t> unsqueezed_strides(ndim_up4);
130   for (int32_t i = 1; i <= ndim; ++i) {
131     int64_t dim_stride = strides.at(ndim - i);
132     unsqueezed_strides.at(ndim_up4 - i) = dim_stride;
133   }
134 
135   for (int32_t i = ndim + 1; i <= ndim_up4; ++i) {
136     unsqueezed_strides.at(ndim_up4 - i) = numel;
137   }
138   return unsqueezed_strides;
139 }
140 
calculate_padded_sizes(const std::vector<int64_t> & sizes,const int32_t packed_dim)141 std::vector<int64_t> calculate_padded_sizes(
142     const std::vector<int64_t>& sizes,
143     const int32_t packed_dim) {
144   int64_t ndim = sizes.size();
145   if (ndim == 0) {
146     ndim = 1;
147   }
148 
149   // Tensor sizes will be unsqueezed up to the next multiple of 4
150   const int64_t ndim_up4 = utils::align_up_4(ndim);
151   std::vector<int64_t> padded_sizes(ndim_up4);
152   for (int64_t i = 0; i < ndim_up4; ++i) {
153     padded_sizes.at(i) = utils::val_at(i - ndim_up4, sizes);
154   }
155 
156   // Pad the packed dim to the next multiple of 4.
157   const int64_t dim_offset = packed_dim + 1;
158   const int64_t padded_dim_size = utils::val_at(-dim_offset, sizes);
159   padded_sizes.at(ndim_up4 - dim_offset) = utils::align_up_4(padded_dim_size);
160 
161   return padded_sizes;
162 }
163 
calculate_image_extents(const std::vector<int64_t> & padded_sizes,const std::vector<int64_t> & axis_map,const int32_t packed_dim)164 utils::uvec3 calculate_image_extents(
165     const std::vector<int64_t>& padded_sizes,
166     const std::vector<int64_t>& axis_map,
167     const int32_t packed_dim) {
168   VK_CHECK_COND(padded_sizes.size() == 4);
169   VK_CHECK_COND(axis_map.size() == 4);
170 
171   utils::uvec3 extents({1, 1, 1});
172   // First three elements of axis_map indicate which (X,Y,Z) image axis the
173   // width, height, and channels dim of the tensor maps to.
174   for (int whcn_dim = 0; whcn_dim < 3; ++whcn_dim) {
175     const int64_t axis = axis_map.at(whcn_dim);
176     const int64_t dim = padded_sizes.size() - 1 - whcn_dim;
177     extents[axis] = utils::safe_downcast<uint32_t>(padded_sizes.at(dim));
178   }
179 
180   // axis_map[3] indicates the WHCN index of the dimension used for batch
181   // concatenation. Thus a double lookup is required to determine the image axis
182   // used for batch concatenation.
183   const int64_t concatted_whcn_dim = axis_map.at(3);
184   const int64_t batch_axis = axis_map.at(concatted_whcn_dim);
185   // Multiply the extents of the batch axis by the batch size.
186   extents[batch_axis] *= padded_sizes.at(0);
187 
188   VK_CHECK_COND(extents[axis_map.at(packed_dim)] % 4 == 0);
189   extents[axis_map.at(packed_dim)] /= 4;
190   return extents;
191 }
192 
193 //
194 // vTensorStorage
195 //
196 
storage_type(const vkapi::VulkanImage & image)197 utils::StorageType storage_type(const vkapi::VulkanImage& image) {
198   const auto type = image.type();
199   switch (type) {
200     case VK_IMAGE_TYPE_3D:
201       return utils::kTexture3D;
202     case VK_IMAGE_TYPE_2D:
203       return utils::kTexture2D;
204     default:
205       VK_THROW("Unsupported image type", type);
206   }
207 }
208 
allocate_image(Context * const context_ptr,utils::uvec3 & image_extents,const utils::StorageType storage_type,const VkFormat image_format,const bool allocate_memory)209 vkapi::VulkanImage allocate_image(
210     Context* const context_ptr,
211     utils::uvec3& image_extents,
212     const utils::StorageType storage_type,
213     const VkFormat image_format,
214     const bool allocate_memory) {
215   vkapi::Adapter* adapter_ptr = context_ptr->adapter_ptr();
216 
217   vkapi::ImageSampler::Properties sampler_props{
218       VK_FILTER_NEAREST,
219       VK_SAMPLER_MIPMAP_MODE_NEAREST,
220       VK_SAMPLER_ADDRESS_MODE_REPEAT,
221       VK_BORDER_COLOR_FLOAT_TRANSPARENT_BLACK,
222   };
223 
224   VkImageType image_type = VK_IMAGE_TYPE_3D;
225   VkImageViewType image_view_type;
226 
227   switch (storage_type) {
228     case utils::kTexture3D:
229       image_type = VK_IMAGE_TYPE_3D;
230       image_view_type = VK_IMAGE_VIEW_TYPE_3D;
231       break;
232     case utils::kTexture2D:
233       image_type = VK_IMAGE_TYPE_2D;
234       image_view_type = VK_IMAGE_VIEW_TYPE_2D;
235       break;
236     default:
237       // Return an empty VulkanImage by default
238       return vkapi::VulkanImage();
239   }
240 
241   VkSampler sampler = adapter_ptr->sampler_cache().retrieve(sampler_props);
242 
243   return adapter_ptr->vma().create_image(
244       context_ptr->device(),
245       vkapi::create_extent3d(image_extents),
246       image_format,
247       image_type,
248       context_ptr->preferred_image_tiling(),
249       image_view_type,
250       sampler_props,
251       sampler,
252       /*allow_transfer = */ true,
253       /*allocate_memory = */ allocate_memory);
254 }
255 
allocate_buffer(Context * const context_ptr,const int64_t numel,const utils::StorageType storage_type,const vkapi::ScalarType dtype,const bool allocate_memory)256 vkapi::VulkanBuffer allocate_buffer(
257     Context* const context_ptr,
258     const int64_t numel,
259     const utils::StorageType storage_type,
260     const vkapi::ScalarType dtype,
261     const bool allocate_memory) {
262   vkapi::Adapter* adapter_ptr = context_ptr->adapter_ptr();
263 
264   switch (storage_type) {
265     case utils::kBuffer:
266       break;
267     default:
268       // Return an empty VulkanBuffer if Buffer storage is not used
269       return vkapi::VulkanBuffer();
270   }
271 
272   return adapter_ptr->vma().create_storage_buffer(
273       element_size(dtype) * numel, allocate_memory);
274 }
275 
vTensorStorage(Context * const context,const utils::StorageType storage_type,const std::vector<int64_t> & axis_map,const int32_t packed_dim,const std::vector<int64_t> & padded_sizes,const vkapi::ScalarType dtype,const bool allocate_memory)276 vTensorStorage::vTensorStorage(
277     Context* const context,
278     const utils::StorageType storage_type,
279     const std::vector<int64_t>& axis_map,
280     const int32_t packed_dim,
281     const std::vector<int64_t>& padded_sizes,
282     const vkapi::ScalarType dtype,
283     const bool allocate_memory)
284     : context_(context),
285       storage_type_{storage_type},
286       image_extents_(
287           calculate_image_extents(padded_sizes, axis_map, packed_dim)),
288       buffer_length_{utils::multiply_integers(padded_sizes)},
289       buffer_offset_{0},
290       image_(allocate_image(
291           context_,
292           image_extents_,
293           storage_type_,
294           to_vkformat(dtype),
295           allocate_memory)),
296       buffer_(allocate_buffer(
297           context_,
298           buffer_length_,
299           storage_type_,
300           dtype,
301           allocate_memory)),
302       last_access_{},
303       has_copies_{false} {}
304 
vTensorStorage(Context * const context,const vkapi::VulkanImage & image)305 vTensorStorage::vTensorStorage(
306     Context* const context,
307     const vkapi::VulkanImage& image)
308     : context_(context),
309       storage_type_{storage_type(image)},
310       image_extents_(
311           {image.extents().width,
312            image.extents().height,
313            image.extents().depth}),
314       buffer_length_{0},
315       buffer_offset_{0},
316       image_(image),
317       buffer_(vkapi::VulkanBuffer()),
318       last_access_{} {}
319 
vTensorStorage(vTensorStorage & other,const int64_t buffer_offset)320 vTensorStorage::vTensorStorage(
321     vTensorStorage& other,
322     const int64_t buffer_offset)
323     : context_(other.context_),
324       storage_type_{other.storage_type_},
325       image_extents_(other.image_extents_),
326       buffer_length_{other.buffer_length_},
327       buffer_offset_{buffer_offset},
328       image_(other.image_),
329       buffer_(other.buffer_, buffer_offset),
330       last_access_{other.last_access_},
331       has_copies_{false} {
332   other.has_copies_ = true;
333 }
334 
~vTensorStorage()335 vTensorStorage::~vTensorStorage() {
336   flush();
337 }
338 
flush()339 void vTensorStorage::flush() {
340   if (image_) {
341     context_->register_image_cleanup(image_);
342   } else if (buffer_) {
343     context_->register_buffer_cleanup(buffer_);
344   }
345   last_access_ = {};
346 }
347 
transition(vkapi::PipelineBarrier & pipeline_barrier,const vkapi::PipelineStageFlags cur_stage,const vkapi::MemoryAccessFlags cur_access)348 void vTensorStorage::transition(
349     vkapi::PipelineBarrier& pipeline_barrier,
350     const vkapi::PipelineStageFlags cur_stage,
351     const vkapi::MemoryAccessFlags cur_access) {
352   // Get last stage access
353   vkapi::PipelineStageFlags prev_stage = last_access_.stage;
354   vkapi::MemoryAccessFlags prev_access = last_access_.access;
355 
356   // If the underlying resource is a copy of another tensor's resource the
357   // last_access may not be accurate, since the original storage may have been
358   // written to as part of the original tensor. Likewise, if the underlying
359   // resource has copies, then the resource may have been updated as part of the
360   // view tensors.
361   //
362   // If the resource is a copy, or has copies of it, then cowardly assume that
363   // it has previously been written to as part of a compute shader before the
364   // current access event so that the appropriate memory barriers may be
365   // inserted.
366   if (is_copy() || has_copies_) {
367     prev_stage = vkapi::PipelineStage::COMPUTE;
368     prev_access = vkapi::kWrite;
369   }
370 
371   const bool prev_written = (prev_access & vkapi::MemoryAccessType::WRITE) != 0;
372 
373   VkImageLayout cur_layout = VK_IMAGE_LAYOUT_UNDEFINED;
374   VkImageLayout new_layout = VK_IMAGE_LAYOUT_UNDEFINED;
375   bool layout_changed = false;
376   if (image_) {
377     cur_layout = image_.layout();
378     new_layout = vkapi::vk_layout(cur_stage, cur_access);
379 
380     layout_changed = cur_layout != new_layout;
381   }
382 
383   if (prev_written || layout_changed) {
384     VkPipelineStageFlags src_stage = vkapi::vk_stage(prev_stage);
385     if (0u == src_stage) {
386       src_stage = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT;
387     }
388     VkPipelineStageFlags dst_stage = vkapi::vk_stage(cur_stage);
389     if (0u == dst_stage) {
390       dst_stage = VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT;
391     }
392 
393     pipeline_barrier.stage.src |= src_stage;
394     pipeline_barrier.stage.dst |= dst_stage;
395 
396     if (image_) {
397       pipeline_barrier.images.emplace_back(
398           vkapi::vk_access(prev_stage, prev_access),
399           vkapi::vk_access(cur_stage, cur_access),
400           cur_layout,
401           new_layout,
402           image_);
403 
404       image_.set_layout(new_layout);
405     } else if (buffer_) {
406       pipeline_barrier.buffers.emplace_back(
407           vkapi::vk_access(prev_stage, prev_access),
408           vkapi::vk_access(cur_stage, cur_access),
409           buffer_);
410     }
411   }
412 
413   last_access_.stage = cur_stage;
414   last_access_.access = cur_access;
415 }
416 
is_copy() const417 bool vTensorStorage::is_copy() const {
418   if (storage_type_ == utils::kBuffer) {
419     return buffer_.is_copy();
420   }
421   return image_.is_copy();
422 }
423 
is_copy_of(const vTensorStorage & other) const424 bool vTensorStorage::is_copy_of(const vTensorStorage& other) const {
425   if (storage_type_ == utils::kBuffer) {
426     return buffer_.is_copy_of(other.buffer_);
427   }
428   return image_.is_copy_of(other.image_);
429 }
430 
431 //
432 // vTensor
433 //
434 
vTensor(Context * const context,const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const utils::StorageType storage_type,const utils::GPUMemoryLayout memory_layout,const bool allocate_memory)435 vTensor::vTensor(
436     Context* const context,
437     const std::vector<int64_t>& sizes,
438     const vkapi::ScalarType dtype,
439     const utils::StorageType storage_type,
440     const utils::GPUMemoryLayout memory_layout,
441     const bool allocate_memory)
442     : dtype_(dtype),
443       // Calculate tensor metadata
444       sizes_(sizes.begin(), sizes.end()),
445       packed_dim_(utils::to_packed_dim<int32_t>(memory_layout)),
446       dim_order_(calculate_dim_order(sizes_.size(), packed_dim_)),
447       axis_map_(default_axis_map()),
448       strides_(calculate_strides(sizes, dim_order_)),
449       numel_(utils::multiply_integers(sizes_)),
450       padded_sizes_{calculate_padded_sizes(sizes, packed_dim_)},
451       unsqueezed_strides_{unsqueeze_strides(strides_, numel_)},
452       padded_numel_(utils::multiply_integers(padded_sizes_)),
453       logical_limits_{{0, 0, 0}},
454       // Utility Uniform Buffers that can be passed to shaders as arguments
455       sizes_uniform_(),
456       strides_uniform_(),
457       numel_uniform_(),
458       logical_limits_uniform_(),
459       // Construct Tensor storage
460       storage_(
461           context,
462           storage_type,
463           axis_map_,
464           packed_dim_,
465           padded_sizes_,
466           dtype_,
467           allocate_memory) {
468   VK_CHECK_COND(
469       dim_order_is_valid(dim_order_), "computed dim order is invalid");
470 
471   if (storage_type != utils::kBuffer) {
472     set_logical_limits(storage_.image_extents_);
473   }
474 
475   if (dtype == vkapi::kHalf) {
476     VK_CHECK_COND(
477         api::context()->adapter_ptr()->supports_16bit_storage_buffers(),
478         "Half dtype is only available if the physical device supports float16 "
479         "storage buffers!");
480   }
481 }
482 
483 // NOLINTNEXTLINE
vTensor(Context * context,const vkapi::VulkanImage & image,const utils::GPUMemoryLayout memory_layout)484 vTensor::vTensor(
485     Context* context,
486     const vkapi::VulkanImage& image,
487     const utils::GPUMemoryLayout memory_layout)
488     : dtype_(vkapi::element_scalartype(image.format())),
489       // Calculate tensor metadata
490       sizes_(calculate_sizes(image, memory_layout)),
491       packed_dim_(utils::to_packed_dim<int32_t>(memory_layout)),
492       dim_order_(),
493       axis_map_(default_axis_map()),
494       strides_(),
495       numel_(utils::multiply_integers(sizes_)),
496       padded_sizes_(calculate_padded_sizes(sizes_, packed_dim_)),
497       unsqueezed_strides_(),
498       padded_numel_(utils::multiply_integers(padded_sizes_)),
499       logical_limits_(),
500       // Utility Uniform Buffers that can be passed to shaders as arguments
501       sizes_uniform_(),
502       strides_uniform_(),
503       numel_uniform_(),
504       logical_limits_uniform_(),
505       // Construct Tensor storage
506       storage_(context, image) {
507   set_logical_limits(storage_.image_extents_);
508 }
509 
vTensor(vTensor & other)510 vTensor::vTensor(vTensor& other)
511     : dtype_(other.dtype_),
512       // Copy tensor size metadata
513       sizes_(other.sizes_.begin(), other.sizes_.end()),
514       packed_dim_{other.packed_dim_},
515       dim_order_(other.dim_order_.begin(), other.dim_order_.end()),
516       axis_map_(other.axis_map_.begin(), other.axis_map_.end()),
517       strides_(other.strides_.begin(), other.strides_.end()),
518       numel_(other.numel_),
519       padded_sizes_{other.padded_sizes_.begin(), other.padded_sizes_.end()},
520       unsqueezed_strides_{
521           other.unsqueezed_strides_.begin(),
522           other.unsqueezed_strides_.end()},
523       padded_numel_(other.padded_numel_),
524       logical_limits_{other.logical_limits_},
525       // Empty initialize Utility Uniform Buffers
526       sizes_uniform_(),
527       strides_uniform_(),
528       numel_uniform_(),
529       logical_limits_uniform_(),
530       // Copy Tensor storage
531       storage_(other.storage_) {}
532 
vTensor(vTensor & other,const std::vector<int64_t> & sizes,const std::vector<int64_t> & dim_order,const int64_t offset_numel)533 vTensor::vTensor(
534     vTensor& other,
535     const std::vector<int64_t>& sizes,
536     const std::vector<int64_t>& dim_order,
537     const int64_t offset_numel)
538     : dtype_(other.dtype_),
539       // Copy tensor size metadata
540       sizes_(sizes.begin(), sizes.end()),
541       packed_dim_(other.packed_dim_),
542       dim_order_(dim_order.begin(), dim_order.end()),
543       axis_map_(default_axis_map()),
544       strides_(calculate_strides(sizes_, dim_order_)),
545       numel_(utils::multiply_integers(sizes_)),
546       padded_sizes_{calculate_padded_sizes(sizes, packed_dim_)},
547       unsqueezed_strides_{unsqueeze_strides(strides_, numel_)},
548       padded_numel_(utils::multiply_integers(padded_sizes_)),
549       logical_limits_(other.logical_limits_),
550       // Empty initialize Utility Uniform Buffers
551       sizes_uniform_(),
552       strides_uniform_(),
553       numel_uniform_(),
554       logical_limits_uniform_(),
555       // Copy Tensor storage
556       storage_(other.storage_, vkapi::element_size(dtype_) * offset_numel) {
557   VK_CHECK_COND(
558       dim_order_is_valid(dim_order_), "new dim order provided is invalid");
559   VK_CHECK_COND(
560       offset_numel + numel_ <= other.numel(),
561       "Tensor alias cannot access more elements than available in the original"
562       "tensor");
563 }
564 
image(vkapi::PipelineBarrier & pipeline_barrier,const vkapi::PipelineStageFlags stage)565 vkapi::VulkanImage& vTensor::image(
566     vkapi::PipelineBarrier& pipeline_barrier,
567     const vkapi::PipelineStageFlags stage) & {
568   storage_.transition(pipeline_barrier, stage, vkapi::MemoryAccessType::READ);
569   return storage_.image_;
570 }
571 
image(vkapi::PipelineBarrier & pipeline_barrier,const vkapi::PipelineStageFlags stage,const vkapi::MemoryAccessFlags access)572 vkapi::VulkanImage& vTensor::image(
573     vkapi::PipelineBarrier& pipeline_barrier,
574     const vkapi::PipelineStageFlags stage,
575     const vkapi::MemoryAccessFlags access) & {
576   storage_.transition(pipeline_barrier, stage, access);
577   return storage_.image_;
578 }
579 
buffer(vkapi::PipelineBarrier & pipeline_barrier,const vkapi::PipelineStageFlags stage)580 vkapi::VulkanBuffer& vTensor::buffer(
581     vkapi::PipelineBarrier& pipeline_barrier,
582     const vkapi::PipelineStageFlags stage) & {
583   storage_.transition(pipeline_barrier, stage, vkapi::MemoryAccessType::READ);
584   return storage_.buffer_;
585 }
586 
buffer(vkapi::PipelineBarrier & pipeline_barrier,const vkapi::PipelineStageFlags stage,const vkapi::MemoryAccessFlags access)587 vkapi::VulkanBuffer& vTensor::buffer(
588     vkapi::PipelineBarrier& pipeline_barrier,
589     const vkapi::PipelineStageFlags stage,
590     const vkapi::MemoryAccessFlags access) & {
591   storage_.transition(pipeline_barrier, stage, access);
592   return storage_.buffer_;
593 }
594 
set_logical_limits(const utils::uvec3 & image_extents)595 void vTensor::set_logical_limits(const utils::uvec3& image_extents) {
596   logical_limits_.limits[0] = image_extents[axis_map_.at(0)];
597   logical_limits_.limits[1] = image_extents[axis_map_.at(1)];
598   logical_limits_.limits[2] = image_extents[axis_map_.at(2)];
599 }
600 
estimate_memory_layout() const601 utils::GPUMemoryLayout vTensor::estimate_memory_layout() const {
602   switch (packed_dim_) {
603     case WHCN::kWidthDim:
604       return utils::kWidthPacked;
605     case WHCN::kHeightDim:
606       return utils::kHeightPacked;
607     case WHCN::kChannelsDim:
608       return utils::kChannelsPacked;
609     default:
610       VK_THROW("Invalid packed dim");
611   }
612 }
613 
sizes_ubo()614 const vkapi::BufferBindInfo vTensor::sizes_ubo() {
615   if (!sizes_uniform_.buffer()) {
616     sizes_uniform_ =
617         ParamsBuffer(storage_.context_, utils::make_whcn_ivec4(sizes_));
618   }
619   return vkapi::BufferBindInfo(sizes_uniform_.buffer());
620 }
621 
strides_ubo()622 const vkapi::BufferBindInfo vTensor::strides_ubo() {
623   if (!strides_uniform_.buffer()) {
624     strides_uniform_ = ParamsBuffer(
625         storage_.context_, utils::make_whcn_ivec4(unsqueezed_strides_));
626   }
627   return vkapi::BufferBindInfo(strides_uniform_.buffer());
628 }
629 
logical_limits_ubo()630 const vkapi::BufferBindInfo vTensor::logical_limits_ubo() {
631   if (!logical_limits_uniform_.buffer()) {
632     logical_limits_uniform_ = ParamsBuffer(storage_.context_, logical_limits_);
633   }
634   return vkapi::BufferBindInfo(logical_limits_uniform_.buffer());
635 }
636 
numel_ubo()637 const vkapi::BufferBindInfo vTensor::numel_ubo() {
638   if (!numel_uniform_.buffer()) {
639     numel_uniform_ = ParamsBuffer(storage_.context_, numel_);
640   }
641   return vkapi::BufferBindInfo(numel_uniform_.buffer());
642 }
643 
staging_buffer_numel() const644 size_t vTensor::staging_buffer_numel() const {
645   const bool is_int8 = dtype_ == vkapi::kChar;
646   const bool int8_supported =
647       storage_.context_->adapter_ptr()->has_full_int8_buffers_support();
648   if (is_int8 && !int8_supported) {
649     return utils::align_up_4(numel_);
650   }
651   if (storage_type() == utils::kBuffer) {
652     return numel_;
653   }
654   return padded_numel_;
655 }
656 
get_memory_requirements() const657 VkMemoryRequirements vTensor::get_memory_requirements() const {
658   switch (storage_type()) {
659     case utils::kBuffer:
660       return storage_.buffer_.get_memory_requirements();
661     case utils::kTexture2D:
662     case utils::kTexture3D:
663       return storage_.image_.get_memory_requirements();
664   }
665   return {};
666 }
667 
bind_allocation(const vkapi::Allocation & allocation)668 void vTensor::bind_allocation(const vkapi::Allocation& allocation) {
669   switch (storage_type()) {
670     case utils::kBuffer:
671       storage_.buffer_.bind_allocation(allocation);
672       break;
673     case utils::kTexture2D:
674     case utils::kTexture3D:
675       storage_.image_.bind_allocation(allocation);
676       break;
677   }
678 }
679 
update_metadata()680 void vTensor::update_metadata() {
681   strides_ = calculate_strides(sizes_, dim_order_);
682   numel_ = utils::multiply_integers(sizes_);
683 
684   padded_sizes_ = calculate_padded_sizes(sizes_, packed_dim_);
685   unsqueezed_strides_ = unsqueeze_strides(strides_, numel_);
686   padded_numel_ = utils::multiply_integers(padded_sizes_);
687 
688   // Calculate the image extents that would have been used to allocate a texture
689   // withthe current sizes, and use that to set the logical limits.
690   set_logical_limits(
691       calculate_image_extents(padded_sizes_, axis_map_, packed_dim_));
692 
693   if (sizes_uniform_.buffer()) {
694     sizes_uniform_.update(utils::make_whcn_ivec4(sizes_));
695   }
696   if (strides_uniform_.buffer()) {
697     strides_uniform_.update(utils::make_whcn_ivec4(unsqueezed_strides_));
698   }
699   if (numel_uniform_.buffer()) {
700     numel_uniform_.update(numel_);
701   }
702   if (logical_limits_uniform_.buffer()) {
703     logical_limits_uniform_.update(logical_limits_);
704   }
705 }
706 
check_sizes(const std::vector<int64_t> & sizes) const707 void vTensor::check_sizes(const std::vector<int64_t>& sizes) const {
708   if (storage_type() != utils::kBuffer) {
709     // For texture storage check that the current texture is large enough for
710     // the new sizes of the tensor.
711     utils::uvec3 virtual_extents =
712         calculate_image_extents(padded_sizes_, axis_map_, packed_dim_);
713 
714     bool valid_resize = virtual_extents[0] <= storage_.image_extents_[0];
715     valid_resize =
716         valid_resize && virtual_extents[1] <= storage_.image_extents_[1];
717     valid_resize =
718         valid_resize && virtual_extents[2] <= storage_.image_extents_[2];
719 
720     VK_CHECK_COND(
721         valid_resize,
722         "tensor sizes requires a larger texture than the current one.");
723   } else {
724     // For buffer storage check that the current buffer is large enough for the
725     // new sizes of the tensor.
726     int64_t numel = utils::multiply_integers(sizes);
727     bool valid_resize =
728         numel + storage_.buffer_offset_ <= storage_.buffer_length_;
729     VK_CHECK_COND(
730         valid_resize,
731         "tensor sizes requires a larger buffer than the current one.");
732   }
733 }
734 
virtual_reconfigure(const std::vector<int64_t> & new_sizes,const std::vector<int64_t> & new_dim_order)735 void vTensor::virtual_reconfigure(
736     const std::vector<int64_t>& new_sizes,
737     const std::vector<int64_t>& new_dim_order) {
738   VK_CHECK_COND(
739       storage_type() == utils::kBuffer,
740       "virtual_reconfigure is only applicable for buffer backed tensors");
741   VK_CHECK_COND(new_sizes.size() == new_dim_order.size());
742   VK_CHECK_COND(dim_order_is_valid(new_dim_order));
743 
744   check_sizes(new_sizes);
745   sizes_ = new_sizes;
746   dim_order_ = new_dim_order;
747   update_metadata();
748 }
749 
virtual_clone(const vTensor & other)750 void vTensor::virtual_clone(const vTensor& other) {
751   VK_CHECK_COND(is_view_of(other));
752   sizes_ = other.sizes_;
753   dim_order_ = other.dim_order_;
754   axis_map_ = other.axis_map_;
755   packed_dim_ = other.packed_dim_;
756 }
757 
virtual_resize(const std::vector<int64_t> & new_sizes)758 void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
759   VK_CHECK_COND(
760       new_sizes.size() == dim_order_.size(),
761       "new sizes cannot modify the dimensionality of the tensor ");
762 
763   check_sizes(new_sizes);
764   sizes_ = new_sizes;
765   update_metadata();
766 }
767 
768 /*
769  * Transposing the dim order is a bit unintuitive. dim0 and dim1 have swapped
770  * their "identities", so we need to swap the values of dim0 and dim1 wherever
771  * they appear in the dim order vector. Compare this to just swapping the
772  * elements at dim0 and dim1 in the `sizes` vectors.
773  */
transpose_dim_order_inplace(std::vector<int64_t> & dim_order,const int64_t dim0,const int64_t dim1)774 void transpose_dim_order_inplace(
775     std::vector<int64_t>& dim_order,
776     const int64_t dim0,
777     const int64_t dim1) {
778   for (int i = 0; i < dim_order.size(); ++i) {
779     if (dim_order[i] == dim0) {
780       dim_order[i] = dim1;
781     } else if (dim_order[i] == dim1) {
782       dim_order[i] = dim0;
783     }
784   }
785 }
786 
virtual_transpose(const int64_t dim0,const int64_t dim1)787 void vTensor::virtual_transpose(const int64_t dim0, const int64_t dim1) {
788   std::iter_swap(sizes_.begin() + dim0, sizes_.begin() + dim1);
789 
790   const int dim0_whcn = sizes_.size() - 1 - dim0;
791   const int dim1_whcn = sizes_.size() - 1 - dim1;
792   if (packed_dim_ == dim0_whcn) {
793     packed_dim_ = dim1_whcn;
794   } else if (packed_dim_ == dim1_whcn) {
795     packed_dim_ = dim0_whcn;
796   }
797 
798   if (storage_type() == utils::kBuffer) {
799     transpose_dim_order_inplace(dim_order_, dim0, dim1);
800   } else {
801     // Cannot transpose batch dimension for texture storage
802     VK_CHECK_COND(dim0_whcn < 3 && dim1_whcn < 3);
803     std::iter_swap(
804         axis_map_.begin() + dim0_whcn, axis_map_.begin() + dim1_whcn);
805     // Update the "identity" of the concatted dimension
806     if (axis_map_.at(3) == dim0_whcn) {
807       axis_map_.at(3) = dim1_whcn;
808     } else if (axis_map_.at(3) == dim1_whcn) {
809       axis_map_.at(3) = dim0_whcn;
810     }
811   }
812   update_metadata();
813 }
814 
815 } // namespace api
816 } // namespace vkcompute
817