xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/memory_planner.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/alignment.h>
2 #include <torch/csrc/jit/runtime/static/memory_planner.h>
3 
4 #include <ATen/Tensor.h>
5 #include <torch/csrc/jit/ir/alias_analysis.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/runtime/static/impl.h>
8 #include <iterator>
9 
10 namespace torch::jit {
11 
12 namespace {
13 
isUnmanagedSpecialCase(const ProcessedNode & pnode,size_t output_idx)14 bool isUnmanagedSpecialCase(const ProcessedNode& pnode, size_t output_idx) {
15   DCHECK(output_idx < pnode.outputs().size());
16   static const auto to_maybe_copy_out_symbol =
17       c10::Symbol::fromQualString("static_runtime::to_maybe_copy_out");
18   // Heuristic and special case:
19   // If to_maybe_copy_out did not actually do anything in the
20   // first iteration, assume it will continue to not do anything
21   // and avoid managing its output.
22   return pnode.node()->kind() == to_maybe_copy_out_symbol &&
23       pnode.Output(output_idx).isNone();
24 }
25 
tensorValueToTensor(const std::vector<ProcessedNode> & nodes,const c10::FastSet<const Value * > & managed_tensor_values)26 c10::FastMap<const Value*, at::Tensor*> tensorValueToTensor(
27     const std::vector<ProcessedNode>& nodes,
28     const c10::FastSet<const Value*>& managed_tensor_values) {
29   c10::FastMap<const Value*, at::Tensor*> tensor_value_to_tensor;
30   for (auto& pnode : nodes) {
31     auto* node = pnode.node();
32     for (const auto output_idx : c10::irange(node->outputs().size())) {
33       auto* output = node->output(output_idx);
34 
35       if (managed_tensor_values.find(output) == managed_tensor_values.end()) {
36         continue;
37       }
38 
39       auto& ival = pnode.Output(output_idx);
40 
41       // ival is allowed to be None in special cases, e.g. to_maybe_copy_out
42       DCHECK(
43           ival.isTensor() ||
44           (ival.isNone() && isUnmanagedSpecialCase(pnode, output_idx)));
45 
46       if (ival.isTensor()) {
47         tensor_value_to_tensor.emplace(
48             output,
49             // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
50             const_cast<at::Tensor*>(&ival.toTensor()));
51       }
52     }
53   }
54   return tensor_value_to_tensor;
55 }
56 
57 // Don't change the size if it is already aligned, otherwise increase the size
58 // to make it aligned.
compute_aligned_tensor_size(size_t nbytes)59 size_t compute_aligned_tensor_size(size_t nbytes) {
60   // Note: everything below is size_t
61   return (nbytes + c10::gAlignment - 1) & (~(c10::gAlignment - 1));
62 }
63 
allocate_buffer(size_t size)64 at::DataPtr allocate_buffer(size_t size) {
65   at::Allocator* allocator = c10::GetCPUCachingAllocator();
66   return allocator->allocate(size);
67 }
68 
69 } // namespace
70 
assignStorageToManagedTensors(graph_node_list nodes,const ManagedTensorRanges & ranges,const c10::FastMap<const Value *,at::Tensor * > & tensor_value_to_tensor)71 std::vector<StorageGroup> assignStorageToManagedTensors(
72     graph_node_list nodes,
73     const ManagedTensorRanges& ranges,
74     const c10::FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor) {
75   std::vector<StorageGroup> managed_tensor_groups;
76   // This set maps each Value* to its assigned storage group.
77   c10::FastMap<const Value*, size_t> storage_group_mapping;
78   // On each iteration, this vector stores the set of storage groups that
79   // are available for re-use.
80   std::vector<size_t> free_storage_groups;
81 
82   auto makeNewStorageGroup = [&](const Value* value) {
83     const auto storage_group = managed_tensor_groups.size();
84     storage_group_mapping.emplace(value, storage_group);
85     auto* tensor_ptr = tensor_value_to_tensor.at(value);
86     managed_tensor_groups.emplace_back(tensor_ptr);
87   };
88 
89   auto assignToAvailableStorageGroup = [&](const Value* value) {
90     DCHECK(!free_storage_groups.empty());
91     const auto storage_group = free_storage_groups.back();
92     TORCH_DCHECK_LT(storage_group, managed_tensor_groups.size());
93     storage_group_mapping.emplace(value, storage_group);
94     auto* tensor_ptr = tensor_value_to_tensor.at(value);
95     managed_tensor_groups[storage_group].addTensor(tensor_ptr);
96     free_storage_groups.pop_back();
97   };
98 
99   auto isManagedTensor = [&](const Value* value) {
100     return tensor_value_to_tensor.find(value) != tensor_value_to_tensor.end();
101   };
102 
103   for (auto* node : nodes) {
104     // Assign storage groups to outputs
105     for (const auto output_idx : c10::irange(node->outputs().size())) {
106       Value* output = node->output(output_idx);
107       if (!isManagedTensor(output)) {
108         continue;
109       }
110       if (free_storage_groups.empty()) {
111         makeNewStorageGroup(output);
112         continue;
113       }
114       assignToAvailableStorageGroup(output);
115     }
116 
117     // This node may be the last use of some managed tensors. If so, we
118     // can mark the corresponding storage groups as free.
119     if (ranges.nodeFreesManagedTensors(node)) {
120       const auto& new_free_tensors =
121           ranges.availableTensorValuesAfterNode(node);
122       for (auto* tensor_value : new_free_tensors) {
123         // We need to check this here to handle special cases like
124         // to_maybe_copy_out. We don't know if the tensor value is managed until
125         // after the first iter, but `ranges` is initialized at load time!
126         if (!isManagedTensor(tensor_value)) {
127           continue;
128         }
129         const auto storage_group = storage_group_mapping.at(tensor_value);
130         free_storage_groups.push_back(storage_group);
131       }
132     }
133   }
134   return managed_tensor_groups;
135 }
136 
ManagedStorages()137 ManagedStorages::ManagedStorages()
138     : storages_(nullptr), size_(0), capacity_(0) {}
139 
~ManagedStorages()140 ManagedStorages::~ManagedStorages() {
141   deallocate();
142 }
143 
allocate(size_t capacity)144 void ManagedStorages::allocate(size_t capacity) {
145   TORCH_CHECK(!is_allocated(), "Must deallocate before allocating again");
146   // `size_` should already be 0 if not allocated, so double check it here
147   TORCH_INTERNAL_ASSERT(size_ == 0);
148   capacity_ = capacity;
149   storages_ = reinterpret_cast<at::StorageImpl*>(
150       new unsigned char[capacity_ * sizeof(at::StorageImpl)]);
151 }
152 
deallocate()153 void ManagedStorages::deallocate() {
154   if (is_allocated()) {
155     for (const size_t idx : c10::irange(size_)) {
156       storages_[idx].~StorageImpl();
157     }
158     delete[] reinterpret_cast<unsigned char*>(storages_);
159     capacity_ = 0;
160     size_ = 0;
161     storages_ = nullptr;
162   }
163 }
164 
append(at::StorageImpl & storageImpl)165 void ManagedStorages::append(at::StorageImpl& storageImpl) {
166   TORCH_INTERNAL_ASSERT(size_ < capacity_);
167   new (&storages_[size_]) at::StorageImpl(
168       at::StorageImpl::use_byte_size_t(),
169       storageImpl.nbytes(),
170       storageImpl.allocator(),
171       storageImpl.resizable());
172   size_++;
173 }
174 
175 namespace {
176 
setIncludes(const c10::FastSet<const Value * > & set,const Value * v)177 bool setIncludes(const c10::FastSet<const Value*>& set, const Value* v) {
178   return set.find(v) != set.end();
179 }
180 
assignStorageToOutputTensors(BlockRunner * block_runner,const c10::FastSet<const Value * > & managed_output_tensor_values)181 std::vector<std::pair<size_t, at::Tensor*>> assignStorageToOutputTensors(
182     BlockRunner* block_runner,
183     const c10::FastSet<const Value*>& managed_output_tensor_values) {
184   std::vector<std::pair<size_t, at::Tensor*>> managed_output_tensors;
185   for (auto& pnode : block_runner->nodes()) {
186     for (const auto i : c10::irange(pnode.outputs().size())) {
187       auto& ival = pnode.Output(i);
188       const auto* val = pnode.node()->outputs()[i];
189       if (!setIncludes(managed_output_tensor_values, val) ||
190           isUnmanagedSpecialCase(pnode, i)) {
191         continue;
192       }
193       TORCH_CHECK(ival.isTensor());
194       at::Tensor* tensor = &ival.toTensor();
195       managed_output_tensors.emplace_back(0, tensor);
196     }
197   }
198   return managed_output_tensors;
199 }
200 
201 } // namespace
202 
MemoryPlanner(BlockRunner * block_runner,const BlockInfo & block_info,bool enable_out_variant,bool manage_output_tensors)203 MemoryPlanner::MemoryPlanner(
204     BlockRunner* block_runner,
205     const BlockInfo& block_info,
206     bool enable_out_variant,
207     bool manage_output_tensors) {
208   const auto& managed_tensor_values = block_info.managed_tensor_values();
209   const auto& managed_output_tensor_values =
210       block_info.managed_output_tensor_values();
211   const auto& leaked_values = block_info.leaked_values();
212 
213   // collect unmanaged output ivalues
214   c10::FastSet<IValue*> unmanaged_ivalues;
215   c10::FastSet<IValue*> unmanaged_borrowed_ivalues;
216   for (ProcessedNode& pnode : block_runner->nodes()) {
217     const auto borrows_outputs = borrowsOutputs(pnode.node()->kind());
218     for (const auto i : c10::irange(pnode.outputs().size())) {
219       const Value* out_v = pnode.node()->outputs()[i];
220       const bool in_managed_tensors = setIncludes(managed_tensor_values, out_v);
221       const bool is_unmanaged_special_case = isUnmanagedSpecialCase(pnode, i);
222       if (in_managed_tensors && !is_unmanaged_special_case) {
223         ++num_managed_tensors_;
224       }
225       const bool in_managed_sets = in_managed_tensors ||
226           // Manage output tensors might have been turned off, so we have to
227           // check the flag here
228           (manage_output_tensors &&
229            setIncludes(managed_output_tensor_values, out_v)) ||
230           setIncludes(leaked_values, out_v);
231 
232       if (in_managed_sets && !is_unmanaged_special_case) {
233         continue;
234       }
235       if (doesNotHeapAllocateWhenStoredInIValue(*out_v->type())) {
236         // Scalars do not need to be freed after each iteration.
237         num_unmanaged_scalar_ivalues_++;
238       } else if (borrows_outputs) {
239         IValue& out = pnode.Output(i);
240         unmanaged_borrowed_ivalues.insert(&out);
241       } else {
242         IValue& out = pnode.Output(i);
243         unmanaged_ivalues.insert(&out);
244       }
245     }
246   }
247   for (IValue* output : block_runner->outputs()) {
248     auto it = unmanaged_borrowed_ivalues.find(output);
249     if (it != unmanaged_borrowed_ivalues.end()) {
250       borrowed_ivalues_needing_incref_.push_back(output);
251       unmanaged_borrowed_ivalues.erase(it);
252     } else {
253       unmanaged_ivalues.erase(output);
254     }
255   }
256 
257   // copy to unmanaged_ivalues_
258   unmanaged_ivalues_.reserve(unmanaged_ivalues.size());
259   unmanaged_ivalues_.insert(
260       unmanaged_ivalues_.begin(),
261       unmanaged_ivalues.begin(),
262       unmanaged_ivalues.end());
263   unmanaged_borrowed_ivalues_.reserve(unmanaged_borrowed_ivalues.size());
264   unmanaged_borrowed_ivalues_.insert(
265       unmanaged_borrowed_ivalues_.begin(),
266       unmanaged_borrowed_ivalues.begin(),
267       unmanaged_borrowed_ivalues.end());
268 
269   if (enable_out_variant && manage_output_tensors) {
270     managed_output_tensors_ = assignStorageToOutputTensors(
271         block_runner, managed_output_tensor_values);
272   }
273 }
274 
allocateBuffer(size_t num_bytes)275 uint8_t* MemoryPlanner::allocateBuffer(size_t num_bytes) {
276   buffer_ = allocate_buffer(num_bytes);
277   uint8_t* start = static_cast<uint8_t*>(buffer_.get());
278   buffer_start_ = start;
279   buffer_end_ = start + num_bytes;
280   return start;
281 }
282 
allocateOutputTensors()283 void MemoryPlanner::allocateOutputTensors() {
284   if (output_buffer_bytes_ == 0) {
285     return;
286   }
287   TORCH_CHECK(
288       !output_buffer_,
289       "Previously allocated output_buffer_ was not deallocated properly.");
290   output_buffer_ = allocate_buffer(output_buffer_bytes_);
291 
292   size_t offset = 0;
293   uint8_t* start = static_cast<uint8_t*>(output_buffer_.get());
294 
295   for (const auto& ms : managed_output_tensors_) {
296     auto tensor_size = ms.first;
297     auto* tensor = ms.second;
298     if (tensor_size == 0) {
299       continue;
300     }
301     TORCH_DCHECK_LE(offset + tensor_size, output_buffer_bytes_);
302     void* src = static_cast<void*>(start + offset);
303     // NOTE: Populating `ctx` enables clients to take the ownership of a
304     // tensor managed by Static Runtime. Some clients use "move" semantics to
305     // pass a Tensor object to another holding object (e.g., a thrift message)
306     // to avoid `memcpy`.
307     // `torch::distributed::detail::WireDumpOp::dumpTensorData is a concrete
308     // example of doing this (See `torch::distributed::detail::hasDeleter`).
309     // Since this output Tensor object is permanently owned by Static Runtime,
310     // this ownership passing does *not* have an intended effect of keeping the
311     // Tensor alive till the "owner" releases it: A premature call to
312     // `StaticRuntime::deallocateOutputTensors` can destruct such a Tensor
313     // object that a holding object believes to retain, causing it to read
314     // corrupted values from an already destructed Tensor object. Therefore, a
315     // client of receiving Static Runtime-managed Tensors needs to be very
316     // careful to call `StaticRuntime::deallocateOutputTensors` after these
317     // holding objects are gone.
318     tensor->storage().set_data_ptr_noswap(
319         at::DataPtr(src, /*ctx=*/src, nullptr, tensor->device()));
320     tensor->storage().set_nbytes(tensor_size);
321     offset += tensor_size;
322   }
323   TORCH_DCHECK_EQ(offset, output_buffer_bytes_);
324 }
325 
allocate()326 void MemoryPlanner::allocate() {
327   // TODO: Improve this once D31357486 is landed.
328   allocateManagedTensors();
329   allocateOutputTensors();
330 }
331 
deallocate()332 void MemoryPlanner::deallocate() {
333   for (auto& iv : borrowed_ivalues_needing_incref_) {
334     auto old = std::move(*iv);
335     *iv = IValue(old);
336     c10::MaybeOwnedTraits<c10::IValue>::destroyBorrow(old);
337   }
338   // for unmanaged ivalues (either tensor or non-tensor), we reset the *iv so
339   // that the objects pointed to by *iv may be reclaimed by reference counting
340   for (auto& iv : unmanaged_ivalues_) {
341     *iv = IValue();
342   }
343   for (auto& iv : unmanaged_borrowed_ivalues_) {
344     c10::MaybeOwnedTraits<c10::IValue>::destroyBorrow(*iv);
345   }
346   // It's important to call this function after all other owning refs
347   // of the managed StorageImpls are cleaned up. It can reset the
348   // the StorageImpl's refcount to (# tensors in storage group),
349   // so destructing any owning refs afterwards will bring the refcount
350   // lower than expected and trigger the debug assertion in
351   // ~intrusive_ptr_target.
352   deallocateManagedTensors();
353   buffer_ = {};
354 }
355 
deallocateOutputTensors()356 void MemoryPlanner::deallocateOutputTensors() {
357   size_t output_buffer_bytes = 0;
358   for (auto& ms : managed_output_tensors_) {
359     auto* tensor = ms.second;
360     size_t current_size =
361         compute_aligned_tensor_size(tensor->storage().nbytes());
362     tensor->storage().unsafeGetStorageImpl()->reset();
363     if (current_size > ms.first) {
364       ms.first = current_size;
365     }
366     output_buffer_bytes += ms.first;
367   }
368   output_buffer_bytes_ = output_buffer_bytes;
369   output_buffer_ = {};
370 }
371 
StandardMemoryPlanner(BlockRunner * block_runner,const BlockInfo & block_info,bool enable_out_variant,bool manage_output_tensors,bool optimize_memory)372 StandardMemoryPlanner::StandardMemoryPlanner(
373     BlockRunner* block_runner,
374     const BlockInfo& block_info,
375     bool enable_out_variant,
376     bool manage_output_tensors,
377     bool optimize_memory)
378     : MemoryPlanner(
379           block_runner,
380           block_info,
381           enable_out_variant,
382           manage_output_tensors) {
383   const auto& managed_tensor_values = block_info.managed_tensor_values();
384   if (enable_out_variant) {
385     const auto tensor_value_to_tensor =
386         tensorValueToTensor(block_runner->nodes(), managed_tensor_values);
387     if (optimize_memory) {
388       managed_tensors_ = assignStorageToManagedTensors(
389           block_info.node_ptrs(),
390           block_info.managed_tensor_ranges(),
391           tensor_value_to_tensor);
392     } else {
393       for (auto& tensor : tensor_value_to_tensor) {
394         managed_tensors_.emplace_back(tensor.second);
395       }
396     }
397   }
398 }
399 
allocateManagedTensors()400 void StandardMemoryPlanner::allocateManagedTensors() {
401   if (managed_bytes_ == 0) {
402     return;
403   }
404   DCHECK(!storages_.empty());
405   size_t offset = 0;
406   auto* start = allocateBuffer(managed_bytes_);
407 
408   reused_tensors_ = 0;
409   size_t group_idx = 0;
410   for (const size_t storages_idx : c10::irange(storages_.size())) {
411     auto tensor_size = storages_nbytes_[storages_idx];
412     if (tensor_size == 0) {
413       group_idx++;
414       continue;
415     }
416     at::StorageImpl* storageImpl = &storages_[storages_idx];
417     TORCH_DCHECK_LE(offset + tensor_size, managed_bytes_);
418     void* src = static_cast<void*>(start + offset);
419 
420 #ifndef NDEBUG
421     TORCH_DCHECK_EQ(tensor_size, managed_tensors_[group_idx].maxTensorSize());
422     for (auto* tensor : managed_tensors_[group_idx].group()) {
423       TORCH_DCHECK_EQ(storageImpl, tensor->storage().unsafeGetStorageImpl());
424     }
425 #endif
426     TORCH_DCHECK_NE(managed_tensors_[group_idx].numManagedTensors(), 0);
427     reused_tensors_ += managed_tensors_[group_idx].numManagedTensors() - 1;
428     storageImpl->set_data_ptr_noswap(
429         at::DataPtr(src, src, nullptr, c10::Device(c10::DeviceType::CPU)));
430     storageImpl->set_nbytes(tensor_size);
431 
432     offset += tensor_size;
433     group_idx++;
434   }
435   TORCH_DCHECK_EQ(offset, managed_bytes_);
436 }
437 
deallocateManagedTensors()438 void StandardMemoryPlanner::deallocateManagedTensors() {
439   managed_bytes_ = 0;
440   // free memory used by outputs of ops in out variants
441   // but keep the TensorImpl and StorageImpl around.
442 
443   // We don't have any guarantee that the model doesn't change the
444   // Storage for managed tensors out from under us during execution,
445   // so we have to check the Storages each time we deallocate.
446   unsigned group_idx = 0;
447   const bool first_time = storages_.empty();
448   if (C10_UNLIKELY(first_time)) {
449     if (storages_.is_allocated()) {
450       storages_.deallocate();
451     }
452     storages_.allocate(managed_tensors_.size());
453     storages_nbytes_.reserve(managed_tensors_.size());
454   }
455   for (auto& ms : managed_tensors_) {
456     const auto& tensors = ms.group();
457     size_t max = ms.maxTensorSize();
458     for (auto& tensor : tensors) {
459       const auto& storage = tensor->storage();
460       size_t current_size = compute_aligned_tensor_size(storage.nbytes());
461       at::StorageImpl* tensorStorageImpl = storage.unsafeGetStorageImpl();
462       if (C10_UNLIKELY(first_time)) {
463         tensorStorageImpl->reset();
464 
465         DCHECK(
466             storages_.size() == group_idx || storages_.size() == group_idx + 1);
467         if (storages_.size() == group_idx) {
468           storages_.append(*tensorStorageImpl);
469           storages_nbytes_.emplace_back(0);
470         }
471         at::StorageImpl* newImpl = &storages_[storages_.size() - 1];
472 
473         // We want to manage StorageImpls' lifetimes ourselves, but TensorImpl
474         // expects to refcount them. unsafe_adapt_non_heap_allocated is our
475         // escape hatch: it sets the reference count for the StorageImpl to an
476         // impractically high value so that it will never get deallocated by
477         // intrusive_ptr, leaving us free to manage its lifetime as we see fit.
478         // (Note that allowing it to be deallocated by intrusive_ptr would be
479         // UB, because that would entail deleting an object that wasn't
480         // allocated with operator new.)
481         //
482         // For more information, see the doc comment for
483         // intrusive_ptr::unsafe_adapt_non_heap_allocated.
484         tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage(
485             c10::intrusive_ptr<at::StorageImpl>::
486                 unsafe_adapt_non_heap_allocated(newImpl, tensors.size())));
487       } else if (C10_UNLIKELY(tensorStorageImpl != &storages_[group_idx])) {
488         tensorStorageImpl->reset();
489 
490         // If somehow the tensor got different storage, put it back to
491         // the shared impl for this group.
492         tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(
493             at::Storage(c10::intrusive_ptr<at::StorageImpl>::
494                             unsafe_adapt_non_heap_allocated(
495                                 &storages_[group_idx], tensors.size())));
496       }
497       TORCH_DCHECK_EQ(
498           tensor->storage().unsafeGetStorageImpl(), &storages_[group_idx]);
499       max = std::max(max, current_size);
500     }
501     // Static runtime does not know the size of tensors statically, so we use
502     // the tensor size from the previous run to allocate tensors for the next
503     // run (following C2 tradition), exploiting the fact that tensor storage
504     // size does not have to match that of real tensor size. The following logic
505     // records the tensor storage size for the next run.
506     storages_nbytes_[group_idx++] = max;
507     ms.setMaxTensorSize(max);
508     managed_bytes_ += max;
509   }
510 
511   TORCH_DCHECK_EQ(storages_.size(), managed_tensors_.size());
512   VLOG(1) << "managed_bytes: " << managed_bytes_;
513 }
514 
515 } // namespace torch::jit
516