1 /*
2  * Copyright 2021 HIMSA II K/S - www.himsa.com.
3  * Represented by EHIMA - www.ehima.com
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 #pragma once
19 
20 #include <base/strings/string_number_conversions.h>
21 #include <bluetooth/log.h>
22 
23 #include <algorithm>
24 #include <map>
25 #include <vector>
26 
27 #include "bta_csis_api.h"
28 #include "bta_gatt_api.h"
29 #include "bta_groups.h"
30 #include "btif/include/btif_storage.h"
31 #include "crypto_toolbox/crypto_toolbox.h"
32 #include "gap_api.h"
33 
34 // Uncomment to debug SIRK calculations
35 // #define CSIS_DEBUG
36 
37 namespace bluetooth {
38 namespace csis {
39 
40 using bluetooth::csis::CsisLockCb;
41 
42 // CSIP additions
43 /* Generic UUID is used when CSIS is not included in any context */
44 static const bluetooth::Uuid kCsisServiceUuid = bluetooth::Uuid::From16Bit(0x1846);
45 static const bluetooth::Uuid kCsisSirkUuid = bluetooth::Uuid::From16Bit(0x2B84);
46 static const bluetooth::Uuid kCsisSizeUuid = bluetooth::Uuid::From16Bit(0x2B85);
47 static const bluetooth::Uuid kCsisLockUuid = bluetooth::Uuid::From16Bit(0x2B86);
48 static const bluetooth::Uuid kCsisRankUuid = bluetooth::Uuid::From16Bit(0x2B87);
49 
50 static constexpr uint8_t kCsisErrorCodeLockDenied = 0x80;
51 static constexpr uint8_t kCsisErrorCodeReleaseNotAllowed = 0x81;
52 static constexpr uint8_t kCsisErrorCodeInvalidValue = 0x82;
53 static constexpr uint8_t kCsisErrorCodeLockAccessSirkRejected = 0x83;
54 static constexpr uint8_t kCsisErrorCodeLockOobSirkOnly = 0x84;
55 static constexpr uint8_t kCsisErrorCodeLockAlreadyGranted = 0x85;
56 
57 static constexpr uint8_t kCsisSirkTypeEncrypted = 0x00;
58 static constexpr uint8_t kCsisSirkCharLen = 17;
59 
60 struct hdl_pair {
hdl_pairhdl_pair61   hdl_pair() {}
hdl_pairhdl_pair62   hdl_pair(uint16_t val_hdl, uint16_t ccc_hdl) : val_hdl(val_hdl), ccc_hdl(ccc_hdl) {}
63 
64   uint16_t val_hdl;
65   uint16_t ccc_hdl;
66 };
67 
68 /* CSIS Types */
69 static constexpr uint8_t kDefaultScanDurationS = 5;
70 static constexpr uint8_t kDefaultCsisSetSize = 1;
71 static constexpr uint8_t kUnknownRank = 0xff;
72 
73 /* Enums */
74 enum class CsisLockState : uint8_t {
75   CSIS_STATE_UNSET = 0x00,
76   CSIS_STATE_UNLOCKED,
77   CSIS_STATE_LOCKED
78 };
79 
80 enum class CsisDiscoveryState : uint8_t {
81   CSIS_DISCOVERY_IDLE = 0x00,
82   CSIS_DISCOVERY_ONGOING,
83   CSIS_DISCOVERY_COMPLETED,
84 };
85 
86 class GattServiceDevice {
87 public:
88   RawAddress addr;
89   /*
90    * We are making active attempt to connect to this device, 'direct connect'.
91    */
92   bool connecting_actively = false;
93 
94   tCONN_ID conn_id = GATT_INVALID_CONN_ID;
95   uint16_t service_handle = GAP_INVALID_HANDLE;
96   bool is_gatt_service_valid = false;
97 
GattServiceDevice(const RawAddress & addr,bool)98   GattServiceDevice(const RawAddress& addr, bool /*first_connection*/) : addr(addr) {}
99 
GattServiceDevice()100   GattServiceDevice() : GattServiceDevice(RawAddress::kEmpty, false) {}
101 
IsConnected()102   bool IsConnected() const { return conn_id != GATT_INVALID_CONN_ID; }
103 
104   class MatchAddress {
105   private:
106     RawAddress addr;
107 
108   public:
MatchAddress(const RawAddress & addr)109     MatchAddress(const RawAddress& addr) : addr(addr) {}
operator()110     bool operator()(const std::shared_ptr<GattServiceDevice>& other) const {
111       return addr == other->addr;
112     }
113   };
114 
115   class MatchConnId {
116   private:
117     tCONN_ID conn_id;
118 
119   public:
MatchConnId(tCONN_ID conn_id)120     MatchConnId(tCONN_ID conn_id) : conn_id(conn_id) {}
operator()121     bool operator()(const std::shared_ptr<GattServiceDevice>& other) const {
122       return conn_id == other->conn_id;
123     }
124   };
125 };
126 
127 /*
128  * CSIS instance represents single CSIS service on the remote device
129  * along with the handle in database and specific data to control CSIS like:
130  * rank, lock state.
131  *
132  * It also inclues UUID of the primary service which includes that CSIS
133  * instance. If this is 0x0000 it means CSIS is per device and not for specific
134  * service.
135  */
136 class CsisInstance {
137 public:
138   bluetooth::Uuid coordinated_service = bluetooth::groups::kGenericContextUuid;
139 
140   struct SvcData {
141     uint16_t start_handle;
142     uint16_t end_handle;
143     struct hdl_pair sirk_handle;
144     struct hdl_pair lock_handle;
145     uint16_t rank_handle;
146     struct hdl_pair size_handle;
147   } svc_data = {
148           GAP_INVALID_HANDLE,
149           GAP_INVALID_HANDLE,
150           {GAP_INVALID_HANDLE, GAP_INVALID_HANDLE},
151           {GAP_INVALID_HANDLE, GAP_INVALID_HANDLE},
152           GAP_INVALID_HANDLE,
153           {GAP_INVALID_HANDLE, GAP_INVALID_HANDLE},
154   };
155 
CsisInstance(uint16_t start_handle,uint16_t end_handle,const bluetooth::Uuid & uuid)156   CsisInstance(uint16_t start_handle, uint16_t end_handle, const bluetooth::Uuid& uuid)
157       : coordinated_service(uuid),
158         group_id_(bluetooth::groups::kGroupUnknown),
159         rank_(kUnknownRank),
160         lock_state_(CsisLockState::CSIS_STATE_UNSET) {
161     svc_data.start_handle = start_handle;
162     svc_data.end_handle = end_handle;
163   }
164 
SetLockState(CsisLockState state)165   void SetLockState(CsisLockState state) {
166     log::debug("current lock state: {}, new lock state: {}", static_cast<int>(lock_state_),
167                static_cast<int>(state));
168     lock_state_ = state;
169   }
GetLockState(void)170   CsisLockState GetLockState(void) const { return lock_state_; }
GetRank(void)171   uint8_t GetRank(void) const { return rank_; }
SetRank(uint8_t rank)172   void SetRank(uint8_t rank) {
173     log::debug("current rank: {}, new rank: {}", static_cast<int>(rank_), static_cast<int>(rank));
174     rank_ = rank;
175   }
176 
SetGroupId(int group_id)177   void SetGroupId(int group_id) {
178     log::info("set group id: {}, instance handle: 0x{:04x}", group_id, svc_data.start_handle);
179     group_id_ = group_id;
180   }
181 
GetGroupId(void)182   int GetGroupId(void) const { return group_id_; }
183 
HasSameUuid(const CsisInstance & csis_instance)184   bool HasSameUuid(const CsisInstance& csis_instance) const {
185     return csis_instance.coordinated_service == coordinated_service;
186   }
187 
GetUuid(void)188   const bluetooth::Uuid& GetUuid(void) const { return coordinated_service; }
IsForUuid(const bluetooth::Uuid & uuid)189   bool IsForUuid(const bluetooth::Uuid& uuid) const { return coordinated_service == uuid; }
190 
191 private:
192   int group_id_;
193   uint8_t rank_;
194   CsisLockState lock_state_;
195 };
196 
197 /*
198  * Csis Device represents remote device and its all CSIS instances.
199  * It can happen that device can have more than one CSIS service instance
200  * if those instances are included in other services. In this way, coordinated
201  * set is within the context of the primary service which includes the instance.
202  *
203  * CsisDevice contains vector of the instances.
204  */
205 class CsisDevice : public GattServiceDevice {
206 public:
207   using GattServiceDevice::GattServiceDevice;
208 
ClearSvcData()209   void ClearSvcData() {
210     GattServiceDevice::service_handle = GAP_INVALID_HANDLE;
211     GattServiceDevice::is_gatt_service_valid = false;
212 
213     csis_instances_.clear();
214   }
215 
FindValueHandleByCccHandle(uint16_t ccc_handle)216   uint16_t FindValueHandleByCccHandle(uint16_t ccc_handle) {
217     uint16_t val_handle = 0;
218     for (const auto& [_, inst] : csis_instances_) {
219       if (inst->svc_data.sirk_handle.ccc_hdl == ccc_handle) {
220         val_handle = inst->svc_data.sirk_handle.val_hdl;
221       } else if (inst->svc_data.lock_handle.ccc_hdl == ccc_handle) {
222         val_handle = inst->svc_data.lock_handle.val_hdl;
223       } else if (inst->svc_data.size_handle.ccc_hdl == ccc_handle) {
224         val_handle = inst->svc_data.size_handle.val_hdl;
225       }
226       if (val_handle) {
227         break;
228       }
229     }
230     return val_handle;
231   }
232 
GetCsisInstanceByOwningHandle(uint16_t handle)233   std::shared_ptr<CsisInstance> GetCsisInstanceByOwningHandle(uint16_t handle) {
234     uint16_t hdl = 0;
235     for (const auto& [h, inst] : csis_instances_) {
236       if (handle >= inst->svc_data.start_handle && handle <= inst->svc_data.end_handle) {
237         hdl = h;
238         log::verbose("found 0x{:04x}", hdl);
239         break;
240       }
241     }
242     return (hdl > 0) ? csis_instances_.at(hdl) : nullptr;
243   }
244 
GetCsisInstanceByGroupId(int group_id)245   std::shared_ptr<CsisInstance> GetCsisInstanceByGroupId(int group_id) {
246     uint16_t hdl = 0;
247     for (const auto& [handle, inst] : csis_instances_) {
248       if (inst->GetGroupId() == group_id) {
249         hdl = handle;
250         break;
251       }
252     }
253     return (hdl > 0) ? csis_instances_.at(hdl) : nullptr;
254   }
255 
SetCsisInstance(uint16_t handle,std::shared_ptr<CsisInstance> csis_instance)256   void SetCsisInstance(uint16_t handle, std::shared_ptr<CsisInstance> csis_instance) {
257     if (csis_instances_.count(handle)) {
258       log::debug("instance is already here: {}", csis_instance->GetUuid().ToString());
259       return;
260     }
261 
262     csis_instances_.insert({handle, csis_instance});
263     log::debug("instance added: 0x{:04x}, device {}", handle, addr);
264   }
265 
RemoveCsisInstance(int group_id)266   void RemoveCsisInstance(int group_id) {
267     for (auto it = csis_instances_.begin(); it != csis_instances_.end(); it++) {
268       if (it->second->GetGroupId() == group_id) {
269         csis_instances_.erase(it);
270         return;
271       }
272     }
273   }
274 
GetNumberOfCsisInstances(void)275   int GetNumberOfCsisInstances(void) { return csis_instances_.size(); }
276 
ForEachCsisInstance(std::function<void (const std::shared_ptr<CsisInstance> &)> cb)277   void ForEachCsisInstance(std::function<void(const std::shared_ptr<CsisInstance>&)> cb) {
278     for (auto const& kv_pair : csis_instances_) {
279       cb(kv_pair.second);
280     }
281   }
282 
SetExpectedGroupIdMember(int group_id)283   void SetExpectedGroupIdMember(int group_id) {
284     log::info("Expected Group ID: {}, for member: {} is set", group_id, addr);
285     expected_group_id_member_ = group_id;
286   }
287 
SetPairingSirkReadFlag(bool flag)288   void SetPairingSirkReadFlag(bool flag) {
289     log::info("Pairing flag for Group ID: {}, member: {} is set to {}", expected_group_id_member_,
290               addr, flag);
291     pairing_sirk_read_flag_ = flag;
292   }
293 
GetExpectedGroupIdMember()294   inline int GetExpectedGroupIdMember() { return expected_group_id_member_; }
GetPairingSirkReadFlag()295   inline bool GetPairingSirkReadFlag() { return pairing_sirk_read_flag_; }
296 
297 private:
298   /* Instances per start handle  */
299   std::map<uint16_t, std::shared_ptr<CsisInstance>> csis_instances_;
300   int expected_group_id_member_ = bluetooth::groups::kGroupUnknown;
301   bool pairing_sirk_read_flag_ = false;
302 };
303 
304 /*
305  * CSIS group gathers devices which belongs to specific group.
306  * It also contains methond to decode encrypted SIRK and also to
307  * resolve PRSI in order to find out if device belongs to given group
308  */
309 class CsisGroup {
310 public:
CsisGroup(int group_id,const bluetooth::Uuid & uuid)311   CsisGroup(int group_id, const bluetooth::Uuid& uuid)
312       : group_id_(group_id),
313         size_(kDefaultCsisSetSize),
314         uuid_(uuid),
315         member_discovery_state_(CsisDiscoveryState::CSIS_DISCOVERY_IDLE),
316         lock_state_(CsisLockState::CSIS_STATE_UNSET),
317         target_lock_state_(CsisLockState::CSIS_STATE_UNSET),
318         lock_transition_cnt_(0) {
319     devices_.clear();
320     BTIF_STORAGE_FILL_PROPERTY(&model_name, BT_PROPERTY_REMOTE_MODEL_NUM, sizeof(model_name_val),
321                                &model_name_val);
322   }
323 
324   bt_property_t model_name;
325   bt_bdname_t model_name_val = {0};
326 
AddDevice(std::shared_ptr<CsisDevice> csis_device)327   void AddDevice(std::shared_ptr<CsisDevice> csis_device) {
328     auto it =
329             find_if(devices_.begin(), devices_.end(), CsisDevice::MatchAddress(csis_device->addr));
330     if (it != devices_.end()) {
331       return;
332     }
333 
334     devices_.push_back(std::move(csis_device));
335   }
336 
RemoveDevice(const RawAddress & bd_addr)337   void RemoveDevice(const RawAddress& bd_addr) {
338     auto it = find_if(devices_.begin(), devices_.end(), CsisDevice::MatchAddress(bd_addr));
339     if (it != devices_.end()) {
340       devices_.erase(it);
341     }
342   }
343 
GetCurrentSize(void)344   int GetCurrentSize(void) const { return devices_.size(); }
GetUuid()345   bluetooth::Uuid GetUuid() const { return uuid_; }
SetUuid(const bluetooth::Uuid & uuid)346   void SetUuid(const bluetooth::Uuid& uuid) { uuid_ = uuid; }
GetGroupId(void)347   int GetGroupId(void) const { return group_id_; }
GetDesiredSize(void)348   int GetDesiredSize(void) const { return size_; }
SetDesiredSize(int size)349   void SetDesiredSize(int size) { size_ = size; }
IsGroupComplete(void)350   bool IsGroupComplete(void) const { return size_ == (int)devices_.size(); }
IsEmpty(void)351   bool IsEmpty(void) const { return devices_.empty(); }
352 
IsDeviceInTheGroup(std::shared_ptr<CsisDevice> & csis_device)353   bool IsDeviceInTheGroup(std::shared_ptr<CsisDevice>& csis_device) {
354     auto it =
355             find_if(devices_.begin(), devices_.end(), CsisDevice::MatchAddress(csis_device->addr));
356     return it != devices_.end();
357   }
IsRsiMatching(const RawAddress & rsi)358   bool IsRsiMatching(const RawAddress& rsi) const { return is_rsi_match_sirk(rsi, GetSirk()); }
IsSirkBelongsToGroup(Octet16 sirk)359   bool IsSirkBelongsToGroup(Octet16 sirk) const { return sirk_available_ && sirk_ == sirk; }
GetSirk(void)360   Octet16 GetSirk(void) const { return sirk_; }
SetSirk(Octet16 & sirk)361   void SetSirk(Octet16& sirk) {
362     if (sirk_available_) {
363       log::debug("Updating SIRK");
364     }
365     sirk_available_ = true;
366     sirk_ = sirk;
367   }
368 
GetNumOfConnectedDevices(void)369   int GetNumOfConnectedDevices(void) {
370     return std::count_if(devices_.begin(), devices_.end(),
371                          [](auto& d) { return d->IsConnected(); });
372   }
373 
GetDiscoveryState(void)374   CsisDiscoveryState GetDiscoveryState(void) const { return member_discovery_state_; }
SetDiscoveryState(CsisDiscoveryState state)375   void SetDiscoveryState(CsisDiscoveryState state) {
376     log::debug("current discovery state: {}, new discovery state: {}",
377                static_cast<int>(member_discovery_state_), static_cast<int>(state));
378     member_discovery_state_ = state;
379   }
380 
SetCurrentLockState(CsisLockState state)381   void SetCurrentLockState(CsisLockState state) { lock_state_ = state; }
382 
383   void SetTargetLockState(CsisLockState state, CsisLockCb cb = base::DoNothing()) {
384     target_lock_state_ = state;
385     cb_ = std::move(cb);
386     switch (state) {
387       case CsisLockState::CSIS_STATE_LOCKED:
388         lock_transition_cnt_ = GetNumOfConnectedDevices();
389         break;
390       case CsisLockState::CSIS_STATE_UNLOCKED:
391       case CsisLockState::CSIS_STATE_UNSET:
392         lock_transition_cnt_ = 0;
393         break;
394     }
395   }
396 
GetLockCb(void)397   CsisLockCb GetLockCb(void) { return std::move(cb_); }
398 
GetCurrentLockState(void)399   CsisLockState GetCurrentLockState(void) const { return lock_state_; }
GetTargetLockState(void)400   CsisLockState GetTargetLockState(void) const { return target_lock_state_; }
401 
IsAvailableForCsisLockOperation(void)402   bool IsAvailableForCsisLockOperation(void) {
403     int id = group_id_;
404     int number_of_connected = 0;
405     auto iter = std::find_if(devices_.begin(), devices_.end(), [id, &number_of_connected](auto& d) {
406       if (!d->IsConnected()) {
407         log::debug("Device {} is not connected in group {}", d->addr, id);
408         return false;
409       }
410       auto inst = d->GetCsisInstanceByGroupId(id);
411       if (!inst) {
412         log::debug("Instance not available for group {}", id);
413         return false;
414       }
415       number_of_connected++;
416       log::debug("Device {},  lock state: {}", d->addr, (int)inst->GetLockState());
417       return inst->GetLockState() == CsisLockState::CSIS_STATE_LOCKED;
418     });
419 
420     log::debug("Locked set: {}, number of connected {}", iter != devices_.end(),
421                number_of_connected);
422     /* If there is no locked device, we are good to go */
423     if (iter != devices_.end()) {
424       log::warn("Device {} is locked", (*iter)->addr);
425       return false;
426     }
427 
428     return number_of_connected > 0;
429   }
430 
SortByCsisRank(void)431   void SortByCsisRank(void) {
432     int id = group_id_;
433     std::sort(devices_.begin(), devices_.end(), [id](auto& dev1, auto& dev2) {
434       auto inst1 = dev1->GetCsisInstanceByGroupId(id);
435       auto inst2 = dev2->GetCsisInstanceByGroupId(id);
436       if (!inst1 || !inst2) {
437         /* One of the device is not connected */
438         log::debug("Device  {} is not connected.", inst1 == nullptr ? dev1->addr : dev2->addr);
439         return dev1->IsConnected();
440       }
441       return inst1->GetRank() < inst2->GetRank();
442     });
443   }
444 
GetFirstDevice(void)445   std::shared_ptr<CsisDevice> GetFirstDevice(void) { return devices_.front(); }
GetLastDevice(void)446   std::shared_ptr<CsisDevice> GetLastDevice(void) { return devices_.back(); }
GetNextDevice(std::shared_ptr<CsisDevice> & device)447   std::shared_ptr<CsisDevice> GetNextDevice(std::shared_ptr<CsisDevice>& device) {
448     auto iter =
449             std::find_if(devices_.begin(), devices_.end(), CsisDevice::MatchAddress(device->addr));
450 
451     /* If reference device not found */
452     if (iter == devices_.end()) {
453       return nullptr;
454     }
455 
456     iter++;
457     /* If reference device is last in group */
458     if (iter == devices_.end()) {
459       return nullptr;
460     }
461 
462     return *iter;
463   }
GetPrevDevice(std::shared_ptr<CsisDevice> & device)464   std::shared_ptr<CsisDevice> GetPrevDevice(std::shared_ptr<CsisDevice>& device) {
465     auto iter = std::find_if(devices_.rbegin(), devices_.rend(),
466                              CsisDevice::MatchAddress(device->addr));
467 
468     /* If reference device not found */
469     if (iter == devices_.rend()) {
470       return nullptr;
471     }
472 
473     iter++;
474 
475     if (iter == devices_.rend()) {
476       return nullptr;
477     }
478     return *iter;
479   }
480 
GetLockTransitionCnt(void)481   int GetLockTransitionCnt(void) const { return lock_transition_cnt_; }
UpdateLockTransitionCnt(int i)482   int UpdateLockTransitionCnt(int i) {
483     lock_transition_cnt_ += i;
484     return lock_transition_cnt_;
485   }
486 
487   /* Return true if given Autoset Private Address |srpa| matches Set Identity
488    * Resolving Key |sirk| */
is_rsi_match_sirk(const RawAddress & rsi,const Octet16 & sirk)489   static bool is_rsi_match_sirk(const RawAddress& rsi, const Octet16& sirk) {
490     /* use the 3 MSB of bd address as prand */
491     Octet16 rand{};
492     rand[0] = rsi.address[2];
493     rand[1] = rsi.address[1];
494     rand[2] = rsi.address[0];
495 #ifdef CSIS_DEBUG
496     log::info("Prand {}", base::HexEncode(rand.data(), 3));
497     log::info("SIRK {}", base::HexEncode(sirk.data(), 16));
498 #endif
499 
500     /* generate X = E irk(R0, R1, R2) and R is random address 3 LSO */
501     Octet16 x = crypto_toolbox::aes_128(sirk, rand);
502 
503 #ifdef CSIS_DEBUG
504     log::info("X {}", base::HexEncode(x.data(), 16));
505 #endif
506 
507     rand[0] = rsi.address[5];
508     rand[1] = rsi.address[4];
509     rand[2] = rsi.address[3];
510 
511 #ifdef CSIS_DEBUG
512     log::info("Hash {}", base::HexEncode(rand.data(), 3));
513 #endif
514 
515     if (memcmp(x.data(), &rand[0], 3) == 0) {
516       // match
517       return true;
518     }
519     // not a match
520     return false;
521   }
522 
523 private:
524   int group_id_;
525   Octet16 sirk_ = {0};
526   bool sirk_available_ = false;
527   int size_;
528   bluetooth::Uuid uuid_;
529 
530   std::vector<std::shared_ptr<CsisDevice>> devices_;
531   CsisDiscoveryState member_discovery_state_;
532 
533   CsisLockState lock_state_;
534   CsisLockState target_lock_state_;
535   int lock_transition_cnt_;
536 
537   CsisLockCb cb_;
538 };
539 
540 }  // namespace csis
541 }  // namespace bluetooth
542