1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/segment/union_find.h"
17
18 #include "absl/strings/str_format.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
20
21 #if GOOGLE_CUDA && GOOGLE_TENSORRT
22
23 namespace tensorflow {
24 namespace tensorrt {
25 namespace segment {
26
27 namespace {
28 template <typename T>
CheckIfCompatible(const std::optional<T> & a,const std::optional<T> & b)29 inline bool CheckIfCompatible(const std::optional<T>& a,
30 const std::optional<T>& b) {
31 if (a.has_value() && b.has_value()) {
32 return *a == *b;
33 }
34 return true;
35 }
36
37 template <typename T>
UnifyValues(std::optional<T> & a,std::optional<T> & b)38 inline bool UnifyValues(std::optional<T>& a, std::optional<T>& b) {
39 if (a.has_value()) {
40 b = a;
41 } else {
42 a = b;
43 }
44 return true;
45 }
46
47 template <typename T>
MergeCompatible(const std::optional<T> & a,const std::optional<T> & b)48 inline std::optional<T> MergeCompatible(const std::optional<T>& a,
49 const std::optional<T>& b) {
50 DCHECK(CheckIfCompatible(a, b));
51 return a.has_value() ? a : b;
52 }
53
54 } // namespace
55
ClusterBatchSize()56 ClusterBatchSize::ClusterBatchSize()
57 : batch_size_(std::nullopt), max_batch_size_(std::nullopt) {}
58
operator ==(const ClusterBatchSize & other)59 bool ClusterBatchSize::operator==(const ClusterBatchSize& other) {
60 return batch_size_ == other.batch_size_ &&
61 max_batch_size_ == other.max_batch_size_;
62 }
63
SetBatchSize(int batch_size)64 ClusterBatchSize& ClusterBatchSize::SetBatchSize(int batch_size) {
65 SetBatchSize(static_cast<std::optional<int>>(batch_size));
66 return *this;
67 }
68
SetBatchSize(const std::optional<int> & batch_size)69 ClusterBatchSize& ClusterBatchSize::SetBatchSize(
70 const std::optional<int>& batch_size) {
71 batch_size_ = MergeCompatible<int>(batch_size_, batch_size);
72 if (batch_size_.has_value() && batch_size_.value() >= 0) {
73 SetMaxBatchSize(batch_size_);
74 }
75 return *this;
76 }
77
HasBatchSize() const78 bool ClusterBatchSize::HasBatchSize() const { return batch_size_.has_value(); }
79
GetBatchSize() const80 int ClusterBatchSize::GetBatchSize() const {
81 DCHECK(HasBatchSize());
82 return batch_size_.value();
83 }
84
SetMaxBatchSize(int max_batch_size)85 ClusterBatchSize& ClusterBatchSize::SetMaxBatchSize(int max_batch_size) {
86 SetBatchSize(static_cast<std::optional<int>>(max_batch_size));
87 return *this;
88 }
89
SetMaxBatchSize(const std::optional<int> & max_batch_size)90 ClusterBatchSize& ClusterBatchSize::SetMaxBatchSize(
91 const std::optional<int>& max_batch_size) {
92 max_batch_size_ = MergeCompatible<int>(max_batch_size_, max_batch_size);
93 return *this;
94 }
95
GetOptionalMaxBatchSize() const96 std::optional<int> ClusterBatchSize::GetOptionalMaxBatchSize() const {
97 return max_batch_size_;
98 }
99
MergeIfCompatible(const ClusterBatchSize & other)100 bool ClusterBatchSize::MergeIfCompatible(const ClusterBatchSize& other) {
101 if (!CheckIfCompatible(batch_size_, other.batch_size_) ||
102 !CheckIfCompatible(max_batch_size_, other.max_batch_size_)) {
103 return false;
104 }
105
106 SetBatchSize(other.batch_size_);
107 SetMaxBatchSize(other.max_batch_size_);
108 return true;
109 }
110
ToString() const111 string ClusterBatchSize::ToString() const {
112 string s;
113 const auto append_optional_num = [&](const std::optional<int>& num) {
114 if (num.has_value()) {
115 absl::StrAppendFormat(&s, "%d", num.value());
116 } else {
117 absl::StrAppendFormat(&s, "?");
118 }
119 };
120 absl::StrAppendFormat(&s, "batch_size=");
121 append_optional_num(batch_size_);
122 absl::StrAppendFormat(&s, ", max_batch_size=");
123 append_optional_num(max_batch_size_);
124 return s;
125 }
126
ClusterProperty(const ClusterBatchSize & batch_size,const DeviceNameUtils::ParsedName & device_name)127 ClusterProperty::ClusterProperty(const ClusterBatchSize& batch_size,
128 const DeviceNameUtils::ParsedName& device_name)
129 : batch_size_(batch_size), device_name_(device_name) {}
130
Merge(const ClusterProperty & other)131 Status ClusterProperty::Merge(const ClusterProperty& other) {
132 ClusterBatchSize merged_batch_size(batch_size_);
133 if (!merged_batch_size.MergeIfCompatible(other.batch_size_)) {
134 return errors::Internal(
135 "trying to merge clusters with incompatible batch sizes.");
136 }
137
138 std::optional<DeviceNameUtils::ParsedName> merged_device_name =
139 MergeIfCompatible(device_name_, other.device_name_);
140 if (!merged_device_name.has_value()) {
141 return errors::Internal(
142 "trying to merge clusters with incompatible device assignment.");
143 }
144
145 batch_size_ = std::move(merged_batch_size);
146 device_name_ = std::move(merged_device_name.value());
147 return Status::OK();
148 }
149
150 } // namespace segment
151 } // namespace tensorrt
152 } // namespace tensorflow
153
154 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
155