xref: /aosp_15_r20/external/pytorch/aten/src/ATen/MemoryOverlap.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Export.h>
4 
5 namespace c10 {
6 struct TensorImpl;
7 }
8 
9 namespace at {
10 class TensorBase;
11 
12 // MemOverlap: Whether or not there is memory overlap
13 //
14 // No: Absolutely no memory overlap
15 // Yes: Absolutely yes memory overlap
16 // TooHard: There might be memory overlap, but it was too expensive to compute.
17 //
18 // NB: Please update the python test for these if you renumber them.
19 enum class MemOverlap { No, Yes, TooHard };
20 
21 enum class MemOverlapStatus { Full, Partial, No, TooHard };
22 
23 TORCH_API MemOverlap has_internal_overlap(const TensorBase& t);
24 TORCH_API MemOverlap has_internal_overlap(c10::TensorImpl* t);
25 
26 TORCH_API void assert_no_internal_overlap(const TensorBase& t);
27 TORCH_API void assert_no_internal_overlap(c10::TensorImpl* t);
28 
29 TORCH_API MemOverlapStatus
30 get_overlap_status(const TensorBase& a, const TensorBase& b);
31 TORCH_API MemOverlapStatus
32 get_overlap_status(const c10::TensorImpl* a, const c10::TensorImpl* b);
33 
34 TORCH_API void assert_no_partial_overlap(
35     const TensorBase& a,
36     const TensorBase& b);
37 void assert_no_partial_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
38 
39 TORCH_API void assert_no_overlap(const TensorBase& a, const TensorBase& b);
40 TORCH_API void assert_no_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
41 
42 } // namespace at
43