xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/profiling_record.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/profiling_record.h>
2 
3 #include <ATen/core/symbol.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/codegen/cuda/interface.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/passes/clear_profiling.h>
8 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
9 #include <torch/csrc/jit/runtime/autodiff.h>
10 #include <torch/csrc/jit/runtime/graph_executor.h>
11 #include <torch/csrc/jit/runtime/interpreter.h>
12 
13 namespace torch::jit {
14 
15 namespace {
16 
17 class ProfileRegistry {
18  public:
getRegistry()19   static ProfileRegistry* getRegistry() {
20     static ProfileRegistry profile_registry_;
21     return &profile_registry_;
22   }
23 
registerProfileNode(const std::function<bool (const Node *)> & func)24   void registerProfileNode(const std::function<bool(const Node*)>& func) {
25     std::lock_guard<std::mutex> guard(mutex_);
26     registry_funcs_.push_back(func);
27   }
28 
shouldProfileNode(const Node * node)29   bool shouldProfileNode(const Node* node) {
30     std::lock_guard<std::mutex> guard(mutex_);
31     // to guard differentiable graphs, we want profiling information
32     // (in particular requires_grad) for nodes handled by autodiff
33     if (isDifferentiable(node)) {
34       return true;
35     }
36     for (const auto& func : registry_funcs_) {
37       if (func(node)) {
38         return true;
39       }
40     }
41     return false;
42   }
43 
44  private:
45   std::vector<std::function<bool(const Node*)>> registry_funcs_;
46   std::mutex mutex_;
47 };
48 
49 } // namespace
50 
RegisterProfilingNode(const std::function<bool (const Node *)> & func)51 void RegisterProfilingNode(const std::function<bool(const Node*)>& func) {
52   ProfileRegistry::getRegistry()->registerProfileNode(func);
53 }
54 
bindSymbolicShapes(at::IntArrayRef new_sizes,const c10::SymbolicShape & sym_shapes)55 bool ShapeSymbolTable::bindSymbolicShapes(
56     at::IntArrayRef new_sizes,
57     const c10::SymbolicShape& sym_shapes) {
58   if (!sym_shapes.rank().has_value()) {
59     return true;
60   }
61   if (*sym_shapes.rank() != new_sizes.size()) {
62     return false;
63   }
64   for (const auto i : c10::irange(new_sizes.size())) {
65     auto symbol = (*sym_shapes.sizes())[i];
66     if (!symbol.is_static()) {
67       continue;
68     }
69 
70     if (!isBound(symbol)) {
71       assign(symbol, new_sizes[i]);
72       continue;
73     }
74 
75     if (getValue(symbol) != new_sizes[i]) {
76       return false;
77     }
78   }
79   return true;
80 }
81 
ProfilingRecord(std::shared_ptr<Graph> g)82 ProfilingRecord::ProfilingRecord(std::shared_ptr<Graph> g)
83     : profiled_graph_(std::move(g)), profiling_count_(getNumProfiledRuns()) {}
84 
createProfileNode(const std::function<void (Stack &)> & fp,at::ArrayRef<Value * > inputs)85 ProfileOp* ProfilingRecord::createProfileNode(
86     const std::function<void(Stack&)>& fp,
87     at::ArrayRef<Value*> inputs) {
88   auto pn = new ProfileOp(profiled_graph_.get(), fp);
89 
90   for (auto in : inputs) {
91     pn->addInput(in);
92   }
93   return pn;
94 }
95 
createProfileIValueNode(Value * in_val)96 ProfileIValueOp* ProfilingRecord::createProfileIValueNode(Value* in_val) {
97   auto pn = new ProfileIValueOp(this->profiled_graph_.get(), nullptr);
98   pn->addInput(in_val);
99   auto pno = pn->addOutput();
100   pno->setType(in_val->type());
101   return pn;
102 }
103 
createProfileIValueNode(ArrayRef<Value * > inputs)104 ProfileIValueOp* ProfilingRecord::createProfileIValueNode(
105     ArrayRef<Value*> inputs) {
106   auto pn = new ProfileIValueOp(this->profiled_graph_.get(), nullptr);
107   for (auto inp : inputs) {
108     pn->addInput(inp);
109     auto pno = pn->addOutput();
110     pno->setType(inp->type());
111   }
112   return pn;
113 }
114 
115 namespace {
isOptionalTensorType(const TypePtr & type)116 bool isOptionalTensorType(const TypePtr& type) {
117   if (type->kind() != c10::TypeKind::OptionalType) {
118     return false;
119   }
120   const auto& kind = type->expectRef<OptionalType>().getElementType()->kind();
121   return kind == c10::TypeKind::TensorType;
122 }
123 } // namespace
124 
125 // Inserts profiling nodes.
126 //
127 // The prim::profile node profiles Tensor and Optional[Tensor].
128 //
129 // It stores two fields:
130 // 1. attr::seen_none, an integer, which is initially 0 and is set to 1 if the
131 // profiled value is ever `None`
132 // 2. attr::profiled_type, which is the most specific Tensor type that matches
133 // all the non-null inputs observed during profiling.
insertShapeProfile(Node * n,size_t offset,const TypePtr & input_type)134 void ProfilingRecord::insertShapeProfile(
135     Node* n,
136     size_t offset,
137     const TypePtr& input_type) {
138   Value* i = n->input(offset);
139   auto pn = createProfileNode(nullptr, {i});
140   auto pno = pn->addOutput();
141   pn->ty_(attr::profiled_type, TensorType::get());
142   pn->i_(attr::seen_none, 0);
143   if (isOptionalTensorType(input_type)) {
144     pno->setType(OptionalType::create(TensorType::get()));
145   } else if (input_type->kind() == c10::TypeKind::TensorType) {
146     pno->setType(TensorType::get());
147   } else {
148     TORCH_INTERNAL_ASSERT(
149         false,
150         "Trying to profile an unsupported type (neither Tensor or Optional[Tensor]): ",
151         input_type->str());
152   }
153   std::function<void(Stack&)> shape_profiler = [this, pn, pno](Stack& stack) {
154     int64_t frame_id = 0;
155     pop(stack, frame_id);
156     IValue v;
157     pop(stack, v);
158 
159     TensorTypePtr new_tensor_type = nullptr;
160     if (v.isTensor()) {
161       auto& t = v.toTensor();
162       new_tensor_type = tensorTypeInCurrentExecutionContext(t);
163     }
164 
165     if (v.isTensor() || v.isNone()) {
166       std::lock_guard<std::mutex> lock(this->mutex_);
167       if (profiling_count_ > 0) {
168         GRAPH_DEBUG(
169             "In run ",
170             frame_id,
171             " annotating %",
172             pno->debugName(),
173             " with ",
174             *new_tensor_type);
175 
176         if (new_tensor_type != nullptr) {
177           if (pn->hasSeenTensor()) {
178             const auto& existing_tensor_type =
179                 pn->ty(attr::profiled_type)->expectRef<TensorType>();
180             GRAPH_DEBUG(
181                 "Existing type for %",
182                 pno->debugName(),
183                 ": ",
184                 existing_tensor_type);
185             auto merged_type = new_tensor_type->merge(existing_tensor_type);
186             GRAPH_DEBUG(
187                 "Merged type for %", pno->debugName(), ": ", *merged_type);
188             pn->ty_(attr::profiled_type, std::move(merged_type));
189           } else {
190             pn->setHasSeenTensor(true);
191             pn->ty_(attr::profiled_type, std::move(new_tensor_type));
192           }
193         }
194         if (v.isNone()) {
195           pn->i_(attr::seen_none, 1);
196         }
197       }
198     }
199     // passing t through
200     push(stack, v);
201   };
202 
203   pn->setCallback(shape_profiler);
204   pn->insertBefore(n);
205   n->replaceInput(offset, pn->output());
206 }
207 
needsProfiledInputs(Node * n)208 static bool needsProfiledInputs(Node* n) {
209   if (tensorexpr::isSupported(n)) {
210     return true;
211   }
212 
213   switch (n->kind()) {
214     // specialize_autogradzero
215     case prim::AutogradAdd:
216     case prim::AutogradAnyNonZero:
217     case prim::AutogradAllNonZero:
218     case prim::AutogradAllZero:
219     case prim::AutogradZero:
220     // peephole
221     case aten::dim:
222     case aten::size:
223     case aten::expand:
224     case prim::dtype:
225     case prim::device:
226     case prim::is_cuda:
227     case aten::is_floating_point:
228     case aten::type_as:
229     // TODO: hack to make `test_lstm_gates_permutations_cuda`
230     // pass.
231     case aten::t:
232     case aten::mm:
233       return true;
234     default:
235       return ProfileRegistry::getRegistry()->shouldProfileNode(n);
236   }
237 }
238 
needsProfiledOutput(Node * n)239 static bool needsProfiledOutput(Node* n) {
240   if (tensorexpr::isSupported(n)) {
241     return true;
242   }
243 
244   switch (n->kind()) {
245     case prim::AutogradAdd:
246     case prim::AutogradZero:
247       return true;
248     default:
249       return ProfileRegistry::getRegistry()->shouldProfileNode(n);
250   }
251 }
252 
removeProfileCounter(Block * b)253 void ProfilingRecord::removeProfileCounter(Block* b) {
254   for (auto it = b->nodes().rbegin(); it != b->nodes().rend();) {
255     auto n = *it;
256     if (n->kind() == prim::profile && n->inputs().empty()) {
257       it.destroyCurrent();
258       // there is only one counter node
259       return;
260     } else {
261       it++;
262     }
263   }
264 }
265 
instrumentBlock(Block * block)266 void ProfilingRecord::instrumentBlock(Block* block) {
267   for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
268     auto n = *it;
269     for (const auto offset : c10::irange(n->inputs().size())) {
270       auto i = n->input(offset);
271       if ((needsProfiledInputs(n) || needsProfiledOutput(i->node()))) {
272         if (i->type()->kind() == c10::TypeKind::TensorType ||
273             isOptionalTensorType(i->type())) {
274           insertShapeProfile(n, offset, i->type());
275         }
276       }
277     }
278 
279     for (auto b : n->blocks()) {
280       instrumentBlock(b);
281     }
282   }
283 
284   // inserting profile nodes on block outputs
285   // allows us to eliminate more guards as
286   // the use of a guard is now in the same
287   // block as opposed to being separated from
288   // the definition by block boundaries
289   for (size_t offset = 0; offset < block->return_node()->inputs().size();
290        offset++) {
291     auto i = block->return_node()->input(offset);
292     if (i->type()->isSubtypeOf(*TensorType::get()) ||
293         isOptionalTensorType(i->type())) {
294       insertShapeProfile(block->return_node(), offset, i->type());
295     }
296   }
297 }
298 
removeProfilingNodes(Block * b)299 void ProfilingRecord::removeProfilingNodes(Block* b) {
300   for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
301     if (it->kind() == prim::profile || it->kind() == prim::profile_ivalue) {
302       it->output()->replaceAllUsesWith(it->input());
303       it.destroyCurrent();
304     } else {
305       for (Block* ib : it->blocks()) {
306         removeProfilingNodes(ib);
307       }
308     }
309   }
310 }
311 
ready() const312 bool ProfilingRecord::ready() const {
313   std::lock_guard<std::mutex> lock(this->mutex_);
314   return profiling_count_ == 0;
315 }
316 
instrumentGraph(const std::shared_ptr<Graph> & graph)317 std::unique_ptr<ProfilingRecord> ProfilingRecord::instrumentGraph(
318     const std::shared_ptr<Graph>& graph) {
319   auto new_g = graph->copy();
320 
321   auto pr = std::unique_ptr<ProfilingRecord>(new ProfilingRecord(new_g));
322   auto raw_pr = pr.get();
323   unprofileGraphInputs(new_g);
324   unprofileBlock(new_g->block());
325   pr->instrumentBlock(new_g->block());
326 
327   std::function<void(Stack&)> counter = [raw_pr](Stack& stack) {
328     int64_t frame_id = 0;
329     pop(stack, frame_id);
330 
331     std::lock_guard<std::mutex> lock(raw_pr->mutex_);
332 
333     if (raw_pr->profiling_count_ > 0) {
334       raw_pr->profiling_count_--;
335     }
336   };
337 
338   auto pop = pr->createProfileNode(counter, {});
339   new_g->appendNode(pop);
340   GRAPH_DUMP("Instrumented Graph: ", new_g);
341   return pr;
342 }
343 
344 } // namespace torch::jit
345