1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21
22 namespace xla {
23
Bind(const DynamicParameter & dynamic_parameter,const DynamicDimension & dynamic_dimension)24 Status DynamicParameterBinding::Bind(
25 const DynamicParameter& dynamic_parameter,
26 const DynamicDimension& dynamic_dimension) {
27 auto result = bindings_.emplace(dynamic_dimension, dynamic_parameter);
28 TF_RET_CHECK(result.second);
29 return OkStatus();
30 }
31
32 std::optional<DynamicParameterBinding::DynamicParameter>
GetBinding(const DynamicDimension & dynamic_dimension) const33 DynamicParameterBinding::GetBinding(
34 const DynamicDimension& dynamic_dimension) const {
35 auto param_iter = bindings_.find(dynamic_dimension);
36 if (param_iter == bindings_.end()) {
37 return std::nullopt;
38 }
39 return param_iter->second;
40 }
41
ToProto() const42 DynamicParameterBindingProto DynamicParameterBinding::ToProto() const {
43 DynamicParameterBindingProto result;
44 for (const auto& binding : bindings_) {
45 const DynamicDimension& dynamic_dimension = binding.first;
46 const DynamicParameter& dynamic_param = binding.second;
47 DynamicParameterBindingProto::Binding binding_proto;
48 binding_proto.set_dynamic_param_num(dynamic_param.parameter_num);
49 for (int64_t i : dynamic_param.parameter_index) {
50 binding_proto.add_dynamic_param_index(i);
51 }
52
53 binding_proto.set_target_param_num(dynamic_dimension.parameter_num);
54
55 for (int64_t i : dynamic_dimension.parameter_index) {
56 binding_proto.add_target_param_index(i);
57 }
58
59 binding_proto.set_target_param_dim_num(dynamic_dimension.dimension);
60 result.add_entries()->Swap(&binding_proto);
61 }
62 return result;
63 }
64
CreateFromProto(const DynamicParameterBindingProto & proto)65 StatusOr<DynamicParameterBinding> DynamicParameterBinding::CreateFromProto(
66 const DynamicParameterBindingProto& proto) {
67 DynamicParameterBinding result;
68 for (const DynamicParameterBindingProto::Binding& binding : proto.entries()) {
69 int64_t dynamic_param_num = binding.dynamic_param_num();
70 ShapeIndex dynamic_param_index(binding.dynamic_param_index().begin(),
71 binding.dynamic_param_index().end());
72 int64_t target_param_num = binding.target_param_num();
73 ShapeIndex target_param_index(binding.target_param_index().begin(),
74 binding.target_param_index().end());
75 int64_t target_dim_num = binding.target_param_dim_num();
76
77 TF_RETURN_IF_ERROR(
78 result.Bind(DynamicParameter{dynamic_param_num, dynamic_param_index},
79 DynamicDimension{target_param_num, target_param_index,
80 target_dim_num}));
81 }
82
83 return result;
84 }
85
ToString() const86 std::string DynamicParameterBinding::ToString() const {
87 std::vector<std::string> pieces;
88 pieces.push_back("DynamicParameterBinding: ");
89 for (const auto& binding : bindings_) {
90 const DynamicDimension& dynamic_dimension = binding.first;
91 const DynamicParameter& dynamic_param = binding.second;
92 pieces.push_back(absl::StrFormat(
93 " -- Input param number %lld at %s has dim %lld as dynamic"
94 " dimension, which is represented by param number %lld at "
95 "%s",
96 dynamic_dimension.parameter_num,
97 dynamic_dimension.parameter_index.ToString(),
98 dynamic_dimension.dimension, dynamic_param.parameter_num,
99 dynamic_param.parameter_index.ToString()));
100 }
101 return absl::StrJoin(pieces, "\n");
102 }
103
ForEachBinding(BindingFn fn) const104 Status DynamicParameterBinding::ForEachBinding(BindingFn fn) const {
105 for (const auto& binding : bindings_) {
106 TF_RETURN_IF_ERROR(fn(binding.second, binding.first));
107 }
108 return OkStatus();
109 }
110
Verify(const HloModule & module) const111 Status DynamicParameterBinding::Verify(const HloModule& module) const {
112 const HloComputation* entry = module.entry_computation();
113 return ForEachBinding([&](const DynamicParameter& dynamic_parameter,
114 const DynamicDimension& dynamic_dimension)
115 -> Status {
116 TF_RET_CHECK(dynamic_parameter.parameter_num >= 0 &&
117 dynamic_parameter.parameter_num < entry->num_parameters());
118 TF_RET_CHECK(dynamic_dimension.parameter_num < entry->num_parameters());
119 TF_RET_CHECK(ShapeUtil::IndexIsValid(
120 entry->parameter_instruction(dynamic_parameter.parameter_num)->shape(),
121 dynamic_parameter.parameter_index));
122 TF_RET_CHECK(ShapeUtil::IndexIsValid(
123 entry->parameter_instruction(dynamic_dimension.parameter_num)->shape(),
124 dynamic_dimension.parameter_index));
125 TF_RET_CHECK(
126 dynamic_dimension.dimension <
127 ShapeUtil::GetSubshape(
128 entry->parameter_instruction(dynamic_dimension.parameter_num)
129 ->shape(),
130 dynamic_dimension.parameter_index)
131 .rank());
132 return OkStatus();
133 });
134 }
135
operator <<(std::ostream & out,const DynamicParameterBinding & binding)136 std::ostream& operator<<(std::ostream& out,
137 const DynamicParameterBinding& binding) {
138 out << binding.ToString();
139 return out;
140 }
141
142 } // namespace xla
143