xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
17 
18 #include <atomic>
19 #include <unordered_map>
20 
21 #include "tensorflow/core/platform/logging.h"
22 
23 #if GOOGLE_CUDA && GOOGLE_TENSORRT
24 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
25 
26 namespace tensorflow {
27 namespace tensorrt {
28 
29 // set the batch size before constructing the thread to execute engine
getBatchSize() const30 int TRTInt8Calibrator::getBatchSize() const noexcept { return batch_size_; }
31 
TRTInt8Calibrator(const std::unordered_map<string,std::pair<void *,size_t>> & dev_buffers,int batch_size,string engine_name)32 TRTInt8Calibrator::TRTInt8Calibrator(
33     const std::unordered_map<string, std::pair<void*, size_t>>& dev_buffers,
34     int batch_size, string engine_name)
35     : batch_size_(batch_size),
36       done_(false),
37       dev_buffers_(dev_buffers),
38       // Make sure setBatch() waits until getBatch() is called (the first time).
39       calib_running_(true),
40       batch_is_set_(false),
41       engine_name_(engine_name) {}
42 
TRTInt8Calibrator(const string & calib_data)43 TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data)
44     : batch_size_(0),
45       done_(true),
46       calib_running_(false),
47       batch_is_set_(false),
48       calibration_table_(calib_data) {}
49 
setBatch(const std::unordered_map<string,void * > & data,const cudaStream_t stream)50 bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
51                                  const cudaStream_t stream) {
52   mutex_lock lock(cond_mtx_);
53 
54   // Wait while the queue is full or calibration is running.
55   while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock);
56   if (done_) return false;
57   CHECK(!calib_running_ && !batch_is_set_);
58   VLOG(1) << "Set Batch Waiting finished";
59 
60   // Sets the batch.
61   for (const auto& it : data) {
62     auto devptr = dev_buffers_.find(it.first);
63     if (devptr == dev_buffers_.end()) {
64       LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first
65                  << "' does not match with the buffer names";
66     }
67     const auto& d = devptr->second;
68 
69     // TODO(sami,aaroey): Need to figure out a way to ensure synchronization
70     // between stream, perhaps using a tensor?
71     auto status = cudaMemcpyAsync(d.first, it.second, d.second,
72                                   cudaMemcpyDeviceToDevice, stream);
73     if (status != cudaSuccess) {
74       LOG(FATAL) << "cudaMemcpy " << engine_name_ << " for '" << it.first
75                  << "' failed with " << status;
76     }
77   }
78 
79   // TODO(Sami, aaorey): Find an alternative way!
80   // we have to wait for the stream before returning!
81   cudaStreamSynchronize(stream);
82   batch_is_set_ = true;
83   cond_.notify_all();
84   return true;
85 }
86 
getBatch(void ** bindings,const char ** names,int num_bindings)87 bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
88                                  int num_bindings) noexcept {
89   mutex_lock lock(cond_mtx_);
90   // Notify finish of last round of calibration.
91   calib_running_ = false;
92   cond_.notify_all();
93 
94   // Wait until new batch arrives
95   while ((!batch_is_set_ && !done_)) cond_.wait(lock);
96   if (done_) return false;
97 
98   // Gets the batch
99   for (int i = 0; i < num_bindings; i++) {
100     auto it = dev_buffers_.find(names[i]);
101     if (it == dev_buffers_.end()) {
102       LOG(FATAL) << "Calibration engine asked for unknown tensor name '"
103                  << names[i] << "' at position " << i;
104     }
105     bindings[i] = it->second.first;
106   }
107   batch_is_set_ = false;
108   calib_running_ = true;
109   return true;
110 }
111 
waitAndSetDone()112 void TRTInt8Calibrator::waitAndSetDone() {
113   mutex_lock lock(cond_mtx_);
114   // Wait while the queue is full or calibration is running, so we don't miss
115   // the last batch.
116   while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock);
117   if (!done_) {
118     done_ = true;
119     cond_.notify_all();
120     dev_buffers_.clear();
121   }
122 }
123 
readCalibrationCache(std::size_t & length)124 const void* TRTInt8Calibrator::readCalibrationCache(
125     std::size_t& length) noexcept {
126   if (calibration_table_.empty()) return nullptr;
127   length = calibration_table_.size();
128   return calibration_table_.data();
129 }
130 
setDone()131 void TRTInt8Calibrator::setDone() {
132   mutex_lock lock(cond_mtx_);
133   done_ = true;
134   cond_.notify_all();
135 }
136 
writeCalibrationCache(const void * ptr,std::size_t length)137 void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
138                                               std::size_t length) noexcept {
139   calibration_table_ = string(static_cast<const char*>(ptr), length);
140   VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr
141           << " length=" << length;
142 }
~TRTInt8Calibrator()143 TRTInt8Calibrator::~TRTInt8Calibrator() {
144   VLOG(1) << "Destroying calibrator for " << engine_name_;
145 }
146 
147 }  // namespace tensorrt
148 }  // namespace tensorflow
149 
150 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
151