1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "mlir/IR/AffineMap.h" // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23 #include "mlir/IR/Diagnostics.h" // from @llvm-project
24 #include "mlir/IR/Location.h" // from @llvm-project
25 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
26 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/types.h"
32
33 using ::int64_t;
34 using mlir::IntegerType;
35 using mlir::MemRefType;
36 using mlir::RankedTensorType;
37 using mlir::ShapedType;
38 using mlir::VectorType;
39 using mlir::mhlo::TypeExtensionsAttr;
40 using xla::PrimitiveType;
41 using xla::ShapeUtil;
42
43 namespace xla {
44
TypeToPrimitiveType(mlir::Type type)45 PrimitiveType TypeToPrimitiveType(mlir::Type type) {
46 if (type.isBF16()) {
47 return PrimitiveType::BF16;
48 } else if (type.isF16()) {
49 return PrimitiveType::F16;
50 } else if (type.isF32()) {
51 return PrimitiveType::F32;
52 } else if (type.isF64()) {
53 return PrimitiveType::F64;
54 } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
55 mlir::Type element_ty = complex_type.getElementType();
56 if (element_ty.isF32()) {
57 return PrimitiveType::C64;
58
59 } else if (element_ty.isF64()) {
60 return PrimitiveType::C128;
61 }
62 return PrimitiveType::PRIMITIVE_TYPE_INVALID;
63 } else if (auto integer_type = type.dyn_cast<mlir::IntegerType>()) {
64 bool is_unsigned = integer_type.isUnsigned();
65 switch (integer_type.getWidth()) {
66 case 1:
67 return PrimitiveType::PRED;
68 case 8:
69 return is_unsigned ? PrimitiveType::U8 : PrimitiveType::S8;
70 case 16:
71 return is_unsigned ? PrimitiveType::U16 : PrimitiveType::S16;
72 case 32:
73 return is_unsigned ? PrimitiveType::U32 : PrimitiveType::S32;
74 case 64:
75 return is_unsigned ? PrimitiveType::U64 : PrimitiveType::S64;
76 default:
77 return PrimitiveType::PRIMITIVE_TYPE_INVALID;
78 }
79 }
80 return PrimitiveType::PRIMITIVE_TYPE_INVALID;
81 }
82
TypeToShape(mlir::Type type)83 Shape TypeToShape(mlir::Type type) {
84 PrimitiveType ptype = TypeToPrimitiveType(type);
85 if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID)
86 return ShapeUtil::MakeShape(ptype, {});
87
88 if (type.isIntOrFloat()) {
89 auto* context = type.getContext();
90 mlir::emitError(mlir::UnknownLoc::get(context))
91 << "lowering should have been handled by primitive type lowering for "
92 << debugString(type);
93 } else if (auto v = type.dyn_cast<mlir::VectorType>()) {
94 llvm::SmallVector<int64_t, 4> span(v.getShape().begin(),
95 v.getShape().end());
96 mlir::Type element_type = v.getElementType();
97 PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
98 if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID)
99 return ShapeUtil::MakeShape(primitive_type, span);
100 } else if (auto m = type.dyn_cast<mlir::MemRefType>()) {
101 llvm::SmallVector<int64_t, 6> span(m.getShape().begin(),
102 m.getShape().end());
103 mlir::Type element_type = m.getElementType();
104 // Treat a memref of a vector as if it was a memref of primitive type with
105 // the vector dimensions at the end.
106 if (auto v = element_type.dyn_cast<mlir::VectorType>()) {
107 element_type = v.getElementType();
108 span.insert(span.end(), v.getShape().begin(), v.getShape().end());
109 }
110 PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
111 if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) return {};
112 // For the primitive type case, the shape of the memref is similar to the
113 // vector type case (i.e., it is, modulo the layout, the same dimensions
114 // and primitive type).
115 if (m.getLayout().isIdentity())
116 return ShapeUtil::MakeShape(primitive_type, span);
117
118 llvm::SmallVector<int64_t, 4> strides;
119 int64_t offset;
120 if (failed(mlir::getStridesAndOffset(m, strides, offset))) return {};
121
122 llvm::SmallVector<std::pair<int64_t, int>, 4> strides_with_indices;
123 for (const auto& e : llvm::enumerate(strides)) {
124 strides_with_indices.push_back({e.value(), e.index()});
125 }
126 std::stable_sort(strides_with_indices.begin(), strides_with_indices.end());
127
128 llvm::SmallVector<int64_t, 4> minor_to_major;
129 int64_t stride = 1;
130 for (const auto& pr : strides_with_indices) {
131 minor_to_major.push_back(pr.second);
132
133 // Either the affine map is not perfectly strided, or the dimensions
134 // recovered from strides don't match the actual dimensions in shapes.
135 if (stride != pr.first && m.getShape()[pr.second] != 1) return {};
136
137 stride *= m.getShape()[pr.second];
138 }
139
140 llvm::SmallVector<int64_t, 4> dimensions(m.getShape().begin(),
141 m.getShape().end());
142 return ::xla::ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions,
143 minor_to_major);
144 } else if (auto t = type.dyn_cast<mlir::RankedTensorType>()) {
145 // TODO(jpienaar): This is only handling the base case with primitive
146 // element type.
147 int64_t rank = t.getRank();
148 llvm::SmallVector<int64_t, 4> bounds;
149 if (auto extn = t.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>()) {
150 bounds = llvm::to_vector<4>(extn.getBounds());
151 } else {
152 bounds.assign(rank, ShapedType::kDynamicSize);
153 }
154
155 llvm::SmallVector<int64_t, 4> shape(rank, mlir::ShapedType::kDynamicSize);
156 std::vector<bool> is_dynamic(rank, false);
157 for (int64_t dim = 0; dim < rank; ++dim) {
158 // Only fully static shapes are supported.
159 // TODO(b/115638799): Update once xla::Shape can support dynamic shapes.
160 int64_t size = t.getDimSize(dim);
161 if (size == ShapedType::kDynamicSize) {
162 if (bounds[dim] == ShapedType::kDynamicSize) return {};
163 shape[dim] = bounds[dim];
164 is_dynamic[dim] = true;
165 } else {
166 if (bounds[dim] != ShapedType::kDynamicSize) return {};
167 shape[dim] = size;
168 }
169 }
170
171 PrimitiveType primitive_type = TypeToPrimitiveType(t.getElementType());
172 if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) return {};
173
174 return ShapeUtil::MakeShape(primitive_type, shape, is_dynamic);
175 } else if (auto tuple_type = type.dyn_cast<mlir::TupleType>()) {
176 llvm::SmallVector<Shape, 4> shapes;
177 shapes.reserve(tuple_type.size());
178 for (mlir::Type sub_type : tuple_type.getTypes()) {
179 shapes.push_back(TypeToShape(sub_type));
180 }
181 return ShapeUtil::MakeTupleShape(shapes);
182
183 } else if (type.isa<mlir::mhlo::TokenType>()) {
184 return ShapeUtil::MakeTokenShape();
185 }
186
187 // Return empty XLA shape to signify error. No MLIR Type maps to a empty
188 // Shape.
189 return {};
190 }
191
192 } // namespace xla
193