xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/remove_mutation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/remove_mutation.h>
2 #include <torch/csrc/jit/passes/restore_mutation.h>
3 
4 namespace torch::jit {
5 
removeListMutation()6 bool MutationRemover::removeListMutation() {
7   return RemoveListMutation(graph_->block());
8 }
9 
removeTensorMutation()10 bool MutationRemover::removeTensorMutation() {
11   return RemoveTensorMutation(graph_->block());
12 }
13 
hasSideEffectOrAlias(Value * v,AliasDb * aliasDb)14 bool MutationRemover::hasSideEffectOrAlias(Value* v, AliasDb* aliasDb) {
15   // bail on nodes with side effects, blocks, or graph / graph inputs
16   Node* n = v->node();
17   bool unhandled_node = !n->blocks().empty() ||
18       n->hasAttribute(attr::Subgraph) || n->hasSideEffects() ||
19       (v->node()->kind() == prim::Param);
20 
21   // if the output isn't contained or alias by the inputs to its node, it's
22   // unique. No need to check for alias if the node is a ListConstruct.
23   bool mayAliasInputs = (v->node()->kind() != prim::ListConstruct) &&
24       aliasDb->mayContainAlias(v->node()->inputs(), v);
25   return unhandled_node || mayAliasInputs || (v->node()->kind() == prim::Param);
26 }
27 
createSpecialMappedOp(Node * n)28 Node* MutationRemover::createSpecialMappedOp(Node* n) {
29   WithInsertPoint guard(n);
30   auto inputs = n->inputs();
31   Node* new_node = nullptr;
32   if (n->matches(
33           "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)")) {
34     auto dtype = graph_->insert(prim::dtype, {inputs.at(0)});
35     new_node = graph_
36                    ->insert(
37                        aten::full_like,
38                        {inputs.at(0), inputs.at(1)},
39                        {NamedValue("dtype", dtype)})
40                    ->node();
41     new_node->copyMetadata(n);
42     new_node->output()->setType(n->output()->type());
43   } else if (n->matches("aten::zero_(Tensor(a!) self) -> Tensor(a!)")) {
44     new_node = graph_->insert(aten::zeros_like, {n->inputs().at(0)})->node();
45   } else if (
46       n->matches(
47           "aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)")) {
48     // TODO: we should have normal_like operator
49     // normal(float mean, float std, int[] size, *, Generator? generator=None,
50     // ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool?
51     // pin_memory=None) -> Tensor
52     auto size = graph_->insert(aten::size, {n->inputs().at(0)});
53     auto dtype = graph_->insert(prim::dtype, {n->inputs().at(0)});
54     auto layout = graph_->insert(prim::layout, {n->inputs().at(0)});
55     auto device = graph_->insert(prim::device, {n->inputs().at(0)});
56     auto pin_memory = graph_->insert(aten::is_pinned, {n->inputs().at(0)});
57     auto generator = graph_->insertConstant(IValue());
58     new_node = graph_->insertNode(graph_->create(
59         aten::normal,
60         {n->inputs().at(1),
61          n->inputs().at(2),
62          size,
63          generator,
64          dtype,
65          layout,
66          device,
67          pin_memory}));
68   } else {
69     TORCH_INTERNAL_ASSERT(false);
70   }
71   new_node->copyMetadata(n);
72   new_node->output()->setType(n->output()->type());
73   return new_node;
74 }
75 
removableSetItem(Node * n)76 static bool removableSetItem(Node* n) {
77   if (n->kind() != aten::_set_item ||
78       n->input(1)->node()->kind() != prim::Constant) {
79     return false;
80   }
81   if (n->inputs().at(0)->node()->kind() != prim::ListConstruct) {
82     return false;
83   }
84   auto li_node = n->inputs().at(0)->node();
85   int64_t index = *constant_as<int64_t>(n->input(1));
86   if (index < 0) {
87     index += li_node->inputs().size();
88   }
89   auto li_len = static_cast<int64_t>(li_node->inputs().size());
90   return index < li_len && index >= 0;
91 }
92 
listMutationFollowingListConstruct(Node * n)93 bool MutationRemover::listMutationFollowingListConstruct(Node* n) {
94   return (
95       (n->kind() == aten::append ||
96        (n->kind() == aten::insert &&
97         n->inputs().at(1)->node()->kind() == prim::Constant) ||
98        (removableSetItem(n))) &&
99       n->inputs().at(0)->node()->kind() == prim::ListConstruct);
100 }
101 
tryMakeCreationAndMutationAtomic(Value * mutated_value,Node * mutating_op)102 bool MutationRemover::tryMakeCreationAndMutationAtomic(
103     Value* mutated_value,
104     Node* mutating_op) {
105   // We can only remove mutation to values that are unique aliases in the
106   // graph. if x = y[0] or y = self.y, then removing the mutation could
107   // change observable semantics
108   if (hasSideEffectOrAlias(mutated_value, getOrCreateAliasDb())) {
109     return false;
110   }
111 
112   // In order to safely remove a mutation, the creation of a tensor and its
113   // subsequent mutation need to be one atomic operation
114   return getOrCreateAliasDb()->moveBeforeTopologicallyValid(
115       mutated_value->node(), mutating_op);
116 }
117 
tryMakeUnaliasedIfOutputAndMutationAtomic(Value * mutated_value,Node * mutating_op)118 bool MutationRemover::tryMakeUnaliasedIfOutputAndMutationAtomic(
119     Value* mutated_value,
120     Node* mutating_op) {
121   // if cond:
122   //    x = op()
123   // else:
124   //    x = op()
125   // x = add_(1)
126   // if x in both blocks have no other uses and are unaliased in the graph,
127   // and we make the if node and the mutation atomic,
128   // then removing mutation add_ does not change observable semantics
129 
130   if (mutated_value->node()->kind() != prim::If) {
131     return false;
132   }
133 
134   auto if_node = mutated_value->node();
135   auto offset = mutated_value->offset();
136   auto true_value = if_node->blocks().at(0)->outputs().at(offset);
137   auto false_value = if_node->blocks().at(1)->outputs().at(offset);
138 
139   if (true_value->uses().size() > 1 || false_value->uses().size() > 1) {
140     return false;
141   }
142 
143   if (hasSideEffectOrAlias(true_value, getOrCreateAliasDb()) ||
144       hasSideEffectOrAlias(false_value, getOrCreateAliasDb())) {
145     return false;
146   }
147 
148   return getOrCreateAliasDb()->moveBeforeTopologicallyValid(
149       if_node, mutating_op);
150 }
151 
RemoveListMutation(Block * block)152 bool MutationRemover::RemoveListMutation(Block* block) {
153   bool changed = false;
154   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
155     auto* node = *it;
156     it++;
157 
158     for (Block* sub_block : node->blocks()) {
159       changed |= RemoveListMutation(sub_block);
160     }
161 
162     if (!listMutationFollowingListConstruct(node)) {
163       continue;
164     }
165 
166     Value* mutated_value = node->inputs().at(0);
167     if (!tryMakeCreationAndMutationAtomic(mutated_value, node)) {
168       continue;
169     }
170 
171     changed = true;
172 
173     // We rewrite something like:
174     // x = {v0}
175     // x.append(v1) (or x.insert(0, v1))
176     // to:
177     // x = {v0, v1} (or x = {v1, v0})
178     // We can remove x.append from the alias db list of writes.
179     // All other aliasing properties remain valid.
180     Node* list_construct = mutated_value->node();
181     switch (node->kind()) {
182       case aten::append:
183         list_construct->addInput(node->inputs().at(1));
184         break;
185       case aten::insert: {
186         int pos = toIValue(node->inputs().at(1))->toInt();
187         int size = list_construct->inputs().size();
188         // insert to neg position equals insert to std::max(pos+size, 0)
189         if (pos < 0) {
190           pos = std::max(pos + size, 0);
191         }
192         // insert beyond current list length is the same as append
193         pos = std::min(pos, size);
194         list_construct->insertInput(pos, node->inputs().at(2));
195         break;
196       }
197       case aten::_set_item: {
198         int pos = toIValue(node->inputs().at(1))->toInt();
199         int size = list_construct->inputs().size();
200         if (pos < 0) {
201           pos = std::max(pos + size, 0);
202         }
203         list_construct->replaceInput(pos, node->input(2));
204         break;
205       }
206       default:
207         TORCH_INTERNAL_ASSERT(false);
208     }
209 
210     // process use-chain and aliasing of node output
211     bool has_output = (!node->outputs().empty());
212     if (has_output) {
213       node->output()->replaceAllUsesWith(mutated_value);
214       getOrCreateAliasDb()->writeIndex_->erase(node);
215     }
216 
217     node->destroy();
218 
219     // TODO: don't strictly need to reset write cache, evaluate on models
220     getOrCreateAliasDb()->buildWrittenToLocationsIndex();
221   }
222 
223   return changed;
224 }
225 
RemoveTensorMutation(Block * block)226 bool MutationRemover::RemoveTensorMutation(Block* block) {
227   bool changed = false;
228   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
229     auto* node = *it;
230     it++;
231 
232     for (Block* sub_block : node->blocks()) {
233       changed |= RemoveTensorMutation(sub_block);
234     }
235 
236     if (mutation_filter_) {
237       const auto& mutation_filter = *mutation_filter_;
238       if (!mutation_filter(node)) {
239         continue;
240       }
241     }
242 
243     // TODO: out op variants
244     if (!inplaceOpVariant(node)) {
245       continue;
246     }
247 
248     Value* mutated_value = node->inputs().at(0);
249     if (!tryMakeCreationAndMutationAtomic(mutated_value, node) &&
250         !tryMakeUnaliasedIfOutputAndMutationAtomic(mutated_value, node)) {
251       continue;
252     }
253 
254     Node* new_node = nullptr;
255     if (isSpecialMappedOp(node)) {
256       new_node = createSpecialMappedOp(node);
257     } else {
258       auto schema_name = node->schema().name();
259       auto new_schema = schema_name.substr(0, schema_name.size() - 1);
260       new_node = graph_->create(Symbol::fromQualString(new_schema), 1);
261       new_node->copyMetadata(node);
262       new_node->insertBefore(node);
263       for (Value* input : node->inputs()) {
264         new_node->addInput(input);
265       }
266       new_node->output()->setType(node->output()->type());
267 
268       // weird case where there is an inplace op and an equivalent functional op
269       // of the same symbol, but they have different schemas
270       if (!new_node->maybeOperator()) {
271         new_node->destroy();
272         continue;
273       }
274     }
275 
276     changed = true;
277     mutated_value->replaceAllUsesAfterNodeWith(node, new_node->output());
278     node->output()->replaceAllUsesWith(new_node->output());
279 
280     // We rewrite something like:
281     // x = torch.zeros()
282     // x.add_(1)
283     // x.add_(2)
284     // to:
285     // x = torch.zeros()
286     // x0 = x.add(1)
287     // x0.add_(2)
288     // For the remainder of the function, x0 will have the
289     // same aliasing relationships as the original x.
290     // To avoid rebuilding the entire alias db, we can replace
291     // the memory DAG element of x with x0.
292     getOrCreateAliasDb()->replaceWithNewValue(
293         mutated_value, new_node->output());
294 
295     // it is an invariant that all mutable types have an element in the memory
296     // DAG so we must regive x an alias db element. We have already verified
297     // that the mutated value is a fresh alias with a single use.
298     getOrCreateAliasDb()->createValue(mutated_value);
299 
300     // We must erase the destroyed node from the AliasDb lists of writes
301     getOrCreateAliasDb()->writeIndex_->erase(node);
302     node->destroy();
303 
304     // now that we have removed a mutating op, the write cache is stale
305     // TODO: don't strictly need to reset write cache, evaluate on models
306     getOrCreateAliasDb()->buildWrittenToLocationsIndex();
307   }
308 
309   return changed;
310 }
311 
inplaceOpVariant(Node * n)312 bool MutationRemover::inplaceOpVariant(Node* n) {
313   if (!n->kind().is_aten()) {
314     return false;
315   }
316 
317   if (isSpecialMappedOp(n)) {
318     return true;
319   }
320 
321   auto name = n->schema().name();
322   bool inplace_op = name.at(name.size() - 1) == '_';
323   if (!inplace_op) {
324     return false;
325   }
326 
327   // needs to have alias analysis by schema
328   auto op = n->maybeOperator();
329   if (!op) {
330     return false;
331   }
332   if (op->aliasAnalysisKind() != AliasAnalysisKind::FROM_SCHEMA) {
333     return false;
334   }
335 
336   // all inplace ops at time of writing have a single input that is mutated
337   // and returned. check that this is true, anything else could have strange
338   // semantics,
339   if (n->outputs().size() != 1 || n->inputs().empty()) {
340     return false;
341   }
342   auto inputs = n->inputs();
343   if (!getOrCreateAliasDb()->writesToAlias(n, {inputs.at(0)}) ||
344       getOrCreateAliasDb()->writesToAlias(
345           n, {inputs.slice(1).begin(), inputs.slice(1).end()})) {
346     return false;
347   }
348 
349   auto new_schema = name.substr(0, name.size() - 1);
350   return !getAllOperatorsFor(Symbol::fromQualString(new_schema)).empty();
351 }
352 
RemoveListMutation(const std::shared_ptr<Graph> & graph)353 bool RemoveListMutation(const std::shared_ptr<Graph>& graph) {
354   MutationRemover mr(graph);
355   return mr.removeListMutation();
356 }
357 
RemoveTensorMutation(const std::shared_ptr<Graph> & graph,std::optional<std::function<bool (Node *)>> mutation_filter)358 bool RemoveTensorMutation(
359     const std::shared_ptr<Graph>& graph,
360     std::optional<std::function<bool(Node*)>> mutation_filter) {
361   MutationRemover mr(graph, std::move(mutation_filter));
362   return mr.removeTensorMutation();
363 }
364 
__anon3871732b0102() 365 static const std::unordered_set<Symbol> activation_ops = []() {
366   std::unordered_set<Symbol> target_ops;
367   for (const auto& iter : activation_type_promotion_mapping) {
368     std::string name = std::string(iter.first.toQualString()) + "_";
369     target_ops.insert(Symbol::fromQualString(name));
370   }
371   return target_ops;
372 }();
373 
InplaceToFunctionalActivation(const std::shared_ptr<Graph> & graph)374 bool InplaceToFunctionalActivation(const std::shared_ptr<Graph>& graph) {
375   return RemoveTensorMutation(graph, [](Node* node) {
376     return activation_ops.count(node->kind()) != 0;
377   });
378 }
379 
380 } // namespace torch::jit
381