xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/tf_wrapper.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/engine/tf_wrapper.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "google/protobuf/any.pb.h"
24 #include "absl/status/status.h"
25 #include "absl/status/statusor.h"
26 #include "fcp/client/diag_codes.pb.h"
27 #include "fcp/client/engine/plan_engine_helpers.h"
28 #include "fcp/client/interruptible_runner.h"
29 
30 namespace fcp {
31 namespace client {
32 namespace engine {
33 
34 using ::google::protobuf::Any;
35 
36 // If `external_config_proto` contains a non-empty config proto, use that.
37 // Otherwise initializes a config proto from a set of defaults.
38 absl::StatusOr<tensorflow::ConfigProto>
InitializeConfigProto(const Any & external_config_proto)39 TensorFlowWrapper::InitializeConfigProto(const Any& external_config_proto) {
40   // Previously, we specified a hardcoded set of options in the ConfigProto by
41   // default. However, if a non-empty ConfigProto is now provided as a
42   // parameter, then we should use it as-is, without overriding any of the
43   // options (otherwise we prevent the caller from having control over the
44   // parameters we set by default).
45   if (external_config_proto.ByteSizeLong() > 0) {
46     // Unpack the external_config_proto parameter if one is provided. In this
47     // case it must be a packed ConfigProto (anything else is an error).
48     // Accordingly, UnpackTo will return false if parsing fails or if the Any is
49     // not of a compatible type.
50     tensorflow::ConfigProto unpacked_config_proto;
51     if (!external_config_proto.UnpackTo(&unpacked_config_proto)) {
52       return absl::InvalidArgumentError("Could not parse ConfigProto.");
53     }
54     if (unpacked_config_proto.ByteSizeLong() > 0) {
55       // The caller-provided, unpacked ConfigProto was not empty, so we use it
56       // in the SessionOptions and we do not specify our default config options
57       // anymore.
58       return unpacked_config_proto;
59     }
60     // We purposely fall through to the next block if the unpacked_config_proto
61     // was empty.
62   }
63 
64   // Only if the provided ConfigProto was empty (or if none was provided) do we
65   // still set hardcoded options (this is our "old" behavior, equivalent to what
66   // we did before we supported caller-specified ConfigProtos).
67   //
68   // WARNING: If the need for tuning configuration options further arises again
69   // in the future, we ideally shouldn't update any of the hardcoded ConfigProto
70   // values here anymore. Instead, we should expect our callers to specify any
71   // ConfigProto values they want to use. We only maintain this block of code
72   // for compatibility with callers that don't provide any ConfigProto at all
73   // (yet).
74   //
75   tensorflow::ConfigProto config_proto;
76   config_proto.mutable_graph_options()->set_place_pruned_graph(true);
77   auto mutable_experimental = config_proto.mutable_experimental();
78   mutable_experimental->set_optimize_for_static_graph(true);
79   mutable_experimental->set_disable_output_partition_graphs(true);
80   return config_proto;
81 }
82 
Create(const std::string & graph,const Any & config_proto,std::function<bool ()> should_abort,const InterruptibleRunner::TimingConfig & timing_config,LogManager * log_manager)83 absl::StatusOr<std::unique_ptr<TensorFlowWrapper>> TensorFlowWrapper::Create(
84     const std::string& graph, const Any& config_proto,
85     std::function<bool()> should_abort,
86     const InterruptibleRunner::TimingConfig& timing_config,
87     LogManager* log_manager) {
88   // Create a tensorflow::Session.
89   tensorflow::Session* session_ptr;
90   std::unique_ptr<tensorflow::Session> session;
91   tensorflow::SessionOptions session_options;
92   FCP_ASSIGN_OR_RETURN(session_options.config,
93                        InitializeConfigProto(config_proto));
94 
95   tensorflow::Status status =
96       tensorflow::NewSession(session_options, &session_ptr);
97   if (!status.ok()) {
98     return ToFcpStatus(status, "Error in tensorflow::NewSession()");
99   }
100   session = absl::WrapUnique(session_ptr);
101 
102   // Parse GraphDef.
103   tensorflow::GraphDef graph_def;
104   bool parse_result = graph_def.ParseFromString(graph);
105   if (parse_result == false) {
106     return absl::InvalidArgumentError("Could not parse GraphDef.");
107   }
108   // Load graph.
109   status = session->Create(std::move(graph_def));
110   if (!status.ok()) {
111     return ToFcpStatus(status, "Error in Session::Create()");
112   }
113 
114   // Create an InterruptibleRunner to execute TF calls in a background thread,
115   // allowing us to abort them if need be.
116   auto interruptible_runner = std::make_unique<InterruptibleRunner>(
117       log_manager, should_abort, timing_config,
118       InterruptibleRunner::DiagnosticsConfig{
119           .interrupted =
120               ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
121           .interrupt_timeout = ProdDiagCode::
122               BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
123           .interrupted_extended = ProdDiagCode::
124               BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
125           .interrupt_timeout_extended = ProdDiagCode::
126               BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT});
127   auto wrapper = absl::WrapUnique(new TensorFlowWrapper(
128       std::move(session), std::move(interruptible_runner), log_manager));
129   return wrapper;
130 }
131 
~TensorFlowWrapper()132 TensorFlowWrapper::~TensorFlowWrapper() { FCP_CHECK(CloseAndRelease().ok()); }
133 
ToFcpStatus(tensorflow::Status s,const std::string & message_prefix)134 absl::Status TensorFlowWrapper::ToFcpStatus(tensorflow::Status s,
135                                             const std::string& message_prefix) {
136   if (s.ok()) {
137     return absl::OkStatus();
138   } else if (s.code() == tensorflow::error::OUT_OF_RANGE) {
139     return absl::OutOfRangeError("");
140   } else {
141     return absl::InvalidArgumentError(
142         absl::StrCat(message_prefix, ": ", s.ToString()));
143   }
144 }
145 
Run(const std::vector<std::pair<std::string,tensorflow::Tensor>> & inputs,const std::vector<std::string> & output_tensor_names,const std::vector<std::string> & target_node_names,std::vector<tensorflow::Tensor> * outputs)146 absl::Status TensorFlowWrapper::Run(
147     const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs,
148     const std::vector<std::string>& output_tensor_names,
149     const std::vector<std::string>& target_node_names,
150     std::vector<tensorflow::Tensor>* outputs) {
151   FCP_CHECK(!session_closed_) << "Run() called after session close!";
152 
153   auto tensorflow_runnable = [&inputs, &output_tensor_names, &target_node_names,
154                               &outputs, this]() -> absl::Status {
155     tensorflow::Status status = this->session_->Run(inputs, output_tensor_names,
156                                                     target_node_names, outputs);
157     if (!status.ok()) {
158       return ToFcpStatus(status, "Error in Session::Run()");
159     }
160     return absl::OkStatus();
161   };
162   auto abort_tensorflow = [this]() {
163     absl::MutexLock _(&session_lock_);
164     // Errors from Close() are expected when interrupting ongoing calls. We
165     // don't call CloseAndRelease() here because that would free the TensorFlow
166     // session while other TensorFlow worker threads may still be using it.
167     session_->Close().IgnoreError();
168     session_closed_ = true;
169   };
170   return interruptible_runner_->Run(tensorflow_runnable, abort_tensorflow);
171 }
172 
CloseAndRelease()173 absl::Status TensorFlowWrapper::CloseAndRelease() {
174   absl::MutexLock _(&session_lock_);
175   // If the TensorFlow session hasn't been closed yet, close it.
176   if (!session_closed_) {
177     FCP_ENGINE_RETURN_IF_ERROR(
178         ToFcpStatus(session_->Close(), "Could not close TF session"));
179     session_closed_ = true;
180   }
181   // If the TensorflowSession hasn't been released yet, release it.
182   if (session_) {
183     session_.reset();
184   }
185   return absl::OkStatus();
186 }
187 
188 }  // namespace engine
189 }  // namespace client
190 }  // namespace fcp
191