1 #include <torch/csrc/jit/passes/clear_undefinedness.h> 2 3 #include <torch/csrc/jit/jit_log.h> 4 5 namespace torch::jit { 6 clearUndefinedness(Value * o)7static void clearUndefinedness(Value* o) { 8 if (o->type()->kind() == TensorType::Kind) { 9 o->setType(TensorType::get()); 10 } else if ( 11 o->type()->kind() == ListType::Kind && 12 o->type()->expectRef<ListType>().getElementType()->kind() == 13 TensorType::Kind) { 14 o->setType(ListType::create(TensorType::get())); 15 } 16 } 17 clearUndefinedness(Block * block)18static void clearUndefinedness(Block* block) { 19 for (auto n : block->nodes()) { 20 for (auto o : n->outputs()) { 21 clearUndefinedness(o); 22 } 23 for (auto ib : n->blocks()) { 24 clearUndefinedness(ib); 25 } 26 } 27 } 28 ClearUndefinedness(const std::shared_ptr<Graph> & graph)29void ClearUndefinedness(const std::shared_ptr<Graph>& graph) { 30 for (auto i : graph->inputs()) { 31 clearUndefinedness(i); 32 } 33 clearUndefinedness(graph->block()); 34 GRAPH_DUMP("After removeUndefinedness: ", graph); 35 } 36 37 } // namespace torch::jit 38