xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass lifts resource variable operations outside of device computation.
17 
18 #include <cstddef>
19 #include <cstdint>
20 
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Block.h"  // from @llvm-project
33 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
34 #include "mlir/IR/Builders.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
36 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
37 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
38 #include "mlir/IR/Operation.h"  // from @llvm-project
39 #include "mlir/IR/Region.h"  // from @llvm-project
40 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
41 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
42 #include "mlir/IR/Types.h"  // from @llvm-project
43 #include "mlir/IR/Value.h"  // from @llvm-project
44 #include "mlir/IR/Verifier.h"  // from @llvm-project
45 #include "mlir/IR/Visitors.h"  // from @llvm-project
46 #include "mlir/Pass/Pass.h"  // from @llvm-project
47 #include "mlir/Support/LLVM.h"  // from @llvm-project
48 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
49 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
50 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
52 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
53 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
54 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
55 #include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h"
56 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes_detail.h"
57 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
58 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
59 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
60 #include "tensorflow/core/framework/tensor_shape.pb.h"
61 
62 namespace mlir {
63 
64 namespace {
65 
66 constexpr char kDeviceAttr[] = "device";
67 
68 // Lift resource operations out of device computation.
69 struct ResourceOpLiftingPass
70     : public TFDevice::ResourceOpLiftingPassBase<ResourceOpLiftingPass> {
71   void runOnOperation() override;
72 };
73 
IsResource(Value value)74 bool IsResource(Value value) {
75   return getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>();
76 }
77 
78 // Get the type of the data contained in a resource. Returns null if there is
79 // no single type in the resource.
GetResourceSubtype(Value value)80 Type GetResourceSubtype(Value value) {
81   auto resource_type =
82       getElementTypeOrSelf(value.getType()).dyn_cast<TF::ResourceType>();
83   auto subtypes = resource_type.getSubtypes();
84   if (subtypes.size() == 1) return subtypes[0];
85   return nullptr;
86 }
87 
88 // Replaces all `tf.VarIsInitializedOp` in a block with a constant true.
89 // TODO(b/171039585): Replace this with proper analysis of
90 // `tf.VarIsInitializedOp` in regards to resource writes and control flow.
SetAllVarIsInitializedToTrue(Block * block)91 void SetAllVarIsInitializedToTrue(Block* block) {
92   auto builder = OpBuilder::atBlockBegin(block);
93   TF::ConstOp const_true = nullptr;
94   for (auto op :
95        llvm::make_early_inc_range(block->getOps<TF::VarIsInitializedOp>())) {
96     builder.setInsertionPoint(op);
97     if (!const_true)
98       const_true = builder.create<TF::ConstOp>(
99           op.getLoc(),
100           DenseIntElementsAttr::get(
101               RankedTensorType::get(/*shape=*/{}, builder.getI1Type()), true));
102 
103     op.is_initialized().replaceAllUsesWith(const_true);
104     op.erase();
105   }
106 }
107 
108 // Performs store-load forwarding. This effectively removes
109 // 1) Any resource loads after a store to that same resource is done
110 // 2) Any resource stores except the last one.
111 // TODO(ycao): Store-load forwarding implemented here is only correct when
112 // computation is purely sequential (no concurrency). Need to support concurrent
113 // computation as well.
ForwardStoreToLoad(Block * block)114 void ForwardStoreToLoad(Block* block) {
115   // resource_handle_to_last_store_op keeps track of the most recent (last)
116   // store to each resource. Non-existent entry indicates that a resource has
117   // not been stored to yet.
118   llvm::SmallDenseMap<Value, TF::AssignVariableOp>
119       resource_handle_to_last_store_op;
120 
121   // Only iterate through ops directly in the block as we can't handle ops
122   // nested deeper in regions.
123   for (Operation& op : llvm::make_early_inc_range(*block)) {
124     if (auto read_variable_op = dyn_cast<TF::ReadVariableOp>(&op)) {
125       Value resource = read_variable_op.resource();
126       auto last_store = resource_handle_to_last_store_op[resource];
127       if (!last_store) continue;
128 
129       // Use stored value in last_store to replace all uses of current resource
130       // load's result, then erase this resource load. Add an intermediate
131       // CastOp if the shape of types doesn't exactly match.
132       Type read_type = read_variable_op.value().getType();
133       if (read_type != last_store.value().getType()) {
134         OpBuilder builder(last_store);
135         builder.setInsertionPointAfter(last_store);
136         auto cast = builder.create<TF::CastOp>(
137             last_store.getLoc(), read_type, last_store.value(),
138             /*Truncate=*/builder.getBoolAttr(false));
139         read_variable_op.value().replaceAllUsesWith(cast);
140       } else {
141         read_variable_op.value().replaceAllUsesWith(last_store.value());
142       }
143 
144       read_variable_op.erase();
145       continue;
146     }
147 
148     if (auto assign_variable_op = dyn_cast<TF::AssignVariableOp>(&op)) {
149       Value resource = assign_variable_op.resource();
150       auto last_store = resource_handle_to_last_store_op[resource];
151       // Previous store ops to same resource can be erased.
152       if (last_store) last_store.erase();
153 
154       resource_handle_to_last_store_op[resource] = assign_variable_op;
155     }
156   }
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // RegionResourceHoister
161 //===----------------------------------------------------------------------===//
162 
163 // Helper class to hoist resource ops out of regions attached to an op.
164 class RegionResourceHoister {
165  public:
RegionResourceHoister(Operation * op)166   explicit RegionResourceHoister(Operation* op) : op_(op) {}
167 
168   // Analyzes attached regions to record resources read and written.
169   LogicalResult Analyze();
170 
171   // Returns all resources accessed by the regions attached the op.
GetResources()172   auto& GetResources() { return resources_; }
173 
174   // Returns if the given value is a resource that needs lifting.
Contains(Value resource) const175   bool Contains(Value resource) const {
176     return resources_.find(resource) != resources_.end();
177   }
178 
179   // Drops the given resource from lifting.
DropResource(Value resource)180   void DropResource(Value resource) {
181     resources_.erase(resource);
182     written_resources_.remove(resource);
183   }
184 
185   // Replaces all resource loads in all regions attached to the op.
ReplaceResourceLoads(bool read_only)186   void ReplaceResourceLoads(bool read_only) {
187     llvm::for_each(op_->getRegions(), [&](Region& region) {
188       ReplaceResourceLoads(region, read_only);
189     });
190   }
191 
192   static LogicalResult ReplaceOpWithNewOp(Operation* op);
193 
194  private:
195   // Returns if any resources need lifting.
NeedsLifting() const196   bool NeedsLifting() const { return !resources_.empty(); }
197 
198   // Returns the number of results generated by the lifted op.
GetLiftedNumResults() const199   int GetLiftedNumResults() const { return num_new_results_; }
200 
201   // Generates hoisted reads for resources that need them before the op.
202   void GenerateHoistedReads();
203 
204   // Replaces all resource loads in the given region with hoisted loads. If
205   // `read_only` is true, limit this replacement to read only resources.
206   void ReplaceResourceLoads(Region& region, bool read_only);
207 
208   // Appends final values writte to resources to the region returns for the
209   // given set of regions.
210   void AppendResourceStoreValueToReturn(RegionRange regions);
211 
212   // Performs the final replacement of the op.
213   void ReplaceOpWithNewOp();
214 
215   // Returns is this resource was written to in any of the regions.
IsWritten(Value resource) const216   bool IsWritten(Value resource) const {
217     return written_resources_.contains(resource);
218   }
219 
220   static LogicalResult HoistResourcesOutOfIfCaseCluster(Operation* op);
221   static LogicalResult HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op);
222 
223   Operation* op_;
224 
225   // Per resource information about accesses to that resource.
226   struct ResourceInfo {
227     // Is this resource read in any of the regions?
228     bool is_read;
229     // Is this resource written in any of the regions?
230     bool is_written;
231     // Is this resource written in all of the regions?
232     bool is_written_all;
233     // The hoisted read used to replace region reads.
234     Value hoisted_read;
235     // the type of the data held by the resource.
236     Type data_type;
237     // For written resources, the result # of the lifted op which will hold the
238     // value of the resource. This result will be used to generates writes to
239     // the resource after the lifted op.
240     int result_index;
241     // Attributes on the read operation.
242     DictionaryAttr read_attrs;
243     // Attributes on the write operation.
244     DictionaryAttr write_attrs;
245 
ResourceInfomlir::__anon60ae525e0111::RegionResourceHoister::ResourceInfo246     ResourceInfo()
247         : is_read(false),
248           is_written(false),
249           is_written_all(false),
250           hoisted_read(nullptr),
251           data_type(nullptr),
252           result_index(-1) {}
253 
IsResultIndexAssignedmlir::__anon60ae525e0111::RegionResourceHoister::ResourceInfo254     bool IsResultIndexAssigned() { return result_index != -1; }
255 
256     // Refine the resource type using the given type `type`.
RefineTypemlir::__anon60ae525e0111::RegionResourceHoister::ResourceInfo257     void RefineType(Type type) {
258       if (!data_type) {
259         data_type = type;
260       } else {
261         data_type = TF::GetCastCompatibleType(data_type, type,
262                                               /*may_ignore_ref_type_a=*/false);
263         assert(data_type != nullptr && "Resource used with incompatible types");
264       }
265     }
266   };
267   llvm::MapVector<Value, ResourceInfo> resources_;
268   llvm::SetVector<Value> written_resources_;
269   // number of new results after lifting.
270   int num_new_results_;
271 };
272 
273 // Analyzes resources that are read or written within attached regions.
Analyze()274 LogicalResult RegionResourceHoister::Analyze() {
275   // Hoisting of child regions might have created opportunity for store-load
276   // forwarding.
277   for (Region& region : op_->getRegions()) {
278     ForwardStoreToLoad(&region.front());
279   }
280 
281   llvm::SetVector<Value> all_resources;
282   bool is_func = false;
283   // For functions, the resources to analyze are the function arguments.
284   // Otherwise, its the region captures.
285   if (func::FuncOp func = dyn_cast<func::FuncOp>(op_)) {
286     is_func = true;
287     Region& body = func.getBody();
288     for (BlockArgument arg : body.getArguments()) {
289       if (IsResource(arg)) all_resources.insert(arg);
290     }
291   } else {
292     getUsedValuesDefinedAbove(op_->getRegions(), all_resources);
293     all_resources.remove_if([](Value value) { return !IsResource(value); });
294   }
295 
296   num_new_results_ = op_->getNumResults();
297 
298   for (auto resource : all_resources) {
299     ResourceInfo info;
300     info.data_type = GetResourceSubtype(resource);
301     llvm::BitVector written_regions(op_->getNumRegions());
302     bool unsupported_use = false;
303     for (OpOperand& use : resource.getUses()) {
304       Operation* user = use.getOwner();
305       // If the user is not in one of the regions, we are not interested in it.
306       // Since all the sub-regions within this region (i.e., regions attached to
307       // op's in this region) have themselves gone through lifting, all resource
308       // users are expected to be operations in this region and not embedded
309       // within other sub-regions attached to op's in this region. So the check
310       // for whether a user is in one of the regions attached to this op is
311       // straightforward.
312       if (user->getParentRegion()->getParentOp() != op_) continue;
313 
314       // For functions, if the resource is used as a return operand, use that
315       // as its result index.
316       if (is_func && isa<func::ReturnOp>(user)) {
317         assert(!info.IsResultIndexAssigned() &&
318                "Expect resource argument to returned no more than once");
319         info.result_index = use.getOperandNumber();
320         continue;
321       }
322 
323       auto read = dyn_cast<TF::ReadVariableOp>(user);
324       auto write = dyn_cast<TF::AssignVariableOp>(user);
325       if (!read && !write) {
326         unsupported_use = true;
327         break;
328       }
329 
330       if (read && !info.is_read) {
331         info.is_read = true;
332         info.RefineType(read.value().getType());
333         info.read_attrs = user->getAttrDictionary();
334       }
335 
336       if (write) {
337         info.is_written = true;
338         info.RefineType(write.value().getType());
339         info.write_attrs = user->getAttrDictionary();
340         written_regions.set(user->getParentRegion()->getRegionNumber());
341       }
342     }
343 
344     // If the resource is used in an op that we do not understand, skip
345     // lifting for that resource.
346     if (unsupported_use) continue;
347 
348     info.is_written_all = written_regions.count() == op_->getNumRegions();
349 
350     // If the resource is written in some but not all regions, we would need
351     // a read for the value before these regions. Note that this is applicable
352     // only to multi-region ops:
353     // If/Case: If not all regions write to the resource, post hoisting the read
354     //   value need to be routed through all paths that don't write.
355     // While: since while condition cannot write, any resource written in the
356     //   while body will need to be read as well in case the while body is never
357     //   executed.
358     // Both cases are handled by the condition below.
359     if (info.is_written && !info.is_written_all) info.is_read = true;
360 
361     // Allocate a result index for written resources that don't have one.
362     if (info.is_written) {
363       written_resources_.insert(resource);
364       if (!info.IsResultIndexAssigned()) info.result_index = num_new_results_++;
365     }
366 
367     resources_.insert({resource, info});
368   }
369   return success();
370 }
371 
372 // Generates hoisted reads for all resources that need them just before the op.
GenerateHoistedReads()373 void RegionResourceHoister::GenerateHoistedReads() {
374   OpBuilder builder(op_);
375   DictionaryAttr empty_attrs = builder.getDictionaryAttr({});
376   for (auto& resource_it : GetResources()) {
377     Value resource = resource_it.first;
378     auto& info = resource_it.second;
379 
380     if (info.is_read) {
381       Operation* read = builder.create<TF::ReadVariableOp>(
382           op_->getLoc(), info.data_type, resource);
383       read->setAttrs(info.read_attrs ? info.read_attrs : empty_attrs);
384       read->removeAttr(kDeviceAttr);
385       info.hoisted_read = read->getResult(0);
386     }
387   }
388 }
389 
390 // Replaces all resource reads with the hoisted read.
ReplaceResourceLoads(Region & region,bool read_only)391 void RegionResourceHoister::ReplaceResourceLoads(Region& region,
392                                                  bool read_only) {
393   assert(llvm::hasSingleElement(region) && "Expected single block region");
394   // Only iterate through ops directly in the body as we can't handle
395   // ops nested deeper in regions.
396   auto all_reads = region.front().getOps<TF::ReadVariableOp>();
397   for (auto read_op : llvm::make_early_inc_range(all_reads)) {
398     Value resource = read_op.resource();
399     if (!Contains(resource)) continue;
400 
401     ResourceInfo& info = resources_[resource];
402     // If replacing loads for read only resources, skip if the resource
403     // was written to.
404     if (read_only && info.is_written) continue;
405 
406     read_op.replaceAllUsesWith(info.hoisted_read);
407     read_op.erase();
408   }
409 }
410 
411 // For written resources, add its value at the end of each region to that
412 // regions return value. For a region, its value at the end may be a value
413 // written to that resource in that region, or its hoisted read value if the
414 // resource is not written in that region. The return value can be vended out
415 // either as an existing return value, or a newly allocated return value.
AppendResourceStoreValueToReturn(RegionRange regions)416 void RegionResourceHoister::AppendResourceStoreValueToReturn(
417     RegionRange regions) {
418   for (Region* region : regions) {
419     assert(llvm::hasSingleElement(*region) && "Expected single block region");
420     Block& front = region->front();
421     auto old_return = front.getTerminator();
422     assert(old_return->getNumOperands() == op_->getNumResults());
423     auto new_return_operands = llvm::to_vector<4>(old_return->getOperands());
424     new_return_operands.resize(num_new_results_);
425 
426     // initialize return values for written resources to be the hoisted reads.
427     for (Value resource : written_resources_) {
428       const ResourceInfo& info = resources_[resource];
429       new_return_operands[info.result_index] = info.hoisted_read;
430     }
431 
432     // Only iterate through ops directly in the body as op's embedded in child
433     // regions should have been lifted out.
434     auto assign_ops = front.getOps<TF::AssignVariableOp>();
435     for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) {
436       Value resource = assign_variable_op.resource();
437       if (!IsWritten(resource)) continue;
438 
439       // TODO(ycao): Prevent same value from being returned multiple times.
440       // TODO(ycao): Do not return resource store value if it is defined outside
441       // of cluster. Both of these can be post-resource-op-lifting cleanup
442       // passes.
443       int result_index = resources_[resource].result_index;
444       new_return_operands[result_index] = assign_variable_op.value();
445       assign_variable_op.erase();
446     }
447     old_return->setOperands(new_return_operands);
448   }
449 }
450 
451 // Replace the old op with a new op (with potentially additional results), and
452 // add stores to written resources after the new op.
ReplaceOpWithNewOp()453 void RegionResourceHoister::ReplaceOpWithNewOp() {
454   auto new_result_types = llvm::to_vector<4>(op_->getResultTypes());
455   int result_region = isa<TF::WhileRegionOp>(op_) ? 1 : 0;
456   Operation* terminator = op_->getRegion(result_region).front().getTerminator();
457   auto extra_result_types =
458       terminator->getOperands().drop_front(op_->getNumResults()).getTypes();
459   new_result_types.insert(new_result_types.end(), extra_result_types.begin(),
460                           extra_result_types.end());
461   OpBuilder builder(op_);
462   // Clone this old operation but with new result types.
463   Operation* new_op = Operation::create(
464       op_->getLoc(), op_->getName(), new_result_types, op_->getOperands(),
465       op_->getAttrs(), op_->getSuccessors(), op_->getNumRegions());
466   builder.insert(new_op);
467 
468   // Move regions to the new op.
469   for (auto it : llvm::zip(op_->getRegions(), new_op->getRegions())) {
470     Region& old_region = std::get<0>(it);
471     Region& new_region = std::get<1>(it);
472     new_region.takeBody(old_region);
473   }
474 
475   // Insert stores to all written resources.
476   for (Value resource : written_resources_) {
477     ResourceInfo& info = resources_[resource];
478     Value value_to_write = new_op->getResult(info.result_index);
479     Operation* write = builder.create<TF::AssignVariableOp>(
480         op_->getLoc(), resource, value_to_write);
481     write->setAttrs(info.write_attrs);
482     write->removeAttr(kDeviceAttr);
483   }
484 
485   // As a part of lifting, we either reuse an existing slot for resource type
486   // results or add a new slot. Resource type results should not have any uses
487   // to begin with. So we can safely replace each old op result with the
488   // corresponding new op result.
489   int old_num_results = op_->getNumResults();
490   op_->replaceAllUsesWith(new_op->getResults().take_front(old_num_results));
491   op_->erase();
492   op_ = nullptr;
493 }
494 
495 // Lift resource load and stores out of regions attached to `op`, where op is
496 // an If/case/cluster op.
HoistResourcesOutOfIfCaseCluster(Operation * op)497 LogicalResult RegionResourceHoister::HoistResourcesOutOfIfCaseCluster(
498     Operation* op) {
499   RegionResourceHoister hoister(op);
500   if (failed(hoister.Analyze())) return failure();
501 
502   // If there are no resource region captures, then nothing to do.
503   if (!hoister.NeedsLifting()) return success();
504 
505   // Start the transformation. For each region, replace the resource read with
506   // the value read before the op.
507   hoister.GenerateHoistedReads();
508   hoister.ReplaceResourceLoads(/*read_only=*/false);
509   hoister.AppendResourceStoreValueToReturn(op->getRegions());
510   hoister.ReplaceOpWithNewOp();
511   return success();
512 }
513 
514 // Lift resource loads and stores out of WhileRegion
HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op)515 LogicalResult RegionResourceHoister::HoistResourcesOutOfWhileRegion(
516     TF::WhileRegionOp op) {
517   // For WhileRegion, post canonicalization all resource used within the
518   // body and condition regions are replaced with captured values, so we do not
519   // need to take into account the body and condition region arguments.
520   RegionResourceHoister hoister(op);
521 
522   if (failed(hoister.Analyze())) return failure();
523 
524   // If there are no resource region captures, then nothing to do.
525   if (!hoister.NeedsLifting()) return success();
526 
527   // The resources captured for While loop fall into two categories:
528   // (a) read-only. These reads can be replaced by a hoisted read created
529   //        before the WhileOp (similar to if and case).
530   // (b) written: since the value is written in the loop (which can only in
531   //        loop body, all these will become loop variables. Since all resource
532   //        variables are removed from the loop variabled during
533   //        canonicalizationW, we need to create new operand/result slots. The
534   //        input operands for these slots are the read values
535   //        prior to the op, and all references to these are replaced by the
536   //        corresponding slot argument. We need to generate writes following
537   //        the while for these resources.
538   //
539   // Note that for WhileRegion ops, if a resource is written, it will be written
540   // only in the body and not the condition, so the hoister analysis will infer
541   // it as needing a read as well.
542 
543   // Generate hoisted reads before the while.
544   hoister.GenerateHoistedReads();
545 
546   // Replace just the read-only resources with the hoisted reads.
547   hoister.ReplaceResourceLoads(/*read_only=*/true);
548 
549   // For written resources, add additional operands to the while op.
550   int num_old_results = op.getNumResults();
551   int num_new_results = hoister.GetLiftedNumResults();
552   int num_extra_results = num_new_results - num_old_results;
553 
554   SmallVector<Type, 4> new_result_types;
555   SmallVector<Value, 4> new_while_operands;
556   new_result_types.resize(num_extra_results);
557   new_while_operands.resize(num_extra_results);
558 
559   for (auto& it : hoister.GetResources()) {
560     if (!it.second.is_written) continue;
561     int index = it.second.result_index - num_old_results;
562     new_result_types[index] = it.second.data_type;
563     new_while_operands[index] = it.second.hoisted_read;
564   }
565   op.getOperation()->insertOperands(op.getNumOperands(), new_while_operands);
566 
567   // Patch the cond and body regions to have additional arguments, and replace
568   // the remaining resource reads (which will be resource reads for written
569   // resources) with these arguments.
570   Location loc = op.getLoc();
571   for (Region* region : op.getRegions()) {
572     region->addArguments(new_result_types,
573                          SmallVector<Location>(new_result_types.size(), loc));
574     // Point hoisted read for written resources to the region's arguments.
575     for (auto& it : hoister.GetResources()) {
576       if (!it.second.is_written) continue;
577       it.second.hoisted_read = region->getArgument(it.second.result_index);
578     }
579     hoister.ReplaceResourceLoads(*region, /*read_only=*/false);
580   }
581 
582   // Add additional return values to body return. These correspond to values
583   // written to resources in the body region.
584   hoister.AppendResourceStoreValueToReturn(op.getRegions().drop_front());
585 
586   // Finally, create a new while with additional return values.
587   hoister.ReplaceOpWithNewOp();
588   return success();
589 }
590 
591 // Lift resources out of the regions attached to `op`
ReplaceOpWithNewOp(Operation * op)592 LogicalResult RegionResourceHoister::ReplaceOpWithNewOp(Operation* op) {
593   if (auto while_op = dyn_cast<TF::WhileRegionOp>(op))
594     return HoistResourcesOutOfWhileRegion(while_op);
595   return HoistResourcesOutOfIfCaseCluster(op);
596 }
597 
598 // Holds information about a function's use of a resource argument.
599 struct ResourceArgUseInfo {
600   // Data type of the data contained in the resource.
601   Type data_type;
602   // Is the resource argument used in an assign op?
603   bool updated;
604   // Is the resource argument used in a read or assign op?
605   bool used;
606 };
607 
608 // Finds the ResourceArgUseInfo for each resource argument. Forwarding to the
609 // output (i.e., the argument is an operand of the return op) is not considered
610 // as a use. This doesn't support nesting of ops, so before calling this, nested
611 // ops/functions need to be already resource-lifted.
FindResourceArgUseInfo(func::FuncOp func_op,llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> * result)612 LogicalResult FindResourceArgUseInfo(
613     func::FuncOp func_op,
614     llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>* result) {
615   auto return_op = func_op.front().getTerminator();
616   for (auto arg : TF::filter_resources(func_op.getArguments())) {
617     ResourceArgUseInfo info;
618     info.used = false;
619     info.updated = false;
620     bool read_or_assigned = false;
621     bool used_in_unsupported_op = false;
622     for (auto user : arg.getUsers()) {
623       if (user == return_op) continue;
624       info.used = true;
625       if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
626         read_or_assigned = true;
627         info.data_type = read.getType();
628         continue;
629       }
630 
631       if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(user)) {
632         read_or_assigned = true;
633         info.updated = true;
634         info.data_type = assign.value().getType();
635         continue;
636       }
637 
638       used_in_unsupported_op = true;
639       break;
640     }
641 
642     // If the arg is used in an unsupported op, skip lifting it.
643     if (used_in_unsupported_op) continue;
644     (*result)[arg.getArgNumber()] = info;
645   }
646   return success();
647 }
648 
649 // Merges two sets of resource arg use infos. An argument is considered used in
650 // the merged result as long as either set marks it as used. This is used to
651 // merge results from functions that have aliasing inputs, e.g., a while loop's
652 // body and condition. The sets of keys of the two maps must be the same.
MergeArgResourceUseInfo(const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos0,const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos1)653 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> MergeArgResourceUseInfo(
654     const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos0,
655     const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos1) {
656   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> result;
657   for (const auto& entry : infos0) {
658     auto info1_it = infos1.find(entry.getFirst());
659     // If the entry is missing in any input, we should not touch this entry.
660     if (info1_it == infos1.end()) continue;
661     auto& info = result[entry.getFirst()];
662     info = entry.getSecond();
663     if (info.updated) continue;
664     if (info1_it->getSecond().used) {
665       info.used = true;
666       info.updated = info1_it->getSecond().updated;
667       info.data_type = info1_it->getSecond().data_type;
668     }
669   }
670   return result;
671 }
672 
673 // Removes the unused resource arguments, and the return values that forward the
674 // removed arguments. If old_to_new_arg_indices is provided, it will store the
675 // new argument index that corresponds to each original index (-1 means it is
676 // removed). If remaining_resource_data_types is provided, it will store the
677 // data types of the remaining resource arguments, where the indices are after
678 // removing unused ones.
RemoveUnusedResourceArgumentsAndForwardedRetvals(const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos,func::FuncOp func_op,llvm::SmallVector<int64_t,4> * old_to_new_arg_indices=nullptr,llvm::SmallDenseMap<int64_t,Type> * remaining_resource_data_types=nullptr)679 void RemoveUnusedResourceArgumentsAndForwardedRetvals(
680     const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos,
681     func::FuncOp func_op,
682     llvm::SmallVector<int64_t, 4>* old_to_new_arg_indices = nullptr,
683     llvm::SmallDenseMap<int64_t, Type>* remaining_resource_data_types =
684         nullptr) {
685   // Remove return values forwarded from unused arguments.
686   auto return_op = func_op.front().getTerminator();
687   auto old_return_vals = llvm::to_vector<8>(return_op->getOperands());
688   int64_t skipped_retvals = 0;
689   for (auto entry : llvm::enumerate(old_return_vals)) {
690     auto return_val = entry.value();
691     if (auto arg = return_val.dyn_cast<BlockArgument>()) {
692       auto it = infos.find(arg.getArgNumber());
693       if (it != infos.end() && !it->getSecond().used) {
694         return_op->eraseOperand(entry.index() - skipped_retvals++);
695       }
696     }
697   }
698   llvm::BitVector indices_to_erase(func_op.getNumArguments());
699   llvm::SmallVector<Type, 4> new_types;
700   int64_t skipped_args = 0;
701   for (auto arg : func_op.getArguments()) {
702     auto it = infos.find(arg.getArgNumber());
703     if (it != infos.end() && !it->getSecond().used) {
704       indices_to_erase.set(arg.getArgNumber());
705       skipped_args++;
706       if (old_to_new_arg_indices != nullptr) {
707         old_to_new_arg_indices->push_back(-1);
708       }
709     } else {
710       new_types.push_back(arg.getType());
711       if (old_to_new_arg_indices != nullptr) {
712         old_to_new_arg_indices->push_back(arg.getArgNumber() - skipped_args);
713       }
714       if (it != infos.end() && remaining_resource_data_types != nullptr) {
715         (*remaining_resource_data_types)[arg.getArgNumber() - skipped_args] =
716             it->second.data_type;
717       }
718     }
719   }
720   func_op.eraseArguments(indices_to_erase);
721   func_op.setType(
722       FunctionType::get(func_op.getContext(), new_types,
723                         llvm::to_vector<4>(return_op->getOperandTypes())));
724 }
725 
726 // Lifts reads/writes of resource arguments from func_op and changes its
727 // signature. resource_data_types is the (index, data type) pair for each
728 // resource argument. handle_updated_arg_value is a caller-provided function
729 // that handles the updated value for an resource argument.
LiftArgRetResourcesForFunction(func::FuncOp func_op,const llvm::SmallDenseMap<int64_t,Type> & resource_data_types,llvm::function_ref<void (int64_t,Value)> handle_updated_arg_value)730 LogicalResult LiftArgRetResourcesForFunction(
731     func::FuncOp func_op,
732     const llvm::SmallDenseMap<int64_t, Type>& resource_data_types,
733     llvm::function_ref<void(int64_t, Value)> handle_updated_arg_value) {
734   RegionResourceHoister hoister(func_op);
735   if (failed(hoister.Analyze())) return failure();
736 
737   // Each of these resources could be read or written in the function. If its
738   // read, we need to replace the resource arg with a value arg to get the
739   // read value. If its written, we need to replace the write with an additional
740   // value to be written.
741 
742   // Now create read values that will be used to replace each resource that
743   // is read in the function body. These read values are just the same argument
744   // with type replaced.
745   llvm::SmallVector<Value, 4> skipped_args;
746   for (auto& it : hoister.GetResources()) {
747     BlockArgument arg = it.first.dyn_cast<BlockArgument>();
748     assert(arg && "Expect resources for FuncOp to be its arguments");
749     auto type_iter = resource_data_types.find(arg.getArgNumber());
750     if (type_iter == resource_data_types.end()) {
751       // Skip lifting the resource if it's not present in the data type map.
752       // This indicates that the resource is not to be lifted because it is used
753       // in an unsupported op in some other function.
754       skipped_args.push_back(arg);
755     } else {
756       arg.setType(type_iter->second);
757       it.second.hoisted_read = arg;
758     }
759   }
760 
761   // Drop all the args that have to be skipped.
762   for (Value arg : skipped_args) hoister.DropResource(arg);
763 
764   hoister.ReplaceResourceLoads(/*read_only=*/false);
765 
766   // For writes, invoke the callback and then erase the write.
767   auto assign_ops = func_op.front().getOps<TF::AssignVariableOp>();
768   for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) {
769     Value resource = assign_variable_op.resource();
770     if (!hoister.Contains(resource)) continue;
771 
772     auto arg = resource.dyn_cast<BlockArgument>();
773     handle_updated_arg_value(arg.getArgNumber(), assign_variable_op.value());
774     assign_variable_op.erase();
775   }
776 
777   func_op.setType(FunctionType::get(
778       func_op.getContext(), func_op.front().getArgumentTypes(),
779       func_op.front().getTerminator()->getOperandTypes()));
780 
781   return success();
782 }
783 
784 // Returns a vector filtered from range where the unused elements (specified by
785 // resource_arg_uses) are removed.
786 template <typename T, typename Range>
FilterRange(Range range,const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & resource_arg_uses)787 llvm::SmallVector<T, 4> FilterRange(
788     Range range,
789     const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& resource_arg_uses) {
790   llvm::SmallVector<T, 4> filtered;
791   for (auto entry : llvm::enumerate(range)) {
792     auto it = resource_arg_uses.find(entry.index());
793     if (it == resource_arg_uses.end() || it->getSecond().used)
794       filtered.push_back(entry.value());
795   }
796   return filtered;
797 }
798 
799 // Changes the types of the control flow op (e.g., while, if) and adds loads and
800 // stores around it. arg_data_type_and_updated_output_index maps an operand (to
801 // be changed) index to its data type and the updated value index in the output
802 // (-1 means not updated.)
AddLoadsStoresOutsideControlFlowOp(Operation * caller,const llvm::SmallDenseMap<int64_t,std::pair<Type,int64_t>> & arg_data_type_and_updated_output_index)803 void AddLoadsStoresOutsideControlFlowOp(
804     Operation* caller,
805     const llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>&
806         arg_data_type_and_updated_output_index) {
807   OpBuilder builder(caller);
808   auto new_operands = llvm::to_vector<8>(caller->getOperands());
809   llvm::SmallVector<int64_t, 8> changed_indices;
810   // Find the operands to change, and create the loads.
811   for (auto& entry : arg_data_type_and_updated_output_index) {
812     int64_t index = entry.getFirst();
813     Type new_type = entry.getSecond().first;
814     int64_t updated_index = entry.getSecond().second;
815     auto operand = caller->getOperand(index);
816     builder.setInsertionPoint(caller);
817     new_operands[index] = builder.create<TF::ReadVariableOp>(
818         caller->getLoc(), ArrayRef<Type>{new_type}, ArrayRef<Value>{operand});
819     caller->setOperand(index, new_operands[index]);
820     if (updated_index < 0) continue;
821     builder.setInsertionPointAfter(caller);
822     builder.create<TF::AssignVariableOp>(
823         caller->getLoc(), ArrayRef<Type>{},
824         ArrayRef<Value>{operand, caller->getResult(updated_index)});
825   }
826 }
827 
828 // Lifts loads/stores from while loop's body and cond functions.
HandleWhileLoop(TF::WhileOp while_op,func::FuncOp body,func::FuncOp cond)829 LogicalResult HandleWhileLoop(TF::WhileOp while_op, func::FuncOp body,
830                               func::FuncOp cond) {
831   auto return_op = body.front().getTerminator();
832   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> body_use_info;
833   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> cond_use_info;
834   if (failed(FindResourceArgUseInfo(body, &body_use_info)) ||
835       failed(FindResourceArgUseInfo(cond, &cond_use_info))) {
836     return failure();
837   }
838   // A resource is considered used as long as it is used in either body or cond.
839   auto resource_arg_uses =
840       MergeArgResourceUseInfo(body_use_info, cond_use_info);
841   if (resource_arg_uses.empty()) return success();
842 
843   // Remove unused resources in functions.
844   llvm::SmallVector<int64_t, 4> old_to_new_indices;
845   llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
846   RemoveUnusedResourceArgumentsAndForwardedRetvals(
847       resource_arg_uses, body, &old_to_new_indices,
848       &remaining_resource_data_types);
849   RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, cond);
850   (void)LiftArgRetResourcesForFunction(
851       body, remaining_resource_data_types,
852       [&](int64_t index, Value value) { return_op->setOperand(index, value); });
853   (void)LiftArgRetResourcesForFunction(cond, remaining_resource_data_types,
854                                        [&](int64_t index, Value value) {
855                                          // We already checked that cond should
856                                          // not have variable writes.
857                                          assert(false && "Should not happen");
858                                        });
859   // Recreate the while op.
860   OpBuilder builder(while_op);
861   // Now use the filtered original operands, which will be replaced by
862   // AddLoadsStoresOutsideControlFlowOp().
863   auto new_while = builder.create<TF::WhileOp>(
864       while_op.getLoc(), body.getFunctionType().getResults(),
865       FilterRange<Value, OperandRange>(while_op.getOperands(),
866                                        resource_arg_uses),
867       while_op->getAttrs());
868   // Prepare for AddLoadsStoresOutsideControlFlowOp().
869   llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
870       arg_data_type_and_updated_output_index;
871   for (const auto& entry : remaining_resource_data_types) {
872     int64_t update_index = return_op->getOperand(entry.getFirst()) ==
873                                    body.getArgument(entry.getFirst())
874                                ? -1
875                                : entry.getFirst();
876     arg_data_type_and_updated_output_index[entry.getFirst()] = {
877         entry.getSecond(), update_index};
878   }
879   AddLoadsStoresOutsideControlFlowOp(new_while,
880                                      arg_data_type_and_updated_output_index);
881   // Replace uses.
882   for (int64_t i = 0, end = old_to_new_indices.size(); i < end; ++i) {
883     if (old_to_new_indices[i] >= 0) {
884       while_op.getResult(i).replaceAllUsesWith(
885           new_while.getResult(old_to_new_indices[i]));
886     }
887   }
888   while_op.erase();
889   return success();
890 }
891 
892 // Lifts loads/stores from an IfOp or CaseOp's branches.
893 template <class CaseOrIfOp>
HandleCaseOrIfOp(CaseOrIfOp op,ArrayRef<func::FuncOp> branches)894 LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef<func::FuncOp> branches) {
895   // For canonicalized If/Case, there should not be any resource outputs
896   int64_t non_resource_results = op.getNumResults();
897 
898   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> resource_arg_uses;
899   if (failed(FindResourceArgUseInfo(branches.front(), &resource_arg_uses)))
900     return failure();
901 
902   for (auto func : branches.drop_front()) {
903     llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> branch_use_info;
904     if (failed(FindResourceArgUseInfo(func, &branch_use_info)))
905       return failure();
906     // A resource is considered used as long as it is used in either branch.
907     resource_arg_uses =
908         MergeArgResourceUseInfo(resource_arg_uses, branch_use_info);
909   }
910 
911   if (resource_arg_uses.empty()) return success();
912   // Remove unused resources in functions.
913   llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
914   RemoveUnusedResourceArgumentsAndForwardedRetvals(
915       resource_arg_uses, branches.front(), /*old_to_new_arg_indices=*/nullptr,
916       &remaining_resource_data_types);
917   for (auto func : branches.drop_front())
918     RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, func);
919 
920   // Forward resource inputs updated in any branch to the outputs of both
921   // branches. First prepare the mapping from arg to new update output.
922   llvm::SmallDenseMap<int64_t, int64_t> resource_arg_to_new_output;
923   {
924     int64_t removed_args = 0;
925     for (const auto& entry : resource_arg_uses) {
926       if (!entry.getSecond().used) {
927         removed_args++;
928         continue;
929       }
930       if (!entry.getSecond().updated) continue;
931       int64_t new_output_index =
932           non_resource_results + resource_arg_to_new_output.size();
933       resource_arg_to_new_output[entry.getFirst() - removed_args] =
934           new_output_index;
935     }
936   }
937 
938   // Append resource updates to the return ops: now they are just forwarded
939   // input resources, but will be replaced by the data value in
940   // LiftArgRetResourcesForFunction().
941   for (auto branch : branches) {
942     auto new_retvals =
943         llvm::to_vector<4>(branch.front().getTerminator()->getOperands());
944     new_retvals.resize(new_retvals.size() + resource_arg_to_new_output.size());
945     for (const auto& entry : resource_arg_to_new_output) {
946       int64_t resource_arg_index = entry.getFirst();
947       int64_t output_index = entry.getSecond();
948       new_retvals[output_index] = branch.getArgument(resource_arg_index);
949     }
950     auto old_return = branch.front().getTerminator();
951     OpBuilder builder(old_return);
952     auto new_return =
953         builder.create<func::ReturnOp>(old_return->getLoc(), new_retvals);
954     old_return->erase();
955     (void)LiftArgRetResourcesForFunction(
956         branch, remaining_resource_data_types, [&](int64_t index, Value value) {
957           new_return.setOperand(resource_arg_to_new_output[index], value);
958         });
959   }
960 
961   // Recreate the op without resource operands.
962   OpBuilder builder(op);
963   // Now use the filtered original operands, which will be replaced by
964   // AddLoadsStoresOutsideControlFlowOp().
965   auto new_operands =
966       FilterRange<Value, OperandRange>(op.input(), resource_arg_uses);
967   new_operands.insert(new_operands.begin(), op.getOperand(0));
968   func::FuncOp first_func = branches.front();
969   auto new_op = builder.create<CaseOrIfOp>(
970       op.getLoc(), first_func.getFunctionType().getResults(), new_operands,
971       op->getAttrs());
972   // Prepare for AddLoadsStoresOutsideControlFlowOp()
973   llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
974       arg_data_type_and_updated_output_index;
975   for (const auto& entry : remaining_resource_data_types) {
976     auto new_output_it = resource_arg_to_new_output.find(entry.getFirst());
977     int64_t update_index = new_output_it == resource_arg_to_new_output.end()
978                                ? -1
979                                : new_output_it->getSecond();
980     arg_data_type_and_updated_output_index[entry.getFirst() + 1] = {
981         entry.getSecond(), update_index};
982   }
983   AddLoadsStoresOutsideControlFlowOp(new_op,
984                                      arg_data_type_and_updated_output_index);
985   // Replace uses.
986   op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults()));
987   op.erase();
988   return success();
989 }
990 
991 // A resource-lifted function for (potentially multiple) PartitionedCallOps and
992 // information about the lifting changes.
993 struct PartitionedCallLiftingInfo {
994   // Function with resources lifted. Can be nullptr if nothing needs to change.
995   func::FuncOp lifted_callee;
996   // Mapping from old resource outputs to their aliasing output inputs.
997   llvm::SmallDenseMap<int64_t, int64_t> old_outputs_aliasing_old_inputs;
998   // Mapping from old to new output indices in case any output is removed.
999   llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
1000   // ResourceArgUseInfo for each old resource argument.
1001   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> use_info;
1002   // Input for AddLoadsStoresOutsideControlFlowOp(), see its comment.
1003   llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
1004       arg_data_type_and_updated_output_index;
1005 };
1006 
1007 // Lifts loads/stores from a PartitionedCallOp's callee function. If anything
1008 // needs to be changed, the original function will be preserved, and the lifting
1009 // happens on a clone, which will be stored in `result`.
HandlePartitionedCallOpCallee(func::FuncOp callee,PartitionedCallLiftingInfo * result)1010 LogicalResult HandlePartitionedCallOpCallee(
1011     func::FuncOp callee, PartitionedCallLiftingInfo* result) {
1012   // Sanity check: return of resources should be aliases of inputs. Such outputs
1013   // will be removed later.
1014   int64_t non_resource_results = 0;
1015   for (auto entry :
1016        llvm::enumerate(callee.front().getTerminator()->getOperands())) {
1017     auto retval = entry.value();
1018     if (!getElementTypeOrSelf(retval.getType()).isa<TF::ResourceType>()) {
1019       result->old_to_new_output_indices.push_back(non_resource_results++);
1020       continue;
1021     }
1022     auto aliasing_arg = retval.dyn_cast<BlockArgument>();
1023     if (!aliasing_arg) {
1024       return callee.emitOpError("unsupported function call: ")
1025              << "resource return value does not alias an input.";
1026     }
1027     result->old_outputs_aliasing_old_inputs[entry.index()] =
1028         aliasing_arg.getArgNumber();
1029     result->old_to_new_output_indices.push_back(-1);
1030   }
1031 
1032   if (failed(FindResourceArgUseInfo(callee, &result->use_info))) {
1033     return failure();
1034   }
1035   if (result->use_info.empty()) {
1036     result->lifted_callee = nullptr;
1037     return success();
1038   }
1039 
1040   // Clone the callee before making changes.
1041   SmallString<64> name_base = callee.getName();
1042   auto module = callee->getParentOfType<ModuleOp>();
1043   name_base += "_resource_lifted";
1044   auto name = name_base;
1045   callee = callee.clone();
1046   callee.setPrivate();
1047   callee.setName(mlir::StringAttr::get(callee->getContext(), name));
1048   SymbolTable(module).insert(callee);
1049   result->lifted_callee = callee;
1050 
1051   // Remove unused resources in functions.
1052   llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
1053   RemoveUnusedResourceArgumentsAndForwardedRetvals(
1054       result->use_info, callee, /*old_to_new_arg_indices=*/nullptr,
1055       &remaining_resource_data_types);
1056   for (const auto& entry : remaining_resource_data_types) {
1057     result->arg_data_type_and_updated_output_index[entry.getFirst()] = {
1058         entry.getSecond(), -1};
1059   }
1060   llvm::SmallVector<int64_t, 4> retval_indices_to_preserve;
1061   for (auto& val : callee.front().getTerminator()->getOpOperands()) {
1062     // Store indices of results that are not resources.
1063     if (!getElementTypeOrSelf(val.get().getType()).isa<TF::ResourceType>())
1064       retval_indices_to_preserve.push_back(val.getOperandNumber());
1065   }
1066   int64_t num_retvals = retval_indices_to_preserve.size();
1067   llvm::SmallVector<Value, 4> new_retvals;
1068   // Lift resources.
1069   (void)LiftArgRetResourcesForFunction(
1070       callee, remaining_resource_data_types, [&](int64_t index, Value value) {
1071         result->arg_data_type_and_updated_output_index[index].second =
1072             num_retvals++;
1073         new_retvals.push_back(value);
1074       });
1075 
1076   auto old_return = callee.front().getTerminator();
1077   llvm::SmallVector<Value, 4> old_and_new_retvals;
1078   old_and_new_retvals.reserve(retval_indices_to_preserve.size() +
1079                               new_retvals.size());
1080   for (int64_t retval_index : retval_indices_to_preserve)
1081     old_and_new_retvals.push_back(old_return->getOperand(retval_index));
1082 
1083   old_and_new_retvals.append(new_retvals.begin(), new_retvals.end());
1084   // Replace old return with the new ones with update values.
1085   OpBuilder builder(old_return);
1086   auto new_return =
1087       builder.create<func::ReturnOp>(old_return->getLoc(), old_and_new_retvals);
1088   old_return->erase();
1089   callee.setType(FunctionType::get(
1090       callee.getContext(), callee.getFunctionType().getInputs(),
1091       llvm::to_vector<4>(new_return.getOperandTypes())));
1092   return success();
1093 }
1094 
1095 // Updates a PartitionedCallOp/StatefulPartitionedCallOp according to the
1096 // resource-lifted new callee function in lifting_info.
1097 template <typename CallOpType>
UpdatePartitionedCallOpWithNewCallee(CallOpType call_op,PartitionedCallLiftingInfo & lifting_info)1098 void UpdatePartitionedCallOpWithNewCallee(
1099     CallOpType call_op, PartitionedCallLiftingInfo& lifting_info) {
1100   if (!lifting_info.lifted_callee) return;
1101   // Replace output resource uses with the aliasing input, so that we can remove
1102   // this output.
1103   for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) {
1104     call_op.getResult(entry.getFirst())
1105         .replaceAllUsesWith(call_op.getOperand(entry.getSecond()));
1106   }
1107   // Recreate the call op.
1108   OpBuilder builder(call_op);
1109   // Now use the filtered original operands, which will be replaced by
1110   // AddLoadsStoresOutsideControlFlowOp().
1111   auto new_operands =
1112       FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
1113   auto new_call = builder.create<CallOpType>(
1114       call_op.getLoc(),
1115       lifting_info.lifted_callee.getFunctionType().getResults(), new_operands,
1116       call_op->getAttrs());
1117   new_call->setAttr("f",
1118                     SymbolRefAttr::get(builder.getContext(),
1119                                        lifting_info.lifted_callee.getName()));
1120   AddLoadsStoresOutsideControlFlowOp(
1121       new_call, lifting_info.arg_data_type_and_updated_output_index);
1122   // Replace uses.
1123   for (int64_t i = 0, end = lifting_info.old_to_new_output_indices.size();
1124        i < end; ++i) {
1125     if (lifting_info.old_to_new_output_indices[i] >= 0) {
1126       call_op.getResult(i).replaceAllUsesWith(
1127           new_call.getResult(lifting_info.old_to_new_output_indices[i]));
1128     }
1129   }
1130   call_op.erase();
1131 }
1132 
1133 LogicalResult HoistForControlFlow(
1134     Block*, ModuleOp, bool,
1135     llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*);
1136 
1137 // A templated routine for handling both PartitionedCallOp and
1138 // StatefulPartitionedCallOp. If the callee is already lifted, it just updates
1139 // the caller op itself; otherwise, it first recursively handles nested control
1140 // flow, then performs lifting on the callee.
1141 template <typename CallOpType>
HandlePartitionedCallOp(CallOpType call_op,func::FuncOp callee,ModuleOp module,bool vars_initialized,llvm::SmallDenseMap<llvm::StringRef,PartitionedCallLiftingInfo> * lifted_callees)1142 LogicalResult HandlePartitionedCallOp(
1143     CallOpType call_op, func::FuncOp callee, ModuleOp module,
1144     bool vars_initialized,
1145     llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
1146         lifted_callees) {
1147   auto emplace_res = lifted_callees->try_emplace(callee.getName(),
1148                                                  PartitionedCallLiftingInfo());
1149   if (emplace_res.second) {
1150     // Unseen callee. Perform resource lifting on it.
1151     if (failed(HoistForControlFlow(&callee.front(), module, vars_initialized,
1152                                    lifted_callees)))
1153       return failure();
1154 
1155     if (failed(HandlePartitionedCallOpCallee(
1156             callee, &emplace_res.first->getSecond()))) {
1157       return failure();
1158     }
1159   }
1160   UpdatePartitionedCallOpWithNewCallee(call_op, emplace_res.first->getSecond());
1161   return success();
1162 }
1163 
1164 // Hoists resource loads/stores from control flow ops in `block` outside the
1165 // body/cond/branch/callee functions.
HoistForControlFlow(Block * block,ModuleOp module,bool vars_initialized,llvm::SmallDenseMap<llvm::StringRef,PartitionedCallLiftingInfo> * lifted_partitioned_call_callees)1166 LogicalResult HoistForControlFlow(
1167     Block* block, ModuleOp module, bool vars_initialized,
1168     llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
1169         lifted_partitioned_call_callees) {
1170   if (vars_initialized) SetAllVarIsInitializedToTrue(block);
1171 
1172   for (Operation& op : llvm::make_early_inc_range(*block)) {
1173     if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
1174       auto body = while_op.body_function();
1175       auto cond = while_op.cond_function();
1176       // Recursively handle the nested control flow.
1177       (void)HoistForControlFlow(&body.front(), module, vars_initialized,
1178                                 lifted_partitioned_call_callees);
1179       (void)HoistForControlFlow(&cond.front(), module, vars_initialized,
1180                                 lifted_partitioned_call_callees);
1181       if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
1182     } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
1183       auto then_branch = if_op.then_function();
1184       auto else_branch = if_op.else_function();
1185       // Recursively handle the nested control flow.
1186       (void)HoistForControlFlow(&then_branch.front(), module, vars_initialized,
1187                                 lifted_partitioned_call_callees);
1188       (void)HoistForControlFlow(&else_branch.front(), module, vars_initialized,
1189                                 lifted_partitioned_call_callees);
1190       if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch})))
1191         return failure();
1192     } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
1193       SmallVector<func::FuncOp, 4> branch_functions;
1194       case_op.get_branch_functions(branch_functions);
1195       for (func::FuncOp func : branch_functions) {
1196         // Recursively handle the nested control flow.
1197         (void)HoistForControlFlow(&func.front(), module, vars_initialized,
1198                                   lifted_partitioned_call_callees);
1199       }
1200       if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure();
1201     } else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
1202       auto callee = call_op.func();
1203       if (!callee) {
1204         return call_op.emitOpError(
1205             "resource lifting does not support call with nested references.");
1206       }
1207       if (failed(HandlePartitionedCallOp(call_op, callee, module,
1208                                          vars_initialized,
1209                                          lifted_partitioned_call_callees))) {
1210         // Nested control flow handling is done in HandlePartitionedCallOp().
1211         return failure();
1212       }
1213     } else if (auto call_op =
1214                    llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
1215       if (failed(HandlePartitionedCallOp(call_op, call_op.func(), module,
1216                                          vars_initialized,
1217                                          lifted_partitioned_call_callees))) {
1218         return failure();
1219       }
1220     } else if (isa<TF::IfRegionOp, TF::CaseRegionOp, TF::WhileRegionOp>(op)) {
1221       for (Region& region : op.getRegions())
1222         (void)HoistForControlFlow(&region.front(), module, vars_initialized,
1223                                   lifted_partitioned_call_callees);
1224       LogicalResult result = RegionResourceHoister::ReplaceOpWithNewOp(&op);
1225       if (failed(result)) return failure();
1226     }
1227   }
1228 
1229   // After we have hoisted operations in the block, we may have added new read
1230   // and writes of resources to this block. Clean them up by doing store-load
1231   // forwarding.
1232   ForwardStoreToLoad(block);
1233   return success();
1234 }
1235 
1236 // Lifts resource operation from tf_device.cluster ops nested in `op` outside.
1237 // Returns failure if there are remaining resource-type values that can not be
1238 // lifted.
runOnOperation()1239 void ResourceOpLiftingPass::runOnOperation() {
1240   llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
1241       lifted_partitioned_call_callees;
1242   ModuleOp module = getOperation();
1243 
1244   if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(module)))
1245     return signalPassFailure();
1246 
1247   auto walk_result = module.walk([&](func::FuncOp func_op) {
1248     return func_op.walk([&](tf_device::ClusterOp cluster) {
1249       LogicalResult result = HoistForControlFlow(
1250           &cluster.GetBody(), module, /*vars_initialized=*/true,
1251           &lifted_partitioned_call_callees);
1252       if (failed(result)) return WalkResult::interrupt();
1253       result = RegionResourceHoister::ReplaceOpWithNewOp(cluster);
1254       if (failed(result)) return WalkResult::interrupt();
1255       return WalkResult::advance();
1256     });
1257   });
1258 
1259   if (walk_result.wasInterrupted()) return signalPassFailure();
1260 }
1261 
1262 struct ResourceOpLiftingForMainFunctionPass
1263     : public TFDevice::ResourceOpLiftingForMainFunctionPassBase<
1264           ResourceOpLiftingForMainFunctionPass> {
1265   void runOnOperation() override;
1266 };
1267 
runOnOperation()1268 void ResourceOpLiftingForMainFunctionPass::runOnOperation() {
1269   ModuleOp module = getOperation();
1270   func::FuncOp main_func = module.lookupSymbol<func::FuncOp>("main");
1271   if (!main_func) {
1272     return;
1273   }
1274 
1275   if (failed(TF::ResourceLiftingForFunctionalControlFlow(main_func))) {
1276     return signalPassFailure();
1277   }
1278 }
1279 
1280 }  // namespace
1281 
1282 namespace TFDevice {
CreateResourceOpLiftingPass()1283 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceOpLiftingPass() {
1284   return std::make_unique<ResourceOpLiftingPass>();
1285 }
1286 
1287 std::unique_ptr<OperationPass<ModuleOp>>
CreateResourceOpLiftingForMainFunctionPass()1288 CreateResourceOpLiftingForMainFunctionPass() {
1289   return std::make_unique<ResourceOpLiftingForMainFunctionPass>();
1290 }
1291 
1292 }  // namespace TFDevice
1293 
1294 namespace TF {
ResourceLiftingForFunctionalControlFlow(func::FuncOp function)1295 LogicalResult ResourceLiftingForFunctionalControlFlow(func::FuncOp function) {
1296   // This routine should only be called when control flow operations are still
1297   // represented with TF IfOp and WhileOp operations. In this case, there should
1298   // be only one basic blocks in the MLIR representation.
1299   if (!llvm::hasSingleElement(function)) {
1300     return function.emitError()
1301            << "expect the function to have 1 block while it has "
1302            << function.getBlocks().size();
1303   }
1304 
1305   if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(function)))
1306     return failure();
1307 
1308   llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
1309       lifted_partitioned_call_callees;
1310   if (failed(HoistForControlFlow(
1311           &function.front(), cast<ModuleOp>(function->getParentOp()),
1312           /*vars_initialized=*/false, &lifted_partitioned_call_callees)))
1313     return failure();
1314 
1315   // Clean up and canonicalize to remove dead local variables as some local
1316   // variables might be dead after hoisting resource loads/stores from control
1317   // flow ops.
1318   return TF::CleanupAndCanonicalizeForResourceOpLifting(function);
1319 }
1320 }  // namespace TF
1321 
1322 }  // namespace mlir
1323