1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
17 #include "tensorflow/compiler/xla/status_macros.h"
18 #include "tensorflow/core/lib/core/errors.h"
19
20 namespace xla {
21
22 namespace {
23
GetWindowedOutputSize(int64_t input_size,int64_t filter_size,int64_t dilation_rate,int64_t stride,Padding padding_type)24 StatusOr<SpatialDimensionOutputSizeAndPadding> GetWindowedOutputSize(
25 int64_t input_size, int64_t filter_size, int64_t dilation_rate,
26 int64_t stride, Padding padding_type) {
27 if (stride <= 0) {
28 return tensorflow::errors::InvalidArgument("Stride must be > 0, but got ",
29 stride);
30 }
31 if (dilation_rate < 1) {
32 return tensorflow::errors::InvalidArgument(
33 "Dilation rate must be >= 1, but got ", dilation_rate);
34 }
35
36 int64_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;
37 SpatialDimensionOutputSizeAndPadding dim;
38 switch (padding_type) {
39 case Padding::kValid:
40 dim.output_size = (input_size - effective_filter_size + stride) / stride;
41 dim.pad_before = dim.pad_after = 0;
42 break;
43 case Padding::kSame:
44 dim.output_size = (input_size + stride - 1) / stride;
45 const int64_t padding_needed =
46 std::max(int64_t{0}, (dim.output_size - 1) * stride +
47 effective_filter_size - input_size);
48 // For odd values of total padding, add more padding on the "after" side
49 // of the given dimension.
50 dim.pad_before = padding_needed / 2;
51 dim.pad_after = padding_needed - dim.pad_before;
52 break;
53 }
54 if (dim.output_size < 0) {
55 return tensorflow::errors::InvalidArgument(
56 "Computed output size would be negative: ", dim.output_size,
57 " [input_size: ", input_size,
58 ", effective_filter_size: ", effective_filter_size,
59 ", stride: ", stride, "]");
60 }
61 return dim;
62 }
63
64 } // namespace
65
66 StatusOr<SpatialDimensionOutputSizeAndPadding>
ConvGradExtractAndVerifyDimension(int64_t input_size,int64_t filter_size,int64_t output_size,int64_t dilation,int64_t stride,Padding padding)67 ConvGradExtractAndVerifyDimension(int64_t input_size, int64_t filter_size,
68 int64_t output_size, int64_t dilation,
69 int64_t stride, Padding padding) {
70 TF_ASSIGN_OR_RETURN(SpatialDimensionOutputSizeAndPadding output_dim,
71 GetWindowedOutputSize(input_size, filter_size, dilation,
72 stride, padding));
73 if (output_size != output_dim.output_size) {
74 return tensorflow::errors::InvalidArgument(
75 "Size of out_backprop doesn't match computed: ", "actual = ",
76 output_size, ", computed = ", output_dim.output_size,
77 " input: ", input_size, " filter: ", filter_size,
78 " output: ", output_size, " stride: ", stride, " dilation: ", dilation);
79 }
80
81 SpatialDimensionOutputSizeAndPadding dim;
82 int64_t effective_filter_size = (filter_size - 1) * dilation + 1;
83 dim.output_size = (output_dim.output_size - 1) * stride + 1;
84 const auto padded_out_size = input_size + effective_filter_size - 1;
85 dim.pad_before = effective_filter_size - 1 - output_dim.pad_before;
86 dim.pad_after = padded_out_size - dim.output_size - dim.pad_before;
87 VLOG(2) << "expanded_out = " << dim.output_size
88 << ", effective_filter_size = " << effective_filter_size
89 << ", padded_out = " << padded_out_size
90 << ", pad_before = " << dim.pad_before
91 << ", pad_after = " << dim.pad_after << ", dilation = " << dilation
92 << ", strides = " << stride;
93 return dim;
94 }
95
96 } // namespace xla
97