xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/clear_undefinedness.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/clear_undefinedness.h>
2 
3 #include <torch/csrc/jit/jit_log.h>
4 
5 namespace torch::jit {
6 
clearUndefinedness(Value * o)7 static void clearUndefinedness(Value* o) {
8   if (o->type()->kind() == TensorType::Kind) {
9     o->setType(TensorType::get());
10   } else if (
11       o->type()->kind() == ListType::Kind &&
12       o->type()->expectRef<ListType>().getElementType()->kind() ==
13           TensorType::Kind) {
14     o->setType(ListType::create(TensorType::get()));
15   }
16 }
17 
clearUndefinedness(Block * block)18 static void clearUndefinedness(Block* block) {
19   for (auto n : block->nodes()) {
20     for (auto o : n->outputs()) {
21       clearUndefinedness(o);
22     }
23     for (auto ib : n->blocks()) {
24       clearUndefinedness(ib);
25     }
26   }
27 }
28 
ClearUndefinedness(const std::shared_ptr<Graph> & graph)29 void ClearUndefinedness(const std::shared_ptr<Graph>& graph) {
30   for (auto i : graph->inputs()) {
31     clearUndefinedness(i);
32   }
33   clearUndefinedness(graph->block());
34   GRAPH_DUMP("After removeUndefinedness: ", graph);
35 }
36 
37 } // namespace torch::jit
38