xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/ts_backend/ts_node_lowering.h>
2 
3 #include <ATen/Functions.h>
4 #include <torch/csrc/jit/frontend/sugared_value.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/lazy/backend/backend_interface.h>
7 #include <torch/csrc/lazy/core/helpers.h>
8 #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
9 #include <torch/csrc/lazy/core/ir_builder.h>
10 #include <torch/csrc/lazy/core/lazy_graph_executor.h>
11 #include <torch/csrc/lazy/core/ops/utils.h>
12 #include <torch/csrc/lazy/core/permutation_util.h>
13 #include <torch/csrc/lazy/ts_backend/ir_builder.h>
14 #include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
15 
16 namespace torch {
17 namespace lazy {
18 
LowerBuiltin(const torch::lazy::Node * node,std::shared_ptr<torch::jit::GraphFunction> function,const std::vector<torch::jit::NamedValue> & arguments,const std::vector<torch::jit::NamedValue> & kwarguments={})19 static TSOpVector LowerBuiltin(
20     const torch::lazy::Node* node,
21     std::shared_ptr<torch::jit::GraphFunction> function,
22     const std::vector<torch::jit::NamedValue>& arguments,
23     const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
24   return LowerTSBuiltin(function, node->op().op, arguments, kwarguments);
25 }
LowerBuiltin(c10::Symbol sym,std::shared_ptr<torch::jit::GraphFunction> function,const std::vector<torch::jit::NamedValue> & arguments,const std::vector<torch::jit::NamedValue> & kwarguments={})26 static TSOpVector LowerBuiltin(
27     c10::Symbol sym,
28     std::shared_ptr<torch::jit::GraphFunction> function,
29     const std::vector<torch::jit::NamedValue>& arguments,
30     const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
31   return LowerTSBuiltin(function, sym, arguments, kwarguments);
32 }
33 
LowerTSBuiltin(std::shared_ptr<torch::jit::GraphFunction> function,c10::Symbol sym,const std::vector<torch::jit::NamedValue> & arguments,const std::vector<torch::jit::NamedValue> & kwarguments)34 TSOpVector LowerTSBuiltin(
35     std::shared_ptr<torch::jit::GraphFunction> function,
36     c10::Symbol sym,
37     const std::vector<torch::jit::NamedValue>& arguments,
38     const std::vector<torch::jit::NamedValue>& kwarguments) {
39   auto builtin =
40       std::make_shared<torch::jit::BuiltinFunction>(sym, std::nullopt);
41   auto magic_method = std::make_shared<torch::jit::MagicMethod>("", builtin);
42   auto ret = magic_method->call({}, *function, arguments, kwarguments, 0);
43   auto& sv = dynamic_cast<torch::jit::SimpleValue&>(*ret);
44   if (sv.getValue()->type()->kind() == c10::TypeKind::TupleType) {
45     const auto tuple_call_result = sv.asTuple({}, *function);
46     TSOpVector tuple_result;
47     for (const auto& tuple_component : tuple_call_result) {
48       auto tuple_component_sv =
49           dynamic_cast<torch::jit::SimpleValue*>(tuple_component.get());
50       tuple_result.push_back(tuple_component_sv->getValue());
51     }
52     return tuple_result;
53   }
54   return {sv.getValue()};
55 }
56 
GenerateClone(torch::jit::Value * val,std::shared_ptr<torch::jit::GraphFunction> function)57 static torch::jit::Value* GenerateClone(
58     torch::jit::Value* val,
59     std::shared_ptr<torch::jit::GraphFunction> function) {
60   std::vector<torch::jit::NamedValue> clone_arguments;
61   clone_arguments.emplace_back(val);
62   TSOpVector cloned = LowerBuiltin(at::aten::clone, function, clone_arguments);
63   TORCH_CHECK_EQ(cloned.size(), 1);
64   return cloned.front();
65 }
66 
67 // Node Lowerings
68 
69 // Default node lowering
Lower(std::shared_ptr<torch::jit::GraphFunction> function,TSLoweringContext * loctx) const70 TSOpVector TsNode::Lower(
71     std::shared_ptr<torch::jit::GraphFunction> function,
72     TSLoweringContext* loctx) const {
73   std::vector<torch::jit::NamedValue> arguments;
74   for (const torch::lazy::Output& output : operands()) {
75     arguments.emplace_back(loctx->GetOutputOp(output));
76   }
77   return LowerBuiltin(this, function, arguments);
78 }
79 
80 // Non-native ops
Lower(std::shared_ptr<torch::jit::GraphFunction> function,torch::lazy::TSLoweringContext * loctx) const81 torch::lazy::TSOpVector Cast::Lower(
82     std::shared_ptr<torch::jit::GraphFunction> function,
83     torch::lazy::TSLoweringContext* loctx) const {
84   std::vector<torch::jit::NamedValue> arguments;
85   arguments.emplace_back(loctx->GetOutputOp(operand(0)));
86   arguments.emplace_back(dtype);
87   return LowerBuiltin(at::aten::to, function, arguments);
88 }
89 
Lower(std::shared_ptr<torch::jit::GraphFunction> function,torch::lazy::TSLoweringContext * loctx) const90 torch::lazy::TSOpVector DeviceData::Lower(
91     std::shared_ptr<torch::jit::GraphFunction> function,
92     torch::lazy::TSLoweringContext* loctx) const {
93   auto infoptr = data_->info();
94   auto deviceDataInfoPtr =
95       (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
96   if (GRAPH_DUMP_ENABLED) {
97     LOG(ERROR) << "Lowering device data node, tensor id "
98                << deviceDataInfoPtr->tensor_id << std::endl;
99   }
100   return {loctx->GetParameter(data_)};
101 }
102 
Lower(std::shared_ptr<torch::jit::GraphFunction> function,torch::lazy::TSLoweringContext * loctx) const103 torch::lazy::TSOpVector Expand::Lower(
104     std::shared_ptr<torch::jit::GraphFunction> function,
105     torch::lazy::TSLoweringContext* loctx) const {
106   std::vector<torch::jit::NamedValue> arguments;
107   arguments.emplace_back(loctx->GetOutputOp(operand(0)));
108   arguments.emplace_back(size);
109   auto expand_out = LowerBuiltin(this, function, arguments);
110   if (is_scalar_expand) {
111     // The aten::expand operations sets all strides to 0 when the original is
112     // of rank 0. This leads to false positives when checking for internal
113     // memory overlap, because at::has_internal_overlap returns
114     // MemOverlap::YES when a stride is set to 0.
115     TORCH_CHECK_EQ(expand_out.size(), 1);
116     return {GenerateClone(expand_out.front(), function)};
117   }
118   return expand_out;
119 }
120 
Lower(std::shared_ptr<torch::jit::GraphFunction> function,torch::lazy::TSLoweringContext * loctx) const121 torch::lazy::TSOpVector Scalar::Lower(
122     std::shared_ptr<torch::jit::GraphFunction> function,
123     torch::lazy::TSLoweringContext* loctx) const {
124   auto options =
125       at::TensorOptions()
126           .device(torch::lazy::getBackend()->EagerFallbackDeviceType())
127           .dtype(shape().scalar_type());
128   return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))};
129 }
130 
131 } // namespace lazy
132 } // namespace torch
133