1 #include <torch/csrc/lazy/backend/backend_interface.h>
2 #include <torch/csrc/lazy/core/cache.h>
3 #include <torch/csrc/lazy/core/config.h>
4 #include <torch/csrc/lazy/core/ir.h>
5 #include <torch/csrc/lazy/core/ir_metadata.h>
6
7 // Enables caching on for dynamic shapes (aka disable hash on shapes)
8 C10_DEFINE_bool(
9 ltc_enable_dynamic_shapes,
10 false,
11 "Whether dynamic shape is enabled");
12
13 namespace torch {
14 namespace lazy {
15
16 static const torch::lazy::Output kNullOutput = torch::lazy::Output();
17
operator ()(const Output & output) const18 size_t Output::Hasher::operator()(const Output& output) const {
19 return StdHashCombine(
20 reinterpret_cast<std::ptrdiff_t>(output.node), output.index);
21 }
22
hash() const23 hash_t Output::hash() const {
24 return HashCombine(node->hash(), Hash(index));
25 }
26
shapeHash() const27 hash_t Output::shapeHash() const {
28 return HashCombine(node->shapeHash(), Hash(index));
29 }
30
ToString() const31 std::string Output::ToString() const {
32 std::stringstream ss;
33 ss << node->ToString() << ", index=" << index;
34 return ss.str();
35 }
36
operator ==(const Value & rhs) const37 bool Output::operator==(const Value& rhs) const {
38 // Either side could be kNullValue which has node as nullptr
39 return (!node == !rhs.node) &&
40 (!node || (node->hash() == rhs.node->hash() && index == rhs.index));
41 }
42
hash() const43 hash_t Value::hash() const {
44 return HashCombine(node->hash(), Hash(index));
45 }
46
shapeHash() const47 hash_t Value::shapeHash() const {
48 return HashCombine(node->shapeHash(), Hash(index));
49 }
50
Get(const std::string & name)51 OpKind OpKind::Get(const std::string& name) {
52 return OpKind(c10::Symbol::fromQualString(name));
53 }
54
hash() const55 hash_t OpKind::hash() const {
56 return StringHash(op.toQualString());
57 }
58
enableDynamicShape()59 bool Node::enableDynamicShape() {
60 static bool enabled = std::getenv("LTC_ENABLE_DYNAMIC_SHAPES") != nullptr;
61 return enabled || FLAGS_ltc_enable_dynamic_shapes;
62 }
63
Node(OpKind op,size_t num_outputs)64 Node::Node(OpKind op, size_t num_outputs)
65 : op_(op), num_outputs_(num_outputs), metadata_(GetMetaDataIfDebugging()) {}
66
Node(OpKind op,OpList operands,std::vector<Shape> && shapes,size_t num_outputs)67 Node::Node(
68 OpKind op,
69 OpList operands,
70 std::vector<Shape>&& shapes,
71 size_t num_outputs)
72 : Node(op, num_outputs) {
73 // Move shapes into node
74 shapes_.insert(
75 shapes_.end(),
76 std::make_move_iterator(shapes.begin()),
77 std::make_move_iterator(shapes.end()));
78
79 for (auto& operand : operands) {
80 // Ideally, optional operands should be filtered by the leaf node classes,
81 // but it's just much easier to do it here.
82 // TODO(alanwaketan): Find a way to move the below logic to the leaf node
83 // classes.
84 if (!operand) {
85 continue;
86 }
87
88 AddOperand(operand.node, operand.index);
89 }
90 }
91
Node(OpKind op,OpList operands,const std::function<Shape ()> & shape_fn,size_t num_outputs)92 Node::Node(
93 OpKind op,
94 OpList operands,
95 const std::function<Shape()>& shape_fn,
96 size_t num_outputs)
97 : Node(op, operands, std::vector<Shape>{}, num_outputs) {
98 addComputedShape(shape_fn);
99 }
100
Node(OpKind op,OpList operands,size_t num_outputs)101 Node::Node(OpKind op, OpList operands, size_t num_outputs)
102 : Node(op, operands, std::vector<Shape>{}, num_outputs) {}
103
Node(OpKind op,Shape shape,size_t num_outputs)104 Node::Node(OpKind op, Shape shape, size_t num_outputs) : Node(op, num_outputs) {
105 shapes_.push_back(std::move(shape));
106 }
107
108 Node::~Node() = default;
109
110 // Retrieves the full shape of the IR Node.
shapes() const111 c10::ArrayRef<Shape> Node::shapes() const {
112 return shapes_;
113 }
114
115 // Retrieves the shape of the output at a given index.
shape(size_t output_index) const116 const Shape& Node::shape(size_t output_index) const {
117 return shapes_.at(output_index);
118 }
119
120 // Add the shape computed by the shape_fn
121
addComputedShape(const std::function<Shape ()> & shape_fn)122 void Node::addComputedShape(const std::function<Shape()>& shape_fn) {
123 shapes_.push_back(computeShape(shape_fn));
124 }
125
126 using ShapeCache = Cache<hash_t, Shape, HashReducer>;
127
128 // Compute the shape using the provided shape_fn.
computeShape(const std::function<Shape ()> & shape_fn)129 Shape Node::computeShape(const std::function<Shape()>& shape_fn) {
130 static ShapeCache* cache = new ShapeCache(FLAGS_torch_lazy_shape_cache_size);
131
132 auto hash = shapeHash();
133 auto shape = cache->Get(hash);
134 if (shape == nullptr) {
135 shape = cache->Add(hash, std::make_shared<Shape>(shape_fn()));
136 }
137 return *shape;
138 }
139
operands() const140 const std::vector<Output>& Node::operands() const {
141 return operands_as_outputs_;
142 }
143
operand(size_t i) const144 const Output& Node::operand(size_t i) const {
145 return operands_as_outputs_.at(i);
146 }
147
nullable_operand(size_t i) const148 const Output& Node::nullable_operand(size_t i) const {
149 // We use kNullOutput instead of kNullValue here to avoid implicit casting,
150 // which would prevent this method from returning a reference.
151 return i < operands_as_outputs_.size() ? operand(i) : kNullOutput;
152 }
153
ToString() const154 std::string Node::ToString() const {
155 std::stringstream ss;
156 ss << shapes() << " " << op();
157 if (num_outputs() > 1) {
158 ss << ", num_outputs=" << num_outputs();
159 }
160 if (!metadata().scope.empty()) {
161 ss << ", scope=" << metadata().scope;
162 }
163 EmitShortFrameInfo(ss, metadata().frame_info);
164 return ss.str();
165 }
166
AddOperand(NodePtr node,size_t index)167 void Node::AddOperand(NodePtr node, size_t index) {
168 TORCH_CHECK_LT(index, node->num_outputs());
169 operands_.push_back(node);
170 operands_as_outputs_.emplace_back(operands_.back().get(), index);
171 }
172
173 } // namespace lazy
174 } // namespace torch
175