xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/shape_analysis.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <memory>
6 
7 namespace torch::jit {
8 
9 struct Graph;
10 
11 struct propagation_error : std::exception {};
12 
13 class PropertyPropBase {
14   // Used for both Shape Propagation and Dtype/Device Propagation
15  public:
PropertyPropBase(std::shared_ptr<Graph> graph)16   explicit PropertyPropBase(std::shared_ptr<Graph> graph)
17       : graph_(std::move(graph)) {}
18   virtual ~PropertyPropBase() = default;
19 
20   void propagateBlock(Block* block, bool insert_expands = true);
21   // insert_expands is used for shape inference
22 
23   void processIf(Node* node);
24   void processLoop(Node* node);
25 
26  protected:
27   virtual void propagateNode(Node* node, bool insert_expands = true) = 0;
28   void setUnshapedType(Value* o);
29   void setUnshapedType(Node* node);
30   std::shared_ptr<Graph> graph_;
31 };
32 
33 TORCH_API void EraseShapeInformation(const std::shared_ptr<Graph>& graph);
34 TORCH_API void PropagateInputShapes(const std::shared_ptr<Graph>& graph);
35 
36 TORCH_API bool mergeTypes(
37     ArrayRef<Value*> lhs,
38     ArrayRef<Value*> rhs,
39     ArrayRef<Value*> outputs);
40 
41 } // namespace torch::jit
42