1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/value_utils.h"
17
18 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
20 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
21 #include "tensorflow/core/platform/errors.h"
22 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
23 #include "tensorflow/dtensor/mlir/op_utils.h"
24
25 namespace tensorflow {
26 namespace dtensor {
27 namespace {
28
29 // Given a mlir::Value will trace the value back through
30 // DTensorLayout and basic blocks of while loops.
31 // This is like a reverse version of TraceUseToNextTFOp.
GetForwardedInput(mlir::Value value)32 mlir::Value GetForwardedInput(mlir::Value value) {
33 bool value_updated;
34 do {
35 value_updated = false;
36 if (mlir::BlockArgument argument = value.dyn_cast<mlir::BlockArgument>()) {
37 mlir::Region* region = argument.getParentRegion();
38 if (region == nullptr) break;
39 mlir::Operation* parent_op = region->getParentOp();
40 // TODO(bfontain): handle if and other control flow blocks.
41 if (mlir::TF::WhileRegionOp while_op =
42 mlir::dyn_cast<mlir::TF::WhileRegionOp>(parent_op)) {
43 value = while_op.getOperand(argument.getArgNumber());
44 value_updated = true;
45 }
46 } else {
47 mlir::Operation* op = value.getDefiningOp();
48 // TODO(bfontain): Add cases for identity and control flow return values.
49 if (mlir::TF::DTensorLayout layout_op =
50 mlir::dyn_cast<mlir::TF::DTensorLayout>(op)) {
51 value = layout_op.input();
52 value_updated = true;
53 }
54 }
55 } while (value_updated);
56
57 return value;
58 }
59 } // namespace
60
61 namespace ops_util = ::mlir::TF::collection_ops_util;
62
ValueRank(mlir::Value operand_value)63 int ValueRank(mlir::Value operand_value) {
64 mlir::Type type = GetSubtypeOrSelf(operand_value);
65 const auto operand_type = type.cast<mlir::TensorType>();
66 if (!operand_type.hasRank()) return -1;
67 return operand_type.getRank();
68 }
69
EffectivelyScalarR1Type(mlir::Type element_type)70 mlir::RankedTensorType EffectivelyScalarR1Type(mlir::Type element_type) {
71 return mlir::RankedTensorType::get({1}, element_type);
72 }
73
ReshapeSizeTypeToScalar(mlir::OpBuilder builder,mlir::Location loc,mlir::Value tensor)74 mlir::Value ReshapeSizeTypeToScalar(mlir::OpBuilder builder, mlir::Location loc,
75 mlir::Value tensor) {
76 auto scalar_type =
77 mlir::RankedTensorType::get({}, builder.getIntegerType(32));
78 mlir::Value scalar_shape =
79 ops_util::GetR1Const(scalar_type.getShape(), builder, loc);
80 return builder.create<mlir::TF::ReshapeOp>(
81 loc, mlir::ArrayRef<mlir::Type>{scalar_type},
82 mlir::ArrayRef<mlir::Value>{tensor, scalar_shape});
83 }
84
IntConst(mlir::OpBuilder & builder,mlir::Location loc,llvm::ArrayRef<int32> values)85 mlir::Value IntConst(mlir::OpBuilder& builder, mlir::Location loc,
86 llvm::ArrayRef<int32> values) {
87 auto const_type = mlir::RankedTensorType::get(
88 {static_cast<int64_t>(values.size())}, builder.getIntegerType(32));
89 mlir::Attribute const_attr =
90 mlir::DenseIntElementsAttr::get(const_type, values);
91 return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult();
92 }
93
Int64Const(mlir::OpBuilder & builder,mlir::Location loc,llvm::ArrayRef<int64_t> values)94 mlir::Value Int64Const(mlir::OpBuilder& builder, mlir::Location loc,
95 llvm::ArrayRef<int64_t> values) {
96 auto const_type = mlir::RankedTensorType::get(
97 {static_cast<int64_t>(values.size())}, builder.getIntegerType(64));
98 mlir::Attribute const_attr =
99 mlir::DenseIntElementsAttr::get(const_type, values);
100 return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult();
101 }
102
FloatConst(mlir::OpBuilder & builder,mlir::Location loc,llvm::ArrayRef<float> values)103 mlir::Value FloatConst(mlir::OpBuilder& builder, mlir::Location loc,
104 llvm::ArrayRef<float> values) {
105 mlir::RankedTensorType const_type = mlir::RankedTensorType::get(
106 {static_cast<int64_t>(values.size())}, builder.getF32Type());
107 mlir::Attribute const_attr =
108 mlir::DenseFPElementsAttr::get(const_type, values);
109 return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult();
110 }
111
StringConst(mlir::OpBuilder & builder,mlir::Location loc,llvm::ArrayRef<llvm::StringRef> values)112 mlir::Value StringConst(mlir::OpBuilder& builder, mlir::Location loc,
113 llvm::ArrayRef<llvm::StringRef> values) {
114 auto const_type =
115 mlir::RankedTensorType::get({static_cast<int64_t>(values.size())},
116 builder.getType<mlir::TF::StringType>());
117 mlir::Attribute const_attr =
118 mlir::DenseStringElementsAttr::get(const_type, values);
119 return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult();
120 }
121
ExtractConstIntFromValue(mlir::Value value)122 StatusOr<int64_t> ExtractConstIntFromValue(mlir::Value value) {
123 value = GetForwardedInput(value);
124 if (value.isa<mlir::BlockArgument>())
125 return errors::Internal("unable get constant value from block argument");
126 mlir::DenseIntElementsAttr attr;
127 if (!matchPattern(value, m_Constant(&attr))) {
128 return errors::Internal(absl::StrCat("required constant value for ",
129 OpName(value.getDefiningOp())));
130 }
131 if (attr.size() != 1) {
132 return errors::Internal(absl::StrCat("expected 1 element, got ",
133 attr.size(), " for ",
134 OpName(value.getDefiningOp())));
135 }
136 auto a = *attr.value_begin<llvm::APInt>();
137 return a.getSExtValue();
138 }
139
ExtractConstVectorFromValue(mlir::Value value,llvm::SmallVector<int64_t,4> * out_vector)140 Status ExtractConstVectorFromValue(mlir::Value value,
141 llvm::SmallVector<int64_t, 4>* out_vector) {
142 value = GetForwardedInput(value);
143 if (value.isa<mlir::BlockArgument>())
144 return errors::Internal("unable get constant value from block argument");
145 mlir::DenseIntElementsAttr attr;
146 if (!matchPattern(value, m_Constant(&attr))) {
147 return errors::Internal(
148 absl::StrCat("failed to extract constant value from ",
149 value.getDefiningOp()->getName().getStringRef().str()));
150 }
151 for (const mlir::APInt& index : attr)
152 out_vector->emplace_back(index.getSExtValue());
153 return OkStatus();
154 }
155
CreateIntScalarConst(const int64_t value,mlir::OpBuilder builder,mlir::Location loc,bool use_int64)156 mlir::Value CreateIntScalarConst(const int64_t value, mlir::OpBuilder builder,
157 mlir::Location loc, bool use_int64) {
158 if (use_int64) {
159 return builder.create<mlir::TF::ConstOp>(
160 loc, mlir::DenseIntElementsAttr::get(
161 mlir::RankedTensorType::get({}, builder.getI64Type()), value));
162 } else {
163 return builder.create<mlir::TF::ConstOp>(
164 loc, mlir::DenseIntElementsAttr::get(
165 mlir::RankedTensorType::get({}, builder.getI32Type()),
166 static_cast<int32_t>(value)));
167 }
168 }
169
CreateZeroScalarConst(mlir::OpBuilder & builder,mlir::Location loc,mlir::Type type)170 absl::optional<mlir::Value> CreateZeroScalarConst(mlir::OpBuilder& builder,
171 mlir::Location loc,
172 mlir::Type type) {
173 if (type.isF64()) {
174 return builder.create<mlir::TF::ConstOp>(
175 loc, mlir::DenseFPElementsAttr::get(
176 mlir::RankedTensorType::get({}, builder.getF64Type()),
177 static_cast<double>(0.)));
178 } else if (type.isF32()) {
179 return builder.create<mlir::TF::ConstOp>(
180 loc, mlir::DenseFPElementsAttr::get(
181 mlir::RankedTensorType::get({}, builder.getF32Type()),
182 static_cast<float>(0.f)));
183 } else if (type.isInteger(32)) {
184 return builder.create<mlir::TF::ConstOp>(
185 loc, mlir::DenseIntElementsAttr::get(
186 mlir::RankedTensorType::get({}, builder.getI32Type()),
187 static_cast<int32_t>(0)));
188 } else if (type.isInteger(64)) {
189 return builder.create<mlir::TF::ConstOp>(
190 loc, mlir::DenseIntElementsAttr::get(
191 mlir::RankedTensorType::get({}, builder.getI64Type()),
192 static_cast<int64_t>(0)));
193 } else {
194 return absl::nullopt;
195 }
196 }
197
SelectScalarValueFromArray(mlir::OpBuilder & builder,int index,mlir::Location location,mlir::Value array)198 StatusOr<mlir::Value> SelectScalarValueFromArray(mlir::OpBuilder& builder,
199 int index,
200 mlir::Location location,
201 mlir::Value array) {
202 mlir::TensorType arrayType = array.getType().cast<mlir::TensorType>();
203 if (arrayType.getRank() != 2 || arrayType.getDimSize(0) != 1) {
204 return errors::InvalidArgument("Input array must have shape [1, N].");
205 }
206
207 mlir::TF::SliceOp sliced_value = builder.create<mlir::TF::SliceOp>(
208 location, mlir::RankedTensorType::get({1, 1}, arrayType.getElementType()),
209 /*input=*/array,
210 /*begin=*/IntConst(builder, location, {0, index}),
211 /*size=*/IntConst(builder, location, {1, 1}));
212
213 // Reshape the sliced shape (1,1) tensor to shape 0 scalar.
214 auto scalar_size_type =
215 mlir::RankedTensorType::get({}, builder.getIntegerType(32));
216 mlir::Value scalar_shape = mlir::TF::collection_ops_util::GetR1Const(
217 scalar_size_type.getShape(), builder, location);
218 mlir::Value scalar_sliced_value = builder.create<mlir::TF::ReshapeOp>(
219 location, mlir::ArrayRef<mlir::Type>{scalar_size_type},
220 mlir::ArrayRef<mlir::Value>{sliced_value.output(), scalar_shape},
221 mlir::ArrayRef<mlir::NamedAttribute>{});
222 return scalar_sliced_value;
223 }
224
GetSubtypeOrSelf(mlir::Value val)225 mlir::Type GetSubtypeOrSelf(mlir::Value val) {
226 mlir::Type type = val.getType();
227 if (auto type_with_subtype =
228 mlir::getElementTypeOrSelf(val)
229 .dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>()) {
230 if (type_with_subtype.GetSubtypes().size() == 1) {
231 type = type_with_subtype.GetSubtypes().front();
232 }
233 }
234 return type;
235 }
236
237 } // namespace dtensor
238 } // namespace tensorflow
239