1 #pragma once 2 3 #include <ostream> 4 5 #include <torch/csrc/jit/tensorexpr/fwd_decls.h> 6 #include <torch/csrc/jit/tensorexpr/ir.h> 7 #include <torch/csrc/jit/tensorexpr/ir_visitor.h> 8 #include <torch/csrc/jit/tensorexpr/unique_name_manager.h> 9 10 namespace torch::jit::tensorexpr { 11 12 class Tensor; 13 14 class TORCH_API IRPrinter : public IRVisitor { 15 public: IRPrinter(std::ostream & os)16 explicit IRPrinter(std::ostream& os) : printer_os_(this, os) {} 17 18 void print(ExprHandle); 19 void print(Expr&); 20 void print(Stmt&); 21 void visit(const AddPtr& v) override; 22 void visit(const SubPtr& v) override; 23 void visit(const MulPtr& v) override; 24 void visit(const DivPtr& v) override; 25 void visit(const ModPtr& v) override; 26 void visit(const MaxPtr& v) override; 27 void visit(const MinPtr& v) override; 28 void visit(const AndPtr& v) override; 29 void visit(const OrPtr& v) override; 30 void visit(const XorPtr& v) override; 31 void visit(const LshiftPtr& v) override; 32 void visit(const RshiftPtr& v) override; 33 void visit(const CompareSelectPtr& v) override; 34 #define IMM_PRINT_VISIT(Type, Name) void visit(const Name##ImmPtr& v) override; 35 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT); 36 #undef IMM_PRINT_VISIT 37 void visit(const CastPtr& v) override; 38 void visit(const BitCastPtr& v) override; 39 void visit(const VarPtr& v) override; 40 void visit(const BufPtr& v) override; 41 void visit(const RampPtr& v) override; 42 void visit(const LoadPtr& v) override; 43 void visit(const BroadcastPtr& v) override; 44 void visit(const IfThenElsePtr& v) override; 45 void visit(const IntrinsicsPtr& v) override; 46 void visit(const TermPtr& v) override; 47 void visit(const PolynomialPtr& v) override; 48 void visit(const RoundOffPtr& v) override; 49 void visit(const MaxTermPtr& v) override; 50 void visit(const MinTermPtr& v) override; 51 void visit(const ReduceOpPtr& v) override; 52 53 void visit(const AtomicAddPtr& v) override; 54 void visit(const SyncThreadsPtr& v) override; 55 void visit(const ExternalCallPtr& v) override; 56 void visit(const ExternalCallWithAllocPtr& v) override; 57 void visit(const StorePtr& v) override; 58 void visit(const ForPtr& v) override; 59 void visit(const CondPtr& v) override; 60 void visit(const BlockPtr& v) override; 61 void visit(const AllocatePtr& v) override; 62 void visit(const FreePtr& v) override; 63 void visit(const FreeExtPtr& v) override; 64 void visit(const PlacementAllocatePtr& v) override; 65 void visit(const LetPtr& v) override; 66 67 // A child class may have a difference rule for generating dtype 68 // string, e.g. CUDA needs int64_t to be generated as long long. 69 virtual std::string dtypeToCppString(const Dtype& dtype); 70 os()71 std::ostream& os() { 72 return printer_os_; 73 } 74 75 class PrinterStream : public std::ostream { 76 public: PrinterStream(IRPrinter * printer,std::ostream & os)77 PrinterStream(IRPrinter* printer, std::ostream& os) 78 : std::ostream(os.rdbuf()), printer_(printer) {} 79 printer()80 IRPrinter* printer() { 81 return printer_; 82 } 83 84 private: 85 IRPrinter* printer_ = nullptr; 86 }; 87 88 protected: 89 std::string to_string(CompareSelectOperation op); 90 name_manager()91 UniqueNameManager* name_manager() { 92 return &name_manager_; 93 } 94 void emitIndent(); 95 96 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 97 int indent_ = 0; 98 99 private: 100 PrinterStream printer_os_; 101 UniqueNameManager name_manager_; 102 }; 103 104 TORCH_API std::ostream& operator<<(std::ostream& stream, const Expr&); 105 TORCH_API std::ostream& operator<<(std::ostream& stream, const ExprHandle&); 106 TORCH_API std::ostream& operator<<(std::ostream& stream, const Stmt&); 107 TORCH_API std::ostream& operator<<(std::ostream& stream, const Tensor&); 108 109 TORCH_API void print(const ExprPtr& expr); 110 TORCH_API void print(const StmtPtr& stmt); 111 TORCH_API void print(const Tensor& t); 112 113 } // namespace torch::jit::tensorexpr 114 115 namespace std { 116 117 using torch::jit::tensorexpr::Expr; 118 using torch::jit::tensorexpr::ExprPtr; 119 using torch::jit::tensorexpr::Stmt; 120 using torch::jit::tensorexpr::StmtPtr; 121 using torch::jit::tensorexpr::Tensor; 122 123 TORCH_API std::string to_string(const ExprPtr& expr); 124 TORCH_API std::string to_string(const StmtPtr& stmt); 125 TORCH_API std::string to_string(const Tensor& t); 126 } // namespace std 127