1 #include <torch/csrc/lazy/ts_backend/ts_node_lowering.h>
2
3 #include <ATen/Functions.h>
4 #include <torch/csrc/jit/frontend/sugared_value.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/lazy/backend/backend_interface.h>
7 #include <torch/csrc/lazy/core/helpers.h>
8 #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
9 #include <torch/csrc/lazy/core/ir_builder.h>
10 #include <torch/csrc/lazy/core/lazy_graph_executor.h>
11 #include <torch/csrc/lazy/core/ops/utils.h>
12 #include <torch/csrc/lazy/core/permutation_util.h>
13 #include <torch/csrc/lazy/ts_backend/ir_builder.h>
14 #include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
15
16 namespace torch {
17 namespace lazy {
18
LowerBuiltin(const torch::lazy::Node * node,std::shared_ptr<torch::jit::GraphFunction> function,const std::vector<torch::jit::NamedValue> & arguments,const std::vector<torch::jit::NamedValue> & kwarguments={})19 static TSOpVector LowerBuiltin(
20 const torch::lazy::Node* node,
21 std::shared_ptr<torch::jit::GraphFunction> function,
22 const std::vector<torch::jit::NamedValue>& arguments,
23 const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
24 return LowerTSBuiltin(function, node->op().op, arguments, kwarguments);
25 }
LowerBuiltin(c10::Symbol sym,std::shared_ptr<torch::jit::GraphFunction> function,const std::vector<torch::jit::NamedValue> & arguments,const std::vector<torch::jit::NamedValue> & kwarguments={})26 static TSOpVector LowerBuiltin(
27 c10::Symbol sym,
28 std::shared_ptr<torch::jit::GraphFunction> function,
29 const std::vector<torch::jit::NamedValue>& arguments,
30 const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
31 return LowerTSBuiltin(function, sym, arguments, kwarguments);
32 }
33
LowerTSBuiltin(std::shared_ptr<torch::jit::GraphFunction> function,c10::Symbol sym,const std::vector<torch::jit::NamedValue> & arguments,const std::vector<torch::jit::NamedValue> & kwarguments)34 TSOpVector LowerTSBuiltin(
35 std::shared_ptr<torch::jit::GraphFunction> function,
36 c10::Symbol sym,
37 const std::vector<torch::jit::NamedValue>& arguments,
38 const std::vector<torch::jit::NamedValue>& kwarguments) {
39 auto builtin =
40 std::make_shared<torch::jit::BuiltinFunction>(sym, std::nullopt);
41 auto magic_method = std::make_shared<torch::jit::MagicMethod>("", builtin);
42 auto ret = magic_method->call({}, *function, arguments, kwarguments, 0);
43 auto& sv = dynamic_cast<torch::jit::SimpleValue&>(*ret);
44 if (sv.getValue()->type()->kind() == c10::TypeKind::TupleType) {
45 const auto tuple_call_result = sv.asTuple({}, *function);
46 TSOpVector tuple_result;
47 for (const auto& tuple_component : tuple_call_result) {
48 auto tuple_component_sv =
49 dynamic_cast<torch::jit::SimpleValue*>(tuple_component.get());
50 tuple_result.push_back(tuple_component_sv->getValue());
51 }
52 return tuple_result;
53 }
54 return {sv.getValue()};
55 }
56
GenerateClone(torch::jit::Value * val,std::shared_ptr<torch::jit::GraphFunction> function)57 static torch::jit::Value* GenerateClone(
58 torch::jit::Value* val,
59 std::shared_ptr<torch::jit::GraphFunction> function) {
60 std::vector<torch::jit::NamedValue> clone_arguments;
61 clone_arguments.emplace_back(val);
62 TSOpVector cloned = LowerBuiltin(at::aten::clone, function, clone_arguments);
63 TORCH_CHECK_EQ(cloned.size(), 1);
64 return cloned.front();
65 }
66
67 // Node Lowerings
68
69 // Default node lowering
Lower(std::shared_ptr<torch::jit::GraphFunction> function,TSLoweringContext * loctx) const70 TSOpVector TsNode::Lower(
71 std::shared_ptr<torch::jit::GraphFunction> function,
72 TSLoweringContext* loctx) const {
73 std::vector<torch::jit::NamedValue> arguments;
74 for (const torch::lazy::Output& output : operands()) {
75 arguments.emplace_back(loctx->GetOutputOp(output));
76 }
77 return LowerBuiltin(this, function, arguments);
78 }
79
80 // Non-native ops
Lower(std::shared_ptr<torch::jit::GraphFunction> function,torch::lazy::TSLoweringContext * loctx) const81 torch::lazy::TSOpVector Cast::Lower(
82 std::shared_ptr<torch::jit::GraphFunction> function,
83 torch::lazy::TSLoweringContext* loctx) const {
84 std::vector<torch::jit::NamedValue> arguments;
85 arguments.emplace_back(loctx->GetOutputOp(operand(0)));
86 arguments.emplace_back(dtype);
87 return LowerBuiltin(at::aten::to, function, arguments);
88 }
89
Lower(std::shared_ptr<torch::jit::GraphFunction> function,torch::lazy::TSLoweringContext * loctx) const90 torch::lazy::TSOpVector DeviceData::Lower(
91 std::shared_ptr<torch::jit::GraphFunction> function,
92 torch::lazy::TSLoweringContext* loctx) const {
93 auto infoptr = data_->info();
94 auto deviceDataInfoPtr =
95 (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
96 if (GRAPH_DUMP_ENABLED) {
97 LOG(ERROR) << "Lowering device data node, tensor id "
98 << deviceDataInfoPtr->tensor_id << std::endl;
99 }
100 return {loctx->GetParameter(data_)};
101 }
102
Lower(std::shared_ptr<torch::jit::GraphFunction> function,torch::lazy::TSLoweringContext * loctx) const103 torch::lazy::TSOpVector Expand::Lower(
104 std::shared_ptr<torch::jit::GraphFunction> function,
105 torch::lazy::TSLoweringContext* loctx) const {
106 std::vector<torch::jit::NamedValue> arguments;
107 arguments.emplace_back(loctx->GetOutputOp(operand(0)));
108 arguments.emplace_back(size);
109 auto expand_out = LowerBuiltin(this, function, arguments);
110 if (is_scalar_expand) {
111 // The aten::expand operations sets all strides to 0 when the original is
112 // of rank 0. This leads to false positives when checking for internal
113 // memory overlap, because at::has_internal_overlap returns
114 // MemOverlap::YES when a stride is set to 0.
115 TORCH_CHECK_EQ(expand_out.size(), 1);
116 return {GenerateClone(expand_out.front(), function)};
117 }
118 return expand_out;
119 }
120
Lower(std::shared_ptr<torch::jit::GraphFunction> function,torch::lazy::TSLoweringContext * loctx) const121 torch::lazy::TSOpVector Scalar::Lower(
122 std::shared_ptr<torch::jit::GraphFunction> function,
123 torch::lazy::TSLoweringContext* loctx) const {
124 auto options =
125 at::TensorOptions()
126 .device(torch::lazy::getBackend()->EagerFallbackDeviceType())
127 .dtype(shape().scalar_type());
128 return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))};
129 }
130
131 } // namespace lazy
132 } // namespace torch
133