xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/helpers.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/core/helpers.h>
2 
3 #include <c10/util/Half.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/lazy/core/tensor_util.h>
6 
7 #include <limits>
8 
9 namespace torch {
10 namespace lazy {
11 
DropDimensions(c10::ArrayRef<int64_t> sizes,c10::ArrayRef<int64_t> drop_dims)12 std::vector<int64_t> DropDimensions(
13     c10::ArrayRef<int64_t> sizes,
14     c10::ArrayRef<int64_t> drop_dims) {
15   std::vector<int64_t> new_dims;
16   size_t drop_index = 0;
17   for (const auto i : c10::irange(sizes.size())) {
18     if (drop_index < drop_dims.size() &&
19         static_cast<int64_t>(i) == drop_dims[drop_index]) {
20       ++drop_index;
21     } else {
22       new_dims.push_back(sizes[i]);
23     }
24   }
25   TORCH_CHECK(drop_index == drop_dims.size());
26   return new_dims;
27 }
28 
GetCanonicalDimensionIndex(int64_t dim,int64_t rank)29 int64_t GetCanonicalDimensionIndex(int64_t dim, int64_t rank) {
30   int64_t min_shape_dim = -rank;
31   int64_t max_shape_dim = rank - 1;
32   TORCH_CHECK(
33       min_shape_dim <= dim && dim <= max_shape_dim,
34       "Value out of range (expected to be in range of [",
35       min_shape_dim,
36       ", ",
37       max_shape_dim,
38       "], but got ",
39       dim,
40       ")");
41   int64_t dim_index = dim < 0 ? rank + dim : dim;
42   TORCH_CHECK(dim_index >= 0);
43   TORCH_CHECK(dim_index < rank);
44   return dim_index;
45 }
46 
GetCanonicalDimensionIndices(c10::ArrayRef<int64_t> dimensions,int64_t rank)47 std::vector<int64_t> GetCanonicalDimensionIndices(
48     c10::ArrayRef<int64_t> dimensions,
49     int64_t rank) {
50   std::vector<int64_t> canonical_dim_indices;
51   for (int64_t dim : dimensions) {
52     canonical_dim_indices.push_back(GetCanonicalDimensionIndex(dim, rank));
53   }
54   return canonical_dim_indices;
55 }
56 
GetCanonicalPosition(c10::ArrayRef<int64_t> dimensions,int64_t dim,int64_t pos)57 int64_t GetCanonicalPosition(
58     c10::ArrayRef<int64_t> dimensions,
59     int64_t dim,
60     int64_t pos) {
61   dim = GetCanonicalDimensionIndex(dim, dimensions.size());
62   if (pos < 0) {
63     pos = GetCanonicalDimensionIndex(pos, dimensions[dim]);
64   } else {
65     pos = std::min<int64_t>(pos, dimensions[dim]);
66   }
67   return pos;
68 }
69 
MakeTransposePermutation(int64_t dim0,int64_t dim1,int64_t rank)70 std::vector<int64_t> MakeTransposePermutation(
71     int64_t dim0,
72     int64_t dim1,
73     int64_t rank) {
74   int64_t canonical_dim0 = GetCanonicalDimensionIndex(dim0, rank);
75   int64_t canonical_dim1 = GetCanonicalDimensionIndex(dim1, rank);
76   auto permute_dims = Iota<int64_t>(rank);
77   std::swap(permute_dims[canonical_dim0], permute_dims[canonical_dim1]);
78   return permute_dims;
79 }
80 
GetPromotedShape(c10::ArrayRef<int64_t> shape1_dims,c10::ArrayRef<int64_t> shape2_dims)81 std::vector<int64_t> GetPromotedShape(
82     c10::ArrayRef<int64_t> shape1_dims,
83     c10::ArrayRef<int64_t> shape2_dims) {
84   std::vector<int64_t> dimensions;
85   // If the rank of a shape is bigger than then other, fill up the first
86   // dimensions with the ones of the bigger.
87   // Example:
88   //   shape1 = [9, 7, 6, 5, 2]
89   //   shape2 =       [6, 1, 2]
90   // Insert [9, 7] into the dimensions vector.
91   if (shape1_dims.size() > shape2_dims.size()) {
92     dimensions.insert(
93         dimensions.end(),
94         shape1_dims.begin(),
95         shape1_dims.begin() + (shape1_dims.size() - shape2_dims.size()));
96   } else if (shape2_dims.size() > shape1_dims.size()) {
97     dimensions.insert(
98         dimensions.end(),
99         shape2_dims.begin(),
100         shape2_dims.begin() + (shape2_dims.size() - shape1_dims.size()));
101   }
102   // For the common dimensions, they must match, or one of them be 1.
103   size_t min_size = std::min(shape1_dims.size(), shape2_dims.size());
104   for (const auto i : c10::irange(min_size)) {
105     int64_t dim1 = shape1_dims[shape1_dims.size() - min_size + i];
106     int64_t dim2 = shape2_dims[shape2_dims.size() - min_size + i];
107     TORCH_CHECK(
108         dim1 == dim2 || dim1 == 1 || dim2 == 1,
109         "(",
110         c10::Join(", ", shape1_dims),
111         ") and (",
112         c10::Join(", ", shape1_dims),
113         ")");
114     if (dim1 == 0 || dim2 == 0) {
115       dimensions.push_back(0);
116     } else {
117       dimensions.push_back(std::max<int64_t>(dim1, dim2));
118     }
119   }
120   return dimensions;
121 }
122 
GetPromotedBinaryOpShape(const Shape & shape1,const Shape & shape2)123 Shape GetPromotedBinaryOpShape(const Shape& shape1, const Shape& shape2) {
124   return Shape(
125       promoteTypes(shape1.scalar_type(), shape2.scalar_type()),
126       GetPromotedShape(shape1.sizes(), shape2.sizes()));
127 }
128 
StrSplit(c10::string_view text,char delim)129 std::vector<std::string> StrSplit(c10::string_view text, char delim) {
130   size_t start = 0;
131   size_t end = 0;
132 
133   std::vector<std::string> tokens;
134   while ((start = text.find_first_not_of(delim, end)) != std::string::npos) {
135     end = text.find(delim, start);
136     auto token = text.substr(start, end - start);
137     tokens.emplace_back(token.begin(), token.end());
138   }
139   return tokens;
140 }
141 
142 } // namespace lazy
143 } // namespace torch
144