1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/Location.h"  // from @llvm-project
27 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
28 #include "mlir/IR/Operation.h"  // from @llvm-project
29 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
30 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
31 #include "mlir/IR/Types.h"  // from @llvm-project
32 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
33 #include "mlir/IR/Value.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
36 #include "mlir/Support/LLVM.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
44 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
45 #include "tensorflow/core/util/device_name_utils.h"
46 
47 namespace mlir {
48 namespace TFTPU {
49 
50 namespace {
51 
52 constexpr char kDeviceAttr[] = "device";
53 constexpr char kDeviceCPU[] = "CPU";
54 constexpr char kFuncDeviceAttr[] = "tf.device";
55 
56 struct TPUDynamicLayoutPass
57     : public TF::PerFunctionAggregateAnalysisConsumerPass<
58           TPUDynamicLayoutPass, TF::ResourceAliasAnalysis> {
59   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TPUDynamicLayoutPass)
60 
61   void runOnFunction(
62       func::FuncOp func,
63       const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);
64 
getArgumentmlir::TFTPU::__anon346adf1c0111::TPUDynamicLayoutPass65   StringRef getArgument() const final { return "tf-tpu-dynamic-layout-pass"; }
66 
getDescriptionmlir::TFTPU::__anon346adf1c0111::TPUDynamicLayoutPass67   StringRef getDescription() const final {
68     return "Inserts TPU layout ops to determine layout at run time.";
69   }
70 };
71 
72 // Checks if the input producer op is supported in this transform. Right now, we
73 // only check if it is a tf.IteratorGetNext where resource input is coming from
74 // a VarHandle on CPU or a function argument assigned to CPU.
IsSupportedInputOp(Operation * op,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)75 bool IsSupportedInputOp(
76     Operation* op,
77     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
78   TF::IteratorGetNextOp iterator_op = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
79   if (!iterator_op) return false;
80 
81   Value resource_iterator = iterator_op.iterator();
82 
83   if (resource_alias_analysis.IsUnknownResource(resource_iterator))
84     return false;
85   llvm::SmallSetVector<Value, 8> aliases =
86       resource_alias_analysis.GetResourceAliases(resource_iterator);
87 
88   auto is_generator = [](Value val) {
89     if (val.isa<BlockArgument>()) return true;
90     Operation* definition = val.getDefiningOp();
91     return definition->getNumOperands() == 0 &&
92            definition->getNumResults() == 1;
93   };
94 
95   // Check all generator aliases (ops or function argument) are on CPU.
96   func::FuncOp func = iterator_op->getParentOfType<func::FuncOp>();
97   return llvm::all_of(aliases, [&](Value alias) {
98     // Ignore non-generator aliases.
99     if (!is_generator(alias)) return true;
100 
101     StringAttr device;
102     if (auto arg = alias.dyn_cast<BlockArgument>()) {
103       device = func.getArgAttrOfType<mlir::StringAttr>(arg.getArgNumber(),
104                                                        kFuncDeviceAttr);
105     } else {
106       device = alias.getDefiningOp()->getAttrOfType<StringAttr>(kDeviceAttr);
107     }
108 
109     if (!device) return false;
110     tensorflow::DeviceNameUtils::ParsedName parsed_device;
111     if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
112                                                     &parsed_device)) {
113       return false;
114     }
115     return parsed_device.has_type && parsed_device.type == kDeviceCPU;
116   });
117 }
118 
CreateBuilderAfterOp(Operation * op)119 OpBuilder CreateBuilderAfterOp(Operation* op) {
120   return OpBuilder(op->getBlock(), ++Block::iterator(op));
121 }
122 
123 // Builds a TPUGetLayoutOp with the given compile op and input index.
BuildGetLayout(const int64_t execute_arg_index,Value compilation_key,tf_device::LaunchOp compile_launch,OpBuilder * builder)124 TF::TPUGetLayoutOp BuildGetLayout(const int64_t execute_arg_index,
125                                   Value compilation_key,
126                                   tf_device::LaunchOp compile_launch,
127                                   OpBuilder* builder) {
128   return builder->create<TF::TPUGetLayoutOp>(
129       compile_launch.getLoc(),
130       llvm::ArrayRef<Type>{RankedTensorType::get({ShapedType::kDynamicSize},
131                                                  builder->getIntegerType(64))},
132       llvm::ArrayRef<Value>{compilation_key},
133       llvm::ArrayRef<NamedAttribute>{
134           builder->getNamedAttr("index",
135                                 builder->getI64IntegerAttr(execute_arg_index)),
136           builder->getNamedAttr("is_output", builder->getBoolAttr(false))});
137 }
138 
139 // Builds a TPUCopyWithLayoutOp with the given get_layout op and input.
BuildCopyWithLayout(tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch,TF::TPUGetLayoutOp get_layout,Value input,OpBuilder * builder)140 TF::TPUCopyWithLayoutOp BuildCopyWithLayout(tf_device::LaunchOp execute_launch,
141                                             tf_device::LaunchOp compile_launch,
142                                             TF::TPUGetLayoutOp get_layout,
143                                             Value input, OpBuilder* builder) {
144   return builder->create<TF::TPUCopyWithLayoutOp>(
145       execute_launch.getLoc(), llvm::ArrayRef<Type>{input.getType()},
146       llvm::ArrayRef<Value>{input, get_layout.layout()});
147 }
148 
149 // Performs transformation for a non-replicated input.
HandleInput(Value input,const int64_t execute_arg_index,TF::TPUExecuteOp execute,tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch)150 void HandleInput(Value input, const int64_t execute_arg_index,
151                  TF::TPUExecuteOp execute, tf_device::LaunchOp execute_launch,
152                  tf_device::LaunchOp compile_launch) {
153   OpBuilder builder = CreateBuilderAfterOp(compile_launch);
154   auto get_layout = BuildGetLayout(execute_arg_index, execute.key(),
155                                    compile_launch, &builder);
156   builder.setInsertionPoint(execute_launch);
157   auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch,
158                                               get_layout, input, &builder);
159   copy_with_layout->setAttr(kDeviceAttr, execute_launch.deviceAttr());
160   execute.setOperand(execute_arg_index, copy_with_layout);
161 }
162 
163 // Performs transformation for replicated inputs. Returns true if this is a
164 // supported case (thus transform happened).
HandleReplicatedInputs(const int64_t execute_arg_index,Value compilation_key,tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch,mlir::BlockArgument replicate_arg,tf_device::ReplicateOp replicate,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)165 bool HandleReplicatedInputs(
166     const int64_t execute_arg_index, Value compilation_key,
167     tf_device::LaunchOp execute_launch, tf_device::LaunchOp compile_launch,
168     mlir::BlockArgument replicate_arg, tf_device::ReplicateOp replicate,
169     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
170   // We need to know the devices to copy to.
171   if (!replicate.devices()) return false;
172 
173   MutableArrayRef<OpOperand> inputs =
174       replicate.GetOperandsForBlockArgument(replicate_arg);
175   for (auto entry : llvm::enumerate(inputs)) {
176     auto input_op = entry.value().get().getDefiningOp();
177     if (!input_op || !IsSupportedInputOp(input_op, resource_alias_analysis))
178       return false;
179   }
180   OpBuilder builder = CreateBuilderAfterOp(compile_launch);
181   auto get_layout = BuildGetLayout(execute_arg_index, compilation_key,
182                                    compile_launch, &builder);
183   builder.setInsertionPoint(replicate);
184   for (auto entry : llvm::enumerate(inputs)) {
185     auto copy_with_layout =
186         BuildCopyWithLayout(execute_launch, compile_launch, get_layout,
187                             entry.value().get(), &builder);
188 
189     auto device_list = replicate.devices()
190                            .getValue()
191                            .get(execute_launch.getDevice())
192                            .cast<ArrayAttr>();
193     copy_with_layout->setAttr(kDeviceAttr,
194                               device_list.getValue()[entry.index()]);
195 
196     entry.value().set(copy_with_layout);
197   }
198   return true;
199 }
200 
201 // Performs transformation on a compile and associated execute(s) ops. The
202 // compile should not have other uses.
HandleCompileAndExecutes(tf_device::LaunchOp compile_launch,llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)203 void HandleCompileAndExecutes(
204     tf_device::LaunchOp compile_launch,
205     llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,
206     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
207   auto compile =
208       llvm::cast<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front());
209   tensorflow::tpu::TPUCompileMetadataProto metadata;
210   metadata.ParseFromString(compile.metadata().str());
211   llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> input_mappings =
212       tensorflow::GetMetadataArgumentMapping(metadata);
213 
214   bool metadata_updated = false;
215   auto maybe_replicate =
216       execute_launches.front()->getParentOfType<tf_device::ReplicateOp>();
217 
218   for (auto execute_and_input_mapping :
219        llvm::zip(execute_launches, input_mappings)) {
220     auto& execute_launch = std::get<0>(execute_and_input_mapping);
221     auto execute =
222         llvm::cast<TF::TPUExecuteOp>(execute_launch.GetBody().front());
223     const auto& input_mapping = std::get<1>(execute_and_input_mapping);
224 
225     for (auto& input_and_idx : llvm::enumerate(execute.args())) {
226       Value input = input_and_idx.value();
227       const int64_t execute_arg_index = input_and_idx.index();
228       if (auto block_arg = input.dyn_cast<BlockArgument>()) {
229         // For a block argument, consider transforms only when it is a
230         // replicated input (defining ops will be outside the replicate node).
231         if (maybe_replicate != block_arg.getParentRegion()->getParentOp() ||
232             !HandleReplicatedInputs(execute_arg_index, execute.key(),
233                                     execute_launch, compile_launch, block_arg,
234                                     maybe_replicate, resource_alias_analysis)) {
235           continue;
236         }
237       } else {
238         // For an op output, consider transforms only when 1) there is no
239         // replication or 2) it is outside the replicate node that encloses the
240         // execute node. (Because if the op is inside replicate, it is probably
241         // not on the host.)
242         auto* input_op = input.getDefiningOp();
243         if (maybe_replicate &&
244             maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
245           continue;
246         }
247         if (!IsSupportedInputOp(input_op, resource_alias_analysis)) continue;
248         HandleInput(input, execute_arg_index, execute, execute_launch,
249                     compile_launch);
250       }
251 
252       metadata.mutable_args(input_mapping[execute_arg_index])
253           ->set_unrestricted_layout(true);
254       metadata_updated = true;
255     }
256   }
257 
258   if (metadata_updated)
259     compile->setAttr("metadata", StringAttr::get(compile.getContext(),
260                                                  metadata.SerializeAsString()));
261 }
262 
runOnFunction(func::FuncOp func,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)263 void TPUDynamicLayoutPass::runOnFunction(
264     func::FuncOp func,
265     const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
266   func.walk([&](TF::_TPUCompileMlirOp compile) {
267     // Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
268     auto compile_launch =
269         llvm::dyn_cast<tf_device::LaunchOp>(compile->getParentOp());
270     if (!compile_launch || !compile_launch.WrapsSingleOp()) return;
271 
272     llvm::SmallVector<tf_device::LaunchOp, 4> execute_launches;
273     execute_launches.reserve(compile_launch.getNumResults() - 1);
274     for (Value program_result : llvm::drop_begin(compile_launch.results(), 1)) {
275       if (!program_result.hasOneUse()) return;
276       Operation* user = *program_result.user_begin();
277       auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(user);
278       if (!execute) return;
279       auto execute_launch =
280           llvm::dyn_cast<tf_device::LaunchOp>(execute->getParentOp());
281       if (!execute_launch || !execute_launch.WrapsSingleOp()) return;
282       execute_launches.push_back(execute_launch);
283     }
284 
285     HandleCompileAndExecutes(compile_launch, execute_launches,
286                              resource_alias_analysis);
287   });
288 }
289 
290 }  // namespace
291 
CreateTPUDynamicLayoutPass()292 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass() {
293   return std::make_unique<TPUDynamicLayoutPass>();
294 }
295 
296 }  // namespace TFTPU
297 }  // namespace mlir
298