xref: /aosp_15_r20/external/pytorch/aten/src/ATen/MemoryOverlap.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/MemoryOverlap.h>
2 #include <ATen/core/TensorBase.h>
3 #include <c10/core/Layout.h>
4 #include <c10/util/irange.h>
5 
6 namespace at {
7 
has_internal_overlap(const TensorBase & tensor)8 MemOverlap has_internal_overlap(const TensorBase& tensor) {
9   return has_internal_overlap(tensor.unsafeGetTensorImpl());
10 }
11 
has_internal_overlap(TensorImpl * t)12 MemOverlap has_internal_overlap(TensorImpl* t) {
13   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t->layout() == kStrided);
14 
15   if (t->is_non_overlapping_and_dense()) {
16     return MemOverlap::No;
17   }
18 
19   auto strides = t->sym_strides();
20   auto sizes = t->sym_sizes();
21   for (const auto i : c10::irange(strides.size())) {
22     // NB: The size oblivious test is written very carefully here.  When
23     // unbacked SymInts are involved, we should try to conservatively report
24     // if memory overlap /could/ happen under some setting of unbacked
25     // SymInts.  Thus, if I have u0 size, we should assume that this has > 1
26     // elements (first expression), but if I have a u0 stride, I should NOT
27     // assume that it is not zero (second expression)
28     if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[i].sym_gt(1)) && strides[i] == 0) {
29       return MemOverlap::Yes;
30     }
31   }
32 
33   return MemOverlap::TooHard;
34 }
35 
assert_no_internal_overlap(const TensorBase & t)36 void assert_no_internal_overlap(const TensorBase& t) {
37   assert_no_internal_overlap(t.unsafeGetTensorImpl());
38 }
39 
assert_no_internal_overlap(TensorImpl * t)40 void assert_no_internal_overlap(TensorImpl* t) {
41   TORCH_CHECK(has_internal_overlap(t) != MemOverlap::Yes,
42     "unsupported operation: more than one element of the written-to tensor "
43     "refers to a single memory location. Please clone() the tensor before "
44     "performing the operation.");
45 }
46 
get_overlap_status(const TensorBase & a,const TensorBase & b)47 MemOverlapStatus get_overlap_status(const TensorBase& a, const TensorBase& b) {
48   return get_overlap_status(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
49 }
50 
get_overlap_status(const TensorImpl * a,const TensorImpl * b)51 MemOverlapStatus get_overlap_status(const TensorImpl* a, const TensorImpl* b) {
52   if (a == b) return MemOverlapStatus::Full;
53   if (a->numel() == 0 || b->numel() == 0) {
54     return MemOverlapStatus::No;
55   }
56   if (!a->is_non_overlapping_and_dense() || !b->is_non_overlapping_and_dense()) {
57     return MemOverlapStatus::TooHard;
58   }
59   // Test for storage equality, rather than pointer equality.
60   // This reduces precision, but if people are aliasing the
61   // same pointer across multiple storages there are many
62   // similar situations (e.g., storage().data() == storage().data()+1)
63   // which we will miss.
64   auto a_storage = a->unsafe_storage();
65   if (a_storage && a_storage.is_alias_of(b->unsafe_storage())) {
66     const auto a_begin = static_cast<const char*>(a->data());
67     const auto a_end = a_begin + a->numel() * a->itemsize();
68     const auto b_begin = static_cast<const char*>(b->data());
69     const auto b_end = b_begin + b->numel() * b->itemsize();
70 
71     if (a_begin == b_begin && a_end == b_end) {
72       return (a->strides() == b->strides()) ?
73           MemOverlapStatus::Full : MemOverlapStatus::Partial;
74     }
75     if (a_begin < b_end && b_begin < a_end) {
76       return MemOverlapStatus::Partial;
77     }
78   }
79   return MemOverlapStatus::No;
80 }
81 
assert_no_partial_overlap(const TensorBase & a,const TensorBase & b)82 void assert_no_partial_overlap(const TensorBase& a, const TensorBase& b) {
83   assert_no_partial_overlap(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
84 }
85 
assert_no_partial_overlap(TensorImpl * a,TensorImpl * b)86 void assert_no_partial_overlap(TensorImpl* a, TensorImpl* b) {
87   TORCH_CHECK(get_overlap_status(a, b) != MemOverlapStatus::Partial,
88     "unsupported operation: some elements of the input tensor and "
89     "the written-to tensor refer to a single memory location. "
90     "Please clone() the tensor before performing the operation.");
91 }
92 
assert_no_overlap(const TensorBase & a,const TensorBase & b)93 void assert_no_overlap(const TensorBase& a, const TensorBase& b) {
94   assert_no_overlap(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
95 }
96 
assert_no_overlap(TensorImpl * a,TensorImpl * b)97 void assert_no_overlap(TensorImpl* a, TensorImpl* b) {
98   const auto lap = get_overlap_status(a, b);
99   TORCH_CHECK(lap != MemOverlapStatus::Partial && lap != MemOverlapStatus::Full,
100     "unsupported operation: some elements of the input tensor and "
101     "the written-to tensor refer to a single memory location. "
102     "Please clone() the tensor before performing the operation.");
103 }
104 
105 }
106