1 #include <torch/csrc/jit/passes/erase_number_types.h> 2 3 #include <torch/csrc/jit/ir/constants.h> 4 #include <torch/csrc/jit/jit_log.h> 5 #include <torch/csrc/jit/passes/dead_code_elimination.h> 6 7 #include <ATen/ScalarOps.h> 8 9 namespace torch::jit { 10 SetNumTypeToTensorType(Value * v)11static void SetNumTypeToTensorType(Value* v) { 12 if (v->type()->isSubtypeOf(*NumberType::get())) { 13 v->setType(TensorType::fromNumberType(*v->type())); 14 } else if (v->type()->isSubtypeOf(*BoolType::get())) { 15 v->setType(TensorType::fromBoolType()); 16 } 17 } 18 EraseNumberTypesOnBlock(Block * block)19void EraseNumberTypesOnBlock(Block* block) { 20 for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; 21 ++it) { 22 for (auto inp : it->inputs()) { 23 SetNumTypeToTensorType(inp); 24 } 25 for (auto sub : it->blocks()) { 26 EraseNumberTypesOnBlock(sub); 27 } 28 switch (it->kind()) { 29 case prim::Constant: { 30 // remove primitive constants, replacing with tensor equivalent 31 // ONNX does not support non-tensor constants 32 if (it->output()->type()->isSubtypeOf(*NumberType::get()) || 33 it->output()->type()->isSubtypeOf(*BoolType::get())) { 34 at::Scalar s; 35 if (it->output()->type()->isSubtypeOf(*BoolType::get())) { 36 s = *constant_as<bool>(it->output()); 37 } else { 38 s = *constant_as<at::Scalar>(it->output()); 39 } 40 41 WithInsertPoint guard(*it); 42 Value* r = block->owningGraph()->insertConstant( 43 scalar_to_tensor(s), std::nullopt, it->scope()); 44 r->copyMetadata(it->output()); 45 it->output()->replaceAllUsesWith(r); 46 it.destroyCurrent(); 47 } 48 } break; 49 case aten::Bool: 50 case aten::Float: 51 case aten::Int: 52 case aten::FloatImplicit: 53 case aten::IntImplicit: 54 case aten::ScalarImplicit: 55 case prim::NumToTensor: { 56 it->output()->replaceAllUsesWith(it->inputs()[0]); 57 it.destroyCurrent(); 58 } break; 59 default: { 60 for (auto o : it->outputs()) { 61 SetNumTypeToTensorType(o); 62 } 63 } break; 64 } 65 } 66 } 67 EraseNumberTypes(const std::shared_ptr<Graph> & graph)68void EraseNumberTypes(const std::shared_ptr<Graph>& graph) { 69 for (auto inp : graph->inputs()) { 70 SetNumTypeToTensorType(inp); 71 } 72 EraseNumberTypesOnBlock(graph->block()); 73 GRAPH_DUMP("After EraseNumberTypes: ", graph); 74 } 75 } // namespace torch::jit 76