xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/ir.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/backend/backend_interface.h>
2 #include <torch/csrc/lazy/core/cache.h>
3 #include <torch/csrc/lazy/core/config.h>
4 #include <torch/csrc/lazy/core/ir.h>
5 #include <torch/csrc/lazy/core/ir_metadata.h>
6 
7 // Enables caching on for dynamic shapes (aka disable hash on shapes)
8 C10_DEFINE_bool(
9     ltc_enable_dynamic_shapes,
10     false,
11     "Whether dynamic shape is enabled");
12 
13 namespace torch {
14 namespace lazy {
15 
16 static const torch::lazy::Output kNullOutput = torch::lazy::Output();
17 
operator ()(const Output & output) const18 size_t Output::Hasher::operator()(const Output& output) const {
19   return StdHashCombine(
20       reinterpret_cast<std::ptrdiff_t>(output.node), output.index);
21 }
22 
hash() const23 hash_t Output::hash() const {
24   return HashCombine(node->hash(), Hash(index));
25 }
26 
shapeHash() const27 hash_t Output::shapeHash() const {
28   return HashCombine(node->shapeHash(), Hash(index));
29 }
30 
ToString() const31 std::string Output::ToString() const {
32   std::stringstream ss;
33   ss << node->ToString() << ", index=" << index;
34   return ss.str();
35 }
36 
operator ==(const Value & rhs) const37 bool Output::operator==(const Value& rhs) const {
38   // Either side could be kNullValue which has node as nullptr
39   return (!node == !rhs.node) &&
40       (!node || (node->hash() == rhs.node->hash() && index == rhs.index));
41 }
42 
hash() const43 hash_t Value::hash() const {
44   return HashCombine(node->hash(), Hash(index));
45 }
46 
shapeHash() const47 hash_t Value::shapeHash() const {
48   return HashCombine(node->shapeHash(), Hash(index));
49 }
50 
Get(const std::string & name)51 OpKind OpKind::Get(const std::string& name) {
52   return OpKind(c10::Symbol::fromQualString(name));
53 }
54 
hash() const55 hash_t OpKind::hash() const {
56   return StringHash(op.toQualString());
57 }
58 
enableDynamicShape()59 bool Node::enableDynamicShape() {
60   static bool enabled = std::getenv("LTC_ENABLE_DYNAMIC_SHAPES") != nullptr;
61   return enabled || FLAGS_ltc_enable_dynamic_shapes;
62 }
63 
Node(OpKind op,size_t num_outputs)64 Node::Node(OpKind op, size_t num_outputs)
65     : op_(op), num_outputs_(num_outputs), metadata_(GetMetaDataIfDebugging()) {}
66 
Node(OpKind op,OpList operands,std::vector<Shape> && shapes,size_t num_outputs)67 Node::Node(
68     OpKind op,
69     OpList operands,
70     std::vector<Shape>&& shapes,
71     size_t num_outputs)
72     : Node(op, num_outputs) {
73   // Move shapes into node
74   shapes_.insert(
75       shapes_.end(),
76       std::make_move_iterator(shapes.begin()),
77       std::make_move_iterator(shapes.end()));
78 
79   for (auto& operand : operands) {
80     // Ideally, optional operands should be filtered by the leaf node classes,
81     // but it's just much easier to do it here.
82     // TODO(alanwaketan): Find a way to move the below logic to the leaf node
83     // classes.
84     if (!operand) {
85       continue;
86     }
87 
88     AddOperand(operand.node, operand.index);
89   }
90 }
91 
Node(OpKind op,OpList operands,const std::function<Shape ()> & shape_fn,size_t num_outputs)92 Node::Node(
93     OpKind op,
94     OpList operands,
95     const std::function<Shape()>& shape_fn,
96     size_t num_outputs)
97     : Node(op, operands, std::vector<Shape>{}, num_outputs) {
98   addComputedShape(shape_fn);
99 }
100 
Node(OpKind op,OpList operands,size_t num_outputs)101 Node::Node(OpKind op, OpList operands, size_t num_outputs)
102     : Node(op, operands, std::vector<Shape>{}, num_outputs) {}
103 
Node(OpKind op,Shape shape,size_t num_outputs)104 Node::Node(OpKind op, Shape shape, size_t num_outputs) : Node(op, num_outputs) {
105   shapes_.push_back(std::move(shape));
106 }
107 
108 Node::~Node() = default;
109 
110 // Retrieves the full shape of the IR Node.
shapes() const111 c10::ArrayRef<Shape> Node::shapes() const {
112   return shapes_;
113 }
114 
115 // Retrieves the shape of the output at a given index.
shape(size_t output_index) const116 const Shape& Node::shape(size_t output_index) const {
117   return shapes_.at(output_index);
118 }
119 
120 // Add the shape computed by the shape_fn
121 
addComputedShape(const std::function<Shape ()> & shape_fn)122 void Node::addComputedShape(const std::function<Shape()>& shape_fn) {
123   shapes_.push_back(computeShape(shape_fn));
124 }
125 
126 using ShapeCache = Cache<hash_t, Shape, HashReducer>;
127 
128 // Compute the shape using the provided shape_fn.
computeShape(const std::function<Shape ()> & shape_fn)129 Shape Node::computeShape(const std::function<Shape()>& shape_fn) {
130   static ShapeCache* cache = new ShapeCache(FLAGS_torch_lazy_shape_cache_size);
131 
132   auto hash = shapeHash();
133   auto shape = cache->Get(hash);
134   if (shape == nullptr) {
135     shape = cache->Add(hash, std::make_shared<Shape>(shape_fn()));
136   }
137   return *shape;
138 }
139 
operands() const140 const std::vector<Output>& Node::operands() const {
141   return operands_as_outputs_;
142 }
143 
operand(size_t i) const144 const Output& Node::operand(size_t i) const {
145   return operands_as_outputs_.at(i);
146 }
147 
nullable_operand(size_t i) const148 const Output& Node::nullable_operand(size_t i) const {
149   // We use kNullOutput instead of kNullValue here to avoid implicit casting,
150   // which would prevent this method from returning a reference.
151   return i < operands_as_outputs_.size() ? operand(i) : kNullOutput;
152 }
153 
ToString() const154 std::string Node::ToString() const {
155   std::stringstream ss;
156   ss << shapes() << " " << op();
157   if (num_outputs() > 1) {
158     ss << ", num_outputs=" << num_outputs();
159   }
160   if (!metadata().scope.empty()) {
161     ss << ", scope=" << metadata().scope;
162   }
163   EmitShortFrameInfo(ss, metadata().frame_info);
164   return ss.str();
165 }
166 
AddOperand(NodePtr node,size_t index)167 void Node::AddOperand(NodePtr node, size_t index) {
168   TORCH_CHECK_LT(index, node->num_outputs());
169   operands_.push_back(node);
170   operands_as_outputs_.emplace_back(operands_.back().get(), index);
171 }
172 
173 } // namespace lazy
174 } // namespace torch
175