xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/kernel_shape_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/kernel_shape_util.h"
16 
17 #include "tensorflow/core/lib/core/errors.h"
18 
19 namespace tensorflow {
GetWindowedOutputSizeVerboseV2(int64_t input_size,int64_t filter_size,int64_t dilation_rate,int64_t stride,Padding padding_type,int64_t * output_size,int64_t * padding_before,int64_t * padding_after)20 Status GetWindowedOutputSizeVerboseV2(int64_t input_size, int64_t filter_size,
21                                       int64_t dilation_rate, int64_t stride,
22                                       Padding padding_type,
23                                       int64_t* output_size,
24                                       int64_t* padding_before,
25                                       int64_t* padding_after) {
26   if (stride <= 0) {
27     return errors::InvalidArgument("Stride must be > 0, but got ", stride);
28   }
29   if (dilation_rate < 1) {
30     return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
31                                    dilation_rate);
32   }
33 
34   // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2.
35   int64_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;
36   switch (padding_type) {
37     case Padding::VALID:
38       *output_size = (input_size - effective_filter_size + stride) / stride;
39       *padding_before = *padding_after = 0;
40       break;
41     case Padding::EXPLICIT:
42       *output_size = (input_size + *padding_before + *padding_after -
43                       effective_filter_size + stride) /
44                      stride;
45       break;
46     case Padding::SAME:
47       *output_size = (input_size + stride - 1) / stride;
48       const int64_t padding_needed =
49           std::max(int64_t{0}, (*output_size - 1) * stride +
50                                    effective_filter_size - input_size);
51       // For odd values of total padding, add more padding at the 'right'
52       // side of the given dimension.
53       *padding_before = padding_needed / 2;
54       *padding_after = padding_needed - *padding_before;
55       break;
56   }
57   if (*output_size < 0) {
58     return errors::InvalidArgument(
59         "Computed output size would be negative: ", *output_size,
60         " [input_size: ", input_size,
61         ", effective_filter_size: ", effective_filter_size,
62         ", stride: ", stride, "]");
63   }
64   return OkStatus();
65 }
66 
GetWindowedOutputSizeVerbose(int64_t input_size,int64_t filter_size,int64_t stride,Padding padding_type,int64_t * output_size,int64_t * padding_before,int64_t * padding_after)67 Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size,
68                                     int64_t stride, Padding padding_type,
69                                     int64_t* output_size,
70                                     int64_t* padding_before,
71                                     int64_t* padding_after) {
72   return GetWindowedOutputSizeVerboseV2(input_size, filter_size,
73                                         /*dilation_rate=*/1, stride,
74                                         padding_type, output_size,
75                                         padding_before, padding_after);
76 }
77 
GetWindowedOutputSize(int64_t input_size,int64_t filter_size,int64_t stride,Padding padding_type,int64_t * output_size,int64_t * padding_size)78 Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size,
79                              int64_t stride, Padding padding_type,
80                              int64_t* output_size, int64_t* padding_size) {
81   if (padding_type == Padding::EXPLICIT) {
82     return errors::Internal(
83         "GetWindowedOutputSize does not handle EXPLICIT padding; call "
84         "GetWindowedOutputSizeVerbose instead");
85   }
86   int64_t padding_after_unused;
87   return GetWindowedOutputSizeVerbose(input_size, filter_size, stride,
88                                       padding_type, output_size, padding_size,
89                                       &padding_after_unused);
90 }
91 
GetWindowedOutputSizeV2(int64_t input_size,int64_t filter_size,int64_t dilation_rate,int64_t stride,Padding padding_type,int64_t * output_size,int64_t * padding_size)92 Status GetWindowedOutputSizeV2(int64_t input_size, int64_t filter_size,
93                                int64_t dilation_rate, int64_t stride,
94                                Padding padding_type, int64_t* output_size,
95                                int64_t* padding_size) {
96   if (padding_type == Padding::EXPLICIT) {
97     return errors::Internal(
98         "GetWindowedOutputSizeV2 does not handle EXPLICIT padding; call "
99         "GetWindowedOutputSizeVerboseV2 instead");
100   }
101   int64_t padding_after_unused;
102   return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate,
103                                         stride, padding_type, output_size,
104                                         padding_size, &padding_after_unused);
105 }
106 
Get3dOutputSize(const std::array<int64_t,3> & input,const std::array<int64_t,3> & window,const std::array<int64_t,3> & strides,Padding padding_type,std::array<int64_t,3> * output_ptr,std::array<int64_t,3> * padding_ptr)107 Status Get3dOutputSize(const std::array<int64_t, 3>& input,
108                        const std::array<int64_t, 3>& window,
109                        const std::array<int64_t, 3>& strides,
110                        Padding padding_type, std::array<int64_t, 3>* output_ptr,
111                        std::array<int64_t, 3>* padding_ptr) {
112   for (size_t i = 0; i < input.size(); ++i) {
113     TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i],
114                                              padding_type, &(*output_ptr)[i],
115                                              &(*padding_ptr)[i]));
116   }
117   return OkStatus();
118 }
119 
Get3dOutputSizeV2(const std::array<int64_t,3> & input,const std::array<int64_t,3> & window,const std::array<int64_t,3> & dilations,const std::array<int64_t,3> & strides,Padding padding_type,std::array<int64_t,3> * output_ptr,std::array<int64_t,3> * padding_ptr)120 Status Get3dOutputSizeV2(const std::array<int64_t, 3>& input,
121                          const std::array<int64_t, 3>& window,
122                          const std::array<int64_t, 3>& dilations,
123                          const std::array<int64_t, 3>& strides,
124                          Padding padding_type,
125                          std::array<int64_t, 3>* output_ptr,
126                          std::array<int64_t, 3>* padding_ptr) {
127   for (size_t i = 0; i < input.size(); ++i) {
128     TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
129         input[i], window[i], dilations[i], strides[i], padding_type,
130         &(*output_ptr)[i], &(*padding_ptr)[i]));
131   }
132   return OkStatus();
133 }
134 }  // namespace tensorflow
135