1
2 #include <torch/csrc/jit/passes/autocast.h>
3
4 #include <ATen/autocast_mode.h>
5 #include <c10/core/ScalarType.h>
6 #include <c10/util/Exception.h>
7 #include <torch/csrc/jit/ir/ir.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/quantization/helper.h>
10 #include <optional>
11
12 #include <stack>
13 #include <unordered_set>
14 #include <vector>
15
16 namespace torch::jit {
17
18 namespace {
19
20 bool autocast_enabled = true;
21
22 struct AutocastContext {
23 bool gpu_enabled = false;
24 bool cpu_enabled = false;
25 c10::ScalarType gpu_scalar_type = c10::ScalarType::Undefined;
26 c10::ScalarType cpu_scalar_type = c10::ScalarType::Undefined;
27
operator booltorch::jit::__anon50c587710111::AutocastContext28 operator bool() const {
29 return gpu_enabled || cpu_enabled;
30 }
31 };
32
33 struct AutocastScope {
34 Value* instance = nullptr;
35 AutocastContext context;
stacktorch::jit::__anon50c587710111::AutocastScope36 void stack(const AutocastContext& parent_context) {}
37 };
38
isAutocastNode(Value * value)39 bool isAutocastNode(Value* value) {
40 const auto class_name = getModuleName(value);
41 return class_name.has_value() &&
42 (*class_name == "__torch__.torch.cuda.amp.autocast_mode.autocast" ||
43 *class_name == "__torch__.torch.cpu.amp.autocast_mode.autocast" ||
44 *class_name == "__torch__.torch.amp.autocast_mode.autocast");
45 }
46
47 // If we have an autocast instance, return it
48 //
49 // This is the pattern we're looking for (this is done after
50 // autocast.__init__() has been inlined)
51 //
52 // %4 : bool = prim::Constant[value=1]()
53 // %5 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::CreateObject()
54 // = prim::SetAttr[name="_enabled"](%5, %4)
55 //
56 // Notes:
57 // 1. There's no guarantee that the autocast instance is in the same block
58 // as the prim::Enter() node
59 // 2. `prim::SetAttr` must follow `prim::CreateObject()` in the same block,
60 // but there might be other nodes in between
61 //
parseAutocast(Value * value,const AutocastContext & context)62 std::optional<AutocastScope> parseAutocast(
63 Value* value,
64 const AutocastContext& context) {
65 if (!isAutocastNode(value)) {
66 // Not an autocast...
67 return std::nullopt;
68 }
69 if (value->node()->kind() == prim::CreateObject) {
70 AutocastScope scope;
71 scope.instance = value;
72 scope.context = context;
73 std::optional<bool> enabled;
74 std::string device;
75 c10::ScalarType dtype = c10::ScalarType::Undefined;
76 for (Use use : value->uses()) {
77 // TODO: support runtime flag
78 if (use.user->kind() == prim::SetAttr &&
79 use.user->s(attr::name) == "_enabled") {
80 // Search for `prim::SetAttr[name="_enabled"]`
81 enabled = constant_as<bool>(use.user->input(1));
82 TORCH_CHECK(
83 enabled.has_value(),
84 "Autocast _enabled argument must be a constant");
85 } else if (
86 use.user->kind() == prim::SetAttr &&
87 use.user->s(attr::name) == "device") {
88 // Search for `prim::SetAttr[name="device"]`
89 auto ret = constant_as<std::string>(use.user->input(1));
90 TORCH_CHECK(
91 ret.has_value(), "Autocast device argument must be a constant");
92 device = ret.value();
93 } else if (
94 use.user->kind() == prim::SetAttr &&
95 use.user->s(attr::name) == "fast_dtype") {
96 // Search for `prim::SetAttr[name="fast_dtype"]`
97 auto ret = constant_as<c10::ScalarType>(use.user->input(1));
98 if (ret.has_value()) {
99 dtype = ret.value();
100 }
101 }
102 }
103 TORCH_CHECK(enabled.has_value(), "Autocast missing _enabled attribute");
104 TORCH_CHECK(!device.empty(), "Autocast missing device attribute");
105 if (dtype == c10::ScalarType::Undefined) {
106 dtype = at::autocast::get_autocast_dtype(c10::Device(device).type());
107 }
108 TORCH_CHECK(
109 dtype != c10::ScalarType::Undefined,
110 "Autocast has invalid fast_dtype attribute");
111 if (device == "cuda" || device == "mps") {
112 scope.context.gpu_enabled = enabled.value();
113 scope.context.gpu_scalar_type = dtype;
114 } else if (device == "cpu") {
115 scope.context.cpu_enabled = enabled.value();
116 scope.context.cpu_scalar_type = dtype;
117 } else {
118 TORCH_INTERNAL_ASSERT(
119 false, "unrecognized device for autocast pass: ", device);
120 }
121 return scope;
122 } else {
123 // We only support simple and static autocast expressions. For example,
124 // the following should report an error (since the autocast would not
125 // work as expected)
126 //
127 // autocast_on = autocast(enabled=True)
128 // autocast_off = autocast(enabled=False)
129 // with autocast_on if condition else autocast_off:
130 // ...
131 //
132 // TODO: better error message
133 //
134 AT_ERROR("Unsupported autocast syntax");
135 }
136
137 return std::nullopt;
138 }
139
castTensorInputs(Node * node,Symbol cast_op,const AutocastContext & context)140 void castTensorInputs(
141 Node* node,
142 Symbol cast_op,
143 const AutocastContext& context) {
144 if (!context) {
145 return;
146 }
147
148 const auto graph = node->owningGraph();
149
150 std::unordered_set<Value*> casted_inputs;
151 // need to also keep the inputs in order, otherwise tracing fails
152 // sanity checks because casting ops are inserted in random order
153 std::vector<Value*> casted_inputs_ordered;
154 for (auto input : node->inputs()) {
155 // TODO: update cast_op signature to take dynamic context flags
156 auto input_tensor_type = input->type()->cast<TensorType>();
157 if (input_tensor_type && input->node()->kind() != cast_op) {
158 auto has_inserted = casted_inputs.insert(input);
159 if (has_inserted.second) {
160 casted_inputs_ordered.push_back(input);
161 }
162 }
163 }
164
165 WithInsertPoint insert_point(node);
166
167 for (auto input : casted_inputs_ordered) {
168 if (cast_op == aten::_autocast_to_full_precision) {
169 const auto new_input = graph->insert(
170 cast_op,
171 {input,
172 graph->insertConstant(IValue(context.gpu_enabled)),
173 graph->insertConstant(IValue(context.cpu_enabled))});
174 node->replaceInputWith(input, new_input);
175 } else if (cast_op == aten::_autocast_to_reduced_precision) {
176 const auto new_input = graph->insert(
177 cast_op,
178 {input,
179 graph->insertConstant(IValue(context.gpu_enabled)),
180 graph->insertConstant(IValue(context.cpu_enabled)),
181 graph->insertConstant(IValue(context.gpu_scalar_type)),
182 graph->insertConstant(IValue(context.cpu_scalar_type))});
183 node->replaceInputWith(input, new_input);
184 } else {
185 TORCH_INTERNAL_ASSERT(
186 false, "unrecognized cast_op symbol: ", cast_op.toQualString());
187 }
188 }
189 }
190
hasExplicitDtypeArgument(Node * node)191 bool hasExplicitDtypeArgument(Node* node) {
192 if (node->hasNamedInput("dtype")) {
193 Value* dtype_arg = node->namedInput("dtype");
194 return dtype_arg->type()->kind() != TypeKind::NoneType;
195 }
196 return false;
197 }
198
castInputsToWidestType(Node * node,const AutocastContext & context)199 void castInputsToWidestType(Node* node, const AutocastContext& context) {
200 if (!context) {
201 return;
202 }
203 // Figure out the widest type
204 // (really, just looking for any float32 inputs)
205 //
206 // TODO: revisit this (do we need to consider float64 types?)
207 //
208 for (auto input : node->inputs()) {
209 if (auto tensor_type = input->type()->cast<TensorType>()) {
210 const auto dtype = tensor_type->scalarType();
211 if (!dtype.has_value() || *dtype == at::ScalarType::Float) {
212 castTensorInputs(node, aten::_autocast_to_full_precision, context);
213 return;
214 }
215 }
216 }
217 }
218
219 // Users can call torch.is_autocast_enabled() or is_autocast_cpu_enabled() to
220 // determine whether autocasting is enabled. With JIT-scripted functions, we
221 // actually need to return true if eager autocast OR jit autocast are enabled.
222 //
223 // In the case where JIT autocast is enabled, we replace
224 // %x : bool = aten::is_autocast_enabled()
225 // with a constant "True".
226 //
227 // More context on eager vs JIT autocasting:
228 //
229 // Autocasting actually has two settings: eager autocasting, and JIT
230 // autocasting. Eager autocasting is the thread-local setting that turns on
231 // the relevant bit in the dispatcher settings. JIT autocasting is the pass
232 // implemented in this file, which makes changes to the graph to insert casting
233 // ops in order to achieve the same behavior as eager autocasting.
234 //
235 // If eager autocasting is enabled at the time when a JIT-scripted function is
236 // invoked, then autocasting will occur regardless of what the JIT-autocasting
237 // settings are.
updateAutocastEnabledCheck(Node * node,bool is_jit_enabled)238 void updateAutocastEnabledCheck(Node* node, bool is_jit_enabled) {
239 if (!is_jit_enabled) {
240 return;
241 }
242
243 auto graph = node->owningGraph();
244
245 WithInsertPoint insert_point(node);
246
247 Value* true_constant = graph->insertConstant(IValue(true));
248 node->output()->replaceAllUsesWith(true_constant);
249 node->destroy();
250 }
251
252 // [Note: implicit type promotion in Autocast]
253 //
254 // Casting policy below mostly follows pytorch/aten/src/ATen/autocast.cpp, with
255 // a few exceptions, e.g. `aten::add`, which is needed to be put to promotion
256 // list for JIT autocast.
257 // The reason is that in eager amp, some binary ops promote inputs implicitly
258 // inside the operation, e.g. `aten::add` with fp16 & fp32 inputs would both be
259 // casted to fp32. In backward, autograd would cast dgrad to match their
260 // scalar_type in forward graph. So inputs with mismatched scalar_type would
261 // get the different dgrad.
262 // While in JIT, autodiff doesn't do this, so implicit cast is not visible to
263 // autodiff and backward dgrad for mismatched inputs would ended up with dgrads
264 // in the same scalar_type. This has caused downstream operations, which
265 // expects dgrad to be the same scalar type to throw mismatch error.
266 //
267 // TODO: Use the list from AMP eager directly
handleBlock(Block * block,AutocastContext initial_state)268 void handleBlock(Block* block, AutocastContext initial_state) {
269 std::stack<AutocastScope> autocast_stack;
270
271 std::optional<bool> incompatible_amp = std::nullopt;
272
273 // The current autocast enabled/disabled state
274 auto current_state = [&] {
275 return autocast_stack.empty() ? initial_state
276 : autocast_stack.top().context;
277 };
278
279 for (Node* node : block->nodes()) {
280 switch (node->kind()) {
281 case prim::CallFunction:
282 // TODO: limit it only to amp related node;
283 if (current_state() == initial_state) {
284 // if the current autocasting state is the same as the global state,
285 // then autocasting will be done correctly on subsequent method and
286 // function calls
287 if (current_state()) {
288 castTensorInputs(
289 node, aten::_autocast_to_full_precision, current_state());
290 }
291 break;
292 }
293 TORCH_INTERNAL_ASSERT(
294 !incompatible_amp.has_value() || incompatible_amp.value(),
295 "Calls are not expected with AMP & JIT");
296 incompatible_amp = true;
297 break;
298
299 case prim::CallMethod:
300 // TODO: limit it only to amp related node;
301 if (current_state() == initial_state) {
302 // if the current autocasting state is the same as the global state,
303 // then autocasting will be done correctly on subsequent method and
304 // function calls
305 if (current_state()) {
306 castTensorInputs(
307 node, aten::_autocast_to_full_precision, current_state());
308 }
309 break;
310 }
311 if (auto class_type = node->input(0)->type()->cast<ClassType>()) {
312 const auto& name = node->s(attr::name);
313 const auto& function = class_type->getMethod(name);
314 if (!function.isGraphFunction()) {
315 TORCH_INTERNAL_ASSERT(
316 !incompatible_amp.has_value() || incompatible_amp.value(),
317 "Calls are not expected with AMP & JIT");
318 incompatible_amp = true;
319 }
320 } else {
321 TORCH_INTERNAL_ASSERT(
322 !incompatible_amp.has_value() || incompatible_amp.value(),
323 "Unexpected prim::CallMethod form with AMP & JIT");
324 incompatible_amp = true;
325 }
326 break;
327
328 case prim::Enter:
329 if (auto autocast_scope =
330 parseAutocast(node->input(), current_state())) {
331 if (node->hasUses()) {
332 // TODO: better error message
333 AT_ERROR("`with autocast() as ...` is not supported");
334 }
335 TORCH_INTERNAL_ASSERT(
336 !incompatible_amp.has_value() || !incompatible_amp.value(),
337 "Unsupported case by AMP & JIT");
338 incompatible_amp = false;
339 autocast_stack.push(*autocast_scope);
340 }
341 break;
342
343 case prim::Exit:
344 if (isAutocastNode(node->input(0))) {
345 TORCH_INTERNAL_ASSERT(!autocast_stack.empty());
346 TORCH_INTERNAL_ASSERT(autocast_stack.top().instance == node->input());
347 TORCH_INTERNAL_ASSERT(
348 !incompatible_amp.has_value() || !incompatible_amp.value(),
349 "Unsupported case by AMP & JIT");
350 incompatible_amp = false;
351 autocast_stack.pop();
352 }
353 break;
354
355 case aten::is_autocast_enabled:
356 updateAutocastEnabledCheck(node, current_state().gpu_enabled);
357 break;
358
359 case aten::is_autocast_cpu_enabled:
360 updateAutocastEnabledCheck(node, current_state().cpu_enabled);
361 break;
362
363 // CastPolicy::fp16 (cast all inputs to float16)
364 case aten::_convolution:
365 case aten::conv1d:
366 case aten::conv2d:
367 case aten::conv3d:
368 case aten::conv_tbc:
369 case aten::conv_transpose1d:
370 case aten::convolution:
371 case aten::cudnn_convolution:
372 case aten::cudnn_convolution_transpose:
373 case aten::prelu:
374 case aten::addmm:
375 case aten::addmv:
376 case aten::addr:
377 case aten::matmul:
378 case aten::mm:
379 case aten::mv:
380 case aten::linear:
381 case aten::addbmm:
382 case aten::baddbmm:
383 case aten::bmm:
384 case aten::chain_matmul:
385 case aten::_thnn_fused_lstm_cell:
386 case aten::_thnn_fused_gru_cell:
387 case aten::lstm_cell:
388 case aten::gru_cell:
389 case aten::rnn_tanh_cell:
390 case aten::rnn_relu_cell:
391 if (!node->schema().is_mutable()) {
392 castTensorInputs(
393 node, aten::_autocast_to_reduced_precision, current_state());
394 }
395 break;
396
397 // CastPolicy::fp32 (cast all inputs to float32)
398 case aten::native_layer_norm:
399 case aten::acos:
400 case aten::asin:
401 case aten::cosh:
402 case aten::erfinv:
403 case aten::exp:
404 case aten::expm1:
405 case aten::log:
406 case aten::log10:
407 case aten::log2:
408 case aten::log1p:
409 case aten::reciprocal:
410 case aten::rsqrt:
411 case aten::sinh:
412 case aten::tan:
413 case aten::pow:
414 case aten::softplus:
415 case aten::gelu:
416 case aten::layer_norm:
417 case aten::group_norm:
418 case aten::frobenius_norm:
419 case aten::nuclear_norm:
420 case aten::cosine_similarity:
421 case aten::cosine_embedding_loss:
422 case aten::nll_loss:
423 case aten::nll_loss2d:
424 case aten::hinge_embedding_loss:
425 case aten::kl_div:
426 case aten::l1_loss:
427 case aten::smooth_l1_loss:
428 case aten::mse_loss:
429 case aten::margin_ranking_loss:
430 case aten::multilabel_margin_loss:
431 case aten::soft_margin_loss:
432 case aten::triplet_margin_loss:
433 case aten::multi_margin_loss:
434 case aten::binary_cross_entropy_with_logits:
435 case aten::dist:
436 case aten::pdist:
437 case aten::cdist:
438 case aten::renorm:
439 case aten::logsumexp:
440 if (!node->schema().is_mutable()) {
441 castTensorInputs(
442 node, aten::_autocast_to_full_precision, current_state());
443 }
444 break;
445
446 // CastPolicy::fp32_set_opt_dtype
447 case aten::prod:
448 case aten::log_softmax:
449 case aten::cumprod:
450 case aten::cumsum:
451 case aten::sum:
452 if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) {
453 castTensorInputs(
454 node, aten::_autocast_to_full_precision, current_state());
455 }
456 break;
457
458 // cast softmax to fp32 only on GPU
459 case aten::softmax:
460 if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) {
461 auto context = current_state();
462 context.cpu_enabled = false;
463 castTensorInputs(node, aten::_autocast_to_full_precision, context);
464 }
465 break;
466
467 // CastPolicy::promote (promote inputs to the widest type)
468 case aten::addcdiv:
469 case aten::addcmul:
470 case aten::atan2:
471 case aten::bilinear:
472 case aten::cat:
473 case aten::cross:
474 case aten::dot:
475 case aten::equal:
476 case aten::index_put:
477 case aten::stack:
478 case aten::tensordot:
479 // add, sub, mul, div were added to autocast jit, because aten implicit
480 // type promotion is not visible to JIT and could cause dtype mismatch on
481 // backward
482 // see [Note: implicit type promotion in Autocast]
483 case aten::add:
484 case aten::sub:
485 case aten::mul:
486 case aten::div:
487 if (!node->schema().is_mutable()) {
488 castInputsToWidestType(node, current_state());
489 }
490 break;
491
492 // Banned in autocast, see binary_cross_entropy_banned()
493 case aten::binary_cross_entropy:
494 if (current_state()) {
495 AT_ERROR("Unsafe to autocast");
496 }
497 }
498
499 // process sub-blocks, if any
500 for (Block* sub_block : node->blocks()) {
501 handleBlock(sub_block, current_state());
502 }
503 }
504
505 // Sanity check: make sure there's no unbalanced transition
506 TORCH_INTERNAL_ASSERT(autocast_stack.empty());
507 }
508
509 } // namespace
510
setAutocastMode(bool value)511 bool setAutocastMode(bool value) {
512 auto old_value = autocast_enabled;
513 autocast_enabled = value;
514 return old_value;
515 }
516
autocastEnabled()517 bool autocastEnabled() {
518 return autocast_enabled;
519 }
520
Autocast(const std::shared_ptr<Graph> & graph)521 void Autocast(const std::shared_ptr<Graph>& graph) {
522 GRAPH_DUMP("\nBefore Autocast: ", graph);
523 if (autocastEnabled()) {
524 AutocastContext init = {
525 at::autocast::is_autocast_enabled(at::kCUDA),
526 at::autocast::is_autocast_enabled(at::kCPU),
527 at::autocast::get_autocast_dtype(at::kCUDA),
528 at::autocast::get_autocast_dtype(at::kCPU)};
529 handleBlock(graph->block(), init);
530 }
531 GRAPH_DUMP("\nAfter Autocast: ", graph);
532 }
533
534 } // namespace torch::jit
535