xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/dynamic_window_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
17 
18 #include <string>
19 
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/shape_inference.h"
26 
27 namespace xla {
28 namespace {
29 // HloOp wraps an instuction pointer to do arithmetic based on operator
30 // overloading.
31 //
32 // TODO(yunxing): This is only used internally to this file to provide a
33 // convenient way to do operator overloadding.  Find out an idiom and merge this
34 // with hlo_creation_utils.
35 class HloOp {
36  public:
37   HloOp() = default;
HloOp(HloInstruction * inst)38   explicit HloOp(HloInstruction* inst) : inst_(inst) {}
SetName(const std::string & name)39   void SetName(const std::string& name) {
40     inst_->SetAndSanitizeName(name);
41     if (inst_->GetModule() != nullptr) {
42       inst_->UniquifyName(&inst_->GetModule()->instruction_name_uniquer());
43     }
44   }
get()45   HloInstruction* get() { return inst_; }
46 
47  private:
48   HloInstruction* inst_ = nullptr;
49 };
BinaryOp(HloOp x,HloOp y,HloOpcode opcode,const std::string & name="")50 HloOp BinaryOp(HloOp x, HloOp y, HloOpcode opcode,
51                const std::string& name = "") {
52   CHECK_EQ(x.get()->parent(), y.get()->parent());
53   Shape binary_op_shape =
54       ShapeInference::InferBinaryOpShape(opcode, x.get(), y.get()).ValueOrDie();
55   return HloOp(x.get()->parent()->AddInstruction(
56       HloInstruction::CreateBinary(binary_op_shape, opcode, x.get(), y.get()),
57       name));
58 }
operator +(HloOp x,HloOp y)59 HloOp operator+(HloOp x, HloOp y) { return BinaryOp(x, y, HloOpcode::kAdd); }
60 
operator -(HloOp x,HloOp y)61 HloOp operator-(HloOp x, HloOp y) {
62   return BinaryOp(x, y, HloOpcode::kSubtract);
63 }
64 
operator *(HloOp x,HloOp y)65 HloOp operator*(HloOp x, HloOp y) {
66   return BinaryOp(x, y, HloOpcode::kMultiply);
67 }
68 
operator /(HloOp x,HloOp y)69 HloOp operator/(HloOp x, HloOp y) { return BinaryOp(x, y, HloOpcode::kDivide); }
70 
Maximum(HloOp x,HloOp y,const std::string & name="")71 HloOp Maximum(HloOp x, HloOp y, const std::string& name = "") {
72   return BinaryOp(x, y, HloOpcode::kMaximum, name);
73 }
74 
75 template <typename NativeT>
ConstantR0(HloComputation * comp,NativeT value,const std::string & name="")76 HloOp ConstantR0(HloComputation* comp, NativeT value,
77                  const std::string& name = "") {
78   return HloOp(comp->AddInstruction(
79       HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)),
80       name));
81 }
82 
83 template <typename NativeT>
One(HloComputation * comp)84 HloOp One(HloComputation* comp) {
85   return ConstantR0<NativeT>(comp, 1, "one");
86 }
87 
88 template <typename NativeT>
Zero(HloComputation * comp)89 HloOp Zero(HloComputation* comp) {
90   return ConstantR0<NativeT>(comp, 0, "zero");
91 }
92 
EffectiveFilterSize(HloComputation * comp,int64_t window_size,int64_t window_dilation)93 HloOp EffectiveFilterSize(HloComputation* comp, int64_t window_size,
94                           int64_t window_dilation) {
95   return ConstantR0<int32_t>(comp, (window_size - 1) * window_dilation + 1,
96                              "effective_filter_size");
97 }
98 }  // namespace
99 
GetWindowedOutputSize(HloInstruction * input_size,int64_t window_size,int64_t window_dilation,int64_t window_stride,PaddingType padding_type)100 DynamicWindowDims GetWindowedOutputSize(HloInstruction* input_size,
101                                         int64_t window_size,
102                                         int64_t window_dilation,
103                                         int64_t window_stride,
104                                         PaddingType padding_type) {
105   HloComputation* comp = input_size->parent();
106   DynamicWindowDims result;
107 
108   HloOp stride = ConstantR0<int32_t>(comp, window_stride, "stride");
109   HloOp effective_filter_size =
110       EffectiveFilterSize(comp, window_size, window_dilation);
111   if (padding_type == PaddingType::PADDING_VALID) {
112     HloOp output =
113         (HloOp(input_size) + stride - effective_filter_size) / stride;
114     result.output_size = output.get();
115     result.padding_before = Zero<int32_t>(comp).get();
116   } else if (padding_type == PaddingType::PADDING_SAME) {
117     HloOp output = (HloOp(input_size) + stride - One<int32_t>(comp)) / stride;
118     HloOp padding_needed = Maximum(
119         Zero<int32_t>(comp), (output - One<int32_t>(comp)) * stride +
120                                  effective_filter_size - HloOp(input_size));
121     HloOp padding_before = padding_needed / ConstantR0<int32_t>(comp, 2);
122     result.padding_before = padding_before.get();
123     result.output_size = output.get();
124   }
125 
126   return result;
127 }
128 
GetWindowedInputGradSize(HloInstruction * input_size,int64_t window_size,int64_t window_dilation,int64_t window_stride,PaddingType padding_type)129 DynamicWindowDims GetWindowedInputGradSize(HloInstruction* input_size,
130                                            int64_t window_size,
131                                            int64_t window_dilation,
132                                            int64_t window_stride,
133                                            PaddingType padding_type) {
134   HloComputation* comp = input_size->parent();
135   DynamicWindowDims result;
136   HloOp effective_filter_size =
137       ConstantR0<int32_t>(comp, (window_size - 1) * window_dilation + 1);
138   HloOp stride = ConstantR0<int32_t>(comp, window_stride);
139   DynamicWindowDims forward_dims = GetWindowedOutputSize(
140       input_size, window_size, window_dilation, window_stride, padding_type);
141   HloOp output_size =
142       (HloOp(forward_dims.output_size) - One<int32_t>(comp)) * stride +
143       One<int32_t>(comp);
144   HloOp padding_before = effective_filter_size - One<int32_t>(comp) -
145                          HloOp(forward_dims.padding_before);
146   result.output_size = output_size.get();
147   result.padding_before = padding_before.get();
148   return result;
149 }
150 }  // namespace xla
151