1 #include <ATen/MemoryOverlap.h>
2 #include <ATen/core/TensorBase.h>
3 #include <c10/core/Layout.h>
4 #include <c10/util/irange.h>
5
6 namespace at {
7
has_internal_overlap(const TensorBase & tensor)8 MemOverlap has_internal_overlap(const TensorBase& tensor) {
9 return has_internal_overlap(tensor.unsafeGetTensorImpl());
10 }
11
has_internal_overlap(TensorImpl * t)12 MemOverlap has_internal_overlap(TensorImpl* t) {
13 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t->layout() == kStrided);
14
15 if (t->is_non_overlapping_and_dense()) {
16 return MemOverlap::No;
17 }
18
19 auto strides = t->sym_strides();
20 auto sizes = t->sym_sizes();
21 for (const auto i : c10::irange(strides.size())) {
22 // NB: The size oblivious test is written very carefully here. When
23 // unbacked SymInts are involved, we should try to conservatively report
24 // if memory overlap /could/ happen under some setting of unbacked
25 // SymInts. Thus, if I have u0 size, we should assume that this has > 1
26 // elements (first expression), but if I have a u0 stride, I should NOT
27 // assume that it is not zero (second expression)
28 if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[i].sym_gt(1)) && strides[i] == 0) {
29 return MemOverlap::Yes;
30 }
31 }
32
33 return MemOverlap::TooHard;
34 }
35
assert_no_internal_overlap(const TensorBase & t)36 void assert_no_internal_overlap(const TensorBase& t) {
37 assert_no_internal_overlap(t.unsafeGetTensorImpl());
38 }
39
assert_no_internal_overlap(TensorImpl * t)40 void assert_no_internal_overlap(TensorImpl* t) {
41 TORCH_CHECK(has_internal_overlap(t) != MemOverlap::Yes,
42 "unsupported operation: more than one element of the written-to tensor "
43 "refers to a single memory location. Please clone() the tensor before "
44 "performing the operation.");
45 }
46
get_overlap_status(const TensorBase & a,const TensorBase & b)47 MemOverlapStatus get_overlap_status(const TensorBase& a, const TensorBase& b) {
48 return get_overlap_status(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
49 }
50
get_overlap_status(const TensorImpl * a,const TensorImpl * b)51 MemOverlapStatus get_overlap_status(const TensorImpl* a, const TensorImpl* b) {
52 if (a == b) return MemOverlapStatus::Full;
53 if (a->numel() == 0 || b->numel() == 0) {
54 return MemOverlapStatus::No;
55 }
56 if (!a->is_non_overlapping_and_dense() || !b->is_non_overlapping_and_dense()) {
57 return MemOverlapStatus::TooHard;
58 }
59 // Test for storage equality, rather than pointer equality.
60 // This reduces precision, but if people are aliasing the
61 // same pointer across multiple storages there are many
62 // similar situations (e.g., storage().data() == storage().data()+1)
63 // which we will miss.
64 auto a_storage = a->unsafe_storage();
65 if (a_storage && a_storage.is_alias_of(b->unsafe_storage())) {
66 const auto a_begin = static_cast<const char*>(a->data());
67 const auto a_end = a_begin + a->numel() * a->itemsize();
68 const auto b_begin = static_cast<const char*>(b->data());
69 const auto b_end = b_begin + b->numel() * b->itemsize();
70
71 if (a_begin == b_begin && a_end == b_end) {
72 return (a->strides() == b->strides()) ?
73 MemOverlapStatus::Full : MemOverlapStatus::Partial;
74 }
75 if (a_begin < b_end && b_begin < a_end) {
76 return MemOverlapStatus::Partial;
77 }
78 }
79 return MemOverlapStatus::No;
80 }
81
assert_no_partial_overlap(const TensorBase & a,const TensorBase & b)82 void assert_no_partial_overlap(const TensorBase& a, const TensorBase& b) {
83 assert_no_partial_overlap(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
84 }
85
assert_no_partial_overlap(TensorImpl * a,TensorImpl * b)86 void assert_no_partial_overlap(TensorImpl* a, TensorImpl* b) {
87 TORCH_CHECK(get_overlap_status(a, b) != MemOverlapStatus::Partial,
88 "unsupported operation: some elements of the input tensor and "
89 "the written-to tensor refer to a single memory location. "
90 "Please clone() the tensor before performing the operation.");
91 }
92
assert_no_overlap(const TensorBase & a,const TensorBase & b)93 void assert_no_overlap(const TensorBase& a, const TensorBase& b) {
94 assert_no_overlap(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
95 }
96
assert_no_overlap(TensorImpl * a,TensorImpl * b)97 void assert_no_overlap(TensorImpl* a, TensorImpl* b) {
98 const auto lap = get_overlap_status(a, b);
99 TORCH_CHECK(lap != MemOverlapStatus::Partial && lap != MemOverlapStatus::Full,
100 "unsupported operation: some elements of the input tensor and "
101 "the written-to tensor refer to a single memory location. "
102 "Please clone() the tensor before performing the operation.");
103 }
104
105 }
106