1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/expansions/resource_spmd_expander.h"
17
18 #include <algorithm>
19 #include <string>
20
21 #include "absl/strings/str_join.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
33 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/dtensor/cc/constants.h"
36 #include "tensorflow/dtensor/cc/dstatus.h"
37 #include "tensorflow/dtensor/mlir/collectives.h"
38 #include "tensorflow/dtensor/mlir/layout_parsing.h"
39 #include "tensorflow/dtensor/mlir/op_utils.h"
40 #include "tensorflow/dtensor/mlir/shape_utils.h"
41 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
42
43 namespace tensorflow {
44 namespace dtensor {
45 namespace {
46
47 template <typename AttrType>
CreateOrGetMutableAttributeList(mlir::tf_device::ClusterOp op,std::string attr_name)48 std::vector<AttrType> CreateOrGetMutableAttributeList(
49 mlir::tf_device::ClusterOp op, std::string attr_name) {
50 auto array_attribute = op->getAttrOfType<mlir::ArrayAttr>(attr_name);
51
52 std::vector<AttrType> output;
53 if (array_attribute) auto attr_list = array_attribute.getValue().vec();
54 return output;
55 }
56
ValidateAndAssignResourceInputLayout(mlir::tf_device::ClusterOp op,const std::string & layout_string,const int resource_arg_index,mlir::OpBuilder * builder)57 Status ValidateAndAssignResourceInputLayout(mlir::tf_device::ClusterOp op,
58 const std::string& layout_string,
59 const int resource_arg_index,
60 mlir::OpBuilder* builder) {
61 const auto add_layout_as_attributes =
62 [&](std::vector<mlir::StringRef> new_resource_layouts,
63 std::vector<int> new_resource_indices, int resource_arg_index,
64 std::string layout) {
65 new_resource_layouts.emplace_back(layout);
66 new_resource_indices.emplace_back(resource_arg_index);
67 op->setAttr(kNewResourceArgLayouts,
68 builder->getStrArrayAttr(
69 llvm::ArrayRef<mlir::StringRef>(new_resource_layouts)));
70 op->setAttr(kNewResourceLayoutIndices,
71 builder->getI32VectorAttr(new_resource_indices));
72 };
73
74 auto resource_input_layouts_attrs =
75 CreateOrGetMutableAttributeList<mlir::StringAttr>(op,
76 kNewResourceArgLayouts);
77 auto resource_input_indices_attrs =
78 CreateOrGetMutableAttributeList<mlir::IntegerAttr>(
79 op, kNewResourceLayoutIndices);
80 std::vector<llvm::StringRef> mutable_input_layouts;
81 std::vector<int> mutable_input_indices;
82 for (auto layout_index_pair :
83 llvm::zip(resource_input_indices_attrs, resource_input_layouts_attrs)) {
84 mutable_input_indices.emplace_back(std::get<0>(layout_index_pair).getInt());
85 mutable_input_layouts.emplace_back(
86 std::get<1>(layout_index_pair).getValue());
87 }
88
89 if (!mutable_input_indices.empty()) {
90 assert(mutable_input_indices.size() == mutable_input_layouts.size());
91
92 auto it = std::find(mutable_input_indices.begin(),
93 mutable_input_indices.end(), resource_arg_index);
94
95 if (it != mutable_input_indices.end()) {
96 // Input layout for given resource was already inferred from previous
97 // SPMD expansions. Check that layouts of resource are consistent.
98 auto previous_layout = mutable_input_layouts[std::distance(
99 mutable_input_indices.begin(), it)];
100
101 // TODO(hongjunchoi): Implement relayout logic for resource ops.
102 if (layout_string != previous_layout.str())
103 return errors::InvalidArgument(
104 "Trying to assign a variable to a resource with a different "
105 "layout.");
106 } else {
107 add_layout_as_attributes(mutable_input_layouts, mutable_input_indices,
108 resource_arg_index, layout_string);
109 }
110 } else {
111 add_layout_as_attributes(mutable_input_layouts, mutable_input_indices,
112 resource_arg_index, layout_string);
113 }
114 return OkStatus();
115 }
116
117 } // namespace
118
ExpandOp(mlir::Operation * op)119 StatusOr<mlir::Operation*> ResourceSPMDExpander::ExpandOp(mlir::Operation* op) {
120 // These ops need no special handling.
121 if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::DestroyResourceOp,
122 mlir::TF::VarIsInitializedOp>(op))
123 return InferSPMDExpandedLocalShape(op);
124
125 mlir::OpBuilder builder(op);
126
127 // Output of read variable may need to be sliced, so it needs to be treated
128 // specially.
129 if (llvm::isa<mlir::TF::ReadVariableOp>(op)) {
130 builder.setInsertionPointAfter(op);
131 TF_ASSIGN_OR_RETURN(auto output_layout, ExtractSingleLayoutFromOp(op));
132 TF_ASSIGN_OR_RETURN(auto input_layout,
133 ExtractLayoutFromOperand(op->getOperand(0)));
134 if (!output_layout)
135 TF_RETURN_WITH_CONTEXT(errors::Internal("output layout is missing"));
136 if (!input_layout)
137 TF_RETURN_WITH_CONTEXT(errors::Internal("input layout is missing"));
138 InferSPMDExpandedLocalShape(op);
139 llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
140 TF_ASSIGN_OR_RETURN(
141 auto final_output,
142 EmitAllScatter(builder, op->getOpResult(0), input_layout.value(),
143 output_layout.value(), &newly_created_ops));
144 op->getOpResult(0).replaceAllUsesExcept(final_output, newly_created_ops);
145 return final_output.getDefiningOp();
146 }
147
148 if (!llvm::isa<mlir::TF::AssignVariableOp, mlir::TF::AssignAddVariableOp,
149 mlir::TF::AssignSubVariableOp>(op))
150 TF_RETURN_WITH_CONTEXT(errors::Internal("unsupported resource op"));
151
152 TF_ASSIGN_OR_RETURN(absl::optional<Layout> output_layout,
153 ExtractSingleLayoutFromOp(op));
154 TF_ASSIGN_OR_RETURN(absl::optional<Layout> resource_layout,
155 ExtractLayoutFromOperand(op->getOperand(0)));
156 TF_ASSIGN_OR_RETURN(absl::optional<Layout> value_layout,
157 ExtractLayoutFromOperand(op->getOperand(1)));
158
159 // For assignment operations, the layout for the resource (first operand),
160 // when not present, is, inferred from the layout of the input value (second
161 // operand). We attach the inferred layout to the resource.
162 // Note that in the case that input_resource_value.getDefiningOp() exists, it
163 // is a DTensorLayout and this means that the corresponding block argument
164 // already has a layout set.
165 // If the resource is specified in the graph as an op (e.g. VarHandleOp), we
166 // attach the layout directly. Otherwise, the resource is an argument to the
167 // SPMD function, and we attach the layout to the appropriate argument.
168 auto input_resource_value = op->getOpOperand(0).get();
169 if (auto resource_producing_op = input_resource_value.getDefiningOp()) {
170 if (!resource_layout)
171 TF_RETURN_WITH_CONTEXT(errors::Internal("missing layout on resource"));
172 if (!value_layout)
173 TF_RETURN_WITH_CONTEXT(errors::Internal("missing layout on value"));
174 if (resource_layout != value_layout) {
175 TF_ASSIGN_OR_RETURN(auto new_value,
176 EmitRelayout(op->getOperand(1), value_layout.value(),
177 resource_layout.value()));
178 op->setOperand(1, new_value);
179 }
180 } else {
181 if ((!resource_layout || resource_layout->IsEmpty()) && !value_layout)
182 TF_RETURN_WITH_CONTEXT(errors::Internal(
183 "at least one of resource or value layout must be set"));
184 // This error should not happen: if resource_layout is set, then we expect
185 // a DTensorLayout op between the resource tensor and this op, so we should
186 // actaully be in the if case rather than the else case.
187 if (resource_layout && !resource_layout->IsEmpty() && value_layout &&
188 resource_layout != value_layout)
189 TF_RETURN_WITH_CONTEXT(errors::Internal(
190 "if both resource and value layout are set they must be equal"));
191
192 auto block_arg = input_resource_value.dyn_cast<mlir::BlockArgument>();
193 auto enclosing_device_cluster =
194 op->getParentOfType<mlir::tf_device::ClusterOp>();
195
196 if (!enclosing_device_cluster)
197 TF_RETURN_WITH_CONTEXT(
198 errors::InvalidArgument("op must be enclosed by a cluster"));
199
200 auto block_arg_index = block_arg.getArgNumber();
201
202 // If layout of resource already exists, then check that layouts are
203 // consistent. Otherwise, add newly inferred layout of resource argument
204 // as attributes to the enclosing cluster op to be propagated to custom
205 // device.
206 std::string layout_string;
207 if (resource_layout && !resource_layout->IsEmpty())
208 layout_string = resource_layout->ToString();
209 else
210 layout_string = value_layout->ToString();
211 TF_RETURN_IF_ERROR(ValidateAndAssignResourceInputLayout(
212 enclosing_device_cluster, layout_string, block_arg_index, &builder));
213 }
214
215 return InferSPMDExpandedLocalShape(op);
216 }
217
218 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)219 ResourceSPMDExpander::ComputeLayoutForward(
220 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
221 // VarHandle and VarIsInitialized have 0 rank outputs.
222 if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::VarIsInitializedOp>(op))
223 return llvm::DenseMap<int, Layout>({{0, Layout::Empty()}});
224
225 // Handling of resource destruction is no-op.
226 if (llvm::isa<mlir::TF::DestroyResourceOp>(op))
227 return llvm::DenseMap<int, Layout>();
228
229 // Read variable ops have one input so infer the output layout if input
230 // layout exists.
231 if (llvm::isa<mlir::TF::ReadVariableOp>(op)) return input_layouts;
232
233 // These ops do not have outputs, so do not infer any layout.
234 if (llvm::isa<mlir::TF::AssignVariableOp, mlir::TF::AssignAddVariableOp,
235 mlir::TF::AssignSubVariableOp>(op)) {
236 return llvm::DenseMap<int, Layout>();
237 }
238 // Return an error if not any of the ops above.
239 return errors::InvalidArgument(
240 llvm::formatv(
241 "Found unexpected resource op {0} during layout propagation.",
242 OpName(op))
243 .str());
244 }
245
246 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts,const llvm::DenseMap<int,Layout> & output_layouts)247 ResourceSPMDExpander::ComputeLayoutBackward(
248 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
249 const llvm::DenseMap<int, Layout>& output_layouts) {
250 // For Assign* ops, propagate the resource tensor layout to the tensor if
251 // resource tensor layout exists.
252 if (llvm::isa<mlir::TF::AssignVariableOp, mlir::TF::AssignAddVariableOp,
253 mlir::TF::AssignSubVariableOp>(op)) {
254 if (input_layouts.find(0) != input_layouts.end())
255 return llvm::DenseMap<int, Layout>({{1, input_layouts.lookup(0)}});
256 return llvm::DenseMap<int, Layout>();
257 }
258 // Handling of these ops are no-ops.
259 if (llvm::isa<mlir::TF::DestroyResourceOp, mlir::TF::VarHandleOp,
260 mlir::TF::VarIsInitializedOp, mlir::TF::ReadVariableOp>(op))
261 return llvm::DenseMap<int, Layout>();
262
263 // Return an error if not any of the ops above.
264 return errors::InvalidArgument(
265 llvm::formatv(
266 "Found unexpected resource op {0} during layout propagation.",
267 OpName(op))
268 .str());
269 }
270
271 } // namespace dtensor
272 } // namespace tensorflow
273