1 #include <ATen/core/Tensor.h>
2 #include <ATen/core/jit_type.h>
3 #include <c10/core/GradMode.h>
4
5 #include <utility>
6
7 namespace c10 {
8
9 namespace {
10
11 // The idea is to only mark possible overlap across dimensions. We want to
12 // return false for expanded tensors and permuted tensors, for which dimensional
13 // collapsing is safe.
possible_cross_dimension_overlap(c10::IntArrayRef sizes,c10::IntArrayRef strides)14 bool possible_cross_dimension_overlap(c10::IntArrayRef sizes, c10::IntArrayRef strides) {
15 int n_dim = static_cast<int>(sizes.size());
16 std::vector<size_t> stride_indices(n_dim);
17 std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
18
19 // sort indices going with ascending strides
20 for (int i = 1; i < n_dim; i++) {
21 auto c = i;
22 for (int j = i - 1; j >= 0; j--) {
23 if (strides[stride_indices[j]] > strides[stride_indices[c]]) {
24 std::swap(stride_indices[j], stride_indices[c]);
25 c = j;
26 }
27 }
28 }
29
30 for (const auto i : c10::irange(1, n_dim)) {
31 if (i != 0) {
32 // we are being conservative on checking for memory overlap
33 if (sizes[stride_indices[i]] != 1 && strides[stride_indices[i]] < sizes[stride_indices[i-1]] * strides[stride_indices[i-1]]) {
34 return true;
35 }
36 }
37 }
38 return false;
39 }
40
41 }
42
get()43 const TensorTypePtr& TensorType::get() {
44 static auto value = TensorType::create(
45 {}, {}, SymbolicShape(), VaryingShape<Stride>{}, {});
46 return value;
47 }
48
ofTensors()49 ListTypePtr ListType::ofTensors() {
50 static auto value = ListType::create(TensorType::get());
51 return value;
52 }
53
54 template <typename T>
merge(const VaryingShape<T> & other) const55 VaryingShape<T> VaryingShape<T>::merge(const VaryingShape<T>& other) const {
56 if (!dims_ || !other.dims_ || dims_->size() != other.dims_->size()) {
57 return VaryingShape<T>();
58 }
59 ListOfOptionalElements dims;
60 for (size_t i = 0, n = dims_->size(); i < n; i++) {
61 dims.push_back(merge_primitive((*dims_)[i], (*other.dims_)[i]));
62 }
63 return VaryingShape<T>(std::move(dims));
64 }
65
66 template <typename T>
operator <<(std::ostream & out,const VaryingShape<T> & vs)67 std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
68 out << "(";
69 if (!vs.size()) {
70 out << "*)";
71 return out;
72 }
73
74 for (size_t i = 0; i < vs.size(); i++) {
75 if (i > 0) {
76 out << ", ";
77 }
78 if (vs[i].has_value()) {
79 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
80 out << vs[i].value();
81 } else {
82 out << "*";
83 }
84 }
85 out << ")";
86 return out;
87 }
88
89 template std::ostream& operator<<(
90 std::ostream& out,
91 const VaryingShape<int64_t>& vs);
92 template std::ostream& operator<<(
93 std::ostream& out,
94 const VaryingShape<Stride>& vs);
95
operator <<(std::ostream & os,const SymbolicShape & ss)96 std::ostream& operator<<(
97 std::ostream& os,
98 const SymbolicShape& ss) {
99 // TODO: Unranked SymbolicShape printing is ambiguous with that of
100 // dynamic-shaped vector.
101 if(!ss.rank()) {
102 os << "(*)";
103 return os;
104 }
105
106 auto sizes = ss.sizes().value();
107
108 os << "(";
109 for (size_t i = 0; i < ss.rank().value(); i++) {
110 if (i > 0) {
111 os << ", ";
112 }
113 if(sizes[i].is_static()) {
114 os << sizes[i];
115 } else {
116 os << "*";
117 }
118 }
119 os << ")";
120
121 return os;
122 }
123
operator <<(std::ostream & os,const ShapeSymbol & s)124 std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s) {
125 if (s.value_ >= 0) {
126 os << s.value_;
127 } else {
128 os << "SS(" << s.value_ << ')';
129 }
130 return os;
131 }
132
operator <<(std::ostream & os,const Stride & s)133 std::ostream& operator<<(std::ostream& os, const Stride& s) {
134 os << "{";
135 if (s.stride_index_.has_value()) {
136 os << *s.stride_index_;
137 } else {
138 os << "*";
139 }
140 os << ":";
141 if (s.stride_.has_value()) {
142 os << *s.stride_;
143 } else {
144 os << "*";
145 }
146 os << '}';
147 return os;
148 }
149
computeStrideProps(at::IntArrayRef sizes,at::IntArrayRef strides,bool tensor_contiguity)150 VaryingShape<Stride> TensorType::computeStrideProps(
151 at::IntArrayRef sizes,
152 at::IntArrayRef strides,
153 bool tensor_contiguity) {
154 int n_dim = static_cast<int>(sizes.size());
155 std::vector<size_t> stride_indices(n_dim);
156 // default has_overlap to false as we only compute overlap when:
157 // 1. input sizes/strides fails format check;
158 // 2. tensor_contiguity are not set.
159 bool has_overlap = false;
160
161 // Sorting strides in ascending order
162 // Example:
163 // Prior to sorting
164 // Idx: [0, 1, 2, 3]
165 // sizes: [8, 1, 10, 16]
166 // Strides: [160, 1, 16, 1]
167 // After sorting
168 // Idx: [1, 3, 2, 0]
169 // sizes: [1, 16, 10, 8]
170 // Strides: [1, 1, 16, 160]
171 //
172 // The logic below follows what TensorIterator uses in its logic:
173 // 1. Fast_set_up is the short-cut to identify a. channels_last and
174 // b. contiguous format, which is what we have in the below logic.
175 // 2. In more generla cases, it does best effort to preserve permutatoin.
176 if (is_channels_last_strides_2d(sizes, strides) || is_channels_last_strides_3d(sizes, strides)) {
177 // case 1.a. short cut channels last
178 std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2);
179 stride_indices[0] = 1;
180 stride_indices[n_dim - 1] = 0;
181 } else if (is_contiguous_strides(sizes, strides)) {
182 // case 1.b. short cut contiguous
183 std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
184 } else {
185 std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
186 // case 2.
187 //
188 // For broadcasted dimension where stride is 0, we have to stick to
189 // TensorIterator behavior in eager, where they introduce an ambiguous
190 // comparison result to preserve permutation by best effort.
191 // For more details, see NOTE: [Computing output strides]
192 auto should_swap = [&](size_t a, size_t b) {
193 if (strides[a] == 0 || strides[b] == 0) {
194 return 0;
195 } else if (strides[a] < strides[b]) {
196 return -1;
197 } else if (strides[a] > strides[b]) {
198 return 1;
199 } else { // strides[a] == strides[b]
200 if (sizes[a] > sizes[b]) {
201 return 1;
202 }
203 }
204 return 0;
205 };
206 for (int i = 1; i < n_dim; i++) {
207 int dim1 = i;
208 for (int dim0 = i - 1; dim0 >= 0; dim0--) {
209 int comparison = should_swap(stride_indices[dim0], stride_indices[dim1]);
210 if (comparison > 0) {
211 std::swap(stride_indices[dim0], stride_indices[dim1]);
212 dim1 = dim0;
213 } else if (comparison < 0) {
214 break;
215 }
216 }
217 }
218 // conveniently is_contiguous_strides/is_contiguous_strides only returns
219 // true when there's no memory overlap, so we only re-compute has_overlap
220 // in the last branch when both returns false
221 if (!tensor_contiguity) {
222 // trust tensor_contiguity and only computes overlap when it is not set
223 has_overlap = possible_cross_dimension_overlap(sizes, strides);
224 }
225 }
226
227 std::vector<Stride> stride_properties;
228 stride_properties.reserve(stride_indices.size());
229 for (size_t i = 0; i < stride_indices.size(); i++) {
230 bool contiguous_ = tensor_contiguity;
231 if (!contiguous_) {
232 if (!has_overlap) {
233 // innermost stride expected to be 1
234 // TODO: turn contiguous_ into an enum CONTIGUOUS, NONCONTIGUOUS,
235 // BROADCASTED
236 if (i == 0) {
237 contiguous_ = strides[stride_indices[i]] == 1;
238 } else {
239 contiguous_ = strides[stride_indices[i]] == 1 ||
240 (strides[stride_indices[i]] != 0 &&
241 strides[stride_indices[i]] ==
242 strides[stride_indices[i - 1]] * sizes[stride_indices[i - 1]]);
243 }
244 } else {
245 // leaving this assign statement for readability;
246 contiguous_ = false;
247 }
248 }
249 stride_properties.emplace_back(stride_indices[i], contiguous_, strides[stride_indices[i]]);
250 }
251
252 return VaryingShape<Stride>{stride_properties};
253 }
254
create(const at::Tensor & t)255 TensorTypePtr TensorType::create(const at::Tensor& t) {
256 VaryingShape<bool> contiguity;
257 VaryingShape<size_t> stride_indices;
258 VaryingShape<int64_t> strides;
259 VaryingShape<int64_t> sizes;
260 if (t.layout() == at::kStrided && !t.is_nested()) {
261 sizes = VaryingShape<int64_t>{t.sizes().vec()};
262 strides = VaryingShape<int64_t>{t.strides().vec()};
263 return TensorType::create(
264 t.scalar_type(), t.device(), sizes, strides, t.requires_grad(), false, t.is_contiguous());
265 }
266
267 return TensorType::create(
268 t.scalar_type(),
269 t.device(),
270 SymbolicShape(),
271 VaryingShape<Stride>{},
272 t.requires_grad(),
273 false);
274 }
275
create(std::optional<at::ScalarType> scalar_type,std::optional<Device> device,const VaryingShape<int64_t> & sizes,const VaryingShape<int64_t> & strides,std::optional<bool> requires_grad,std::optional<bool> undefined,bool tensor_contiguity)276 TensorTypePtr TensorType::create(
277 std::optional<at::ScalarType> scalar_type,
278 std::optional<Device> device,
279 const VaryingShape<int64_t>& sizes,
280 const VaryingShape<int64_t>& strides,
281 std::optional<bool> requires_grad,
282 std::optional<bool> undefined, bool tensor_contiguity) {
283 if(strides.concrete_sizes() && strides.concrete_sizes().has_value()){
284 // handles case where strides are set
285 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
286 TORCH_INTERNAL_ASSERT(sizes.concrete_sizes()->size() == strides.concrete_sizes()->size());
287 auto sprops = strides.concrete_sizes().has_value()
288 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
289 ? computeStrideProps(*sizes.concrete_sizes(), *strides.concrete_sizes(), tensor_contiguity)
290 : VaryingShape<Stride>();
291 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
292 auto symbol_sizes = SymbolicShape(*sizes.concrete_sizes());
293 return TensorType::create(
294 scalar_type, device, symbol_sizes, sprops, requires_grad, undefined);
295 } else {
296 // strides are all null, but still have number of strides equal to number of ranks
297 TORCH_INTERNAL_ASSERT(sizes.sizes() && sizes.size());
298 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
299 auto symbol_sizes = SymbolicShape(*sizes.sizes());
300 return TensorType::create(
301 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
302 scalar_type, device, symbol_sizes, VaryingShape<Stride>(*sizes.size()), requires_grad, undefined);
303 }
304 }
305
create(std::optional<at::ScalarType> scalar_type,std::optional<Device> device,const SymbolicShape & sizes,const VaryingShape<Stride> & strides,std::optional<bool> requires_grad,std::optional<bool> undefined)306 TensorTypePtr TensorType::create(
307 std::optional<at::ScalarType> scalar_type,
308 std::optional<Device> device,
309 const SymbolicShape& sizes,
310 const VaryingShape<Stride>& strides,
311 std::optional<bool> requires_grad,
312 std::optional<bool> undefined) {
313 auto pt = TensorTypePtr(new TensorType(
314 scalar_type, device, sizes, strides, requires_grad, undefined));
315 return pt;
316 }
317
create(std::optional<at::ScalarType> scalar_type,std::optional<Device> device,std::optional<size_t> dim,std::optional<bool> requires_grad)318 TensorTypePtr TensorType::create(
319 std::optional<at::ScalarType> scalar_type,
320 std::optional<Device> device,
321 std::optional<size_t> dim,
322 std::optional<bool> requires_grad) {
323 return TensorType::create(
324 scalar_type,
325 device,
326 SymbolicShape(dim),
327 VaryingShape<Stride>(dim),
328 requires_grad);
329 }
330
str() const331 std::string TensorType::str() const {
332 return "Tensor";
333 }
334
335 std::atomic<size_t> ShapeSymbol::num_symbols{1};
336
337 template struct VaryingShape<c10::ShapeSymbol>;
338 template struct VaryingShape<bool>;
339 template struct VaryingShape<size_t>;
340 template struct VaryingShape<int64_t>;
341 template struct VaryingShape<c10::Stride>;
342
sizes() const343 VaryingShape<int64_t> TensorType::sizes() const {
344 if (!sizes_.rank()) {
345 return VaryingShape<int64_t>();
346 }
347 return VaryingShape<int64_t>(
348 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
349 fmap(*sizes_.sizes(), [](ShapeSymbol ss) {
350 // we turn symbolic shapes into unknowns
351 return ss.is_static()
352 ? std::optional<int64_t>(ss.static_size())
353 : std::nullopt;
354 }));
355 }
356
merge(const TensorType & other,bool merge_sizes) const357 TensorTypePtr TensorType::merge(const TensorType& other, bool merge_sizes) const {
358 auto scalar_type = merge_primitive(scalarType(), other.scalarType());
359 auto dev = merge_primitive(device(), other.device());
360 auto sprops = stride_properties().merge(other.stride_properties());
361 auto gr = merge_primitive(requiresGrad(), other.requiresGrad());
362 auto undef = merge_primitive(undefined(), other.undefined());
363 return TensorType::create(
364 scalar_type,
365 dev,
366 merge_sizes ? symbolic_sizes().merge(other.symbolic_sizes())
367 : symbolic_sizes(),
368 sprops,
369 gr,
370 undef);
371 }
372
373 template <typename T>
is_null_or_equal(std::optional<T> a,c10::IntArrayRef b)374 bool is_null_or_equal(std::optional<T> a, c10::IntArrayRef b) {
375 return !a.has_value() || a.value() == b;
376 }
377
matchTensor(const at::Tensor & t)378 bool TensorType::matchTensor(const at::Tensor& t) {
379 bool undef = undefined().value_or(!t.defined());
380 if (undef != !t.defined()) {
381 // When the followings are true, we consider it's not a match:
382 // - undefined().has_value() == true
383 // - undefined().value() != !t.defined()
384 return false;
385 } else if (!t.defined()) {
386 // When the followings are true, we consider it's a match:
387 // - t is not defined
388 // - undefined() == null or undefined().value() == true
389 return true;
390 }
391 // Here we know t.defined() == true and compare all other properties.
392 bool rg = at::GradMode::is_enabled() && t.requires_grad();
393 bool matched_strides = (!stride_properties().size()) ||
394 (!t.has_storage() && !stride_properties().isComplete()) ||
395 stride_properties() ==
396 computeStrideProps(t.sizes(), t.strides(), t.is_contiguous());
397 return scalarType().value_or(t.scalar_type()) == t.scalar_type()
398 && device().value_or(t.device()) == t.device()
399 && requiresGrad().value_or(rg) == rg
400 && matched_strides
401 && is_null_or_equal(sizes().concrete_sizes(), t.sizes());
402 }
403
equals(const c10::Type & rhs) const404 bool TensorType::equals(const c10::Type& rhs) const {
405 if (rhs.kind() != kind()) {
406 return false;
407 }
408 auto rt = rhs.expect<TensorType>();
409
410 return scalar_type_ == rt->scalarType() && sizes() == rt->sizes() &&
411 stride_properties() == rt->stride_properties() &&
412 device() == rt->device() && requiresGrad() == rt->requiresGrad() &&
413 undefined() == rt->undefined();
414 }
415
strides() const416 VaryingShape<int64_t> TensorType::strides() const {
417 if (!strides_.size().has_value()) {
418 return VaryingShape<int64_t>();
419 }
420 std::vector<std::optional<int64_t>> ss(*strides_.size());
421 for (size_t i = 0; i < *strides_.size(); i++) {
422 if (!strides_[i].has_value()) {
423 continue;
424 }
425 auto s = *strides_[i];
426 if (s.stride_index_.has_value() && s.stride_.has_value()) {
427 ss[*s.stride_index_] = *s.stride_;
428 }
429 }
430 return VaryingShape<int64_t>(std::move(ss));
431 }
432
TensorType(std::optional<at::ScalarType> scalar_type,std::optional<Device> device,SymbolicShape sizes,VaryingShape<Stride> strides,std::optional<bool> requires_grad,std::optional<bool> undefined)433 TensorType::TensorType(
434 std::optional<at::ScalarType> scalar_type,
435 std::optional<Device> device,
436 SymbolicShape sizes,
437 VaryingShape<Stride> strides,
438 std::optional<bool> requires_grad,
439 std::optional<bool> undefined)
440 : SharedType(TypeKind::TensorType),
441 scalar_type_(scalar_type),
442 device_(device),
443 sizes_(std::move(sizes)),
444 strides_(std::move(strides)),
445 requires_grad_(requires_grad),
446 undefined_(undefined) {}
447
createContiguous(at::ScalarType scalar_type,at::Device device,at::IntArrayRef sizes)448 TensorTypePtr TensorType::createContiguous(
449 at::ScalarType scalar_type,
450 at::Device device,
451 at::IntArrayRef sizes) {
452 auto strides = contiguousStridesOf(sizes);
453 TORCH_INTERNAL_ASSERT(strides.size() == sizes.size());
454 return create(
455 scalar_type,
456 device,
457 VaryingShape<int64_t>(sizes),
458 VaryingShape<int64_t>(strides),
459 std::nullopt);
460 }
461
symbolic_sizes() const462 const SymbolicShape& TensorType::symbolic_sizes() const {
463 return sizes_;
464 }
465
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const466 bool TensorType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
467 if (auto rhs_p = rhs.cast<TensorType>()) {
468 // if we have the same pointer, avoid computing the merge
469 if (this == rhs_p.get()) {
470 return true;
471 }
472 return *merge(*rhs_p) == *rhs_p;
473 }
474 return Type::isSubtypeOfExt(rhs, why_not);
475 }
476
477 }
478