1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_bluetooth_sapphire/internal/host/gatt/fake_layer.h"
16
17 #include "pw_bluetooth_sapphire/internal/host/gatt/remote_service.h"
18
19 namespace bt::gatt::testing {
20
TestPeer(pw::async::Dispatcher & pw_dispatcher)21 FakeLayer::TestPeer::TestPeer(pw::async::Dispatcher& pw_dispatcher)
22 : fake_client(pw_dispatcher) {}
23
24 std::pair<RemoteService::WeakPtr, FakeClient::WeakPtr>
AddPeerService(PeerId peer_id,const ServiceData & info,bool notify)25 FakeLayer::AddPeerService(PeerId peer_id,
26 const ServiceData& info,
27 bool notify) {
28 auto [iter, _] = peers_.try_emplace(peer_id, pw_dispatcher_);
29 auto& peer = iter->second;
30
31 PW_CHECK(info.range_start <= info.range_end);
32 auto service =
33 std::make_unique<RemoteService>(info, peer.fake_client.GetWeakPtr());
34 RemoteService::WeakPtr service_weak = service->GetWeakPtr();
35
36 std::vector<att::Handle> removed;
37 ServiceList added;
38 ServiceList modified;
39
40 auto svc_iter = peer.services.find(info.range_start);
41 if (svc_iter != peer.services.end()) {
42 if (svc_iter->second->uuid() == info.type) {
43 modified.push_back(service_weak);
44 } else {
45 removed.push_back(svc_iter->second->handle());
46 added.push_back(service_weak);
47 }
48
49 svc_iter->second->set_service_changed(true);
50 peer.services.erase(svc_iter);
51 } else {
52 added.push_back(service_weak);
53 }
54
55 bt_log(DEBUG,
56 "gatt",
57 "services changed (removed: %zu, added: %zu, modified: %zu)",
58 removed.size(),
59 added.size(),
60 modified.size());
61
62 peer.services.emplace(info.range_start, std::move(service));
63
64 if (notify && remote_service_watchers_.count(peer_id)) {
65 remote_service_watchers_[peer_id](removed, added, modified);
66 }
67
68 return {service_weak, peer.fake_client.AsFakeWeakPtr()};
69 }
70
RemovePeerService(PeerId peer_id,att::Handle handle)71 void FakeLayer::RemovePeerService(PeerId peer_id, att::Handle handle) {
72 auto peer_iter = peers_.find(peer_id);
73 if (peer_iter == peers_.end()) {
74 return;
75 }
76 auto svc_iter = peer_iter->second.services.find(handle);
77 if (svc_iter == peer_iter->second.services.end()) {
78 return;
79 }
80 svc_iter->second->set_service_changed(true);
81 peer_iter->second.services.erase(svc_iter);
82
83 if (remote_service_watchers_.count(peer_id)) {
84 remote_service_watchers_[peer_id](
85 /*removed=*/{handle}, /*added=*/{}, /*modified=*/{});
86 }
87 }
88
AddConnection(PeerId peer_id,std::unique_ptr<Client>,Server::FactoryFunction)89 void FakeLayer::AddConnection(PeerId peer_id,
90 std::unique_ptr<Client>,
91 Server::FactoryFunction) {
92 peers_.try_emplace(peer_id, pw_dispatcher_);
93 }
94
RemoveConnection(PeerId peer_id)95 void FakeLayer::RemoveConnection(PeerId peer_id) { peers_.erase(peer_id); }
96
RegisterPeerMtuListener(PeerMtuListener)97 GATT::PeerMtuListenerId FakeLayer::RegisterPeerMtuListener(PeerMtuListener) {
98 BT_PANIC("implement fake behavior if needed");
99 }
100
UnregisterPeerMtuListener(PeerMtuListenerId)101 bool FakeLayer::UnregisterPeerMtuListener(PeerMtuListenerId) {
102 BT_PANIC("implement fake behavior if needed");
103 }
104
RegisterService(ServicePtr service,ServiceIdCallback callback,ReadHandler read_handler,WriteHandler write_handler,ClientConfigCallback ccc_callback)105 void FakeLayer::RegisterService(ServicePtr service,
106 ServiceIdCallback callback,
107 ReadHandler read_handler,
108 WriteHandler write_handler,
109 ClientConfigCallback ccc_callback) {
110 if (register_service_fails_) {
111 callback(kInvalidId);
112 return;
113 }
114
115 IdType id = next_local_service_id_++;
116 local_services_.try_emplace(id,
117 LocalService{std::move(service),
118 std::move(read_handler),
119 std::move(write_handler),
120 std::move(ccc_callback),
121 {}});
122 callback(id);
123 }
124
UnregisterService(IdType service_id)125 void FakeLayer::UnregisterService(IdType service_id) {
126 local_services_.erase(service_id);
127 }
128
SendUpdate(IdType service_id,IdType chrc_id,PeerId peer_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)129 void FakeLayer::SendUpdate(IdType service_id,
130 IdType chrc_id,
131 PeerId peer_id,
132 ::std::vector<uint8_t> value,
133 IndicationCallback indicate_cb) {
134 auto iter = local_services_.find(service_id);
135 if (iter == local_services_.end()) {
136 indicate_cb(fit::error(att::ErrorCode::kInvalidHandle));
137 return;
138 }
139 iter->second.updates.push_back(
140 Update{chrc_id, std::move(value), std::move(indicate_cb), peer_id});
141 }
142
UpdateConnectedPeers(IdType service_id,IdType chrc_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)143 void FakeLayer::UpdateConnectedPeers(IdType service_id,
144 IdType chrc_id,
145 ::std::vector<uint8_t> value,
146 IndicationCallback indicate_cb) {
147 auto iter = local_services_.find(service_id);
148 if (iter == local_services_.end()) {
149 indicate_cb(fit::error(att::ErrorCode::kInvalidHandle));
150 return;
151 }
152 iter->second.updates.push_back(
153 Update{chrc_id, std::move(value), std::move(indicate_cb), std::nullopt});
154 }
155
SetPersistServiceChangedCCCCallback(PersistServiceChangedCCCCallback callback)156 void FakeLayer::SetPersistServiceChangedCCCCallback(
157 PersistServiceChangedCCCCallback callback) {
158 if (set_persist_service_changed_ccc_cb_cb_) {
159 set_persist_service_changed_ccc_cb_cb_();
160 }
161 persist_service_changed_ccc_cb_ = std::move(callback);
162 }
163
SetRetrieveServiceChangedCCCCallback(RetrieveServiceChangedCCCCallback callback)164 void FakeLayer::SetRetrieveServiceChangedCCCCallback(
165 RetrieveServiceChangedCCCCallback callback) {
166 if (set_retrieve_service_changed_ccc_cb_cb_) {
167 set_retrieve_service_changed_ccc_cb_cb_();
168 }
169 retrieve_service_changed_ccc_cb_ = std::move(callback);
170 }
171
InitializeClient(PeerId peer_id,std::vector<UUID> services_to_discover)172 void FakeLayer::InitializeClient(PeerId peer_id,
173 std::vector<UUID> services_to_discover) {
174 std::vector<UUID> uuids = std::move(services_to_discover);
175 if (initialize_client_cb_) {
176 initialize_client_cb_(peer_id, uuids);
177 }
178
179 auto iter = peers_.find(peer_id);
180 if (iter == peers_.end()) {
181 return;
182 }
183
184 std::vector<RemoteService::WeakPtr> added;
185 if (uuids.empty()) {
186 for (auto& svc_pair : iter->second.services) {
187 added.push_back(svc_pair.second->GetWeakPtr());
188 }
189 } else {
190 for (auto& svc_pair : iter->second.services) {
191 auto uuid_iter =
192 std::find_if(uuids.begin(), uuids.end(), [&svc_pair](auto uuid) {
193 return svc_pair.second->uuid() == uuid;
194 });
195 if (uuid_iter != uuids.end()) {
196 added.push_back(svc_pair.second->GetWeakPtr());
197 }
198 }
199 }
200
201 if (remote_service_watchers_.count(peer_id)) {
202 remote_service_watchers_[peer_id](
203 /*removed=*/{}, /*added=*/added, /*modified=*/{});
204 }
205 }
206
RegisterRemoteServiceWatcherForPeer(PeerId peer_id,RemoteServiceWatcher watcher)207 GATT::RemoteServiceWatcherId FakeLayer::RegisterRemoteServiceWatcherForPeer(
208 PeerId peer_id, RemoteServiceWatcher watcher) {
209 PW_CHECK(remote_service_watchers_.count(peer_id) == 0);
210 remote_service_watchers_[peer_id] = std::move(watcher);
211 // Use the PeerId as the watcher ID because FakeLayer only needs to support 1
212 // watcher per peer.
213 return peer_id.value();
214 }
UnregisterRemoteServiceWatcher(RemoteServiceWatcherId watcher_id)215 bool FakeLayer::UnregisterRemoteServiceWatcher(
216 RemoteServiceWatcherId watcher_id) {
217 bool result = remote_service_watchers_.count(PeerId(watcher_id));
218 remote_service_watchers_.erase(PeerId(watcher_id));
219 return result;
220 }
221
ListServices(PeerId peer_id,std::vector<UUID> uuids,ServiceListCallback callback)222 void FakeLayer::ListServices(PeerId peer_id,
223 std::vector<UUID> uuids,
224 ServiceListCallback callback) {
225 if (pause_list_services_) {
226 return;
227 }
228
229 ServiceList services;
230
231 auto iter = peers_.find(peer_id);
232 if (iter != peers_.end()) {
233 for (auto& svc_pair : iter->second.services) {
234 auto pred = [&](const UUID& uuid) {
235 return svc_pair.second->uuid() == uuid;
236 };
237 if (uuids.empty() ||
238 std::find_if(uuids.begin(), uuids.end(), pred) != uuids.end()) {
239 services.push_back(svc_pair.second->GetWeakPtr());
240 }
241 }
242 }
243
244 callback(list_services_status_, std::move(services));
245 }
246
FindService(PeerId peer_id,IdType service_id)247 RemoteService::WeakPtr FakeLayer::FindService(PeerId peer_id,
248 IdType service_id) {
249 auto peer_iter = peers_.find(peer_id);
250 if (peer_iter == peers_.end()) {
251 return RemoteService::WeakPtr();
252 }
253 auto svc_iter = peer_iter->second.services.find(service_id);
254 if (svc_iter == peer_iter->second.services.end()) {
255 return RemoteService::WeakPtr();
256 }
257 return svc_iter->second->GetWeakPtr();
258 }
259
SetInitializeClientCallback(InitializeClientCallback cb)260 void FakeLayer::SetInitializeClientCallback(InitializeClientCallback cb) {
261 initialize_client_cb_ = std::move(cb);
262 }
263
set_list_services_status(att::Result<> status)264 void FakeLayer::set_list_services_status(att::Result<> status) {
265 list_services_status_ = status;
266 }
267
SetSetPersistServiceChangedCCCCallbackCallback(SetPersistServiceChangedCCCCallbackCallback cb)268 void FakeLayer::SetSetPersistServiceChangedCCCCallbackCallback(
269 SetPersistServiceChangedCCCCallbackCallback cb) {
270 set_persist_service_changed_ccc_cb_cb_ = std::move(cb);
271 }
272
SetSetRetrieveServiceChangedCCCCallbackCallback(SetRetrieveServiceChangedCCCCallbackCallback cb)273 void FakeLayer::SetSetRetrieveServiceChangedCCCCallbackCallback(
274 SetRetrieveServiceChangedCCCCallbackCallback cb) {
275 set_retrieve_service_changed_ccc_cb_cb_ = std::move(cb);
276 }
277
CallPersistServiceChangedCCCCallback(PeerId peer_id,bool notify,bool indicate)278 void FakeLayer::CallPersistServiceChangedCCCCallback(PeerId peer_id,
279 bool notify,
280 bool indicate) {
281 persist_service_changed_ccc_cb_(peer_id,
282 {.notify = notify, .indicate = indicate});
283 }
284
285 std::optional<ServiceChangedCCCPersistedData>
CallRetrieveServiceChangedCCCCallback(PeerId peer_id)286 FakeLayer::CallRetrieveServiceChangedCCCCallback(PeerId peer_id) {
287 return retrieve_service_changed_ccc_cb_(peer_id);
288 }
289
290 } // namespace bt::gatt::testing
291