xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/autocast.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 #include <torch/csrc/jit/passes/autocast.h>
3 
4 #include <ATen/autocast_mode.h>
5 #include <c10/core/ScalarType.h>
6 #include <c10/util/Exception.h>
7 #include <torch/csrc/jit/ir/ir.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/quantization/helper.h>
10 #include <optional>
11 
12 #include <stack>
13 #include <unordered_set>
14 #include <vector>
15 
16 namespace torch::jit {
17 
18 namespace {
19 
20 bool autocast_enabled = true;
21 
22 struct AutocastContext {
23   bool gpu_enabled = false;
24   bool cpu_enabled = false;
25   c10::ScalarType gpu_scalar_type = c10::ScalarType::Undefined;
26   c10::ScalarType cpu_scalar_type = c10::ScalarType::Undefined;
27 
operator booltorch::jit::__anon50c587710111::AutocastContext28   operator bool() const {
29     return gpu_enabled || cpu_enabled;
30   }
31 };
32 
33 struct AutocastScope {
34   Value* instance = nullptr;
35   AutocastContext context;
stacktorch::jit::__anon50c587710111::AutocastScope36   void stack(const AutocastContext& parent_context) {}
37 };
38 
isAutocastNode(Value * value)39 bool isAutocastNode(Value* value) {
40   const auto class_name = getModuleName(value);
41   return class_name.has_value() &&
42       (*class_name == "__torch__.torch.cuda.amp.autocast_mode.autocast" ||
43        *class_name == "__torch__.torch.cpu.amp.autocast_mode.autocast" ||
44        *class_name == "__torch__.torch.amp.autocast_mode.autocast");
45 }
46 
47 // If we have an autocast instance, return it
48 //
49 // This is the pattern we're looking for (this is done after
50 //  autocast.__init__() has been inlined)
51 //
52 // %4 : bool = prim::Constant[value=1]()
53 // %5 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::CreateObject()
54 //  = prim::SetAttr[name="_enabled"](%5, %4)
55 //
56 // Notes:
57 //  1. There's no guarantee that the autocast instance is in the same block
58 //    as the prim::Enter() node
59 //  2. `prim::SetAttr` must follow `prim::CreateObject()` in the same block,
60 //    but there might be other nodes in between
61 //
parseAutocast(Value * value,const AutocastContext & context)62 std::optional<AutocastScope> parseAutocast(
63     Value* value,
64     const AutocastContext& context) {
65   if (!isAutocastNode(value)) {
66     // Not an autocast...
67     return std::nullopt;
68   }
69   if (value->node()->kind() == prim::CreateObject) {
70     AutocastScope scope;
71     scope.instance = value;
72     scope.context = context;
73     std::optional<bool> enabled;
74     std::string device;
75     c10::ScalarType dtype = c10::ScalarType::Undefined;
76     for (Use use : value->uses()) {
77       // TODO: support runtime flag
78       if (use.user->kind() == prim::SetAttr &&
79           use.user->s(attr::name) == "_enabled") {
80         // Search for `prim::SetAttr[name="_enabled"]`
81         enabled = constant_as<bool>(use.user->input(1));
82         TORCH_CHECK(
83             enabled.has_value(),
84             "Autocast _enabled argument must be a constant");
85       } else if (
86           use.user->kind() == prim::SetAttr &&
87           use.user->s(attr::name) == "device") {
88         // Search for `prim::SetAttr[name="device"]`
89         auto ret = constant_as<std::string>(use.user->input(1));
90         TORCH_CHECK(
91             ret.has_value(), "Autocast device argument must be a constant");
92         device = ret.value();
93       } else if (
94           use.user->kind() == prim::SetAttr &&
95           use.user->s(attr::name) == "fast_dtype") {
96         // Search for `prim::SetAttr[name="fast_dtype"]`
97         auto ret = constant_as<c10::ScalarType>(use.user->input(1));
98         if (ret.has_value()) {
99           dtype = ret.value();
100         }
101       }
102     }
103     TORCH_CHECK(enabled.has_value(), "Autocast missing _enabled attribute");
104     TORCH_CHECK(!device.empty(), "Autocast missing device attribute");
105     if (dtype == c10::ScalarType::Undefined) {
106       dtype = at::autocast::get_autocast_dtype(c10::Device(device).type());
107     }
108     TORCH_CHECK(
109         dtype != c10::ScalarType::Undefined,
110         "Autocast has invalid fast_dtype attribute");
111     if (device == "cuda" || device == "mps") {
112       scope.context.gpu_enabled = enabled.value();
113       scope.context.gpu_scalar_type = dtype;
114     } else if (device == "cpu") {
115       scope.context.cpu_enabled = enabled.value();
116       scope.context.cpu_scalar_type = dtype;
117     } else {
118       TORCH_INTERNAL_ASSERT(
119           false, "unrecognized device for autocast pass: ", device);
120     }
121     return scope;
122   } else {
123     // We only support simple and static autocast expressions. For example,
124     // the following should report an error (since the autocast would not
125     // work as expected)
126     //
127     //    autocast_on = autocast(enabled=True)
128     //    autocast_off = autocast(enabled=False)
129     //    with autocast_on if condition else autocast_off:
130     //        ...
131     //
132     // TODO: better error message
133     //
134     AT_ERROR("Unsupported autocast syntax");
135   }
136 
137   return std::nullopt;
138 }
139 
castTensorInputs(Node * node,Symbol cast_op,const AutocastContext & context)140 void castTensorInputs(
141     Node* node,
142     Symbol cast_op,
143     const AutocastContext& context) {
144   if (!context) {
145     return;
146   }
147 
148   const auto graph = node->owningGraph();
149 
150   std::unordered_set<Value*> casted_inputs;
151   // need to also keep the inputs in order, otherwise tracing fails
152   // sanity checks because casting ops are inserted in random order
153   std::vector<Value*> casted_inputs_ordered;
154   for (auto input : node->inputs()) {
155     // TODO: update cast_op signature to take dynamic context flags
156     auto input_tensor_type = input->type()->cast<TensorType>();
157     if (input_tensor_type && input->node()->kind() != cast_op) {
158       auto has_inserted = casted_inputs.insert(input);
159       if (has_inserted.second) {
160         casted_inputs_ordered.push_back(input);
161       }
162     }
163   }
164 
165   WithInsertPoint insert_point(node);
166 
167   for (auto input : casted_inputs_ordered) {
168     if (cast_op == aten::_autocast_to_full_precision) {
169       const auto new_input = graph->insert(
170           cast_op,
171           {input,
172            graph->insertConstant(IValue(context.gpu_enabled)),
173            graph->insertConstant(IValue(context.cpu_enabled))});
174       node->replaceInputWith(input, new_input);
175     } else if (cast_op == aten::_autocast_to_reduced_precision) {
176       const auto new_input = graph->insert(
177           cast_op,
178           {input,
179            graph->insertConstant(IValue(context.gpu_enabled)),
180            graph->insertConstant(IValue(context.cpu_enabled)),
181            graph->insertConstant(IValue(context.gpu_scalar_type)),
182            graph->insertConstant(IValue(context.cpu_scalar_type))});
183       node->replaceInputWith(input, new_input);
184     } else {
185       TORCH_INTERNAL_ASSERT(
186           false, "unrecognized cast_op symbol: ", cast_op.toQualString());
187     }
188   }
189 }
190 
hasExplicitDtypeArgument(Node * node)191 bool hasExplicitDtypeArgument(Node* node) {
192   if (node->hasNamedInput("dtype")) {
193     Value* dtype_arg = node->namedInput("dtype");
194     return dtype_arg->type()->kind() != TypeKind::NoneType;
195   }
196   return false;
197 }
198 
castInputsToWidestType(Node * node,const AutocastContext & context)199 void castInputsToWidestType(Node* node, const AutocastContext& context) {
200   if (!context) {
201     return;
202   }
203   // Figure out the widest type
204   // (really, just looking for any float32 inputs)
205   //
206   // TODO: revisit this (do we need to consider float64 types?)
207   //
208   for (auto input : node->inputs()) {
209     if (auto tensor_type = input->type()->cast<TensorType>()) {
210       const auto dtype = tensor_type->scalarType();
211       if (!dtype.has_value() || *dtype == at::ScalarType::Float) {
212         castTensorInputs(node, aten::_autocast_to_full_precision, context);
213         return;
214       }
215     }
216   }
217 }
218 
219 // Users can call torch.is_autocast_enabled() or is_autocast_cpu_enabled() to
220 // determine whether autocasting is enabled. With JIT-scripted functions, we
221 // actually need to return true if eager autocast OR jit autocast are enabled.
222 //
223 // In the case where JIT autocast is enabled, we replace
224 //    %x : bool = aten::is_autocast_enabled()
225 // with a constant "True".
226 //
227 // More context on eager vs JIT autocasting:
228 //
229 // Autocasting actually has two settings: eager autocasting, and JIT
230 // autocasting. Eager autocasting is the thread-local setting that turns on
231 // the relevant bit in the dispatcher settings. JIT autocasting is the pass
232 // implemented in this file, which makes changes to the graph to insert casting
233 // ops in order to achieve the same behavior as eager autocasting.
234 //
235 // If eager autocasting is enabled at the time when a JIT-scripted function is
236 // invoked, then autocasting will occur regardless of what the JIT-autocasting
237 // settings are.
updateAutocastEnabledCheck(Node * node,bool is_jit_enabled)238 void updateAutocastEnabledCheck(Node* node, bool is_jit_enabled) {
239   if (!is_jit_enabled) {
240     return;
241   }
242 
243   auto graph = node->owningGraph();
244 
245   WithInsertPoint insert_point(node);
246 
247   Value* true_constant = graph->insertConstant(IValue(true));
248   node->output()->replaceAllUsesWith(true_constant);
249   node->destroy();
250 }
251 
252 // [Note: implicit type promotion in Autocast]
253 //
254 // Casting policy below mostly follows pytorch/aten/src/ATen/autocast.cpp, with
255 // a few exceptions, e.g. `aten::add`, which is needed to be put to promotion
256 // list for JIT autocast.
257 // The reason is that in eager amp, some binary ops promote inputs implicitly
258 // inside the operation, e.g. `aten::add` with fp16 & fp32 inputs would both be
259 // casted to fp32. In backward, autograd would cast dgrad to match their
260 // scalar_type in forward graph. So inputs with mismatched scalar_type would
261 // get the different dgrad.
262 // While in JIT, autodiff doesn't do this, so implicit cast is not visible to
263 // autodiff and backward dgrad for mismatched inputs would ended up with dgrads
264 // in the same scalar_type. This has caused downstream operations, which
265 // expects dgrad to be the same scalar type to throw mismatch error.
266 //
267 // TODO: Use the list from AMP eager directly
handleBlock(Block * block,AutocastContext initial_state)268 void handleBlock(Block* block, AutocastContext initial_state) {
269   std::stack<AutocastScope> autocast_stack;
270 
271   std::optional<bool> incompatible_amp = std::nullopt;
272 
273   // The current autocast enabled/disabled state
274   auto current_state = [&] {
275     return autocast_stack.empty() ? initial_state
276                                   : autocast_stack.top().context;
277   };
278 
279   for (Node* node : block->nodes()) {
280     switch (node->kind()) {
281       case prim::CallFunction:
282         // TODO: limit it only to amp related node;
283         if (current_state() == initial_state) {
284           // if the current autocasting state is the same as the global state,
285           // then autocasting will be done correctly on subsequent method and
286           // function calls
287           if (current_state()) {
288             castTensorInputs(
289                 node, aten::_autocast_to_full_precision, current_state());
290           }
291           break;
292         }
293         TORCH_INTERNAL_ASSERT(
294             !incompatible_amp.has_value() || incompatible_amp.value(),
295             "Calls are not expected with AMP & JIT");
296         incompatible_amp = true;
297         break;
298 
299       case prim::CallMethod:
300         // TODO: limit it only to amp related node;
301         if (current_state() == initial_state) {
302           // if the current autocasting state is the same as the global state,
303           // then autocasting will be done correctly on subsequent method and
304           // function calls
305           if (current_state()) {
306             castTensorInputs(
307                 node, aten::_autocast_to_full_precision, current_state());
308           }
309           break;
310         }
311         if (auto class_type = node->input(0)->type()->cast<ClassType>()) {
312           const auto& name = node->s(attr::name);
313           const auto& function = class_type->getMethod(name);
314           if (!function.isGraphFunction()) {
315             TORCH_INTERNAL_ASSERT(
316                 !incompatible_amp.has_value() || incompatible_amp.value(),
317                 "Calls are not expected with AMP & JIT");
318             incompatible_amp = true;
319           }
320         } else {
321           TORCH_INTERNAL_ASSERT(
322               !incompatible_amp.has_value() || incompatible_amp.value(),
323               "Unexpected prim::CallMethod form with AMP & JIT");
324           incompatible_amp = true;
325         }
326         break;
327 
328       case prim::Enter:
329         if (auto autocast_scope =
330                 parseAutocast(node->input(), current_state())) {
331           if (node->hasUses()) {
332             // TODO: better error message
333             AT_ERROR("`with autocast() as ...` is not supported");
334           }
335           TORCH_INTERNAL_ASSERT(
336               !incompatible_amp.has_value() || !incompatible_amp.value(),
337               "Unsupported case by AMP & JIT");
338           incompatible_amp = false;
339           autocast_stack.push(*autocast_scope);
340         }
341         break;
342 
343       case prim::Exit:
344         if (isAutocastNode(node->input(0))) {
345           TORCH_INTERNAL_ASSERT(!autocast_stack.empty());
346           TORCH_INTERNAL_ASSERT(autocast_stack.top().instance == node->input());
347           TORCH_INTERNAL_ASSERT(
348               !incompatible_amp.has_value() || !incompatible_amp.value(),
349               "Unsupported case by AMP & JIT");
350           incompatible_amp = false;
351           autocast_stack.pop();
352         }
353         break;
354 
355       case aten::is_autocast_enabled:
356         updateAutocastEnabledCheck(node, current_state().gpu_enabled);
357         break;
358 
359       case aten::is_autocast_cpu_enabled:
360         updateAutocastEnabledCheck(node, current_state().cpu_enabled);
361         break;
362 
363       // CastPolicy::fp16 (cast all inputs to float16)
364       case aten::_convolution:
365       case aten::conv1d:
366       case aten::conv2d:
367       case aten::conv3d:
368       case aten::conv_tbc:
369       case aten::conv_transpose1d:
370       case aten::convolution:
371       case aten::cudnn_convolution:
372       case aten::cudnn_convolution_transpose:
373       case aten::prelu:
374       case aten::addmm:
375       case aten::addmv:
376       case aten::addr:
377       case aten::matmul:
378       case aten::mm:
379       case aten::mv:
380       case aten::linear:
381       case aten::addbmm:
382       case aten::baddbmm:
383       case aten::bmm:
384       case aten::chain_matmul:
385       case aten::_thnn_fused_lstm_cell:
386       case aten::_thnn_fused_gru_cell:
387       case aten::lstm_cell:
388       case aten::gru_cell:
389       case aten::rnn_tanh_cell:
390       case aten::rnn_relu_cell:
391         if (!node->schema().is_mutable()) {
392           castTensorInputs(
393               node, aten::_autocast_to_reduced_precision, current_state());
394         }
395         break;
396 
397       // CastPolicy::fp32 (cast all inputs to float32)
398       case aten::native_layer_norm:
399       case aten::acos:
400       case aten::asin:
401       case aten::cosh:
402       case aten::erfinv:
403       case aten::exp:
404       case aten::expm1:
405       case aten::log:
406       case aten::log10:
407       case aten::log2:
408       case aten::log1p:
409       case aten::reciprocal:
410       case aten::rsqrt:
411       case aten::sinh:
412       case aten::tan:
413       case aten::pow:
414       case aten::softplus:
415       case aten::gelu:
416       case aten::layer_norm:
417       case aten::group_norm:
418       case aten::frobenius_norm:
419       case aten::nuclear_norm:
420       case aten::cosine_similarity:
421       case aten::cosine_embedding_loss:
422       case aten::nll_loss:
423       case aten::nll_loss2d:
424       case aten::hinge_embedding_loss:
425       case aten::kl_div:
426       case aten::l1_loss:
427       case aten::smooth_l1_loss:
428       case aten::mse_loss:
429       case aten::margin_ranking_loss:
430       case aten::multilabel_margin_loss:
431       case aten::soft_margin_loss:
432       case aten::triplet_margin_loss:
433       case aten::multi_margin_loss:
434       case aten::binary_cross_entropy_with_logits:
435       case aten::dist:
436       case aten::pdist:
437       case aten::cdist:
438       case aten::renorm:
439       case aten::logsumexp:
440         if (!node->schema().is_mutable()) {
441           castTensorInputs(
442               node, aten::_autocast_to_full_precision, current_state());
443         }
444         break;
445 
446       // CastPolicy::fp32_set_opt_dtype
447       case aten::prod:
448       case aten::log_softmax:
449       case aten::cumprod:
450       case aten::cumsum:
451       case aten::sum:
452         if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) {
453           castTensorInputs(
454               node, aten::_autocast_to_full_precision, current_state());
455         }
456         break;
457 
458       // cast softmax to fp32 only on GPU
459       case aten::softmax:
460         if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) {
461           auto context = current_state();
462           context.cpu_enabled = false;
463           castTensorInputs(node, aten::_autocast_to_full_precision, context);
464         }
465         break;
466 
467       // CastPolicy::promote (promote inputs to the widest type)
468       case aten::addcdiv:
469       case aten::addcmul:
470       case aten::atan2:
471       case aten::bilinear:
472       case aten::cat:
473       case aten::cross:
474       case aten::dot:
475       case aten::equal:
476       case aten::index_put:
477       case aten::stack:
478       case aten::tensordot:
479       // add, sub, mul, div were added to autocast jit, because aten implicit
480       // type promotion is not visible to JIT and could cause dtype mismatch on
481       // backward
482       // see [Note: implicit type promotion in Autocast]
483       case aten::add:
484       case aten::sub:
485       case aten::mul:
486       case aten::div:
487         if (!node->schema().is_mutable()) {
488           castInputsToWidestType(node, current_state());
489         }
490         break;
491 
492       // Banned in autocast, see binary_cross_entropy_banned()
493       case aten::binary_cross_entropy:
494         if (current_state()) {
495           AT_ERROR("Unsafe to autocast");
496         }
497     }
498 
499     // process sub-blocks, if any
500     for (Block* sub_block : node->blocks()) {
501       handleBlock(sub_block, current_state());
502     }
503   }
504 
505   // Sanity check: make sure there's no unbalanced transition
506   TORCH_INTERNAL_ASSERT(autocast_stack.empty());
507 }
508 
509 } // namespace
510 
setAutocastMode(bool value)511 bool setAutocastMode(bool value) {
512   auto old_value = autocast_enabled;
513   autocast_enabled = value;
514   return old_value;
515 }
516 
autocastEnabled()517 bool autocastEnabled() {
518   return autocast_enabled;
519 }
520 
Autocast(const std::shared_ptr<Graph> & graph)521 void Autocast(const std::shared_ptr<Graph>& graph) {
522   GRAPH_DUMP("\nBefore Autocast: ", graph);
523   if (autocastEnabled()) {
524     AutocastContext init = {
525         at::autocast::is_autocast_enabled(at::kCUDA),
526         at::autocast::is_autocast_enabled(at::kCPU),
527         at::autocast::get_autocast_dtype(at::kCUDA),
528         at::autocast::get_autocast_dtype(at::kCPU)};
529     handleBlock(graph->block(), init);
530   }
531   GRAPH_DUMP("\nAfter Autocast: ", graph);
532 }
533 
534 } // namespace torch::jit
535