xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/containers/SharedObject.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/containers/SharedObject.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12 
13 namespace vkcompute {
14 
has_user(const ValueRef idx) const15 bool SharedObject::has_user(const ValueRef idx) const {
16   return std::find(users.begin(), users.end(), idx) != users.end();
17 }
18 
add_user(ComputeGraph * const graph,const ValueRef idx)19 void SharedObject::add_user(ComputeGraph* const graph, const ValueRef idx) {
20   vTensorPtr t = graph->get_tensor(idx);
21 
22   // Aggregate Memory Requirements
23   const VkMemoryRequirements mem_reqs = t->get_memory_requirements();
24   aggregate_memory_requirements.size =
25       std::max(mem_reqs.size, aggregate_memory_requirements.size);
26   aggregate_memory_requirements.alignment =
27       std::max(mem_reqs.alignment, aggregate_memory_requirements.alignment);
28   aggregate_memory_requirements.memoryTypeBits |= mem_reqs.memoryTypeBits;
29 
30   users.emplace_back(idx);
31 }
32 
allocate(ComputeGraph * const graph)33 void SharedObject::allocate(ComputeGraph* const graph) {
34   if (aggregate_memory_requirements.size == 0) {
35     return;
36   }
37 
38   VmaAllocationCreateInfo alloc_create_info =
39       graph->context()->adapter_ptr()->vma().gpuonly_resource_create_info();
40 
41   allocation = graph->context()->adapter_ptr()->vma().create_allocation(
42       aggregate_memory_requirements, alloc_create_info);
43 }
44 
bind_users(ComputeGraph * const graph)45 void SharedObject::bind_users(ComputeGraph* const graph) {
46   if (users.empty()) {
47     return;
48   }
49   for (const ValueRef idx : users) {
50     graph->get_tensor(idx)->bind_allocation(allocation);
51   }
52 }
53 
54 } // namespace vkcompute
55