xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/remove_mutation.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/csrc/jit/ir/alias_analysis.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 
8 #include <utility>
9 
10 namespace torch::jit {
11 
12 struct TORCH_API MutationRemover {
13   MutationRemover(
14       std::shared_ptr<Graph> graph,
15       std::optional<std::function<bool(Node*)>> mutation_filter = std::nullopt)
mutation_filter_MutationRemover16       : mutation_filter_(std::move(mutation_filter)),
17         aliasDb_(nullptr),
18         graph_(std::move(graph)) {}
19 
20   // return true if graph is modified
21   bool removeListMutation();
22 
23   // return true if graph is modified
24   bool removeTensorMutation();
25 
isSpecialMappedOpMutationRemover26   bool isSpecialMappedOp(Node* n) {
27     return n->matches("aten::zero_(Tensor(a!) self) -> Tensor(a!)") ||
28         n->matches(
29             "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)") ||
30         n->matches(
31             "aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)");
32   }
33 
34   bool inplaceOpVariant(Node* n);
35 
36   static bool hasSideEffectOrAlias(Value* v, AliasDb* aliasDb);
37 
38  private:
39   Node* createSpecialMappedOp(Node* n);
40   bool listMutationFollowingListConstruct(Node* n);
41   bool tryMakeCreationAndMutationAtomic(
42       Value* mutated_value,
43       Node* mutating_op);
44   bool tryMakeUnaliasedIfOutputAndMutationAtomic(
45       Value* mutated_value,
46       Node* mutating_op);
47   // return true if graph is modified
48   bool RemoveListMutation(Block* block);
49   // return true if graph is modified
50   bool RemoveTensorMutation(Block* block);
51 
getOrCreateAliasDbMutationRemover52   AliasDb* getOrCreateAliasDb() {
53     if (!aliasDb_) {
54       aliasDb_ = std::make_unique<AliasDb>(graph_);
55     }
56     return aliasDb_.get();
57   }
58 
59   std::optional<std::function<bool(Node*)>> mutation_filter_;
60   std::unique_ptr<AliasDb> aliasDb_ = nullptr;
61   std::shared_ptr<Graph> graph_;
62 };
63 
64 // Removes list mutation with functional equivalents
65 // return true if graph is modified
66 TORCH_API bool RemoveListMutation(const std::shared_ptr<Graph>& graph);
67 
68 // Replaces in-place aten ops with their functional equivalents
69 // when it can be proven that this does not change graph semantics
70 // if `mutation_filter` is present, the pass will only attempt to
71 // remove mutation on nodes which return true for the filter
72 // return true if graph is modified
73 TORCH_API bool RemoveTensorMutation(
74     const std::shared_ptr<Graph>& graph,
75     std::optional<std::function<bool(Node*)>> mutation_filter = std::nullopt);
76 
77 // Replaces in-place aten activation ops with their functional equivalence
78 TORCH_API bool InplaceToFunctionalActivation(
79     const std::shared_ptr<Graph>& graph);
80 
81 } // namespace torch::jit
82