xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/protocol/python/aggregation_protocol.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/aggregation/protocol/aggregation_protocol.h"
18 
19 #include <pybind11/pybind11.h>
20 
21 #include <cstdint>
22 
23 #include "absl/status/status.h"
24 #include "absl/strings/cord.h"
25 #include "fcp/aggregation/protocol/aggregation_protocol_messages.pb.h"
26 #include "fcp/aggregation/protocol/configuration.pb.h"
27 #include "pybind11_abseil/absl_casters.h"
28 #include "pybind11_abseil/status_casters.h"
29 #include "pybind11_protobuf/native_proto_caster.h"
30 
31 namespace {
32 
33 namespace py = ::pybind11;
34 
35 using ::fcp::aggregation::AcceptanceMessage;
36 using ::fcp::aggregation::AggregationProtocol;
37 using ::fcp::aggregation::ServerMessage;
38 
39 // Allow AggregationProtocol::Callback to be subclassed in Python. See
40 // https://pybind11.readthedocs.io/en/stable/advanced/classes.html#overriding-virtual-functions-in-python
41 class PyAggregationProtocolCallback : public AggregationProtocol::Callback {
42  public:
OnAcceptClients(int64_t start_client_id,int64_t num_clients,const AcceptanceMessage & message)43   void OnAcceptClients(int64_t start_client_id, int64_t num_clients,
44                        const AcceptanceMessage& message) override {
45     PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback, OnAcceptClients,
46                            start_client_id, num_clients, message);
47   }
48 
OnSendServerMessage(int64_t client_id,const ServerMessage & message)49   void OnSendServerMessage(int64_t client_id,
50                            const ServerMessage& message) override {
51     PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback,
52                            OnSendServerMessage, client_id, message);
53   }
54 
OnCloseClient(int64_t client_id,absl::Status diagnostic_status)55   void OnCloseClient(int64_t client_id,
56                      absl::Status diagnostic_status) override {
57     PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback, OnCloseClient,
58                            client_id,
59                            py::google::DoNotThrowStatus(diagnostic_status));
60   }
61 
OnComplete(absl::Cord result)62   void OnComplete(absl::Cord result) override {
63     PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback, OnComplete,
64                            result);
65   }
66 
OnAbort(absl::Status diagnostic_status)67   void OnAbort(absl::Status diagnostic_status) override {
68     PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback, OnAbort,
69                            py::google::DoNotThrowStatus(diagnostic_status));
70   }
71 };
72 
73 }  // namespace
74 
PYBIND11_MODULE(aggregation_protocol,m)75 PYBIND11_MODULE(aggregation_protocol, m) {
76   pybind11::google::ImportStatusModule();
77   pybind11_protobuf::ImportNativeProtoCasters();
78 
79   auto py_aggregation_protocol =
80       py::class_<AggregationProtocol>(m, "AggregationProtocol")
81           .def("Start", &AggregationProtocol::Start)
82           .def("AddClients", &AggregationProtocol::AddClients)
83           .def("ReceiveClientMessage",
84                &AggregationProtocol::ReceiveClientMessage)
85           .def("CloseClient", &AggregationProtocol::CloseClient)
86           .def("Complete", &AggregationProtocol::Complete)
87           .def("Abort", &AggregationProtocol::Abort)
88           .def("GetStatus", &AggregationProtocol::GetStatus);
89 
90   pybind11::class_<AggregationProtocol::Callback,
91                    PyAggregationProtocolCallback>(py_aggregation_protocol,
92                                                   "Callback")
93       .def(py::init<>())
94       .def("OnAcceptClients", &AggregationProtocol::Callback::OnAcceptClients)
95       .def("OnSendServerMessage",
96            &AggregationProtocol::Callback::OnSendServerMessage)
97       .def("OnCloseClient", &AggregationProtocol::Callback::OnCloseClient)
98       .def("OnComplete", &AggregationProtocol::Callback::OnComplete)
99       .def("OnAbort", &AggregationProtocol::Callback::OnAbort);
100 }
101