xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/shape_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/shape_utils.h"
17 
18 #include "llvm/Support/FormatVariadic.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
24 #include "mlir/IR/Location.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
27 #include "mlir/IR/Value.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h"
29 #include "tensorflow/core/public/version.h"
30 #include "tensorflow/dtensor/cc/constants.h"
31 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
32 #include "tensorflow/dtensor/mlir/value_utils.h"
33 
34 namespace tensorflow {
35 namespace dtensor {
36 
ExtractGlobalInputShape(mlir::OpOperand & input_value)37 StatusOr<llvm::ArrayRef<int64_t>> ExtractGlobalInputShape(
38     mlir::OpOperand& input_value) {
39   const int operand_index = input_value.getOperandNumber();
40   auto input_defining_op = input_value.get().getDefiningOp();
41 
42   if (input_defining_op) {
43     if (auto layout_op =
44             llvm::dyn_cast<mlir::TF::DTensorLayout>(input_defining_op)) {
45       auto global_shape = layout_op.global_shape();
46       if (!global_shape)
47         return errors::Internal("global_shape does not have static rank");
48       return *global_shape;
49     }
50     return ExtractGlobalOutputShape(input_value.get().cast<mlir::OpResult>());
51   }
52 
53   // If we reach this point, we're working with a function argument.
54   auto op = input_value.getOwner();
55   auto enclosing_function = op->getParentOfType<mlir::func::FuncOp>();
56   if (!enclosing_function)
57     return errors::InvalidArgument(
58         llvm::formatv("Could not find global shape of {0}-th input to op: {1}",
59                       operand_index, op->getName())
60             .str());
61 
62   auto block_arg = input_value.get().dyn_cast<mlir::BlockArgument>();
63   auto global_shape_attr =
64       enclosing_function.getArgAttrOfType<mlir::TF::ShapeAttr>(
65           block_arg.getArgNumber(), kGlobalShapeDialectAttr);
66   if (!global_shape_attr)
67     return errors::InvalidArgument(
68         "`tf._global_shape` attribute of operation not found.");
69 
70   return global_shape_attr.getShape();
71 }
72 
ExtractGlobalOutputShape(mlir::OpResult result_value)73 StatusOr<llvm::ArrayRef<int64_t>> ExtractGlobalOutputShape(
74     mlir::OpResult result_value) {
75   auto op = result_value.getOwner();
76   const int output_index = result_value.getResultNumber();
77 
78   if (op->getOpResult(output_index).hasOneUse()) {
79     auto user = op->getOpResult(output_index).getUses().begin().getUser();
80     if (auto layout_op = mlir::dyn_cast<mlir::TF::DTensorLayout>(user)) {
81       auto global_shape = layout_op.global_shape();
82       if (!global_shape)
83         return errors::Internal("global_shape does not have static rank");
84       return *global_shape;
85     }
86   }
87 
88   auto global_shape_attr = op->getAttrOfType<mlir::ArrayAttr>(kGlobalShape);
89   if (!global_shape_attr)
90     return errors::InvalidArgument(
91         "`_global_shape` attribute of operation not found.");
92 
93   const int num_results = op->getNumResults();
94   assert(global_shape_attr.size() == num_results);
95 
96   if (output_index >= op->getNumResults())
97     return errors::InvalidArgument(
98         llvm::formatv("Requested global shape of {0} output but op has only "
99                       "{1} return values.",
100                       output_index, num_results)
101             .str());
102 
103   auto shape_attr = global_shape_attr[output_index];
104   return shape_attr.cast<mlir::TF::ShapeAttr>().getShape();
105 }
106 
107 namespace {
108 
109 // Extracts attributes from a MLIR operation, including derived attributes, into
110 // one NamedAttrList.
GetAllAttributesFromOperation(mlir::Operation * op)111 mlir::NamedAttrList GetAllAttributesFromOperation(mlir::Operation* op) {
112   mlir::NamedAttrList attr_list;
113   attr_list.append(op->getAttrDictionary().getValue());
114 
115   if (auto derived = llvm::dyn_cast<mlir::DerivedAttributeOpInterface>(op)) {
116     auto materialized = derived.materializeDerivedAttributes();
117     attr_list.append(materialized.getValue());
118   }
119 
120   return attr_list;
121 }
122 
123 // Infers output shape of `op` given its local operand shape. For shape
124 // inference function that requires input operation to be a constant, if input
125 // operation is `DTensorLayout` op, then we use input of DTensorLayout op
126 // instead for correct constant matching.
InferShapeOfTFOpWithCustomOperandConstantFn(llvm::Optional<mlir::Location> location,mlir::Operation * op,int64_t graph_version,llvm::SmallVectorImpl<mlir::ShapedTypeComponents> & inferred_return_shapes)127 mlir::LogicalResult InferShapeOfTFOpWithCustomOperandConstantFn(
128     llvm::Optional<mlir::Location> location, mlir::Operation* op,
129     int64_t graph_version,
130     llvm::SmallVectorImpl<mlir::ShapedTypeComponents>& inferred_return_shapes) {
131   if (auto type_op = llvm::dyn_cast<mlir::InferTypeOpInterface>(op)) {
132     auto attributes = GetAllAttributesFromOperation(op);
133     llvm::SmallVector<mlir::Type, 4> inferred_return_types;
134     auto result = type_op.inferReturnTypes(
135         op->getContext(), location, op->getOperands(),
136         mlir::DictionaryAttr::get(op->getContext(), attributes),
137         op->getRegions(), inferred_return_types);
138     if (failed(result)) return mlir::failure();
139 
140     inferred_return_shapes.resize(inferred_return_types.size());
141     for (const auto& inferred_return_type :
142          llvm::enumerate(inferred_return_types)) {
143       if (auto shaped_type =
144               inferred_return_type.value().dyn_cast<mlir::ShapedType>()) {
145         if (shaped_type.hasRank()) {
146           inferred_return_shapes[inferred_return_type.index()] =
147               mlir::ShapedTypeComponents(shaped_type.getShape(),
148                                          shaped_type.getElementType());
149         } else {
150           inferred_return_shapes[inferred_return_type.index()] =
151               mlir::ShapedTypeComponents(shaped_type.getElementType());
152         }
153       }
154     }
155 
156     return mlir::success();
157   }
158 
159   if (auto shape_type_op =
160           llvm::dyn_cast<mlir::InferShapedTypeOpInterface>(op)) {
161     auto attributes = GetAllAttributesFromOperation(op);
162     return shape_type_op.inferReturnTypeComponents(
163         op->getContext(), location, op->getOperands(),
164         mlir::DictionaryAttr::get(op->getContext(), attributes),
165         op->getRegions(), inferred_return_shapes);
166   }
167 
168   // If `operand` is from DTensorLayout op, use input value of DTensorLayout op
169   // instead.
170   auto operand_as_constant_fn = [](mlir::Value operand) -> mlir::Attribute {
171     while (auto input_op = llvm::dyn_cast_or_null<mlir::TF::DTensorLayout>(
172                operand.getDefiningOp())) {
173       operand = input_op.input();
174     }
175 
176     mlir::Attribute attr;
177     if (matchPattern(operand, m_Constant(&attr))) return attr;
178     return nullptr;
179   };
180 
181   auto op_result_as_shape_fn =
182       [](shape_inference::InferenceContext& ic,
183          mlir::OpResult op_result) -> shape_inference::ShapeHandle {
184     auto rt = op_result.getType().dyn_cast<mlir::RankedTensorType>();
185     if (!rt || rt.getRank() != 1 || !rt.hasStaticShape()) return {};
186 
187     std::vector<shape_inference::DimensionHandle> dims(rt.getDimSize(0),
188                                                        ic.UnknownDim());
189     mlir::Attribute attr;
190     if (matchPattern(op_result, m_Constant(&attr))) {
191       auto elements = attr.dyn_cast<mlir::DenseIntElementsAttr>();
192       if (elements)
193         for (const auto& element :
194              llvm::enumerate(elements.getValues<llvm::APInt>()))
195           dims[element.index()] = ic.MakeDim(element.value().getSExtValue());
196     }
197     return ic.MakeShape(dims);
198   };
199 
200   auto result_element_type_fn = [](int) -> mlir::Type { return nullptr; };
201 
202   return mlir::TF::InferReturnTypeComponentsForTFOp(
203       location, op, graph_version, operand_as_constant_fn,
204       op_result_as_shape_fn, result_element_type_fn, inferred_return_shapes);
205 }
206 
207 }  // namespace
208 
InferSPMDExpandedLocalShape(mlir::Operation * op)209 mlir::Operation* InferSPMDExpandedLocalShape(mlir::Operation* op) {
210   llvm::SmallVector<mlir::ShapedTypeComponents, 4> inferred_return_types;
211   (void)InferShapeOfTFOpWithCustomOperandConstantFn(
212       op->getLoc(), op, TF_GRAPH_DEF_VERSION, inferred_return_types);
213   assert(inferred_return_types.size() == op->getNumResults());
214 
215   for (auto it : llvm::zip(inferred_return_types, op->getOpResults())) {
216     const auto& return_type = std::get<0>(it);
217     auto& op_result = std::get<1>(it);
218     const auto element_type =
219         op_result.getType().cast<mlir::TensorType>().getElementType();
220 
221     if (return_type.hasRank()) {
222       op_result.setType(
223           mlir::RankedTensorType::get(return_type.getDims(), element_type));
224     } else {
225       op_result.setType(mlir::UnrankedTensorType::get(element_type));
226     }
227   }
228 
229   return op;
230 }
231 
GetShapeOfValue(const mlir::Value & value,bool fail_on_dynamic)232 StatusOr<llvm::ArrayRef<int64_t>> GetShapeOfValue(const mlir::Value& value,
233                                                   bool fail_on_dynamic) {
234   // Getting the subtype or self allows supporting extracting the underlying
235   // shape that variant or resource tensors point to.
236   mlir::Type type = GetSubtypeOrSelf(value);
237   if (auto ranked_type = type.dyn_cast<mlir::RankedTensorType>()) {
238     if (ranked_type.hasStaticShape() || !fail_on_dynamic)
239       return ranked_type.getShape();
240     else
241       return errors::InvalidArgument("value shape is not static");
242   }
243   return errors::InvalidArgument("value type is not a RankedTensorType");
244 }
245 
GetGlobalShapeOfValueFromDTensorLayout(const mlir::Value & value)246 StatusOr<llvm::ArrayRef<int64_t>> GetGlobalShapeOfValueFromDTensorLayout(
247     const mlir::Value& value) {
248   if (value.isa<mlir::OpResult>() &&
249       mlir::isa<mlir::TF::DTensorLayout>(value.getDefiningOp())) {
250     auto layout_op = mlir::cast<mlir::TF::DTensorLayout>(value.getDefiningOp());
251     if (layout_op.global_shape()) return layout_op.global_shape().getValue();
252   } else if (value.hasOneUse() &&
253              mlir::isa<mlir::TF::DTensorLayout>(*value.getUsers().begin())) {
254     auto layout_op =
255         mlir::cast<mlir::TF::DTensorLayout>(*value.getUsers().begin());
256     if (layout_op.global_shape()) return layout_op.global_shape().getValue();
257   }
258   return errors::InvalidArgument(
259       "consumer or producer of value is not a DTensorLayout");
260 }
261 
262 }  // namespace dtensor
263 }  // namespace tensorflow
264