xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/var_substitutor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <unordered_map>
4 #include <utility>
5 #include <vector>
6 
7 #include <torch/csrc/jit/tensorexpr/analysis.h>
8 #include <torch/csrc/jit/tensorexpr/ir.h>
9 #include <torch/csrc/jit/tensorexpr/ir_mutator.h>
10 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
11 #include <torch/csrc/jit/tensorexpr/reduction.h>
12 
13 namespace torch::jit::tensorexpr {
14 
15 using VarMapping = std::vector<std::pair<VarPtr, ExprPtr>>;
16 
17 class VarSubMutator : public IRMutator {
18  public:
VarSubMutator(const VarMapping & var_mapping)19   VarSubMutator(const VarMapping& var_mapping) {
20     for (auto& entry : var_mapping) {
21       VarPtr key_var = entry.first;
22       ExprPtr value = entry.second;
23       if (!key_var) {
24         throw malformed_input("missing key in VarSubMutator");
25       }
26       var_mapping_[std::move(key_var)] = std::move(value);
27     }
28   }
29 
mutate(const VarPtr & var)30   ExprPtr mutate(const VarPtr& var) override {
31     auto iter = var_mapping_.find(var);
32     if (iter == var_mapping_.end()) {
33       return var;
34     }
35     return iter->second;
36   }
37 
mutate(const ReduceOpPtr & var)38   ExprPtr mutate(const ReduceOpPtr& var) override {
39     auto body = var->body()->accept_mutator(this);
40     std::vector<VarPtr> new_inner;
41 
42     for (const auto& v : var->reduce_args()) {
43       ExprPtr e = v->accept_mutator(this);
44       if (VarPtr new_var = to<Var>(e)) {
45         new_inner.push_back(std::move(new_var));
46       } else {
47         VarFinder varFinder;
48         e->accept(&varFinder);
49         auto varlist = varFinder.vars();
50         new_inner.insert(new_inner.end(), varlist.begin(), varlist.end());
51       }
52     }
53 
54     return alloc<ReduceOp>(body, new_inner, var->reducer());
55   }
56 
57  private:
58   std::unordered_map<VarPtr, ExprPtr> var_mapping_;
59 };
60 
61 } // namespace torch::jit::tensorexpr
62