xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/resource_spmd_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/expansions/resource_spmd_expander.h"
17 
18 #include <algorithm>
19 #include <string>
20 
21 #include "absl/strings/str_join.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
33 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/dtensor/cc/constants.h"
36 #include "tensorflow/dtensor/cc/dstatus.h"
37 #include "tensorflow/dtensor/mlir/collectives.h"
38 #include "tensorflow/dtensor/mlir/layout_parsing.h"
39 #include "tensorflow/dtensor/mlir/op_utils.h"
40 #include "tensorflow/dtensor/mlir/shape_utils.h"
41 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
42 
43 namespace tensorflow {
44 namespace dtensor {
45 namespace {
46 
47 template <typename AttrType>
CreateOrGetMutableAttributeList(mlir::tf_device::ClusterOp op,std::string attr_name)48 std::vector<AttrType> CreateOrGetMutableAttributeList(
49     mlir::tf_device::ClusterOp op, std::string attr_name) {
50   auto array_attribute = op->getAttrOfType<mlir::ArrayAttr>(attr_name);
51 
52   std::vector<AttrType> output;
53   if (array_attribute) auto attr_list = array_attribute.getValue().vec();
54   return output;
55 }
56 
ValidateAndAssignResourceInputLayout(mlir::tf_device::ClusterOp op,const std::string & layout_string,const int resource_arg_index,mlir::OpBuilder * builder)57 Status ValidateAndAssignResourceInputLayout(mlir::tf_device::ClusterOp op,
58                                             const std::string& layout_string,
59                                             const int resource_arg_index,
60                                             mlir::OpBuilder* builder) {
61   const auto add_layout_as_attributes =
62       [&](std::vector<mlir::StringRef> new_resource_layouts,
63           std::vector<int> new_resource_indices, int resource_arg_index,
64           std::string layout) {
65         new_resource_layouts.emplace_back(layout);
66         new_resource_indices.emplace_back(resource_arg_index);
67         op->setAttr(kNewResourceArgLayouts,
68                     builder->getStrArrayAttr(
69                         llvm::ArrayRef<mlir::StringRef>(new_resource_layouts)));
70         op->setAttr(kNewResourceLayoutIndices,
71                     builder->getI32VectorAttr(new_resource_indices));
72       };
73 
74   auto resource_input_layouts_attrs =
75       CreateOrGetMutableAttributeList<mlir::StringAttr>(op,
76                                                         kNewResourceArgLayouts);
77   auto resource_input_indices_attrs =
78       CreateOrGetMutableAttributeList<mlir::IntegerAttr>(
79           op, kNewResourceLayoutIndices);
80   std::vector<llvm::StringRef> mutable_input_layouts;
81   std::vector<int> mutable_input_indices;
82   for (auto layout_index_pair :
83        llvm::zip(resource_input_indices_attrs, resource_input_layouts_attrs)) {
84     mutable_input_indices.emplace_back(std::get<0>(layout_index_pair).getInt());
85     mutable_input_layouts.emplace_back(
86         std::get<1>(layout_index_pair).getValue());
87   }
88 
89   if (!mutable_input_indices.empty()) {
90     assert(mutable_input_indices.size() == mutable_input_layouts.size());
91 
92     auto it = std::find(mutable_input_indices.begin(),
93                         mutable_input_indices.end(), resource_arg_index);
94 
95     if (it != mutable_input_indices.end()) {
96       // Input layout for given resource was already inferred from previous
97       // SPMD expansions. Check that layouts of resource are consistent.
98       auto previous_layout = mutable_input_layouts[std::distance(
99           mutable_input_indices.begin(), it)];
100 
101       // TODO(hongjunchoi): Implement relayout logic for resource ops.
102       if (layout_string != previous_layout.str())
103         return errors::InvalidArgument(
104             "Trying to assign a variable to a resource with a different "
105             "layout.");
106     } else {
107       add_layout_as_attributes(mutable_input_layouts, mutable_input_indices,
108                                resource_arg_index, layout_string);
109     }
110   } else {
111     add_layout_as_attributes(mutable_input_layouts, mutable_input_indices,
112                              resource_arg_index, layout_string);
113   }
114   return OkStatus();
115 }
116 
117 }  // namespace
118 
ExpandOp(mlir::Operation * op)119 StatusOr<mlir::Operation*> ResourceSPMDExpander::ExpandOp(mlir::Operation* op) {
120   // These ops need no special handling.
121   if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::DestroyResourceOp,
122                 mlir::TF::VarIsInitializedOp>(op))
123     return InferSPMDExpandedLocalShape(op);
124 
125   mlir::OpBuilder builder(op);
126 
127   // Output of read variable may need to be sliced, so it needs to be treated
128   // specially.
129   if (llvm::isa<mlir::TF::ReadVariableOp>(op)) {
130     builder.setInsertionPointAfter(op);
131     TF_ASSIGN_OR_RETURN(auto output_layout, ExtractSingleLayoutFromOp(op));
132     TF_ASSIGN_OR_RETURN(auto input_layout,
133                         ExtractLayoutFromOperand(op->getOperand(0)));
134     if (!output_layout)
135       TF_RETURN_WITH_CONTEXT(errors::Internal("output layout is missing"));
136     if (!input_layout)
137       TF_RETURN_WITH_CONTEXT(errors::Internal("input layout is missing"));
138     InferSPMDExpandedLocalShape(op);
139     llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
140     TF_ASSIGN_OR_RETURN(
141         auto final_output,
142         EmitAllScatter(builder, op->getOpResult(0), input_layout.value(),
143                        output_layout.value(), &newly_created_ops));
144     op->getOpResult(0).replaceAllUsesExcept(final_output, newly_created_ops);
145     return final_output.getDefiningOp();
146   }
147 
148   if (!llvm::isa<mlir::TF::AssignVariableOp, mlir::TF::AssignAddVariableOp,
149                  mlir::TF::AssignSubVariableOp>(op))
150     TF_RETURN_WITH_CONTEXT(errors::Internal("unsupported resource op"));
151 
152   TF_ASSIGN_OR_RETURN(absl::optional<Layout> output_layout,
153                       ExtractSingleLayoutFromOp(op));
154   TF_ASSIGN_OR_RETURN(absl::optional<Layout> resource_layout,
155                       ExtractLayoutFromOperand(op->getOperand(0)));
156   TF_ASSIGN_OR_RETURN(absl::optional<Layout> value_layout,
157                       ExtractLayoutFromOperand(op->getOperand(1)));
158 
159   // For assignment operations, the layout for the resource (first operand),
160   // when not present, is, inferred from the layout of the input value (second
161   // operand). We attach the inferred layout to the resource.
162   // Note that in the case that input_resource_value.getDefiningOp() exists, it
163   // is a DTensorLayout and this means that the corresponding block argument
164   // already has a layout set.
165   // If the resource is specified in the graph as an op (e.g. VarHandleOp), we
166   // attach the layout directly. Otherwise, the resource is an argument to the
167   // SPMD function, and we attach the layout to the appropriate argument.
168   auto input_resource_value = op->getOpOperand(0).get();
169   if (auto resource_producing_op = input_resource_value.getDefiningOp()) {
170     if (!resource_layout)
171       TF_RETURN_WITH_CONTEXT(errors::Internal("missing layout on resource"));
172     if (!value_layout)
173       TF_RETURN_WITH_CONTEXT(errors::Internal("missing layout on value"));
174     if (resource_layout != value_layout) {
175       TF_ASSIGN_OR_RETURN(auto new_value,
176                           EmitRelayout(op->getOperand(1), value_layout.value(),
177                                        resource_layout.value()));
178       op->setOperand(1, new_value);
179     }
180   } else {
181     if ((!resource_layout || resource_layout->IsEmpty()) && !value_layout)
182       TF_RETURN_WITH_CONTEXT(errors::Internal(
183           "at least one of resource or value layout must be set"));
184     // This error should not happen: if resource_layout is set, then we expect
185     // a DTensorLayout op between the resource tensor and this op, so we should
186     // actaully be in the if case rather than the else case.
187     if (resource_layout && !resource_layout->IsEmpty() && value_layout &&
188         resource_layout != value_layout)
189       TF_RETURN_WITH_CONTEXT(errors::Internal(
190           "if both resource and value layout are set they must be equal"));
191 
192     auto block_arg = input_resource_value.dyn_cast<mlir::BlockArgument>();
193     auto enclosing_device_cluster =
194         op->getParentOfType<mlir::tf_device::ClusterOp>();
195 
196     if (!enclosing_device_cluster)
197       TF_RETURN_WITH_CONTEXT(
198           errors::InvalidArgument("op must be enclosed by a cluster"));
199 
200     auto block_arg_index = block_arg.getArgNumber();
201 
202     // If layout of resource already exists, then check that layouts are
203     // consistent. Otherwise, add newly inferred layout of resource argument
204     // as attributes to the enclosing cluster op to be propagated to custom
205     // device.
206     std::string layout_string;
207     if (resource_layout && !resource_layout->IsEmpty())
208       layout_string = resource_layout->ToString();
209     else
210       layout_string = value_layout->ToString();
211     TF_RETURN_IF_ERROR(ValidateAndAssignResourceInputLayout(
212         enclosing_device_cluster, layout_string, block_arg_index, &builder));
213   }
214 
215   return InferSPMDExpandedLocalShape(op);
216 }
217 
218 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)219 ResourceSPMDExpander::ComputeLayoutForward(
220     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
221   // VarHandle and VarIsInitialized have 0 rank outputs.
222   if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::VarIsInitializedOp>(op))
223     return llvm::DenseMap<int, Layout>({{0, Layout::Empty()}});
224 
225   // Handling of resource destruction is no-op.
226   if (llvm::isa<mlir::TF::DestroyResourceOp>(op))
227     return llvm::DenseMap<int, Layout>();
228 
229   // Read variable ops have one input so infer the output layout if input
230   // layout exists.
231   if (llvm::isa<mlir::TF::ReadVariableOp>(op)) return input_layouts;
232 
233   // These ops do not have outputs, so do not infer any layout.
234   if (llvm::isa<mlir::TF::AssignVariableOp, mlir::TF::AssignAddVariableOp,
235                 mlir::TF::AssignSubVariableOp>(op)) {
236     return llvm::DenseMap<int, Layout>();
237   }
238   // Return an error if not any of the ops above.
239   return errors::InvalidArgument(
240       llvm::formatv(
241           "Found unexpected resource op {0} during layout propagation.",
242           OpName(op))
243           .str());
244 }
245 
246 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts,const llvm::DenseMap<int,Layout> & output_layouts)247 ResourceSPMDExpander::ComputeLayoutBackward(
248     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
249     const llvm::DenseMap<int, Layout>& output_layouts) {
250   // For Assign* ops, propagate the resource tensor layout to the tensor if
251   // resource tensor layout exists.
252   if (llvm::isa<mlir::TF::AssignVariableOp, mlir::TF::AssignAddVariableOp,
253                 mlir::TF::AssignSubVariableOp>(op)) {
254     if (input_layouts.find(0) != input_layouts.end())
255       return llvm::DenseMap<int, Layout>({{1, input_layouts.lookup(0)}});
256     return llvm::DenseMap<int, Layout>();
257   }
258   // Handling of these ops are no-ops.
259   if (llvm::isa<mlir::TF::DestroyResourceOp, mlir::TF::VarHandleOp,
260                 mlir::TF::VarIsInitializedOp, mlir::TF::ReadVariableOp>(op))
261     return llvm::DenseMap<int, Layout>();
262 
263   // Return an error if not any of the ops above.
264   return errors::InvalidArgument(
265       llvm::formatv(
266           "Found unexpected resource op {0} during layout propagation.",
267           OpName(op))
268           .str());
269 }
270 
271 }  // namespace dtensor
272 }  // namespace tensorflow
273