xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/remove_expands.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/remove_expands.h>
2 
3 namespace torch::jit {
4 
RemoveExpands(Block * block)5 static void RemoveExpands(Block* block) {
6   for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
7        ++it) {
8     for (auto sub : it->blocks())
9       RemoveExpands(sub);
10 
11     if (it->kind() == aten::expand && it->get<bool>(attr::implicit) == true) {
12       it->output()->replaceAllUsesWith(it->namedInput(attr::self));
13       it.destroyCurrent();
14     }
15   }
16 }
17 
RemoveExpands(const std::shared_ptr<Graph> & graph)18 void RemoveExpands(const std::shared_ptr<Graph>& graph) {
19   RemoveExpands(graph->block());
20 }
21 
22 } // namespace torch::jit
23