1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 #include <memory> 6 7 namespace torch::jit { 8 9 struct Graph; 10 11 struct propagation_error : std::exception {}; 12 13 class PropertyPropBase { 14 // Used for both Shape Propagation and Dtype/Device Propagation 15 public: PropertyPropBase(std::shared_ptr<Graph> graph)16 explicit PropertyPropBase(std::shared_ptr<Graph> graph) 17 : graph_(std::move(graph)) {} 18 virtual ~PropertyPropBase() = default; 19 20 void propagateBlock(Block* block, bool insert_expands = true); 21 // insert_expands is used for shape inference 22 23 void processIf(Node* node); 24 void processLoop(Node* node); 25 26 protected: 27 virtual void propagateNode(Node* node, bool insert_expands = true) = 0; 28 void setUnshapedType(Value* o); 29 void setUnshapedType(Node* node); 30 std::shared_ptr<Graph> graph_; 31 }; 32 33 TORCH_API void EraseShapeInformation(const std::shared_ptr<Graph>& graph); 34 TORCH_API void PropagateInputShapes(const std::shared_ptr<Graph>& graph); 35 36 TORCH_API bool mergeTypes( 37 ArrayRef<Value*> lhs, 38 ArrayRef<Value*> rhs, 39 ArrayRef<Value*> outputs); 40 41 } // namespace torch::jit 42