1 #include <torch/csrc/jit/ir/alias_analysis.h>
2 #include <torch/csrc/jit/passes/peephole_dict_idioms.h>
3
4 namespace torch::jit {
5
6 namespace {
7
8 class DictNodeImplBase {
9 public:
10 virtual ~DictNodeImplBase() = default;
11
12 virtual bool contains(const IValue&) const = 0;
13 virtual size_t size() const = 0;
14 virtual Value* get(const IValue&) const = 0;
15
canOptimize()16 bool canOptimize() {
17 return !has_overlap_ && !has_non_const_key_;
18 }
19
20 protected:
21 bool has_overlap_ = false;
22 bool has_non_const_key_ = false;
23 };
24
25 template <class KeyType>
26 class DictNodeImpl : public DictNodeImplBase {
27 public:
DictNodeImpl(std::function<KeyType (const IValue &)> ivalue_converter,Node * dict_creation_node)28 DictNodeImpl(
29 std::function<KeyType(const IValue&)> ivalue_converter,
30 Node* dict_creation_node)
31 : ivalue_converter_(std::move(ivalue_converter)) {
32 for (size_t i = 0; i < dict_creation_node->inputs().size(); i += 2) {
33 auto key_opt = toIValue(dict_creation_node->input(i));
34
35 // Key is not constant if we cannot convert to IValue
36 if (key_opt == std::nullopt) {
37 has_non_const_key_ = true;
38 continue;
39 }
40
41 KeyType key = ivalue_converter_(*key_opt);
42 if (dict_.find(key) == dict_.end()) {
43 dict_.emplace(key, dict_creation_node->input(i + 1));
44 } else {
45 has_overlap_ = true;
46 }
47 }
48 }
49
contains(const IValue & ivalue) const50 bool contains(const IValue& ivalue) const override {
51 auto key = ivalue_converter_(ivalue);
52 return dict_.find(key) != dict_.end();
53 }
54
size() const55 size_t size() const override {
56 return dict_.size();
57 }
58
get(const IValue & ivalue) const59 Value* get(const IValue& ivalue) const override {
60 auto val = ivalue_converter_(ivalue);
61 auto loc = dict_.find(val);
62 if (loc != dict_.end()) {
63 return loc->second;
64 }
65 TORCH_CHECK(false, "Cannot get non-existent key");
66 }
67
68 private:
69 std::unordered_map<KeyType, Value*> dict_;
70 std::function<KeyType(const IValue&)> ivalue_converter_;
71 };
72
73 class DictNode {
74 public:
DictNode(Node * dict_creation_node)75 explicit DictNode(Node* dict_creation_node) {
76 auto dict_type = dict_creation_node->output()->type();
77 auto key_value_types = dict_type->containedTypes();
78 TORCH_CHECK(
79 key_value_types.size() == 2, "Dict must have 2 contained types");
80 const auto& key_type = key_value_types[0];
81
82 switch (key_type->kind()) {
83 case TypeKind::IntType: {
84 auto ivalue_converter = [](const IValue& ival) { return ival.toInt(); };
85 impl_ = std::make_unique<DictNodeImpl<int64_t>>(
86 std::move(ivalue_converter), dict_creation_node);
87 break;
88 }
89
90 case TypeKind::FloatType: {
91 auto ivalue_converter = [](const IValue& ival) {
92 return ival.toDouble();
93 };
94 impl_ = std::make_unique<DictNodeImpl<double>>(
95 std::move(ivalue_converter), dict_creation_node);
96 break;
97 }
98
99 case TypeKind::StringType: {
100 auto ivalue_converter = [](const IValue& ival) {
101 return *ival.toString();
102 };
103 impl_ = std::make_unique<DictNodeImpl<std::string>>(
104 std::move(ivalue_converter), dict_creation_node);
105 break;
106 }
107
108 default:
109 impl_ = nullptr;
110 }
111 }
112
canOptimize() const113 bool canOptimize() const {
114 if (impl_) {
115 return impl_->canOptimize();
116 }
117 return false;
118 }
119
size() const120 size_t size() const {
121 if (impl_) {
122 return impl_->size();
123 }
124 return 0;
125 }
126
getOrNullopt(const IValue & key) const127 std::optional<Value*> getOrNullopt(const IValue& key) const {
128 if (impl_ && impl_->contains(key)) {
129 return impl_->get(key);
130 }
131 return std::nullopt;
132 }
133
134 private:
135 std::unique_ptr<DictNodeImplBase> impl_;
136 };
137
isDict(Value * v)138 bool isDict(Value* v) {
139 return v->type()->castRaw<DictType>() != nullptr;
140 }
141
142 class PeepholeOptimizeDictIdiomsImpl {
143 public:
PeepholeOptimizeDictIdiomsImpl(std::shared_ptr<Graph> graph)144 explicit PeepholeOptimizeDictIdiomsImpl(std::shared_ptr<Graph> graph)
145 : graph_(std::move(graph)), aliasDb_(std::make_unique<AliasDb>(graph_)) {}
146
run()147 bool run() {
148 collectMutatedDicts(graph_->block());
149 return runBlock(graph_->block());
150 }
151
152 private:
checkForMutatedDicts(Value * v)153 void checkForMutatedDicts(Value* v) {
154 if (isDict(v) && aliasDb_->hasWriters(v)) {
155 mutated_dicts_.insert(v);
156 }
157 }
158
collectMutatedDicts(Block * b)159 void collectMutatedDicts(Block* b) {
160 for (Value* v : b->inputs()) {
161 checkForMutatedDicts(v);
162 }
163 for (Node* n : b->nodes()) {
164 for (Value* v : n->outputs()) {
165 checkForMutatedDicts(v);
166 }
167 for (Block* block : n->blocks()) {
168 collectMutatedDicts(block);
169 }
170 }
171 }
172
getDictNode(Node * creation_node)173 const DictNode& getDictNode(Node* creation_node) {
174 auto cached = dict_cache_.find(creation_node);
175 if (cached == dict_cache_.end()) {
176 cached =
177 dict_cache_.emplace(creation_node, DictNode(creation_node)).first;
178 }
179
180 return cached->second;
181 }
182
getValueFromDict(Node * dict_creation_node,Value * key)183 std::optional<Value*> getValueFromDict(Node* dict_creation_node, Value* key) {
184 const DictNode& dict_node = getDictNode(dict_creation_node);
185 auto key_opt = toIValue(key);
186 // Key is not constant if we cannot convert to IValue
187 if (key_opt == std::nullopt) {
188 return std::nullopt;
189 }
190 IValue key_ival = *key_opt;
191 if (dict_node.canOptimize()) {
192 return dict_node.getOrNullopt(key_ival);
193 }
194 return std::nullopt;
195 }
196
computeLen(Node * dict_creation_node)197 std::optional<int64_t> computeLen(Node* dict_creation_node) {
198 const DictNode& dict_node = getDictNode(dict_creation_node);
199 if (dict_node.canOptimize()) {
200 return static_cast<int64_t>(dict_node.size());
201 }
202 return std::nullopt;
203 }
204
optimizeLen(Node * len_node,Node * creation_node)205 bool optimizeLen(Node* len_node, Node* creation_node) {
206 if (creation_node->kind() == prim::DictConstruct) {
207 auto len = computeLen(creation_node);
208 if (len != std::nullopt) {
209 WithInsertPoint guard(len_node);
210 len_node->output()->replaceAllUsesWith(graph_->insertConstant(len));
211 return true;
212 }
213 }
214 return false;
215 }
216
optimizeGetItem(Node * getitem_node,Node * creation_node)217 bool optimizeGetItem(Node* getitem_node, Node* creation_node) {
218 if (creation_node->kind() == prim::DictConstruct) {
219 auto key = getitem_node->input(1);
220 auto value = getValueFromDict(creation_node, key);
221 if (value != std::nullopt) {
222 getitem_node->output()->replaceAllUsesWith(*value);
223 return true;
224 }
225 }
226 return false;
227 }
228
runBlock(Block * block)229 bool runBlock(Block* block) {
230 bool changed = false;
231 for (Node* node : block->nodes()) {
232 for (Block* b : node->blocks()) {
233 changed |= runBlock(b);
234 }
235
236 // only optimizing dict ops
237 if (node->inputs().empty() || !isDict(node->input(0))) {
238 continue;
239 }
240
241 auto first_input = node->input(0);
242
243 // only optimizing ops with unmutated inputs
244 if (mutated_dicts_.count(first_input)) {
245 continue;
246 }
247
248 if (node->kind() == aten::len) {
249 changed |= optimizeLen(node, first_input->node());
250 } else if (node->kind() == aten::__getitem__) {
251 changed |= optimizeGetItem(node, first_input->node());
252 }
253 }
254 return changed;
255 }
256
257 std::shared_ptr<Graph> graph_;
258 std::unordered_set<Value*> mutated_dicts_;
259 std::unique_ptr<AliasDb> aliasDb_;
260 std::unordered_map<Node*, DictNode> dict_cache_;
261 };
262
263 } // namespace
264
PeepholeOptimizeDictIdioms(const std::shared_ptr<Graph> & graph)265 bool PeepholeOptimizeDictIdioms(const std::shared_ptr<Graph>& graph) {
266 PeepholeOptimizeDictIdiomsImpl opt(graph);
267 return opt.run();
268 }
269
270 } // namespace torch::jit
271