1 #include <c10/util/irange.h>
2 #include <torch/nn/modules/_functions.h>
3
4 using namespace torch::autograd;
5
6 namespace torch {
7 namespace nn {
8 namespace functions {
9
forward(AutogradContext * ctx,const Variable & input,const CrossMapLRN2dOptions & options)10 Variable CrossMapLRN2d::forward(
11 AutogradContext* ctx,
12 const Variable& input,
13 const CrossMapLRN2dOptions& options) {
14 ctx->saved_data["size"] = options.size();
15 ctx->saved_data["alpha"] = options.alpha();
16 ctx->saved_data["beta"] = options.beta();
17 ctx->saved_data["k"] = options.k();
18 ctx->saved_data["scale"] = torch::Tensor();
19
20 TORCH_CHECK(input.dim() == 4);
21
22 ctx->saved_data["scale"] = ctx->saved_data["scale"].toTensor().defined()
23 ? ctx->saved_data["scale"]
24 : torch::empty({0}, input.options());
25
26 torch::Tensor output = torch::empty({0}, input.options());
27
28 int64_t channels = input.size(1);
29
30 output.resize_as_(input);
31 ctx->saved_data["scale"].toTensor().resize_as_(input);
32
33 /// use output storage as temporary buffer
34 auto input_square = output;
35 torch::pow_out(input_square, input, 2);
36
37 int64_t pre_pad =
38 static_cast<int64_t>((ctx->saved_data["size"].toInt() - 1) / 2 + 1);
39 int64_t pre_pad_crop = pre_pad > channels ? channels : pre_pad;
40
41 auto scale_first = ctx->saved_data["scale"].toTensor().select(1, 0);
42 scale_first.zero_();
43
44 /// compute first feature map normalization
45 for (const auto c : c10::irange(pre_pad_crop)) {
46 scale_first.add_(input_square.select(1, c));
47 }
48
49 /// reuse computations for next feature maps normalization
50 /// by adding the next feature map and removing the previous
51 torch::Tensor scale_previous, scale_current, square_next, square_previous;
52
53 for (const auto c : c10::irange(1, channels)) {
54 scale_previous = ctx->saved_data["scale"].toTensor().select(1, c - 1);
55 scale_current = ctx->saved_data["scale"].toTensor().select(1, c);
56 scale_current.copy_(scale_previous);
57
58 if (c < channels - pre_pad + 1) {
59 square_next = input_square.select(1, c + pre_pad - 1);
60 scale_current.add_(square_next, 1);
61 }
62
63 if (c > pre_pad) {
64 square_previous = input_square.select(1, c - pre_pad);
65 scale_current.add_(square_previous, -1);
66 }
67 }
68
69 ctx->saved_data["scale"]
70 .toTensor()
71 .mul_(
72 ctx->saved_data["alpha"].toDouble() / ctx->saved_data["size"].toInt())
73 .add_(ctx->saved_data["k"].toInt());
74
75 torch::pow_out(
76 output,
77 ctx->saved_data["scale"].toTensor(),
78 -ctx->saved_data["beta"].toDouble());
79 output.mul_(input);
80
81 ctx->save_for_backward({input, output});
82 return output;
83 }
84
backward(AutogradContext * ctx,variable_list grad_outputs)85 variable_list CrossMapLRN2d::backward(
86 AutogradContext* ctx,
87 variable_list grad_outputs) {
88 auto grad_output = grad_outputs[0];
89 auto input = ctx->get_saved_variables()[0];
90 auto output = ctx->get_saved_variables()[1];
91 auto grad_input = torch::empty({0}, grad_output.options());
92
93 int64_t batch_size = input.size(0);
94 int64_t channels = input.size(1);
95 int64_t input_height = input.size(2);
96 int64_t input_width = input.size(3);
97
98 auto padded_ratio = torch::empty(
99 {channels + ctx->saved_data["size"].toInt() - 1,
100 input_height,
101 input_width},
102 input.options());
103 auto accum_ratio = torch::empty({input_height, input_width}, input.options());
104 double cache_ratio_value = 2 * ctx->saved_data["alpha"].toDouble() *
105 ctx->saved_data["beta"].toDouble() / ctx->saved_data["size"].toInt();
106 int64_t inversePrePad = static_cast<int64_t>(
107 ctx->saved_data["size"].toInt() -
108 (ctx->saved_data["size"].toInt() - 1) / 2);
109
110 grad_input.resize_as_(input);
111 torch::pow_out(
112 grad_input,
113 ctx->saved_data["scale"].toTensor(),
114 -ctx->saved_data["beta"].toDouble())
115 .mul_(grad_output);
116
117 padded_ratio.zero_();
118 auto padded_ratio_center = padded_ratio.narrow(0, inversePrePad, channels);
119
120 for (const auto n : c10::irange(batch_size)) {
121 torch::mul_out(padded_ratio_center, grad_output[n], output[n]);
122 padded_ratio_center.div_(ctx->saved_data["scale"].toTensor()[n]);
123 torch::sum_out(
124 accum_ratio,
125 padded_ratio.narrow(0, 0, ctx->saved_data["size"].toInt() - 1),
126 0,
127 /*keepdim=*/false);
128 for (const auto c : c10::irange(channels)) {
129 accum_ratio.add_(padded_ratio[c + ctx->saved_data["size"].toInt() - 1]);
130 grad_input[n][c].addcmul_(input[n][c], accum_ratio, -cache_ratio_value);
131 accum_ratio.add_(padded_ratio[c], -1);
132 }
133 }
134
135 return variable_list{
136 grad_input, Variable(), Variable(), Variable(), Variable()};
137 }
138
139 } // namespace functions
140 } // namespace nn
141 } // namespace torch
142