xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/peephole_dict_idioms.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/ir/alias_analysis.h>
2 #include <torch/csrc/jit/passes/peephole_dict_idioms.h>
3 
4 namespace torch::jit {
5 
6 namespace {
7 
8 class DictNodeImplBase {
9  public:
10   virtual ~DictNodeImplBase() = default;
11 
12   virtual bool contains(const IValue&) const = 0;
13   virtual size_t size() const = 0;
14   virtual Value* get(const IValue&) const = 0;
15 
canOptimize()16   bool canOptimize() {
17     return !has_overlap_ && !has_non_const_key_;
18   }
19 
20  protected:
21   bool has_overlap_ = false;
22   bool has_non_const_key_ = false;
23 };
24 
25 template <class KeyType>
26 class DictNodeImpl : public DictNodeImplBase {
27  public:
DictNodeImpl(std::function<KeyType (const IValue &)> ivalue_converter,Node * dict_creation_node)28   DictNodeImpl(
29       std::function<KeyType(const IValue&)> ivalue_converter,
30       Node* dict_creation_node)
31       : ivalue_converter_(std::move(ivalue_converter)) {
32     for (size_t i = 0; i < dict_creation_node->inputs().size(); i += 2) {
33       auto key_opt = toIValue(dict_creation_node->input(i));
34 
35       // Key is not constant if we cannot convert to IValue
36       if (key_opt == std::nullopt) {
37         has_non_const_key_ = true;
38         continue;
39       }
40 
41       KeyType key = ivalue_converter_(*key_opt);
42       if (dict_.find(key) == dict_.end()) {
43         dict_.emplace(key, dict_creation_node->input(i + 1));
44       } else {
45         has_overlap_ = true;
46       }
47     }
48   }
49 
contains(const IValue & ivalue) const50   bool contains(const IValue& ivalue) const override {
51     auto key = ivalue_converter_(ivalue);
52     return dict_.find(key) != dict_.end();
53   }
54 
size() const55   size_t size() const override {
56     return dict_.size();
57   }
58 
get(const IValue & ivalue) const59   Value* get(const IValue& ivalue) const override {
60     auto val = ivalue_converter_(ivalue);
61     auto loc = dict_.find(val);
62     if (loc != dict_.end()) {
63       return loc->second;
64     }
65     TORCH_CHECK(false, "Cannot get non-existent key");
66   }
67 
68  private:
69   std::unordered_map<KeyType, Value*> dict_;
70   std::function<KeyType(const IValue&)> ivalue_converter_;
71 };
72 
73 class DictNode {
74  public:
DictNode(Node * dict_creation_node)75   explicit DictNode(Node* dict_creation_node) {
76     auto dict_type = dict_creation_node->output()->type();
77     auto key_value_types = dict_type->containedTypes();
78     TORCH_CHECK(
79         key_value_types.size() == 2, "Dict must have 2 contained types");
80     const auto& key_type = key_value_types[0];
81 
82     switch (key_type->kind()) {
83       case TypeKind::IntType: {
84         auto ivalue_converter = [](const IValue& ival) { return ival.toInt(); };
85         impl_ = std::make_unique<DictNodeImpl<int64_t>>(
86             std::move(ivalue_converter), dict_creation_node);
87         break;
88       }
89 
90       case TypeKind::FloatType: {
91         auto ivalue_converter = [](const IValue& ival) {
92           return ival.toDouble();
93         };
94         impl_ = std::make_unique<DictNodeImpl<double>>(
95             std::move(ivalue_converter), dict_creation_node);
96         break;
97       }
98 
99       case TypeKind::StringType: {
100         auto ivalue_converter = [](const IValue& ival) {
101           return *ival.toString();
102         };
103         impl_ = std::make_unique<DictNodeImpl<std::string>>(
104             std::move(ivalue_converter), dict_creation_node);
105         break;
106       }
107 
108       default:
109         impl_ = nullptr;
110     }
111   }
112 
canOptimize() const113   bool canOptimize() const {
114     if (impl_) {
115       return impl_->canOptimize();
116     }
117     return false;
118   }
119 
size() const120   size_t size() const {
121     if (impl_) {
122       return impl_->size();
123     }
124     return 0;
125   }
126 
getOrNullopt(const IValue & key) const127   std::optional<Value*> getOrNullopt(const IValue& key) const {
128     if (impl_ && impl_->contains(key)) {
129       return impl_->get(key);
130     }
131     return std::nullopt;
132   }
133 
134  private:
135   std::unique_ptr<DictNodeImplBase> impl_;
136 };
137 
isDict(Value * v)138 bool isDict(Value* v) {
139   return v->type()->castRaw<DictType>() != nullptr;
140 }
141 
142 class PeepholeOptimizeDictIdiomsImpl {
143  public:
PeepholeOptimizeDictIdiomsImpl(std::shared_ptr<Graph> graph)144   explicit PeepholeOptimizeDictIdiomsImpl(std::shared_ptr<Graph> graph)
145       : graph_(std::move(graph)), aliasDb_(std::make_unique<AliasDb>(graph_)) {}
146 
run()147   bool run() {
148     collectMutatedDicts(graph_->block());
149     return runBlock(graph_->block());
150   }
151 
152  private:
checkForMutatedDicts(Value * v)153   void checkForMutatedDicts(Value* v) {
154     if (isDict(v) && aliasDb_->hasWriters(v)) {
155       mutated_dicts_.insert(v);
156     }
157   }
158 
collectMutatedDicts(Block * b)159   void collectMutatedDicts(Block* b) {
160     for (Value* v : b->inputs()) {
161       checkForMutatedDicts(v);
162     }
163     for (Node* n : b->nodes()) {
164       for (Value* v : n->outputs()) {
165         checkForMutatedDicts(v);
166       }
167       for (Block* block : n->blocks()) {
168         collectMutatedDicts(block);
169       }
170     }
171   }
172 
getDictNode(Node * creation_node)173   const DictNode& getDictNode(Node* creation_node) {
174     auto cached = dict_cache_.find(creation_node);
175     if (cached == dict_cache_.end()) {
176       cached =
177           dict_cache_.emplace(creation_node, DictNode(creation_node)).first;
178     }
179 
180     return cached->second;
181   }
182 
getValueFromDict(Node * dict_creation_node,Value * key)183   std::optional<Value*> getValueFromDict(Node* dict_creation_node, Value* key) {
184     const DictNode& dict_node = getDictNode(dict_creation_node);
185     auto key_opt = toIValue(key);
186     // Key is not constant if we cannot convert to IValue
187     if (key_opt == std::nullopt) {
188       return std::nullopt;
189     }
190     IValue key_ival = *key_opt;
191     if (dict_node.canOptimize()) {
192       return dict_node.getOrNullopt(key_ival);
193     }
194     return std::nullopt;
195   }
196 
computeLen(Node * dict_creation_node)197   std::optional<int64_t> computeLen(Node* dict_creation_node) {
198     const DictNode& dict_node = getDictNode(dict_creation_node);
199     if (dict_node.canOptimize()) {
200       return static_cast<int64_t>(dict_node.size());
201     }
202     return std::nullopt;
203   }
204 
optimizeLen(Node * len_node,Node * creation_node)205   bool optimizeLen(Node* len_node, Node* creation_node) {
206     if (creation_node->kind() == prim::DictConstruct) {
207       auto len = computeLen(creation_node);
208       if (len != std::nullopt) {
209         WithInsertPoint guard(len_node);
210         len_node->output()->replaceAllUsesWith(graph_->insertConstant(len));
211         return true;
212       }
213     }
214     return false;
215   }
216 
optimizeGetItem(Node * getitem_node,Node * creation_node)217   bool optimizeGetItem(Node* getitem_node, Node* creation_node) {
218     if (creation_node->kind() == prim::DictConstruct) {
219       auto key = getitem_node->input(1);
220       auto value = getValueFromDict(creation_node, key);
221       if (value != std::nullopt) {
222         getitem_node->output()->replaceAllUsesWith(*value);
223         return true;
224       }
225     }
226     return false;
227   }
228 
runBlock(Block * block)229   bool runBlock(Block* block) {
230     bool changed = false;
231     for (Node* node : block->nodes()) {
232       for (Block* b : node->blocks()) {
233         changed |= runBlock(b);
234       }
235 
236       // only optimizing dict ops
237       if (node->inputs().empty() || !isDict(node->input(0))) {
238         continue;
239       }
240 
241       auto first_input = node->input(0);
242 
243       // only optimizing ops with unmutated inputs
244       if (mutated_dicts_.count(first_input)) {
245         continue;
246       }
247 
248       if (node->kind() == aten::len) {
249         changed |= optimizeLen(node, first_input->node());
250       } else if (node->kind() == aten::__getitem__) {
251         changed |= optimizeGetItem(node, first_input->node());
252       }
253     }
254     return changed;
255   }
256 
257   std::shared_ptr<Graph> graph_;
258   std::unordered_set<Value*> mutated_dicts_;
259   std::unique_ptr<AliasDb> aliasDb_;
260   std::unordered_map<Node*, DictNode> dict_cache_;
261 };
262 
263 } // namespace
264 
PeepholeOptimizeDictIdioms(const std::shared_ptr<Graph> & graph)265 bool PeepholeOptimizeDictIdioms(const std::shared_ptr<Graph>& graph) {
266   PeepholeOptimizeDictIdiomsImpl opt(graph);
267   return opt.run();
268 }
269 
270 } // namespace torch::jit
271