xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/dynamic_ir.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
2 
DimCast(torch::lazy::Output output)3 static const torch::lazy::DimensionNode* DimCast(torch::lazy::Output output) {
4   return dynamic_cast<const torch::lazy::DimensionNode*>(output.node);
5 }
6 
7 namespace torch {
8 namespace lazy {
9 
Lower(std::shared_ptr<torch::jit::GraphFunction> function,TSLoweringContext * loctx) const10 TSOpVector SizeNode::Lower(
11     std::shared_ptr<torch::jit::GraphFunction> function,
12     TSLoweringContext* loctx) const {
13   std::vector<torch::jit::NamedValue> arguments;
14   std::vector<torch::jit::NamedValue> kwarguments;
15   arguments.reserve(2);
16   auto index = loctx->graph()->insertConstant(static_cast<int64_t>(this->dim_));
17   arguments.emplace_back(loctx->GetOutputOp(operand(0)));
18   arguments.emplace_back(index);
19   torch::lazy::TSOpVector size_out =
20       torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
21   TORCH_CHECK_EQ(size_out.size(), 1);
22   return size_out;
23 }
24 
SizeNode(Value input,size_t dim)25 SizeNode::SizeNode(Value input, size_t dim)
26     : TsNode(
27           OpKind{c10::Symbol::fromQualString("aten::size")},
28           {input},
29           std::vector<Shape>{},
30           1,
31           MHash(dim)),
32       dim_(dim){};
33 
getStaticValue() const34 int64_t SizeNode::getStaticValue() const {
35   return dynamic_cast<const TsNode*>(operand(0).node)->shape(0).size(dim_);
36 }
isSymbolic() const37 bool SizeNode::isSymbolic() const {
38   auto symbolic_vec =
39       dynamic_cast<const TsNode*>(operand(0).node)->shape(0).is_symbolic();
40   if (!symbolic_vec.has_value()) {
41     return true;
42   }
43   return symbolic_vec->at(dim_);
44 }
45 
ToString() const46 std::string SizeNode::ToString() const {
47   return "SizeNode";
48 }
49 
SizeAdd(Value a,Value b)50 SizeAdd::SizeAdd(Value a, Value b)
51     : TsNode(
52           OpKind{c10::Symbol::fromQualString("aten::add")},
53           {a, b},
54           std::vector<Shape>{},
55           1){};
56 
getStaticValue() const57 int64_t SizeAdd::getStaticValue() const {
58   return DimCast(operand(0))->getStaticValue() +
59       DimCast(operand(1))->getStaticValue();
60 }
61 
isSymbolic() const62 bool SizeAdd::isSymbolic() const {
63   return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
64 }
65 
ToString() const66 std::string SizeAdd::ToString() const {
67   return "SizeAdd";
68 }
69 
SizeMul(Value a,Value b)70 SizeMul::SizeMul(Value a, Value b)
71     : TsNode(
72           OpKind{c10::Symbol::fromQualString("aten::mul")},
73           {a, b},
74           std::vector<Shape>{},
75           1){};
76 
getStaticValue() const77 int64_t SizeMul::getStaticValue() const {
78   return DimCast(operand(0))->getStaticValue() *
79       DimCast(operand(1))->getStaticValue();
80 }
81 
isSymbolic() const82 bool SizeMul::isSymbolic() const {
83   return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
84 }
85 
ToString() const86 std::string SizeMul::ToString() const {
87   return "SizeMul";
88 }
89 
SizeDiv(Value a,Value b)90 SizeDiv::SizeDiv(Value a, Value b)
91     : TsNode(
92           OpKind{c10::Symbol::fromQualString("aten::div")},
93           {a, b},
94           std::vector<Shape>{},
95           1){};
96 
getStaticValue() const97 int64_t SizeDiv::getStaticValue() const {
98   TORCH_CHECK(
99       DimCast(operand(1))->getStaticValue() != 0,
100       "Can't divide a dimension by zero");
101   return DimCast(operand(0))->getStaticValue() /
102       DimCast(operand(1))->getStaticValue();
103 }
104 
isSymbolic() const105 bool SizeDiv::isSymbolic() const {
106   return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
107 }
108 
ToString() const109 std::string SizeDiv::ToString() const {
110   return "SizeDiv";
111 }
112 
113 } // namespace lazy
114 } // namespace torch
115