1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass lifts resource variable operations outside of device computation.
17
18 #include <cstddef>
19 #include <cstdint>
20
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Block.h" // from @llvm-project
33 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
34 #include "mlir/IR/Builders.h" // from @llvm-project
35 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
36 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
37 #include "mlir/IR/Diagnostics.h" // from @llvm-project
38 #include "mlir/IR/Operation.h" // from @llvm-project
39 #include "mlir/IR/Region.h" // from @llvm-project
40 #include "mlir/IR/SymbolTable.h" // from @llvm-project
41 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
42 #include "mlir/IR/Types.h" // from @llvm-project
43 #include "mlir/IR/Value.h" // from @llvm-project
44 #include "mlir/IR/Verifier.h" // from @llvm-project
45 #include "mlir/IR/Visitors.h" // from @llvm-project
46 #include "mlir/Pass/Pass.h" // from @llvm-project
47 #include "mlir/Support/LLVM.h" // from @llvm-project
48 #include "mlir/Support/LogicalResult.h" // from @llvm-project
49 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
50 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
52 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
53 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
54 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
55 #include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h"
56 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes_detail.h"
57 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
58 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
59 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
60 #include "tensorflow/core/framework/tensor_shape.pb.h"
61
62 namespace mlir {
63
64 namespace {
65
66 constexpr char kDeviceAttr[] = "device";
67
68 // Lift resource operations out of device computation.
69 struct ResourceOpLiftingPass
70 : public TFDevice::ResourceOpLiftingPassBase<ResourceOpLiftingPass> {
71 void runOnOperation() override;
72 };
73
IsResource(Value value)74 bool IsResource(Value value) {
75 return getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>();
76 }
77
78 // Get the type of the data contained in a resource. Returns null if there is
79 // no single type in the resource.
GetResourceSubtype(Value value)80 Type GetResourceSubtype(Value value) {
81 auto resource_type =
82 getElementTypeOrSelf(value.getType()).dyn_cast<TF::ResourceType>();
83 auto subtypes = resource_type.getSubtypes();
84 if (subtypes.size() == 1) return subtypes[0];
85 return nullptr;
86 }
87
88 // Replaces all `tf.VarIsInitializedOp` in a block with a constant true.
89 // TODO(b/171039585): Replace this with proper analysis of
90 // `tf.VarIsInitializedOp` in regards to resource writes and control flow.
SetAllVarIsInitializedToTrue(Block * block)91 void SetAllVarIsInitializedToTrue(Block* block) {
92 auto builder = OpBuilder::atBlockBegin(block);
93 TF::ConstOp const_true = nullptr;
94 for (auto op :
95 llvm::make_early_inc_range(block->getOps<TF::VarIsInitializedOp>())) {
96 builder.setInsertionPoint(op);
97 if (!const_true)
98 const_true = builder.create<TF::ConstOp>(
99 op.getLoc(),
100 DenseIntElementsAttr::get(
101 RankedTensorType::get(/*shape=*/{}, builder.getI1Type()), true));
102
103 op.is_initialized().replaceAllUsesWith(const_true);
104 op.erase();
105 }
106 }
107
108 // Performs store-load forwarding. This effectively removes
109 // 1) Any resource loads after a store to that same resource is done
110 // 2) Any resource stores except the last one.
111 // TODO(ycao): Store-load forwarding implemented here is only correct when
112 // computation is purely sequential (no concurrency). Need to support concurrent
113 // computation as well.
ForwardStoreToLoad(Block * block)114 void ForwardStoreToLoad(Block* block) {
115 // resource_handle_to_last_store_op keeps track of the most recent (last)
116 // store to each resource. Non-existent entry indicates that a resource has
117 // not been stored to yet.
118 llvm::SmallDenseMap<Value, TF::AssignVariableOp>
119 resource_handle_to_last_store_op;
120
121 // Only iterate through ops directly in the block as we can't handle ops
122 // nested deeper in regions.
123 for (Operation& op : llvm::make_early_inc_range(*block)) {
124 if (auto read_variable_op = dyn_cast<TF::ReadVariableOp>(&op)) {
125 Value resource = read_variable_op.resource();
126 auto last_store = resource_handle_to_last_store_op[resource];
127 if (!last_store) continue;
128
129 // Use stored value in last_store to replace all uses of current resource
130 // load's result, then erase this resource load. Add an intermediate
131 // CastOp if the shape of types doesn't exactly match.
132 Type read_type = read_variable_op.value().getType();
133 if (read_type != last_store.value().getType()) {
134 OpBuilder builder(last_store);
135 builder.setInsertionPointAfter(last_store);
136 auto cast = builder.create<TF::CastOp>(
137 last_store.getLoc(), read_type, last_store.value(),
138 /*Truncate=*/builder.getBoolAttr(false));
139 read_variable_op.value().replaceAllUsesWith(cast);
140 } else {
141 read_variable_op.value().replaceAllUsesWith(last_store.value());
142 }
143
144 read_variable_op.erase();
145 continue;
146 }
147
148 if (auto assign_variable_op = dyn_cast<TF::AssignVariableOp>(&op)) {
149 Value resource = assign_variable_op.resource();
150 auto last_store = resource_handle_to_last_store_op[resource];
151 // Previous store ops to same resource can be erased.
152 if (last_store) last_store.erase();
153
154 resource_handle_to_last_store_op[resource] = assign_variable_op;
155 }
156 }
157 }
158
159 //===----------------------------------------------------------------------===//
160 // RegionResourceHoister
161 //===----------------------------------------------------------------------===//
162
163 // Helper class to hoist resource ops out of regions attached to an op.
164 class RegionResourceHoister {
165 public:
RegionResourceHoister(Operation * op)166 explicit RegionResourceHoister(Operation* op) : op_(op) {}
167
168 // Analyzes attached regions to record resources read and written.
169 LogicalResult Analyze();
170
171 // Returns all resources accessed by the regions attached the op.
GetResources()172 auto& GetResources() { return resources_; }
173
174 // Returns if the given value is a resource that needs lifting.
Contains(Value resource) const175 bool Contains(Value resource) const {
176 return resources_.find(resource) != resources_.end();
177 }
178
179 // Drops the given resource from lifting.
DropResource(Value resource)180 void DropResource(Value resource) {
181 resources_.erase(resource);
182 written_resources_.remove(resource);
183 }
184
185 // Replaces all resource loads in all regions attached to the op.
ReplaceResourceLoads(bool read_only)186 void ReplaceResourceLoads(bool read_only) {
187 llvm::for_each(op_->getRegions(), [&](Region& region) {
188 ReplaceResourceLoads(region, read_only);
189 });
190 }
191
192 static LogicalResult ReplaceOpWithNewOp(Operation* op);
193
194 private:
195 // Returns if any resources need lifting.
NeedsLifting() const196 bool NeedsLifting() const { return !resources_.empty(); }
197
198 // Returns the number of results generated by the lifted op.
GetLiftedNumResults() const199 int GetLiftedNumResults() const { return num_new_results_; }
200
201 // Generates hoisted reads for resources that need them before the op.
202 void GenerateHoistedReads();
203
204 // Replaces all resource loads in the given region with hoisted loads. If
205 // `read_only` is true, limit this replacement to read only resources.
206 void ReplaceResourceLoads(Region& region, bool read_only);
207
208 // Appends final values writte to resources to the region returns for the
209 // given set of regions.
210 void AppendResourceStoreValueToReturn(RegionRange regions);
211
212 // Performs the final replacement of the op.
213 void ReplaceOpWithNewOp();
214
215 // Returns is this resource was written to in any of the regions.
IsWritten(Value resource) const216 bool IsWritten(Value resource) const {
217 return written_resources_.contains(resource);
218 }
219
220 static LogicalResult HoistResourcesOutOfIfCaseCluster(Operation* op);
221 static LogicalResult HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op);
222
223 Operation* op_;
224
225 // Per resource information about accesses to that resource.
226 struct ResourceInfo {
227 // Is this resource read in any of the regions?
228 bool is_read;
229 // Is this resource written in any of the regions?
230 bool is_written;
231 // Is this resource written in all of the regions?
232 bool is_written_all;
233 // The hoisted read used to replace region reads.
234 Value hoisted_read;
235 // the type of the data held by the resource.
236 Type data_type;
237 // For written resources, the result # of the lifted op which will hold the
238 // value of the resource. This result will be used to generates writes to
239 // the resource after the lifted op.
240 int result_index;
241 // Attributes on the read operation.
242 DictionaryAttr read_attrs;
243 // Attributes on the write operation.
244 DictionaryAttr write_attrs;
245
ResourceInfomlir::__anon60ae525e0111::RegionResourceHoister::ResourceInfo246 ResourceInfo()
247 : is_read(false),
248 is_written(false),
249 is_written_all(false),
250 hoisted_read(nullptr),
251 data_type(nullptr),
252 result_index(-1) {}
253
IsResultIndexAssignedmlir::__anon60ae525e0111::RegionResourceHoister::ResourceInfo254 bool IsResultIndexAssigned() { return result_index != -1; }
255
256 // Refine the resource type using the given type `type`.
RefineTypemlir::__anon60ae525e0111::RegionResourceHoister::ResourceInfo257 void RefineType(Type type) {
258 if (!data_type) {
259 data_type = type;
260 } else {
261 data_type = TF::GetCastCompatibleType(data_type, type,
262 /*may_ignore_ref_type_a=*/false);
263 assert(data_type != nullptr && "Resource used with incompatible types");
264 }
265 }
266 };
267 llvm::MapVector<Value, ResourceInfo> resources_;
268 llvm::SetVector<Value> written_resources_;
269 // number of new results after lifting.
270 int num_new_results_;
271 };
272
273 // Analyzes resources that are read or written within attached regions.
Analyze()274 LogicalResult RegionResourceHoister::Analyze() {
275 // Hoisting of child regions might have created opportunity for store-load
276 // forwarding.
277 for (Region& region : op_->getRegions()) {
278 ForwardStoreToLoad(®ion.front());
279 }
280
281 llvm::SetVector<Value> all_resources;
282 bool is_func = false;
283 // For functions, the resources to analyze are the function arguments.
284 // Otherwise, its the region captures.
285 if (func::FuncOp func = dyn_cast<func::FuncOp>(op_)) {
286 is_func = true;
287 Region& body = func.getBody();
288 for (BlockArgument arg : body.getArguments()) {
289 if (IsResource(arg)) all_resources.insert(arg);
290 }
291 } else {
292 getUsedValuesDefinedAbove(op_->getRegions(), all_resources);
293 all_resources.remove_if([](Value value) { return !IsResource(value); });
294 }
295
296 num_new_results_ = op_->getNumResults();
297
298 for (auto resource : all_resources) {
299 ResourceInfo info;
300 info.data_type = GetResourceSubtype(resource);
301 llvm::BitVector written_regions(op_->getNumRegions());
302 bool unsupported_use = false;
303 for (OpOperand& use : resource.getUses()) {
304 Operation* user = use.getOwner();
305 // If the user is not in one of the regions, we are not interested in it.
306 // Since all the sub-regions within this region (i.e., regions attached to
307 // op's in this region) have themselves gone through lifting, all resource
308 // users are expected to be operations in this region and not embedded
309 // within other sub-regions attached to op's in this region. So the check
310 // for whether a user is in one of the regions attached to this op is
311 // straightforward.
312 if (user->getParentRegion()->getParentOp() != op_) continue;
313
314 // For functions, if the resource is used as a return operand, use that
315 // as its result index.
316 if (is_func && isa<func::ReturnOp>(user)) {
317 assert(!info.IsResultIndexAssigned() &&
318 "Expect resource argument to returned no more than once");
319 info.result_index = use.getOperandNumber();
320 continue;
321 }
322
323 auto read = dyn_cast<TF::ReadVariableOp>(user);
324 auto write = dyn_cast<TF::AssignVariableOp>(user);
325 if (!read && !write) {
326 unsupported_use = true;
327 break;
328 }
329
330 if (read && !info.is_read) {
331 info.is_read = true;
332 info.RefineType(read.value().getType());
333 info.read_attrs = user->getAttrDictionary();
334 }
335
336 if (write) {
337 info.is_written = true;
338 info.RefineType(write.value().getType());
339 info.write_attrs = user->getAttrDictionary();
340 written_regions.set(user->getParentRegion()->getRegionNumber());
341 }
342 }
343
344 // If the resource is used in an op that we do not understand, skip
345 // lifting for that resource.
346 if (unsupported_use) continue;
347
348 info.is_written_all = written_regions.count() == op_->getNumRegions();
349
350 // If the resource is written in some but not all regions, we would need
351 // a read for the value before these regions. Note that this is applicable
352 // only to multi-region ops:
353 // If/Case: If not all regions write to the resource, post hoisting the read
354 // value need to be routed through all paths that don't write.
355 // While: since while condition cannot write, any resource written in the
356 // while body will need to be read as well in case the while body is never
357 // executed.
358 // Both cases are handled by the condition below.
359 if (info.is_written && !info.is_written_all) info.is_read = true;
360
361 // Allocate a result index for written resources that don't have one.
362 if (info.is_written) {
363 written_resources_.insert(resource);
364 if (!info.IsResultIndexAssigned()) info.result_index = num_new_results_++;
365 }
366
367 resources_.insert({resource, info});
368 }
369 return success();
370 }
371
372 // Generates hoisted reads for all resources that need them just before the op.
GenerateHoistedReads()373 void RegionResourceHoister::GenerateHoistedReads() {
374 OpBuilder builder(op_);
375 DictionaryAttr empty_attrs = builder.getDictionaryAttr({});
376 for (auto& resource_it : GetResources()) {
377 Value resource = resource_it.first;
378 auto& info = resource_it.second;
379
380 if (info.is_read) {
381 Operation* read = builder.create<TF::ReadVariableOp>(
382 op_->getLoc(), info.data_type, resource);
383 read->setAttrs(info.read_attrs ? info.read_attrs : empty_attrs);
384 read->removeAttr(kDeviceAttr);
385 info.hoisted_read = read->getResult(0);
386 }
387 }
388 }
389
390 // Replaces all resource reads with the hoisted read.
ReplaceResourceLoads(Region & region,bool read_only)391 void RegionResourceHoister::ReplaceResourceLoads(Region& region,
392 bool read_only) {
393 assert(llvm::hasSingleElement(region) && "Expected single block region");
394 // Only iterate through ops directly in the body as we can't handle
395 // ops nested deeper in regions.
396 auto all_reads = region.front().getOps<TF::ReadVariableOp>();
397 for (auto read_op : llvm::make_early_inc_range(all_reads)) {
398 Value resource = read_op.resource();
399 if (!Contains(resource)) continue;
400
401 ResourceInfo& info = resources_[resource];
402 // If replacing loads for read only resources, skip if the resource
403 // was written to.
404 if (read_only && info.is_written) continue;
405
406 read_op.replaceAllUsesWith(info.hoisted_read);
407 read_op.erase();
408 }
409 }
410
411 // For written resources, add its value at the end of each region to that
412 // regions return value. For a region, its value at the end may be a value
413 // written to that resource in that region, or its hoisted read value if the
414 // resource is not written in that region. The return value can be vended out
415 // either as an existing return value, or a newly allocated return value.
AppendResourceStoreValueToReturn(RegionRange regions)416 void RegionResourceHoister::AppendResourceStoreValueToReturn(
417 RegionRange regions) {
418 for (Region* region : regions) {
419 assert(llvm::hasSingleElement(*region) && "Expected single block region");
420 Block& front = region->front();
421 auto old_return = front.getTerminator();
422 assert(old_return->getNumOperands() == op_->getNumResults());
423 auto new_return_operands = llvm::to_vector<4>(old_return->getOperands());
424 new_return_operands.resize(num_new_results_);
425
426 // initialize return values for written resources to be the hoisted reads.
427 for (Value resource : written_resources_) {
428 const ResourceInfo& info = resources_[resource];
429 new_return_operands[info.result_index] = info.hoisted_read;
430 }
431
432 // Only iterate through ops directly in the body as op's embedded in child
433 // regions should have been lifted out.
434 auto assign_ops = front.getOps<TF::AssignVariableOp>();
435 for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) {
436 Value resource = assign_variable_op.resource();
437 if (!IsWritten(resource)) continue;
438
439 // TODO(ycao): Prevent same value from being returned multiple times.
440 // TODO(ycao): Do not return resource store value if it is defined outside
441 // of cluster. Both of these can be post-resource-op-lifting cleanup
442 // passes.
443 int result_index = resources_[resource].result_index;
444 new_return_operands[result_index] = assign_variable_op.value();
445 assign_variable_op.erase();
446 }
447 old_return->setOperands(new_return_operands);
448 }
449 }
450
451 // Replace the old op with a new op (with potentially additional results), and
452 // add stores to written resources after the new op.
ReplaceOpWithNewOp()453 void RegionResourceHoister::ReplaceOpWithNewOp() {
454 auto new_result_types = llvm::to_vector<4>(op_->getResultTypes());
455 int result_region = isa<TF::WhileRegionOp>(op_) ? 1 : 0;
456 Operation* terminator = op_->getRegion(result_region).front().getTerminator();
457 auto extra_result_types =
458 terminator->getOperands().drop_front(op_->getNumResults()).getTypes();
459 new_result_types.insert(new_result_types.end(), extra_result_types.begin(),
460 extra_result_types.end());
461 OpBuilder builder(op_);
462 // Clone this old operation but with new result types.
463 Operation* new_op = Operation::create(
464 op_->getLoc(), op_->getName(), new_result_types, op_->getOperands(),
465 op_->getAttrs(), op_->getSuccessors(), op_->getNumRegions());
466 builder.insert(new_op);
467
468 // Move regions to the new op.
469 for (auto it : llvm::zip(op_->getRegions(), new_op->getRegions())) {
470 Region& old_region = std::get<0>(it);
471 Region& new_region = std::get<1>(it);
472 new_region.takeBody(old_region);
473 }
474
475 // Insert stores to all written resources.
476 for (Value resource : written_resources_) {
477 ResourceInfo& info = resources_[resource];
478 Value value_to_write = new_op->getResult(info.result_index);
479 Operation* write = builder.create<TF::AssignVariableOp>(
480 op_->getLoc(), resource, value_to_write);
481 write->setAttrs(info.write_attrs);
482 write->removeAttr(kDeviceAttr);
483 }
484
485 // As a part of lifting, we either reuse an existing slot for resource type
486 // results or add a new slot. Resource type results should not have any uses
487 // to begin with. So we can safely replace each old op result with the
488 // corresponding new op result.
489 int old_num_results = op_->getNumResults();
490 op_->replaceAllUsesWith(new_op->getResults().take_front(old_num_results));
491 op_->erase();
492 op_ = nullptr;
493 }
494
495 // Lift resource load and stores out of regions attached to `op`, where op is
496 // an If/case/cluster op.
HoistResourcesOutOfIfCaseCluster(Operation * op)497 LogicalResult RegionResourceHoister::HoistResourcesOutOfIfCaseCluster(
498 Operation* op) {
499 RegionResourceHoister hoister(op);
500 if (failed(hoister.Analyze())) return failure();
501
502 // If there are no resource region captures, then nothing to do.
503 if (!hoister.NeedsLifting()) return success();
504
505 // Start the transformation. For each region, replace the resource read with
506 // the value read before the op.
507 hoister.GenerateHoistedReads();
508 hoister.ReplaceResourceLoads(/*read_only=*/false);
509 hoister.AppendResourceStoreValueToReturn(op->getRegions());
510 hoister.ReplaceOpWithNewOp();
511 return success();
512 }
513
514 // Lift resource loads and stores out of WhileRegion
HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op)515 LogicalResult RegionResourceHoister::HoistResourcesOutOfWhileRegion(
516 TF::WhileRegionOp op) {
517 // For WhileRegion, post canonicalization all resource used within the
518 // body and condition regions are replaced with captured values, so we do not
519 // need to take into account the body and condition region arguments.
520 RegionResourceHoister hoister(op);
521
522 if (failed(hoister.Analyze())) return failure();
523
524 // If there are no resource region captures, then nothing to do.
525 if (!hoister.NeedsLifting()) return success();
526
527 // The resources captured for While loop fall into two categories:
528 // (a) read-only. These reads can be replaced by a hoisted read created
529 // before the WhileOp (similar to if and case).
530 // (b) written: since the value is written in the loop (which can only in
531 // loop body, all these will become loop variables. Since all resource
532 // variables are removed from the loop variabled during
533 // canonicalizationW, we need to create new operand/result slots. The
534 // input operands for these slots are the read values
535 // prior to the op, and all references to these are replaced by the
536 // corresponding slot argument. We need to generate writes following
537 // the while for these resources.
538 //
539 // Note that for WhileRegion ops, if a resource is written, it will be written
540 // only in the body and not the condition, so the hoister analysis will infer
541 // it as needing a read as well.
542
543 // Generate hoisted reads before the while.
544 hoister.GenerateHoistedReads();
545
546 // Replace just the read-only resources with the hoisted reads.
547 hoister.ReplaceResourceLoads(/*read_only=*/true);
548
549 // For written resources, add additional operands to the while op.
550 int num_old_results = op.getNumResults();
551 int num_new_results = hoister.GetLiftedNumResults();
552 int num_extra_results = num_new_results - num_old_results;
553
554 SmallVector<Type, 4> new_result_types;
555 SmallVector<Value, 4> new_while_operands;
556 new_result_types.resize(num_extra_results);
557 new_while_operands.resize(num_extra_results);
558
559 for (auto& it : hoister.GetResources()) {
560 if (!it.second.is_written) continue;
561 int index = it.second.result_index - num_old_results;
562 new_result_types[index] = it.second.data_type;
563 new_while_operands[index] = it.second.hoisted_read;
564 }
565 op.getOperation()->insertOperands(op.getNumOperands(), new_while_operands);
566
567 // Patch the cond and body regions to have additional arguments, and replace
568 // the remaining resource reads (which will be resource reads for written
569 // resources) with these arguments.
570 Location loc = op.getLoc();
571 for (Region* region : op.getRegions()) {
572 region->addArguments(new_result_types,
573 SmallVector<Location>(new_result_types.size(), loc));
574 // Point hoisted read for written resources to the region's arguments.
575 for (auto& it : hoister.GetResources()) {
576 if (!it.second.is_written) continue;
577 it.second.hoisted_read = region->getArgument(it.second.result_index);
578 }
579 hoister.ReplaceResourceLoads(*region, /*read_only=*/false);
580 }
581
582 // Add additional return values to body return. These correspond to values
583 // written to resources in the body region.
584 hoister.AppendResourceStoreValueToReturn(op.getRegions().drop_front());
585
586 // Finally, create a new while with additional return values.
587 hoister.ReplaceOpWithNewOp();
588 return success();
589 }
590
591 // Lift resources out of the regions attached to `op`
ReplaceOpWithNewOp(Operation * op)592 LogicalResult RegionResourceHoister::ReplaceOpWithNewOp(Operation* op) {
593 if (auto while_op = dyn_cast<TF::WhileRegionOp>(op))
594 return HoistResourcesOutOfWhileRegion(while_op);
595 return HoistResourcesOutOfIfCaseCluster(op);
596 }
597
598 // Holds information about a function's use of a resource argument.
599 struct ResourceArgUseInfo {
600 // Data type of the data contained in the resource.
601 Type data_type;
602 // Is the resource argument used in an assign op?
603 bool updated;
604 // Is the resource argument used in a read or assign op?
605 bool used;
606 };
607
608 // Finds the ResourceArgUseInfo for each resource argument. Forwarding to the
609 // output (i.e., the argument is an operand of the return op) is not considered
610 // as a use. This doesn't support nesting of ops, so before calling this, nested
611 // ops/functions need to be already resource-lifted.
FindResourceArgUseInfo(func::FuncOp func_op,llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> * result)612 LogicalResult FindResourceArgUseInfo(
613 func::FuncOp func_op,
614 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>* result) {
615 auto return_op = func_op.front().getTerminator();
616 for (auto arg : TF::filter_resources(func_op.getArguments())) {
617 ResourceArgUseInfo info;
618 info.used = false;
619 info.updated = false;
620 bool read_or_assigned = false;
621 bool used_in_unsupported_op = false;
622 for (auto user : arg.getUsers()) {
623 if (user == return_op) continue;
624 info.used = true;
625 if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
626 read_or_assigned = true;
627 info.data_type = read.getType();
628 continue;
629 }
630
631 if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(user)) {
632 read_or_assigned = true;
633 info.updated = true;
634 info.data_type = assign.value().getType();
635 continue;
636 }
637
638 used_in_unsupported_op = true;
639 break;
640 }
641
642 // If the arg is used in an unsupported op, skip lifting it.
643 if (used_in_unsupported_op) continue;
644 (*result)[arg.getArgNumber()] = info;
645 }
646 return success();
647 }
648
649 // Merges two sets of resource arg use infos. An argument is considered used in
650 // the merged result as long as either set marks it as used. This is used to
651 // merge results from functions that have aliasing inputs, e.g., a while loop's
652 // body and condition. The sets of keys of the two maps must be the same.
MergeArgResourceUseInfo(const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos0,const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos1)653 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> MergeArgResourceUseInfo(
654 const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos0,
655 const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos1) {
656 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> result;
657 for (const auto& entry : infos0) {
658 auto info1_it = infos1.find(entry.getFirst());
659 // If the entry is missing in any input, we should not touch this entry.
660 if (info1_it == infos1.end()) continue;
661 auto& info = result[entry.getFirst()];
662 info = entry.getSecond();
663 if (info.updated) continue;
664 if (info1_it->getSecond().used) {
665 info.used = true;
666 info.updated = info1_it->getSecond().updated;
667 info.data_type = info1_it->getSecond().data_type;
668 }
669 }
670 return result;
671 }
672
673 // Removes the unused resource arguments, and the return values that forward the
674 // removed arguments. If old_to_new_arg_indices is provided, it will store the
675 // new argument index that corresponds to each original index (-1 means it is
676 // removed). If remaining_resource_data_types is provided, it will store the
677 // data types of the remaining resource arguments, where the indices are after
678 // removing unused ones.
RemoveUnusedResourceArgumentsAndForwardedRetvals(const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos,func::FuncOp func_op,llvm::SmallVector<int64_t,4> * old_to_new_arg_indices=nullptr,llvm::SmallDenseMap<int64_t,Type> * remaining_resource_data_types=nullptr)679 void RemoveUnusedResourceArgumentsAndForwardedRetvals(
680 const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos,
681 func::FuncOp func_op,
682 llvm::SmallVector<int64_t, 4>* old_to_new_arg_indices = nullptr,
683 llvm::SmallDenseMap<int64_t, Type>* remaining_resource_data_types =
684 nullptr) {
685 // Remove return values forwarded from unused arguments.
686 auto return_op = func_op.front().getTerminator();
687 auto old_return_vals = llvm::to_vector<8>(return_op->getOperands());
688 int64_t skipped_retvals = 0;
689 for (auto entry : llvm::enumerate(old_return_vals)) {
690 auto return_val = entry.value();
691 if (auto arg = return_val.dyn_cast<BlockArgument>()) {
692 auto it = infos.find(arg.getArgNumber());
693 if (it != infos.end() && !it->getSecond().used) {
694 return_op->eraseOperand(entry.index() - skipped_retvals++);
695 }
696 }
697 }
698 llvm::BitVector indices_to_erase(func_op.getNumArguments());
699 llvm::SmallVector<Type, 4> new_types;
700 int64_t skipped_args = 0;
701 for (auto arg : func_op.getArguments()) {
702 auto it = infos.find(arg.getArgNumber());
703 if (it != infos.end() && !it->getSecond().used) {
704 indices_to_erase.set(arg.getArgNumber());
705 skipped_args++;
706 if (old_to_new_arg_indices != nullptr) {
707 old_to_new_arg_indices->push_back(-1);
708 }
709 } else {
710 new_types.push_back(arg.getType());
711 if (old_to_new_arg_indices != nullptr) {
712 old_to_new_arg_indices->push_back(arg.getArgNumber() - skipped_args);
713 }
714 if (it != infos.end() && remaining_resource_data_types != nullptr) {
715 (*remaining_resource_data_types)[arg.getArgNumber() - skipped_args] =
716 it->second.data_type;
717 }
718 }
719 }
720 func_op.eraseArguments(indices_to_erase);
721 func_op.setType(
722 FunctionType::get(func_op.getContext(), new_types,
723 llvm::to_vector<4>(return_op->getOperandTypes())));
724 }
725
726 // Lifts reads/writes of resource arguments from func_op and changes its
727 // signature. resource_data_types is the (index, data type) pair for each
728 // resource argument. handle_updated_arg_value is a caller-provided function
729 // that handles the updated value for an resource argument.
LiftArgRetResourcesForFunction(func::FuncOp func_op,const llvm::SmallDenseMap<int64_t,Type> & resource_data_types,llvm::function_ref<void (int64_t,Value)> handle_updated_arg_value)730 LogicalResult LiftArgRetResourcesForFunction(
731 func::FuncOp func_op,
732 const llvm::SmallDenseMap<int64_t, Type>& resource_data_types,
733 llvm::function_ref<void(int64_t, Value)> handle_updated_arg_value) {
734 RegionResourceHoister hoister(func_op);
735 if (failed(hoister.Analyze())) return failure();
736
737 // Each of these resources could be read or written in the function. If its
738 // read, we need to replace the resource arg with a value arg to get the
739 // read value. If its written, we need to replace the write with an additional
740 // value to be written.
741
742 // Now create read values that will be used to replace each resource that
743 // is read in the function body. These read values are just the same argument
744 // with type replaced.
745 llvm::SmallVector<Value, 4> skipped_args;
746 for (auto& it : hoister.GetResources()) {
747 BlockArgument arg = it.first.dyn_cast<BlockArgument>();
748 assert(arg && "Expect resources for FuncOp to be its arguments");
749 auto type_iter = resource_data_types.find(arg.getArgNumber());
750 if (type_iter == resource_data_types.end()) {
751 // Skip lifting the resource if it's not present in the data type map.
752 // This indicates that the resource is not to be lifted because it is used
753 // in an unsupported op in some other function.
754 skipped_args.push_back(arg);
755 } else {
756 arg.setType(type_iter->second);
757 it.second.hoisted_read = arg;
758 }
759 }
760
761 // Drop all the args that have to be skipped.
762 for (Value arg : skipped_args) hoister.DropResource(arg);
763
764 hoister.ReplaceResourceLoads(/*read_only=*/false);
765
766 // For writes, invoke the callback and then erase the write.
767 auto assign_ops = func_op.front().getOps<TF::AssignVariableOp>();
768 for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) {
769 Value resource = assign_variable_op.resource();
770 if (!hoister.Contains(resource)) continue;
771
772 auto arg = resource.dyn_cast<BlockArgument>();
773 handle_updated_arg_value(arg.getArgNumber(), assign_variable_op.value());
774 assign_variable_op.erase();
775 }
776
777 func_op.setType(FunctionType::get(
778 func_op.getContext(), func_op.front().getArgumentTypes(),
779 func_op.front().getTerminator()->getOperandTypes()));
780
781 return success();
782 }
783
784 // Returns a vector filtered from range where the unused elements (specified by
785 // resource_arg_uses) are removed.
786 template <typename T, typename Range>
FilterRange(Range range,const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & resource_arg_uses)787 llvm::SmallVector<T, 4> FilterRange(
788 Range range,
789 const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& resource_arg_uses) {
790 llvm::SmallVector<T, 4> filtered;
791 for (auto entry : llvm::enumerate(range)) {
792 auto it = resource_arg_uses.find(entry.index());
793 if (it == resource_arg_uses.end() || it->getSecond().used)
794 filtered.push_back(entry.value());
795 }
796 return filtered;
797 }
798
799 // Changes the types of the control flow op (e.g., while, if) and adds loads and
800 // stores around it. arg_data_type_and_updated_output_index maps an operand (to
801 // be changed) index to its data type and the updated value index in the output
802 // (-1 means not updated.)
AddLoadsStoresOutsideControlFlowOp(Operation * caller,const llvm::SmallDenseMap<int64_t,std::pair<Type,int64_t>> & arg_data_type_and_updated_output_index)803 void AddLoadsStoresOutsideControlFlowOp(
804 Operation* caller,
805 const llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>&
806 arg_data_type_and_updated_output_index) {
807 OpBuilder builder(caller);
808 auto new_operands = llvm::to_vector<8>(caller->getOperands());
809 llvm::SmallVector<int64_t, 8> changed_indices;
810 // Find the operands to change, and create the loads.
811 for (auto& entry : arg_data_type_and_updated_output_index) {
812 int64_t index = entry.getFirst();
813 Type new_type = entry.getSecond().first;
814 int64_t updated_index = entry.getSecond().second;
815 auto operand = caller->getOperand(index);
816 builder.setInsertionPoint(caller);
817 new_operands[index] = builder.create<TF::ReadVariableOp>(
818 caller->getLoc(), ArrayRef<Type>{new_type}, ArrayRef<Value>{operand});
819 caller->setOperand(index, new_operands[index]);
820 if (updated_index < 0) continue;
821 builder.setInsertionPointAfter(caller);
822 builder.create<TF::AssignVariableOp>(
823 caller->getLoc(), ArrayRef<Type>{},
824 ArrayRef<Value>{operand, caller->getResult(updated_index)});
825 }
826 }
827
828 // Lifts loads/stores from while loop's body and cond functions.
HandleWhileLoop(TF::WhileOp while_op,func::FuncOp body,func::FuncOp cond)829 LogicalResult HandleWhileLoop(TF::WhileOp while_op, func::FuncOp body,
830 func::FuncOp cond) {
831 auto return_op = body.front().getTerminator();
832 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> body_use_info;
833 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> cond_use_info;
834 if (failed(FindResourceArgUseInfo(body, &body_use_info)) ||
835 failed(FindResourceArgUseInfo(cond, &cond_use_info))) {
836 return failure();
837 }
838 // A resource is considered used as long as it is used in either body or cond.
839 auto resource_arg_uses =
840 MergeArgResourceUseInfo(body_use_info, cond_use_info);
841 if (resource_arg_uses.empty()) return success();
842
843 // Remove unused resources in functions.
844 llvm::SmallVector<int64_t, 4> old_to_new_indices;
845 llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
846 RemoveUnusedResourceArgumentsAndForwardedRetvals(
847 resource_arg_uses, body, &old_to_new_indices,
848 &remaining_resource_data_types);
849 RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, cond);
850 (void)LiftArgRetResourcesForFunction(
851 body, remaining_resource_data_types,
852 [&](int64_t index, Value value) { return_op->setOperand(index, value); });
853 (void)LiftArgRetResourcesForFunction(cond, remaining_resource_data_types,
854 [&](int64_t index, Value value) {
855 // We already checked that cond should
856 // not have variable writes.
857 assert(false && "Should not happen");
858 });
859 // Recreate the while op.
860 OpBuilder builder(while_op);
861 // Now use the filtered original operands, which will be replaced by
862 // AddLoadsStoresOutsideControlFlowOp().
863 auto new_while = builder.create<TF::WhileOp>(
864 while_op.getLoc(), body.getFunctionType().getResults(),
865 FilterRange<Value, OperandRange>(while_op.getOperands(),
866 resource_arg_uses),
867 while_op->getAttrs());
868 // Prepare for AddLoadsStoresOutsideControlFlowOp().
869 llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
870 arg_data_type_and_updated_output_index;
871 for (const auto& entry : remaining_resource_data_types) {
872 int64_t update_index = return_op->getOperand(entry.getFirst()) ==
873 body.getArgument(entry.getFirst())
874 ? -1
875 : entry.getFirst();
876 arg_data_type_and_updated_output_index[entry.getFirst()] = {
877 entry.getSecond(), update_index};
878 }
879 AddLoadsStoresOutsideControlFlowOp(new_while,
880 arg_data_type_and_updated_output_index);
881 // Replace uses.
882 for (int64_t i = 0, end = old_to_new_indices.size(); i < end; ++i) {
883 if (old_to_new_indices[i] >= 0) {
884 while_op.getResult(i).replaceAllUsesWith(
885 new_while.getResult(old_to_new_indices[i]));
886 }
887 }
888 while_op.erase();
889 return success();
890 }
891
892 // Lifts loads/stores from an IfOp or CaseOp's branches.
893 template <class CaseOrIfOp>
HandleCaseOrIfOp(CaseOrIfOp op,ArrayRef<func::FuncOp> branches)894 LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef<func::FuncOp> branches) {
895 // For canonicalized If/Case, there should not be any resource outputs
896 int64_t non_resource_results = op.getNumResults();
897
898 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> resource_arg_uses;
899 if (failed(FindResourceArgUseInfo(branches.front(), &resource_arg_uses)))
900 return failure();
901
902 for (auto func : branches.drop_front()) {
903 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> branch_use_info;
904 if (failed(FindResourceArgUseInfo(func, &branch_use_info)))
905 return failure();
906 // A resource is considered used as long as it is used in either branch.
907 resource_arg_uses =
908 MergeArgResourceUseInfo(resource_arg_uses, branch_use_info);
909 }
910
911 if (resource_arg_uses.empty()) return success();
912 // Remove unused resources in functions.
913 llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
914 RemoveUnusedResourceArgumentsAndForwardedRetvals(
915 resource_arg_uses, branches.front(), /*old_to_new_arg_indices=*/nullptr,
916 &remaining_resource_data_types);
917 for (auto func : branches.drop_front())
918 RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, func);
919
920 // Forward resource inputs updated in any branch to the outputs of both
921 // branches. First prepare the mapping from arg to new update output.
922 llvm::SmallDenseMap<int64_t, int64_t> resource_arg_to_new_output;
923 {
924 int64_t removed_args = 0;
925 for (const auto& entry : resource_arg_uses) {
926 if (!entry.getSecond().used) {
927 removed_args++;
928 continue;
929 }
930 if (!entry.getSecond().updated) continue;
931 int64_t new_output_index =
932 non_resource_results + resource_arg_to_new_output.size();
933 resource_arg_to_new_output[entry.getFirst() - removed_args] =
934 new_output_index;
935 }
936 }
937
938 // Append resource updates to the return ops: now they are just forwarded
939 // input resources, but will be replaced by the data value in
940 // LiftArgRetResourcesForFunction().
941 for (auto branch : branches) {
942 auto new_retvals =
943 llvm::to_vector<4>(branch.front().getTerminator()->getOperands());
944 new_retvals.resize(new_retvals.size() + resource_arg_to_new_output.size());
945 for (const auto& entry : resource_arg_to_new_output) {
946 int64_t resource_arg_index = entry.getFirst();
947 int64_t output_index = entry.getSecond();
948 new_retvals[output_index] = branch.getArgument(resource_arg_index);
949 }
950 auto old_return = branch.front().getTerminator();
951 OpBuilder builder(old_return);
952 auto new_return =
953 builder.create<func::ReturnOp>(old_return->getLoc(), new_retvals);
954 old_return->erase();
955 (void)LiftArgRetResourcesForFunction(
956 branch, remaining_resource_data_types, [&](int64_t index, Value value) {
957 new_return.setOperand(resource_arg_to_new_output[index], value);
958 });
959 }
960
961 // Recreate the op without resource operands.
962 OpBuilder builder(op);
963 // Now use the filtered original operands, which will be replaced by
964 // AddLoadsStoresOutsideControlFlowOp().
965 auto new_operands =
966 FilterRange<Value, OperandRange>(op.input(), resource_arg_uses);
967 new_operands.insert(new_operands.begin(), op.getOperand(0));
968 func::FuncOp first_func = branches.front();
969 auto new_op = builder.create<CaseOrIfOp>(
970 op.getLoc(), first_func.getFunctionType().getResults(), new_operands,
971 op->getAttrs());
972 // Prepare for AddLoadsStoresOutsideControlFlowOp()
973 llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
974 arg_data_type_and_updated_output_index;
975 for (const auto& entry : remaining_resource_data_types) {
976 auto new_output_it = resource_arg_to_new_output.find(entry.getFirst());
977 int64_t update_index = new_output_it == resource_arg_to_new_output.end()
978 ? -1
979 : new_output_it->getSecond();
980 arg_data_type_and_updated_output_index[entry.getFirst() + 1] = {
981 entry.getSecond(), update_index};
982 }
983 AddLoadsStoresOutsideControlFlowOp(new_op,
984 arg_data_type_and_updated_output_index);
985 // Replace uses.
986 op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults()));
987 op.erase();
988 return success();
989 }
990
991 // A resource-lifted function for (potentially multiple) PartitionedCallOps and
992 // information about the lifting changes.
993 struct PartitionedCallLiftingInfo {
994 // Function with resources lifted. Can be nullptr if nothing needs to change.
995 func::FuncOp lifted_callee;
996 // Mapping from old resource outputs to their aliasing output inputs.
997 llvm::SmallDenseMap<int64_t, int64_t> old_outputs_aliasing_old_inputs;
998 // Mapping from old to new output indices in case any output is removed.
999 llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
1000 // ResourceArgUseInfo for each old resource argument.
1001 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> use_info;
1002 // Input for AddLoadsStoresOutsideControlFlowOp(), see its comment.
1003 llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
1004 arg_data_type_and_updated_output_index;
1005 };
1006
1007 // Lifts loads/stores from a PartitionedCallOp's callee function. If anything
1008 // needs to be changed, the original function will be preserved, and the lifting
1009 // happens on a clone, which will be stored in `result`.
HandlePartitionedCallOpCallee(func::FuncOp callee,PartitionedCallLiftingInfo * result)1010 LogicalResult HandlePartitionedCallOpCallee(
1011 func::FuncOp callee, PartitionedCallLiftingInfo* result) {
1012 // Sanity check: return of resources should be aliases of inputs. Such outputs
1013 // will be removed later.
1014 int64_t non_resource_results = 0;
1015 for (auto entry :
1016 llvm::enumerate(callee.front().getTerminator()->getOperands())) {
1017 auto retval = entry.value();
1018 if (!getElementTypeOrSelf(retval.getType()).isa<TF::ResourceType>()) {
1019 result->old_to_new_output_indices.push_back(non_resource_results++);
1020 continue;
1021 }
1022 auto aliasing_arg = retval.dyn_cast<BlockArgument>();
1023 if (!aliasing_arg) {
1024 return callee.emitOpError("unsupported function call: ")
1025 << "resource return value does not alias an input.";
1026 }
1027 result->old_outputs_aliasing_old_inputs[entry.index()] =
1028 aliasing_arg.getArgNumber();
1029 result->old_to_new_output_indices.push_back(-1);
1030 }
1031
1032 if (failed(FindResourceArgUseInfo(callee, &result->use_info))) {
1033 return failure();
1034 }
1035 if (result->use_info.empty()) {
1036 result->lifted_callee = nullptr;
1037 return success();
1038 }
1039
1040 // Clone the callee before making changes.
1041 SmallString<64> name_base = callee.getName();
1042 auto module = callee->getParentOfType<ModuleOp>();
1043 name_base += "_resource_lifted";
1044 auto name = name_base;
1045 callee = callee.clone();
1046 callee.setPrivate();
1047 callee.setName(mlir::StringAttr::get(callee->getContext(), name));
1048 SymbolTable(module).insert(callee);
1049 result->lifted_callee = callee;
1050
1051 // Remove unused resources in functions.
1052 llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
1053 RemoveUnusedResourceArgumentsAndForwardedRetvals(
1054 result->use_info, callee, /*old_to_new_arg_indices=*/nullptr,
1055 &remaining_resource_data_types);
1056 for (const auto& entry : remaining_resource_data_types) {
1057 result->arg_data_type_and_updated_output_index[entry.getFirst()] = {
1058 entry.getSecond(), -1};
1059 }
1060 llvm::SmallVector<int64_t, 4> retval_indices_to_preserve;
1061 for (auto& val : callee.front().getTerminator()->getOpOperands()) {
1062 // Store indices of results that are not resources.
1063 if (!getElementTypeOrSelf(val.get().getType()).isa<TF::ResourceType>())
1064 retval_indices_to_preserve.push_back(val.getOperandNumber());
1065 }
1066 int64_t num_retvals = retval_indices_to_preserve.size();
1067 llvm::SmallVector<Value, 4> new_retvals;
1068 // Lift resources.
1069 (void)LiftArgRetResourcesForFunction(
1070 callee, remaining_resource_data_types, [&](int64_t index, Value value) {
1071 result->arg_data_type_and_updated_output_index[index].second =
1072 num_retvals++;
1073 new_retvals.push_back(value);
1074 });
1075
1076 auto old_return = callee.front().getTerminator();
1077 llvm::SmallVector<Value, 4> old_and_new_retvals;
1078 old_and_new_retvals.reserve(retval_indices_to_preserve.size() +
1079 new_retvals.size());
1080 for (int64_t retval_index : retval_indices_to_preserve)
1081 old_and_new_retvals.push_back(old_return->getOperand(retval_index));
1082
1083 old_and_new_retvals.append(new_retvals.begin(), new_retvals.end());
1084 // Replace old return with the new ones with update values.
1085 OpBuilder builder(old_return);
1086 auto new_return =
1087 builder.create<func::ReturnOp>(old_return->getLoc(), old_and_new_retvals);
1088 old_return->erase();
1089 callee.setType(FunctionType::get(
1090 callee.getContext(), callee.getFunctionType().getInputs(),
1091 llvm::to_vector<4>(new_return.getOperandTypes())));
1092 return success();
1093 }
1094
1095 // Updates a PartitionedCallOp/StatefulPartitionedCallOp according to the
1096 // resource-lifted new callee function in lifting_info.
1097 template <typename CallOpType>
UpdatePartitionedCallOpWithNewCallee(CallOpType call_op,PartitionedCallLiftingInfo & lifting_info)1098 void UpdatePartitionedCallOpWithNewCallee(
1099 CallOpType call_op, PartitionedCallLiftingInfo& lifting_info) {
1100 if (!lifting_info.lifted_callee) return;
1101 // Replace output resource uses with the aliasing input, so that we can remove
1102 // this output.
1103 for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) {
1104 call_op.getResult(entry.getFirst())
1105 .replaceAllUsesWith(call_op.getOperand(entry.getSecond()));
1106 }
1107 // Recreate the call op.
1108 OpBuilder builder(call_op);
1109 // Now use the filtered original operands, which will be replaced by
1110 // AddLoadsStoresOutsideControlFlowOp().
1111 auto new_operands =
1112 FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
1113 auto new_call = builder.create<CallOpType>(
1114 call_op.getLoc(),
1115 lifting_info.lifted_callee.getFunctionType().getResults(), new_operands,
1116 call_op->getAttrs());
1117 new_call->setAttr("f",
1118 SymbolRefAttr::get(builder.getContext(),
1119 lifting_info.lifted_callee.getName()));
1120 AddLoadsStoresOutsideControlFlowOp(
1121 new_call, lifting_info.arg_data_type_and_updated_output_index);
1122 // Replace uses.
1123 for (int64_t i = 0, end = lifting_info.old_to_new_output_indices.size();
1124 i < end; ++i) {
1125 if (lifting_info.old_to_new_output_indices[i] >= 0) {
1126 call_op.getResult(i).replaceAllUsesWith(
1127 new_call.getResult(lifting_info.old_to_new_output_indices[i]));
1128 }
1129 }
1130 call_op.erase();
1131 }
1132
1133 LogicalResult HoistForControlFlow(
1134 Block*, ModuleOp, bool,
1135 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*);
1136
1137 // A templated routine for handling both PartitionedCallOp and
1138 // StatefulPartitionedCallOp. If the callee is already lifted, it just updates
1139 // the caller op itself; otherwise, it first recursively handles nested control
1140 // flow, then performs lifting on the callee.
1141 template <typename CallOpType>
HandlePartitionedCallOp(CallOpType call_op,func::FuncOp callee,ModuleOp module,bool vars_initialized,llvm::SmallDenseMap<llvm::StringRef,PartitionedCallLiftingInfo> * lifted_callees)1142 LogicalResult HandlePartitionedCallOp(
1143 CallOpType call_op, func::FuncOp callee, ModuleOp module,
1144 bool vars_initialized,
1145 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
1146 lifted_callees) {
1147 auto emplace_res = lifted_callees->try_emplace(callee.getName(),
1148 PartitionedCallLiftingInfo());
1149 if (emplace_res.second) {
1150 // Unseen callee. Perform resource lifting on it.
1151 if (failed(HoistForControlFlow(&callee.front(), module, vars_initialized,
1152 lifted_callees)))
1153 return failure();
1154
1155 if (failed(HandlePartitionedCallOpCallee(
1156 callee, &emplace_res.first->getSecond()))) {
1157 return failure();
1158 }
1159 }
1160 UpdatePartitionedCallOpWithNewCallee(call_op, emplace_res.first->getSecond());
1161 return success();
1162 }
1163
1164 // Hoists resource loads/stores from control flow ops in `block` outside the
1165 // body/cond/branch/callee functions.
HoistForControlFlow(Block * block,ModuleOp module,bool vars_initialized,llvm::SmallDenseMap<llvm::StringRef,PartitionedCallLiftingInfo> * lifted_partitioned_call_callees)1166 LogicalResult HoistForControlFlow(
1167 Block* block, ModuleOp module, bool vars_initialized,
1168 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
1169 lifted_partitioned_call_callees) {
1170 if (vars_initialized) SetAllVarIsInitializedToTrue(block);
1171
1172 for (Operation& op : llvm::make_early_inc_range(*block)) {
1173 if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
1174 auto body = while_op.body_function();
1175 auto cond = while_op.cond_function();
1176 // Recursively handle the nested control flow.
1177 (void)HoistForControlFlow(&body.front(), module, vars_initialized,
1178 lifted_partitioned_call_callees);
1179 (void)HoistForControlFlow(&cond.front(), module, vars_initialized,
1180 lifted_partitioned_call_callees);
1181 if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
1182 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
1183 auto then_branch = if_op.then_function();
1184 auto else_branch = if_op.else_function();
1185 // Recursively handle the nested control flow.
1186 (void)HoistForControlFlow(&then_branch.front(), module, vars_initialized,
1187 lifted_partitioned_call_callees);
1188 (void)HoistForControlFlow(&else_branch.front(), module, vars_initialized,
1189 lifted_partitioned_call_callees);
1190 if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch})))
1191 return failure();
1192 } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
1193 SmallVector<func::FuncOp, 4> branch_functions;
1194 case_op.get_branch_functions(branch_functions);
1195 for (func::FuncOp func : branch_functions) {
1196 // Recursively handle the nested control flow.
1197 (void)HoistForControlFlow(&func.front(), module, vars_initialized,
1198 lifted_partitioned_call_callees);
1199 }
1200 if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure();
1201 } else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
1202 auto callee = call_op.func();
1203 if (!callee) {
1204 return call_op.emitOpError(
1205 "resource lifting does not support call with nested references.");
1206 }
1207 if (failed(HandlePartitionedCallOp(call_op, callee, module,
1208 vars_initialized,
1209 lifted_partitioned_call_callees))) {
1210 // Nested control flow handling is done in HandlePartitionedCallOp().
1211 return failure();
1212 }
1213 } else if (auto call_op =
1214 llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
1215 if (failed(HandlePartitionedCallOp(call_op, call_op.func(), module,
1216 vars_initialized,
1217 lifted_partitioned_call_callees))) {
1218 return failure();
1219 }
1220 } else if (isa<TF::IfRegionOp, TF::CaseRegionOp, TF::WhileRegionOp>(op)) {
1221 for (Region& region : op.getRegions())
1222 (void)HoistForControlFlow(®ion.front(), module, vars_initialized,
1223 lifted_partitioned_call_callees);
1224 LogicalResult result = RegionResourceHoister::ReplaceOpWithNewOp(&op);
1225 if (failed(result)) return failure();
1226 }
1227 }
1228
1229 // After we have hoisted operations in the block, we may have added new read
1230 // and writes of resources to this block. Clean them up by doing store-load
1231 // forwarding.
1232 ForwardStoreToLoad(block);
1233 return success();
1234 }
1235
1236 // Lifts resource operation from tf_device.cluster ops nested in `op` outside.
1237 // Returns failure if there are remaining resource-type values that can not be
1238 // lifted.
runOnOperation()1239 void ResourceOpLiftingPass::runOnOperation() {
1240 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
1241 lifted_partitioned_call_callees;
1242 ModuleOp module = getOperation();
1243
1244 if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(module)))
1245 return signalPassFailure();
1246
1247 auto walk_result = module.walk([&](func::FuncOp func_op) {
1248 return func_op.walk([&](tf_device::ClusterOp cluster) {
1249 LogicalResult result = HoistForControlFlow(
1250 &cluster.GetBody(), module, /*vars_initialized=*/true,
1251 &lifted_partitioned_call_callees);
1252 if (failed(result)) return WalkResult::interrupt();
1253 result = RegionResourceHoister::ReplaceOpWithNewOp(cluster);
1254 if (failed(result)) return WalkResult::interrupt();
1255 return WalkResult::advance();
1256 });
1257 });
1258
1259 if (walk_result.wasInterrupted()) return signalPassFailure();
1260 }
1261
1262 struct ResourceOpLiftingForMainFunctionPass
1263 : public TFDevice::ResourceOpLiftingForMainFunctionPassBase<
1264 ResourceOpLiftingForMainFunctionPass> {
1265 void runOnOperation() override;
1266 };
1267
runOnOperation()1268 void ResourceOpLiftingForMainFunctionPass::runOnOperation() {
1269 ModuleOp module = getOperation();
1270 func::FuncOp main_func = module.lookupSymbol<func::FuncOp>("main");
1271 if (!main_func) {
1272 return;
1273 }
1274
1275 if (failed(TF::ResourceLiftingForFunctionalControlFlow(main_func))) {
1276 return signalPassFailure();
1277 }
1278 }
1279
1280 } // namespace
1281
1282 namespace TFDevice {
CreateResourceOpLiftingPass()1283 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceOpLiftingPass() {
1284 return std::make_unique<ResourceOpLiftingPass>();
1285 }
1286
1287 std::unique_ptr<OperationPass<ModuleOp>>
CreateResourceOpLiftingForMainFunctionPass()1288 CreateResourceOpLiftingForMainFunctionPass() {
1289 return std::make_unique<ResourceOpLiftingForMainFunctionPass>();
1290 }
1291
1292 } // namespace TFDevice
1293
1294 namespace TF {
ResourceLiftingForFunctionalControlFlow(func::FuncOp function)1295 LogicalResult ResourceLiftingForFunctionalControlFlow(func::FuncOp function) {
1296 // This routine should only be called when control flow operations are still
1297 // represented with TF IfOp and WhileOp operations. In this case, there should
1298 // be only one basic blocks in the MLIR representation.
1299 if (!llvm::hasSingleElement(function)) {
1300 return function.emitError()
1301 << "expect the function to have 1 block while it has "
1302 << function.getBlocks().size();
1303 }
1304
1305 if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(function)))
1306 return failure();
1307
1308 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
1309 lifted_partitioned_call_callees;
1310 if (failed(HoistForControlFlow(
1311 &function.front(), cast<ModuleOp>(function->getParentOp()),
1312 /*vars_initialized=*/false, &lifted_partitioned_call_callees)))
1313 return failure();
1314
1315 // Clean up and canonicalize to remove dead local variables as some local
1316 // variables might be dead after hoisting resource loads/stores from control
1317 // flow ops.
1318 return TF::CleanupAndCanonicalizeForResourceOpLifting(function);
1319 }
1320 } // namespace TF
1321
1322 } // namespace mlir
1323