xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/tensor_type.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Tensor.h>
2 #include <ATen/core/jit_type.h>
3 #include <c10/core/GradMode.h>
4 
5 #include <utility>
6 
7 namespace c10 {
8 
9 namespace {
10 
11 // The idea is to only mark possible overlap across dimensions. We want to
12 // return false for expanded tensors and permuted tensors, for which dimensional
13 // collapsing is safe.
possible_cross_dimension_overlap(c10::IntArrayRef sizes,c10::IntArrayRef strides)14 bool possible_cross_dimension_overlap(c10::IntArrayRef sizes, c10::IntArrayRef strides) {
15   int n_dim = static_cast<int>(sizes.size());
16   std::vector<size_t> stride_indices(n_dim);
17   std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
18 
19   // sort indices going with ascending strides
20   for (int i = 1; i < n_dim; i++) {
21     auto c = i;
22     for (int j = i - 1; j >= 0; j--) {
23       if (strides[stride_indices[j]] > strides[stride_indices[c]]) {
24         std::swap(stride_indices[j], stride_indices[c]);
25         c = j;
26       }
27     }
28   }
29 
30   for (const auto i : c10::irange(1, n_dim)) {
31     if (i != 0) {
32       // we are being conservative on checking for memory overlap
33       if (sizes[stride_indices[i]] != 1 && strides[stride_indices[i]] < sizes[stride_indices[i-1]] * strides[stride_indices[i-1]]) {
34         return true;
35       }
36     }
37   }
38   return false;
39 }
40 
41 }
42 
get()43 const TensorTypePtr& TensorType::get() {
44   static auto value = TensorType::create(
45       {}, {}, SymbolicShape(), VaryingShape<Stride>{}, {});
46   return value;
47 }
48 
ofTensors()49 ListTypePtr ListType::ofTensors() {
50   static auto value = ListType::create(TensorType::get());
51   return value;
52 }
53 
54 template <typename T>
merge(const VaryingShape<T> & other) const55 VaryingShape<T> VaryingShape<T>::merge(const VaryingShape<T>& other) const {
56   if (!dims_ || !other.dims_ || dims_->size() != other.dims_->size()) {
57     return VaryingShape<T>();
58   }
59   ListOfOptionalElements dims;
60   for (size_t i = 0, n = dims_->size(); i < n; i++) {
61     dims.push_back(merge_primitive((*dims_)[i], (*other.dims_)[i]));
62   }
63   return VaryingShape<T>(std::move(dims));
64 }
65 
66 template <typename T>
operator <<(std::ostream & out,const VaryingShape<T> & vs)67 std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
68   out << "(";
69   if (!vs.size()) {
70     out << "*)";
71     return out;
72   }
73 
74   for (size_t i = 0; i < vs.size(); i++) {
75     if (i > 0) {
76       out << ", ";
77     }
78     if (vs[i].has_value()) {
79       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
80       out << vs[i].value();
81     } else {
82       out << "*";
83     }
84   }
85   out << ")";
86   return out;
87 }
88 
89 template std::ostream& operator<<(
90     std::ostream& out,
91     const VaryingShape<int64_t>& vs);
92 template std::ostream& operator<<(
93     std::ostream& out,
94     const VaryingShape<Stride>& vs);
95 
operator <<(std::ostream & os,const SymbolicShape & ss)96 std::ostream& operator<<(
97     std::ostream& os,
98     const SymbolicShape& ss) {
99   // TODO: Unranked SymbolicShape printing is ambiguous with that of
100   // dynamic-shaped vector.
101   if(!ss.rank()) {
102     os << "(*)";
103     return os;
104   }
105 
106   auto sizes = ss.sizes().value();
107 
108   os << "(";
109   for (size_t i = 0; i < ss.rank().value(); i++) {
110     if (i > 0) {
111       os << ", ";
112     }
113     if(sizes[i].is_static()) {
114       os << sizes[i];
115     } else {
116       os << "*";
117     }
118   }
119   os << ")";
120 
121   return os;
122 }
123 
operator <<(std::ostream & os,const ShapeSymbol & s)124 std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s) {
125   if (s.value_ >= 0) {
126     os << s.value_;
127   } else {
128     os << "SS(" << s.value_ << ')';
129   }
130   return os;
131 }
132 
operator <<(std::ostream & os,const Stride & s)133 std::ostream& operator<<(std::ostream& os, const Stride& s) {
134   os << "{";
135   if (s.stride_index_.has_value()) {
136     os << *s.stride_index_;
137   } else {
138     os << "*";
139   }
140   os << ":";
141   if (s.stride_.has_value()) {
142     os << *s.stride_;
143   } else {
144     os << "*";
145   }
146   os << '}';
147   return os;
148 }
149 
computeStrideProps(at::IntArrayRef sizes,at::IntArrayRef strides,bool tensor_contiguity)150 VaryingShape<Stride> TensorType::computeStrideProps(
151     at::IntArrayRef sizes,
152     at::IntArrayRef strides,
153     bool tensor_contiguity) {
154   int n_dim = static_cast<int>(sizes.size());
155   std::vector<size_t> stride_indices(n_dim);
156   // default has_overlap to false as we only compute overlap when:
157   // 1. input sizes/strides fails format check;
158   // 2. tensor_contiguity are not set.
159   bool has_overlap = false;
160 
161   // Sorting strides in ascending order
162   // Example:
163   //  Prior to sorting
164   //  Idx:     [0,   1,  2,  3]
165   //  sizes:   [8,   1, 10, 16]
166   //  Strides: [160, 1, 16,  1]
167   //  After sorting
168   //  Idx:     [1,  3,  2,   0]
169   //  sizes:   [1, 16, 10,   8]
170   //  Strides: [1,  1, 16, 160]
171   //
172   // The logic below follows what TensorIterator uses in its logic:
173   //   1. Fast_set_up is the short-cut to identify a. channels_last and
174   //      b. contiguous format, which is what we have in the below logic.
175   //   2. In more generla cases, it does best effort to preserve permutatoin.
176   if (is_channels_last_strides_2d(sizes, strides) || is_channels_last_strides_3d(sizes, strides)) {
177     // case 1.a. short cut channels last
178     std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2);
179     stride_indices[0] = 1;
180     stride_indices[n_dim - 1] = 0;
181   } else if (is_contiguous_strides(sizes, strides)) {
182     // case 1.b. short cut contiguous
183     std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
184   } else {
185     std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
186     // case 2.
187     //
188     // For broadcasted dimension where stride is 0, we have to stick to
189     // TensorIterator behavior in eager, where they introduce an ambiguous
190     // comparison result to preserve permutation by best effort.
191     // For more details, see NOTE: [Computing output strides]
192     auto should_swap = [&](size_t a, size_t b) {
193       if (strides[a] == 0 || strides[b] == 0) {
194         return 0;
195       } else if (strides[a] < strides[b]) {
196         return -1;
197       } else if (strides[a] > strides[b]) {
198         return 1;
199       } else { // strides[a] == strides[b]
200         if (sizes[a] > sizes[b]) {
201           return 1;
202         }
203       }
204       return 0;
205     };
206     for (int i = 1; i < n_dim; i++) {
207       int dim1 = i;
208       for (int dim0 = i - 1; dim0 >= 0; dim0--) {
209         int comparison = should_swap(stride_indices[dim0], stride_indices[dim1]);
210         if (comparison > 0) {
211           std::swap(stride_indices[dim0], stride_indices[dim1]);
212           dim1 = dim0;
213         } else if (comparison < 0) {
214           break;
215         }
216       }
217     }
218     // conveniently is_contiguous_strides/is_contiguous_strides only returns
219     // true when there's no memory overlap, so we only re-compute has_overlap
220     // in the last branch when both returns false
221     if (!tensor_contiguity) {
222       // trust tensor_contiguity and only computes overlap when it is not set
223       has_overlap = possible_cross_dimension_overlap(sizes, strides);
224     }
225   }
226 
227   std::vector<Stride> stride_properties;
228   stride_properties.reserve(stride_indices.size());
229   for (size_t i = 0; i < stride_indices.size(); i++) {
230     bool contiguous_ = tensor_contiguity;
231     if (!contiguous_) {
232       if (!has_overlap) {
233         // innermost stride expected to be 1
234         // TODO: turn contiguous_ into an enum CONTIGUOUS, NONCONTIGUOUS,
235         // BROADCASTED
236         if (i == 0) {
237           contiguous_ = strides[stride_indices[i]] == 1;
238         } else {
239           contiguous_ = strides[stride_indices[i]] == 1 ||
240               (strides[stride_indices[i]] != 0 &&
241                strides[stride_indices[i]] ==
242                    strides[stride_indices[i - 1]] * sizes[stride_indices[i - 1]]);
243         }
244       } else {
245         // leaving this assign statement for readability;
246         contiguous_ = false;
247       }
248     }
249     stride_properties.emplace_back(stride_indices[i], contiguous_, strides[stride_indices[i]]);
250   }
251 
252   return VaryingShape<Stride>{stride_properties};
253 }
254 
create(const at::Tensor & t)255 TensorTypePtr TensorType::create(const at::Tensor& t) {
256   VaryingShape<bool> contiguity;
257   VaryingShape<size_t> stride_indices;
258   VaryingShape<int64_t> strides;
259   VaryingShape<int64_t> sizes;
260   if (t.layout() == at::kStrided && !t.is_nested()) {
261     sizes = VaryingShape<int64_t>{t.sizes().vec()};
262     strides = VaryingShape<int64_t>{t.strides().vec()};
263     return TensorType::create(
264         t.scalar_type(), t.device(), sizes, strides, t.requires_grad(), false, t.is_contiguous());
265   }
266 
267   return TensorType::create(
268       t.scalar_type(),
269       t.device(),
270       SymbolicShape(),
271       VaryingShape<Stride>{},
272       t.requires_grad(),
273       false);
274 }
275 
create(std::optional<at::ScalarType> scalar_type,std::optional<Device> device,const VaryingShape<int64_t> & sizes,const VaryingShape<int64_t> & strides,std::optional<bool> requires_grad,std::optional<bool> undefined,bool tensor_contiguity)276 TensorTypePtr TensorType::create(
277     std::optional<at::ScalarType> scalar_type,
278     std::optional<Device> device,
279     const VaryingShape<int64_t>& sizes,
280     const VaryingShape<int64_t>& strides,
281     std::optional<bool> requires_grad,
282     std::optional<bool> undefined, bool tensor_contiguity) {
283   if(strides.concrete_sizes() && strides.concrete_sizes().has_value()){
284     // handles case where strides are set
285     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
286     TORCH_INTERNAL_ASSERT(sizes.concrete_sizes()->size() == strides.concrete_sizes()->size());
287     auto sprops = strides.concrete_sizes().has_value()
288       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
289       ? computeStrideProps(*sizes.concrete_sizes(), *strides.concrete_sizes(), tensor_contiguity)
290       : VaryingShape<Stride>();
291     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
292     auto symbol_sizes = SymbolicShape(*sizes.concrete_sizes());
293     return TensorType::create(
294       scalar_type, device, symbol_sizes, sprops, requires_grad, undefined);
295   } else {
296     // strides are all null, but still have number of strides equal to number of ranks
297     TORCH_INTERNAL_ASSERT(sizes.sizes() && sizes.size());
298     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
299     auto symbol_sizes = SymbolicShape(*sizes.sizes());
300     return TensorType::create(
301       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
302       scalar_type, device, symbol_sizes, VaryingShape<Stride>(*sizes.size()), requires_grad, undefined);
303   }
304 }
305 
create(std::optional<at::ScalarType> scalar_type,std::optional<Device> device,const SymbolicShape & sizes,const VaryingShape<Stride> & strides,std::optional<bool> requires_grad,std::optional<bool> undefined)306 TensorTypePtr TensorType::create(
307     std::optional<at::ScalarType> scalar_type,
308     std::optional<Device> device,
309     const SymbolicShape& sizes,
310     const VaryingShape<Stride>& strides,
311     std::optional<bool> requires_grad,
312     std::optional<bool> undefined) {
313   auto pt = TensorTypePtr(new TensorType(
314       scalar_type, device, sizes, strides, requires_grad, undefined));
315   return pt;
316 }
317 
create(std::optional<at::ScalarType> scalar_type,std::optional<Device> device,std::optional<size_t> dim,std::optional<bool> requires_grad)318 TensorTypePtr TensorType::create(
319     std::optional<at::ScalarType> scalar_type,
320     std::optional<Device> device,
321     std::optional<size_t> dim,
322     std::optional<bool> requires_grad) {
323   return TensorType::create(
324       scalar_type,
325       device,
326       SymbolicShape(dim),
327       VaryingShape<Stride>(dim),
328       requires_grad);
329 }
330 
str() const331 std::string TensorType::str() const {
332   return "Tensor";
333 }
334 
335 std::atomic<size_t> ShapeSymbol::num_symbols{1};
336 
337 template struct VaryingShape<c10::ShapeSymbol>;
338 template struct VaryingShape<bool>;
339 template struct VaryingShape<size_t>;
340 template struct VaryingShape<int64_t>;
341 template struct VaryingShape<c10::Stride>;
342 
sizes() const343 VaryingShape<int64_t> TensorType::sizes() const {
344   if (!sizes_.rank()) {
345     return VaryingShape<int64_t>();
346   }
347   return VaryingShape<int64_t>(
348       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
349       fmap(*sizes_.sizes(), [](ShapeSymbol ss) {
350         // we turn symbolic shapes into unknowns
351         return ss.is_static()
352             ? std::optional<int64_t>(ss.static_size())
353             : std::nullopt;
354       }));
355 }
356 
merge(const TensorType & other,bool merge_sizes) const357 TensorTypePtr TensorType::merge(const TensorType& other, bool merge_sizes) const {
358   auto scalar_type = merge_primitive(scalarType(), other.scalarType());
359   auto dev = merge_primitive(device(), other.device());
360   auto sprops = stride_properties().merge(other.stride_properties());
361   auto gr = merge_primitive(requiresGrad(), other.requiresGrad());
362   auto undef = merge_primitive(undefined(), other.undefined());
363   return TensorType::create(
364       scalar_type,
365       dev,
366       merge_sizes ? symbolic_sizes().merge(other.symbolic_sizes())
367                   : symbolic_sizes(),
368       sprops,
369       gr,
370       undef);
371 }
372 
373 template <typename T>
is_null_or_equal(std::optional<T> a,c10::IntArrayRef b)374 bool is_null_or_equal(std::optional<T> a, c10::IntArrayRef b) {
375   return !a.has_value() || a.value() == b;
376 }
377 
matchTensor(const at::Tensor & t)378 bool TensorType::matchTensor(const at::Tensor& t) {
379   bool undef = undefined().value_or(!t.defined());
380   if (undef != !t.defined()) {
381     // When the followings are true, we consider it's not a match:
382     // - undefined().has_value() == true
383     // - undefined().value() != !t.defined()
384     return false;
385   } else if (!t.defined()) {
386     // When the followings are true, we consider it's a match:
387     // - t is not defined
388     // - undefined() == null or undefined().value() == true
389     return true;
390   }
391   // Here we know t.defined() == true and compare all other properties.
392   bool rg = at::GradMode::is_enabled() && t.requires_grad();
393   bool matched_strides = (!stride_properties().size()) ||
394       (!t.has_storage() && !stride_properties().isComplete()) ||
395       stride_properties() ==
396           computeStrideProps(t.sizes(), t.strides(), t.is_contiguous());
397   return scalarType().value_or(t.scalar_type()) == t.scalar_type()
398     && device().value_or(t.device()) == t.device()
399     && requiresGrad().value_or(rg) == rg
400     && matched_strides
401     && is_null_or_equal(sizes().concrete_sizes(), t.sizes());
402 }
403 
equals(const c10::Type & rhs) const404 bool TensorType::equals(const c10::Type& rhs) const {
405   if (rhs.kind() != kind()) {
406     return false;
407   }
408   auto rt = rhs.expect<TensorType>();
409 
410   return scalar_type_ == rt->scalarType() && sizes() == rt->sizes() &&
411       stride_properties() == rt->stride_properties() &&
412       device() == rt->device() && requiresGrad() == rt->requiresGrad() &&
413       undefined() == rt->undefined();
414 }
415 
strides() const416 VaryingShape<int64_t> TensorType::strides() const {
417   if (!strides_.size().has_value()) {
418     return VaryingShape<int64_t>();
419   }
420   std::vector<std::optional<int64_t>> ss(*strides_.size());
421   for (size_t i = 0; i < *strides_.size(); i++) {
422     if (!strides_[i].has_value()) {
423       continue;
424     }
425     auto s = *strides_[i];
426     if (s.stride_index_.has_value() && s.stride_.has_value()) {
427       ss[*s.stride_index_] = *s.stride_;
428     }
429   }
430   return VaryingShape<int64_t>(std::move(ss));
431 }
432 
TensorType(std::optional<at::ScalarType> scalar_type,std::optional<Device> device,SymbolicShape sizes,VaryingShape<Stride> strides,std::optional<bool> requires_grad,std::optional<bool> undefined)433 TensorType::TensorType(
434     std::optional<at::ScalarType> scalar_type,
435     std::optional<Device> device,
436     SymbolicShape sizes,
437     VaryingShape<Stride> strides,
438     std::optional<bool> requires_grad,
439     std::optional<bool> undefined)
440     : SharedType(TypeKind::TensorType),
441       scalar_type_(scalar_type),
442       device_(device),
443       sizes_(std::move(sizes)),
444       strides_(std::move(strides)),
445       requires_grad_(requires_grad),
446       undefined_(undefined) {}
447 
createContiguous(at::ScalarType scalar_type,at::Device device,at::IntArrayRef sizes)448 TensorTypePtr TensorType::createContiguous(
449     at::ScalarType scalar_type,
450     at::Device device,
451     at::IntArrayRef sizes) {
452   auto strides = contiguousStridesOf(sizes);
453   TORCH_INTERNAL_ASSERT(strides.size() == sizes.size());
454   return create(
455       scalar_type,
456       device,
457       VaryingShape<int64_t>(sizes),
458       VaryingShape<int64_t>(strides),
459       std::nullopt);
460 }
461 
symbolic_sizes() const462 const SymbolicShape& TensorType::symbolic_sizes() const {
463   return sizes_;
464 }
465 
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const466 bool TensorType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
467   if (auto rhs_p = rhs.cast<TensorType>()) {
468     // if we have the same pointer, avoid computing the merge
469     if (this == rhs_p.get()) {
470       return true;
471     }
472     return *merge(*rhs_p) == *rhs_p;
473   }
474   return Type::isSubtypeOfExt(rhs, why_not);
475 }
476 
477 }
478