1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/Support/Casting.h"
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Transforms/RegionUtils.h"
26 
27 namespace mlir {
28 namespace mhlo {
29 
30 namespace {
31 
32 // A pass that sinks constants implicitly captured in control flow regions. This
33 // is necessary to export to XLA.
34 //
35 // TODO(b/203775547): Any value used within the region that is defined outside
36 // of op's region should be sank to the regions and not just the constants. Ops
37 // such as If and While whose computations doesn't require fixed signature like
38 // Sort or Reduce have an option to pass outside values as operands of the op to
39 // avoid recomputing those within internally. Note that doing so is the only
40 // option in case of values defined outside that are BlockArguments of any of
41 // the parent region.
42 class SinkConstantsToControlFlowPass
43     : public SinkConstantsToControlFlowPassBase<
44           SinkConstantsToControlFlowPass> {
runOnOperation()45   void runOnOperation() override {
46     getOperation().walk([](Operation* op) {
47       for (Region& region : op->getRegions()) sinkToRegion(&region);
48     });
49   }
50 
51  private:
52   // Performs constant sinking into a region.
sinkToRegion(Region * region)53   static void sinkToRegion(Region* region) {
54     llvm::DenseMap<Value, Operation*> sunkConstant;
55     visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
56       Value constant = use->get();
57       auto* op = constant.getDefiningOp();
58       if (!op || !op->hasTrait<mlir::OpTrait::ConstantLike>()) return;
59       auto mapEntry = sunkConstant.try_emplace(constant, nullptr);
60       if (!mapEntry.second) {
61         // This constant has already been cloned into the region, reuse it.
62         use->set(mapEntry.first->getSecond()->getResult(0));
63         if (op->use_empty()) op->erase();
64         return;
65       }
66       if (constant.hasOneUse()) {
67         op->moveBefore(&region->front().front());
68         return;
69       }
70       mapEntry.first->getSecond() = op->clone();
71       region->front().getOperations().insert(region->front().begin(),
72                                              mapEntry.first->getSecond());
73       use->set(mapEntry.first->getSecond()->getResult(0));
74     });
75   }
76 };
77 
78 }  // anonymous namespace
79 
80 // TODO(hinsu): Rename this pass and move to a different file along with the
81 // generalization to make all ops isolated from above.
82 std::unique_ptr<OperationPass<func::FuncOp>>
createSinkConstantsToControlFlowPass()83 createSinkConstantsToControlFlowPass() {
84   return std::make_unique<SinkConstantsToControlFlowPass>();
85 }
86 
87 }  // namespace mhlo
88 }  // namespace mlir
89