1 /*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #define ATRACE_TAG (ATRACE_TAG_THERMAL | ATRACE_TAG_HAL)
17
18 #include "virtualtemp_estimator.h"
19
20 #include <android-base/logging.h>
21 #include <android-base/stringprintf.h>
22 #include <dlfcn.h>
23 #include <json/reader.h>
24 #include <utils/Trace.h>
25
26 #include <cmath>
27 #include <sstream>
28 #include <vector>
29
30 namespace thermal {
31 namespace vtestimator {
32 namespace {
getFloatFromValue(const Json::Value & value)33 float getFloatFromValue(const Json::Value &value) {
34 if (value.isString()) {
35 return std::atof(value.asString().c_str());
36 } else {
37 return value.asFloat();
38 }
39 }
40
getInputRangeInfoFromJsonValues(const Json::Value & values,InputRangeInfo * input_range_info)41 bool getInputRangeInfoFromJsonValues(const Json::Value &values, InputRangeInfo *input_range_info) {
42 if (values.size() != 2) {
43 LOG(ERROR) << "Data Range Values size: " << values.size() << "is invalid.";
44 return false;
45 }
46
47 float min_val = getFloatFromValue(values[0]);
48 float max_val = getFloatFromValue(values[1]);
49
50 if (std::isnan(min_val) || std::isnan(max_val)) {
51 LOG(ERROR) << "Illegal data range: thresholds not defined properly " << min_val << " : "
52 << max_val;
53 return false;
54 }
55
56 if (min_val > max_val) {
57 LOG(ERROR) << "Illegal data range: data_min_threshold(" << min_val
58 << ") > data_max_threshold(" << max_val << ")";
59 return false;
60 }
61 input_range_info->min_threshold = min_val;
62 input_range_info->max_threshold = max_val;
63 LOG(INFO) << "Data Range Info: " << input_range_info->min_threshold
64 << " <= val <= " << input_range_info->max_threshold;
65 return true;
66 }
67
CalculateOffset(const std::vector<float> & offset_thresholds,const std::vector<float> & offset_values,const float value)68 float CalculateOffset(const std::vector<float> &offset_thresholds,
69 const std::vector<float> &offset_values, const float value) {
70 for (int i = offset_thresholds.size(); i > 0; --i) {
71 if (offset_thresholds[i - 1] < value) {
72 return offset_values[i - 1];
73 }
74 }
75
76 return 0;
77 }
78 } // namespace
79
DumpTraces()80 VtEstimatorStatus VirtualTempEstimator::DumpTraces() {
81 if (type != kUseMLModel) {
82 return kVtEstimatorUnSupported;
83 }
84
85 if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
86 LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr during DumpTraces\n";
87 return kVtEstimatorInitFailed;
88 }
89
90 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
91
92 if (!common_instance_->is_initialized) {
93 LOG(ERROR) << "tflite_instance_ not initialized for " << common_instance_->sensor_name;
94 return kVtEstimatorInitFailed;
95 }
96
97 // get model input/output buffers
98 float *model_input = tflite_instance_->input_buffer;
99 float *model_output = tflite_instance_->output_buffer;
100 auto input_buffer_size = tflite_instance_->input_buffer_size;
101 auto output_buffer_size = tflite_instance_->output_buffer_size;
102
103 // In Case of use_prev_samples, inputs are available in order in scratch buffer
104 if (common_instance_->use_prev_samples) {
105 model_input = tflite_instance_->scratch_buffer;
106 }
107
108 // Add traces for model input/output buffers
109 std::string sensor_name = common_instance_->sensor_name;
110 for (size_t i = 0; i < input_buffer_size; ++i) {
111 ATRACE_INT((sensor_name + "_input_" + std::to_string(i)).c_str(),
112 static_cast<int>(model_input[i]));
113 }
114
115 for (size_t i = 0; i < output_buffer_size; ++i) {
116 ATRACE_INT((sensor_name + "_output_" + std::to_string(i)).c_str(),
117 static_cast<int>(model_output[i]));
118 }
119
120 // log input data and output data buffers
121 std::string input_data_str = "model_input_buffer: [";
122 for (size_t i = 0; i < input_buffer_size; ++i) {
123 input_data_str += ::android::base::StringPrintf("%0.2f ", model_input[i]);
124 }
125 input_data_str += "]";
126 LOG(INFO) << input_data_str;
127
128 std::string output_data_str = "model_output_buffer: [";
129 for (size_t i = 0; i < output_buffer_size; ++i) {
130 output_data_str += ::android::base::StringPrintf("%0.2f ", model_output[i]);
131 }
132 output_data_str += "]";
133 LOG(INFO) << output_data_str;
134
135 return kVtEstimatorOk;
136 }
137
LoadTFLiteWrapper()138 void VirtualTempEstimator::LoadTFLiteWrapper() {
139 if (!tflite_instance_) {
140 LOG(ERROR) << "tflite_instance_ is nullptr during LoadTFLiteWrapper";
141 return;
142 }
143
144 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
145
146 void *mLibHandle = dlopen("/vendor/lib64/libthermal_tflite_wrapper.so", 0);
147 if (mLibHandle == nullptr) {
148 LOG(ERROR) << "Could not load libthermal_tflite_wrapper library with error: " << dlerror();
149 return;
150 }
151
152 tflite_instance_->tflite_methods.create =
153 reinterpret_cast<tflitewrapper_create>(dlsym(mLibHandle, "ThermalTfliteCreate"));
154 if (!tflite_instance_->tflite_methods.create) {
155 LOG(ERROR) << "Could not link and cast tflitewrapper_create with error: " << dlerror();
156 }
157
158 tflite_instance_->tflite_methods.init =
159 reinterpret_cast<tflitewrapper_init>(dlsym(mLibHandle, "ThermalTfliteInit"));
160 if (!tflite_instance_->tflite_methods.init) {
161 LOG(ERROR) << "Could not link and cast tflitewrapper_init with error: " << dlerror();
162 }
163
164 tflite_instance_->tflite_methods.invoke =
165 reinterpret_cast<tflitewrapper_invoke>(dlsym(mLibHandle, "ThermalTfliteInvoke"));
166 if (!tflite_instance_->tflite_methods.invoke) {
167 LOG(ERROR) << "Could not link and cast tflitewrapper_invoke with error: " << dlerror();
168 }
169
170 tflite_instance_->tflite_methods.destroy =
171 reinterpret_cast<tflitewrapper_destroy>(dlsym(mLibHandle, "ThermalTfliteDestroy"));
172 if (!tflite_instance_->tflite_methods.destroy) {
173 LOG(ERROR) << "Could not link and cast tflitewrapper_destroy with error: " << dlerror();
174 }
175
176 tflite_instance_->tflite_methods.get_input_config_size =
177 reinterpret_cast<tflitewrapper_get_input_config_size>(
178 dlsym(mLibHandle, "ThermalTfliteGetInputConfigSize"));
179 if (!tflite_instance_->tflite_methods.get_input_config_size) {
180 LOG(ERROR) << "Could not link and cast tflitewrapper_get_input_config_size with error: "
181 << dlerror();
182 }
183
184 tflite_instance_->tflite_methods.get_input_config =
185 reinterpret_cast<tflitewrapper_get_input_config>(
186 dlsym(mLibHandle, "ThermalTfliteGetInputConfig"));
187 if (!tflite_instance_->tflite_methods.get_input_config) {
188 LOG(ERROR) << "Could not link and cast tflitewrapper_get_input_config with error: "
189 << dlerror();
190 }
191 }
192
VirtualTempEstimator(std::string_view sensor_name,VtEstimationType estimationType,size_t num_linked_sensors)193 VirtualTempEstimator::VirtualTempEstimator(std::string_view sensor_name,
194 VtEstimationType estimationType,
195 size_t num_linked_sensors) {
196 type = estimationType;
197
198 common_instance_ = std::make_unique<VtEstimatorCommonData>(sensor_name, num_linked_sensors);
199 if (estimationType == kUseMLModel) {
200 tflite_instance_ = std::make_unique<VtEstimatorTFLiteData>();
201 LoadTFLiteWrapper();
202 } else if (estimationType == kUseLinearModel) {
203 linear_model_instance_ = std::make_unique<VtEstimatorLinearModelData>();
204 } else {
205 LOG(ERROR) << "Unsupported estimationType [" << estimationType << "]";
206 }
207 }
208
~VirtualTempEstimator()209 VirtualTempEstimator::~VirtualTempEstimator() {
210 LOG(INFO) << "VirtualTempEstimator destructor";
211 }
212
LinearModelInitialize(LinearModelInitData data)213 VtEstimatorStatus VirtualTempEstimator::LinearModelInitialize(LinearModelInitData data) {
214 if (linear_model_instance_ == nullptr || common_instance_ == nullptr) {
215 LOG(ERROR) << "linear_model_instance_ or common_instance_ is nullptr during Initialize";
216 return kVtEstimatorInitFailed;
217 }
218
219 size_t num_linked_sensors = common_instance_->num_linked_sensors;
220 std::unique_lock<std::mutex> lock(linear_model_instance_->mutex);
221
222 if ((num_linked_sensors == 0) || (data.coefficients.size() == 0) ||
223 (data.prev_samples_order == 0)) {
224 LOG(ERROR) << "Invalid num_linked_sensors [" << num_linked_sensors
225 << "] or coefficients.size() [" << data.coefficients.size()
226 << "] or prev_samples_order [" << data.prev_samples_order << "]";
227 return kVtEstimatorInitFailed;
228 }
229
230 if (data.coefficients.size() != (num_linked_sensors * data.prev_samples_order)) {
231 LOG(ERROR) << "In valid args coefficients.size()[" << data.coefficients.size()
232 << "] num_linked_sensors [" << num_linked_sensors << "] prev_samples_order["
233 << data.prev_samples_order << "]";
234 return kVtEstimatorInvalidArgs;
235 }
236
237 common_instance_->use_prev_samples = data.use_prev_samples;
238 common_instance_->prev_samples_order = data.prev_samples_order;
239
240 linear_model_instance_->input_samples.reserve(common_instance_->prev_samples_order);
241 linear_model_instance_->coefficients.reserve(common_instance_->prev_samples_order);
242
243 // Store coefficients
244 for (size_t i = 0; i < data.prev_samples_order; ++i) {
245 std::vector<float> single_order_coefficients;
246 for (size_t j = 0; j < num_linked_sensors; ++j) {
247 single_order_coefficients.emplace_back(data.coefficients[i * num_linked_sensors + j]);
248 }
249 linear_model_instance_->coefficients.emplace_back(single_order_coefficients);
250 }
251
252 common_instance_->offset_thresholds = data.offset_thresholds;
253 common_instance_->offset_values = data.offset_values;
254 common_instance_->is_initialized = true;
255
256 return kVtEstimatorOk;
257 }
258
TFliteInitialize(MLModelInitData data)259 VtEstimatorStatus VirtualTempEstimator::TFliteInitialize(MLModelInitData data) {
260 if (!tflite_instance_ || !common_instance_) {
261 LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr during Initialize\n";
262 return kVtEstimatorInitFailed;
263 }
264
265 std::string_view sensor_name = common_instance_->sensor_name;
266 size_t num_linked_sensors = common_instance_->num_linked_sensors;
267 bool use_prev_samples = data.use_prev_samples;
268 size_t prev_samples_order = data.prev_samples_order;
269 size_t num_hot_spots = data.num_hot_spots;
270 size_t output_label_count = data.output_label_count;
271
272 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
273
274 if (data.model_path.empty()) {
275 LOG(ERROR) << "Invalid model_path:" << data.model_path << " for " << sensor_name;
276 return kVtEstimatorInvalidArgs;
277 }
278
279 if (num_linked_sensors == 0 || prev_samples_order < 1 ||
280 (!use_prev_samples && prev_samples_order > 1)) {
281 LOG(ERROR) << "Invalid tflite_instance_ config: " << "number of linked sensor: "
282 << num_linked_sensors << " use previous: " << use_prev_samples
283 << " previous sample order: " << prev_samples_order << " for " << sensor_name;
284 return kVtEstimatorInitFailed;
285 }
286
287 common_instance_->use_prev_samples = data.use_prev_samples;
288 common_instance_->prev_samples_order = prev_samples_order;
289 tflite_instance_->support_under_sampling = data.support_under_sampling;
290 tflite_instance_->enable_input_validation = data.enable_input_validation;
291 tflite_instance_->input_buffer_size = num_linked_sensors * prev_samples_order;
292 tflite_instance_->input_buffer = new float[tflite_instance_->input_buffer_size];
293 if (common_instance_->use_prev_samples) {
294 tflite_instance_->scratch_buffer = new float[tflite_instance_->input_buffer_size];
295 }
296
297 if (output_label_count < 1 || num_hot_spots < 1) {
298 LOG(ERROR) << "Invalid tflite_instance_ config:" << "number of hot spots: " << num_hot_spots
299 << " predicted sample order: " << output_label_count << " for " << sensor_name;
300 return kVtEstimatorInitFailed;
301 }
302
303 tflite_instance_->output_label_count = output_label_count;
304 tflite_instance_->num_hot_spots = num_hot_spots;
305 tflite_instance_->output_buffer_size = output_label_count * num_hot_spots;
306 tflite_instance_->output_buffer = new float[tflite_instance_->output_buffer_size];
307
308 if (!tflite_instance_->tflite_methods.create || !tflite_instance_->tflite_methods.init ||
309 !tflite_instance_->tflite_methods.invoke || !tflite_instance_->tflite_methods.destroy ||
310 !tflite_instance_->tflite_methods.get_input_config_size ||
311 !tflite_instance_->tflite_methods.get_input_config) {
312 LOG(ERROR) << "Invalid tflite methods for " << sensor_name;
313 return kVtEstimatorInitFailed;
314 }
315
316 tflite_instance_->tflite_wrapper =
317 tflite_instance_->tflite_methods.create(kNumInputTensors, kNumOutputTensors);
318 if (!tflite_instance_->tflite_wrapper) {
319 LOG(ERROR) << "Failed to create tflite wrapper for " << sensor_name;
320 return kVtEstimatorInitFailed;
321 }
322
323 int ret = tflite_instance_->tflite_methods.init(tflite_instance_->tflite_wrapper,
324 data.model_path.c_str());
325 if (ret) {
326 LOG(ERROR) << "Failed to Init tflite_wrapper for " << sensor_name << " (ret: " << ret
327 << ")";
328 return kVtEstimatorInitFailed;
329 }
330
331 Json::Value input_config;
332 if (!GetInputConfig(&input_config)) {
333 LOG(ERROR) << "Get Input Config failed for " << sensor_name;
334 return kVtEstimatorInitFailed;
335 }
336
337 if (!ParseInputConfig(input_config)) {
338 LOG(ERROR) << "Parse Input Config failed for " << sensor_name;
339 return kVtEstimatorInitFailed;
340 }
341
342 if (tflite_instance_->enable_input_validation && !tflite_instance_->input_range.size()) {
343 LOG(ERROR) << "Input ranges missing when input data validation is enabled for "
344 << sensor_name;
345 return kVtEstimatorInitFailed;
346 }
347
348 common_instance_->offset_thresholds = data.offset_thresholds;
349 common_instance_->offset_values = data.offset_values;
350 tflite_instance_->model_path = data.model_path;
351
352 common_instance_->is_initialized = true;
353 LOG(INFO) << "Successfully initialized VirtualTempEstimator for " << sensor_name;
354 return kVtEstimatorOk;
355 }
356
LinearModelEstimate(const std::vector<float> & thermistors,std::vector<float> * output)357 VtEstimatorStatus VirtualTempEstimator::LinearModelEstimate(const std::vector<float> &thermistors,
358 std::vector<float> *output) {
359 if (linear_model_instance_ == nullptr || common_instance_ == nullptr) {
360 LOG(ERROR) << "linear_model_instance_ or common_instance_ is nullptr during Initialize";
361 return kVtEstimatorInitFailed;
362 }
363
364 std::string_view sensor_name = common_instance_->sensor_name;
365 size_t prev_samples_order = common_instance_->prev_samples_order;
366 size_t num_linked_sensors = common_instance_->num_linked_sensors;
367
368 std::unique_lock<std::mutex> lock(linear_model_instance_->mutex);
369
370 if ((thermistors.size() != num_linked_sensors) || (output == nullptr)) {
371 LOG(ERROR) << "Invalid args Thermistors size[" << thermistors.size()
372 << "] num_linked_sensors[" << num_linked_sensors << "] output[" << output << "]"
373 << " for " << sensor_name;
374 return kVtEstimatorInvalidArgs;
375 }
376
377 if (common_instance_->is_initialized == false) {
378 LOG(ERROR) << "tflite_instance_ not initialized for " << sensor_name;
379 return kVtEstimatorInitFailed;
380 }
381
382 // For the first iteration copy current inputs to all previous inputs
383 // This would allow the estimator to have previous samples from the first iteration itself
384 // and provide a valid predicted value
385 if (common_instance_->cur_sample_count == 0) {
386 for (size_t i = 0; i < prev_samples_order; ++i) {
387 linear_model_instance_->input_samples[i] = thermistors;
388 }
389 }
390
391 size_t cur_sample_index = common_instance_->cur_sample_count % prev_samples_order;
392 linear_model_instance_->input_samples[cur_sample_index] = thermistors;
393
394 // Calculate Weighted Average Value
395 int input_level = cur_sample_index;
396 float estimated_value = 0;
397 for (size_t i = 0; i < prev_samples_order; ++i) {
398 for (size_t j = 0; j < num_linked_sensors; ++j) {
399 estimated_value += linear_model_instance_->coefficients[i][j] *
400 linear_model_instance_->input_samples[input_level][j];
401 }
402 input_level--; // go to previous samples
403 input_level = (input_level >= 0) ? input_level : (prev_samples_order - 1);
404 }
405
406 // Update sample count
407 common_instance_->cur_sample_count++;
408
409 // add offset to estimated value if applicable
410 estimated_value += CalculateOffset(common_instance_->offset_thresholds,
411 common_instance_->offset_values, estimated_value);
412
413 std::vector<float> data = {estimated_value};
414 *output = data;
415 return kVtEstimatorOk;
416 }
417
TFliteEstimate(const std::vector<float> & thermistors,std::vector<float> * output)418 VtEstimatorStatus VirtualTempEstimator::TFliteEstimate(const std::vector<float> &thermistors,
419 std::vector<float> *output) {
420 if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
421 LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr during Estimate\n";
422 return kVtEstimatorInitFailed;
423 }
424
425 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
426
427 if (!common_instance_->is_initialized) {
428 LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
429 return kVtEstimatorInitFailed;
430 }
431
432 std::string_view sensor_name = common_instance_->sensor_name;
433 size_t num_linked_sensors = common_instance_->num_linked_sensors;
434 if ((thermistors.size() != num_linked_sensors) || (output == nullptr)) {
435 LOG(ERROR) << "Invalid args for " << sensor_name
436 << " thermistors.size(): " << thermistors.size()
437 << " num_linked_sensors: " << num_linked_sensors << " output: " << output;
438 return kVtEstimatorInvalidArgs;
439 }
440
441 // log input data
442 std::string input_data_str = "model_input: [";
443 for (size_t i = 0; i < num_linked_sensors; ++i) {
444 input_data_str += ::android::base::StringPrintf("%0.2f ", thermistors[i]);
445 }
446 input_data_str += "]";
447 LOG(INFO) << sensor_name << ": " << input_data_str;
448
449 // check time gap between samples and ignore stale previous samples
450 if (std::chrono::duration_cast<std::chrono::milliseconds>(boot_clock::now() -
451 tflite_instance_->prev_sample_time) >=
452 tflite_instance_->max_sample_interval) {
453 LOG(INFO) << "Ignoring stale previous samples for " << sensor_name;
454 common_instance_->cur_sample_count = 0;
455 }
456
457 // copy input data into input tensors
458 size_t prev_samples_order = common_instance_->prev_samples_order;
459 size_t cur_sample_index = common_instance_->cur_sample_count % prev_samples_order;
460 size_t sample_start_index = cur_sample_index * num_linked_sensors;
461 for (size_t i = 0; i < num_linked_sensors; ++i) {
462 if (tflite_instance_->enable_input_validation) {
463 if (thermistors[i] < tflite_instance_->input_range[i].min_threshold ||
464 thermistors[i] > tflite_instance_->input_range[i].max_threshold) {
465 LOG(INFO) << "thermistors[" << i << "] value: " << thermistors[i]
466 << " not in range: " << tflite_instance_->input_range[i].min_threshold
467 << " <= val <= " << tflite_instance_->input_range[i].max_threshold
468 << " for " << sensor_name;
469 common_instance_->cur_sample_count = 0;
470 return kVtEstimatorLowConfidence;
471 }
472 }
473 tflite_instance_->input_buffer[sample_start_index + i] = thermistors[i];
474 if (cur_sample_index == 0 && tflite_instance_->support_under_sampling) {
475 // fill previous samples if support under sampling
476 for (size_t j = 1; j < prev_samples_order; ++j) {
477 size_t copy_start_index = j * num_linked_sensors;
478 tflite_instance_->input_buffer[copy_start_index + i] = thermistors[i];
479 }
480 }
481 }
482
483 // Update sample count
484 common_instance_->cur_sample_count++;
485 tflite_instance_->prev_sample_time = boot_clock::now();
486 if ((common_instance_->cur_sample_count < prev_samples_order) &&
487 !(tflite_instance_->support_under_sampling)) {
488 return kVtEstimatorUnderSampling;
489 }
490
491 // prepare model input
492 float *model_input;
493 size_t input_buffer_size = tflite_instance_->input_buffer_size;
494 size_t output_buffer_size = tflite_instance_->output_buffer_size;
495 if (!common_instance_->use_prev_samples) {
496 model_input = tflite_instance_->input_buffer;
497 } else {
498 sample_start_index = ((cur_sample_index + 1) * num_linked_sensors) % input_buffer_size;
499 for (size_t i = 0; i < input_buffer_size; ++i) {
500 size_t input_index = (sample_start_index + i) % input_buffer_size;
501 tflite_instance_->scratch_buffer[i] = tflite_instance_->input_buffer[input_index];
502 }
503 model_input = tflite_instance_->scratch_buffer;
504 }
505
506 int ret = tflite_instance_->tflite_methods.invoke(
507 tflite_instance_->tflite_wrapper, model_input, input_buffer_size,
508 tflite_instance_->output_buffer, output_buffer_size);
509 if (ret) {
510 LOG(ERROR) << "Failed to Invoke for " << sensor_name << " (ret: " << ret << ")";
511 return kVtEstimatorInvokeFailed;
512 }
513 tflite_instance_->last_update_time = boot_clock::now();
514
515 // prepare output
516 std::vector<float> data;
517 std::ostringstream model_out_log, predict_log;
518 data.reserve(output_buffer_size);
519 for (size_t i = 0; i < output_buffer_size; ++i) {
520 // add offset to predicted value
521 float predicted_value = tflite_instance_->output_buffer[i];
522 model_out_log << predicted_value << " ";
523 predicted_value += CalculateOffset(common_instance_->offset_thresholds,
524 common_instance_->offset_values, predicted_value);
525 predict_log << predicted_value << " ";
526 data.emplace_back(predicted_value);
527 }
528 LOG(INFO) << sensor_name << ": model_output: [" << model_out_log.str() << "]";
529 LOG(INFO) << sensor_name << ": predicted_value: [" << predict_log.str() << "]";
530 *output = data;
531
532 return kVtEstimatorOk;
533 }
534
Estimate(const std::vector<float> & thermistors,std::vector<float> * output)535 VtEstimatorStatus VirtualTempEstimator::Estimate(const std::vector<float> &thermistors,
536 std::vector<float> *output) {
537 if (type == kUseMLModel) {
538 return TFliteEstimate(thermistors, output);
539 } else if (type == kUseLinearModel) {
540 return LinearModelEstimate(thermistors, output);
541 }
542
543 LOG(ERROR) << "Unsupported estimationType [" << type << "]";
544 return kVtEstimatorUnSupported;
545 }
546
TFliteGetMaxPredictWindowMs(size_t * predict_window_ms)547 VtEstimatorStatus VirtualTempEstimator::TFliteGetMaxPredictWindowMs(size_t *predict_window_ms) {
548 if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
549 LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr for predict window\n";
550 return kVtEstimatorInitFailed;
551 }
552
553 if (!common_instance_->is_initialized) {
554 LOG(ERROR) << "tflite_instance_ not initialized for " << common_instance_->sensor_name;
555 return kVtEstimatorInitFailed;
556 }
557
558 size_t window = tflite_instance_->predict_window_ms;
559 if (window == 0) {
560 return kVtEstimatorUnSupported;
561 }
562 *predict_window_ms = window;
563 return kVtEstimatorOk;
564 }
565
GetMaxPredictWindowMs(size_t * predict_window_ms)566 VtEstimatorStatus VirtualTempEstimator::GetMaxPredictWindowMs(size_t *predict_window_ms) {
567 if (type == kUseMLModel) {
568 return TFliteGetMaxPredictWindowMs(predict_window_ms);
569 }
570
571 LOG(ERROR) << "Unsupported estimationType [" << type << "]";
572 return kVtEstimatorUnSupported;
573 }
574
TFlitePredictAfterTimeMs(const size_t time_ms,float * output)575 VtEstimatorStatus VirtualTempEstimator::TFlitePredictAfterTimeMs(const size_t time_ms,
576 float *output) {
577 if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
578 LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr for predict window\n";
579 return kVtEstimatorInitFailed;
580 }
581
582 if (!common_instance_->is_initialized) {
583 LOG(ERROR) << "tflite_instance_ not initialized for " << common_instance_->sensor_name;
584 return kVtEstimatorInitFailed;
585 }
586
587 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
588
589 size_t window = tflite_instance_->predict_window_ms;
590 auto sample_interval = tflite_instance_->sample_interval;
591 auto last_update_time = tflite_instance_->last_update_time;
592 auto request_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(boot_clock::now() -
593 last_update_time);
594 // check for under sampling
595 if ((common_instance_->cur_sample_count < common_instance_->prev_samples_order) &&
596 !(tflite_instance_->support_under_sampling)) {
597 LOG(INFO) << tflite_instance_->model_path
598 << " cannot provide prediction while under sampling";
599 return kVtEstimatorUnderSampling;
600 }
601
602 // calculate requested time since last update
603 request_time_ms = request_time_ms + std::chrono::milliseconds{time_ms};
604 if (sample_interval.count() == 0 || window == 0 ||
605 window < static_cast<size_t>(request_time_ms.count())) {
606 LOG(INFO) << tflite_instance_->model_path << " cannot predict temperature after ("
607 << time_ms << " + " << request_time_ms.count() - time_ms
608 << ") ms since last update with sample interval [" << sample_interval.count()
609 << "] ms and predict window [" << window << "] ms";
610 return kVtEstimatorUnSupported;
611 }
612
613 size_t request_step = request_time_ms / sample_interval;
614 size_t output_label_count = tflite_instance_->output_label_count;
615 float *output_buffer = tflite_instance_->output_buffer;
616 float prediction;
617 if (request_step == output_label_count - 1) {
618 // request prediction is on the right boundary of the window
619 prediction = output_buffer[output_label_count - 1];
620 } else {
621 float left = output_buffer[request_step], right = output_buffer[request_step + 1];
622 prediction = left;
623 if (left != right) {
624 prediction += (request_time_ms - sample_interval * request_step) * (right - left) /
625 sample_interval;
626 }
627 }
628
629 *output = prediction;
630
631 return kVtEstimatorOk;
632 }
633
PredictAfterTimeMs(const size_t time_ms,float * output)634 VtEstimatorStatus VirtualTempEstimator::PredictAfterTimeMs(const size_t time_ms, float *output) {
635 if (type == kUseMLModel) {
636 return TFlitePredictAfterTimeMs(time_ms, output);
637 }
638
639 LOG(ERROR) << "PredictAfterTimeMs not supported for type [" << type << "]";
640 return kVtEstimatorUnSupported;
641 }
642
TFliteGetAllPredictions(std::vector<float> * output)643 VtEstimatorStatus VirtualTempEstimator::TFliteGetAllPredictions(std::vector<float> *output) {
644 if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
645 LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr for predict window\n";
646 return kVtEstimatorInitFailed;
647 }
648
649 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
650
651 if (!common_instance_->is_initialized) {
652 LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
653 return kVtEstimatorInitFailed;
654 }
655
656 if (output == nullptr) {
657 LOG(ERROR) << "output is nullptr";
658 return kVtEstimatorInvalidArgs;
659 }
660
661 std::vector<float> tflite_output;
662 size_t output_buffer_size = tflite_instance_->output_buffer_size;
663 tflite_output.reserve(output_buffer_size);
664 for (size_t i = 0; i < output_buffer_size; ++i) {
665 tflite_output.emplace_back(tflite_instance_->output_buffer[i]);
666 }
667 *output = tflite_output;
668
669 return kVtEstimatorOk;
670 }
671
GetAllPredictions(std::vector<float> * output)672 VtEstimatorStatus VirtualTempEstimator::GetAllPredictions(std::vector<float> *output) {
673 if (type == kUseMLModel) {
674 return TFliteGetAllPredictions(output);
675 }
676
677 LOG(INFO) << "GetAllPredicts not supported by estimationType [" << type << "]";
678 return kVtEstimatorUnSupported;
679 }
680
TFLiteDumpStatus(std::string_view sensor_name,std::ostringstream * dump_buf)681 VtEstimatorStatus VirtualTempEstimator::TFLiteDumpStatus(std::string_view sensor_name,
682 std::ostringstream *dump_buf) {
683 if (dump_buf == nullptr) {
684 LOG(ERROR) << "dump_buf is nullptr for " << sensor_name;
685 return kVtEstimatorInvalidArgs;
686 }
687
688 if (!common_instance_->is_initialized) {
689 LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
690 return kVtEstimatorInitFailed;
691 }
692
693 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
694
695 *dump_buf << " Sensor Name: " << sensor_name << std::endl;
696 *dump_buf << " Current Values: ";
697 size_t output_buffer_size = tflite_instance_->output_buffer_size;
698 for (size_t i = 0; i < output_buffer_size; ++i) {
699 // add offset to predicted value
700 float predicted_value = tflite_instance_->output_buffer[i];
701 predicted_value += CalculateOffset(common_instance_->offset_thresholds,
702 common_instance_->offset_values, predicted_value);
703 *dump_buf << predicted_value << ", ";
704 }
705 *dump_buf << std::endl;
706
707 *dump_buf << " Model Path: \"" << tflite_instance_->model_path << "\"" << std::endl;
708
709 return kVtEstimatorOk;
710 }
711
DumpStatus(std::string_view sensor_name,std::ostringstream * dump_buff)712 VtEstimatorStatus VirtualTempEstimator::DumpStatus(std::string_view sensor_name,
713 std::ostringstream *dump_buff) {
714 if (type == kUseMLModel) {
715 return TFLiteDumpStatus(sensor_name, dump_buff);
716 }
717
718 LOG(INFO) << "DumpStatus not supported by estimationType [" << type << "]";
719 return kVtEstimatorUnSupported;
720 }
721
Initialize(const VtEstimationInitData & data)722 VtEstimatorStatus VirtualTempEstimator::Initialize(const VtEstimationInitData &data) {
723 LOG(INFO) << "Initialize VirtualTempEstimator for " << type;
724
725 if (type == kUseMLModel) {
726 return TFliteInitialize(data.ml_model_init_data);
727 } else if (type == kUseLinearModel) {
728 return LinearModelInitialize(data.linear_model_init_data);
729 }
730
731 LOG(ERROR) << "Unsupported estimationType [" << type << "]";
732 return kVtEstimatorUnSupported;
733 }
734
ParseInputConfig(const Json::Value & input_config)735 bool VirtualTempEstimator::ParseInputConfig(const Json::Value &input_config) {
736 if (!input_config["ModelConfig"].empty()) {
737 if (!input_config["ModelConfig"]["sample_interval_ms"].empty()) {
738 // read input sample interval
739 int sample_interval_ms = input_config["ModelConfig"]["sample_interval_ms"].asInt();
740 if (sample_interval_ms <= 0) {
741 LOG(ERROR) << "Invalid sample_interval_ms: " << sample_interval_ms;
742 return false;
743 }
744
745 tflite_instance_->sample_interval = std::chrono::milliseconds{sample_interval_ms};
746 LOG(INFO) << "Parsed tflite model input sample_interval: " << sample_interval_ms
747 << " for " << common_instance_->sensor_name;
748
749 // determine predict window
750 tflite_instance_->predict_window_ms =
751 sample_interval_ms * (tflite_instance_->output_label_count - 1);
752 LOG(INFO) << "Max prediction window size: " << tflite_instance_->predict_window_ms
753 << " ms for " << common_instance_->sensor_name;
754 }
755
756 if (!input_config["ModelConfig"]["max_sample_interval_ms"].empty()) {
757 // read input max sample interval
758 int max_sample_interval_ms =
759 input_config["ModelConfig"]["max_sample_interval_ms"].asInt();
760 if (max_sample_interval_ms <= 0) {
761 LOG(ERROR) << "Invalid max_sample_interval_ms " << max_sample_interval_ms;
762 return false;
763 }
764
765 tflite_instance_->max_sample_interval =
766 std::chrono::milliseconds{max_sample_interval_ms};
767 LOG(INFO) << "Parsed tflite model max_sample_interval: " << max_sample_interval_ms
768 << " for " << common_instance_->sensor_name;
769 }
770 }
771
772 if (!input_config["InputData"].empty()) {
773 Json::Value input_data = input_config["InputData"];
774 if (input_data.size() != common_instance_->num_linked_sensors) {
775 LOG(ERROR) << "Input ranges size: " << input_data.size()
776 << " does not match num_linked_sensors: "
777 << common_instance_->num_linked_sensors;
778 return false;
779 }
780
781 LOG(INFO) << "Start to parse tflite model input config for "
782 << common_instance_->num_linked_sensors;
783 tflite_instance_->input_range.assign(input_data.size(), InputRangeInfo());
784 for (Json::Value::ArrayIndex i = 0; i < input_data.size(); ++i) {
785 const std::string &name = input_data[i]["Name"].asString();
786 LOG(INFO) << "Sensor[" << i << "] Name: " << name;
787 if (!getInputRangeInfoFromJsonValues(input_data[i]["Range"],
788 &tflite_instance_->input_range[i])) {
789 LOG(ERROR) << "Failed to parse tflite model temp range for sensor: [" << name
790 << "]";
791 return false;
792 }
793 }
794 }
795
796 return true;
797 }
798
GetInputConfig(Json::Value * config)799 bool VirtualTempEstimator::GetInputConfig(Json::Value *config) {
800 int config_size = 0;
801 int ret = tflite_instance_->tflite_methods.get_input_config_size(
802 tflite_instance_->tflite_wrapper, &config_size);
803 if (ret || config_size <= 0) {
804 LOG(ERROR) << "Failed to get tflite input config size (ret: " << ret
805 << ") with size: " << config_size;
806 return false;
807 }
808
809 LOG(INFO) << "Model input config_size: " << config_size << " for "
810 << common_instance_->sensor_name;
811
812 char *config_str = new char[config_size];
813 ret = tflite_instance_->tflite_methods.get_input_config(tflite_instance_->tflite_wrapper,
814 config_str, config_size);
815 if (ret) {
816 LOG(ERROR) << "Failed to get tflite input config (ret: " << ret << ")";
817 delete[] config_str;
818 return false;
819 }
820
821 Json::CharReaderBuilder builder;
822 std::unique_ptr<Json::CharReader> reader(builder.newCharReader());
823 std::string errorMessage;
824
825 bool success = true;
826 if (!reader->parse(config_str, config_str + config_size, config, &errorMessage)) {
827 LOG(ERROR) << "Failed to parse tflite JSON input config: " << errorMessage;
828 success = false;
829 }
830 delete[] config_str;
831 return success;
832 }
833
834 } // namespace vtestimator
835 } // namespace thermal
836