xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/tf_wrapper.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef FCP_CLIENT_ENGINE_TF_WRAPPER_H_
17 #define FCP_CLIENT_ENGINE_TF_WRAPPER_H_
18 
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 #include "google/protobuf/any.pb.h"
24 #include "absl/status/status.h"
25 #include "absl/status/statusor.h"
26 #include "absl/time/time.h"
27 #include "fcp/base/future.h"
28 #include "fcp/base/monitoring.h"
29 #include "fcp/base/scheduler.h"
30 #include "fcp/client/interruptible_runner.h"
31 #include "fcp/client/log_manager.h"
32 #include "tensorflow/core/public/session.h"
33 
34 namespace fcp {
35 namespace client {
36 namespace engine {
37 
38 // A class to call into TensorFlow.
39 // All functions in this interface indicate errors as follows:
40 // - CANCELLED: interrupted execution
41 // - INVALID_ARGUMENT: TensorFlow error. The TensorFlow error code and message
42 //   are included in the Status message.
43 // - OUT_OF_RANGE: internal abortion, i.e. TensorFlow reporting the model
44 //   aborted execution.
45 // This class supports aborting ongoing calls, by polling the provided
46 // should_abort function.
47 class TensorFlowWrapper {
48  public:
49   static absl::StatusOr<std::unique_ptr<TensorFlowWrapper>> Create(
50       const std::string& graph, const ::google::protobuf::Any& config_proto,
51       std::function<bool()> should_abort,
52       const InterruptibleRunner::TimingConfig& timing_config,
53       LogManager* log_manager);
54 
55   // Utility method for creating a ConfigProto from an optionally
56   // externally provided value, or from hardcoded defaults. This is a separate
57   // method to aid with testing.
58   static absl::StatusOr<::tensorflow::ConfigProto> InitializeConfigProto(
59       const ::google::protobuf::Any& external_config_proto);
60 
61   ~TensorFlowWrapper();
62 
63   // Wrapper around TensorFlow's Session::Run method with full support for
64   // feeds, fetches and target node names.
65   // Returns OK, OUT_OF_RANGE, INVALID_ARGUMENT, or CANCELLED.
66   absl::Status Run(
67       const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs,
68       const std::vector<std::string>& output_tensor_names,
69       const std::vector<std::string>& target_node_names,
70       std::vector<tensorflow::Tensor>* outputs);
71 
72   // Closes and releases the TensorFlow session. After this is called, no
73   // further calls on this TensorFlowWrapper should be made. Subsequent calls to
74   // CloseAndRelease() will have no effect.
75   absl::Status CloseAndRelease();
76 
77  private:
TensorFlowWrapper(std::unique_ptr<tensorflow::Session> session,std::unique_ptr<InterruptibleRunner> interruptible_runner,LogManager * log_manager)78   TensorFlowWrapper(std::unique_ptr<tensorflow::Session> session,
79                     std::unique_ptr<InterruptibleRunner> interruptible_runner,
80                     LogManager* log_manager)
81       : session_(std::move(session)),
82         interruptible_runner_(std::move(interruptible_runner)),
83         session_closed_(false) {}
84 
85   // Converts a TensorFlow status to an absl::Status.
86   //
87   // Rule:
88   // TensorFlow OK status -> absl OK status
89   // TensorFlow OUT_OF_RANGE -> absl OUT_OF_RANGE status (this is TF indicating
90   //   that the plan decided to abort, e.g. because of convergence)
91   // Other TensorFlow status -> absl INVALID_ARGUMENT status with error
92   // message being message_prefix + TensorFlow status code + error message.
93   static absl::Status ToFcpStatus(tensorflow::Status s,
94                                   const std::string& message_prefix);
95 
96   std::unique_ptr<tensorflow::Session> session_;
97   std::unique_ptr<InterruptibleRunner> interruptible_runner_;
98   absl::Mutex session_lock_;
99   bool session_closed_;
100 };
101 
102 }  // namespace engine
103 }  // namespace client
104 }  // namespace fcp
105 
106 #endif  // FCP_CLIENT_ENGINE_TF_WRAPPER_H_
107