xref: /aosp_15_r20/external/pigweed/pw_bluetooth_sapphire/host/gatt/fake_layer.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_bluetooth_sapphire/internal/host/gatt/fake_layer.h"
16 
17 #include "pw_bluetooth_sapphire/internal/host/gatt/remote_service.h"
18 
19 namespace bt::gatt::testing {
20 
TestPeer(pw::async::Dispatcher & pw_dispatcher)21 FakeLayer::TestPeer::TestPeer(pw::async::Dispatcher& pw_dispatcher)
22     : fake_client(pw_dispatcher) {}
23 
24 std::pair<RemoteService::WeakPtr, FakeClient::WeakPtr>
AddPeerService(PeerId peer_id,const ServiceData & info,bool notify)25 FakeLayer::AddPeerService(PeerId peer_id,
26                           const ServiceData& info,
27                           bool notify) {
28   auto [iter, _] = peers_.try_emplace(peer_id, pw_dispatcher_);
29   auto& peer = iter->second;
30 
31   PW_CHECK(info.range_start <= info.range_end);
32   auto service =
33       std::make_unique<RemoteService>(info, peer.fake_client.GetWeakPtr());
34   RemoteService::WeakPtr service_weak = service->GetWeakPtr();
35 
36   std::vector<att::Handle> removed;
37   ServiceList added;
38   ServiceList modified;
39 
40   auto svc_iter = peer.services.find(info.range_start);
41   if (svc_iter != peer.services.end()) {
42     if (svc_iter->second->uuid() == info.type) {
43       modified.push_back(service_weak);
44     } else {
45       removed.push_back(svc_iter->second->handle());
46       added.push_back(service_weak);
47     }
48 
49     svc_iter->second->set_service_changed(true);
50     peer.services.erase(svc_iter);
51   } else {
52     added.push_back(service_weak);
53   }
54 
55   bt_log(DEBUG,
56          "gatt",
57          "services changed (removed: %zu, added: %zu, modified: %zu)",
58          removed.size(),
59          added.size(),
60          modified.size());
61 
62   peer.services.emplace(info.range_start, std::move(service));
63 
64   if (notify && remote_service_watchers_.count(peer_id)) {
65     remote_service_watchers_[peer_id](removed, added, modified);
66   }
67 
68   return {service_weak, peer.fake_client.AsFakeWeakPtr()};
69 }
70 
RemovePeerService(PeerId peer_id,att::Handle handle)71 void FakeLayer::RemovePeerService(PeerId peer_id, att::Handle handle) {
72   auto peer_iter = peers_.find(peer_id);
73   if (peer_iter == peers_.end()) {
74     return;
75   }
76   auto svc_iter = peer_iter->second.services.find(handle);
77   if (svc_iter == peer_iter->second.services.end()) {
78     return;
79   }
80   svc_iter->second->set_service_changed(true);
81   peer_iter->second.services.erase(svc_iter);
82 
83   if (remote_service_watchers_.count(peer_id)) {
84     remote_service_watchers_[peer_id](
85         /*removed=*/{handle}, /*added=*/{}, /*modified=*/{});
86   }
87 }
88 
AddConnection(PeerId peer_id,std::unique_ptr<Client>,Server::FactoryFunction)89 void FakeLayer::AddConnection(PeerId peer_id,
90                               std::unique_ptr<Client>,
91                               Server::FactoryFunction) {
92   peers_.try_emplace(peer_id, pw_dispatcher_);
93 }
94 
RemoveConnection(PeerId peer_id)95 void FakeLayer::RemoveConnection(PeerId peer_id) { peers_.erase(peer_id); }
96 
RegisterPeerMtuListener(PeerMtuListener)97 GATT::PeerMtuListenerId FakeLayer::RegisterPeerMtuListener(PeerMtuListener) {
98   BT_PANIC("implement fake behavior if needed");
99 }
100 
UnregisterPeerMtuListener(PeerMtuListenerId)101 bool FakeLayer::UnregisterPeerMtuListener(PeerMtuListenerId) {
102   BT_PANIC("implement fake behavior if needed");
103 }
104 
RegisterService(ServicePtr service,ServiceIdCallback callback,ReadHandler read_handler,WriteHandler write_handler,ClientConfigCallback ccc_callback)105 void FakeLayer::RegisterService(ServicePtr service,
106                                 ServiceIdCallback callback,
107                                 ReadHandler read_handler,
108                                 WriteHandler write_handler,
109                                 ClientConfigCallback ccc_callback) {
110   if (register_service_fails_) {
111     callback(kInvalidId);
112     return;
113   }
114 
115   IdType id = next_local_service_id_++;
116   local_services_.try_emplace(id,
117                               LocalService{std::move(service),
118                                            std::move(read_handler),
119                                            std::move(write_handler),
120                                            std::move(ccc_callback),
121                                            {}});
122   callback(id);
123 }
124 
UnregisterService(IdType service_id)125 void FakeLayer::UnregisterService(IdType service_id) {
126   local_services_.erase(service_id);
127 }
128 
SendUpdate(IdType service_id,IdType chrc_id,PeerId peer_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)129 void FakeLayer::SendUpdate(IdType service_id,
130                            IdType chrc_id,
131                            PeerId peer_id,
132                            ::std::vector<uint8_t> value,
133                            IndicationCallback indicate_cb) {
134   auto iter = local_services_.find(service_id);
135   if (iter == local_services_.end()) {
136     indicate_cb(fit::error(att::ErrorCode::kInvalidHandle));
137     return;
138   }
139   iter->second.updates.push_back(
140       Update{chrc_id, std::move(value), std::move(indicate_cb), peer_id});
141 }
142 
UpdateConnectedPeers(IdType service_id,IdType chrc_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)143 void FakeLayer::UpdateConnectedPeers(IdType service_id,
144                                      IdType chrc_id,
145                                      ::std::vector<uint8_t> value,
146                                      IndicationCallback indicate_cb) {
147   auto iter = local_services_.find(service_id);
148   if (iter == local_services_.end()) {
149     indicate_cb(fit::error(att::ErrorCode::kInvalidHandle));
150     return;
151   }
152   iter->second.updates.push_back(
153       Update{chrc_id, std::move(value), std::move(indicate_cb), std::nullopt});
154 }
155 
SetPersistServiceChangedCCCCallback(PersistServiceChangedCCCCallback callback)156 void FakeLayer::SetPersistServiceChangedCCCCallback(
157     PersistServiceChangedCCCCallback callback) {
158   if (set_persist_service_changed_ccc_cb_cb_) {
159     set_persist_service_changed_ccc_cb_cb_();
160   }
161   persist_service_changed_ccc_cb_ = std::move(callback);
162 }
163 
SetRetrieveServiceChangedCCCCallback(RetrieveServiceChangedCCCCallback callback)164 void FakeLayer::SetRetrieveServiceChangedCCCCallback(
165     RetrieveServiceChangedCCCCallback callback) {
166   if (set_retrieve_service_changed_ccc_cb_cb_) {
167     set_retrieve_service_changed_ccc_cb_cb_();
168   }
169   retrieve_service_changed_ccc_cb_ = std::move(callback);
170 }
171 
InitializeClient(PeerId peer_id,std::vector<UUID> services_to_discover)172 void FakeLayer::InitializeClient(PeerId peer_id,
173                                  std::vector<UUID> services_to_discover) {
174   std::vector<UUID> uuids = std::move(services_to_discover);
175   if (initialize_client_cb_) {
176     initialize_client_cb_(peer_id, uuids);
177   }
178 
179   auto iter = peers_.find(peer_id);
180   if (iter == peers_.end()) {
181     return;
182   }
183 
184   std::vector<RemoteService::WeakPtr> added;
185   if (uuids.empty()) {
186     for (auto& svc_pair : iter->second.services) {
187       added.push_back(svc_pair.second->GetWeakPtr());
188     }
189   } else {
190     for (auto& svc_pair : iter->second.services) {
191       auto uuid_iter =
192           std::find_if(uuids.begin(), uuids.end(), [&svc_pair](auto uuid) {
193             return svc_pair.second->uuid() == uuid;
194           });
195       if (uuid_iter != uuids.end()) {
196         added.push_back(svc_pair.second->GetWeakPtr());
197       }
198     }
199   }
200 
201   if (remote_service_watchers_.count(peer_id)) {
202     remote_service_watchers_[peer_id](
203         /*removed=*/{}, /*added=*/added, /*modified=*/{});
204   }
205 }
206 
RegisterRemoteServiceWatcherForPeer(PeerId peer_id,RemoteServiceWatcher watcher)207 GATT::RemoteServiceWatcherId FakeLayer::RegisterRemoteServiceWatcherForPeer(
208     PeerId peer_id, RemoteServiceWatcher watcher) {
209   PW_CHECK(remote_service_watchers_.count(peer_id) == 0);
210   remote_service_watchers_[peer_id] = std::move(watcher);
211   // Use the PeerId as the watcher ID because FakeLayer only needs to support 1
212   // watcher per peer.
213   return peer_id.value();
214 }
UnregisterRemoteServiceWatcher(RemoteServiceWatcherId watcher_id)215 bool FakeLayer::UnregisterRemoteServiceWatcher(
216     RemoteServiceWatcherId watcher_id) {
217   bool result = remote_service_watchers_.count(PeerId(watcher_id));
218   remote_service_watchers_.erase(PeerId(watcher_id));
219   return result;
220 }
221 
ListServices(PeerId peer_id,std::vector<UUID> uuids,ServiceListCallback callback)222 void FakeLayer::ListServices(PeerId peer_id,
223                              std::vector<UUID> uuids,
224                              ServiceListCallback callback) {
225   if (pause_list_services_) {
226     return;
227   }
228 
229   ServiceList services;
230 
231   auto iter = peers_.find(peer_id);
232   if (iter != peers_.end()) {
233     for (auto& svc_pair : iter->second.services) {
234       auto pred = [&](const UUID& uuid) {
235         return svc_pair.second->uuid() == uuid;
236       };
237       if (uuids.empty() ||
238           std::find_if(uuids.begin(), uuids.end(), pred) != uuids.end()) {
239         services.push_back(svc_pair.second->GetWeakPtr());
240       }
241     }
242   }
243 
244   callback(list_services_status_, std::move(services));
245 }
246 
FindService(PeerId peer_id,IdType service_id)247 RemoteService::WeakPtr FakeLayer::FindService(PeerId peer_id,
248                                               IdType service_id) {
249   auto peer_iter = peers_.find(peer_id);
250   if (peer_iter == peers_.end()) {
251     return RemoteService::WeakPtr();
252   }
253   auto svc_iter = peer_iter->second.services.find(service_id);
254   if (svc_iter == peer_iter->second.services.end()) {
255     return RemoteService::WeakPtr();
256   }
257   return svc_iter->second->GetWeakPtr();
258 }
259 
SetInitializeClientCallback(InitializeClientCallback cb)260 void FakeLayer::SetInitializeClientCallback(InitializeClientCallback cb) {
261   initialize_client_cb_ = std::move(cb);
262 }
263 
set_list_services_status(att::Result<> status)264 void FakeLayer::set_list_services_status(att::Result<> status) {
265   list_services_status_ = status;
266 }
267 
SetSetPersistServiceChangedCCCCallbackCallback(SetPersistServiceChangedCCCCallbackCallback cb)268 void FakeLayer::SetSetPersistServiceChangedCCCCallbackCallback(
269     SetPersistServiceChangedCCCCallbackCallback cb) {
270   set_persist_service_changed_ccc_cb_cb_ = std::move(cb);
271 }
272 
SetSetRetrieveServiceChangedCCCCallbackCallback(SetRetrieveServiceChangedCCCCallbackCallback cb)273 void FakeLayer::SetSetRetrieveServiceChangedCCCCallbackCallback(
274     SetRetrieveServiceChangedCCCCallbackCallback cb) {
275   set_retrieve_service_changed_ccc_cb_cb_ = std::move(cb);
276 }
277 
CallPersistServiceChangedCCCCallback(PeerId peer_id,bool notify,bool indicate)278 void FakeLayer::CallPersistServiceChangedCCCCallback(PeerId peer_id,
279                                                      bool notify,
280                                                      bool indicate) {
281   persist_service_changed_ccc_cb_(peer_id,
282                                   {.notify = notify, .indicate = indicate});
283 }
284 
285 std::optional<ServiceChangedCCCPersistedData>
CallRetrieveServiceChangedCCCCallback(PeerId peer_id)286 FakeLayer::CallRetrieveServiceChangedCCCCallback(PeerId peer_id) {
287   return retrieve_service_changed_ccc_cb_(peer_id);
288 }
289 
290 }  // namespace bt::gatt::testing
291