1 #pragma once
2
3 // See NOTE: [Tensor vs. TensorBase]
4 // https://github.com/pytorch/pytorch/pull/66979
5 #include <ATen/core/TensorBase.h>
6 #include <ATen/native/TensorProperties.h>
7 #include <ATen/native/CanUse32BitIndexMath.h>
8
9 namespace at::native {
10
11 namespace detail {
12
13 enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic};
14 enum class GridSamplerPadding {Zeros, Border, Reflection};
15
16 } // namespace detail
17
18 using detail::GridSamplerInterpolation;
19 using detail::GridSamplerPadding;
20
21 // See NOTE [ grid_sampler Native Functions ].
check_grid_sampler_common(const TensorBase & input,const TensorBase & grid)22 inline void check_grid_sampler_common(
23 const TensorBase& input,
24 const TensorBase& grid
25 ) {
26 auto input_opt = input.options();
27 auto grid_opt = grid.options();
28
29 TORCH_CHECK(
30 input.defined(),
31 "grid_sampler(): expected input to not be undefined");
32 TORCH_CHECK(
33 grid.defined(),
34 "grid_sampler(): expected grid to not be undefined");
35 TORCH_CHECK(
36 input_opt.device() == grid_opt.device(),
37 "grid_sampler(): expected input and grid to be on same device, but input "
38 "is on ", input_opt.device(), " and grid is on ", grid_opt.device());
39 TORCH_CHECK(
40 input_opt.layout() == kStrided && grid_opt.layout() == kStrided,
41 "grid_sampler(): expected input and grid to have torch.strided layout, but "
42 "input has ", input_opt.layout(), " and grid has ", grid_opt.layout());
43 TORCH_CHECK(
44 input.size(0) == grid.size(0),
45 "grid_sampler(): expected grid and input to have same batch size, but got "
46 "input with sizes ", input.sizes(), " and grid with sizes ", grid.sizes());
47 TORCH_CHECK(
48 grid.size(-1) == input.dim() - 2,
49 "grid_sampler(): expected grid to have size ", input.dim() - 2, " in last "
50 "dimension, but got grid with sizes ", grid.sizes());
51
52 for (const auto i : c10::irange(2, input.dim())) {
53 TORCH_CHECK(input.size(i) > 0,
54 "grid_sampler(): expected input to have non-empty spatial dimensions, "
55 "but input has sizes ", input.sizes(), " with dimension ", i, " being "
56 "empty");
57 }
58 }
59
60 // See NOTE [ grid_sampler Native Functions ].
check_grid_sampler_2d(const TensorBase & input,const TensorBase & grid)61 inline void check_grid_sampler_2d(
62 const TensorBase& input,
63 const TensorBase& grid
64 ) {
65 TORCH_CHECK(
66 input.dim() == 4 && input.dim() == grid.dim(),
67 "grid_sampler(): expected 4D input and grid with same number of "
68 "dimensions, but got input with sizes ", input.sizes(),
69 " and grid with sizes ", grid.sizes());
70 }
71
72 // See NOTE [ grid_sampler Native Functions ].
check_grid_sampler_3d(const TensorBase & input,const TensorBase & grid,int64_t interpolation_mode)73 inline void check_grid_sampler_3d(
74 const TensorBase& input,
75 const TensorBase& grid,
76 int64_t interpolation_mode
77 ) {
78 TORCH_CHECK(
79 input.dim() == 5 && input.dim() == grid.dim(),
80 "grid_sampler(): expected 5D input and grid with same number of "
81 "dimensions, but got input with sizes ", input.sizes(),
82 " and grid with sizes ", grid.sizes());
83 TORCH_CHECK(
84 !(input.dim() == 5 &&
85 static_cast<GridSamplerInterpolation>(interpolation_mode) ==
86 GridSamplerInterpolation::Bicubic),
87 "grid_sampler(): bicubic interpolation only supports 4D input");
88 }
89
90 // See NOTE [ grid_sampler Native Functions ].
91 // cudnn does not support inputs larger than 1024.
cond_cudnn_grid_sampler(const TensorBase & input,const TensorBase & grid)92 inline bool cond_cudnn_grid_sampler(
93 const TensorBase& input,
94 const TensorBase& grid
95 ) {
96 return (
97 at::native::cudnn_is_acceptable(input) &&
98 at::native::cudnn_is_acceptable(grid) &&
99 at::native::canUse32BitIndexMath(input) &&
100 at::native::canUse32BitIndexMath(grid) &&
101 input.dim() == 4 &&
102 input.sym_size(1) <= 1024);
103 }
104
105 } // namespace at::native
106