xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/api/containers/ParamsBuffer.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
12 
13 #include <executorch/backends/vulkan/runtime/api/Context.h>
14 
15 #include <executorch/backends/vulkan/runtime/vk_api/memory/Buffer.h>
16 
17 namespace vkcompute {
18 namespace api {
19 
20 class ParamsBuffer final {
21  private:
22   Context* context_p_;
23   size_t nbytes_;
24   vkapi::VulkanBuffer vulkan_buffer_;
25 
26  public:
ParamsBuffer()27   ParamsBuffer() : context_p_{nullptr}, vulkan_buffer_{} {}
28 
29   template <typename Block>
ParamsBuffer(Context * context_p,const Block & block)30   ParamsBuffer(Context* context_p, const Block& block)
31       : context_p_(context_p),
32         nbytes_(sizeof(block)),
33         vulkan_buffer_(
34             context_p_->adapter_ptr()->vma().create_params_buffer(block)) {}
35 
36   ParamsBuffer(const ParamsBuffer&);
37   ParamsBuffer& operator=(const ParamsBuffer&);
38 
39   ParamsBuffer(ParamsBuffer&&) = default;
40   ParamsBuffer& operator=(ParamsBuffer&&) = default;
41 
~ParamsBuffer()42   ~ParamsBuffer() {
43     if (vulkan_buffer_) {
44       context_p_->register_buffer_cleanup(vulkan_buffer_);
45     }
46   }
47 
buffer()48   const vkapi::VulkanBuffer& buffer() const {
49     return vulkan_buffer_;
50   }
51 
52   template <typename Block>
update(const Block & block)53   void update(const Block& block) {
54     if (sizeof(block) != nbytes_) {
55       VK_THROW("Attempted to update ParamsBuffer with data of different size");
56     }
57     // Fill the uniform buffer with data in block
58     {
59       vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::kWrite);
60       Block* data_ptr = mapping.template data<Block>();
61 
62       *data_ptr = block;
63     }
64   }
65 
66   template <typename T>
read()67   T read() const {
68     T val;
69     if (sizeof(val) != nbytes_) {
70       VK_THROW(
71           "Attempted to store value from ParamsBuffer to type of different size");
72     }
73     // Read value from uniform buffer and store in val
74     {
75       vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::kRead);
76       T* data_ptr = mapping.template data<T>();
77 
78       val = *data_ptr;
79     }
80     return val;
81   }
82 };
83 
84 } // namespace api
85 } // namespace vkcompute
86