1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
17
18 #include <string>
19
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/shape_inference.h"
26
27 namespace xla {
28 namespace {
29 // HloOp wraps an instuction pointer to do arithmetic based on operator
30 // overloading.
31 //
32 // TODO(yunxing): This is only used internally to this file to provide a
33 // convenient way to do operator overloadding. Find out an idiom and merge this
34 // with hlo_creation_utils.
35 class HloOp {
36 public:
37 HloOp() = default;
HloOp(HloInstruction * inst)38 explicit HloOp(HloInstruction* inst) : inst_(inst) {}
SetName(const std::string & name)39 void SetName(const std::string& name) {
40 inst_->SetAndSanitizeName(name);
41 if (inst_->GetModule() != nullptr) {
42 inst_->UniquifyName(&inst_->GetModule()->instruction_name_uniquer());
43 }
44 }
get()45 HloInstruction* get() { return inst_; }
46
47 private:
48 HloInstruction* inst_ = nullptr;
49 };
BinaryOp(HloOp x,HloOp y,HloOpcode opcode,const std::string & name="")50 HloOp BinaryOp(HloOp x, HloOp y, HloOpcode opcode,
51 const std::string& name = "") {
52 CHECK_EQ(x.get()->parent(), y.get()->parent());
53 Shape binary_op_shape =
54 ShapeInference::InferBinaryOpShape(opcode, x.get(), y.get()).ValueOrDie();
55 return HloOp(x.get()->parent()->AddInstruction(
56 HloInstruction::CreateBinary(binary_op_shape, opcode, x.get(), y.get()),
57 name));
58 }
operator +(HloOp x,HloOp y)59 HloOp operator+(HloOp x, HloOp y) { return BinaryOp(x, y, HloOpcode::kAdd); }
60
operator -(HloOp x,HloOp y)61 HloOp operator-(HloOp x, HloOp y) {
62 return BinaryOp(x, y, HloOpcode::kSubtract);
63 }
64
operator *(HloOp x,HloOp y)65 HloOp operator*(HloOp x, HloOp y) {
66 return BinaryOp(x, y, HloOpcode::kMultiply);
67 }
68
operator /(HloOp x,HloOp y)69 HloOp operator/(HloOp x, HloOp y) { return BinaryOp(x, y, HloOpcode::kDivide); }
70
Maximum(HloOp x,HloOp y,const std::string & name="")71 HloOp Maximum(HloOp x, HloOp y, const std::string& name = "") {
72 return BinaryOp(x, y, HloOpcode::kMaximum, name);
73 }
74
75 template <typename NativeT>
ConstantR0(HloComputation * comp,NativeT value,const std::string & name="")76 HloOp ConstantR0(HloComputation* comp, NativeT value,
77 const std::string& name = "") {
78 return HloOp(comp->AddInstruction(
79 HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)),
80 name));
81 }
82
83 template <typename NativeT>
One(HloComputation * comp)84 HloOp One(HloComputation* comp) {
85 return ConstantR0<NativeT>(comp, 1, "one");
86 }
87
88 template <typename NativeT>
Zero(HloComputation * comp)89 HloOp Zero(HloComputation* comp) {
90 return ConstantR0<NativeT>(comp, 0, "zero");
91 }
92
EffectiveFilterSize(HloComputation * comp,int64_t window_size,int64_t window_dilation)93 HloOp EffectiveFilterSize(HloComputation* comp, int64_t window_size,
94 int64_t window_dilation) {
95 return ConstantR0<int32_t>(comp, (window_size - 1) * window_dilation + 1,
96 "effective_filter_size");
97 }
98 } // namespace
99
GetWindowedOutputSize(HloInstruction * input_size,int64_t window_size,int64_t window_dilation,int64_t window_stride,PaddingType padding_type)100 DynamicWindowDims GetWindowedOutputSize(HloInstruction* input_size,
101 int64_t window_size,
102 int64_t window_dilation,
103 int64_t window_stride,
104 PaddingType padding_type) {
105 HloComputation* comp = input_size->parent();
106 DynamicWindowDims result;
107
108 HloOp stride = ConstantR0<int32_t>(comp, window_stride, "stride");
109 HloOp effective_filter_size =
110 EffectiveFilterSize(comp, window_size, window_dilation);
111 if (padding_type == PaddingType::PADDING_VALID) {
112 HloOp output =
113 (HloOp(input_size) + stride - effective_filter_size) / stride;
114 result.output_size = output.get();
115 result.padding_before = Zero<int32_t>(comp).get();
116 } else if (padding_type == PaddingType::PADDING_SAME) {
117 HloOp output = (HloOp(input_size) + stride - One<int32_t>(comp)) / stride;
118 HloOp padding_needed = Maximum(
119 Zero<int32_t>(comp), (output - One<int32_t>(comp)) * stride +
120 effective_filter_size - HloOp(input_size));
121 HloOp padding_before = padding_needed / ConstantR0<int32_t>(comp, 2);
122 result.padding_before = padding_before.get();
123 result.output_size = output.get();
124 }
125
126 return result;
127 }
128
GetWindowedInputGradSize(HloInstruction * input_size,int64_t window_size,int64_t window_dilation,int64_t window_stride,PaddingType padding_type)129 DynamicWindowDims GetWindowedInputGradSize(HloInstruction* input_size,
130 int64_t window_size,
131 int64_t window_dilation,
132 int64_t window_stride,
133 PaddingType padding_type) {
134 HloComputation* comp = input_size->parent();
135 DynamicWindowDims result;
136 HloOp effective_filter_size =
137 ConstantR0<int32_t>(comp, (window_size - 1) * window_dilation + 1);
138 HloOp stride = ConstantR0<int32_t>(comp, window_stride);
139 DynamicWindowDims forward_dims = GetWindowedOutputSize(
140 input_size, window_size, window_dilation, window_stride, padding_type);
141 HloOp output_size =
142 (HloOp(forward_dims.output_size) - One<int32_t>(comp)) * stride +
143 One<int32_t>(comp);
144 HloOp padding_before = effective_filter_size - One<int32_t>(comp) -
145 HloOp(forward_dims.padding_before);
146 result.output_size = output_size.get();
147 result.padding_before = padding_before.get();
148 return result;
149 }
150 } // namespace xla
151