1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #define ATRACE_TAG (ATRACE_TAG_THERMAL | ATRACE_TAG_HAL)
17 
18 #include "virtualtemp_estimator.h"
19 
20 #include <android-base/logging.h>
21 #include <android-base/stringprintf.h>
22 #include <dlfcn.h>
23 #include <json/reader.h>
24 #include <utils/Trace.h>
25 
26 #include <cmath>
27 #include <sstream>
28 #include <vector>
29 
30 namespace thermal {
31 namespace vtestimator {
32 namespace {
getFloatFromValue(const Json::Value & value)33 float getFloatFromValue(const Json::Value &value) {
34     if (value.isString()) {
35         return std::atof(value.asString().c_str());
36     } else {
37         return value.asFloat();
38     }
39 }
40 
getInputRangeInfoFromJsonValues(const Json::Value & values,InputRangeInfo * input_range_info)41 bool getInputRangeInfoFromJsonValues(const Json::Value &values, InputRangeInfo *input_range_info) {
42     if (values.size() != 2) {
43         LOG(ERROR) << "Data Range Values size: " << values.size() << "is invalid.";
44         return false;
45     }
46 
47     float min_val = getFloatFromValue(values[0]);
48     float max_val = getFloatFromValue(values[1]);
49 
50     if (std::isnan(min_val) || std::isnan(max_val)) {
51         LOG(ERROR) << "Illegal data range: thresholds not defined properly " << min_val << " : "
52                    << max_val;
53         return false;
54     }
55 
56     if (min_val > max_val) {
57         LOG(ERROR) << "Illegal data range: data_min_threshold(" << min_val
58                    << ") > data_max_threshold(" << max_val << ")";
59         return false;
60     }
61     input_range_info->min_threshold = min_val;
62     input_range_info->max_threshold = max_val;
63     LOG(INFO) << "Data Range Info: " << input_range_info->min_threshold
64               << " <= val <= " << input_range_info->max_threshold;
65     return true;
66 }
67 
CalculateOffset(const std::vector<float> & offset_thresholds,const std::vector<float> & offset_values,const float value)68 float CalculateOffset(const std::vector<float> &offset_thresholds,
69                       const std::vector<float> &offset_values, const float value) {
70     for (int i = offset_thresholds.size(); i > 0; --i) {
71         if (offset_thresholds[i - 1] < value) {
72             return offset_values[i - 1];
73         }
74     }
75 
76     return 0;
77 }
78 }  // namespace
79 
DumpTraces()80 VtEstimatorStatus VirtualTempEstimator::DumpTraces() {
81     if (type != kUseMLModel) {
82         return kVtEstimatorUnSupported;
83     }
84 
85     if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
86         LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr during DumpTraces\n";
87         return kVtEstimatorInitFailed;
88     }
89 
90     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
91 
92     if (!common_instance_->is_initialized) {
93         LOG(ERROR) << "tflite_instance_ not initialized for " << common_instance_->sensor_name;
94         return kVtEstimatorInitFailed;
95     }
96 
97     // get model input/output buffers
98     float *model_input = tflite_instance_->input_buffer;
99     float *model_output = tflite_instance_->output_buffer;
100     auto input_buffer_size = tflite_instance_->input_buffer_size;
101     auto output_buffer_size = tflite_instance_->output_buffer_size;
102 
103     // In Case of use_prev_samples, inputs are available in order in scratch buffer
104     if (common_instance_->use_prev_samples) {
105         model_input = tflite_instance_->scratch_buffer;
106     }
107 
108     // Add traces for model input/output buffers
109     std::string sensor_name = common_instance_->sensor_name;
110     for (size_t i = 0; i < input_buffer_size; ++i) {
111         ATRACE_INT((sensor_name + "_input_" + std::to_string(i)).c_str(),
112                    static_cast<int>(model_input[i]));
113     }
114 
115     for (size_t i = 0; i < output_buffer_size; ++i) {
116         ATRACE_INT((sensor_name + "_output_" + std::to_string(i)).c_str(),
117                    static_cast<int>(model_output[i]));
118     }
119 
120     // log input data and output data buffers
121     std::string input_data_str = "model_input_buffer: [";
122     for (size_t i = 0; i < input_buffer_size; ++i) {
123         input_data_str += ::android::base::StringPrintf("%0.2f ", model_input[i]);
124     }
125     input_data_str += "]";
126     LOG(INFO) << input_data_str;
127 
128     std::string output_data_str = "model_output_buffer: [";
129     for (size_t i = 0; i < output_buffer_size; ++i) {
130         output_data_str += ::android::base::StringPrintf("%0.2f ", model_output[i]);
131     }
132     output_data_str += "]";
133     LOG(INFO) << output_data_str;
134 
135     return kVtEstimatorOk;
136 }
137 
LoadTFLiteWrapper()138 void VirtualTempEstimator::LoadTFLiteWrapper() {
139     if (!tflite_instance_) {
140         LOG(ERROR) << "tflite_instance_ is nullptr during LoadTFLiteWrapper";
141         return;
142     }
143 
144     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
145 
146     void *mLibHandle = dlopen("/vendor/lib64/libthermal_tflite_wrapper.so", 0);
147     if (mLibHandle == nullptr) {
148         LOG(ERROR) << "Could not load libthermal_tflite_wrapper library with error: " << dlerror();
149         return;
150     }
151 
152     tflite_instance_->tflite_methods.create =
153             reinterpret_cast<tflitewrapper_create>(dlsym(mLibHandle, "ThermalTfliteCreate"));
154     if (!tflite_instance_->tflite_methods.create) {
155         LOG(ERROR) << "Could not link and cast tflitewrapper_create with error: " << dlerror();
156     }
157 
158     tflite_instance_->tflite_methods.init =
159             reinterpret_cast<tflitewrapper_init>(dlsym(mLibHandle, "ThermalTfliteInit"));
160     if (!tflite_instance_->tflite_methods.init) {
161         LOG(ERROR) << "Could not link and cast tflitewrapper_init with error: " << dlerror();
162     }
163 
164     tflite_instance_->tflite_methods.invoke =
165             reinterpret_cast<tflitewrapper_invoke>(dlsym(mLibHandle, "ThermalTfliteInvoke"));
166     if (!tflite_instance_->tflite_methods.invoke) {
167         LOG(ERROR) << "Could not link and cast tflitewrapper_invoke with error: " << dlerror();
168     }
169 
170     tflite_instance_->tflite_methods.destroy =
171             reinterpret_cast<tflitewrapper_destroy>(dlsym(mLibHandle, "ThermalTfliteDestroy"));
172     if (!tflite_instance_->tflite_methods.destroy) {
173         LOG(ERROR) << "Could not link and cast tflitewrapper_destroy with error: " << dlerror();
174     }
175 
176     tflite_instance_->tflite_methods.get_input_config_size =
177             reinterpret_cast<tflitewrapper_get_input_config_size>(
178                     dlsym(mLibHandle, "ThermalTfliteGetInputConfigSize"));
179     if (!tflite_instance_->tflite_methods.get_input_config_size) {
180         LOG(ERROR) << "Could not link and cast tflitewrapper_get_input_config_size with error: "
181                    << dlerror();
182     }
183 
184     tflite_instance_->tflite_methods.get_input_config =
185             reinterpret_cast<tflitewrapper_get_input_config>(
186                     dlsym(mLibHandle, "ThermalTfliteGetInputConfig"));
187     if (!tflite_instance_->tflite_methods.get_input_config) {
188         LOG(ERROR) << "Could not link and cast tflitewrapper_get_input_config with error: "
189                    << dlerror();
190     }
191 }
192 
VirtualTempEstimator(std::string_view sensor_name,VtEstimationType estimationType,size_t num_linked_sensors)193 VirtualTempEstimator::VirtualTempEstimator(std::string_view sensor_name,
194                                            VtEstimationType estimationType,
195                                            size_t num_linked_sensors) {
196     type = estimationType;
197 
198     common_instance_ = std::make_unique<VtEstimatorCommonData>(sensor_name, num_linked_sensors);
199     if (estimationType == kUseMLModel) {
200         tflite_instance_ = std::make_unique<VtEstimatorTFLiteData>();
201         LoadTFLiteWrapper();
202     } else if (estimationType == kUseLinearModel) {
203         linear_model_instance_ = std::make_unique<VtEstimatorLinearModelData>();
204     } else {
205         LOG(ERROR) << "Unsupported estimationType [" << estimationType << "]";
206     }
207 }
208 
~VirtualTempEstimator()209 VirtualTempEstimator::~VirtualTempEstimator() {
210     LOG(INFO) << "VirtualTempEstimator destructor";
211 }
212 
LinearModelInitialize(LinearModelInitData data)213 VtEstimatorStatus VirtualTempEstimator::LinearModelInitialize(LinearModelInitData data) {
214     if (linear_model_instance_ == nullptr || common_instance_ == nullptr) {
215         LOG(ERROR) << "linear_model_instance_ or common_instance_ is nullptr during Initialize";
216         return kVtEstimatorInitFailed;
217     }
218 
219     size_t num_linked_sensors = common_instance_->num_linked_sensors;
220     std::unique_lock<std::mutex> lock(linear_model_instance_->mutex);
221 
222     if ((num_linked_sensors == 0) || (data.coefficients.size() == 0) ||
223         (data.prev_samples_order == 0)) {
224         LOG(ERROR) << "Invalid num_linked_sensors [" << num_linked_sensors
225                    << "] or coefficients.size() [" << data.coefficients.size()
226                    << "] or prev_samples_order [" << data.prev_samples_order << "]";
227         return kVtEstimatorInitFailed;
228     }
229 
230     if (data.coefficients.size() != (num_linked_sensors * data.prev_samples_order)) {
231         LOG(ERROR) << "In valid args coefficients.size()[" << data.coefficients.size()
232                    << "] num_linked_sensors [" << num_linked_sensors << "] prev_samples_order["
233                    << data.prev_samples_order << "]";
234         return kVtEstimatorInvalidArgs;
235     }
236 
237     common_instance_->use_prev_samples = data.use_prev_samples;
238     common_instance_->prev_samples_order = data.prev_samples_order;
239 
240     linear_model_instance_->input_samples.reserve(common_instance_->prev_samples_order);
241     linear_model_instance_->coefficients.reserve(common_instance_->prev_samples_order);
242 
243     // Store coefficients
244     for (size_t i = 0; i < data.prev_samples_order; ++i) {
245         std::vector<float> single_order_coefficients;
246         for (size_t j = 0; j < num_linked_sensors; ++j) {
247             single_order_coefficients.emplace_back(data.coefficients[i * num_linked_sensors + j]);
248         }
249         linear_model_instance_->coefficients.emplace_back(single_order_coefficients);
250     }
251 
252     common_instance_->offset_thresholds = data.offset_thresholds;
253     common_instance_->offset_values = data.offset_values;
254     common_instance_->is_initialized = true;
255 
256     return kVtEstimatorOk;
257 }
258 
TFliteInitialize(MLModelInitData data)259 VtEstimatorStatus VirtualTempEstimator::TFliteInitialize(MLModelInitData data) {
260     if (!tflite_instance_ || !common_instance_) {
261         LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr during Initialize\n";
262         return kVtEstimatorInitFailed;
263     }
264 
265     std::string_view sensor_name = common_instance_->sensor_name;
266     size_t num_linked_sensors = common_instance_->num_linked_sensors;
267     bool use_prev_samples = data.use_prev_samples;
268     size_t prev_samples_order = data.prev_samples_order;
269     size_t num_hot_spots = data.num_hot_spots;
270     size_t output_label_count = data.output_label_count;
271 
272     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
273 
274     if (data.model_path.empty()) {
275         LOG(ERROR) << "Invalid model_path:" << data.model_path << " for " << sensor_name;
276         return kVtEstimatorInvalidArgs;
277     }
278 
279     if (num_linked_sensors == 0 || prev_samples_order < 1 ||
280         (!use_prev_samples && prev_samples_order > 1)) {
281         LOG(ERROR) << "Invalid tflite_instance_ config: " << "number of linked sensor: "
282                    << num_linked_sensors << " use previous: " << use_prev_samples
283                    << " previous sample order: " << prev_samples_order << " for " << sensor_name;
284         return kVtEstimatorInitFailed;
285     }
286 
287     common_instance_->use_prev_samples = data.use_prev_samples;
288     common_instance_->prev_samples_order = prev_samples_order;
289     tflite_instance_->support_under_sampling = data.support_under_sampling;
290     tflite_instance_->enable_input_validation = data.enable_input_validation;
291     tflite_instance_->input_buffer_size = num_linked_sensors * prev_samples_order;
292     tflite_instance_->input_buffer = new float[tflite_instance_->input_buffer_size];
293     if (common_instance_->use_prev_samples) {
294         tflite_instance_->scratch_buffer = new float[tflite_instance_->input_buffer_size];
295     }
296 
297     if (output_label_count < 1 || num_hot_spots < 1) {
298         LOG(ERROR) << "Invalid tflite_instance_ config:" << "number of hot spots: " << num_hot_spots
299                    << " predicted sample order: " << output_label_count << " for " << sensor_name;
300         return kVtEstimatorInitFailed;
301     }
302 
303     tflite_instance_->output_label_count = output_label_count;
304     tflite_instance_->num_hot_spots = num_hot_spots;
305     tflite_instance_->output_buffer_size = output_label_count * num_hot_spots;
306     tflite_instance_->output_buffer = new float[tflite_instance_->output_buffer_size];
307 
308     if (!tflite_instance_->tflite_methods.create || !tflite_instance_->tflite_methods.init ||
309         !tflite_instance_->tflite_methods.invoke || !tflite_instance_->tflite_methods.destroy ||
310         !tflite_instance_->tflite_methods.get_input_config_size ||
311         !tflite_instance_->tflite_methods.get_input_config) {
312         LOG(ERROR) << "Invalid tflite methods for " << sensor_name;
313         return kVtEstimatorInitFailed;
314     }
315 
316     tflite_instance_->tflite_wrapper =
317             tflite_instance_->tflite_methods.create(kNumInputTensors, kNumOutputTensors);
318     if (!tflite_instance_->tflite_wrapper) {
319         LOG(ERROR) << "Failed to create tflite wrapper for " << sensor_name;
320         return kVtEstimatorInitFailed;
321     }
322 
323     int ret = tflite_instance_->tflite_methods.init(tflite_instance_->tflite_wrapper,
324                                                     data.model_path.c_str());
325     if (ret) {
326         LOG(ERROR) << "Failed to Init tflite_wrapper for " << sensor_name << " (ret: " << ret
327                    << ")";
328         return kVtEstimatorInitFailed;
329     }
330 
331     Json::Value input_config;
332     if (!GetInputConfig(&input_config)) {
333         LOG(ERROR) << "Get Input Config failed for " << sensor_name;
334         return kVtEstimatorInitFailed;
335     }
336 
337     if (!ParseInputConfig(input_config)) {
338         LOG(ERROR) << "Parse Input Config failed for " << sensor_name;
339         return kVtEstimatorInitFailed;
340     }
341 
342     if (tflite_instance_->enable_input_validation && !tflite_instance_->input_range.size()) {
343         LOG(ERROR) << "Input ranges missing when input data validation is enabled for "
344                    << sensor_name;
345         return kVtEstimatorInitFailed;
346     }
347 
348     common_instance_->offset_thresholds = data.offset_thresholds;
349     common_instance_->offset_values = data.offset_values;
350     tflite_instance_->model_path = data.model_path;
351 
352     common_instance_->is_initialized = true;
353     LOG(INFO) << "Successfully initialized VirtualTempEstimator for " << sensor_name;
354     return kVtEstimatorOk;
355 }
356 
LinearModelEstimate(const std::vector<float> & thermistors,std::vector<float> * output)357 VtEstimatorStatus VirtualTempEstimator::LinearModelEstimate(const std::vector<float> &thermistors,
358                                                             std::vector<float> *output) {
359     if (linear_model_instance_ == nullptr || common_instance_ == nullptr) {
360         LOG(ERROR) << "linear_model_instance_ or common_instance_ is nullptr during Initialize";
361         return kVtEstimatorInitFailed;
362     }
363 
364     std::string_view sensor_name = common_instance_->sensor_name;
365     size_t prev_samples_order = common_instance_->prev_samples_order;
366     size_t num_linked_sensors = common_instance_->num_linked_sensors;
367 
368     std::unique_lock<std::mutex> lock(linear_model_instance_->mutex);
369 
370     if ((thermistors.size() != num_linked_sensors) || (output == nullptr)) {
371         LOG(ERROR) << "Invalid args Thermistors size[" << thermistors.size()
372                    << "] num_linked_sensors[" << num_linked_sensors << "] output[" << output << "]"
373                    << " for " << sensor_name;
374         return kVtEstimatorInvalidArgs;
375     }
376 
377     if (common_instance_->is_initialized == false) {
378         LOG(ERROR) << "tflite_instance_ not initialized for " << sensor_name;
379         return kVtEstimatorInitFailed;
380     }
381 
382     // For the first iteration copy current inputs to all previous inputs
383     // This would allow the estimator to have previous samples from the first iteration itself
384     // and provide a valid predicted value
385     if (common_instance_->cur_sample_count == 0) {
386         for (size_t i = 0; i < prev_samples_order; ++i) {
387             linear_model_instance_->input_samples[i] = thermistors;
388         }
389     }
390 
391     size_t cur_sample_index = common_instance_->cur_sample_count % prev_samples_order;
392     linear_model_instance_->input_samples[cur_sample_index] = thermistors;
393 
394     // Calculate Weighted Average Value
395     int input_level = cur_sample_index;
396     float estimated_value = 0;
397     for (size_t i = 0; i < prev_samples_order; ++i) {
398         for (size_t j = 0; j < num_linked_sensors; ++j) {
399             estimated_value += linear_model_instance_->coefficients[i][j] *
400                                linear_model_instance_->input_samples[input_level][j];
401         }
402         input_level--;  // go to previous samples
403         input_level = (input_level >= 0) ? input_level : (prev_samples_order - 1);
404     }
405 
406     // Update sample count
407     common_instance_->cur_sample_count++;
408 
409     // add offset to estimated value if applicable
410     estimated_value += CalculateOffset(common_instance_->offset_thresholds,
411                                        common_instance_->offset_values, estimated_value);
412 
413     std::vector<float> data = {estimated_value};
414     *output = data;
415     return kVtEstimatorOk;
416 }
417 
TFliteEstimate(const std::vector<float> & thermistors,std::vector<float> * output)418 VtEstimatorStatus VirtualTempEstimator::TFliteEstimate(const std::vector<float> &thermistors,
419                                                        std::vector<float> *output) {
420     if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
421         LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr during Estimate\n";
422         return kVtEstimatorInitFailed;
423     }
424 
425     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
426 
427     if (!common_instance_->is_initialized) {
428         LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
429         return kVtEstimatorInitFailed;
430     }
431 
432     std::string_view sensor_name = common_instance_->sensor_name;
433     size_t num_linked_sensors = common_instance_->num_linked_sensors;
434     if ((thermistors.size() != num_linked_sensors) || (output == nullptr)) {
435         LOG(ERROR) << "Invalid args for " << sensor_name
436                    << " thermistors.size(): " << thermistors.size()
437                    << " num_linked_sensors: " << num_linked_sensors << " output: " << output;
438         return kVtEstimatorInvalidArgs;
439     }
440 
441     // log input data
442     std::string input_data_str = "model_input: [";
443     for (size_t i = 0; i < num_linked_sensors; ++i) {
444         input_data_str += ::android::base::StringPrintf("%0.2f ", thermistors[i]);
445     }
446     input_data_str += "]";
447     LOG(INFO) << sensor_name << ": " << input_data_str;
448 
449     // check time gap between samples and ignore stale previous samples
450     if (std::chrono::duration_cast<std::chrono::milliseconds>(boot_clock::now() -
451                                                               tflite_instance_->prev_sample_time) >=
452         tflite_instance_->max_sample_interval) {
453         LOG(INFO) << "Ignoring stale previous samples for " << sensor_name;
454         common_instance_->cur_sample_count = 0;
455     }
456 
457     // copy input data into input tensors
458     size_t prev_samples_order = common_instance_->prev_samples_order;
459     size_t cur_sample_index = common_instance_->cur_sample_count % prev_samples_order;
460     size_t sample_start_index = cur_sample_index * num_linked_sensors;
461     for (size_t i = 0; i < num_linked_sensors; ++i) {
462         if (tflite_instance_->enable_input_validation) {
463             if (thermistors[i] < tflite_instance_->input_range[i].min_threshold ||
464                 thermistors[i] > tflite_instance_->input_range[i].max_threshold) {
465                 LOG(INFO) << "thermistors[" << i << "] value: " << thermistors[i]
466                           << " not in range: " << tflite_instance_->input_range[i].min_threshold
467                           << " <= val <= " << tflite_instance_->input_range[i].max_threshold
468                           << " for " << sensor_name;
469                 common_instance_->cur_sample_count = 0;
470                 return kVtEstimatorLowConfidence;
471             }
472         }
473         tflite_instance_->input_buffer[sample_start_index + i] = thermistors[i];
474         if (cur_sample_index == 0 && tflite_instance_->support_under_sampling) {
475             // fill previous samples if support under sampling
476             for (size_t j = 1; j < prev_samples_order; ++j) {
477                 size_t copy_start_index = j * num_linked_sensors;
478                 tflite_instance_->input_buffer[copy_start_index + i] = thermistors[i];
479             }
480         }
481     }
482 
483     // Update sample count
484     common_instance_->cur_sample_count++;
485     tflite_instance_->prev_sample_time = boot_clock::now();
486     if ((common_instance_->cur_sample_count < prev_samples_order) &&
487         !(tflite_instance_->support_under_sampling)) {
488         return kVtEstimatorUnderSampling;
489     }
490 
491     // prepare model input
492     float *model_input;
493     size_t input_buffer_size = tflite_instance_->input_buffer_size;
494     size_t output_buffer_size = tflite_instance_->output_buffer_size;
495     if (!common_instance_->use_prev_samples) {
496         model_input = tflite_instance_->input_buffer;
497     } else {
498         sample_start_index = ((cur_sample_index + 1) * num_linked_sensors) % input_buffer_size;
499         for (size_t i = 0; i < input_buffer_size; ++i) {
500             size_t input_index = (sample_start_index + i) % input_buffer_size;
501             tflite_instance_->scratch_buffer[i] = tflite_instance_->input_buffer[input_index];
502         }
503         model_input = tflite_instance_->scratch_buffer;
504     }
505 
506     int ret = tflite_instance_->tflite_methods.invoke(
507             tflite_instance_->tflite_wrapper, model_input, input_buffer_size,
508             tflite_instance_->output_buffer, output_buffer_size);
509     if (ret) {
510         LOG(ERROR) << "Failed to Invoke for " << sensor_name << " (ret: " << ret << ")";
511         return kVtEstimatorInvokeFailed;
512     }
513     tflite_instance_->last_update_time = boot_clock::now();
514 
515     // prepare output
516     std::vector<float> data;
517     std::ostringstream model_out_log, predict_log;
518     data.reserve(output_buffer_size);
519     for (size_t i = 0; i < output_buffer_size; ++i) {
520         // add offset to predicted value
521         float predicted_value = tflite_instance_->output_buffer[i];
522         model_out_log << predicted_value << " ";
523         predicted_value += CalculateOffset(common_instance_->offset_thresholds,
524                                            common_instance_->offset_values, predicted_value);
525         predict_log << predicted_value << " ";
526         data.emplace_back(predicted_value);
527     }
528     LOG(INFO) << sensor_name << ": model_output: [" << model_out_log.str() << "]";
529     LOG(INFO) << sensor_name << ": predicted_value: [" << predict_log.str() << "]";
530     *output = data;
531 
532     return kVtEstimatorOk;
533 }
534 
Estimate(const std::vector<float> & thermistors,std::vector<float> * output)535 VtEstimatorStatus VirtualTempEstimator::Estimate(const std::vector<float> &thermistors,
536                                                  std::vector<float> *output) {
537     if (type == kUseMLModel) {
538         return TFliteEstimate(thermistors, output);
539     } else if (type == kUseLinearModel) {
540         return LinearModelEstimate(thermistors, output);
541     }
542 
543     LOG(ERROR) << "Unsupported estimationType [" << type << "]";
544     return kVtEstimatorUnSupported;
545 }
546 
TFliteGetMaxPredictWindowMs(size_t * predict_window_ms)547 VtEstimatorStatus VirtualTempEstimator::TFliteGetMaxPredictWindowMs(size_t *predict_window_ms) {
548     if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
549         LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr for predict window\n";
550         return kVtEstimatorInitFailed;
551     }
552 
553     if (!common_instance_->is_initialized) {
554         LOG(ERROR) << "tflite_instance_ not initialized for " << common_instance_->sensor_name;
555         return kVtEstimatorInitFailed;
556     }
557 
558     size_t window = tflite_instance_->predict_window_ms;
559     if (window == 0) {
560         return kVtEstimatorUnSupported;
561     }
562     *predict_window_ms = window;
563     return kVtEstimatorOk;
564 }
565 
GetMaxPredictWindowMs(size_t * predict_window_ms)566 VtEstimatorStatus VirtualTempEstimator::GetMaxPredictWindowMs(size_t *predict_window_ms) {
567     if (type == kUseMLModel) {
568         return TFliteGetMaxPredictWindowMs(predict_window_ms);
569     }
570 
571     LOG(ERROR) << "Unsupported estimationType [" << type << "]";
572     return kVtEstimatorUnSupported;
573 }
574 
TFlitePredictAfterTimeMs(const size_t time_ms,float * output)575 VtEstimatorStatus VirtualTempEstimator::TFlitePredictAfterTimeMs(const size_t time_ms,
576                                                                  float *output) {
577     if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
578         LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr for predict window\n";
579         return kVtEstimatorInitFailed;
580     }
581 
582     if (!common_instance_->is_initialized) {
583         LOG(ERROR) << "tflite_instance_ not initialized for " << common_instance_->sensor_name;
584         return kVtEstimatorInitFailed;
585     }
586 
587     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
588 
589     size_t window = tflite_instance_->predict_window_ms;
590     auto sample_interval = tflite_instance_->sample_interval;
591     auto last_update_time = tflite_instance_->last_update_time;
592     auto request_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(boot_clock::now() -
593                                                                                  last_update_time);
594     // check for under sampling
595     if ((common_instance_->cur_sample_count < common_instance_->prev_samples_order) &&
596         !(tflite_instance_->support_under_sampling)) {
597         LOG(INFO) << tflite_instance_->model_path
598                   << " cannot provide prediction while under sampling";
599         return kVtEstimatorUnderSampling;
600     }
601 
602     // calculate requested time since last update
603     request_time_ms = request_time_ms + std::chrono::milliseconds{time_ms};
604     if (sample_interval.count() == 0 || window == 0 ||
605         window < static_cast<size_t>(request_time_ms.count())) {
606         LOG(INFO) << tflite_instance_->model_path << " cannot predict temperature after ("
607                   << time_ms << " + " << request_time_ms.count() - time_ms
608                   << ") ms since last update with sample interval [" << sample_interval.count()
609                   << "] ms and predict window [" << window << "] ms";
610         return kVtEstimatorUnSupported;
611     }
612 
613     size_t request_step = request_time_ms / sample_interval;
614     size_t output_label_count = tflite_instance_->output_label_count;
615     float *output_buffer = tflite_instance_->output_buffer;
616     float prediction;
617     if (request_step == output_label_count - 1) {
618         // request prediction is on the right boundary of the window
619         prediction = output_buffer[output_label_count - 1];
620     } else {
621         float left = output_buffer[request_step], right = output_buffer[request_step + 1];
622         prediction = left;
623         if (left != right) {
624             prediction += (request_time_ms - sample_interval * request_step) * (right - left) /
625                           sample_interval;
626         }
627     }
628 
629     *output = prediction;
630 
631     return kVtEstimatorOk;
632 }
633 
PredictAfterTimeMs(const size_t time_ms,float * output)634 VtEstimatorStatus VirtualTempEstimator::PredictAfterTimeMs(const size_t time_ms, float *output) {
635     if (type == kUseMLModel) {
636         return TFlitePredictAfterTimeMs(time_ms, output);
637     }
638 
639     LOG(ERROR) << "PredictAfterTimeMs not supported for type [" << type << "]";
640     return kVtEstimatorUnSupported;
641 }
642 
TFliteGetAllPredictions(std::vector<float> * output)643 VtEstimatorStatus VirtualTempEstimator::TFliteGetAllPredictions(std::vector<float> *output) {
644     if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
645         LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr for predict window\n";
646         return kVtEstimatorInitFailed;
647     }
648 
649     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
650 
651     if (!common_instance_->is_initialized) {
652         LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
653         return kVtEstimatorInitFailed;
654     }
655 
656     if (output == nullptr) {
657         LOG(ERROR) << "output is nullptr";
658         return kVtEstimatorInvalidArgs;
659     }
660 
661     std::vector<float> tflite_output;
662     size_t output_buffer_size = tflite_instance_->output_buffer_size;
663     tflite_output.reserve(output_buffer_size);
664     for (size_t i = 0; i < output_buffer_size; ++i) {
665         tflite_output.emplace_back(tflite_instance_->output_buffer[i]);
666     }
667     *output = tflite_output;
668 
669     return kVtEstimatorOk;
670 }
671 
GetAllPredictions(std::vector<float> * output)672 VtEstimatorStatus VirtualTempEstimator::GetAllPredictions(std::vector<float> *output) {
673     if (type == kUseMLModel) {
674         return TFliteGetAllPredictions(output);
675     }
676 
677     LOG(INFO) << "GetAllPredicts not supported by estimationType [" << type << "]";
678     return kVtEstimatorUnSupported;
679 }
680 
TFLiteDumpStatus(std::string_view sensor_name,std::ostringstream * dump_buf)681 VtEstimatorStatus VirtualTempEstimator::TFLiteDumpStatus(std::string_view sensor_name,
682                                                          std::ostringstream *dump_buf) {
683     if (dump_buf == nullptr) {
684         LOG(ERROR) << "dump_buf is nullptr for " << sensor_name;
685         return kVtEstimatorInvalidArgs;
686     }
687 
688     if (!common_instance_->is_initialized) {
689         LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
690         return kVtEstimatorInitFailed;
691     }
692 
693     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
694 
695     *dump_buf << " Sensor Name: " << sensor_name << std::endl;
696     *dump_buf << "  Current Values: ";
697     size_t output_buffer_size = tflite_instance_->output_buffer_size;
698     for (size_t i = 0; i < output_buffer_size; ++i) {
699         // add offset to predicted value
700         float predicted_value = tflite_instance_->output_buffer[i];
701         predicted_value += CalculateOffset(common_instance_->offset_thresholds,
702                                            common_instance_->offset_values, predicted_value);
703         *dump_buf << predicted_value << ", ";
704     }
705     *dump_buf << std::endl;
706 
707     *dump_buf << "  Model Path: \"" << tflite_instance_->model_path << "\"" << std::endl;
708 
709     return kVtEstimatorOk;
710 }
711 
DumpStatus(std::string_view sensor_name,std::ostringstream * dump_buff)712 VtEstimatorStatus VirtualTempEstimator::DumpStatus(std::string_view sensor_name,
713                                                    std::ostringstream *dump_buff) {
714     if (type == kUseMLModel) {
715         return TFLiteDumpStatus(sensor_name, dump_buff);
716     }
717 
718     LOG(INFO) << "DumpStatus not supported by estimationType [" << type << "]";
719     return kVtEstimatorUnSupported;
720 }
721 
Initialize(const VtEstimationInitData & data)722 VtEstimatorStatus VirtualTempEstimator::Initialize(const VtEstimationInitData &data) {
723     LOG(INFO) << "Initialize VirtualTempEstimator for " << type;
724 
725     if (type == kUseMLModel) {
726         return TFliteInitialize(data.ml_model_init_data);
727     } else if (type == kUseLinearModel) {
728         return LinearModelInitialize(data.linear_model_init_data);
729     }
730 
731     LOG(ERROR) << "Unsupported estimationType [" << type << "]";
732     return kVtEstimatorUnSupported;
733 }
734 
ParseInputConfig(const Json::Value & input_config)735 bool VirtualTempEstimator::ParseInputConfig(const Json::Value &input_config) {
736     if (!input_config["ModelConfig"].empty()) {
737         if (!input_config["ModelConfig"]["sample_interval_ms"].empty()) {
738             // read input sample interval
739             int sample_interval_ms = input_config["ModelConfig"]["sample_interval_ms"].asInt();
740             if (sample_interval_ms <= 0) {
741                 LOG(ERROR) << "Invalid sample_interval_ms: " << sample_interval_ms;
742                 return false;
743             }
744 
745             tflite_instance_->sample_interval = std::chrono::milliseconds{sample_interval_ms};
746             LOG(INFO) << "Parsed tflite model input sample_interval: " << sample_interval_ms
747                       << " for " << common_instance_->sensor_name;
748 
749             // determine predict window
750             tflite_instance_->predict_window_ms =
751                     sample_interval_ms * (tflite_instance_->output_label_count - 1);
752             LOG(INFO) << "Max prediction window size: " << tflite_instance_->predict_window_ms
753                       << " ms for " << common_instance_->sensor_name;
754         }
755 
756         if (!input_config["ModelConfig"]["max_sample_interval_ms"].empty()) {
757             // read input max sample interval
758             int max_sample_interval_ms =
759                     input_config["ModelConfig"]["max_sample_interval_ms"].asInt();
760             if (max_sample_interval_ms <= 0) {
761                 LOG(ERROR) << "Invalid max_sample_interval_ms " << max_sample_interval_ms;
762                 return false;
763             }
764 
765             tflite_instance_->max_sample_interval =
766                     std::chrono::milliseconds{max_sample_interval_ms};
767             LOG(INFO) << "Parsed tflite model max_sample_interval: " << max_sample_interval_ms
768                       << " for " << common_instance_->sensor_name;
769         }
770     }
771 
772     if (!input_config["InputData"].empty()) {
773         Json::Value input_data = input_config["InputData"];
774         if (input_data.size() != common_instance_->num_linked_sensors) {
775             LOG(ERROR) << "Input ranges size: " << input_data.size()
776                        << " does not match num_linked_sensors: "
777                        << common_instance_->num_linked_sensors;
778             return false;
779         }
780 
781         LOG(INFO) << "Start to parse tflite model input config for "
782                   << common_instance_->num_linked_sensors;
783         tflite_instance_->input_range.assign(input_data.size(), InputRangeInfo());
784         for (Json::Value::ArrayIndex i = 0; i < input_data.size(); ++i) {
785             const std::string &name = input_data[i]["Name"].asString();
786             LOG(INFO) << "Sensor[" << i << "] Name: " << name;
787             if (!getInputRangeInfoFromJsonValues(input_data[i]["Range"],
788                                                  &tflite_instance_->input_range[i])) {
789                 LOG(ERROR) << "Failed to parse tflite model temp range for sensor: [" << name
790                            << "]";
791                 return false;
792             }
793         }
794     }
795 
796     return true;
797 }
798 
GetInputConfig(Json::Value * config)799 bool VirtualTempEstimator::GetInputConfig(Json::Value *config) {
800     int config_size = 0;
801     int ret = tflite_instance_->tflite_methods.get_input_config_size(
802             tflite_instance_->tflite_wrapper, &config_size);
803     if (ret || config_size <= 0) {
804         LOG(ERROR) << "Failed to get tflite input config size (ret: " << ret
805                    << ") with size: " << config_size;
806         return false;
807     }
808 
809     LOG(INFO) << "Model input config_size: " << config_size << " for "
810               << common_instance_->sensor_name;
811 
812     char *config_str = new char[config_size];
813     ret = tflite_instance_->tflite_methods.get_input_config(tflite_instance_->tflite_wrapper,
814                                                             config_str, config_size);
815     if (ret) {
816         LOG(ERROR) << "Failed to get tflite input config (ret: " << ret << ")";
817         delete[] config_str;
818         return false;
819     }
820 
821     Json::CharReaderBuilder builder;
822     std::unique_ptr<Json::CharReader> reader(builder.newCharReader());
823     std::string errorMessage;
824 
825     bool success = true;
826     if (!reader->parse(config_str, config_str + config_size, config, &errorMessage)) {
827         LOG(ERROR) << "Failed to parse tflite JSON input config: " << errorMessage;
828         success = false;
829     }
830     delete[] config_str;
831     return success;
832 }
833 
834 }  // namespace vtestimator
835 }  // namespace thermal
836