xref: /aosp_15_r20/external/pytorch/test/onnx/autograd_helper.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: onnx"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker# Autograd funtion that is a replica of the autograd funtion in
7*da0073e9SAndroid Build Coastguard Worker# test_utility_funs.py (test_autograd_module_name)
8*da0073e9SAndroid Build Coastguard Workerclass CustomFunction(torch.autograd.Function):
9*da0073e9SAndroid Build Coastguard Worker    @staticmethod
10*da0073e9SAndroid Build Coastguard Worker    def forward(ctx, input):
11*da0073e9SAndroid Build Coastguard Worker        ctx.save_for_backward(input)
12*da0073e9SAndroid Build Coastguard Worker        return input.clamp(min=0)
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker    @staticmethod
15*da0073e9SAndroid Build Coastguard Worker    def backward(ctx, grad_output):
16*da0073e9SAndroid Build Coastguard Worker        (input,) = ctx.saved_tensors
17*da0073e9SAndroid Build Coastguard Worker        grad_input = grad_output.clone()
18*da0073e9SAndroid Build Coastguard Worker        grad_input[input < 0] = 0
19*da0073e9SAndroid Build Coastguard Worker        return grad_input
20