xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/value_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/value_utils.h"
17 
18 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
20 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
21 #include "tensorflow/core/platform/errors.h"
22 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
23 #include "tensorflow/dtensor/mlir/op_utils.h"
24 
25 namespace tensorflow {
26 namespace dtensor {
27 namespace {
28 
29 // Given a mlir::Value will trace the value back through
30 // DTensorLayout and basic blocks of while loops.
31 // This is like a reverse version of TraceUseToNextTFOp.
GetForwardedInput(mlir::Value value)32 mlir::Value GetForwardedInput(mlir::Value value) {
33   bool value_updated;
34   do {
35     value_updated = false;
36     if (mlir::BlockArgument argument = value.dyn_cast<mlir::BlockArgument>()) {
37       mlir::Region* region = argument.getParentRegion();
38       if (region == nullptr) break;
39       mlir::Operation* parent_op = region->getParentOp();
40       // TODO(bfontain): handle if and other control flow blocks.
41       if (mlir::TF::WhileRegionOp while_op =
42               mlir::dyn_cast<mlir::TF::WhileRegionOp>(parent_op)) {
43         value = while_op.getOperand(argument.getArgNumber());
44         value_updated = true;
45       }
46     } else {
47       mlir::Operation* op = value.getDefiningOp();
48       // TODO(bfontain): Add cases for identity and control flow return values.
49       if (mlir::TF::DTensorLayout layout_op =
50               mlir::dyn_cast<mlir::TF::DTensorLayout>(op)) {
51         value = layout_op.input();
52         value_updated = true;
53       }
54     }
55   } while (value_updated);
56 
57   return value;
58 }
59 }  // namespace
60 
61 namespace ops_util = ::mlir::TF::collection_ops_util;
62 
ValueRank(mlir::Value operand_value)63 int ValueRank(mlir::Value operand_value) {
64   mlir::Type type = GetSubtypeOrSelf(operand_value);
65   const auto operand_type = type.cast<mlir::TensorType>();
66   if (!operand_type.hasRank()) return -1;
67   return operand_type.getRank();
68 }
69 
EffectivelyScalarR1Type(mlir::Type element_type)70 mlir::RankedTensorType EffectivelyScalarR1Type(mlir::Type element_type) {
71   return mlir::RankedTensorType::get({1}, element_type);
72 }
73 
ReshapeSizeTypeToScalar(mlir::OpBuilder builder,mlir::Location loc,mlir::Value tensor)74 mlir::Value ReshapeSizeTypeToScalar(mlir::OpBuilder builder, mlir::Location loc,
75                                     mlir::Value tensor) {
76   auto scalar_type =
77       mlir::RankedTensorType::get({}, builder.getIntegerType(32));
78   mlir::Value scalar_shape =
79       ops_util::GetR1Const(scalar_type.getShape(), builder, loc);
80   return builder.create<mlir::TF::ReshapeOp>(
81       loc, mlir::ArrayRef<mlir::Type>{scalar_type},
82       mlir::ArrayRef<mlir::Value>{tensor, scalar_shape});
83 }
84 
IntConst(mlir::OpBuilder & builder,mlir::Location loc,llvm::ArrayRef<int32> values)85 mlir::Value IntConst(mlir::OpBuilder& builder, mlir::Location loc,
86                      llvm::ArrayRef<int32> values) {
87   auto const_type = mlir::RankedTensorType::get(
88       {static_cast<int64_t>(values.size())}, builder.getIntegerType(32));
89   mlir::Attribute const_attr =
90       mlir::DenseIntElementsAttr::get(const_type, values);
91   return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult();
92 }
93 
Int64Const(mlir::OpBuilder & builder,mlir::Location loc,llvm::ArrayRef<int64_t> values)94 mlir::Value Int64Const(mlir::OpBuilder& builder, mlir::Location loc,
95                        llvm::ArrayRef<int64_t> values) {
96   auto const_type = mlir::RankedTensorType::get(
97       {static_cast<int64_t>(values.size())}, builder.getIntegerType(64));
98   mlir::Attribute const_attr =
99       mlir::DenseIntElementsAttr::get(const_type, values);
100   return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult();
101 }
102 
FloatConst(mlir::OpBuilder & builder,mlir::Location loc,llvm::ArrayRef<float> values)103 mlir::Value FloatConst(mlir::OpBuilder& builder, mlir::Location loc,
104                        llvm::ArrayRef<float> values) {
105   mlir::RankedTensorType const_type = mlir::RankedTensorType::get(
106       {static_cast<int64_t>(values.size())}, builder.getF32Type());
107   mlir::Attribute const_attr =
108       mlir::DenseFPElementsAttr::get(const_type, values);
109   return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult();
110 }
111 
StringConst(mlir::OpBuilder & builder,mlir::Location loc,llvm::ArrayRef<llvm::StringRef> values)112 mlir::Value StringConst(mlir::OpBuilder& builder, mlir::Location loc,
113                         llvm::ArrayRef<llvm::StringRef> values) {
114   auto const_type =
115       mlir::RankedTensorType::get({static_cast<int64_t>(values.size())},
116                                   builder.getType<mlir::TF::StringType>());
117   mlir::Attribute const_attr =
118       mlir::DenseStringElementsAttr::get(const_type, values);
119   return builder.create<mlir::TF::ConstOp>(loc, const_attr).getResult();
120 }
121 
ExtractConstIntFromValue(mlir::Value value)122 StatusOr<int64_t> ExtractConstIntFromValue(mlir::Value value) {
123   value = GetForwardedInput(value);
124   if (value.isa<mlir::BlockArgument>())
125     return errors::Internal("unable get constant value from block argument");
126   mlir::DenseIntElementsAttr attr;
127   if (!matchPattern(value, m_Constant(&attr))) {
128     return errors::Internal(absl::StrCat("required constant value for ",
129                                          OpName(value.getDefiningOp())));
130   }
131   if (attr.size() != 1) {
132     return errors::Internal(absl::StrCat("expected 1 element, got ",
133                                          attr.size(), " for ",
134                                          OpName(value.getDefiningOp())));
135   }
136   auto a = *attr.value_begin<llvm::APInt>();
137   return a.getSExtValue();
138 }
139 
ExtractConstVectorFromValue(mlir::Value value,llvm::SmallVector<int64_t,4> * out_vector)140 Status ExtractConstVectorFromValue(mlir::Value value,
141                                    llvm::SmallVector<int64_t, 4>* out_vector) {
142   value = GetForwardedInput(value);
143   if (value.isa<mlir::BlockArgument>())
144     return errors::Internal("unable get constant value from block argument");
145   mlir::DenseIntElementsAttr attr;
146   if (!matchPattern(value, m_Constant(&attr))) {
147     return errors::Internal(
148         absl::StrCat("failed to extract constant value from ",
149                      value.getDefiningOp()->getName().getStringRef().str()));
150   }
151   for (const mlir::APInt& index : attr)
152     out_vector->emplace_back(index.getSExtValue());
153   return OkStatus();
154 }
155 
CreateIntScalarConst(const int64_t value,mlir::OpBuilder builder,mlir::Location loc,bool use_int64)156 mlir::Value CreateIntScalarConst(const int64_t value, mlir::OpBuilder builder,
157                                  mlir::Location loc, bool use_int64) {
158   if (use_int64) {
159     return builder.create<mlir::TF::ConstOp>(
160         loc, mlir::DenseIntElementsAttr::get(
161                  mlir::RankedTensorType::get({}, builder.getI64Type()), value));
162   } else {
163     return builder.create<mlir::TF::ConstOp>(
164         loc, mlir::DenseIntElementsAttr::get(
165                  mlir::RankedTensorType::get({}, builder.getI32Type()),
166                  static_cast<int32_t>(value)));
167   }
168 }
169 
CreateZeroScalarConst(mlir::OpBuilder & builder,mlir::Location loc,mlir::Type type)170 absl::optional<mlir::Value> CreateZeroScalarConst(mlir::OpBuilder& builder,
171                                                   mlir::Location loc,
172                                                   mlir::Type type) {
173   if (type.isF64()) {
174     return builder.create<mlir::TF::ConstOp>(
175         loc, mlir::DenseFPElementsAttr::get(
176                  mlir::RankedTensorType::get({}, builder.getF64Type()),
177                  static_cast<double>(0.)));
178   } else if (type.isF32()) {
179     return builder.create<mlir::TF::ConstOp>(
180         loc, mlir::DenseFPElementsAttr::get(
181                  mlir::RankedTensorType::get({}, builder.getF32Type()),
182                  static_cast<float>(0.f)));
183   } else if (type.isInteger(32)) {
184     return builder.create<mlir::TF::ConstOp>(
185         loc, mlir::DenseIntElementsAttr::get(
186                  mlir::RankedTensorType::get({}, builder.getI32Type()),
187                  static_cast<int32_t>(0)));
188   } else if (type.isInteger(64)) {
189     return builder.create<mlir::TF::ConstOp>(
190         loc, mlir::DenseIntElementsAttr::get(
191                  mlir::RankedTensorType::get({}, builder.getI64Type()),
192                  static_cast<int64_t>(0)));
193   } else {
194     return absl::nullopt;
195   }
196 }
197 
SelectScalarValueFromArray(mlir::OpBuilder & builder,int index,mlir::Location location,mlir::Value array)198 StatusOr<mlir::Value> SelectScalarValueFromArray(mlir::OpBuilder& builder,
199                                                  int index,
200                                                  mlir::Location location,
201                                                  mlir::Value array) {
202   mlir::TensorType arrayType = array.getType().cast<mlir::TensorType>();
203   if (arrayType.getRank() != 2 || arrayType.getDimSize(0) != 1) {
204     return errors::InvalidArgument("Input array must have shape [1, N].");
205   }
206 
207   mlir::TF::SliceOp sliced_value = builder.create<mlir::TF::SliceOp>(
208       location, mlir::RankedTensorType::get({1, 1}, arrayType.getElementType()),
209       /*input=*/array,
210       /*begin=*/IntConst(builder, location, {0, index}),
211       /*size=*/IntConst(builder, location, {1, 1}));
212 
213   // Reshape the sliced shape (1,1) tensor to shape 0 scalar.
214   auto scalar_size_type =
215       mlir::RankedTensorType::get({}, builder.getIntegerType(32));
216   mlir::Value scalar_shape = mlir::TF::collection_ops_util::GetR1Const(
217       scalar_size_type.getShape(), builder, location);
218   mlir::Value scalar_sliced_value = builder.create<mlir::TF::ReshapeOp>(
219       location, mlir::ArrayRef<mlir::Type>{scalar_size_type},
220       mlir::ArrayRef<mlir::Value>{sliced_value.output(), scalar_shape},
221       mlir::ArrayRef<mlir::NamedAttribute>{});
222   return scalar_sliced_value;
223 }
224 
GetSubtypeOrSelf(mlir::Value val)225 mlir::Type GetSubtypeOrSelf(mlir::Value val) {
226   mlir::Type type = val.getType();
227   if (auto type_with_subtype =
228           mlir::getElementTypeOrSelf(val)
229               .dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>()) {
230     if (type_with_subtype.GetSubtypes().size() == 1) {
231       type = type_with_subtype.GetSubtypes().front();
232     }
233   }
234   return type;
235 }
236 
237 }  // namespace dtensor
238 }  // namespace tensorflow
239