xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/layout_parsing.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/layout_parsing.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/strings/str_cat.h"
22 #include "absl/types/optional.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
28 #include "mlir/IR/Operation.h"  // from @llvm-project
29 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/dtensor/cc/constants.h"
34 #include "tensorflow/dtensor/cc/tensor_layout.h"
35 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
36 #include "tensorflow/stream_executor/lib/statusor.h"
37 
38 namespace tensorflow {
39 namespace dtensor {
40 namespace {
41 
OpUsesV2LayoutAnnotation(mlir::Operation * op)42 bool OpUsesV2LayoutAnnotation(mlir::Operation* op) {
43   return !op->getUsers().empty() &&
44          llvm::all_of(op->getUsers(), [](mlir::Operation* user_op) {
45            return llvm::isa<mlir::TF::DTensorLayout>(user_op);
46          });
47 }
48 
49 }  // namespace
50 
ExtractSingleLayoutFromOp(mlir::Operation * op,std::string attr_name)51 StatusOr<absl::optional<Layout>> ExtractSingleLayoutFromOp(
52     mlir::Operation* op, std::string attr_name) {
53   absl::optional<Layout> out;
54 
55   // If v2 layout propagation algorithm is used, parse layout from DTensorLayout
56   // op.
57   if (OpUsesV2LayoutAnnotation(op)) {
58     // If DTensorLayout is used, then DTensorLayout op is the only consumer for
59     // the operation output value.
60     auto users = op->getUsers();
61     out.emplace(llvm::cast<mlir::TF::DTensorLayout>(*users.begin()).layout());
62   } else {
63     TF_ASSIGN_OR_RETURN(auto layouts, ExtractLayoutFromOp(op, attr_name));
64     if (layouts.empty()) return out;
65     if (layouts.size() != 1) {
66       return errors::Internal(
67           "Extracting single layout on Op that has multiple layout attached is "
68           "ambiguous. op : ",
69           op->getName().getStringRef().str());
70     }
71     out.swap(layouts[0]);
72   }
73   return out;
74 }
75 
ExtractSingleLayoutFromOp(mlir::Operation * op)76 StatusOr<absl::optional<Layout>> ExtractSingleLayoutFromOp(
77     mlir::Operation* op) {
78   return ExtractSingleLayoutFromOp(op, kLayoutAttr);
79 }
80 
ExtractRequiredSingleLayoutFromOp(mlir::Operation * op)81 StatusOr<Layout> ExtractRequiredSingleLayoutFromOp(mlir::Operation* op) {
82   TF_ASSIGN_OR_RETURN(absl::optional<Layout> layout,
83                       ExtractSingleLayoutFromOp(op));
84   if (!layout) return errors::Internal("expected layout missing");
85 
86   return *layout;
87 }
88 
ExtractLayoutFromOp(mlir::Operation * op,std::string attr_name)89 StatusOr<std::vector<absl::optional<Layout>>> ExtractLayoutFromOp(
90     mlir::Operation* op, std::string attr_name) {
91   std::vector<absl::optional<Layout>> outs;
92   outs.reserve(op->getNumResults());
93 
94   // If v2 layout propagation algorithm is used, parse layout from DTensorLayout
95   // op.
96   if (OpUsesV2LayoutAnnotation(op)) {
97     for (auto op_result : op->getOpResults()) {
98       outs.emplace_back(
99           llvm::cast<mlir::TF::DTensorLayout>(*op_result.getUsers().begin())
100               .layout());
101     }
102   } else {
103     auto serialized_layouts = op->getAttrOfType<mlir::ArrayAttr>(attr_name);
104     if (!serialized_layouts) return outs;
105 
106     for (auto const& attr : serialized_layouts) {
107       auto attr_str = attr.cast<mlir::StringAttr>().getValue().str();
108       if (!attr_str.empty()) {
109         TF_ASSIGN_OR_RETURN(auto layout, Layout::FromString(attr_str));
110         outs.emplace_back(std::move(layout));
111       } else {
112         outs.emplace_back(absl::nullopt);
113       }
114     }
115   }
116   return outs;
117 }
118 
ExtractLayoutFromOp(mlir::Operation * op)119 StatusOr<std::vector<absl::optional<Layout>>> ExtractLayoutFromOp(
120     mlir::Operation* op) {
121   return ExtractLayoutFromOp(op, kLayoutAttr);
122 }
123 
ExtractRequiredLayoutFromOp(mlir::Operation * op)124 StatusOr<std::vector<Layout>> ExtractRequiredLayoutFromOp(mlir::Operation* op) {
125   TF_ASSIGN_OR_RETURN(std::vector<absl::optional<Layout>> optional_layouts,
126                       ExtractLayoutFromOp(op));
127   std::vector<Layout> layouts;
128   for (const absl::optional<Layout>& layout : optional_layouts) {
129     if (!layout) return errors::Internal("expected layout missing");
130     layouts.emplace_back(*layout);
131   }
132 
133   return layouts;
134 }
135 
ExtractDeviceMeshEnclosingCluster(mlir::Operation * op)136 StatusOr<Mesh> ExtractDeviceMeshEnclosingCluster(mlir::Operation* op) {
137   auto enclosing_cluster = op->getParentOfType<mlir::tf_device::ClusterOp>();
138   if (!enclosing_cluster)
139     return errors::InvalidArgument("op is not inside a device mesh cluster.");
140 
141   TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshFromOp(enclosing_cluster));
142   if (!mesh)
143     return errors::InvalidArgument(
144         "op's enclosing device cluster does not have mesh defined.");
145 
146   return *mesh;
147 }
148 
ExtractDeviceMeshFromOp(mlir::Operation * op)149 StatusOr<absl::optional<Mesh>> ExtractDeviceMeshFromOp(mlir::Operation* op) {
150   absl::optional<Mesh> extracted_mesh;
151   if (op == nullptr) return extracted_mesh;
152 
153   auto mesh_str_attr = op->getAttrOfType<mlir::StringAttr>(kMeshAttr);
154   if (!mesh_str_attr) return extracted_mesh;
155 
156   TF_ASSIGN_OR_RETURN(Mesh mesh,
157                       Mesh::FromString(mesh_str_attr.getValue().str()));
158 
159   extracted_mesh.emplace(std::move(mesh));
160   return extracted_mesh;
161 }
162 
ExtractLayoutFromOperand(mlir::Value operand)163 StatusOr<absl::optional<Layout>> ExtractLayoutFromOperand(mlir::Value operand) {
164   if (auto op_result = operand.dyn_cast<mlir::OpResult>()) {
165     mlir::Operation* op = op_result.getDefiningOp();
166     absl::optional<Layout> out;
167     if (auto layout_op = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) {
168       out.emplace(layout_op.layout());
169     } else {
170       const int result_number = op_result.getResultNumber();
171       TF_ASSIGN_OR_RETURN(auto layouts, ExtractLayoutFromOp(op, kLayoutAttr));
172 
173       if (layouts.empty()) return out;
174 
175       if (result_number >= layouts.size()) {
176         return errors::Internal(
177             "Expect to extract the ", result_number,
178             "-th output's layout, but "
179             "only see ",
180             layouts.size(), " outputs: ", op->getName().getStringRef().str());
181       }
182       out.swap(layouts[result_number]);
183     }
184     return out;
185   }
186 
187   auto block_arg = operand.dyn_cast<mlir::BlockArgument>();
188   if (!block_arg)
189     return errors::Internal(
190         "Operand is not either a OpResult or a BlockArgument. This should not "
191         "happen.");
192   auto func_op = mlir::dyn_cast_or_null<mlir::func::FuncOp>(
193       block_arg.getOwner()->getParentOp());
194   if (!func_op) {
195     return errors::InvalidArgument("op must be enclosed by a function");
196   }
197 
198   absl::optional<Layout> extracted_layout;
199   auto layout_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
200       block_arg.getArgNumber(), kCustomDeviceAttr);
201   if (!layout_attr) return extracted_layout;
202 
203   TF_ASSIGN_OR_RETURN(auto layout,
204                       Layout::FromString(layout_attr.getValue().str()));
205   extracted_layout.emplace(std::move(layout));
206   return extracted_layout;
207 }
208 
ExtractRequiredLayoutFromOperand(mlir::Value operand)209 StatusOr<Layout> ExtractRequiredLayoutFromOperand(mlir::Value operand) {
210   TF_ASSIGN_OR_RETURN(absl::optional<Layout> layout,
211                       ExtractLayoutFromOperand(operand));
212   if (!layout) return errors::Internal("expected layout missing");
213 
214   return *layout;
215 }
216 
ExtractRequiredLayoutFromOperands(mlir::Operation * op)217 StatusOr<std::vector<Layout>> ExtractRequiredLayoutFromOperands(
218     mlir::Operation* op) {
219   std::vector<Layout> layouts;
220   for (const auto& operand : op->getOpOperands()) {
221     TF_ASSIGN_OR_RETURN(auto operand_layout,
222                         ExtractRequiredLayoutFromOperand(operand.get()));
223     layouts.emplace_back(operand_layout);
224   }
225   return layouts;
226 }
227 
SetLayoutOnOp(mlir::Operation * op,mlir::OpBuilder builder,absl::Span<const absl::optional<Layout>> layouts)228 void SetLayoutOnOp(mlir::Operation* op, mlir::OpBuilder builder,
229                    absl::Span<const absl::optional<Layout>> layouts) {
230   llvm::SmallVector<std::string, 8> serialized_layouts;
231   for (auto const& layout : layouts) {
232     serialized_layouts.emplace_back(layout.has_value() ? layout->ToString()
233                                                        : "");
234   }
235   op->setAttr(kLayoutAttr,
236               builder.getStrArrayAttr(llvm::SmallVector<llvm::StringRef, 8>(
237                   serialized_layouts.begin(), serialized_layouts.end())));
238 }
239 
SetLayoutOnOp(mlir::Operation * op,absl::Span<const absl::optional<Layout>> layouts)240 void SetLayoutOnOp(mlir::Operation* op,
241                    absl::Span<const absl::optional<Layout>> layouts) {
242   SetLayoutOnOp(op, mlir::OpBuilder(op), layouts);
243 }
244 
SetSingleLayoutOnOp(mlir::Operation * op,const Layout & layout)245 void SetSingleLayoutOnOp(mlir::Operation* op, const Layout& layout) {
246   SetLayoutOnOp(op, mlir::OpBuilder(op), {absl::optional<Layout>(layout)});
247 }
248 
ExtractLayoutFromFunctionReturnAttr(mlir::func::ReturnOp return_op,const int return_index)249 StatusOr<absl::optional<Layout>> ExtractLayoutFromFunctionReturnAttr(
250     mlir::func::ReturnOp return_op, const int return_index) {
251   absl::optional<Layout> layout;
252   // If value feeds into func op return op, then check to see if layout
253   // attribute is set for the return value.
254   auto function = return_op->getParentOfType<mlir::func::FuncOp>();
255   auto layout_attr_from_func_result =
256       function.getResultAttrOfType<mlir::StringAttr>(return_index,
257                                                      kCustomDefaultLayoutAttr);
258   if (!layout_attr_from_func_result) return layout;
259 
260   const std::string layout_string =
261       layout_attr_from_func_result.getValue().str();
262   auto result_layout_or_status = Layout::FromString(layout_string);
263   if (!result_layout_or_status.ok())
264     return errors::InvalidArgument(
265         llvm::formatv("Malformed default return layout received. {0} Received "
266                       "layout : {1}",
267                       result_layout_or_status.status().error_message(),
268                       layout_string)
269             .str());
270 
271   layout.emplace(result_layout_or_status.ValueOrDie());
272   return layout;
273 }
274 
275 }  // namespace dtensor
276 }  // namespace tensorflow
277