xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/dynamic_dimension_inference.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/types.h"
32 
33 namespace xla {
34 
35 // DynamicDimensionInference analyzes each HLO instruction in a graph and
36 // inferences which dimensions are dynamic and which scalar instructions
37 // represent the runtime real size of those dynamic dimensions.
38 class DynamicDimensionInference {
39  public:
40   enum ShapeCheckMode {
41     kInvalid = 0,
42     // At compile time, pessimisticly assumes runtime shape checks may fail and
43     // returns a compile-time error.
44     kCompileTime,
45     // Insert runtime checks as Hlo ops.
46     kRuntime,
47     // Ignore shape check.
48     kIgnore,
49   };
50   using CustomCallInferenceHandler =
51       std::function<Status(HloInstruction*, DynamicDimensionInference*)>;
52 
53   // Generate an assertion which fails the execution if the instruction value is
54   // false.
55   using AssertionGenerator = std::function<void(HloInstruction*)>;
56 
57   static StatusOr<DynamicDimensionInference> Run(
58       HloModule* module,
59       CustomCallInferenceHandler custom_call_handler = nullptr,
60       ShapeCheckMode shape_check_mode = ShapeCheckMode::kIgnore,
61       const AssertionGenerator& assertion_generator = nullptr);
62 
63   std::string ToString() const;
64 
65   // If the dimension `dim` of instruction `inst` at `index` has a dynamic size,
66   // returns a scalar HloInstruction that represents the runtime size of that
67   // dimension. Otherwise returns nullptr.
68   HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index,
69                                  int64_t dim) const;
70 
71   // Returns dynamic sizes of all dimensions of `inst`'s leaf node at `index`.
72   // Static sizes are represented by nullptr.
73   std::vector<HloInstruction*> GetDynamicSizes(HloInstruction* inst,
74                                                const ShapeIndex& index) const;
75 
76   // Returns if `index` at `inst` contains any dynamic dimension.
77   // Recursively go into tuples.
78   bool HasDynamicDimension(HloInstruction* inst,
79                            ShapeIndexView index = {}) const;
80 
81   // Forward dynamic dimension size at `dim` from `inst` to `new_inst`.
82   Status ForwardDynamicSize(HloInstruction* inst, HloInstruction* new_inst,
83                             const ShapeIndex& index);
84 
85   // Update the dynamic mapping so that we know dimension `dim` of instruction
86   // `inst` at `index` has a dynamic size, and its runtime size is represented
87   // by a scalar instruction `size`.
88   void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index,
89                       int64_t dim, HloInstruction* size);
90 
91   // For all tensors whose dynamic dimension is `replace`, replace them with
92   // `with`.
93   void ReplaceAllDynamicDimensionUsesWith(HloInstruction* replace,
94                                           HloInstruction* with);
95 
96   // Update dynamic dimension inference to analyze `inst`. Useful to
97   // incrementally track new instructions added after initial run.
98   Status Update(HloInstruction* inst);
99 
100   friend class DynamicDimensionInferenceVisitor;
101 
102  private:
103   explicit DynamicDimensionInference(
104       HloModule* module, CustomCallInferenceHandler custom_call_handler,
105       ShapeCheckMode shape_check_mode, AssertionGenerator assertion_generator);
106 
107   // DynamicDimension is used as a key in the dynamic key-value mapping. It
108   // unambiguously represents a dynamic dimension of a instruction at a given
109   // index.
110   struct DynamicDimension {
111     // HloInstruction that holds the dimension.
112     HloInstruction* inst;
113     // Subshape of the instruction that holds the dimension.
114     ShapeIndex index;
115     // The dimension number of the dynamic dimension at given index of a given
116     // instruction.
117     int64_t dim;
118 
119     // Artifacts needed to make this struct able to be used as a `key` in absl
120     // maps. "friend" keywords are added so these functions can be found through
121     // ADL.
122     template <typename H>
AbslHashValueDynamicDimension123     friend H AbslHashValue(H h, const DynamicDimension& m) {
124       return H::combine(std::move(h), m.inst, m.index, m.dim);
125     }
126 
127     friend bool operator==(const DynamicDimension& lhs,
128                            const DynamicDimension& rhs) {
129       return lhs.inst == rhs.inst && lhs.index == rhs.index &&
130              lhs.dim == rhs.dim;
131     }
132 
ToTupleDynamicDimension133     std::tuple<int, int, std::string, int64_t> ToTuple() const {
134       return std::make_tuple(
135           inst && inst->GetModule() ? inst->GetModule()->unique_id() : -1,
136           inst ? inst->unique_id() : -1, index.ToString(), dim);
137     }
138 
139     friend bool operator<(const DynamicDimension& lhs,
140                           const DynamicDimension& rhs) {
141       return lhs.ToTuple() < rhs.ToTuple();
142     }
143   };
144 
145   // Copies the internal mapping from instruction `from` to instruction `to`.
146   // This is useful when an instruction is replaced by the other during the
147   // inferencing process.
148   void CopyMapping(HloInstruction* from, HloInstruction* to);
149 
150   // AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in
151   // module_.
152   Status AnalyzeDynamicDimensions();
153 
154   // HloModule being analyzed.
155   HloModule* module_;
156 
157   // dynamic_mapping_ holds the result of the analysis. It maps a dynamic
158   // dimension to a scalar HloInstruction that represents the real dynamic size
159   // of the dynamic dimension.
160   using DynamicMapping = std::map<DynamicDimension, HloInstruction*>;
161   DynamicMapping dynamic_mapping_;
162 
163   // A convenient mapping from an hlo to the set of dynamic dimensions that it
164   // holds.
165   using PerHloDynamicDimensions =
166       ConstHloInstructionMap<std::set<DynamicDimension>>;
167   PerHloDynamicDimensions per_hlo_dynamic_dimensions_;
168 
169   // A handler for custom calls.
170   CustomCallInferenceHandler custom_call_handler_;
171 
172   // Indicates what to do at places where shape check is needed.
173   ShapeCheckMode shape_check_mode_;
174 
175   AssertionGenerator assertion_generator_;
176 };
177 
178 }  // namespace xla
179 
180 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
181