1 #include <torch/csrc/lazy/core/helpers.h>
2
3 #include <c10/util/Half.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/lazy/core/tensor_util.h>
6
7 #include <limits>
8
9 namespace torch {
10 namespace lazy {
11
DropDimensions(c10::ArrayRef<int64_t> sizes,c10::ArrayRef<int64_t> drop_dims)12 std::vector<int64_t> DropDimensions(
13 c10::ArrayRef<int64_t> sizes,
14 c10::ArrayRef<int64_t> drop_dims) {
15 std::vector<int64_t> new_dims;
16 size_t drop_index = 0;
17 for (const auto i : c10::irange(sizes.size())) {
18 if (drop_index < drop_dims.size() &&
19 static_cast<int64_t>(i) == drop_dims[drop_index]) {
20 ++drop_index;
21 } else {
22 new_dims.push_back(sizes[i]);
23 }
24 }
25 TORCH_CHECK(drop_index == drop_dims.size());
26 return new_dims;
27 }
28
GetCanonicalDimensionIndex(int64_t dim,int64_t rank)29 int64_t GetCanonicalDimensionIndex(int64_t dim, int64_t rank) {
30 int64_t min_shape_dim = -rank;
31 int64_t max_shape_dim = rank - 1;
32 TORCH_CHECK(
33 min_shape_dim <= dim && dim <= max_shape_dim,
34 "Value out of range (expected to be in range of [",
35 min_shape_dim,
36 ", ",
37 max_shape_dim,
38 "], but got ",
39 dim,
40 ")");
41 int64_t dim_index = dim < 0 ? rank + dim : dim;
42 TORCH_CHECK(dim_index >= 0);
43 TORCH_CHECK(dim_index < rank);
44 return dim_index;
45 }
46
GetCanonicalDimensionIndices(c10::ArrayRef<int64_t> dimensions,int64_t rank)47 std::vector<int64_t> GetCanonicalDimensionIndices(
48 c10::ArrayRef<int64_t> dimensions,
49 int64_t rank) {
50 std::vector<int64_t> canonical_dim_indices;
51 for (int64_t dim : dimensions) {
52 canonical_dim_indices.push_back(GetCanonicalDimensionIndex(dim, rank));
53 }
54 return canonical_dim_indices;
55 }
56
GetCanonicalPosition(c10::ArrayRef<int64_t> dimensions,int64_t dim,int64_t pos)57 int64_t GetCanonicalPosition(
58 c10::ArrayRef<int64_t> dimensions,
59 int64_t dim,
60 int64_t pos) {
61 dim = GetCanonicalDimensionIndex(dim, dimensions.size());
62 if (pos < 0) {
63 pos = GetCanonicalDimensionIndex(pos, dimensions[dim]);
64 } else {
65 pos = std::min<int64_t>(pos, dimensions[dim]);
66 }
67 return pos;
68 }
69
MakeTransposePermutation(int64_t dim0,int64_t dim1,int64_t rank)70 std::vector<int64_t> MakeTransposePermutation(
71 int64_t dim0,
72 int64_t dim1,
73 int64_t rank) {
74 int64_t canonical_dim0 = GetCanonicalDimensionIndex(dim0, rank);
75 int64_t canonical_dim1 = GetCanonicalDimensionIndex(dim1, rank);
76 auto permute_dims = Iota<int64_t>(rank);
77 std::swap(permute_dims[canonical_dim0], permute_dims[canonical_dim1]);
78 return permute_dims;
79 }
80
GetPromotedShape(c10::ArrayRef<int64_t> shape1_dims,c10::ArrayRef<int64_t> shape2_dims)81 std::vector<int64_t> GetPromotedShape(
82 c10::ArrayRef<int64_t> shape1_dims,
83 c10::ArrayRef<int64_t> shape2_dims) {
84 std::vector<int64_t> dimensions;
85 // If the rank of a shape is bigger than then other, fill up the first
86 // dimensions with the ones of the bigger.
87 // Example:
88 // shape1 = [9, 7, 6, 5, 2]
89 // shape2 = [6, 1, 2]
90 // Insert [9, 7] into the dimensions vector.
91 if (shape1_dims.size() > shape2_dims.size()) {
92 dimensions.insert(
93 dimensions.end(),
94 shape1_dims.begin(),
95 shape1_dims.begin() + (shape1_dims.size() - shape2_dims.size()));
96 } else if (shape2_dims.size() > shape1_dims.size()) {
97 dimensions.insert(
98 dimensions.end(),
99 shape2_dims.begin(),
100 shape2_dims.begin() + (shape2_dims.size() - shape1_dims.size()));
101 }
102 // For the common dimensions, they must match, or one of them be 1.
103 size_t min_size = std::min(shape1_dims.size(), shape2_dims.size());
104 for (const auto i : c10::irange(min_size)) {
105 int64_t dim1 = shape1_dims[shape1_dims.size() - min_size + i];
106 int64_t dim2 = shape2_dims[shape2_dims.size() - min_size + i];
107 TORCH_CHECK(
108 dim1 == dim2 || dim1 == 1 || dim2 == 1,
109 "(",
110 c10::Join(", ", shape1_dims),
111 ") and (",
112 c10::Join(", ", shape1_dims),
113 ")");
114 if (dim1 == 0 || dim2 == 0) {
115 dimensions.push_back(0);
116 } else {
117 dimensions.push_back(std::max<int64_t>(dim1, dim2));
118 }
119 }
120 return dimensions;
121 }
122
GetPromotedBinaryOpShape(const Shape & shape1,const Shape & shape2)123 Shape GetPromotedBinaryOpShape(const Shape& shape1, const Shape& shape2) {
124 return Shape(
125 promoteTypes(shape1.scalar_type(), shape2.scalar_type()),
126 GetPromotedShape(shape1.sizes(), shape2.sizes()));
127 }
128
StrSplit(c10::string_view text,char delim)129 std::vector<std::string> StrSplit(c10::string_view text, char delim) {
130 size_t start = 0;
131 size_t end = 0;
132
133 std::vector<std::string> tokens;
134 while ((start = text.find_first_not_of(delim, end)) != std::string::npos) {
135 end = text.find(delim, start);
136 auto token = text.substr(start, end - start);
137 tokens.emplace_back(token.begin(), token.end());
138 }
139 return tokens;
140 }
141
142 } // namespace lazy
143 } // namespace torch
144