1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/kernel_shape_util.h"
16
17 #include "tensorflow/core/lib/core/errors.h"
18
19 namespace tensorflow {
GetWindowedOutputSizeVerboseV2(int64_t input_size,int64_t filter_size,int64_t dilation_rate,int64_t stride,Padding padding_type,int64_t * output_size,int64_t * padding_before,int64_t * padding_after)20 Status GetWindowedOutputSizeVerboseV2(int64_t input_size, int64_t filter_size,
21 int64_t dilation_rate, int64_t stride,
22 Padding padding_type,
23 int64_t* output_size,
24 int64_t* padding_before,
25 int64_t* padding_after) {
26 if (stride <= 0) {
27 return errors::InvalidArgument("Stride must be > 0, but got ", stride);
28 }
29 if (dilation_rate < 1) {
30 return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
31 dilation_rate);
32 }
33
34 // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2.
35 int64_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;
36 switch (padding_type) {
37 case Padding::VALID:
38 *output_size = (input_size - effective_filter_size + stride) / stride;
39 *padding_before = *padding_after = 0;
40 break;
41 case Padding::EXPLICIT:
42 *output_size = (input_size + *padding_before + *padding_after -
43 effective_filter_size + stride) /
44 stride;
45 break;
46 case Padding::SAME:
47 *output_size = (input_size + stride - 1) / stride;
48 const int64_t padding_needed =
49 std::max(int64_t{0}, (*output_size - 1) * stride +
50 effective_filter_size - input_size);
51 // For odd values of total padding, add more padding at the 'right'
52 // side of the given dimension.
53 *padding_before = padding_needed / 2;
54 *padding_after = padding_needed - *padding_before;
55 break;
56 }
57 if (*output_size < 0) {
58 return errors::InvalidArgument(
59 "Computed output size would be negative: ", *output_size,
60 " [input_size: ", input_size,
61 ", effective_filter_size: ", effective_filter_size,
62 ", stride: ", stride, "]");
63 }
64 return OkStatus();
65 }
66
GetWindowedOutputSizeVerbose(int64_t input_size,int64_t filter_size,int64_t stride,Padding padding_type,int64_t * output_size,int64_t * padding_before,int64_t * padding_after)67 Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size,
68 int64_t stride, Padding padding_type,
69 int64_t* output_size,
70 int64_t* padding_before,
71 int64_t* padding_after) {
72 return GetWindowedOutputSizeVerboseV2(input_size, filter_size,
73 /*dilation_rate=*/1, stride,
74 padding_type, output_size,
75 padding_before, padding_after);
76 }
77
GetWindowedOutputSize(int64_t input_size,int64_t filter_size,int64_t stride,Padding padding_type,int64_t * output_size,int64_t * padding_size)78 Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size,
79 int64_t stride, Padding padding_type,
80 int64_t* output_size, int64_t* padding_size) {
81 if (padding_type == Padding::EXPLICIT) {
82 return errors::Internal(
83 "GetWindowedOutputSize does not handle EXPLICIT padding; call "
84 "GetWindowedOutputSizeVerbose instead");
85 }
86 int64_t padding_after_unused;
87 return GetWindowedOutputSizeVerbose(input_size, filter_size, stride,
88 padding_type, output_size, padding_size,
89 &padding_after_unused);
90 }
91
GetWindowedOutputSizeV2(int64_t input_size,int64_t filter_size,int64_t dilation_rate,int64_t stride,Padding padding_type,int64_t * output_size,int64_t * padding_size)92 Status GetWindowedOutputSizeV2(int64_t input_size, int64_t filter_size,
93 int64_t dilation_rate, int64_t stride,
94 Padding padding_type, int64_t* output_size,
95 int64_t* padding_size) {
96 if (padding_type == Padding::EXPLICIT) {
97 return errors::Internal(
98 "GetWindowedOutputSizeV2 does not handle EXPLICIT padding; call "
99 "GetWindowedOutputSizeVerboseV2 instead");
100 }
101 int64_t padding_after_unused;
102 return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate,
103 stride, padding_type, output_size,
104 padding_size, &padding_after_unused);
105 }
106
Get3dOutputSize(const std::array<int64_t,3> & input,const std::array<int64_t,3> & window,const std::array<int64_t,3> & strides,Padding padding_type,std::array<int64_t,3> * output_ptr,std::array<int64_t,3> * padding_ptr)107 Status Get3dOutputSize(const std::array<int64_t, 3>& input,
108 const std::array<int64_t, 3>& window,
109 const std::array<int64_t, 3>& strides,
110 Padding padding_type, std::array<int64_t, 3>* output_ptr,
111 std::array<int64_t, 3>* padding_ptr) {
112 for (size_t i = 0; i < input.size(); ++i) {
113 TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i],
114 padding_type, &(*output_ptr)[i],
115 &(*padding_ptr)[i]));
116 }
117 return OkStatus();
118 }
119
Get3dOutputSizeV2(const std::array<int64_t,3> & input,const std::array<int64_t,3> & window,const std::array<int64_t,3> & dilations,const std::array<int64_t,3> & strides,Padding padding_type,std::array<int64_t,3> * output_ptr,std::array<int64_t,3> * padding_ptr)120 Status Get3dOutputSizeV2(const std::array<int64_t, 3>& input,
121 const std::array<int64_t, 3>& window,
122 const std::array<int64_t, 3>& dilations,
123 const std::array<int64_t, 3>& strides,
124 Padding padding_type,
125 std::array<int64_t, 3>* output_ptr,
126 std::array<int64_t, 3>* padding_ptr) {
127 for (size_t i = 0; i < input.size(); ++i) {
128 TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
129 input[i], window[i], dilations[i], strides[i], padding_type,
130 &(*output_ptr)[i], &(*padding_ptr)[i]));
131 }
132 return OkStatus();
133 }
134 } // namespace tensorflow
135