1 #include <ATen/core/jit_type.h> 2 #include <ATen/core/symbol.h> 3 #include <torch/csrc/jit/passes/remove_mutation.h> 4 #include <torch/csrc/jit/passes/restore_mutation.h> 5 6 namespace torch::jit { 7 FunctionalToInplaceRewriter(std::shared_ptr<Graph> graph)8FunctionalToInplaceRewriter::FunctionalToInplaceRewriter( 9 std::shared_ptr<Graph> graph) 10 : aliasDb_(nullptr), graph_(std::move(graph)) {} 11 CanBeInplace(Node * node)12bool FunctionalToInplaceRewriter::CanBeInplace(Node* node) { 13 if (activation_type_promotion_mapping.find(node->kind()) == 14 activation_type_promotion_mapping.end()) { 15 return false; 16 } 17 18 Symbol inplace_op = 19 Symbol::fromQualString(std::string(node->kind().toQualString()) + "_"); 20 if (!inplace_op) { 21 return false; 22 } 23 24 // If type promotion is allowed, then perform dtype check 25 bool check_dtype = activation_type_promotion_mapping.at(node->kind()); 26 27 Value* input = node->inputs().at(0); 28 Value* output = node->outputs().at(0); 29 auto inputDtype = input->type()->expect<TensorType>()->scalarType(); 30 auto outputDtype = output->type()->expect<TensorType>()->scalarType(); 31 32 // In general, we don't need to check shape for activation ops as they 33 // element-wise. But for those where type promotion could happen, we need to 34 // make sure the dtype of input and output are the same. For now the dtype 35 // checking will always fail until the type inference is ready. 36 if (check_dtype && 37 (!inputDtype || !outputDtype || 38 inputDtype.value() != outputDtype.value())) { 39 return false; 40 } 41 42 // Skip if input's def node has side effect or input has alias 43 if (MutationRemover::hasSideEffectOrAlias(input, getOrCreateAliasDb())) { 44 return false; 45 } 46 47 // If x has more than one use, skip the conversion. 48 // TODO: Use liveness analysis to catch more general scenario 49 return (input->uses().size() == 1); 50 } 51 FunctionalToInplace(Block * block)52bool FunctionalToInplaceRewriter::FunctionalToInplace(Block* block) { 53 bool changed = false; 54 for (auto it = block->nodes().begin(); it != block->nodes().end();) { 55 auto* node = *it; 56 it++; 57 58 for (Block* sub_block : node->blocks()) { 59 changed |= FunctionalToInplace(sub_block); 60 } 61 62 if (!CanBeInplace(node)) { 63 continue; 64 } 65 66 changed = true; 67 Node* inplace_node = node->replaceWithNewSymbol( 68 Symbol::fromQualString(node->schema().name() + "_")); 69 inplace_node->output()->replaceAllUsesWith(node->inputs().at(0)); 70 getOrCreateAliasDb()->replaceWithNewValue( 71 node->output(), inplace_node->output()); 72 73 node->destroy(); 74 } 75 return changed; 76 } 77 FunctionalToInplaceActivation(const std::shared_ptr<Graph> & graph)78bool FunctionalToInplaceActivation(const std::shared_ptr<Graph>& graph) { 79 FunctionalToInplaceRewriter rewriter(graph); 80 return rewriter.FunctionalToInplace(graph->block()); 81 } 82 83 } // namespace torch::jit 84