1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ 18 19 #include <functional> 20 #include <optional> 21 #include <string> 22 #include <utility> 23 24 #include "absl/container/flat_hash_map.h" 25 #include "tensorflow/compiler/xla/service/hlo.pb.h" 26 #include "tensorflow/compiler/xla/shape_tree.h" 27 #include "tensorflow/compiler/xla/shape_util.h" 28 29 namespace xla { 30 31 class HloModule; 32 // We currently use an explicit API that takes an extra parameter to indicate 33 // the runtime size of a dynamic dimension. DynamicParameterBinding indicates 34 // the relationship between parameter: We can have a dynamic parameter that 35 // points to another target parameter to indicate that the target parameter is 36 // dynamic. 37 // 38 // 39 // TODO(b/119520625): Remove this API once we have more dynamic shape infra 40 // ready. 41 class DynamicParameterBinding { 42 public: 43 // DynamicParameter represents a special parameter that is used to represent 44 // the runtime size of a dimension of another parameter. A dynamic parameter 45 // has to be a scalar value. 46 struct DynamicParameter { 47 // The parameter number of dynamic parameter. 48 int64_t parameter_num; 49 // The index of the parameter. 50 ShapeIndex parameter_index; 51 }; 52 53 // DynamicDimension represents a dimension whose size is determined at 54 // runtime. A DynamicDimension's runtime size is determined by the binded 55 // DynamicParameter using `DynamicParameterBinding::Bind` method. 56 struct DynamicDimension { 57 // The parameter number of dynamic dimension. 58 int64_t parameter_num; 59 // The subshape index of the parameter. 60 ShapeIndex parameter_index; 61 // The dimension number in the subshape. 62 int64_t dimension; 63 64 // "friend" keyword are added so these functions can be found by ADL. 65 template <typename H> AbslHashValueDynamicDimension66 friend H AbslHashValue(H h, const DynamicDimension& m) { 67 return H::combine(std::move(h), m.parameter_num, m.parameter_index, 68 m.dimension); 69 } 70 71 friend bool operator==(const DynamicDimension& lhs, 72 const DynamicDimension& rhs) { 73 return lhs.parameter_num == rhs.parameter_num && 74 lhs.parameter_index == rhs.parameter_index && 75 lhs.dimension == rhs.dimension; 76 } 77 }; 78 79 DynamicParameterBinding() = default; 80 81 virtual ~DynamicParameterBinding() = default; 82 83 // Adds binding which indicates that the dimension indicated by 84 // `dynamic_dimension` is dynamic, and its runtime size is represented by 85 // `dynamic_parameter`. 86 Status Bind(const DynamicParameter& dynamic_parameter, 87 const DynamicDimension& dynamic_dimension); 88 89 // Returns the parameter and the index representing the runtime size of 90 // dimension `dim_num` of parameter `param_num` at `param_index`. 91 // 92 // Returns nullopt if the binding is not set. 93 std::optional<DynamicParameter> GetBinding( 94 const DynamicDimension& dynamic_dimension) const; 95 96 using BindingFn = 97 std::function<Status(const DynamicParameter& dynamic_parameter, 98 const DynamicDimension& dynamic_dimension)>; 99 100 // Iterate through each binding. 101 Status ForEachBinding(BindingFn fn) const; 102 103 DynamicParameterBindingProto ToProto() const; 104 105 static StatusOr<DynamicParameterBinding> CreateFromProto( 106 const DynamicParameterBindingProto& proto); 107 108 std::string ToString() const; 109 110 // Verifies that the given binding is valid for the given module. 111 // Specifically, the binding's parameter and parameter size should be valid. 112 Status Verify(const HloModule& module) const; 113 114 private: 115 // Keeps track of mappings from DynamicDimension to DynamicParameter. The 116 // direction of is chosen so that we can easily query if a dimension is 117 // dynamic and which dynamic parameter represents the real size of that 118 // dimension. 119 absl::flat_hash_map<DynamicDimension, DynamicParameter> bindings_; 120 }; 121 122 std::ostream& operator<<(std::ostream& out, 123 const DynamicParameterBinding& binding); 124 125 } // namespace xla 126 127 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ 128