xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/_functions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/autograd/custom_function.h>
4 #include <torch/csrc/autograd/variable.h>
5 #include <torch/nn/options/normalization.h>
6 #include <torch/types.h>
7 
8 namespace torch {
9 namespace nn {
10 namespace functions {
11 
12 class CrossMapLRN2d : public torch::autograd::Function<CrossMapLRN2d> {
13  public:
14   static torch::autograd::Variable forward(
15       torch::autograd::AutogradContext* ctx,
16       const torch::autograd::Variable& input,
17       const CrossMapLRN2dOptions& options);
18 
19   static torch::autograd::variable_list backward(
20       torch::autograd::AutogradContext* ctx,
21       torch::autograd::variable_list grad_output);
22 };
23 
24 } // namespace functions
25 } // namespace nn
26 } // namespace torch
27