xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/ScalarType.h>
2 #include <torch/csrc/lazy/ts_backend/ts_backend_impl.h>
3 #include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
4 #include <torch/csrc/lazy/ts_backend/ts_node.h>
5 
6 namespace torch {
7 namespace lazy {
8 
TSLoweringContext(const std::string & name,BackendDevice device)9 TSLoweringContext::TSLoweringContext(
10     const std::string& name,
11     BackendDevice device)
12     : torch::lazy::LoweringContext(name, device),
13       graph_(std::make_shared<torch::jit::Graph>()),
14       function_(
15           std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)) {}
16 
TSLoweringContext(const std::string & name,BackendDevice device,c10::ArrayRef<const Node * > post_order,Util::EmissionMap emit_status)17 TSLoweringContext::TSLoweringContext(
18     const std::string& name,
19     BackendDevice device,
20     c10::ArrayRef<const Node*> post_order,
21     Util::EmissionMap emit_status)
22     : torch::lazy::LoweringContext(name, device, post_order, emit_status),
23       graph_(std::make_shared<torch::jit::Graph>()),
24       function_(
25           std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)) {
26   for (auto node : post_order) {
27     Lower(node);
28   }
29 }
30 
Lower(const Node * node)31 void TSLoweringContext::Lower(const Node* node) {
32   if (auto* tsnode = dynamic_cast<const torch::lazy::TsNode*>(node)) {
33     // First, we call the node lowering function, which exists for newly
34     // codegenned or refactored nodes
35     TSOpVector ops = tsnode->Lower(function_, this);
36     TORCH_CHECK(!ops.empty(), "Failed to lower: ", *node);
37     TORCH_CHECK_EQ(node->num_outputs(), ops.size());
38     for (size_t i = 0; i < ops.size(); ++i) {
39       AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
40     }
41   } else {
42     throw std::runtime_error(
43         "Expected torch::lazy::TsNode but could not dynamic cast");
44   }
45 }
46 
AssignOutputOp(const Output & output,torch::jit::Value * op)47 void TSLoweringContext::AssignOutputOp(
48     const Output& output,
49     torch::jit::Value* op) {
50   const TsNode* ts_node = static_cast<const TsNode*>(output.node);
51   std::string stack_trace = ts_node->getPythonStacktrace();
52   if (!stack_trace.empty()) {
53     op->node()->s_(c10::Symbol::attr("source"), stack_trace);
54   }
55   emitted_outputs_[output] = op;
56 }
57 
GetParameter(BackendDataPtr data)58 torch::jit::Value* TSLoweringContext::GetParameter(BackendDataPtr data) {
59   const auto ts_data = std::static_pointer_cast<TSData>(data);
60   BackendData::Handle handle = ts_data->GetHandle();
61   auto it = parameters_map_.find(handle);
62   if (it == parameters_map_.end()) {
63     torch::jit::Value* param =
64         graph_->addInput(c10::str("p", parameters_.size()));
65     if (ts_data->scalar.has_value()) {
66       auto scalarType = ts_data->scalar.value().type();
67       if (isFloatingType(scalarType)) {
68         param->setType(c10::FloatType::get());
69       } else if (isIntegralType(scalarType, /*includeBool=*/true)) {
70         param->setType(c10::IntType::get());
71       } else {
72         TORCH_CHECK(
73             false, "Unhandled scalar type: ", c10::toString(scalarType));
74       }
75     }
76     it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
77              .first;
78     parameters_.push_back(ts_data);
79   }
80   parameter_sequence_.push_back(it->second.index);
81   return it->second.param;
82 }
83 
84 } // namespace lazy
85 } // namespace torch
86