xref: /aosp_15_r20/external/pigweed/pw_bluetooth_sapphire/host/sdp/server.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_bluetooth_sapphire/internal/host/sdp/server.h"
16 
17 #include <cstdint>
18 #include <cstdio>
19 
20 #include "pw_bluetooth_sapphire/internal/host/common/assert.h"
21 #include "pw_bluetooth_sapphire/internal/host/common/log.h"
22 #include "pw_bluetooth_sapphire/internal/host/common/random.h"
23 #include "pw_bluetooth_sapphire/internal/host/l2cap/l2cap_defs.h"
24 #include "pw_bluetooth_sapphire/internal/host/l2cap/types.h"
25 #include "pw_bluetooth_sapphire/internal/host/sdp/data_element.h"
26 #include "pw_bluetooth_sapphire/internal/host/sdp/pdu.h"
27 #include "pw_bluetooth_sapphire/internal/host/sdp/sdp.h"
28 
29 namespace bt::sdp {
30 
31 using RegistrationHandle = Server::RegistrationHandle;
32 
33 namespace {
34 
35 constexpr const char* kInspectRegisteredPsmName = "registered_psms";
36 constexpr const char* kInspectPsmName = "psm";
37 constexpr const char* kInspectRecordName = "record";
38 
IsQueuedPsm(const std::vector<std::pair<l2cap::Psm,ServiceHandle>> * queued_psms,l2cap::Psm psm)39 bool IsQueuedPsm(
40     const std::vector<std::pair<l2cap::Psm, ServiceHandle>>* queued_psms,
41     l2cap::Psm psm) {
42   auto is_queued = [target = psm](const auto& psm_pair) {
43     return target == psm_pair.first;
44   };
45   auto iter = std::find_if(queued_psms->begin(), queued_psms->end(), is_queued);
46   return iter != queued_psms->end();
47 }
48 
49 // Returns true if the |psm| is considered valid.
IsValidPsm(l2cap::Psm psm)50 bool IsValidPsm(l2cap::Psm psm) {
51   // The least significant bit of the most significant octet must be 0
52   // (Core 5.4, Vol 3, Part A, 4.2).
53   constexpr uint16_t MS_OCTET_MASK = 0x0100;
54   if (psm & MS_OCTET_MASK) {
55     return false;
56   }
57 
58   // The least significant bit of all other octets must be 1
59   // (Core 5.4, Vol 3, Part A, 4.2).
60   constexpr uint16_t LOWER_OCTET_MASK = 0x0001;
61   if ((psm & LOWER_OCTET_MASK) != LOWER_OCTET_MASK) {
62     return false;
63   }
64   return true;
65 }
66 
67 // Updates the L2CAP |protocol| with the provided dynamic |new_psm|.
68 // Returns true if the list was updated, false if |protocol| is invalid.
UpdateProtocolWithL2capPsm(DataElement * protocol,l2cap::Psm new_psm)69 bool UpdateProtocolWithL2capPsm(DataElement* protocol, l2cap::Psm new_psm) {
70   bt_log(TRACE,
71          "sdp",
72          "Updating protocol with dynamic PSM: %s",
73          protocol->ToString().c_str());
74 
75   // A valid protocol is a sequence containing a UUID and PSM value (2
76   // elements).
77   auto l2cap_protocol = protocol->Get<std::vector<DataElement>>();
78   if (!l2cap_protocol || (*l2cap_protocol).size() != 2) {
79     return false;
80   }
81 
82   // The protocol should specify the L2CAP UUID.
83   const auto prot_uuid = (*l2cap_protocol).data();
84   if (!prot_uuid || prot_uuid->type() != DataElement::Type::kUuid ||
85       *prot_uuid->Get<UUID>() != protocol::kL2CAP) {
86     return false;
87   }
88 
89   // The second element should be the dynamic PSM. If found, update it.
90   auto dynamic_psm_elem = &(*l2cap_protocol)[1];
91   if (!dynamic_psm_elem->Get<uint16_t>() ||
92       dynamic_psm_elem->Get<uint16_t>() != Server::kDynamicPsm) {
93     bt_log(WARN, "sdp", "Request to update non-dynamic L2CAP PSM. Ignoring");
94     return false;
95   }
96   (*l2cap_protocol)[1] = DataElement(uint16_t{new_psm});
97   protocol->Set(std::move(*l2cap_protocol));
98 
99   bt_log(TRACE,
100          "sdp",
101          "Updated protocol list with dynamic PSM %s",
102          protocol->ToString().c_str());
103   return true;
104 }
105 
106 // Updates the L2CAP |protocol_list| with the dynamic |new_psm|.
107 // |protocol_list| must be a list of protocols- one of which must be L2CAP.
108 // Returns true if the list was updated with the |new_psm|, false otherwise.
UpdateProtocolListWithL2capPsm(DataElement & protocol_list,l2cap::Psm new_psm)109 bool UpdateProtocolListWithL2capPsm(DataElement& protocol_list,
110                                     l2cap::Psm new_psm) {
111   bt_log(TRACE,
112          "sdp",
113          "Updating protocol list with dynamic psm: %s",
114          protocol_list.ToString().c_str());
115 
116   auto protocol_seq = protocol_list.Get<std::vector<DataElement>>();
117   if (!protocol_seq) {
118     bt_log(TRACE, "sdp", "ProtocolDescriptorList is not a valid sequence");
119     return false;
120   }
121 
122   bool updated = false;
123   for (DataElement& protocol : (*protocol_seq)) {
124     if (UpdateProtocolWithL2capPsm(&protocol, new_psm)) {
125       updated = true;
126       break;
127     }
128   }
129 
130   protocol_list.Set(std::move(*protocol_seq));
131   return updated;
132 }
133 
134 // Finds the PSM that is specified in a ProtocolDescriptorList
135 // Returns l2cap::kInvalidPsm if none is found or the list is invalid
FindProtocolListPsm(const DataElement & protocol_list)136 l2cap::Psm FindProtocolListPsm(const DataElement& protocol_list) {
137   bt_log(TRACE,
138          "sdp",
139          "Trying to find PSM from %s",
140          protocol_list.ToString().c_str());
141   const auto* l2cap_protocol = protocol_list.At(0);
142   PW_DCHECK(l2cap_protocol);
143   const auto* prot_uuid = l2cap_protocol->At(0);
144   if (!prot_uuid || prot_uuid->type() != DataElement::Type::kUuid ||
145       *prot_uuid->Get<UUID>() != protocol::kL2CAP) {
146     bt_log(TRACE, "sdp", "ProtocolDescriptorList is not valid or not L2CAP");
147     return l2cap::kInvalidPsm;
148   }
149 
150   const auto* psm_elem = l2cap_protocol->At(1);
151   if (psm_elem && psm_elem->Get<uint16_t>()) {
152     return *psm_elem->Get<uint16_t>();
153   }
154   if (psm_elem) {
155     bt_log(TRACE, "sdp", "ProtocolDescriptorList invalid L2CAP parameter type");
156     return l2cap::kInvalidPsm;
157   }
158 
159   // The PSM is missing, determined by the next protocol.
160   const auto* next_protocol = protocol_list.At(1);
161   if (!next_protocol) {
162     bt_log(TRACE, "sdp", "L2CAP has no PSM and no additional protocol");
163     return l2cap::kInvalidPsm;
164   }
165   const auto* next_protocol_uuid = next_protocol->At(0);
166   if (!next_protocol_uuid ||
167       next_protocol_uuid->type() != DataElement::Type::kUuid) {
168     bt_log(TRACE, "sdp", "L2CAP has no PSM and additional protocol invalid");
169     return l2cap::kInvalidPsm;
170   }
171   UUID protocol_uuid = *next_protocol_uuid->Get<UUID>();
172   // When it's RFCOMM, the L2CAP protocol descriptor omits the PSM parameter
173   // See example in the SPP Spec, v1.2
174   if (protocol_uuid == protocol::kRFCOMM) {
175     return l2cap::kRFCOMM;
176   }
177   bt_log(TRACE, "sdp", "Can't determine L2CAP PSM from protocol");
178   return l2cap::kInvalidPsm;
179 }
180 
PsmFromProtocolList(const DataElement * protocol_list)181 l2cap::Psm PsmFromProtocolList(const DataElement* protocol_list) {
182   const auto* primary_protocol = protocol_list->At(0);
183   if (!primary_protocol) {
184     bt_log(TRACE, "sdp", "ProtocolDescriptorList is not a sequence");
185     return l2cap::kInvalidPsm;
186   }
187 
188   const auto* prot_uuid = primary_protocol->At(0);
189   if (!prot_uuid || prot_uuid->type() != DataElement::Type::kUuid) {
190     bt_log(TRACE, "sdp", "ProtocolDescriptorList is not valid");
191     return l2cap::kInvalidPsm;
192   }
193 
194   // We do nothing for primary protocols that are not L2CAP
195   if (*prot_uuid->Get<UUID>() != protocol::kL2CAP) {
196     return l2cap::kInvalidPsm;
197   }
198 
199   l2cap::Psm psm = FindProtocolListPsm(*protocol_list);
200   if (psm == l2cap::kInvalidPsm) {
201     bt_log(TRACE, "sdp", "Couldn't find PSM from ProtocolDescriptorList");
202     return l2cap::kInvalidPsm;
203   }
204 
205   return psm;
206 }
207 
208 // Sets the browse group list of the record to be the top-level group.
SetBrowseGroupList(ServiceRecord * record)209 void SetBrowseGroupList(ServiceRecord* record) {
210   std::vector<DataElement> browse_list;
211   browse_list.emplace_back(kPublicBrowseRootUuid);
212   record->SetAttribute(kBrowseGroupList, DataElement(std::move(browse_list)));
213 }
214 
215 }  // namespace
216 
217 // The VersionNumberList value. (5.0, Vol 3, Part B, 5.2.3)
218 constexpr uint16_t kVersion = 0x0100;  // Version 1.0
219 
220 // The initial ServiceDatabaseState
221 constexpr uint32_t kInitialDbState = 0;
222 
223 // Populates the ServiceDiscoveryService record.
MakeServiceDiscoveryService()224 ServiceRecord Server::MakeServiceDiscoveryService() {
225   ServiceRecord sdp;
226   sdp.SetHandle(kSDPHandle);
227 
228   // ServiceClassIDList attribute should have the
229   // ServiceDiscoveryServerServiceClassID
230   // See v5.0, Vol 3, Part B, Sec 5.2.2
231   sdp.SetServiceClassUUIDs({profile::kServiceDiscoveryClass});
232 
233   // The VersionNumberList attribute. See v5.0, Vol 3, Part B, Sec 5.2.3
234   // Version 1.0
235   std::vector<DataElement> version_attribute;
236   version_attribute.emplace_back(kVersion);
237   sdp.SetAttribute(kSDP_VersionNumberList,
238                    DataElement(std::move(version_attribute)));
239 
240   // ServiceDatabaseState attribute. Changes when a service gets added or
241   // removed.
242   sdp.SetAttribute(kSDP_ServiceDatabaseState, DataElement(kInitialDbState));
243 
244   return sdp;
245 }
246 
Server(l2cap::ChannelManager * l2cap)247 Server::Server(l2cap::ChannelManager* l2cap)
248     : l2cap_(l2cap),
249       next_handle_(kFirstUnreservedHandle),
250       db_state_(0),
251       weak_ptr_factory_(this) {
252   PW_CHECK(l2cap_);
253 
254   records_.emplace(kSDPHandle, Server::MakeServiceDiscoveryService());
255 
256   // Register SDP
257   l2cap::ChannelParameters sdp_chan_params;
258   sdp_chan_params.mode = l2cap::RetransmissionAndFlowControlMode::kBasic;
259   l2cap_->RegisterService(
260       l2cap::kSDP,
261       sdp_chan_params,
262       [self = weak_ptr_factory_.GetWeakPtr()](auto channel) {
263         if (self.is_alive())
264           self->AddConnection(channel);
265       });
266 
267   // SDP is used by SDP server.
268   psm_to_service_.emplace(l2cap::kSDP,
269                           std::unordered_set<ServiceHandle>({kSDPHandle}));
270   service_to_psms_.emplace(kSDPHandle,
271                            std::unordered_set<l2cap::Psm>({l2cap::kSDP}));
272 
273   // Update the inspect properties after Server initialization.
274   UpdateInspectProperties();
275 }
276 
~Server()277 Server::~Server() { l2cap_->UnregisterService(l2cap::kSDP); }
278 
AttachInspect(inspect::Node & parent,std::string name)279 void Server::AttachInspect(inspect::Node& parent, std::string name) {
280   inspect_properties_.sdp_server_node = parent.CreateChild(name);
281   UpdateInspectProperties();
282 }
283 
AddConnection(l2cap::Channel::WeakPtr channel)284 bool Server::AddConnection(l2cap::Channel::WeakPtr channel) {
285   PW_CHECK(channel.is_alive());
286   hci_spec::ConnectionHandle handle = channel->link_handle();
287   bt_log(DEBUG, "sdp", "add connection handle %#.4x", handle);
288 
289   l2cap::Channel::UniqueId chan_id = channel->unique_id();
290   auto iter = channels_.find(chan_id);
291   if (iter != channels_.end()) {
292     bt_log(WARN, "sdp", "l2cap channel to %#.4x already connected", handle);
293     return false;
294   }
295 
296   auto self = weak_ptr_factory_.GetWeakPtr();
297   bool activated = channel->Activate(
298       [self, chan_id, max_tx_sdu_size = channel->max_tx_sdu_size()](
299           ByteBufferPtr sdu) {
300         if (self.is_alive()) {
301           auto packet = self->HandleRequest(std::move(sdu), max_tx_sdu_size);
302           if (packet) {
303             self->Send(chan_id, std::move(packet.value()));
304           }
305         }
306       },
307       [self, chan_id] {
308         if (self.is_alive()) {
309           self->OnChannelClosed(chan_id);
310         }
311       });
312   if (!activated) {
313     bt_log(WARN, "sdp", "failed to activate channel (handle %#.4x)", handle);
314     return false;
315   }
316   self->channels_.emplace(chan_id, std::move(channel));
317   return true;
318 }
319 
AddPsmToProtocol(ProtocolQueue * protocols_to_register,l2cap::Psm psm,ServiceHandle handle) const320 bool Server::AddPsmToProtocol(ProtocolQueue* protocols_to_register,
321                               l2cap::Psm psm,
322                               ServiceHandle handle) const {
323   if (psm == l2cap::kInvalidPsm) {
324     return false;
325   }
326 
327   if (IsAllocated(psm)) {
328     bt_log(TRACE, "sdp", "L2CAP PSM %#.4x is already allocated", psm);
329     return false;
330   }
331 
332   auto data = std::make_pair(psm, handle);
333   protocols_to_register->emplace_back(std::move(data));
334   return true;
335 }
336 
GetDynamicPsm(const ProtocolQueue * queued_psms) const337 l2cap::Psm Server::GetDynamicPsm(const ProtocolQueue* queued_psms) const {
338   // Generate a random PSM in the valid range of PSMs.
339   // RNG(Range(MIN, MAX)) = MIN + RNG(MAX-MIN) where MIN = kMinDynamicPSM =
340   // 0x1001. MAX = 0xffff.
341   uint16_t offset = 0;
342   constexpr uint16_t MAX_MINUS_MIN = 0xeffe;
343   random_generator()->GetInt(offset, MAX_MINUS_MIN);
344   uint16_t psm = l2cap::kMinDynamicPsm + offset;
345   // LSB of upper octet must be 0. LSB of lower octet must be 1.
346   constexpr uint16_t UPPER_OCTET_MASK = 0xFEFF;
347   constexpr uint16_t LOWER_OCTET_MASK = 0x0001;
348   psm &= UPPER_OCTET_MASK;
349   psm |= LOWER_OCTET_MASK;
350   bt_log(DEBUG, "sdp", "Trying random dynamic PSM %#.4x", psm);
351 
352   // Check if the PSM is valid (e.g. valid construction, not allocated, & not
353   // queued).
354   if ((IsValidPsm(psm)) && (!IsAllocated(psm)) &&
355       (!IsQueuedPsm(queued_psms, psm))) {
356     bt_log(TRACE, "sdp", "Generated random dynamic PSM %#.4x", psm);
357     return psm;
358   }
359 
360   // Otherwise, fall back to sequentially finding the next available PSM.
361   bool search_wrapped = false;
362   for (uint16_t next_psm = psm + 2; next_psm <= UINT16_MAX; next_psm += 2) {
363     if ((IsValidPsm(next_psm)) && (!IsAllocated(next_psm)) &&
364         (!IsQueuedPsm(queued_psms, next_psm))) {
365       bt_log(TRACE, "sdp", "Generated sequential dynamic PSM %#.4x", next_psm);
366       return next_psm;
367     }
368 
369     // If we reach the max valid PSM, wrap around to the minimum valid dynamic
370     // PSM. Only try this once.
371     if (next_psm == 0xFEFF) {
372       next_psm = l2cap::kMinDynamicPsm;
373       if (search_wrapped) {
374         break;
375       }
376       search_wrapped = true;
377     }
378   }
379   bt_log(WARN, "sdp", "Couldn't find an available dynamic PSM");
380   return l2cap::kInvalidPsm;
381 }
382 
QueueService(ServiceRecord * record,ProtocolQueue * protocols_to_register)383 bool Server::QueueService(ServiceRecord* record,
384                           ProtocolQueue* protocols_to_register) {
385   // ProtocolDescriptorList handling:
386   if (record->HasAttribute(kProtocolDescriptorList)) {
387     const auto& primary_protocol =
388         record->GetAttribute(kProtocolDescriptorList);
389     auto psm = PsmFromProtocolList(&primary_protocol);
390     if (psm == kDynamicPsm) {
391       bt_log(TRACE, "sdp", "Primary protocol contains dynamic PSM");
392       auto primary_protocol_copy = primary_protocol.Clone();
393       psm = GetDynamicPsm(protocols_to_register);
394       if (!UpdateProtocolListWithL2capPsm(primary_protocol_copy, psm)) {
395         return false;
396       }
397       record->SetAttribute(kProtocolDescriptorList,
398                            std::move(primary_protocol_copy));
399     }
400     if (!AddPsmToProtocol(protocols_to_register, psm, record->handle())) {
401       return false;
402     }
403   }
404 
405   // AdditionalProtocolDescriptorList handling:
406   if (record->HasAttribute(kAdditionalProtocolDescriptorList)) {
407     // |additional_list| is a list of ProtocolDescriptorLists.
408     const auto& additional_list =
409         record->GetAttribute(kAdditionalProtocolDescriptorList);
410     size_t attribute_id = 0;
411     const auto* additional = additional_list.At(attribute_id);
412 
413     // If `kAdditionalProtocolDescriptorList` exists, there should be at least
414     // one protocol provided.
415     if (!additional) {
416       bt_log(
417           TRACE, "sdp", "AdditionalProtocolDescriptorList provided but empty");
418       return false;
419     }
420 
421     // Add valid additional PSMs to the register queue. Because some additional
422     // protocols may need dynamic PSM assignment, modify the relevant protocols
423     // and rebuild the list.
424     std::vector<DataElement> additional_protocols;
425     while (additional) {
426       auto psm = PsmFromProtocolList(additional);
427       auto additional_protocol_copy = additional->Clone();
428       if (psm == kDynamicPsm) {
429         bt_log(TRACE, "sdp", "Additional protocol contains dynamic PSM");
430         psm = GetDynamicPsm(protocols_to_register);
431         if (!UpdateProtocolListWithL2capPsm(additional_protocol_copy, psm)) {
432           return l2cap::kInvalidPsm;
433         }
434       }
435       if (!AddPsmToProtocol(protocols_to_register, psm, record->handle())) {
436         return false;
437       }
438 
439       attribute_id++;
440       additional_protocols.emplace_back(std::move(additional_protocol_copy));
441       additional = additional_list.At(attribute_id);
442     }
443     record->SetAttribute(kAdditionalProtocolDescriptorList,
444                          DataElement(std::move(additional_protocols)));
445   }
446 
447   // For some services that depend on OBEX, the L2CAP PSM is specified in the
448   // GoepL2capPsm attribute.
449   bool has_obex = record->FindUUID(std::unordered_set<UUID>({protocol::kOBEX}));
450   if (has_obex && record->HasAttribute(kGoepL2capPsm)) {
451     const auto& attribute = record->GetAttribute(kGoepL2capPsm);
452     if (attribute.Get<uint16_t>()) {
453       auto psm = *attribute.Get<uint16_t>();
454       // If a dynamic PSM was requested, attempt to allocate the next available
455       // PSM.
456       if (psm == kDynamicPsm) {
457         bt_log(TRACE, "sdp", "GoepL2capAttribute contains dynamic PSM");
458         psm = GetDynamicPsm(protocols_to_register);
459         record->SetAttribute(kGoepL2capPsm, DataElement(uint16_t{psm}));
460       }
461       if (!AddPsmToProtocol(protocols_to_register, psm, record->handle())) {
462         return false;
463       }
464     }
465   }
466 
467   return true;
468 }
469 
RegisterService(std::vector<ServiceRecord> records,l2cap::ChannelParameters chan_params,ConnectCallback conn_cb)470 RegistrationHandle Server::RegisterService(std::vector<ServiceRecord> records,
471                                            l2cap::ChannelParameters chan_params,
472                                            ConnectCallback conn_cb) {
473   if (records.empty()) {
474     return 0;
475   }
476 
477   // The PSMs and their ServiceHandles to register.
478   ProtocolQueue protocols_to_register;
479 
480   // The ServiceHandles that are assigned to each ServiceRecord.
481   // There should be one ServiceHandle per ServiceRecord in |records|.
482   std::set<ServiceHandle> assigned_handles;
483 
484   for (auto& record : records) {
485     ServiceHandle next = GetNextHandle();
486     if (!next) {
487       return 0;
488     }
489     // Assign a new handle for the service record.
490     record.SetHandle(next);
491 
492     if (!record.IsProtocolOnly()) {
493       // Place record in a browse group.
494       SetBrowseGroupList(&record);
495 
496       // Validate the |ServiceRecord|.
497       if (!record.IsRegisterable()) {
498         return 0;
499       }
500     }
501 
502     // Attempt to queue the |record| for registration.
503     // Note: Since the validation & queueing operations for ALL the records
504     // occur before registration, multiple ServiceRecords can share the same
505     // PSM.
506     //
507     // If any |record| is not parsable, exit the registration process early.
508     if (!QueueService(&record, &protocols_to_register)) {
509       return 0;
510     }
511 
512     // For every ServiceRecord, there will be one ServiceHandle assigned.
513     assigned_handles.emplace(next);
514   }
515 
516   PW_CHECK(assigned_handles.size() == records.size());
517 
518   // The RegistrationHandle is the smallest ServiceHandle that was assigned.
519   RegistrationHandle reg_handle = *assigned_handles.begin();
520 
521   // Multiple ServiceRecords in |records| can request the same PSM. However,
522   // |l2cap_| expects a single target for each PSM to go to. Consequently,
523   // only the first occurrence of a PSM needs to be registered with the
524   // |l2cap_|.
525   std::unordered_set<l2cap::Psm> psms_to_register;
526 
527   // All PSMs have assigned handles and will be registered.
528   for (auto& [psm, handle] : protocols_to_register) {
529     psm_to_service_[psm].insert(handle);
530     service_to_psms_[handle].insert(psm);
531 
532     // Add unique PSMs to the data domain registration queue.
533     psms_to_register.insert(psm);
534   }
535 
536   for (const auto& psm : psms_to_register) {
537     bt_log(TRACE, "sdp", "Allocating PSM %#.4x for new service", psm);
538     l2cap_->RegisterService(
539         psm,
540         chan_params,
541         [l2cap_psm = psm, conn_cb_shared = conn_cb.share()](
542             l2cap::Channel::WeakPtr channel) mutable {
543           bt_log(TRACE, "sdp", "Channel connected to %#.4x", l2cap_psm);
544           // Build the L2CAP descriptor
545           std::vector<DataElement> protocol_l2cap;
546           protocol_l2cap.emplace_back(protocol::kL2CAP);
547           protocol_l2cap.emplace_back(l2cap_psm);
548           std::vector<DataElement> protocol;
549           protocol.emplace_back(std::move(protocol_l2cap));
550           conn_cb_shared(std::move(channel), DataElement(std::move(protocol)));
551         });
552   }
553 
554   // Store the complete records.
555   for (auto& record : records) {
556     auto [it, success] = records_.emplace(record.handle(), std::move(record));
557     PW_DCHECK(success);
558     const ServiceRecord& placed_record = it->second;
559     if (placed_record.IsProtocolOnly()) {
560       bt_log(TRACE,
561              "sdp",
562              "registered protocol-only service %#.8x, Protocol: %s",
563              placed_record.handle(),
564              bt_str(placed_record.GetAttribute(kProtocolDescriptorList)));
565     } else {
566       bt_log(TRACE,
567              "sdp",
568              "registered service %#.8x, classes: %s",
569              placed_record.handle(),
570              bt_str(placed_record.GetAttribute(kServiceClassIdList)));
571     }
572   }
573 
574   // Store the RegistrationHandle that represents the set of services that were
575   // registered.
576   reg_to_service_[reg_handle] = std::move(assigned_handles);
577 
578   // Update the inspect properties.
579   UpdateInspectProperties();
580 
581   return reg_handle;
582 }
583 
UnregisterService(RegistrationHandle handle)584 bool Server::UnregisterService(RegistrationHandle handle) {
585   if (handle == kNotRegistered) {
586     return false;
587   }
588 
589   auto handles_it = reg_to_service_.extract(handle);
590   if (!handles_it) {
591     return false;
592   }
593 
594   for (const auto& svc_h : handles_it.mapped()) {
595     PW_CHECK(svc_h != kSDPHandle);
596     PW_CHECK(records_.find(svc_h) != records_.end());
597     bt_log(DEBUG, "sdp", "unregistering service (handle: %#.8x)", svc_h);
598 
599     // Unregister any service callbacks from L2CAP
600     auto psms_it = service_to_psms_.extract(svc_h);
601     if (psms_it) {
602       for (const auto& psm : psms_it.mapped()) {
603         bt_log(DEBUG, "sdp", "removing registration for psm %#.4x", psm);
604         l2cap_->UnregisterService(psm);
605         psm_to_service_.erase(psm);
606       }
607     }
608 
609     records_.erase(svc_h);
610   }
611 
612   // Update the inspect properties as the registered PSMs may have changed.
613   UpdateInspectProperties();
614 
615   return true;
616 }
617 
GetRegisteredServices(RegistrationHandle handle) const618 std::vector<ServiceRecord> Server::GetRegisteredServices(
619     RegistrationHandle handle) const {
620   std::vector<ServiceRecord> out;
621   if (handle == kNotRegistered) {
622     return out;
623   }
624 
625   auto service_handles_it = reg_to_service_.find(handle);
626   if (service_handles_it == reg_to_service_.end()) {
627     return out;
628   }
629 
630   for (const auto& service_handle : service_handles_it->second) {
631     auto record_it = records_.find(service_handle);
632     if (record_it != records_.end()) {
633       ServiceRecord record_copy = record_it->second;
634       out.emplace_back(std::move(record_copy));
635     }
636   }
637 
638   return out;
639 }
640 
GetNextHandle()641 ServiceHandle Server::GetNextHandle() {
642   ServiceHandle initial_next_handle = next_handle_;
643   // We expect most of these to be free.
644   // Safeguard against possibly having to wrap-around and reuse handles.
645   while (records_.count(next_handle_)) {
646     if (next_handle_ == kLastHandle) {
647       bt_log(WARN, "sdp", "service handle wrapped to start");
648       next_handle_ = kFirstUnreservedHandle;
649     } else {
650       next_handle_++;
651     }
652     if (next_handle_ == initial_next_handle) {
653       return 0;
654     }
655   }
656   return next_handle_++;
657 }
658 
SearchServices(const std::unordered_set<UUID> & pattern) const659 ServiceSearchResponse Server::SearchServices(
660     const std::unordered_set<UUID>& pattern) const {
661   ServiceSearchResponse resp;
662   std::vector<ServiceHandle> matched;
663   for (const auto& it : records_) {
664     if (it.second.FindUUID(pattern) && !it.second.IsProtocolOnly()) {
665       matched.push_back(it.first);
666     }
667   }
668   bt_log(TRACE, "sdp", "ServiceSearch matched %zu records", matched.size());
669   resp.set_service_record_handle_list(matched);
670   return resp;
671 }
672 
GetServiceAttributes(ServiceHandle handle,const std::list<AttributeRange> & ranges) const673 ServiceAttributeResponse Server::GetServiceAttributes(
674     ServiceHandle handle, const std::list<AttributeRange>& ranges) const {
675   ServiceAttributeResponse resp;
676   const auto& record = records_.at(handle);
677   for (const auto& range : ranges) {
678     auto attrs = record.GetAttributesInRange(range.start, range.end);
679     for (const auto& attr : attrs) {
680       resp.set_attribute(attr, record.GetAttribute(attr).Clone());
681     }
682   }
683   bt_log(TRACE,
684          "sdp",
685          "ServiceAttribute %zu attributes",
686          resp.attributes().size());
687   return resp;
688 }
689 
SearchAllServiceAttributes(const std::unordered_set<UUID> & search_pattern,const std::list<AttributeRange> & attribute_ranges) const690 ServiceSearchAttributeResponse Server::SearchAllServiceAttributes(
691     const std::unordered_set<UUID>& search_pattern,
692     const std::list<AttributeRange>& attribute_ranges) const {
693   ServiceSearchAttributeResponse resp;
694   for (const auto& it : records_) {
695     const auto& rec = it.second;
696     if (rec.IsProtocolOnly()) {
697       continue;
698     }
699     if (rec.FindUUID(search_pattern)) {
700       for (const auto& range : attribute_ranges) {
701         auto attrs = rec.GetAttributesInRange(range.start, range.end);
702         for (const auto& attr : attrs) {
703           resp.SetAttribute(it.first, attr, rec.GetAttribute(attr).Clone());
704         }
705       }
706     }
707   }
708 
709   bt_log(TRACE,
710          "sdp",
711          "ServiceSearchAttribute %zu records",
712          resp.num_attribute_lists());
713   return resp;
714 }
715 
OnChannelClosed(l2cap::Channel::UniqueId channel_id)716 void Server::OnChannelClosed(l2cap::Channel::UniqueId channel_id) {
717   channels_.erase(channel_id);
718 }
719 
HandleRequest(ByteBufferPtr sdu,uint16_t max_tx_sdu_size)720 std::optional<ByteBufferPtr> Server::HandleRequest(ByteBufferPtr sdu,
721                                                    uint16_t max_tx_sdu_size) {
722   PW_DCHECK(sdu);
723   TRACE_DURATION("bluetooth", "sdp::Server::HandleRequest");
724   if (sdu->size() < sizeof(Header)) {
725     bt_log(DEBUG, "sdp", "PDU too short; dropping");
726     return std::nullopt;
727   }
728   PacketView<Header> packet(sdu.get());
729   TransactionId tid =
730       pw::bytes::ConvertOrderFrom(cpp20::endian::big, packet.header().tid);
731   uint16_t param_length = pw::bytes::ConvertOrderFrom(
732       cpp20::endian::big, packet.header().param_length);
733   auto error_response_builder =
734       [tid, max_tx_sdu_size](ErrorCode code) -> ByteBufferPtr {
735     return ErrorResponse(code).GetPDU(
736         0 /* ignored */, tid, max_tx_sdu_size, BufferView());
737   };
738   if (param_length != (sdu->size() - sizeof(Header))) {
739     bt_log(TRACE,
740            "sdp",
741            "request isn't the correct size (%hu != %zu)",
742            param_length,
743            sdu->size() - sizeof(Header));
744     return error_response_builder(ErrorCode::kInvalidSize);
745   }
746   packet.Resize(param_length);
747   switch (packet.header().pdu_id) {
748     case kServiceSearchRequest: {
749       ServiceSearchRequest request(packet.payload_data());
750       if (!request.valid()) {
751         bt_log(DEBUG, "sdp", "ServiceSearchRequest not valid");
752         return error_response_builder(ErrorCode::kInvalidRequestSyntax);
753       }
754       auto resp = SearchServices(request.service_search_pattern());
755 
756       auto bytes = resp.GetPDU(request.max_service_record_count(),
757                                tid,
758                                max_tx_sdu_size,
759                                request.ContinuationState());
760       if (!bytes) {
761         return error_response_builder(ErrorCode::kInvalidContinuationState);
762       }
763       return std::move(bytes);
764     }
765     case kServiceAttributeRequest: {
766       ServiceAttributeRequest request(packet.payload_data());
767       if (!request.valid()) {
768         bt_log(TRACE, "sdp", "ServiceAttributeRequest not valid");
769         return error_response_builder(ErrorCode::kInvalidRequestSyntax);
770       }
771       auto handle = request.service_record_handle();
772       auto record_it = records_.find(handle);
773       if (record_it == records_.end() || record_it->second.IsProtocolOnly()) {
774         bt_log(TRACE,
775                "sdp",
776                "ServiceAttributeRequest can't find handle %#.8x",
777                handle);
778         return error_response_builder(ErrorCode::kInvalidRecordHandle);
779       }
780       auto resp = GetServiceAttributes(handle, request.attribute_ranges());
781       auto bytes = resp.GetPDU(request.max_attribute_byte_count(),
782                                tid,
783                                max_tx_sdu_size,
784                                request.ContinuationState());
785       if (!bytes) {
786         return error_response_builder(ErrorCode::kInvalidContinuationState);
787       }
788       return std::move(bytes);
789     }
790     case kServiceSearchAttributeRequest: {
791       ServiceSearchAttributeRequest request(packet.payload_data());
792       if (!request.valid()) {
793         bt_log(TRACE, "sdp", "ServiceSearchAttributeRequest not valid");
794         return error_response_builder(ErrorCode::kInvalidRequestSyntax);
795       }
796       auto resp = SearchAllServiceAttributes(request.service_search_pattern(),
797                                              request.attribute_ranges());
798       auto bytes = resp.GetPDU(request.max_attribute_byte_count(),
799                                tid,
800                                max_tx_sdu_size,
801                                request.ContinuationState());
802       if (!bytes) {
803         return error_response_builder(ErrorCode::kInvalidContinuationState);
804       }
805       return std::move(bytes);
806     }
807     case kErrorResponse: {
808       bt_log(TRACE, "sdp", "ErrorResponse isn't allowed as a request");
809       return error_response_builder(ErrorCode::kInvalidRequestSyntax);
810     }
811     default: {
812       bt_log(TRACE, "sdp", "unhandled request, returning InvalidRequest");
813       return error_response_builder(ErrorCode::kInvalidRequestSyntax);
814     }
815   }
816 }
817 
Send(l2cap::Channel::UniqueId channel_id,ByteBufferPtr bytes)818 void Server::Send(l2cap::Channel::UniqueId channel_id, ByteBufferPtr bytes) {
819   auto it = channels_.find(channel_id);
820   if (it == channels_.end()) {
821     bt_log(ERROR, "sdp", "can't find peer to respond to; dropping");
822     return;
823   }
824   l2cap::Channel::WeakPtr chan = it->second.get();
825   chan->Send(std::move(bytes));
826 }
827 
UpdateInspectProperties()828 void Server::UpdateInspectProperties() {
829   // Skip update if node has not been attached.
830   if (!inspect_properties_.sdp_server_node) {
831     return;
832   }
833 
834   // Clear the previous inspect data.
835   inspect_properties_.svc_record_properties.clear();
836 
837   for (const auto& svc_record : records_) {
838     auto record_string = svc_record.second.ToString();
839     auto psms_it = service_to_psms_.find(svc_record.first);
840     std::unordered_set<l2cap::Psm> psm_set;
841     if (psms_it != service_to_psms_.end()) {
842       psm_set = psms_it->second;
843     }
844 
845     InspectProperties::InspectServiceRecordProperties svc_rec_props(
846         std::move(record_string), std::move(psm_set));
847     auto& parent = inspect_properties_.sdp_server_node;
848     svc_rec_props.AttachInspect(parent, parent.UniqueName(kInspectRecordName));
849 
850     inspect_properties_.svc_record_properties.push_back(
851         std::move(svc_rec_props));
852   }
853 }
854 
AllocatedPsmsForTest() const855 std::set<l2cap::Psm> Server::AllocatedPsmsForTest() const {
856   std::set<l2cap::Psm> allocated;
857   for (auto it = psm_to_service_.begin(); it != psm_to_service_.end(); ++it) {
858     allocated.insert(it->first);
859   }
860   return allocated;
861 }
862 
863 Server::InspectProperties::InspectServiceRecordProperties::
InspectServiceRecordProperties(std::string record_in,std::unordered_set<l2cap::Psm> psms_in)864     InspectServiceRecordProperties(std::string record_in,
865                                    std::unordered_set<l2cap::Psm> psms_in)
866     : record(std::move(record_in)), psms(std::move(psms_in)) {}
867 
AttachInspect(inspect::Node & parent,std::string name)868 void Server::InspectProperties::InspectServiceRecordProperties::AttachInspect(
869     inspect::Node& parent, std::string name) {
870   node = parent.CreateChild(name);
871   record_property = node.CreateString(kInspectRecordName, record);
872   psms_node = node.CreateChild(kInspectRegisteredPsmName);
873   psm_nodes.clear();
874   for (const auto& psm : psms) {
875     auto psm_node =
876         psms_node.CreateChild(psms_node.UniqueName(kInspectPsmName));
877     auto psm_string =
878         psm_node.CreateString(kInspectPsmName, l2cap::PsmToString(psm));
879     psm_nodes.emplace_back(std::move(psm_node), std::move(psm_string));
880   }
881 }
882 
883 }  // namespace bt::sdp
884