xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/Tensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/vulkan/api/Tensor.h>
2 #include <ATen/native/vulkan/api/Utils.h>
3 
4 namespace at {
5 namespace native {
6 namespace vulkan {
7 
8 namespace {
9 
10 /*
11  * Calculates the strides of a contiguous tensor. empty_tensor_restride from
12  * TensorImpl.h was used as a reference.
13  */
calc_contiguous_strides(const std::vector<int64_t> & sizes)14 std::vector<int64_t> calc_contiguous_strides(
15     const std::vector<int64_t>& sizes) {
16   int64_t ndim = static_cast<int64_t>(sizes.size());
17   std::vector<int64_t> strides(ndim);
18 
19   int64_t running_product = 1;
20   if (ndim >= 1) {
21     strides.at(ndim - 1) = running_product;
22     for (int i = static_cast<int>(sizes.size()) - 2; i >= 0; --i) {
23       running_product *= sizes.at(i + 1);
24       strides.at(i) = running_product;
25     }
26   }
27 
28   return strides;
29 }
30 
calc_channels_last_strides(const std::vector<int64_t> & sizes)31 std::vector<int64_t> calc_channels_last_strides(
32     const std::vector<int64_t>& sizes) {
33   std::vector<int64_t> strides(sizes.size());
34 
35   switch (sizes.size()) {
36     case 4:
37       strides.at(1) = 1;
38       strides.at(3) = sizes.at(1);
39       strides.at(2) = strides.at(3) * sizes.at(3);
40       strides.at(0) = strides.at(2) * sizes.at(2);
41       return strides;
42     case 3:
43       strides.at(0) = 1;
44       strides.at(2) = sizes.at(0);
45       strides.at(1) = strides.at(2) * sizes.at(2);
46       return strides;
47     default:
48       VK_THROW("ChannelsLast format only available for 3 <= ndim <= 4!");
49   }
50 
51   return strides;
52 }
53 
54 /*
55  * Calculates the strides of a tensor based on the sizes and memory format. Note
56  * that strides are only valid for vTensors that are backed by buffer storage;
57  * if texture storage is used then the strides are invalid and set to zeros.
58  */
calc_strides(const std::vector<int64_t> & sizes,const api::GPUMemoryLayout memory_layout,const api::StorageType storage_type)59 std::vector<int64_t> calc_strides(
60     const std::vector<int64_t>& sizes,
61     const api::GPUMemoryLayout memory_layout,
62     const api::StorageType storage_type) {
63   switch (storage_type) {
64     case api::StorageType::BUFFER:
65       switch (memory_layout) {
66         case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED:
67           return calc_contiguous_strides(sizes);
68           break;
69         case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED:
70           return calc_channels_last_strides(sizes);
71           break;
72         default:
73           VK_THROW("Invalid memory format used to create vTensor!");
74       }
75       break;
76     case api::StorageType::TEXTURE_3D:
77     case api::StorageType::TEXTURE_2D:
78       return std::vector<int64_t>(sizes.size());
79     default:
80       VK_THROW("Invalid storage type used to create vTensor!");
81   }
82 }
83 
84 /*
85  * When stored on the GPU, one dimension will be aligned to the next multiple of
86  * 4 in order to take advantage of vec4 data types. The dimension that is
87  * packed is denoted by the GPUMemoryLayout. This function adjusts one of
88  * the dimensions based on the desired memory format and storage type and
89  * returns a sizes array describing the dimensions of the memory used to store
90  * the tensor data on the GPU.
91  */
calc_gpu_sizes(const std::vector<int64_t> & sizes,const api::GPUMemoryLayout memory_layout,const api::StorageType storage_type)92 std::vector<int64_t> calc_gpu_sizes(
93     const std::vector<int64_t>& sizes,
94     const api::GPUMemoryLayout memory_layout,
95     const api::StorageType storage_type) {
96   VK_CHECK_COND(storage_type != api::StorageType::UNKNOWN);
97 
98   std::vector<int64_t> gpu_sizes;
99   if (storage_type == api::StorageType::BUFFER) {
100     gpu_sizes.resize(sizes.size());
101     for (size_t i = 0; i < sizes.size(); i++) {
102       gpu_sizes.at(i) = sizes.at(i);
103     }
104   }
105   // For texture storage, tensors are typically stored using 3D image textures.
106   // Batches are stacked along the depth dimension. To represent the physical
107   // 3 dimensionality of the image texture (with concatenated batches) GPU sizes
108   // will be fixed to 4 dimensions when using texture storage.
109   else {
110     VK_CHECK_COND(
111         sizes.size() >= 0 && sizes.size() <= 4,
112         "Texture storage only valid for 0 <= ndim <= 4, received: ",
113         sizes.size());
114 
115     gpu_sizes.resize(4);
116     gpu_sizes.at(0) = api::utils::val_at(-4, sizes);
117     gpu_sizes.at(1) = api::utils::val_at(-3, sizes);
118     gpu_sizes.at(2) = api::utils::val_at(-2, sizes);
119     gpu_sizes.at(3) = api::utils::val_at(-1, sizes);
120   }
121 
122   size_t ndim = gpu_sizes.size();
123   switch (memory_layout) {
124     case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED:
125       if (ndim >= 1) {
126         gpu_sizes.at(ndim - 1) =
127             api::utils::align_up(api::utils::val_at(-1, sizes), INT64_C(4));
128       }
129       break;
130 
131     case api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED:
132       if (ndim >= 2) {
133         gpu_sizes.at(ndim - 2) =
134             api::utils::align_up(api::utils::val_at(-2, sizes), INT64_C(4));
135       }
136       break;
137 
138     case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED:
139       if (ndim >= 3) {
140         gpu_sizes.at(ndim - 3) =
141             api::utils::align_up(api::utils::val_at(-3, sizes), INT64_C(4));
142       }
143       break;
144   }
145 
146   return gpu_sizes;
147 }
148 
149 /*
150  * Creates a uvec3 denoting the extents of the image texture that will be
151  * created to store a tensor of a given size.
152  */
create_image_extents(const std::vector<int64_t> & gpu_sizes,const api::StorageType storage_type,const api::GPUMemoryLayout memory_layout)153 api::utils::uvec3 create_image_extents(
154     const std::vector<int64_t>& gpu_sizes,
155     const api::StorageType storage_type,
156     const api::GPUMemoryLayout memory_layout) {
157   size_t ndim = gpu_sizes.size();
158 
159   if (storage_type == api::StorageType::BUFFER) {
160     // image extents do not apply to buffer storage
161     return {0u, 0u, 0u};
162   } else {
163     VK_CHECK_COND(
164         ndim >= 1 && ndim <= 4,
165         "Texture storage only valid for 1 <= ndim <= 4!");
166 
167     using namespace api::utils;
168     uint32_t width = safe_downcast<uint32_t>(val_at(-1, gpu_sizes));
169     uint32_t height = safe_downcast<uint32_t>(val_at(-2, gpu_sizes));
170     uint32_t channels = safe_downcast<uint32_t>(val_at(-3, gpu_sizes));
171     uint32_t batch = safe_downcast<uint32_t>(val_at(-4, gpu_sizes));
172 
173     switch (memory_layout) {
174       case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED:
175         VK_CHECK_COND(width % 4 == 0, "Channels must be divisible by 4!");
176         width /= 4;
177         break;
178       case api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED:
179         VK_CHECK_COND(height % 4 == 0, "Channels must be divisible by 4!");
180         height /= 4;
181         break;
182       case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED:
183         VK_CHECK_COND(channels % 4 == 0, "Channels must be divisible by 4!");
184         channels /= 4;
185         break;
186       default:
187         VK_THROW("Invalid memory format used!");
188     }
189 
190     return {width, height, batch * channels};
191   }
192 }
193 
make_metadata_uniform(api::Context * const context,const std::vector<int64_t> & sizes,const std::vector<int64_t> & strides,const api::StorageType storage_type)194 api::UniformParamsBuffer make_metadata_uniform(
195     api::Context* const context,
196     const std::vector<int64_t>& sizes,
197     const std::vector<int64_t>& strides,
198     const api::StorageType storage_type) {
199   if (storage_type != api::StorageType::BUFFER) {
200     return api::UniformParamsBuffer();
201   }
202 
203   vTensor::BufferMetadata metadata{
204       api::utils::make_whcn_uvec4(sizes),
205       api::utils::make_whcn_uvec4(strides),
206       api::utils::safe_downcast<uint32_t>(sizes.size()),
207       api::utils::safe_downcast<uint32_t>(api::utils::multiply_integers(sizes)),
208   };
209 
210   return api::UniformParamsBuffer(context, metadata);
211 }
212 
213 } // namespace
214 
215 //
216 // vTensor
217 //
218 
vTensor(api::Context * const context,const std::vector<int64_t> & sizes,const api::ScalarType dtype,const api::StorageType storage_type,const api::GPUMemoryLayout memory_layout,const bool allocate_memory)219 vTensor::vTensor(
220     api::Context* const context,
221     const std::vector<int64_t>& sizes,
222     const api::ScalarType dtype,
223     const api::StorageType storage_type,
224     const api::GPUMemoryLayout memory_layout,
225     const bool allocate_memory)
226     : dtype_(dtype),
227       memory_layout_(memory_layout),
228       // Calculate sizes and strides
229       sizes_(sizes.begin(), sizes.end()),
230       strides_{calc_strides(sizes, memory_layout_, storage_type)},
231       gpu_sizes_{calc_gpu_sizes(sizes, memory_layout_, storage_type)},
232       gpu_strides_{calc_strides(gpu_sizes_, memory_layout_, storage_type)},
233       virtual_extents_(
234           create_image_extents(gpu_sizes_, storage_type, memory_layout)),
235       // Utility Uniform Buffers that can be passed to shaders as arguments
236       metadata_uniform_(),
237       cpu_sizes_uniform_(nullptr),
238       gpu_sizes_uniform_(nullptr),
239       extents_uniform_(nullptr),
240       // Construct Tensor storage
241       view_(std::make_shared<vTensorStorage>(
242           context,
243           storage_type,
244           memory_layout_,
245           gpu_sizes_,
246           dtype_,
247           allocate_memory)) {}
248 
vTensor(api::Context * const context,const std::vector<int64_t> & sizes,double q_scale,int64_t q_zero_point,const api::ScalarType dtype,const api::StorageType storage_type,const api::GPUMemoryLayout memory_layout)249 vTensor::vTensor(
250     api::Context* const context,
251     const std::vector<int64_t>& sizes,
252     double q_scale,
253     int64_t q_zero_point,
254     const api::ScalarType dtype,
255     const api::StorageType storage_type,
256     const api::GPUMemoryLayout memory_layout)
257     : dtype_(dtype),
258       memory_layout_(memory_layout),
259       // Calculate sizes and strides
260       sizes_(sizes.begin(), sizes.end()),
261       strides_{calc_strides(sizes, memory_layout_, storage_type)},
262       gpu_sizes_{calc_gpu_sizes(sizes, memory_layout_, storage_type)},
263       gpu_strides_{calc_strides(gpu_sizes_, memory_layout_, storage_type)},
264       virtual_extents_(
265           create_image_extents(gpu_sizes_, storage_type, memory_layout)),
266       // Vulkan uniform buffer containing sizes and stride info
267       metadata_uniform_(),
268       cpu_sizes_uniform_(nullptr),
269       gpu_sizes_uniform_(nullptr),
270       extents_uniform_(nullptr),
271       // Quantization params
272       is_quantized_{true},
273       q_scale_{q_scale},
274       q_zero_point_{q_zero_point},
275       // Construct Tensor storage
276       view_(std::make_shared<vTensorStorage>(
277           context,
278           storage_type,
279           memory_layout_,
280           gpu_sizes_,
281           dtype_)) {}
282 
image(api::PipelineBarrier & pipeline_barrier,const api::PipelineStageFlags stage) const283 api::VulkanImage& vTensor::image(
284     api::PipelineBarrier& pipeline_barrier,
285     const api::PipelineStageFlags stage) const& {
286   view_->transition(pipeline_barrier, stage, api::MemoryAccessType::READ);
287   return view_->image_;
288 }
289 
image(api::PipelineBarrier & pipeline_barrier,const api::PipelineStageFlags stage,const api::MemoryAccessFlags access)290 api::VulkanImage& vTensor::image(
291     api::PipelineBarrier& pipeline_barrier,
292     const api::PipelineStageFlags stage,
293     const api::MemoryAccessFlags access) & {
294   view_->transition(pipeline_barrier, stage, access);
295   return view_->image_;
296 }
297 
buffer(api::PipelineBarrier & pipeline_barrier,const api::PipelineStageFlags stage) const298 api::VulkanBuffer& vTensor::buffer(
299     api::PipelineBarrier& pipeline_barrier,
300     const api::PipelineStageFlags stage) const& {
301   view_->transition(pipeline_barrier, stage, api::MemoryAccessType::READ);
302   return view_->buffer_;
303 }
304 
buffer(api::PipelineBarrier & pipeline_barrier,const api::PipelineStageFlags stage,const api::MemoryAccessFlags access)305 api::VulkanBuffer& vTensor::buffer(
306     api::PipelineBarrier& pipeline_barrier,
307     const api::PipelineStageFlags stage,
308     const api::MemoryAccessFlags access) & {
309   view_->transition(pipeline_barrier, stage, access);
310   return view_->buffer_;
311 }
312 
buffer_metadata()313 api::VulkanBuffer& vTensor::buffer_metadata() {
314   if (!metadata_uniform_.buffer()) {
315     metadata_uniform_ = make_metadata_uniform(
316         view_->context_, gpu_sizes_, gpu_strides_, storage_type());
317   }
318   return metadata_uniform_.buffer();
319 }
320 
cpu_sizes_ubo()321 std::shared_ptr<api::UniformParamsBuffer> vTensor::cpu_sizes_ubo() {
322   if (!cpu_sizes_uniform_) {
323     cpu_sizes_uniform_.reset(new api::UniformParamsBuffer(
324         view_->context_, api::utils::make_whcn_ivec4(sizes_)));
325   }
326   return cpu_sizes_uniform_;
327 }
328 
gpu_sizes_ubo()329 std::shared_ptr<api::UniformParamsBuffer> vTensor::gpu_sizes_ubo() {
330   if (!gpu_sizes_uniform_) {
331     gpu_sizes_uniform_.reset(new api::UniformParamsBuffer(
332         view_->context_, api::utils::make_whcn_ivec4(gpu_sizes_)));
333   }
334   return gpu_sizes_uniform_;
335 }
336 
extents_ubo()337 std::shared_ptr<api::UniformParamsBuffer> vTensor::extents_ubo() {
338   if (!extents_uniform_) {
339     extents_uniform_.reset(new api::UniformParamsBuffer(
340         view_->context_,
341         api::utils::uvec4(
342             {view_->extents_.data[0],
343              view_->extents_.data[1],
344              view_->extents_.data[2],
345              1u})));
346   }
347   return extents_uniform_;
348 }
349 
get_cpu_buffer_metadata() const350 vTensor::BufferMetadata vTensor::get_cpu_buffer_metadata() const {
351   return {
352       api::utils::make_whcn_uvec4(sizes_),
353       api::utils::make_whcn_uvec4(strides_),
354       api::utils::safe_downcast<uint32_t>(sizes_.size()),
355       api::utils::safe_downcast<uint32_t>(
356           api::utils::multiply_integers(sizes_)),
357   };
358 }
359 
get_allocation_create_info() const360 VmaAllocationCreateInfo vTensor::get_allocation_create_info() const {
361   switch (storage_type()) {
362     case api::StorageType::BUFFER:
363       return view_->buffer_.allocation_create_info();
364     case api::StorageType::TEXTURE_2D:
365     case api::StorageType::TEXTURE_3D:
366       return view_->image_.allocation_create_info();
367     case api::StorageType::UNKNOWN:
368       break;
369   }
370   return {};
371 }
372 
get_memory_requirements() const373 VkMemoryRequirements vTensor::get_memory_requirements() const {
374   switch (storage_type()) {
375     case api::StorageType::BUFFER:
376       return view_->buffer_.get_memory_requirements();
377     case api::StorageType::TEXTURE_2D:
378     case api::StorageType::TEXTURE_3D:
379       return view_->image_.get_memory_requirements();
380     case api::StorageType::UNKNOWN:
381       break;
382   }
383   return {};
384 }
385 
bind_allocation(const api::MemoryAllocation & allocation)386 void vTensor::bind_allocation(const api::MemoryAllocation& allocation) {
387   switch (storage_type()) {
388     case api::StorageType::BUFFER:
389       view_->buffer_.bind_allocation(allocation);
390       break;
391     case api::StorageType::TEXTURE_2D:
392     case api::StorageType::TEXTURE_3D:
393       view_->image_.bind_allocation(allocation);
394       break;
395     case api::StorageType::UNKNOWN:
396       break;
397   }
398 }
399 
update_size_metadata(const std::vector<int64_t> & new_sizes)400 void vTensor::update_size_metadata(const std::vector<int64_t>& new_sizes) {
401   sizes_ = new_sizes;
402   gpu_sizes_ = calc_gpu_sizes(sizes_, memory_layout_, storage_type());
403   virtual_extents_ =
404       create_image_extents(gpu_sizes_, storage_type(), memory_layout_);
405 
406   if (cpu_sizes_uniform_) {
407     cpu_sizes_uniform_->update(api::utils::make_whcn_ivec4(sizes_));
408   }
409 
410   if (gpu_sizes_uniform_) {
411     gpu_sizes_uniform_->update(api::utils::make_whcn_ivec4(gpu_sizes_));
412   }
413 
414   if (extents_uniform_) {
415     extents_uniform_->update(api::utils::uvec4(
416         {virtual_extents_.data[0],
417          virtual_extents_.data[1],
418          virtual_extents_.data[2],
419          1u}));
420   }
421 }
422 
reallocate(const std::vector<int64_t> & new_sizes)423 void vTensor::reallocate(const std::vector<int64_t>& new_sizes) {
424   update_size_metadata(new_sizes);
425   view_->discard_and_reallocate(
426       calc_gpu_sizes(new_sizes, memory_layout_, storage_type()),
427       memory_layout_,
428       dtype_);
429 }
430 
virtual_resize(const std::vector<int64_t> & new_sizes)431 void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {
432   update_size_metadata(new_sizes);
433   if (storage_type() == api::StorageType::BUFFER) {
434     if (gpu_nbytes() > view_->buffer_.mem_size()) {
435       VK_THROW(
436           "Cannot virtual_resize a vTensor with sizes that require a larger "
437           "buffer! reallocate() should be used instead.");
438     }
439   } else {
440     bool valid_resize = true;
441     if (virtual_extents_.data[0] > view_->extents_.data[0]) {
442       valid_resize = false;
443     }
444     if (virtual_extents_.data[1] > view_->extents_.data[1]) {
445       valid_resize = false;
446     }
447     if (virtual_extents_.data[2] > view_->extents_.data[2]) {
448       valid_resize = false;
449     }
450 
451     if (!valid_resize) {
452       VK_THROW(
453           "Cannot virtual_resize a vTensor with sizes that require a larger "
454           "image texture! reallocate() should be used instead.");
455     }
456   }
457 }
458 
459 //
460 // vTensorStorage
461 //
462 
allocate_image(api::Context * const context_ptr,api::utils::uvec3 & extents,const api::StorageType storage_type,const VkFormat image_format,const bool allocate_memory)463 static api::VulkanImage allocate_image(
464     api::Context* const context_ptr,
465     api::utils::uvec3& extents,
466     const api::StorageType storage_type,
467     const VkFormat image_format,
468     const bool allocate_memory) {
469   api::Adapter* adapter_ptr = context_ptr->adapter_ptr();
470 
471   api::ImageSampler::Properties sampler_props{
472       VK_FILTER_NEAREST,
473       VK_SAMPLER_MIPMAP_MODE_NEAREST,
474       VK_SAMPLER_ADDRESS_MODE_REPEAT,
475       VK_BORDER_COLOR_FLOAT_TRANSPARENT_BLACK,
476   };
477 
478   VkImageType image_type = VK_IMAGE_TYPE_3D;
479   VkImageViewType image_view_type = VK_IMAGE_VIEW_TYPE_3D;
480 
481   switch (storage_type) {
482     case api::StorageType::TEXTURE_3D:
483       image_type = VK_IMAGE_TYPE_3D;
484       image_view_type = VK_IMAGE_VIEW_TYPE_3D;
485       break;
486     case api::StorageType::TEXTURE_2D:
487       image_type = VK_IMAGE_TYPE_2D;
488       image_view_type = VK_IMAGE_VIEW_TYPE_2D;
489       break;
490     default:
491       // Return an empty VulkanImage by default
492       return api::VulkanImage();
493   }
494 
495   VkSampler sampler = adapter_ptr->sampler_cache().retrieve(sampler_props);
496 
497   return adapter_ptr->vma().create_image(
498       api::create_extent3d(extents),
499       image_format,
500       image_type,
501       image_view_type,
502       sampler_props,
503       sampler,
504       /*allow_transfer = */ true,
505       /*allocate_memory = */ allocate_memory);
506 }
507 
allocate_buffer(api::Context * const context_ptr,const int64_t numel,const api::StorageType storage_type,const api::ScalarType dtype,const bool allocate_memory)508 static api::VulkanBuffer allocate_buffer(
509     api::Context* const context_ptr,
510     const int64_t numel,
511     const api::StorageType storage_type,
512     const api::ScalarType dtype,
513     const bool allocate_memory) {
514   api::Adapter* adapter_ptr = context_ptr->adapter_ptr();
515 
516   switch (storage_type) {
517     case api::StorageType::BUFFER:
518       break;
519     default:
520       // Return an empty VulkanBuffer if Buffer storage is not used
521       return api::VulkanBuffer();
522   }
523 
524   return adapter_ptr->vma().create_storage_buffer(
525       api::element_size(dtype) * numel, /*gpu_only = */ true, allocate_memory);
526 }
527 
vTensorStorage(api::Context * const context,const api::StorageType storage_type,const api::GPUMemoryLayout gpu_memory_layout,const std::vector<int64_t> & gpu_sizes,const api::ScalarType dtype,const bool allocate_memory)528 vTensorStorage::vTensorStorage(
529     api::Context* const context,
530     const api::StorageType storage_type,
531     const api::GPUMemoryLayout gpu_memory_layout,
532     const std::vector<int64_t>& gpu_sizes,
533     const api::ScalarType dtype,
534     const bool allocate_memory)
535     : context_(context),
536       storage_type_{storage_type},
537       extents_(
538           create_image_extents(gpu_sizes, storage_type, gpu_memory_layout)),
539       buffer_length_{api::utils::multiply_integers(gpu_sizes)},
540       image_(allocate_image(
541           context_,
542           extents_,
543           storage_type_,
544           api::to_vkformat(dtype),
545           allocate_memory)),
546       buffer_(allocate_buffer(
547           context_,
548           buffer_length_,
549           storage_type_,
550           dtype,
551           allocate_memory)),
552       last_access_{} {}
553 
~vTensorStorage()554 vTensorStorage::~vTensorStorage() {
555   flush();
556 }
557 
flush()558 void vTensorStorage::flush() {
559   if (image_) {
560     context_->register_image_cleanup(image_);
561   } else if (buffer_) {
562     context_->register_buffer_cleanup(buffer_);
563   }
564   last_access_ = {};
565 }
566 
transition(api::PipelineBarrier & pipeline_barrier,const api::PipelineStageFlags cur_stage,const api::MemoryAccessFlags cur_access)567 void vTensorStorage::transition(
568     api::PipelineBarrier& pipeline_barrier,
569     const api::PipelineStageFlags cur_stage,
570     const api::MemoryAccessFlags cur_access) {
571   // Get last stage access
572   api::PipelineStageFlags prev_stage = last_access_.stage;
573   api::MemoryAccessFlags prev_access = last_access_.access;
574 
575   const bool prev_written = (prev_access & api::MemoryAccessType::WRITE) != 0;
576 
577   VkImageLayout cur_layout = VK_IMAGE_LAYOUT_UNDEFINED;
578   VkImageLayout new_layout = VK_IMAGE_LAYOUT_UNDEFINED;
579   bool layout_changed = false;
580   if (image_) {
581     cur_layout = image_.layout();
582     new_layout = api::vk_layout(cur_stage, cur_access);
583 
584     layout_changed = cur_layout != new_layout;
585   }
586 
587   if (prev_written || layout_changed) {
588     VkPipelineStageFlags src_stage = api::vk_stage(prev_stage);
589     if (0u == src_stage) {
590       src_stage = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT;
591     }
592     VkPipelineStageFlags dst_stage = api::vk_stage(cur_stage);
593     if (0u == dst_stage) {
594       dst_stage = VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT;
595     }
596 
597     pipeline_barrier.stage.src |= src_stage;
598     pipeline_barrier.stage.dst |= dst_stage;
599 
600     if (image_) {
601       pipeline_barrier.images.emplace_back(
602           api::vk_access(prev_stage, prev_access),
603           api::vk_access(cur_stage, cur_access),
604           cur_layout,
605           new_layout,
606           image_);
607 
608       image_.set_layout(new_layout);
609     } else if (buffer_) {
610       pipeline_barrier.buffers.emplace_back(
611           api::vk_access(prev_stage, prev_access),
612           api::vk_access(cur_stage, cur_access),
613           buffer_);
614     }
615   }
616 
617   last_access_.stage = cur_stage;
618   last_access_.access = cur_access;
619 }
620 
add_buffer_barrier(api::PipelineBarrier & pipeline_barrier,const api::VulkanBuffer & buffer,const api::PipelineStageFlags prev_stage,const api::MemoryAccessFlags prev_access,const api::PipelineStageFlags cur_stage,const api::MemoryAccessFlags cur_access)621 void add_buffer_barrier(
622     api::PipelineBarrier& pipeline_barrier,
623     const api::VulkanBuffer& buffer,
624     const api::PipelineStageFlags prev_stage,
625     const api::MemoryAccessFlags prev_access,
626     const api::PipelineStageFlags cur_stage,
627     const api::MemoryAccessFlags cur_access) {
628   // Check for RAW
629   const bool read_requested = (cur_access & api::MemoryAccessType::READ) != 0;
630   const bool prev_written = (prev_access & api::MemoryAccessType::WRITE) != 0;
631 
632   const bool is_RAW = read_requested && prev_written;
633 
634   if (is_RAW) {
635     VkPipelineStageFlags src_stage = api::vk_stage(prev_stage);
636     if (0u == src_stage) {
637       src_stage = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT;
638     }
639     VkPipelineStageFlags dst_stage = api::vk_stage(cur_stage);
640     if (0u == dst_stage) {
641       dst_stage = VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT;
642     }
643 
644     pipeline_barrier.stage.src |= src_stage;
645     pipeline_barrier.stage.dst |= dst_stage;
646 
647     pipeline_barrier.buffers.emplace_back(
648         api::vk_access(prev_stage, prev_access),
649         api::vk_access(cur_stage, cur_access),
650         buffer);
651   }
652 }
653 
discard_and_reallocate(const std::vector<int64_t> & gpu_sizes,const api::GPUMemoryLayout gpu_memory_layout,const api::ScalarType dtype)654 void vTensorStorage::discard_and_reallocate(
655     const std::vector<int64_t>& gpu_sizes,
656     const api::GPUMemoryLayout gpu_memory_layout,
657     const api::ScalarType dtype) {
658   const bool image_owns_memory = image_.owns_memory();
659   const bool buffer_owns_memory = buffer_.owns_memory();
660 
661   flush();
662 
663   extents_ = create_image_extents(gpu_sizes, storage_type_, gpu_memory_layout);
664   image_ = allocate_image(
665       context_,
666       extents_,
667       storage_type_,
668       api::to_vkformat(dtype),
669       image_owns_memory);
670 
671   buffer_length_ = api::utils::multiply_integers(gpu_sizes);
672   buffer_ = allocate_buffer(
673       context_, buffer_length_, storage_type_, dtype, buffer_owns_memory);
674 }
675 
676 } // namespace vulkan
677 } // namespace native
678 } // namespace at
679