xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/vision.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/enum.h>
6 #include <torch/types.h>
7 
8 namespace torch {
9 namespace nn {
10 namespace functional {
11 
12 /// Options for `torch::nn::functional::grid_sample`.
13 ///
14 /// Example:
15 /// ```
16 /// namespace F = torch::nn::functional;
17 /// F::grid_sample(input, grid,
18 /// F::GridSampleFuncOptions().mode(torch::kBilinear).padding_mode(torch::kZeros).align_corners(true));
19 /// ```
20 struct TORCH_API GridSampleFuncOptions {
21   typedef std::variant<enumtype::kBilinear, enumtype::kNearest> mode_t;
22   typedef std::
23       variant<enumtype::kZeros, enumtype::kBorder, enumtype::kReflection>
24           padding_mode_t;
25 
26   /// interpolation mode to calculate output values. Default: Bilinear
27   TORCH_ARG(mode_t, mode) = torch::kBilinear;
28   /// padding mode for outside grid values. Default: Zeros
29   TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros;
30   /// Specifies perspective to pixel as point. Default: false
31   TORCH_ARG(std::optional<bool>, align_corners) = std::nullopt;
32 };
33 
34 } // namespace functional
35 } // namespace nn
36 } // namespace torch
37