xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 
22 namespace xla {
23 
Bind(const DynamicParameter & dynamic_parameter,const DynamicDimension & dynamic_dimension)24 Status DynamicParameterBinding::Bind(
25     const DynamicParameter& dynamic_parameter,
26     const DynamicDimension& dynamic_dimension) {
27   auto result = bindings_.emplace(dynamic_dimension, dynamic_parameter);
28   TF_RET_CHECK(result.second);
29   return OkStatus();
30 }
31 
32 std::optional<DynamicParameterBinding::DynamicParameter>
GetBinding(const DynamicDimension & dynamic_dimension) const33 DynamicParameterBinding::GetBinding(
34     const DynamicDimension& dynamic_dimension) const {
35   auto param_iter = bindings_.find(dynamic_dimension);
36   if (param_iter == bindings_.end()) {
37     return std::nullopt;
38   }
39   return param_iter->second;
40 }
41 
ToProto() const42 DynamicParameterBindingProto DynamicParameterBinding::ToProto() const {
43   DynamicParameterBindingProto result;
44   for (const auto& binding : bindings_) {
45     const DynamicDimension& dynamic_dimension = binding.first;
46     const DynamicParameter& dynamic_param = binding.second;
47     DynamicParameterBindingProto::Binding binding_proto;
48     binding_proto.set_dynamic_param_num(dynamic_param.parameter_num);
49     for (int64_t i : dynamic_param.parameter_index) {
50       binding_proto.add_dynamic_param_index(i);
51     }
52 
53     binding_proto.set_target_param_num(dynamic_dimension.parameter_num);
54 
55     for (int64_t i : dynamic_dimension.parameter_index) {
56       binding_proto.add_target_param_index(i);
57     }
58 
59     binding_proto.set_target_param_dim_num(dynamic_dimension.dimension);
60     result.add_entries()->Swap(&binding_proto);
61   }
62   return result;
63 }
64 
CreateFromProto(const DynamicParameterBindingProto & proto)65 StatusOr<DynamicParameterBinding> DynamicParameterBinding::CreateFromProto(
66     const DynamicParameterBindingProto& proto) {
67   DynamicParameterBinding result;
68   for (const DynamicParameterBindingProto::Binding& binding : proto.entries()) {
69     int64_t dynamic_param_num = binding.dynamic_param_num();
70     ShapeIndex dynamic_param_index(binding.dynamic_param_index().begin(),
71                                    binding.dynamic_param_index().end());
72     int64_t target_param_num = binding.target_param_num();
73     ShapeIndex target_param_index(binding.target_param_index().begin(),
74                                   binding.target_param_index().end());
75     int64_t target_dim_num = binding.target_param_dim_num();
76 
77     TF_RETURN_IF_ERROR(
78         result.Bind(DynamicParameter{dynamic_param_num, dynamic_param_index},
79                     DynamicDimension{target_param_num, target_param_index,
80                                      target_dim_num}));
81   }
82 
83   return result;
84 }
85 
ToString() const86 std::string DynamicParameterBinding::ToString() const {
87   std::vector<std::string> pieces;
88   pieces.push_back("DynamicParameterBinding: ");
89   for (const auto& binding : bindings_) {
90     const DynamicDimension& dynamic_dimension = binding.first;
91     const DynamicParameter& dynamic_param = binding.second;
92     pieces.push_back(absl::StrFormat(
93         " -- Input param number %lld at %s has dim %lld as dynamic"
94         " dimension, which is represented by param number %lld at "
95         "%s",
96         dynamic_dimension.parameter_num,
97         dynamic_dimension.parameter_index.ToString(),
98         dynamic_dimension.dimension, dynamic_param.parameter_num,
99         dynamic_param.parameter_index.ToString()));
100   }
101   return absl::StrJoin(pieces, "\n");
102 }
103 
ForEachBinding(BindingFn fn) const104 Status DynamicParameterBinding::ForEachBinding(BindingFn fn) const {
105   for (const auto& binding : bindings_) {
106     TF_RETURN_IF_ERROR(fn(binding.second, binding.first));
107   }
108   return OkStatus();
109 }
110 
Verify(const HloModule & module) const111 Status DynamicParameterBinding::Verify(const HloModule& module) const {
112   const HloComputation* entry = module.entry_computation();
113   return ForEachBinding([&](const DynamicParameter& dynamic_parameter,
114                             const DynamicDimension& dynamic_dimension)
115                             -> Status {
116     TF_RET_CHECK(dynamic_parameter.parameter_num >= 0 &&
117                  dynamic_parameter.parameter_num < entry->num_parameters());
118     TF_RET_CHECK(dynamic_dimension.parameter_num < entry->num_parameters());
119     TF_RET_CHECK(ShapeUtil::IndexIsValid(
120         entry->parameter_instruction(dynamic_parameter.parameter_num)->shape(),
121         dynamic_parameter.parameter_index));
122     TF_RET_CHECK(ShapeUtil::IndexIsValid(
123         entry->parameter_instruction(dynamic_dimension.parameter_num)->shape(),
124         dynamic_dimension.parameter_index));
125     TF_RET_CHECK(
126         dynamic_dimension.dimension <
127         ShapeUtil::GetSubshape(
128             entry->parameter_instruction(dynamic_dimension.parameter_num)
129                 ->shape(),
130             dynamic_dimension.parameter_index)
131             .rank());
132     return OkStatus();
133   });
134 }
135 
operator <<(std::ostream & out,const DynamicParameterBinding & binding)136 std::ostream& operator<<(std::ostream& out,
137                          const DynamicParameterBinding& binding) {
138   out << binding.ToString();
139   return out;
140 }
141 
142 }  // namespace xla
143