xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/ir_printer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ostream>
4 
5 #include <torch/csrc/jit/tensorexpr/fwd_decls.h>
6 #include <torch/csrc/jit/tensorexpr/ir.h>
7 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
8 #include <torch/csrc/jit/tensorexpr/unique_name_manager.h>
9 
10 namespace torch::jit::tensorexpr {
11 
12 class Tensor;
13 
14 class TORCH_API IRPrinter : public IRVisitor {
15  public:
IRPrinter(std::ostream & os)16   explicit IRPrinter(std::ostream& os) : printer_os_(this, os) {}
17 
18   void print(ExprHandle);
19   void print(Expr&);
20   void print(Stmt&);
21   void visit(const AddPtr& v) override;
22   void visit(const SubPtr& v) override;
23   void visit(const MulPtr& v) override;
24   void visit(const DivPtr& v) override;
25   void visit(const ModPtr& v) override;
26   void visit(const MaxPtr& v) override;
27   void visit(const MinPtr& v) override;
28   void visit(const AndPtr& v) override;
29   void visit(const OrPtr& v) override;
30   void visit(const XorPtr& v) override;
31   void visit(const LshiftPtr& v) override;
32   void visit(const RshiftPtr& v) override;
33   void visit(const CompareSelectPtr& v) override;
34 #define IMM_PRINT_VISIT(Type, Name) void visit(const Name##ImmPtr& v) override;
35   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT);
36 #undef IMM_PRINT_VISIT
37   void visit(const CastPtr& v) override;
38   void visit(const BitCastPtr& v) override;
39   void visit(const VarPtr& v) override;
40   void visit(const BufPtr& v) override;
41   void visit(const RampPtr& v) override;
42   void visit(const LoadPtr& v) override;
43   void visit(const BroadcastPtr& v) override;
44   void visit(const IfThenElsePtr& v) override;
45   void visit(const IntrinsicsPtr& v) override;
46   void visit(const TermPtr& v) override;
47   void visit(const PolynomialPtr& v) override;
48   void visit(const RoundOffPtr& v) override;
49   void visit(const MaxTermPtr& v) override;
50   void visit(const MinTermPtr& v) override;
51   void visit(const ReduceOpPtr& v) override;
52 
53   void visit(const AtomicAddPtr& v) override;
54   void visit(const SyncThreadsPtr& v) override;
55   void visit(const ExternalCallPtr& v) override;
56   void visit(const ExternalCallWithAllocPtr& v) override;
57   void visit(const StorePtr& v) override;
58   void visit(const ForPtr& v) override;
59   void visit(const CondPtr& v) override;
60   void visit(const BlockPtr& v) override;
61   void visit(const AllocatePtr& v) override;
62   void visit(const FreePtr& v) override;
63   void visit(const FreeExtPtr& v) override;
64   void visit(const PlacementAllocatePtr& v) override;
65   void visit(const LetPtr& v) override;
66 
67   // A child class may have a difference rule for generating dtype
68   // string, e.g. CUDA needs int64_t to be generated as long long.
69   virtual std::string dtypeToCppString(const Dtype& dtype);
70 
os()71   std::ostream& os() {
72     return printer_os_;
73   }
74 
75   class PrinterStream : public std::ostream {
76    public:
PrinterStream(IRPrinter * printer,std::ostream & os)77     PrinterStream(IRPrinter* printer, std::ostream& os)
78         : std::ostream(os.rdbuf()), printer_(printer) {}
79 
printer()80     IRPrinter* printer() {
81       return printer_;
82     }
83 
84    private:
85     IRPrinter* printer_ = nullptr;
86   };
87 
88  protected:
89   std::string to_string(CompareSelectOperation op);
90 
name_manager()91   UniqueNameManager* name_manager() {
92     return &name_manager_;
93   }
94   void emitIndent();
95 
96   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
97   int indent_ = 0;
98 
99  private:
100   PrinterStream printer_os_;
101   UniqueNameManager name_manager_;
102 };
103 
104 TORCH_API std::ostream& operator<<(std::ostream& stream, const Expr&);
105 TORCH_API std::ostream& operator<<(std::ostream& stream, const ExprHandle&);
106 TORCH_API std::ostream& operator<<(std::ostream& stream, const Stmt&);
107 TORCH_API std::ostream& operator<<(std::ostream& stream, const Tensor&);
108 
109 TORCH_API void print(const ExprPtr& expr);
110 TORCH_API void print(const StmtPtr& stmt);
111 TORCH_API void print(const Tensor& t);
112 
113 } // namespace torch::jit::tensorexpr
114 
115 namespace std {
116 
117 using torch::jit::tensorexpr::Expr;
118 using torch::jit::tensorexpr::ExprPtr;
119 using torch::jit::tensorexpr::Stmt;
120 using torch::jit::tensorexpr::StmtPtr;
121 using torch::jit::tensorexpr::Tensor;
122 
123 TORCH_API std::string to_string(const ExprPtr& expr);
124 TORCH_API std::string to_string(const StmtPtr& stmt);
125 TORCH_API std::string to_string(const Tensor& t);
126 } // namespace std
127