1 // Copyright 2023 The Pigweed Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 // use this file except in compliance with the License. You may obtain a copy of 5 // the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations under 13 // the License. 14 15 #pragma once 16 #include "pw_async/dispatcher.h" 17 #include "pw_bluetooth_sapphire/internal/host/gatt/fake_client.h" 18 #include "pw_bluetooth_sapphire/internal/host/gatt/gatt.h" 19 20 namespace bt::gatt::testing { 21 22 // This is a fake version of the root GATT object that can be injected in unit 23 // tests. 24 class FakeLayer final : public GATT { 25 public: 26 struct Update { 27 IdType chrc_id; 28 std::vector<uint8_t> value; 29 IndicationCallback indicate_cb; 30 std::optional<PeerId> peer; 31 }; 32 33 struct LocalService { 34 std::unique_ptr<Service> service; 35 ReadHandler read_handler; 36 WriteHandler write_handler; 37 ClientConfigCallback ccc_callback; 38 std::vector<Update> updates; 39 }; 40 FakeLayer(pw::async::Dispatcher & pw_dispatcher)41 explicit FakeLayer(pw::async::Dispatcher& pw_dispatcher) 42 : pw_dispatcher_(pw_dispatcher) {} 43 ~FakeLayer() override = default; 44 45 // Create a new peer GATT service. Creates a peer entry if it doesn't already 46 // exist. Replaces an existing service with the same handle if it exists. 47 // Notifies the remote service watcher if |notify| is true. 48 // 49 // Returns the fake remote service and a handle to the fake object. 50 // 51 // NOTE: the remote service watcher can also get triggered by calling 52 // InitializeClient(). 53 std::pair<RemoteService::WeakPtr, FakeClient::WeakPtr> AddPeerService( 54 PeerId peer_id, const ServiceData& info, bool notify = false); 55 56 // Removes the service with start handle of |handle| and notifies service 57 // watcher. 58 void RemovePeerService(PeerId peer_id, att::Handle handle); 59 60 // Assign a callback to be notified when a request is made to initialize the 61 // client. 62 using InitializeClientCallback = 63 fit::function<void(PeerId, std::vector<UUID>)>; 64 void SetInitializeClientCallback(InitializeClientCallback cb); 65 66 // Assign the status that will be returned by the ListServices callback. 67 void set_list_services_status(att::Result<>); 68 69 // Ignore future calls to ListServices(). stop_list_services()70 void stop_list_services() { pause_list_services_ = true; } 71 72 // Assign a callback to be notified when the persist service changed CCC 73 // callback is set. 74 using SetPersistServiceChangedCCCCallbackCallback = fit::function<void()>; 75 void SetSetPersistServiceChangedCCCCallbackCallback( 76 SetPersistServiceChangedCCCCallbackCallback cb); 77 78 // Assign a callback to be notified when the retrieve service changed CCC 79 // callback is set. 80 using SetRetrieveServiceChangedCCCCallbackCallback = fit::function<void()>; 81 void SetSetRetrieveServiceChangedCCCCallbackCallback( 82 SetRetrieveServiceChangedCCCCallbackCallback cb); 83 84 // Directly force the fake layer to call the persist service changed CCC 85 // callback, to test the GAP adapter and peer cache. 86 void CallPersistServiceChangedCCCCallback(PeerId peer_id, 87 bool notify, 88 bool indicate); 89 90 // Directly force the fake layer to call the retrieve service changed CCC 91 // callback, to test the GAP adapter and peer cache. 92 std::optional<ServiceChangedCCCPersistedData> 93 CallRetrieveServiceChangedCCCCallback(PeerId peer_id); 94 FindLocalServiceById(IdType id)95 Service* FindLocalServiceById(IdType id) { 96 return local_services_.count(id) ? local_services_[id].service.get() 97 : nullptr; 98 } 99 local_services()100 std::map<IdType, LocalService>& local_services() { return local_services_; } 101 102 // If true, cause all calls to RegisterService() to fail. set_register_service_fails(bool fails)103 void set_register_service_fails(bool fails) { 104 register_service_fails_ = fails; 105 } 106 107 // GATT overrides: 108 void AddConnection(PeerId peer_id, 109 std::unique_ptr<Client> client, 110 Server::FactoryFunction server_factory) override; 111 void RemoveConnection(PeerId peer_id) override; 112 PeerMtuListenerId RegisterPeerMtuListener(PeerMtuListener listener) override; 113 bool UnregisterPeerMtuListener(PeerMtuListenerId listener_id) override; 114 void RegisterService(ServicePtr service, 115 ServiceIdCallback callback, 116 ReadHandler read_handler, 117 WriteHandler write_handler, 118 ClientConfigCallback ccc_callback) override; 119 void UnregisterService(IdType service_id) override; 120 void SendUpdate(IdType service_id, 121 IdType chrc_id, 122 PeerId peer_id, 123 ::std::vector<uint8_t> value, 124 IndicationCallback indicate_cb) override; 125 void UpdateConnectedPeers(IdType service_id, 126 IdType chrc_id, 127 ::std::vector<uint8_t> value, 128 IndicationCallback indicate_cb) override; 129 void SetPersistServiceChangedCCCCallback( 130 PersistServiceChangedCCCCallback callback) override; 131 void SetRetrieveServiceChangedCCCCallback( 132 RetrieveServiceChangedCCCCallback callback) override; 133 void InitializeClient(PeerId peer_id, 134 std::vector<UUID> services_to_discover) override; 135 RemoteServiceWatcherId RegisterRemoteServiceWatcherForPeer( 136 PeerId peer_id, RemoteServiceWatcher watcher) override; 137 bool UnregisterRemoteServiceWatcher( 138 RemoteServiceWatcherId watcher_id) override; 139 void ListServices(PeerId peer_id, 140 std::vector<UUID> uuids, 141 ServiceListCallback callback) override; 142 RemoteService::WeakPtr FindService(PeerId peer_id, 143 IdType service_id) override; 144 145 using WeakPtr = WeakSelf<FakeLayer>::WeakPtr; GetFakePtr()146 FakeLayer::WeakPtr GetFakePtr() { return weak_fake_.GetWeakPtr(); } 147 148 private: 149 IdType next_local_service_id_ = 150 100; // Start at a random large ID to help catch bugs (e.g. 151 // FIDL IDs mixed up with internal IDs). 152 std::map<IdType, LocalService> local_services_; 153 154 bool register_service_fails_ = false; 155 156 // Test callbacks 157 InitializeClientCallback initialize_client_cb_; 158 SetPersistServiceChangedCCCCallbackCallback 159 set_persist_service_changed_ccc_cb_cb_; 160 SetRetrieveServiceChangedCCCCallbackCallback 161 set_retrieve_service_changed_ccc_cb_cb_; 162 163 // Emulated callbacks 164 std::unordered_map<PeerId, RemoteServiceWatcher> remote_service_watchers_; 165 166 PersistServiceChangedCCCCallback persist_service_changed_ccc_cb_; 167 RetrieveServiceChangedCCCCallback retrieve_service_changed_ccc_cb_; 168 169 att::Result<> list_services_status_ = fit::ok(); 170 bool pause_list_services_ = false; 171 172 // Emulated GATT peer. 173 struct TestPeer { 174 explicit TestPeer(pw::async::Dispatcher& pw_dispatcher); 175 176 FakeClient fake_client; 177 std::unordered_map<IdType, std::unique_ptr<RemoteService>> services; 178 179 BT_DISALLOW_COPY_AND_ASSIGN_ALLOW_MOVE(TestPeer); 180 }; 181 std::unordered_map<PeerId, TestPeer> peers_; 182 183 pw::async::Dispatcher& pw_dispatcher_; 184 WeakSelf<FakeLayer> weak_fake_{this}; 185 186 BT_DISALLOW_COPY_AND_ASSIGN_ALLOW_MOVE(FakeLayer); 187 }; 188 189 } // namespace bt::gatt::testing 190