xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/tf_wrapper.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2019 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker #ifndef FCP_CLIENT_ENGINE_TF_WRAPPER_H_
17*14675a02SAndroid Build Coastguard Worker #define FCP_CLIENT_ENGINE_TF_WRAPPER_H_
18*14675a02SAndroid Build Coastguard Worker 
19*14675a02SAndroid Build Coastguard Worker #include <functional>
20*14675a02SAndroid Build Coastguard Worker #include <string>
21*14675a02SAndroid Build Coastguard Worker #include <utility>
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker #include "google/protobuf/any.pb.h"
24*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
25*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
26*14675a02SAndroid Build Coastguard Worker #include "absl/time/time.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/base/future.h"
28*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
29*14675a02SAndroid Build Coastguard Worker #include "fcp/base/scheduler.h"
30*14675a02SAndroid Build Coastguard Worker #include "fcp/client/interruptible_runner.h"
31*14675a02SAndroid Build Coastguard Worker #include "fcp/client/log_manager.h"
32*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/public/session.h"
33*14675a02SAndroid Build Coastguard Worker 
34*14675a02SAndroid Build Coastguard Worker namespace fcp {
35*14675a02SAndroid Build Coastguard Worker namespace client {
36*14675a02SAndroid Build Coastguard Worker namespace engine {
37*14675a02SAndroid Build Coastguard Worker 
38*14675a02SAndroid Build Coastguard Worker // A class to call into TensorFlow.
39*14675a02SAndroid Build Coastguard Worker // All functions in this interface indicate errors as follows:
40*14675a02SAndroid Build Coastguard Worker // - CANCELLED: interrupted execution
41*14675a02SAndroid Build Coastguard Worker // - INVALID_ARGUMENT: TensorFlow error. The TensorFlow error code and message
42*14675a02SAndroid Build Coastguard Worker //   are included in the Status message.
43*14675a02SAndroid Build Coastguard Worker // - OUT_OF_RANGE: internal abortion, i.e. TensorFlow reporting the model
44*14675a02SAndroid Build Coastguard Worker //   aborted execution.
45*14675a02SAndroid Build Coastguard Worker // This class supports aborting ongoing calls, by polling the provided
46*14675a02SAndroid Build Coastguard Worker // should_abort function.
47*14675a02SAndroid Build Coastguard Worker class TensorFlowWrapper {
48*14675a02SAndroid Build Coastguard Worker  public:
49*14675a02SAndroid Build Coastguard Worker   static absl::StatusOr<std::unique_ptr<TensorFlowWrapper>> Create(
50*14675a02SAndroid Build Coastguard Worker       const std::string& graph, const ::google::protobuf::Any& config_proto,
51*14675a02SAndroid Build Coastguard Worker       std::function<bool()> should_abort,
52*14675a02SAndroid Build Coastguard Worker       const InterruptibleRunner::TimingConfig& timing_config,
53*14675a02SAndroid Build Coastguard Worker       LogManager* log_manager);
54*14675a02SAndroid Build Coastguard Worker 
55*14675a02SAndroid Build Coastguard Worker   // Utility method for creating a ConfigProto from an optionally
56*14675a02SAndroid Build Coastguard Worker   // externally provided value, or from hardcoded defaults. This is a separate
57*14675a02SAndroid Build Coastguard Worker   // method to aid with testing.
58*14675a02SAndroid Build Coastguard Worker   static absl::StatusOr<::tensorflow::ConfigProto> InitializeConfigProto(
59*14675a02SAndroid Build Coastguard Worker       const ::google::protobuf::Any& external_config_proto);
60*14675a02SAndroid Build Coastguard Worker 
61*14675a02SAndroid Build Coastguard Worker   ~TensorFlowWrapper();
62*14675a02SAndroid Build Coastguard Worker 
63*14675a02SAndroid Build Coastguard Worker   // Wrapper around TensorFlow's Session::Run method with full support for
64*14675a02SAndroid Build Coastguard Worker   // feeds, fetches and target node names.
65*14675a02SAndroid Build Coastguard Worker   // Returns OK, OUT_OF_RANGE, INVALID_ARGUMENT, or CANCELLED.
66*14675a02SAndroid Build Coastguard Worker   absl::Status Run(
67*14675a02SAndroid Build Coastguard Worker       const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs,
68*14675a02SAndroid Build Coastguard Worker       const std::vector<std::string>& output_tensor_names,
69*14675a02SAndroid Build Coastguard Worker       const std::vector<std::string>& target_node_names,
70*14675a02SAndroid Build Coastguard Worker       std::vector<tensorflow::Tensor>* outputs);
71*14675a02SAndroid Build Coastguard Worker 
72*14675a02SAndroid Build Coastguard Worker   // Closes and releases the TensorFlow session. After this is called, no
73*14675a02SAndroid Build Coastguard Worker   // further calls on this TensorFlowWrapper should be made. Subsequent calls to
74*14675a02SAndroid Build Coastguard Worker   // CloseAndRelease() will have no effect.
75*14675a02SAndroid Build Coastguard Worker   absl::Status CloseAndRelease();
76*14675a02SAndroid Build Coastguard Worker 
77*14675a02SAndroid Build Coastguard Worker  private:
TensorFlowWrapper(std::unique_ptr<tensorflow::Session> session,std::unique_ptr<InterruptibleRunner> interruptible_runner,LogManager * log_manager)78*14675a02SAndroid Build Coastguard Worker   TensorFlowWrapper(std::unique_ptr<tensorflow::Session> session,
79*14675a02SAndroid Build Coastguard Worker                     std::unique_ptr<InterruptibleRunner> interruptible_runner,
80*14675a02SAndroid Build Coastguard Worker                     LogManager* log_manager)
81*14675a02SAndroid Build Coastguard Worker       : session_(std::move(session)),
82*14675a02SAndroid Build Coastguard Worker         interruptible_runner_(std::move(interruptible_runner)),
83*14675a02SAndroid Build Coastguard Worker         session_closed_(false) {}
84*14675a02SAndroid Build Coastguard Worker 
85*14675a02SAndroid Build Coastguard Worker   // Converts a TensorFlow status to an absl::Status.
86*14675a02SAndroid Build Coastguard Worker   //
87*14675a02SAndroid Build Coastguard Worker   // Rule:
88*14675a02SAndroid Build Coastguard Worker   // TensorFlow OK status -> absl OK status
89*14675a02SAndroid Build Coastguard Worker   // TensorFlow OUT_OF_RANGE -> absl OUT_OF_RANGE status (this is TF indicating
90*14675a02SAndroid Build Coastguard Worker   //   that the plan decided to abort, e.g. because of convergence)
91*14675a02SAndroid Build Coastguard Worker   // Other TensorFlow status -> absl INVALID_ARGUMENT status with error
92*14675a02SAndroid Build Coastguard Worker   // message being message_prefix + TensorFlow status code + error message.
93*14675a02SAndroid Build Coastguard Worker   static absl::Status ToFcpStatus(tensorflow::Status s,
94*14675a02SAndroid Build Coastguard Worker                                   const std::string& message_prefix);
95*14675a02SAndroid Build Coastguard Worker 
96*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<tensorflow::Session> session_;
97*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<InterruptibleRunner> interruptible_runner_;
98*14675a02SAndroid Build Coastguard Worker   absl::Mutex session_lock_;
99*14675a02SAndroid Build Coastguard Worker   bool session_closed_;
100*14675a02SAndroid Build Coastguard Worker };
101*14675a02SAndroid Build Coastguard Worker 
102*14675a02SAndroid Build Coastguard Worker }  // namespace engine
103*14675a02SAndroid Build Coastguard Worker }  // namespace client
104*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
105*14675a02SAndroid Build Coastguard Worker 
106*14675a02SAndroid Build Coastguard Worker #endif  // FCP_CLIENT_ENGINE_TF_WRAPPER_H_
107