xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/type_to_shape.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "mlir/IR/AffineMap.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
24 #include "mlir/IR/Location.h"  // from @llvm-project
25 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
26 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 using ::int64_t;
34 using mlir::IntegerType;
35 using mlir::MemRefType;
36 using mlir::RankedTensorType;
37 using mlir::ShapedType;
38 using mlir::VectorType;
39 using mlir::mhlo::TypeExtensionsAttr;
40 using xla::PrimitiveType;
41 using xla::ShapeUtil;
42 
43 namespace xla {
44 
TypeToPrimitiveType(mlir::Type type)45 PrimitiveType TypeToPrimitiveType(mlir::Type type) {
46   if (type.isBF16()) {
47     return PrimitiveType::BF16;
48   } else if (type.isF16()) {
49     return PrimitiveType::F16;
50   } else if (type.isF32()) {
51     return PrimitiveType::F32;
52   } else if (type.isF64()) {
53     return PrimitiveType::F64;
54   } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
55     mlir::Type element_ty = complex_type.getElementType();
56     if (element_ty.isF32()) {
57       return PrimitiveType::C64;
58 
59     } else if (element_ty.isF64()) {
60       return PrimitiveType::C128;
61     }
62     return PrimitiveType::PRIMITIVE_TYPE_INVALID;
63   } else if (auto integer_type = type.dyn_cast<mlir::IntegerType>()) {
64     bool is_unsigned = integer_type.isUnsigned();
65     switch (integer_type.getWidth()) {
66       case 1:
67         return PrimitiveType::PRED;
68       case 8:
69         return is_unsigned ? PrimitiveType::U8 : PrimitiveType::S8;
70       case 16:
71         return is_unsigned ? PrimitiveType::U16 : PrimitiveType::S16;
72       case 32:
73         return is_unsigned ? PrimitiveType::U32 : PrimitiveType::S32;
74       case 64:
75         return is_unsigned ? PrimitiveType::U64 : PrimitiveType::S64;
76       default:
77         return PrimitiveType::PRIMITIVE_TYPE_INVALID;
78     }
79   }
80   return PrimitiveType::PRIMITIVE_TYPE_INVALID;
81 }
82 
TypeToShape(mlir::Type type)83 Shape TypeToShape(mlir::Type type) {
84   PrimitiveType ptype = TypeToPrimitiveType(type);
85   if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID)
86     return ShapeUtil::MakeShape(ptype, {});
87 
88   if (type.isIntOrFloat()) {
89     auto* context = type.getContext();
90     mlir::emitError(mlir::UnknownLoc::get(context))
91         << "lowering should have been handled by primitive type lowering for "
92         << debugString(type);
93   } else if (auto v = type.dyn_cast<mlir::VectorType>()) {
94     llvm::SmallVector<int64_t, 4> span(v.getShape().begin(),
95                                        v.getShape().end());
96     mlir::Type element_type = v.getElementType();
97     PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
98     if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID)
99       return ShapeUtil::MakeShape(primitive_type, span);
100   } else if (auto m = type.dyn_cast<mlir::MemRefType>()) {
101     llvm::SmallVector<int64_t, 6> span(m.getShape().begin(),
102                                        m.getShape().end());
103     mlir::Type element_type = m.getElementType();
104     // Treat a memref of a vector as if it was a memref of primitive type with
105     // the vector dimensions at the end.
106     if (auto v = element_type.dyn_cast<mlir::VectorType>()) {
107       element_type = v.getElementType();
108       span.insert(span.end(), v.getShape().begin(), v.getShape().end());
109     }
110     PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
111     if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) return {};
112     // For the primitive type case, the shape of the memref is similar to the
113     // vector type case (i.e., it is, modulo the layout, the same dimensions
114     // and primitive type).
115     if (m.getLayout().isIdentity())
116       return ShapeUtil::MakeShape(primitive_type, span);
117 
118     llvm::SmallVector<int64_t, 4> strides;
119     int64_t offset;
120     if (failed(mlir::getStridesAndOffset(m, strides, offset))) return {};
121 
122     llvm::SmallVector<std::pair<int64_t, int>, 4> strides_with_indices;
123     for (const auto& e : llvm::enumerate(strides)) {
124       strides_with_indices.push_back({e.value(), e.index()});
125     }
126     std::stable_sort(strides_with_indices.begin(), strides_with_indices.end());
127 
128     llvm::SmallVector<int64_t, 4> minor_to_major;
129     int64_t stride = 1;
130     for (const auto& pr : strides_with_indices) {
131       minor_to_major.push_back(pr.second);
132 
133       // Either the affine map is not perfectly strided, or the dimensions
134       // recovered from strides don't match the actual dimensions in shapes.
135       if (stride != pr.first && m.getShape()[pr.second] != 1) return {};
136 
137       stride *= m.getShape()[pr.second];
138     }
139 
140     llvm::SmallVector<int64_t, 4> dimensions(m.getShape().begin(),
141                                              m.getShape().end());
142     return ::xla::ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions,
143                                                  minor_to_major);
144   } else if (auto t = type.dyn_cast<mlir::RankedTensorType>()) {
145     // TODO(jpienaar): This is only handling the base case with primitive
146     // element type.
147     int64_t rank = t.getRank();
148     llvm::SmallVector<int64_t, 4> bounds;
149     if (auto extn = t.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>()) {
150       bounds = llvm::to_vector<4>(extn.getBounds());
151     } else {
152       bounds.assign(rank, ShapedType::kDynamicSize);
153     }
154 
155     llvm::SmallVector<int64_t, 4> shape(rank, mlir::ShapedType::kDynamicSize);
156     std::vector<bool> is_dynamic(rank, false);
157     for (int64_t dim = 0; dim < rank; ++dim) {
158       // Only fully static shapes are supported.
159       // TODO(b/115638799): Update once xla::Shape can support dynamic shapes.
160       int64_t size = t.getDimSize(dim);
161       if (size == ShapedType::kDynamicSize) {
162         if (bounds[dim] == ShapedType::kDynamicSize) return {};
163         shape[dim] = bounds[dim];
164         is_dynamic[dim] = true;
165       } else {
166         if (bounds[dim] != ShapedType::kDynamicSize) return {};
167         shape[dim] = size;
168       }
169     }
170 
171     PrimitiveType primitive_type = TypeToPrimitiveType(t.getElementType());
172     if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) return {};
173 
174     return ShapeUtil::MakeShape(primitive_type, shape, is_dynamic);
175   } else if (auto tuple_type = type.dyn_cast<mlir::TupleType>()) {
176     llvm::SmallVector<Shape, 4> shapes;
177     shapes.reserve(tuple_type.size());
178     for (mlir::Type sub_type : tuple_type.getTypes()) {
179       shapes.push_back(TypeToShape(sub_type));
180     }
181     return ShapeUtil::MakeTupleShape(shapes);
182 
183   } else if (type.isa<mlir::mhlo::TokenType>()) {
184     return ShapeUtil::MakeTokenShape();
185   }
186 
187   // Return empty XLA shape to signify error. No MLIR Type maps to a empty
188   // Shape.
189   return {};
190 }
191 
192 }  // namespace xla
193