1 #include <torch/csrc/jit/runtime/profiling_record.h>
2
3 #include <ATen/core/symbol.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/codegen/cuda/interface.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/passes/clear_profiling.h>
8 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
9 #include <torch/csrc/jit/runtime/autodiff.h>
10 #include <torch/csrc/jit/runtime/graph_executor.h>
11 #include <torch/csrc/jit/runtime/interpreter.h>
12
13 namespace torch::jit {
14
15 namespace {
16
17 class ProfileRegistry {
18 public:
getRegistry()19 static ProfileRegistry* getRegistry() {
20 static ProfileRegistry profile_registry_;
21 return &profile_registry_;
22 }
23
registerProfileNode(const std::function<bool (const Node *)> & func)24 void registerProfileNode(const std::function<bool(const Node*)>& func) {
25 std::lock_guard<std::mutex> guard(mutex_);
26 registry_funcs_.push_back(func);
27 }
28
shouldProfileNode(const Node * node)29 bool shouldProfileNode(const Node* node) {
30 std::lock_guard<std::mutex> guard(mutex_);
31 // to guard differentiable graphs, we want profiling information
32 // (in particular requires_grad) for nodes handled by autodiff
33 if (isDifferentiable(node)) {
34 return true;
35 }
36 for (const auto& func : registry_funcs_) {
37 if (func(node)) {
38 return true;
39 }
40 }
41 return false;
42 }
43
44 private:
45 std::vector<std::function<bool(const Node*)>> registry_funcs_;
46 std::mutex mutex_;
47 };
48
49 } // namespace
50
RegisterProfilingNode(const std::function<bool (const Node *)> & func)51 void RegisterProfilingNode(const std::function<bool(const Node*)>& func) {
52 ProfileRegistry::getRegistry()->registerProfileNode(func);
53 }
54
bindSymbolicShapes(at::IntArrayRef new_sizes,const c10::SymbolicShape & sym_shapes)55 bool ShapeSymbolTable::bindSymbolicShapes(
56 at::IntArrayRef new_sizes,
57 const c10::SymbolicShape& sym_shapes) {
58 if (!sym_shapes.rank().has_value()) {
59 return true;
60 }
61 if (*sym_shapes.rank() != new_sizes.size()) {
62 return false;
63 }
64 for (const auto i : c10::irange(new_sizes.size())) {
65 auto symbol = (*sym_shapes.sizes())[i];
66 if (!symbol.is_static()) {
67 continue;
68 }
69
70 if (!isBound(symbol)) {
71 assign(symbol, new_sizes[i]);
72 continue;
73 }
74
75 if (getValue(symbol) != new_sizes[i]) {
76 return false;
77 }
78 }
79 return true;
80 }
81
ProfilingRecord(std::shared_ptr<Graph> g)82 ProfilingRecord::ProfilingRecord(std::shared_ptr<Graph> g)
83 : profiled_graph_(std::move(g)), profiling_count_(getNumProfiledRuns()) {}
84
createProfileNode(const std::function<void (Stack &)> & fp,at::ArrayRef<Value * > inputs)85 ProfileOp* ProfilingRecord::createProfileNode(
86 const std::function<void(Stack&)>& fp,
87 at::ArrayRef<Value*> inputs) {
88 auto pn = new ProfileOp(profiled_graph_.get(), fp);
89
90 for (auto in : inputs) {
91 pn->addInput(in);
92 }
93 return pn;
94 }
95
createProfileIValueNode(Value * in_val)96 ProfileIValueOp* ProfilingRecord::createProfileIValueNode(Value* in_val) {
97 auto pn = new ProfileIValueOp(this->profiled_graph_.get(), nullptr);
98 pn->addInput(in_val);
99 auto pno = pn->addOutput();
100 pno->setType(in_val->type());
101 return pn;
102 }
103
createProfileIValueNode(ArrayRef<Value * > inputs)104 ProfileIValueOp* ProfilingRecord::createProfileIValueNode(
105 ArrayRef<Value*> inputs) {
106 auto pn = new ProfileIValueOp(this->profiled_graph_.get(), nullptr);
107 for (auto inp : inputs) {
108 pn->addInput(inp);
109 auto pno = pn->addOutput();
110 pno->setType(inp->type());
111 }
112 return pn;
113 }
114
115 namespace {
isOptionalTensorType(const TypePtr & type)116 bool isOptionalTensorType(const TypePtr& type) {
117 if (type->kind() != c10::TypeKind::OptionalType) {
118 return false;
119 }
120 const auto& kind = type->expectRef<OptionalType>().getElementType()->kind();
121 return kind == c10::TypeKind::TensorType;
122 }
123 } // namespace
124
125 // Inserts profiling nodes.
126 //
127 // The prim::profile node profiles Tensor and Optional[Tensor].
128 //
129 // It stores two fields:
130 // 1. attr::seen_none, an integer, which is initially 0 and is set to 1 if the
131 // profiled value is ever `None`
132 // 2. attr::profiled_type, which is the most specific Tensor type that matches
133 // all the non-null inputs observed during profiling.
insertShapeProfile(Node * n,size_t offset,const TypePtr & input_type)134 void ProfilingRecord::insertShapeProfile(
135 Node* n,
136 size_t offset,
137 const TypePtr& input_type) {
138 Value* i = n->input(offset);
139 auto pn = createProfileNode(nullptr, {i});
140 auto pno = pn->addOutput();
141 pn->ty_(attr::profiled_type, TensorType::get());
142 pn->i_(attr::seen_none, 0);
143 if (isOptionalTensorType(input_type)) {
144 pno->setType(OptionalType::create(TensorType::get()));
145 } else if (input_type->kind() == c10::TypeKind::TensorType) {
146 pno->setType(TensorType::get());
147 } else {
148 TORCH_INTERNAL_ASSERT(
149 false,
150 "Trying to profile an unsupported type (neither Tensor or Optional[Tensor]): ",
151 input_type->str());
152 }
153 std::function<void(Stack&)> shape_profiler = [this, pn, pno](Stack& stack) {
154 int64_t frame_id = 0;
155 pop(stack, frame_id);
156 IValue v;
157 pop(stack, v);
158
159 TensorTypePtr new_tensor_type = nullptr;
160 if (v.isTensor()) {
161 auto& t = v.toTensor();
162 new_tensor_type = tensorTypeInCurrentExecutionContext(t);
163 }
164
165 if (v.isTensor() || v.isNone()) {
166 std::lock_guard<std::mutex> lock(this->mutex_);
167 if (profiling_count_ > 0) {
168 GRAPH_DEBUG(
169 "In run ",
170 frame_id,
171 " annotating %",
172 pno->debugName(),
173 " with ",
174 *new_tensor_type);
175
176 if (new_tensor_type != nullptr) {
177 if (pn->hasSeenTensor()) {
178 const auto& existing_tensor_type =
179 pn->ty(attr::profiled_type)->expectRef<TensorType>();
180 GRAPH_DEBUG(
181 "Existing type for %",
182 pno->debugName(),
183 ": ",
184 existing_tensor_type);
185 auto merged_type = new_tensor_type->merge(existing_tensor_type);
186 GRAPH_DEBUG(
187 "Merged type for %", pno->debugName(), ": ", *merged_type);
188 pn->ty_(attr::profiled_type, std::move(merged_type));
189 } else {
190 pn->setHasSeenTensor(true);
191 pn->ty_(attr::profiled_type, std::move(new_tensor_type));
192 }
193 }
194 if (v.isNone()) {
195 pn->i_(attr::seen_none, 1);
196 }
197 }
198 }
199 // passing t through
200 push(stack, v);
201 };
202
203 pn->setCallback(shape_profiler);
204 pn->insertBefore(n);
205 n->replaceInput(offset, pn->output());
206 }
207
needsProfiledInputs(Node * n)208 static bool needsProfiledInputs(Node* n) {
209 if (tensorexpr::isSupported(n)) {
210 return true;
211 }
212
213 switch (n->kind()) {
214 // specialize_autogradzero
215 case prim::AutogradAdd:
216 case prim::AutogradAnyNonZero:
217 case prim::AutogradAllNonZero:
218 case prim::AutogradAllZero:
219 case prim::AutogradZero:
220 // peephole
221 case aten::dim:
222 case aten::size:
223 case aten::expand:
224 case prim::dtype:
225 case prim::device:
226 case prim::is_cuda:
227 case aten::is_floating_point:
228 case aten::type_as:
229 // TODO: hack to make `test_lstm_gates_permutations_cuda`
230 // pass.
231 case aten::t:
232 case aten::mm:
233 return true;
234 default:
235 return ProfileRegistry::getRegistry()->shouldProfileNode(n);
236 }
237 }
238
needsProfiledOutput(Node * n)239 static bool needsProfiledOutput(Node* n) {
240 if (tensorexpr::isSupported(n)) {
241 return true;
242 }
243
244 switch (n->kind()) {
245 case prim::AutogradAdd:
246 case prim::AutogradZero:
247 return true;
248 default:
249 return ProfileRegistry::getRegistry()->shouldProfileNode(n);
250 }
251 }
252
removeProfileCounter(Block * b)253 void ProfilingRecord::removeProfileCounter(Block* b) {
254 for (auto it = b->nodes().rbegin(); it != b->nodes().rend();) {
255 auto n = *it;
256 if (n->kind() == prim::profile && n->inputs().empty()) {
257 it.destroyCurrent();
258 // there is only one counter node
259 return;
260 } else {
261 it++;
262 }
263 }
264 }
265
instrumentBlock(Block * block)266 void ProfilingRecord::instrumentBlock(Block* block) {
267 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
268 auto n = *it;
269 for (const auto offset : c10::irange(n->inputs().size())) {
270 auto i = n->input(offset);
271 if ((needsProfiledInputs(n) || needsProfiledOutput(i->node()))) {
272 if (i->type()->kind() == c10::TypeKind::TensorType ||
273 isOptionalTensorType(i->type())) {
274 insertShapeProfile(n, offset, i->type());
275 }
276 }
277 }
278
279 for (auto b : n->blocks()) {
280 instrumentBlock(b);
281 }
282 }
283
284 // inserting profile nodes on block outputs
285 // allows us to eliminate more guards as
286 // the use of a guard is now in the same
287 // block as opposed to being separated from
288 // the definition by block boundaries
289 for (size_t offset = 0; offset < block->return_node()->inputs().size();
290 offset++) {
291 auto i = block->return_node()->input(offset);
292 if (i->type()->isSubtypeOf(*TensorType::get()) ||
293 isOptionalTensorType(i->type())) {
294 insertShapeProfile(block->return_node(), offset, i->type());
295 }
296 }
297 }
298
removeProfilingNodes(Block * b)299 void ProfilingRecord::removeProfilingNodes(Block* b) {
300 for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
301 if (it->kind() == prim::profile || it->kind() == prim::profile_ivalue) {
302 it->output()->replaceAllUsesWith(it->input());
303 it.destroyCurrent();
304 } else {
305 for (Block* ib : it->blocks()) {
306 removeProfilingNodes(ib);
307 }
308 }
309 }
310 }
311
ready() const312 bool ProfilingRecord::ready() const {
313 std::lock_guard<std::mutex> lock(this->mutex_);
314 return profiling_count_ == 0;
315 }
316
instrumentGraph(const std::shared_ptr<Graph> & graph)317 std::unique_ptr<ProfilingRecord> ProfilingRecord::instrumentGraph(
318 const std::shared_ptr<Graph>& graph) {
319 auto new_g = graph->copy();
320
321 auto pr = std::unique_ptr<ProfilingRecord>(new ProfilingRecord(new_g));
322 auto raw_pr = pr.get();
323 unprofileGraphInputs(new_g);
324 unprofileBlock(new_g->block());
325 pr->instrumentBlock(new_g->block());
326
327 std::function<void(Stack&)> counter = [raw_pr](Stack& stack) {
328 int64_t frame_id = 0;
329 pop(stack, frame_id);
330
331 std::lock_guard<std::mutex> lock(raw_pr->mutex_);
332
333 if (raw_pr->profiling_count_ > 0) {
334 raw_pr->profiling_count_--;
335 }
336 };
337
338 auto pop = pr->createProfileNode(counter, {});
339 new_g->appendNode(pop);
340 GRAPH_DUMP("Instrumented Graph: ", new_g);
341 return pr;
342 }
343
344 } // namespace torch::jit
345