xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ts_lowering_context.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <sstream>
4 
5 #include <torch/csrc/api/include/torch/jit.h>
6 #include <torch/csrc/jit/runtime/graph_executor.h>
7 #include <torch/csrc/lazy/backend/lowering_context.h>
8 #include <torch/csrc/lazy/core/ir.h>
9 #include <torch/csrc/lazy/ts_backend/ts_node_lowering.h>
10 
11 namespace torch {
12 namespace lazy {
13 
14 using TSOpVector = std::vector<torch::jit::Value*>;
15 
16 class TORCH_API TSComputation : public Computation {
17  public:
TSComputation(const std::shared_ptr<torch::jit::Graph> & graph)18   TSComputation(const std::shared_ptr<torch::jit::Graph>& graph)
19       : graph_(graph), graph_executor_(graph, "") {
20     for (torch::jit::Value* input : graph_->inputs()) {
21       parameter_names_.push_back(input->debugName());
22     }
23   }
24 
parameters_size()25   int parameters_size() const override {
26     return parameter_names_.size();
27   }
28 
parameter_shapes()29   const std::vector<Shape>& parameter_shapes() const override {
30     throw std::runtime_error(
31         "TODO(whc) implement TS computation shapes or change interface");
32     return parameter_shapes_;
33   }
34 
parameter_names()35   const std::vector<std::string>& parameter_names() const override {
36     return parameter_names_;
37   }
38 
result_shape()39   const Shape& result_shape() const override {
40     throw std::runtime_error(
41         "TODO(whc) implement TS computation shapes or change interface");
42     return result_shape_;
43   }
44 
to_string()45   const std::string to_string() const override {
46     std::ostringstream oss;
47     oss << *graph_;
48     return oss.str();
49   }
50 
graph()51   std::shared_ptr<torch::jit::Graph> graph() const {
52     return graph_;
53   }
54 
graph_executor()55   torch::jit::GraphExecutor& graph_executor() {
56     return graph_executor_;
57   }
58 
59  private:
60   std::shared_ptr<torch::jit::Graph> graph_;
61   torch::jit::GraphExecutor graph_executor_;
62   std::vector<std::string> parameter_names_;
63   std::vector<Shape> parameter_shapes_;
64   Shape result_shape_;
65 };
66 
67 class TORCH_API TSLoweringContext : public LoweringContext {
68  public:
69   TSLoweringContext(const std::string& name, const BackendDevice device);
70 
71   TSLoweringContext(
72       const std::string& name,
73       BackendDevice device,
74       c10::ArrayRef<const Node*> post_order,
75       Util::EmissionMap emit_status);
76 
AddResult(const Output & output)77   size_t AddResult(const Output& output) override {
78     return AddResult(GetOutputOp(output));
79   }
80 
AddParameter(const torch::lazy::Output & output,size_t index,const Shape & shape,const std::string & name)81   void AddParameter(
82       const torch::lazy::Output& output,
83       size_t index,
84       const Shape& shape,
85       const std::string& name) override {
86     TORCH_INTERNAL_ASSERT(false, "not implemented");
87   }
88 
89   void Lower(const Node* node);
90 
Build()91   ComputationPtr Build() override {
92     for (torch::jit::Value* output : root_tuple_) {
93       graph_->block()->registerOutput(output);
94     }
95     return std::shared_ptr<Computation>(new TSComputation(graph_));
96   }
97 
98   // Retrieves the lowered operation for an output. If the requested output is
99   // not available yet, the graph behind the output's Node is lowered, and the
100   // corresponding TS operation returned.
GetOutputOp(const Output & output)101   torch::jit::Value* GetOutputOp(const Output& output) {
102     auto it = emitted_outputs_.find(output);
103     if (it == emitted_outputs_.end()) {
104       auto post_order = Util::ComputePostOrder(output.node, &emit_status_);
105       for (auto node : post_order) {
106         Lower(node);
107       }
108       // At this point the output better be present, otherwise there is an issue
109       // with the lowering code.
110       it = emitted_outputs_.find(output);
111       TORCH_CHECK(
112           it != emitted_outputs_.end(),
113           "No TS operation emitted for output: ",
114           output.ToString());
115     }
116     return it->second;
117   }
118 
119   // Assigns the given TS operation to the specified output. As outputs are
120   // lowered in a post-order fashion, later nodes should always find their
121   // operands among the emitted outputs.
122   void AssignOutputOp(const Output& output, torch::jit::Value* op);
123 
124   // If a parameter associated with data has already been declared, it will be
125   // returned. Otherwise a new one will be created, associated with the tensor
126   // held in data.
127   torch::jit::Value* GetParameter(BackendDataPtr data);
128 
graph()129   std::shared_ptr<torch::jit::Graph> graph() const {
130     return graph_;
131   }
132 
133  private:
134   struct Parameter {
135     torch::jit::Value* param{nullptr};
136     size_t index = 0;
137   };
138 
AddResult(torch::jit::Value * op)139   size_t AddResult(torch::jit::Value* op) {
140     root_tuple_.push_back(std::move(op));
141     return root_tuple_.size() - 1;
142   }
143 
144   std::shared_ptr<torch::jit::Graph> graph_;
145   std::shared_ptr<torch::jit::GraphFunction> function_;
146   std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
147   std::vector<torch::jit::Value*> root_tuple_;
148   OutputMap<torch::jit::Value*> emitted_outputs_;
149 };
150 
151 } // namespace lazy
152 } // namespace torch
153