1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
17
18 #include <atomic>
19 #include <unordered_map>
20
21 #include "tensorflow/core/platform/logging.h"
22
23 #if GOOGLE_CUDA && GOOGLE_TENSORRT
24 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
25
26 namespace tensorflow {
27 namespace tensorrt {
28
29 // set the batch size before constructing the thread to execute engine
getBatchSize() const30 int TRTInt8Calibrator::getBatchSize() const noexcept { return batch_size_; }
31
TRTInt8Calibrator(const std::unordered_map<string,std::pair<void *,size_t>> & dev_buffers,int batch_size,string engine_name)32 TRTInt8Calibrator::TRTInt8Calibrator(
33 const std::unordered_map<string, std::pair<void*, size_t>>& dev_buffers,
34 int batch_size, string engine_name)
35 : batch_size_(batch_size),
36 done_(false),
37 dev_buffers_(dev_buffers),
38 // Make sure setBatch() waits until getBatch() is called (the first time).
39 calib_running_(true),
40 batch_is_set_(false),
41 engine_name_(engine_name) {}
42
TRTInt8Calibrator(const string & calib_data)43 TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data)
44 : batch_size_(0),
45 done_(true),
46 calib_running_(false),
47 batch_is_set_(false),
48 calibration_table_(calib_data) {}
49
setBatch(const std::unordered_map<string,void * > & data,const cudaStream_t stream)50 bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
51 const cudaStream_t stream) {
52 mutex_lock lock(cond_mtx_);
53
54 // Wait while the queue is full or calibration is running.
55 while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock);
56 if (done_) return false;
57 CHECK(!calib_running_ && !batch_is_set_);
58 VLOG(1) << "Set Batch Waiting finished";
59
60 // Sets the batch.
61 for (const auto& it : data) {
62 auto devptr = dev_buffers_.find(it.first);
63 if (devptr == dev_buffers_.end()) {
64 LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first
65 << "' does not match with the buffer names";
66 }
67 const auto& d = devptr->second;
68
69 // TODO(sami,aaroey): Need to figure out a way to ensure synchronization
70 // between stream, perhaps using a tensor?
71 auto status = cudaMemcpyAsync(d.first, it.second, d.second,
72 cudaMemcpyDeviceToDevice, stream);
73 if (status != cudaSuccess) {
74 LOG(FATAL) << "cudaMemcpy " << engine_name_ << " for '" << it.first
75 << "' failed with " << status;
76 }
77 }
78
79 // TODO(Sami, aaorey): Find an alternative way!
80 // we have to wait for the stream before returning!
81 cudaStreamSynchronize(stream);
82 batch_is_set_ = true;
83 cond_.notify_all();
84 return true;
85 }
86
getBatch(void ** bindings,const char ** names,int num_bindings)87 bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
88 int num_bindings) noexcept {
89 mutex_lock lock(cond_mtx_);
90 // Notify finish of last round of calibration.
91 calib_running_ = false;
92 cond_.notify_all();
93
94 // Wait until new batch arrives
95 while ((!batch_is_set_ && !done_)) cond_.wait(lock);
96 if (done_) return false;
97
98 // Gets the batch
99 for (int i = 0; i < num_bindings; i++) {
100 auto it = dev_buffers_.find(names[i]);
101 if (it == dev_buffers_.end()) {
102 LOG(FATAL) << "Calibration engine asked for unknown tensor name '"
103 << names[i] << "' at position " << i;
104 }
105 bindings[i] = it->second.first;
106 }
107 batch_is_set_ = false;
108 calib_running_ = true;
109 return true;
110 }
111
waitAndSetDone()112 void TRTInt8Calibrator::waitAndSetDone() {
113 mutex_lock lock(cond_mtx_);
114 // Wait while the queue is full or calibration is running, so we don't miss
115 // the last batch.
116 while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock);
117 if (!done_) {
118 done_ = true;
119 cond_.notify_all();
120 dev_buffers_.clear();
121 }
122 }
123
readCalibrationCache(std::size_t & length)124 const void* TRTInt8Calibrator::readCalibrationCache(
125 std::size_t& length) noexcept {
126 if (calibration_table_.empty()) return nullptr;
127 length = calibration_table_.size();
128 return calibration_table_.data();
129 }
130
setDone()131 void TRTInt8Calibrator::setDone() {
132 mutex_lock lock(cond_mtx_);
133 done_ = true;
134 cond_.notify_all();
135 }
136
writeCalibrationCache(const void * ptr,std::size_t length)137 void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
138 std::size_t length) noexcept {
139 calibration_table_ = string(static_cast<const char*>(ptr), length);
140 VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr
141 << " length=" << length;
142 }
~TRTInt8Calibrator()143 TRTInt8Calibrator::~TRTInt8Calibrator() {
144 VLOG(1) << "Destroying calibrator for " << engine_name_;
145 }
146
147 } // namespace tensorrt
148 } // namespace tensorflow
149
150 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
151