1 #include <torch/csrc/jit/passes/remove_expands.h> 2 3 namespace torch::jit { 4 RemoveExpands(Block * block)5static void RemoveExpands(Block* block) { 6 for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; 7 ++it) { 8 for (auto sub : it->blocks()) 9 RemoveExpands(sub); 10 11 if (it->kind() == aten::expand && it->get<bool>(attr::implicit) == true) { 12 it->output()->replaceAllUsesWith(it->namedInput(attr::self)); 13 it.destroyCurrent(); 14 } 15 } 16 } 17 RemoveExpands(const std::shared_ptr<Graph> & graph)18void RemoveExpands(const std::shared_ptr<Graph>& graph) { 19 RemoveExpands(graph->block()); 20 } 21 22 } // namespace torch::jit 23