xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/lower_grad_of.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 namespace torch::jit {
6 
7 // This pass removes 'grad_of' nodes, replacing them with conditionals of
8 // the form:
9 // if any_defined(inputs):
10 //  outputs = <original_computation>
11 // else:
12 //  outputs = undefineds
13 TORCH_API void LowerGradOf(Graph& g);
14 
15 } // namespace torch::jit
16