1 #include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
2
DimCast(torch::lazy::Output output)3 static const torch::lazy::DimensionNode* DimCast(torch::lazy::Output output) {
4 return dynamic_cast<const torch::lazy::DimensionNode*>(output.node);
5 }
6
7 namespace torch {
8 namespace lazy {
9
Lower(std::shared_ptr<torch::jit::GraphFunction> function,TSLoweringContext * loctx) const10 TSOpVector SizeNode::Lower(
11 std::shared_ptr<torch::jit::GraphFunction> function,
12 TSLoweringContext* loctx) const {
13 std::vector<torch::jit::NamedValue> arguments;
14 std::vector<torch::jit::NamedValue> kwarguments;
15 arguments.reserve(2);
16 auto index = loctx->graph()->insertConstant(static_cast<int64_t>(this->dim_));
17 arguments.emplace_back(loctx->GetOutputOp(operand(0)));
18 arguments.emplace_back(index);
19 torch::lazy::TSOpVector size_out =
20 torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
21 TORCH_CHECK_EQ(size_out.size(), 1);
22 return size_out;
23 }
24
SizeNode(Value input,size_t dim)25 SizeNode::SizeNode(Value input, size_t dim)
26 : TsNode(
27 OpKind{c10::Symbol::fromQualString("aten::size")},
28 {input},
29 std::vector<Shape>{},
30 1,
31 MHash(dim)),
32 dim_(dim){};
33
getStaticValue() const34 int64_t SizeNode::getStaticValue() const {
35 return dynamic_cast<const TsNode*>(operand(0).node)->shape(0).size(dim_);
36 }
isSymbolic() const37 bool SizeNode::isSymbolic() const {
38 auto symbolic_vec =
39 dynamic_cast<const TsNode*>(operand(0).node)->shape(0).is_symbolic();
40 if (!symbolic_vec.has_value()) {
41 return true;
42 }
43 return symbolic_vec->at(dim_);
44 }
45
ToString() const46 std::string SizeNode::ToString() const {
47 return "SizeNode";
48 }
49
SizeAdd(Value a,Value b)50 SizeAdd::SizeAdd(Value a, Value b)
51 : TsNode(
52 OpKind{c10::Symbol::fromQualString("aten::add")},
53 {a, b},
54 std::vector<Shape>{},
55 1){};
56
getStaticValue() const57 int64_t SizeAdd::getStaticValue() const {
58 return DimCast(operand(0))->getStaticValue() +
59 DimCast(operand(1))->getStaticValue();
60 }
61
isSymbolic() const62 bool SizeAdd::isSymbolic() const {
63 return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
64 }
65
ToString() const66 std::string SizeAdd::ToString() const {
67 return "SizeAdd";
68 }
69
SizeMul(Value a,Value b)70 SizeMul::SizeMul(Value a, Value b)
71 : TsNode(
72 OpKind{c10::Symbol::fromQualString("aten::mul")},
73 {a, b},
74 std::vector<Shape>{},
75 1){};
76
getStaticValue() const77 int64_t SizeMul::getStaticValue() const {
78 return DimCast(operand(0))->getStaticValue() *
79 DimCast(operand(1))->getStaticValue();
80 }
81
isSymbolic() const82 bool SizeMul::isSymbolic() const {
83 return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
84 }
85
ToString() const86 std::string SizeMul::ToString() const {
87 return "SizeMul";
88 }
89
SizeDiv(Value a,Value b)90 SizeDiv::SizeDiv(Value a, Value b)
91 : TsNode(
92 OpKind{c10::Symbol::fromQualString("aten::div")},
93 {a, b},
94 std::vector<Shape>{},
95 1){};
96
getStaticValue() const97 int64_t SizeDiv::getStaticValue() const {
98 TORCH_CHECK(
99 DimCast(operand(1))->getStaticValue() != 0,
100 "Can't divide a dimension by zero");
101 return DimCast(operand(0))->getStaticValue() /
102 DimCast(operand(1))->getStaticValue();
103 }
104
isSymbolic() const105 bool SizeDiv::isSymbolic() const {
106 return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
107 }
108
ToString() const109 std::string SizeDiv::ToString() const {
110 return "SizeDiv";
111 }
112
113 } // namespace lazy
114 } // namespace torch
115