1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/aggregation/protocol/aggregation_protocol.h"
18
19 #include <pybind11/pybind11.h>
20
21 #include <cstdint>
22
23 #include "absl/status/status.h"
24 #include "absl/strings/cord.h"
25 #include "fcp/aggregation/protocol/aggregation_protocol_messages.pb.h"
26 #include "fcp/aggregation/protocol/configuration.pb.h"
27 #include "pybind11_abseil/absl_casters.h"
28 #include "pybind11_abseil/status_casters.h"
29 #include "pybind11_protobuf/native_proto_caster.h"
30
31 namespace {
32
33 namespace py = ::pybind11;
34
35 using ::fcp::aggregation::AcceptanceMessage;
36 using ::fcp::aggregation::AggregationProtocol;
37 using ::fcp::aggregation::ServerMessage;
38
39 // Allow AggregationProtocol::Callback to be subclassed in Python. See
40 // https://pybind11.readthedocs.io/en/stable/advanced/classes.html#overriding-virtual-functions-in-python
41 class PyAggregationProtocolCallback : public AggregationProtocol::Callback {
42 public:
OnAcceptClients(int64_t start_client_id,int64_t num_clients,const AcceptanceMessage & message)43 void OnAcceptClients(int64_t start_client_id, int64_t num_clients,
44 const AcceptanceMessage& message) override {
45 PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback, OnAcceptClients,
46 start_client_id, num_clients, message);
47 }
48
OnSendServerMessage(int64_t client_id,const ServerMessage & message)49 void OnSendServerMessage(int64_t client_id,
50 const ServerMessage& message) override {
51 PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback,
52 OnSendServerMessage, client_id, message);
53 }
54
OnCloseClient(int64_t client_id,absl::Status diagnostic_status)55 void OnCloseClient(int64_t client_id,
56 absl::Status diagnostic_status) override {
57 PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback, OnCloseClient,
58 client_id,
59 py::google::DoNotThrowStatus(diagnostic_status));
60 }
61
OnComplete(absl::Cord result)62 void OnComplete(absl::Cord result) override {
63 PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback, OnComplete,
64 result);
65 }
66
OnAbort(absl::Status diagnostic_status)67 void OnAbort(absl::Status diagnostic_status) override {
68 PYBIND11_OVERRIDE_PURE(void, AggregationProtocol::Callback, OnAbort,
69 py::google::DoNotThrowStatus(diagnostic_status));
70 }
71 };
72
73 } // namespace
74
PYBIND11_MODULE(aggregation_protocol,m)75 PYBIND11_MODULE(aggregation_protocol, m) {
76 pybind11::google::ImportStatusModule();
77 pybind11_protobuf::ImportNativeProtoCasters();
78
79 auto py_aggregation_protocol =
80 py::class_<AggregationProtocol>(m, "AggregationProtocol")
81 .def("Start", &AggregationProtocol::Start)
82 .def("AddClients", &AggregationProtocol::AddClients)
83 .def("ReceiveClientMessage",
84 &AggregationProtocol::ReceiveClientMessage)
85 .def("CloseClient", &AggregationProtocol::CloseClient)
86 .def("Complete", &AggregationProtocol::Complete)
87 .def("Abort", &AggregationProtocol::Abort)
88 .def("GetStatus", &AggregationProtocol::GetStatus);
89
90 pybind11::class_<AggregationProtocol::Callback,
91 PyAggregationProtocolCallback>(py_aggregation_protocol,
92 "Callback")
93 .def(py::init<>())
94 .def("OnAcceptClients", &AggregationProtocol::Callback::OnAcceptClients)
95 .def("OnSendServerMessage",
96 &AggregationProtocol::Callback::OnSendServerMessage)
97 .def("OnCloseClient", &AggregationProtocol::Callback::OnCloseClient)
98 .def("OnComplete", &AggregationProtocol::Callback::OnComplete)
99 .def("OnAbort", &AggregationProtocol::Callback::OnAbort);
100 }
101