1 #include <torch/csrc/jit/passes/remove_mutation.h>
2 #include <torch/csrc/jit/passes/restore_mutation.h>
3
4 namespace torch::jit {
5
removeListMutation()6 bool MutationRemover::removeListMutation() {
7 return RemoveListMutation(graph_->block());
8 }
9
removeTensorMutation()10 bool MutationRemover::removeTensorMutation() {
11 return RemoveTensorMutation(graph_->block());
12 }
13
hasSideEffectOrAlias(Value * v,AliasDb * aliasDb)14 bool MutationRemover::hasSideEffectOrAlias(Value* v, AliasDb* aliasDb) {
15 // bail on nodes with side effects, blocks, or graph / graph inputs
16 Node* n = v->node();
17 bool unhandled_node = !n->blocks().empty() ||
18 n->hasAttribute(attr::Subgraph) || n->hasSideEffects() ||
19 (v->node()->kind() == prim::Param);
20
21 // if the output isn't contained or alias by the inputs to its node, it's
22 // unique. No need to check for alias if the node is a ListConstruct.
23 bool mayAliasInputs = (v->node()->kind() != prim::ListConstruct) &&
24 aliasDb->mayContainAlias(v->node()->inputs(), v);
25 return unhandled_node || mayAliasInputs || (v->node()->kind() == prim::Param);
26 }
27
createSpecialMappedOp(Node * n)28 Node* MutationRemover::createSpecialMappedOp(Node* n) {
29 WithInsertPoint guard(n);
30 auto inputs = n->inputs();
31 Node* new_node = nullptr;
32 if (n->matches(
33 "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)")) {
34 auto dtype = graph_->insert(prim::dtype, {inputs.at(0)});
35 new_node = graph_
36 ->insert(
37 aten::full_like,
38 {inputs.at(0), inputs.at(1)},
39 {NamedValue("dtype", dtype)})
40 ->node();
41 new_node->copyMetadata(n);
42 new_node->output()->setType(n->output()->type());
43 } else if (n->matches("aten::zero_(Tensor(a!) self) -> Tensor(a!)")) {
44 new_node = graph_->insert(aten::zeros_like, {n->inputs().at(0)})->node();
45 } else if (
46 n->matches(
47 "aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)")) {
48 // TODO: we should have normal_like operator
49 // normal(float mean, float std, int[] size, *, Generator? generator=None,
50 // ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool?
51 // pin_memory=None) -> Tensor
52 auto size = graph_->insert(aten::size, {n->inputs().at(0)});
53 auto dtype = graph_->insert(prim::dtype, {n->inputs().at(0)});
54 auto layout = graph_->insert(prim::layout, {n->inputs().at(0)});
55 auto device = graph_->insert(prim::device, {n->inputs().at(0)});
56 auto pin_memory = graph_->insert(aten::is_pinned, {n->inputs().at(0)});
57 auto generator = graph_->insertConstant(IValue());
58 new_node = graph_->insertNode(graph_->create(
59 aten::normal,
60 {n->inputs().at(1),
61 n->inputs().at(2),
62 size,
63 generator,
64 dtype,
65 layout,
66 device,
67 pin_memory}));
68 } else {
69 TORCH_INTERNAL_ASSERT(false);
70 }
71 new_node->copyMetadata(n);
72 new_node->output()->setType(n->output()->type());
73 return new_node;
74 }
75
removableSetItem(Node * n)76 static bool removableSetItem(Node* n) {
77 if (n->kind() != aten::_set_item ||
78 n->input(1)->node()->kind() != prim::Constant) {
79 return false;
80 }
81 if (n->inputs().at(0)->node()->kind() != prim::ListConstruct) {
82 return false;
83 }
84 auto li_node = n->inputs().at(0)->node();
85 int64_t index = *constant_as<int64_t>(n->input(1));
86 if (index < 0) {
87 index += li_node->inputs().size();
88 }
89 auto li_len = static_cast<int64_t>(li_node->inputs().size());
90 return index < li_len && index >= 0;
91 }
92
listMutationFollowingListConstruct(Node * n)93 bool MutationRemover::listMutationFollowingListConstruct(Node* n) {
94 return (
95 (n->kind() == aten::append ||
96 (n->kind() == aten::insert &&
97 n->inputs().at(1)->node()->kind() == prim::Constant) ||
98 (removableSetItem(n))) &&
99 n->inputs().at(0)->node()->kind() == prim::ListConstruct);
100 }
101
tryMakeCreationAndMutationAtomic(Value * mutated_value,Node * mutating_op)102 bool MutationRemover::tryMakeCreationAndMutationAtomic(
103 Value* mutated_value,
104 Node* mutating_op) {
105 // We can only remove mutation to values that are unique aliases in the
106 // graph. if x = y[0] or y = self.y, then removing the mutation could
107 // change observable semantics
108 if (hasSideEffectOrAlias(mutated_value, getOrCreateAliasDb())) {
109 return false;
110 }
111
112 // In order to safely remove a mutation, the creation of a tensor and its
113 // subsequent mutation need to be one atomic operation
114 return getOrCreateAliasDb()->moveBeforeTopologicallyValid(
115 mutated_value->node(), mutating_op);
116 }
117
tryMakeUnaliasedIfOutputAndMutationAtomic(Value * mutated_value,Node * mutating_op)118 bool MutationRemover::tryMakeUnaliasedIfOutputAndMutationAtomic(
119 Value* mutated_value,
120 Node* mutating_op) {
121 // if cond:
122 // x = op()
123 // else:
124 // x = op()
125 // x = add_(1)
126 // if x in both blocks have no other uses and are unaliased in the graph,
127 // and we make the if node and the mutation atomic,
128 // then removing mutation add_ does not change observable semantics
129
130 if (mutated_value->node()->kind() != prim::If) {
131 return false;
132 }
133
134 auto if_node = mutated_value->node();
135 auto offset = mutated_value->offset();
136 auto true_value = if_node->blocks().at(0)->outputs().at(offset);
137 auto false_value = if_node->blocks().at(1)->outputs().at(offset);
138
139 if (true_value->uses().size() > 1 || false_value->uses().size() > 1) {
140 return false;
141 }
142
143 if (hasSideEffectOrAlias(true_value, getOrCreateAliasDb()) ||
144 hasSideEffectOrAlias(false_value, getOrCreateAliasDb())) {
145 return false;
146 }
147
148 return getOrCreateAliasDb()->moveBeforeTopologicallyValid(
149 if_node, mutating_op);
150 }
151
RemoveListMutation(Block * block)152 bool MutationRemover::RemoveListMutation(Block* block) {
153 bool changed = false;
154 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
155 auto* node = *it;
156 it++;
157
158 for (Block* sub_block : node->blocks()) {
159 changed |= RemoveListMutation(sub_block);
160 }
161
162 if (!listMutationFollowingListConstruct(node)) {
163 continue;
164 }
165
166 Value* mutated_value = node->inputs().at(0);
167 if (!tryMakeCreationAndMutationAtomic(mutated_value, node)) {
168 continue;
169 }
170
171 changed = true;
172
173 // We rewrite something like:
174 // x = {v0}
175 // x.append(v1) (or x.insert(0, v1))
176 // to:
177 // x = {v0, v1} (or x = {v1, v0})
178 // We can remove x.append from the alias db list of writes.
179 // All other aliasing properties remain valid.
180 Node* list_construct = mutated_value->node();
181 switch (node->kind()) {
182 case aten::append:
183 list_construct->addInput(node->inputs().at(1));
184 break;
185 case aten::insert: {
186 int pos = toIValue(node->inputs().at(1))->toInt();
187 int size = list_construct->inputs().size();
188 // insert to neg position equals insert to std::max(pos+size, 0)
189 if (pos < 0) {
190 pos = std::max(pos + size, 0);
191 }
192 // insert beyond current list length is the same as append
193 pos = std::min(pos, size);
194 list_construct->insertInput(pos, node->inputs().at(2));
195 break;
196 }
197 case aten::_set_item: {
198 int pos = toIValue(node->inputs().at(1))->toInt();
199 int size = list_construct->inputs().size();
200 if (pos < 0) {
201 pos = std::max(pos + size, 0);
202 }
203 list_construct->replaceInput(pos, node->input(2));
204 break;
205 }
206 default:
207 TORCH_INTERNAL_ASSERT(false);
208 }
209
210 // process use-chain and aliasing of node output
211 bool has_output = (!node->outputs().empty());
212 if (has_output) {
213 node->output()->replaceAllUsesWith(mutated_value);
214 getOrCreateAliasDb()->writeIndex_->erase(node);
215 }
216
217 node->destroy();
218
219 // TODO: don't strictly need to reset write cache, evaluate on models
220 getOrCreateAliasDb()->buildWrittenToLocationsIndex();
221 }
222
223 return changed;
224 }
225
RemoveTensorMutation(Block * block)226 bool MutationRemover::RemoveTensorMutation(Block* block) {
227 bool changed = false;
228 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
229 auto* node = *it;
230 it++;
231
232 for (Block* sub_block : node->blocks()) {
233 changed |= RemoveTensorMutation(sub_block);
234 }
235
236 if (mutation_filter_) {
237 const auto& mutation_filter = *mutation_filter_;
238 if (!mutation_filter(node)) {
239 continue;
240 }
241 }
242
243 // TODO: out op variants
244 if (!inplaceOpVariant(node)) {
245 continue;
246 }
247
248 Value* mutated_value = node->inputs().at(0);
249 if (!tryMakeCreationAndMutationAtomic(mutated_value, node) &&
250 !tryMakeUnaliasedIfOutputAndMutationAtomic(mutated_value, node)) {
251 continue;
252 }
253
254 Node* new_node = nullptr;
255 if (isSpecialMappedOp(node)) {
256 new_node = createSpecialMappedOp(node);
257 } else {
258 auto schema_name = node->schema().name();
259 auto new_schema = schema_name.substr(0, schema_name.size() - 1);
260 new_node = graph_->create(Symbol::fromQualString(new_schema), 1);
261 new_node->copyMetadata(node);
262 new_node->insertBefore(node);
263 for (Value* input : node->inputs()) {
264 new_node->addInput(input);
265 }
266 new_node->output()->setType(node->output()->type());
267
268 // weird case where there is an inplace op and an equivalent functional op
269 // of the same symbol, but they have different schemas
270 if (!new_node->maybeOperator()) {
271 new_node->destroy();
272 continue;
273 }
274 }
275
276 changed = true;
277 mutated_value->replaceAllUsesAfterNodeWith(node, new_node->output());
278 node->output()->replaceAllUsesWith(new_node->output());
279
280 // We rewrite something like:
281 // x = torch.zeros()
282 // x.add_(1)
283 // x.add_(2)
284 // to:
285 // x = torch.zeros()
286 // x0 = x.add(1)
287 // x0.add_(2)
288 // For the remainder of the function, x0 will have the
289 // same aliasing relationships as the original x.
290 // To avoid rebuilding the entire alias db, we can replace
291 // the memory DAG element of x with x0.
292 getOrCreateAliasDb()->replaceWithNewValue(
293 mutated_value, new_node->output());
294
295 // it is an invariant that all mutable types have an element in the memory
296 // DAG so we must regive x an alias db element. We have already verified
297 // that the mutated value is a fresh alias with a single use.
298 getOrCreateAliasDb()->createValue(mutated_value);
299
300 // We must erase the destroyed node from the AliasDb lists of writes
301 getOrCreateAliasDb()->writeIndex_->erase(node);
302 node->destroy();
303
304 // now that we have removed a mutating op, the write cache is stale
305 // TODO: don't strictly need to reset write cache, evaluate on models
306 getOrCreateAliasDb()->buildWrittenToLocationsIndex();
307 }
308
309 return changed;
310 }
311
inplaceOpVariant(Node * n)312 bool MutationRemover::inplaceOpVariant(Node* n) {
313 if (!n->kind().is_aten()) {
314 return false;
315 }
316
317 if (isSpecialMappedOp(n)) {
318 return true;
319 }
320
321 auto name = n->schema().name();
322 bool inplace_op = name.at(name.size() - 1) == '_';
323 if (!inplace_op) {
324 return false;
325 }
326
327 // needs to have alias analysis by schema
328 auto op = n->maybeOperator();
329 if (!op) {
330 return false;
331 }
332 if (op->aliasAnalysisKind() != AliasAnalysisKind::FROM_SCHEMA) {
333 return false;
334 }
335
336 // all inplace ops at time of writing have a single input that is mutated
337 // and returned. check that this is true, anything else could have strange
338 // semantics,
339 if (n->outputs().size() != 1 || n->inputs().empty()) {
340 return false;
341 }
342 auto inputs = n->inputs();
343 if (!getOrCreateAliasDb()->writesToAlias(n, {inputs.at(0)}) ||
344 getOrCreateAliasDb()->writesToAlias(
345 n, {inputs.slice(1).begin(), inputs.slice(1).end()})) {
346 return false;
347 }
348
349 auto new_schema = name.substr(0, name.size() - 1);
350 return !getAllOperatorsFor(Symbol::fromQualString(new_schema)).empty();
351 }
352
RemoveListMutation(const std::shared_ptr<Graph> & graph)353 bool RemoveListMutation(const std::shared_ptr<Graph>& graph) {
354 MutationRemover mr(graph);
355 return mr.removeListMutation();
356 }
357
RemoveTensorMutation(const std::shared_ptr<Graph> & graph,std::optional<std::function<bool (Node *)>> mutation_filter)358 bool RemoveTensorMutation(
359 const std::shared_ptr<Graph>& graph,
360 std::optional<std::function<bool(Node*)>> mutation_filter) {
361 MutationRemover mr(graph, std::move(mutation_filter));
362 return mr.removeTensorMutation();
363 }
364
__anon3871732b0102() 365 static const std::unordered_set<Symbol> activation_ops = []() {
366 std::unordered_set<Symbol> target_ops;
367 for (const auto& iter : activation_type_promotion_mapping) {
368 std::string name = std::string(iter.first.toQualString()) + "_";
369 target_ops.insert(Symbol::fromQualString(name));
370 }
371 return target_ops;
372 }();
373
InplaceToFunctionalActivation(const std::shared_ptr<Graph> & graph)374 bool InplaceToFunctionalActivation(const std::shared_ptr<Graph>& graph) {
375 return RemoveTensorMutation(graph, [](Node* node) {
376 return activation_ops.count(node->kind()) != 0;
377 });
378 }
379
380 } // namespace torch::jit
381