1 #pragma once 2 3 #include <string> 4 #include <unordered_map> 5 #include <unordered_set> 6 #include <utility> 7 8 #include <ATen/ATen.h> 9 #include <torch/csrc/jit/resource_guard.h> 10 #include <torch/csrc/jit/tensorexpr/analysis.h> 11 #include <torch/csrc/jit/tensorexpr/codegen.h> 12 #include <torch/csrc/jit/tensorexpr/ir.h> 13 #include <torch/csrc/jit/tensorexpr/ir_printer.h> 14 #include <torch/csrc/jit/tensorexpr/ir_visitor.h> 15 #include <torch/csrc/jit/tensorexpr/unique_name_manager.h> 16 17 namespace torch::jit::tensorexpr { 18 19 // A class that analyzes the given program relevant for Block backend. 20 class BlockAnalysis : public IRVisitor { 21 public: is_buf_store_target(const BufPtr & buf)22 bool is_buf_store_target(const BufPtr& buf) const { 23 return store_targets_.count(buf) > 0; 24 } 25 loads()26 const std::unordered_set<BufPtr>& loads() const { 27 return loads_; 28 } 29 stores()30 const std::unordered_set<BufPtr>& stores() const { 31 return store_targets_; 32 } 33 block_size()34 int64_t block_size() const { 35 return block_size_; 36 } 37 38 bool areBufsInMap(const std::unordered_set<BufPtr>& bufs) const; 39 40 BufPtr getMultiDimBuf(const BufPtr& buf) const; 41 42 std::string getInputName(const BufPtr& buf) const; 43 getFlatInputName(const BufPtr & buf)44 std::string getFlatInputName(const BufPtr& buf) const { 45 return getInputName(buf) + "_flat"; 46 } 47 getBufferMap()48 std::unordered_map<std::string, BufPtr> getBufferMap() const { 49 return map_input_to_tensor_bufs_; 50 } 51 52 private: 53 void visit(const StorePtr& v) override; 54 void visit(const LoadPtr& v) override; 55 void visit(const ForPtr& v) override; 56 57 std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_; 58 std::unordered_set<BufPtr> store_targets_; 59 std::unordered_set<BufPtr> loads_; 60 int64_t block_size_ = 32; 61 }; 62 63 // A class that overrides the underlying IRPrinter to produce Block. 64 class BlockPrinter : public IRPrinter { 65 public: BlockPrinter(std::ostream * os,BlockAnalysis * block_analysis)66 BlockPrinter(std::ostream* os, BlockAnalysis* block_analysis) 67 : IRPrinter(*os), block_analysis_(block_analysis) {} 68 69 using IRPrinter::name_manager; 70 using IRPrinter::visit; 71 72 private: 73 BlockAnalysis* block_analysis_; 74 std::unordered_map<std::string, int> dim_values_map; 75 std::vector<std::string> dim_names = {"N", "H", "W", "C"}; 76 std::vector<std::string> flat_dim_names = {"N", "NH", "NHW", "NHWC"}; 77 void PrintTensorInfo(const std::unordered_set<BufPtr>& bufs); 78 void PrintArguments(const std::unordered_set<BufPtr>& bufs); 79 void PrintBufferInfo(const std::unordered_set<BufPtr>& bufs); 80 void PrintDistribution(const std::unordered_set<BufPtr>& bufs); 81 void PrintLoop(const std::unordered_set<BufPtr>& bufs, bool block_idx = true); 82 void PrintReshapeInfo( 83 const std::unordered_set<BufPtr>& bufs, 84 bool reverse = false); 85 void PrintDMAs(const std::unordered_set<BufPtr>& bufs); 86 void PrintAdjustBuffers(const std::unordered_set<BufPtr>& bufs); 87 88 void visit(const ForPtr& v) override; 89 void visit(const LoadPtr& v) override; 90 void visit(const StorePtr& v) override; 91 void visit(const BlockPtr& v) override; 92 void visit(const AddPtr& v) override; 93 void visit(const MulPtr& v) override; 94 }; 95 96 class TORCH_API BlockCodeGen : public CodeGen { 97 public: 98 template <typename... Ts> 99 /* implicit */ BlockCodeGen(StmtPtr stmt,Ts...ts)100 BlockCodeGen(StmtPtr stmt, Ts... ts) 101 : CodeGen( 102 stmt, 103 std::vector<BufferArg>({BufferArg(ts)...}), 104 at::Device(at::kCPU)) { 105 Initialize(); 106 } 107 108 BlockCodeGen( 109 StmtPtr stmt, 110 const std::vector<BufferArg>& buffer_args, 111 at::Device device = at::Device(at::kCPU), 112 const std::string& kernel_func_name = "func") CodeGen(std::move (stmt),buffer_args,device,kernel_func_name)113 : CodeGen(std::move(stmt), buffer_args, device, kernel_func_name) { 114 Initialize(); 115 } 116 117 ~BlockCodeGen() override; 118 119 void call(const std::vector<CallArg>& args) override; 120 void call_raw(const std::vector<void*>& args) override; 121 122 void Initialize(); 123 124 std::string getCodeText(const std::string& attr = "") override { 125 return oss_.str(); 126 } 127 128 private: name_manager()129 UniqueNameManager* name_manager() { 130 if (!printer_) { 131 throw std::runtime_error("Null IRPrinter is not expected"); 132 } 133 return printer_->name_manager(); 134 } 135 os()136 std::ostream& os() { 137 return printer_->os(); 138 } 139 140 std::ostringstream oss_; 141 std::unique_ptr<BlockPrinter> printer_; 142 std::unique_ptr<BlockAnalysis> block_analysis_; 143 144 std::string GetUniqueFuncName(const std::string& func_prefix); 145 }; 146 } // namespace torch::jit::tensorexpr 147