1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: onnx"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker# Autograd funtion that is a replica of the autograd funtion in 7*da0073e9SAndroid Build Coastguard Worker# test_utility_funs.py (test_autograd_module_name) 8*da0073e9SAndroid Build Coastguard Workerclass CustomFunction(torch.autograd.Function): 9*da0073e9SAndroid Build Coastguard Worker @staticmethod 10*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 11*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(input) 12*da0073e9SAndroid Build Coastguard Worker return input.clamp(min=0) 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker @staticmethod 15*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 16*da0073e9SAndroid Build Coastguard Worker (input,) = ctx.saved_tensors 17*da0073e9SAndroid Build Coastguard Worker grad_input = grad_output.clone() 18*da0073e9SAndroid Build Coastguard Worker grad_input[input < 0] = 0 19*da0073e9SAndroid Build Coastguard Worker return grad_input 20