1 #pragma once 2 3 #include <sstream> 4 5 #include <torch/csrc/api/include/torch/jit.h> 6 #include <torch/csrc/jit/runtime/graph_executor.h> 7 #include <torch/csrc/lazy/backend/lowering_context.h> 8 #include <torch/csrc/lazy/core/ir.h> 9 #include <torch/csrc/lazy/ts_backend/ts_node_lowering.h> 10 11 namespace torch { 12 namespace lazy { 13 14 using TSOpVector = std::vector<torch::jit::Value*>; 15 16 class TORCH_API TSComputation : public Computation { 17 public: TSComputation(const std::shared_ptr<torch::jit::Graph> & graph)18 TSComputation(const std::shared_ptr<torch::jit::Graph>& graph) 19 : graph_(graph), graph_executor_(graph, "") { 20 for (torch::jit::Value* input : graph_->inputs()) { 21 parameter_names_.push_back(input->debugName()); 22 } 23 } 24 parameters_size()25 int parameters_size() const override { 26 return parameter_names_.size(); 27 } 28 parameter_shapes()29 const std::vector<Shape>& parameter_shapes() const override { 30 throw std::runtime_error( 31 "TODO(whc) implement TS computation shapes or change interface"); 32 return parameter_shapes_; 33 } 34 parameter_names()35 const std::vector<std::string>& parameter_names() const override { 36 return parameter_names_; 37 } 38 result_shape()39 const Shape& result_shape() const override { 40 throw std::runtime_error( 41 "TODO(whc) implement TS computation shapes or change interface"); 42 return result_shape_; 43 } 44 to_string()45 const std::string to_string() const override { 46 std::ostringstream oss; 47 oss << *graph_; 48 return oss.str(); 49 } 50 graph()51 std::shared_ptr<torch::jit::Graph> graph() const { 52 return graph_; 53 } 54 graph_executor()55 torch::jit::GraphExecutor& graph_executor() { 56 return graph_executor_; 57 } 58 59 private: 60 std::shared_ptr<torch::jit::Graph> graph_; 61 torch::jit::GraphExecutor graph_executor_; 62 std::vector<std::string> parameter_names_; 63 std::vector<Shape> parameter_shapes_; 64 Shape result_shape_; 65 }; 66 67 class TORCH_API TSLoweringContext : public LoweringContext { 68 public: 69 TSLoweringContext(const std::string& name, const BackendDevice device); 70 71 TSLoweringContext( 72 const std::string& name, 73 BackendDevice device, 74 c10::ArrayRef<const Node*> post_order, 75 Util::EmissionMap emit_status); 76 AddResult(const Output & output)77 size_t AddResult(const Output& output) override { 78 return AddResult(GetOutputOp(output)); 79 } 80 AddParameter(const torch::lazy::Output & output,size_t index,const Shape & shape,const std::string & name)81 void AddParameter( 82 const torch::lazy::Output& output, 83 size_t index, 84 const Shape& shape, 85 const std::string& name) override { 86 TORCH_INTERNAL_ASSERT(false, "not implemented"); 87 } 88 89 void Lower(const Node* node); 90 Build()91 ComputationPtr Build() override { 92 for (torch::jit::Value* output : root_tuple_) { 93 graph_->block()->registerOutput(output); 94 } 95 return std::shared_ptr<Computation>(new TSComputation(graph_)); 96 } 97 98 // Retrieves the lowered operation for an output. If the requested output is 99 // not available yet, the graph behind the output's Node is lowered, and the 100 // corresponding TS operation returned. GetOutputOp(const Output & output)101 torch::jit::Value* GetOutputOp(const Output& output) { 102 auto it = emitted_outputs_.find(output); 103 if (it == emitted_outputs_.end()) { 104 auto post_order = Util::ComputePostOrder(output.node, &emit_status_); 105 for (auto node : post_order) { 106 Lower(node); 107 } 108 // At this point the output better be present, otherwise there is an issue 109 // with the lowering code. 110 it = emitted_outputs_.find(output); 111 TORCH_CHECK( 112 it != emitted_outputs_.end(), 113 "No TS operation emitted for output: ", 114 output.ToString()); 115 } 116 return it->second; 117 } 118 119 // Assigns the given TS operation to the specified output. As outputs are 120 // lowered in a post-order fashion, later nodes should always find their 121 // operands among the emitted outputs. 122 void AssignOutputOp(const Output& output, torch::jit::Value* op); 123 124 // If a parameter associated with data has already been declared, it will be 125 // returned. Otherwise a new one will be created, associated with the tensor 126 // held in data. 127 torch::jit::Value* GetParameter(BackendDataPtr data); 128 graph()129 std::shared_ptr<torch::jit::Graph> graph() const { 130 return graph_; 131 } 132 133 private: 134 struct Parameter { 135 torch::jit::Value* param{nullptr}; 136 size_t index = 0; 137 }; 138 AddResult(torch::jit::Value * op)139 size_t AddResult(torch::jit::Value* op) { 140 root_tuple_.push_back(std::move(op)); 141 return root_tuple_.size() - 1; 142 } 143 144 std::shared_ptr<torch::jit::Graph> graph_; 145 std::shared_ptr<torch::jit::GraphFunction> function_; 146 std::unordered_map<BackendData::Handle, Parameter> parameters_map_; 147 std::vector<torch::jit::Value*> root_tuple_; 148 OutputMap<torch::jit::Value*> emitted_outputs_; 149 }; 150 151 } // namespace lazy 152 } // namespace torch 153