xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/GridSamplerUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // See NOTE: [Tensor vs. TensorBase]
4 // https://github.com/pytorch/pytorch/pull/66979
5 #include <ATen/core/TensorBase.h>
6 #include <ATen/native/TensorProperties.h>
7 #include <ATen/native/CanUse32BitIndexMath.h>
8 
9 namespace at::native {
10 
11 namespace detail {
12 
13 enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic};
14 enum class GridSamplerPadding {Zeros, Border, Reflection};
15 
16 } // namespace detail
17 
18 using detail::GridSamplerInterpolation;
19 using detail::GridSamplerPadding;
20 
21 // See NOTE [ grid_sampler Native Functions ].
check_grid_sampler_common(const TensorBase & input,const TensorBase & grid)22 inline void check_grid_sampler_common(
23   const TensorBase& input,
24   const TensorBase& grid
25 ) {
26   auto input_opt = input.options();
27   auto grid_opt = grid.options();
28 
29   TORCH_CHECK(
30     input.defined(),
31     "grid_sampler(): expected input to not be undefined");
32   TORCH_CHECK(
33     grid.defined(),
34     "grid_sampler(): expected grid to not be undefined");
35   TORCH_CHECK(
36     input_opt.device() == grid_opt.device(),
37     "grid_sampler(): expected input and grid to be on same device, but input "
38     "is on ", input_opt.device(), " and grid is on ", grid_opt.device());
39   TORCH_CHECK(
40     input_opt.layout() == kStrided && grid_opt.layout() == kStrided,
41     "grid_sampler(): expected input and grid to have torch.strided layout, but "
42     "input has ", input_opt.layout(), " and grid has ", grid_opt.layout());
43   TORCH_CHECK(
44     input.size(0) == grid.size(0),
45     "grid_sampler(): expected grid and input to have same batch size, but got "
46     "input with sizes ", input.sizes(), " and grid with sizes ", grid.sizes());
47   TORCH_CHECK(
48     grid.size(-1) == input.dim() - 2,
49     "grid_sampler(): expected grid to have size ", input.dim() - 2, " in last "
50     "dimension, but got grid with sizes ", grid.sizes());
51 
52   for (const auto i : c10::irange(2, input.dim())) {
53     TORCH_CHECK(input.size(i) > 0,
54       "grid_sampler(): expected input to have non-empty spatial dimensions, "
55       "but input has sizes ", input.sizes(), " with dimension ", i, " being "
56       "empty");
57   }
58 }
59 
60 // See NOTE [ grid_sampler Native Functions ].
check_grid_sampler_2d(const TensorBase & input,const TensorBase & grid)61 inline void check_grid_sampler_2d(
62   const TensorBase& input,
63   const TensorBase& grid
64 ) {
65   TORCH_CHECK(
66     input.dim() == 4 && input.dim() == grid.dim(),
67     "grid_sampler(): expected 4D input and grid with same number of "
68     "dimensions, but got input with sizes ", input.sizes(),
69     " and grid with sizes ", grid.sizes());
70 }
71 
72 // See NOTE [ grid_sampler Native Functions ].
check_grid_sampler_3d(const TensorBase & input,const TensorBase & grid,int64_t interpolation_mode)73 inline void check_grid_sampler_3d(
74   const TensorBase& input,
75   const TensorBase& grid,
76   int64_t interpolation_mode
77 ) {
78   TORCH_CHECK(
79     input.dim() == 5 && input.dim() == grid.dim(),
80     "grid_sampler(): expected 5D input and grid with same number of "
81     "dimensions, but got input with sizes ", input.sizes(),
82     " and grid with sizes ", grid.sizes());
83   TORCH_CHECK(
84     !(input.dim() == 5 &&
85       static_cast<GridSamplerInterpolation>(interpolation_mode) ==
86         GridSamplerInterpolation::Bicubic),
87     "grid_sampler(): bicubic interpolation only supports 4D input");
88 }
89 
90 // See NOTE [ grid_sampler Native Functions ].
91 // cudnn does not support inputs larger than 1024.
cond_cudnn_grid_sampler(const TensorBase & input,const TensorBase & grid)92 inline bool cond_cudnn_grid_sampler(
93   const TensorBase& input,
94   const TensorBase& grid
95 ) {
96   return (
97     at::native::cudnn_is_acceptable(input) &&
98     at::native::cudnn_is_acceptable(grid) &&
99     at::native::canUse32BitIndexMath(input) &&
100     at::native::canUse32BitIndexMath(grid) &&
101     input.dim() == 4 &&
102     input.sym_size(1) <= 1024);
103 }
104 
105 } // namespace at::native
106