xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/restore_mutation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/jit_type.h>
2 #include <ATen/core/symbol.h>
3 #include <torch/csrc/jit/passes/remove_mutation.h>
4 #include <torch/csrc/jit/passes/restore_mutation.h>
5 
6 namespace torch::jit {
7 
FunctionalToInplaceRewriter(std::shared_ptr<Graph> graph)8 FunctionalToInplaceRewriter::FunctionalToInplaceRewriter(
9     std::shared_ptr<Graph> graph)
10     : aliasDb_(nullptr), graph_(std::move(graph)) {}
11 
CanBeInplace(Node * node)12 bool FunctionalToInplaceRewriter::CanBeInplace(Node* node) {
13   if (activation_type_promotion_mapping.find(node->kind()) ==
14       activation_type_promotion_mapping.end()) {
15     return false;
16   }
17 
18   Symbol inplace_op =
19       Symbol::fromQualString(std::string(node->kind().toQualString()) + "_");
20   if (!inplace_op) {
21     return false;
22   }
23 
24   // If type promotion is allowed, then perform dtype check
25   bool check_dtype = activation_type_promotion_mapping.at(node->kind());
26 
27   Value* input = node->inputs().at(0);
28   Value* output = node->outputs().at(0);
29   auto inputDtype = input->type()->expect<TensorType>()->scalarType();
30   auto outputDtype = output->type()->expect<TensorType>()->scalarType();
31 
32   // In general, we don't need to check shape for activation ops as they
33   // element-wise. But for those where type promotion could happen, we need to
34   // make sure the dtype of input and output are the same. For now the dtype
35   // checking will always fail until the type inference is ready.
36   if (check_dtype &&
37       (!inputDtype || !outputDtype ||
38        inputDtype.value() != outputDtype.value())) {
39     return false;
40   }
41 
42   // Skip if input's def node has side effect or input has alias
43   if (MutationRemover::hasSideEffectOrAlias(input, getOrCreateAliasDb())) {
44     return false;
45   }
46 
47   // If x has more than one use, skip the conversion.
48   // TODO: Use liveness analysis to catch more general scenario
49   return (input->uses().size() == 1);
50 }
51 
FunctionalToInplace(Block * block)52 bool FunctionalToInplaceRewriter::FunctionalToInplace(Block* block) {
53   bool changed = false;
54   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
55     auto* node = *it;
56     it++;
57 
58     for (Block* sub_block : node->blocks()) {
59       changed |= FunctionalToInplace(sub_block);
60     }
61 
62     if (!CanBeInplace(node)) {
63       continue;
64     }
65 
66     changed = true;
67     Node* inplace_node = node->replaceWithNewSymbol(
68         Symbol::fromQualString(node->schema().name() + "_"));
69     inplace_node->output()->replaceAllUsesWith(node->inputs().at(0));
70     getOrCreateAliasDb()->replaceWithNewValue(
71         node->output(), inplace_node->output());
72 
73     node->destroy();
74   }
75   return changed;
76 }
77 
FunctionalToInplaceActivation(const std::shared_ptr<Graph> & graph)78 bool FunctionalToInplaceActivation(const std::shared_ptr<Graph>& graph) {
79   FunctionalToInplaceRewriter rewriter(graph);
80   return rewriter.FunctionalToInplace(graph->block());
81 }
82 
83 } // namespace torch::jit
84