xref: /aosp_15_r20/external/pigweed/pw_bluetooth_sapphire/host/sdp/service_record.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_bluetooth_sapphire/internal/host/sdp/service_record.h"
16 
17 #include <pw_bytes/endian.h>
18 
19 #include <iterator>
20 #include <set>
21 #include <vector>
22 
23 #include "pw_bluetooth_sapphire/internal/host/common/log.h"
24 #include "pw_bluetooth_sapphire/internal/host/sdp/sdp.h"
25 
26 namespace bt::sdp {
27 
28 namespace {
29 
30 // Adds all UUIDs that it finds in |elem| to |out|, recursing through
31 // sequences and alternatives if necessary.
AddAllUUIDs(const DataElement & elem,std::unordered_set<UUID> * out)32 void AddAllUUIDs(const DataElement& elem, std::unordered_set<UUID>* out) {
33   DataElement::Type type = elem.type();
34   if (type == DataElement::Type::kUuid) {
35     out->emplace(*elem.Get<UUID>());
36   } else if (type == DataElement::Type::kSequence ||
37              type == DataElement::Type::kAlternative) {
38     const DataElement* it;
39     for (size_t idx = 0; nullptr != (it = elem.At(idx)); idx++) {
40       AddAllUUIDs(*it, out);
41     }
42   }
43 }
44 
45 }  // namespace
46 
ServiceRecord()47 ServiceRecord::ServiceRecord() {
48   SetAttribute(kServiceId, DataElement(UUID::Generate()));
49 }
50 
ServiceRecord(const ServiceRecord & other)51 ServiceRecord::ServiceRecord(const ServiceRecord& other) {
52   handle_ = other.handle_;
53   security_level_ = other.security_level_;
54 
55   for (const auto& attribute : other.attributes_) {
56     attributes_.emplace(attribute.first, attribute.second.Clone());
57   }
58 
59   for (const auto& protocol : other.addl_protocols_) {
60     addl_protocols_.emplace(protocol.first, protocol.second.Clone());
61   }
62 }
63 
SetAttribute(AttributeId id,DataElement value)64 void ServiceRecord::SetAttribute(AttributeId id, DataElement value) {
65   attributes_.erase(id);
66   attributes_.emplace(id, std::move(value));
67 }
68 
GetAttribute(AttributeId id) const69 const DataElement& ServiceRecord::GetAttribute(AttributeId id) const {
70   auto it = attributes_.find(id);
71   PW_DCHECK(it != attributes_.end(), "attribute %#.4x not set!", id);
72   return it->second;
73 }
74 
HasAttribute(AttributeId id) const75 bool ServiceRecord::HasAttribute(AttributeId id) const {
76   return attributes_.count(id) == 1;
77 }
78 
RemoveAttribute(AttributeId id)79 void ServiceRecord::RemoveAttribute(AttributeId id) { attributes_.erase(id); }
80 
IsProtocolOnly() const81 bool ServiceRecord::IsProtocolOnly() const {
82   // Protocol-only services have exactly:
83   //  - A UUID (generated by constructor)
84   //  - A primary protocol descriptor
85   //  - A service record handle (assigned by the SDP server)
86   if (attributes_.size() != 3) {
87     return false;
88   }
89   for (AttributeId x :
90        {kServiceRecordHandle, kProtocolDescriptorList, kServiceId}) {
91     if (attributes_.count(x) != 1) {
92       return false;
93     }
94   }
95   return true;
96 }
97 
IsRegisterable() const98 bool ServiceRecord::IsRegisterable() const {
99   // Services must at least have a ServiceClassIDList (5.0, Vol 3, Part B, 5.1)
100   if (!HasAttribute(kServiceClassIdList)) {
101     bt_log(TRACE, "sdp", "record missing ServiceClass");
102     return false;
103   }
104   // Class ID list is a data element sequence in which each data element is
105   // a UUID representing the service classes that a given service record
106   // conforms to. (5.0, Vol 3, Part B, 5.1.2)
107   const DataElement& class_id_list = GetAttribute(kServiceClassIdList);
108   if (class_id_list.type() != DataElement::Type::kSequence) {
109     bt_log(TRACE, "sdp", "class ID list isn't a sequence");
110     return false;
111   }
112 
113   size_t idx;
114   const DataElement* elem;
115   for (idx = 0; nullptr != (elem = class_id_list.At(idx)); idx++) {
116     if (elem->type() != DataElement::Type::kUuid) {
117       bt_log(TRACE, "sdp", "class ID list elements are not all UUIDs");
118       return false;
119     }
120   }
121 
122   if (idx == 0) {
123     bt_log(TRACE, "sdp", "no elements in the Class ID list (need at least 1)");
124     return false;
125   }
126 
127   if (!HasAttribute(kBrowseGroupList)) {
128     bt_log(TRACE, "sdp", "record isn't part of a browse group");
129     return false;
130   }
131 
132   return true;
133 }
134 
SetHandle(ServiceHandle handle)135 void ServiceRecord::SetHandle(ServiceHandle handle) {
136   handle_ = handle;
137   SetAttribute(kServiceRecordHandle, DataElement(uint32_t(handle_)));
138 }
139 
GetAttributesInRange(AttributeId start,AttributeId end) const140 std::set<AttributeId> ServiceRecord::GetAttributesInRange(
141     AttributeId start, AttributeId end) const {
142   std::set<AttributeId> attrs;
143   if (start > end) {
144     return attrs;
145   }
146   for (auto it = attributes_.lower_bound(start);
147        it != attributes_.end() && (it->first <= end);
148        ++it) {
149     attrs.emplace(it->first);
150   }
151 
152   return attrs;
153 }
154 
FindUUID(const std::unordered_set<UUID> & uuids) const155 bool ServiceRecord::FindUUID(const std::unordered_set<UUID>& uuids) const {
156   if (uuids.size() == 0) {
157     return true;
158   }
159   // Gather all the UUIDs in the attributes
160   std::unordered_set<UUID> attribute_uuids;
161   for (const auto& it : attributes_) {
162     AddAllUUIDs(it.second, &attribute_uuids);
163   }
164   for (const auto& uuid : uuids) {
165     if (attribute_uuids.count(uuid) == 0) {
166       return false;
167     }
168   }
169   return true;
170 }
171 
SetServiceClassUUIDs(const std::vector<UUID> & classes)172 void ServiceRecord::SetServiceClassUUIDs(const std::vector<UUID>& classes) {
173   std::vector<DataElement> class_uuids;
174   for (const auto& uuid : classes) {
175     class_uuids.emplace_back(DataElement(uuid));
176   }
177   DataElement class_id_list_val(std::move(class_uuids));
178   SetAttribute(kServiceClassIdList, std::move(class_id_list_val));
179 }
180 
AddProtocolDescriptor(const ProtocolListId id,const UUID & uuid,DataElement params)181 void ServiceRecord::AddProtocolDescriptor(const ProtocolListId id,
182                                           const UUID& uuid,
183                                           DataElement params) {
184   std::vector<DataElement> seq;
185   if (id == kPrimaryProtocolList) {
186     auto list_it = attributes_.find(kProtocolDescriptorList);
187     if (list_it != attributes_.end()) {
188       auto v = list_it->second.Get<std::vector<DataElement>>();
189       seq = std::move(*v);
190     }
191   } else if (addl_protocols_.count(id)) {
192     auto v = addl_protocols_[id].Get<std::vector<DataElement>>();
193     seq = std::move(*v);
194   }
195 
196   std::vector<DataElement> protocol_desc;
197   protocol_desc.emplace_back(DataElement(uuid));
198   if (params.type() == DataElement::Type::kSequence) {
199     auto v = params.Get<std::vector<DataElement>>();
200     auto param_seq = std::move(*v);
201     std::move(std::begin(param_seq),
202               std::end(param_seq),
203               std::back_inserter(protocol_desc));
204   } else if (params.type() != DataElement::Type::kNull) {
205     protocol_desc.emplace_back(std::move(params));
206   }
207 
208   seq.emplace_back(DataElement(std::move(protocol_desc)));
209 
210   if (id == kPrimaryProtocolList) {
211     SetAttribute(kProtocolDescriptorList, DataElement(std::move(seq)));
212   } else {
213     addl_protocols_.erase(id);
214     addl_protocols_.emplace(id, DataElement(std::move(seq)));
215 
216     std::vector<DataElement> addl_protocol_seq;
217     for (const auto& it : addl_protocols_) {
218       addl_protocol_seq.emplace_back(it.second.Clone());
219     }
220 
221     SetAttribute(kAdditionalProtocolDescriptorList,
222                  DataElement(std::move(addl_protocol_seq)));
223   }
224 }
225 
AddProfile(const UUID & uuid,uint8_t major,uint8_t minor)226 void ServiceRecord::AddProfile(const UUID& uuid, uint8_t major, uint8_t minor) {
227   std::vector<DataElement> seq;
228   auto list_it = attributes_.find(kBluetoothProfileDescriptorList);
229   if (list_it != attributes_.end()) {
230     auto v = list_it->second.Get<std::vector<DataElement>>();
231     seq = std::move(*v);
232   }
233 
234   std::vector<DataElement> profile_desc;
235   profile_desc.emplace_back(DataElement(uuid));
236   // Safety notes:
237   // 1.) `<<` applies integer promotion of `major` to `int` (32 bits) before
238   // operating. This makes
239   //     it safe to left shift 8 bits, even though 8 is >= `major`'s original
240   //     width.
241   // 2.) Casting to 16 bits is safe because `major` and `minor` are both only 8
242   // bits, so it is only
243   //     possible for 16 bits of the resulting value to be populated.
244   uint16_t profile_version = static_cast<uint16_t>(
245       (major << std::numeric_limits<uint8_t>::digits) | minor);
246   profile_desc.emplace_back(DataElement(profile_version));
247 
248   seq.emplace_back(DataElement(std::move(profile_desc)));
249 
250   SetAttribute(kBluetoothProfileDescriptorList, DataElement(std::move(seq)));
251 }
252 
AddInfo(const std::string & language_code,const std::string & name,const std::string & description,const std::string & provider)253 bool ServiceRecord::AddInfo(const std::string& language_code,
254                             const std::string& name,
255                             const std::string& description,
256                             const std::string& provider) {
257   if ((name.empty() && description.empty() && provider.empty()) ||
258       (language_code.size() != 2)) {
259     return false;
260   }
261   AttributeId base_attrid = 0x0100;
262   std::vector<DataElement> base_attr_list;
263   auto it = attributes_.find(kLanguageBaseAttributeIdList);
264   if (it != attributes_.end()) {
265     auto v = it->second.Get<std::vector<DataElement>>();
266     base_attr_list = std::move(*v);
267 
268     // "%" can't be in pw_assert statements.
269     const size_t list_size_mod_3 = base_attr_list.size() % 3;
270     PW_DCHECK(list_size_mod_3 == 0);
271 
272     // 0x0100 is guaranteed to be taken, start counting from higher.
273     base_attrid = 0x9000;
274   }
275 
276   // Find the first base_attrid that's not taken
277   while (HasAttribute(base_attrid + kServiceNameOffset) ||
278          HasAttribute(base_attrid + kServiceDescriptionOffset) ||
279          HasAttribute(base_attrid + kProviderNameOffset)) {
280     base_attrid++;
281     if (base_attrid == 0xFFFF) {
282       return false;
283     }
284   }
285 
286   // Core Spec v5.0, Vol 3, Part B, Sect 5.1.8: "The LanguageBaseAttributeIDList
287   // attribute consists of a data element sequence in which each element is a
288   // 16-bit unsigned integer."
289   // The language code consists of two byte characters in left-to-right order,
290   // so it may be considered a 16-bit big-endian integer that can be converted
291   // to host byte order.
292   uint16_t lang_encoded = pw::bytes::ConvertOrderFrom(
293       cpp20::endian::big, *((const uint16_t*)(language_code.data())));
294   base_attr_list.emplace_back(DataElement(lang_encoded));
295   base_attr_list.emplace_back(DataElement(uint16_t{106}));  // UTF-8
296   base_attr_list.emplace_back(DataElement(base_attrid));
297 
298   if (!name.empty()) {
299     SetAttribute(base_attrid + kServiceNameOffset, DataElement(name));
300   }
301   if (!description.empty()) {
302     SetAttribute(base_attrid + kServiceDescriptionOffset,
303                  DataElement(description));
304   }
305   if (!provider.empty()) {
306     SetAttribute(base_attrid + kProviderNameOffset, DataElement(provider));
307   }
308 
309   SetAttribute(kLanguageBaseAttributeIdList,
310                DataElement(std::move(base_attr_list)));
311   return true;
312 }
313 
GetInfo() const314 std::vector<ServiceRecord::Information> ServiceRecord::GetInfo() const {
315   if (!HasAttribute(kLanguageBaseAttributeIdList)) {
316     return {};
317   }
318 
319   const auto& base_id_list = GetAttribute(kLanguageBaseAttributeIdList);
320   // Expected to be a sequence.
321   if (base_id_list.type() != DataElement::Type::kSequence) {
322     bt_log(WARN, "sdp", "kLanguageBaseAttributeIdList not a sequence");
323     return {};
324   }
325 
326   std::vector<ServiceRecord::Information> out;
327   const auto& base_id_seq = base_id_list.Get<std::vector<DataElement>>();
328   const size_t list_size_mod_3 = base_id_seq->size() % 3;
329   PW_DCHECK(list_size_mod_3 == 0);
330 
331   for (size_t i = 0; i + 2 < base_id_seq->size(); i += 3) {
332     // Each entry is a triplet of uint16_t (language_code, encoding format, base
333     // attribute ID). Encoding format is always Utf-8 and can be ignored.
334     const std::optional<uint16_t> language = base_id_seq->at(i).Get<uint16_t>();
335     const std::optional<uint16_t> base_attr_id =
336         base_id_seq->at(i + 2).Get<uint16_t>();
337 
338     if (!language || !base_attr_id) {
339       bt_log(WARN, "sdp", "Missing language or base_attr_id");
340       return {};
341     }
342 
343     ServiceRecord::Information info;
344     // The language code is stored in host byte order, but is interpreted as two
345     // byte characters in left-to-right order (big-endian).
346     uint16_t language_be =
347         pw::bytes::ConvertOrderTo(cpp20::endian::big, language.value());
348     info.language_code = std::string(
349         reinterpret_cast<const char*>(&language_be), sizeof(language_be));
350 
351     if (HasAttribute(base_attr_id.value() + kServiceNameOffset)) {
352       std::optional<std::string> name =
353           GetAttribute(base_attr_id.value() + kServiceNameOffset)
354               .Get<std::string>();
355       if (!name) {
356         bt_log(WARN, "sdp", "Invalid name field in information entry");
357         return {};
358       }
359       info.name = std::move(name.value());
360     }
361 
362     if (HasAttribute(base_attr_id.value() + kServiceDescriptionOffset)) {
363       std::optional<std::string> description =
364           GetAttribute(base_attr_id.value() + kServiceDescriptionOffset)
365               .Get<std::string>();
366       if (!description) {
367         bt_log(WARN, "sdp", "Invalid description field in information entry");
368         return {};
369       }
370       info.description = std::move(description.value());
371     }
372 
373     if (HasAttribute(base_attr_id.value() + kProviderNameOffset)) {
374       std::optional<std::string> provider =
375           GetAttribute(base_attr_id.value() + kProviderNameOffset)
376               .Get<std::string>();
377       if (!provider) {
378         bt_log(WARN, "sdp", "Invalid provider field in information entry");
379         return {};
380       }
381       info.provider = std::move(provider.value());
382     }
383 
384     out.emplace_back(std::move(info));
385   }
386 
387   return out;
388 }
389 
ToString() const390 std::string ServiceRecord::ToString() const {
391   std::string str;
392 
393   if (HasAttribute(kBluetoothProfileDescriptorList)) {
394     const DataElement& prof_desc =
395         GetAttribute(kBluetoothProfileDescriptorList);
396     str += "Profile Descriptor: " + prof_desc.ToString() + "\n";
397   }
398 
399   if (HasAttribute(kServiceClassIdList)) {
400     const DataElement& svc_class_list = GetAttribute(kServiceClassIdList);
401     str += "Service Class Id List: " + svc_class_list.ToString();
402   }
403 
404   return str;
405 }
406 
407 }  // namespace bt::sdp
408