1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/engine/tf_wrapper.h"
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "google/protobuf/any.pb.h"
24 #include "absl/status/status.h"
25 #include "absl/status/statusor.h"
26 #include "fcp/client/diag_codes.pb.h"
27 #include "fcp/client/engine/plan_engine_helpers.h"
28 #include "fcp/client/interruptible_runner.h"
29
30 namespace fcp {
31 namespace client {
32 namespace engine {
33
34 using ::google::protobuf::Any;
35
36 // If `external_config_proto` contains a non-empty config proto, use that.
37 // Otherwise initializes a config proto from a set of defaults.
38 absl::StatusOr<tensorflow::ConfigProto>
InitializeConfigProto(const Any & external_config_proto)39 TensorFlowWrapper::InitializeConfigProto(const Any& external_config_proto) {
40 // Previously, we specified a hardcoded set of options in the ConfigProto by
41 // default. However, if a non-empty ConfigProto is now provided as a
42 // parameter, then we should use it as-is, without overriding any of the
43 // options (otherwise we prevent the caller from having control over the
44 // parameters we set by default).
45 if (external_config_proto.ByteSizeLong() > 0) {
46 // Unpack the external_config_proto parameter if one is provided. In this
47 // case it must be a packed ConfigProto (anything else is an error).
48 // Accordingly, UnpackTo will return false if parsing fails or if the Any is
49 // not of a compatible type.
50 tensorflow::ConfigProto unpacked_config_proto;
51 if (!external_config_proto.UnpackTo(&unpacked_config_proto)) {
52 return absl::InvalidArgumentError("Could not parse ConfigProto.");
53 }
54 if (unpacked_config_proto.ByteSizeLong() > 0) {
55 // The caller-provided, unpacked ConfigProto was not empty, so we use it
56 // in the SessionOptions and we do not specify our default config options
57 // anymore.
58 return unpacked_config_proto;
59 }
60 // We purposely fall through to the next block if the unpacked_config_proto
61 // was empty.
62 }
63
64 // Only if the provided ConfigProto was empty (or if none was provided) do we
65 // still set hardcoded options (this is our "old" behavior, equivalent to what
66 // we did before we supported caller-specified ConfigProtos).
67 //
68 // WARNING: If the need for tuning configuration options further arises again
69 // in the future, we ideally shouldn't update any of the hardcoded ConfigProto
70 // values here anymore. Instead, we should expect our callers to specify any
71 // ConfigProto values they want to use. We only maintain this block of code
72 // for compatibility with callers that don't provide any ConfigProto at all
73 // (yet).
74 //
75 tensorflow::ConfigProto config_proto;
76 config_proto.mutable_graph_options()->set_place_pruned_graph(true);
77 auto mutable_experimental = config_proto.mutable_experimental();
78 mutable_experimental->set_optimize_for_static_graph(true);
79 mutable_experimental->set_disable_output_partition_graphs(true);
80 return config_proto;
81 }
82
Create(const std::string & graph,const Any & config_proto,std::function<bool ()> should_abort,const InterruptibleRunner::TimingConfig & timing_config,LogManager * log_manager)83 absl::StatusOr<std::unique_ptr<TensorFlowWrapper>> TensorFlowWrapper::Create(
84 const std::string& graph, const Any& config_proto,
85 std::function<bool()> should_abort,
86 const InterruptibleRunner::TimingConfig& timing_config,
87 LogManager* log_manager) {
88 // Create a tensorflow::Session.
89 tensorflow::Session* session_ptr;
90 std::unique_ptr<tensorflow::Session> session;
91 tensorflow::SessionOptions session_options;
92 FCP_ASSIGN_OR_RETURN(session_options.config,
93 InitializeConfigProto(config_proto));
94
95 tensorflow::Status status =
96 tensorflow::NewSession(session_options, &session_ptr);
97 if (!status.ok()) {
98 return ToFcpStatus(status, "Error in tensorflow::NewSession()");
99 }
100 session = absl::WrapUnique(session_ptr);
101
102 // Parse GraphDef.
103 tensorflow::GraphDef graph_def;
104 bool parse_result = graph_def.ParseFromString(graph);
105 if (parse_result == false) {
106 return absl::InvalidArgumentError("Could not parse GraphDef.");
107 }
108 // Load graph.
109 status = session->Create(std::move(graph_def));
110 if (!status.ok()) {
111 return ToFcpStatus(status, "Error in Session::Create()");
112 }
113
114 // Create an InterruptibleRunner to execute TF calls in a background thread,
115 // allowing us to abort them if need be.
116 auto interruptible_runner = std::make_unique<InterruptibleRunner>(
117 log_manager, should_abort, timing_config,
118 InterruptibleRunner::DiagnosticsConfig{
119 .interrupted =
120 ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
121 .interrupt_timeout = ProdDiagCode::
122 BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
123 .interrupted_extended = ProdDiagCode::
124 BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
125 .interrupt_timeout_extended = ProdDiagCode::
126 BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT});
127 auto wrapper = absl::WrapUnique(new TensorFlowWrapper(
128 std::move(session), std::move(interruptible_runner), log_manager));
129 return wrapper;
130 }
131
~TensorFlowWrapper()132 TensorFlowWrapper::~TensorFlowWrapper() { FCP_CHECK(CloseAndRelease().ok()); }
133
ToFcpStatus(tensorflow::Status s,const std::string & message_prefix)134 absl::Status TensorFlowWrapper::ToFcpStatus(tensorflow::Status s,
135 const std::string& message_prefix) {
136 if (s.ok()) {
137 return absl::OkStatus();
138 } else if (s.code() == tensorflow::error::OUT_OF_RANGE) {
139 return absl::OutOfRangeError("");
140 } else {
141 return absl::InvalidArgumentError(
142 absl::StrCat(message_prefix, ": ", s.ToString()));
143 }
144 }
145
Run(const std::vector<std::pair<std::string,tensorflow::Tensor>> & inputs,const std::vector<std::string> & output_tensor_names,const std::vector<std::string> & target_node_names,std::vector<tensorflow::Tensor> * outputs)146 absl::Status TensorFlowWrapper::Run(
147 const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs,
148 const std::vector<std::string>& output_tensor_names,
149 const std::vector<std::string>& target_node_names,
150 std::vector<tensorflow::Tensor>* outputs) {
151 FCP_CHECK(!session_closed_) << "Run() called after session close!";
152
153 auto tensorflow_runnable = [&inputs, &output_tensor_names, &target_node_names,
154 &outputs, this]() -> absl::Status {
155 tensorflow::Status status = this->session_->Run(inputs, output_tensor_names,
156 target_node_names, outputs);
157 if (!status.ok()) {
158 return ToFcpStatus(status, "Error in Session::Run()");
159 }
160 return absl::OkStatus();
161 };
162 auto abort_tensorflow = [this]() {
163 absl::MutexLock _(&session_lock_);
164 // Errors from Close() are expected when interrupting ongoing calls. We
165 // don't call CloseAndRelease() here because that would free the TensorFlow
166 // session while other TensorFlow worker threads may still be using it.
167 session_->Close().IgnoreError();
168 session_closed_ = true;
169 };
170 return interruptible_runner_->Run(tensorflow_runnable, abort_tensorflow);
171 }
172
CloseAndRelease()173 absl::Status TensorFlowWrapper::CloseAndRelease() {
174 absl::MutexLock _(&session_lock_);
175 // If the TensorFlow session hasn't been closed yet, close it.
176 if (!session_closed_) {
177 FCP_ENGINE_RETURN_IF_ERROR(
178 ToFcpStatus(session_->Close(), "Could not close TF session"));
179 session_closed_ = true;
180 }
181 // If the TensorflowSession hasn't been released yet, release it.
182 if (session_) {
183 session_.reset();
184 }
185 return absl::OkStatus();
186 }
187
188 } // namespace engine
189 } // namespace client
190 } // namespace fcp
191