1 /* 2 * Copyright 2019 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef FCP_CLIENT_ENGINE_TF_WRAPPER_H_ 17 #define FCP_CLIENT_ENGINE_TF_WRAPPER_H_ 18 19 #include <functional> 20 #include <string> 21 #include <utility> 22 23 #include "google/protobuf/any.pb.h" 24 #include "absl/status/status.h" 25 #include "absl/status/statusor.h" 26 #include "absl/time/time.h" 27 #include "fcp/base/future.h" 28 #include "fcp/base/monitoring.h" 29 #include "fcp/base/scheduler.h" 30 #include "fcp/client/interruptible_runner.h" 31 #include "fcp/client/log_manager.h" 32 #include "tensorflow/core/public/session.h" 33 34 namespace fcp { 35 namespace client { 36 namespace engine { 37 38 // A class to call into TensorFlow. 39 // All functions in this interface indicate errors as follows: 40 // - CANCELLED: interrupted execution 41 // - INVALID_ARGUMENT: TensorFlow error. The TensorFlow error code and message 42 // are included in the Status message. 43 // - OUT_OF_RANGE: internal abortion, i.e. TensorFlow reporting the model 44 // aborted execution. 45 // This class supports aborting ongoing calls, by polling the provided 46 // should_abort function. 47 class TensorFlowWrapper { 48 public: 49 static absl::StatusOr<std::unique_ptr<TensorFlowWrapper>> Create( 50 const std::string& graph, const ::google::protobuf::Any& config_proto, 51 std::function<bool()> should_abort, 52 const InterruptibleRunner::TimingConfig& timing_config, 53 LogManager* log_manager); 54 55 // Utility method for creating a ConfigProto from an optionally 56 // externally provided value, or from hardcoded defaults. This is a separate 57 // method to aid with testing. 58 static absl::StatusOr<::tensorflow::ConfigProto> InitializeConfigProto( 59 const ::google::protobuf::Any& external_config_proto); 60 61 ~TensorFlowWrapper(); 62 63 // Wrapper around TensorFlow's Session::Run method with full support for 64 // feeds, fetches and target node names. 65 // Returns OK, OUT_OF_RANGE, INVALID_ARGUMENT, or CANCELLED. 66 absl::Status Run( 67 const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs, 68 const std::vector<std::string>& output_tensor_names, 69 const std::vector<std::string>& target_node_names, 70 std::vector<tensorflow::Tensor>* outputs); 71 72 // Closes and releases the TensorFlow session. After this is called, no 73 // further calls on this TensorFlowWrapper should be made. Subsequent calls to 74 // CloseAndRelease() will have no effect. 75 absl::Status CloseAndRelease(); 76 77 private: TensorFlowWrapper(std::unique_ptr<tensorflow::Session> session,std::unique_ptr<InterruptibleRunner> interruptible_runner,LogManager * log_manager)78 TensorFlowWrapper(std::unique_ptr<tensorflow::Session> session, 79 std::unique_ptr<InterruptibleRunner> interruptible_runner, 80 LogManager* log_manager) 81 : session_(std::move(session)), 82 interruptible_runner_(std::move(interruptible_runner)), 83 session_closed_(false) {} 84 85 // Converts a TensorFlow status to an absl::Status. 86 // 87 // Rule: 88 // TensorFlow OK status -> absl OK status 89 // TensorFlow OUT_OF_RANGE -> absl OUT_OF_RANGE status (this is TF indicating 90 // that the plan decided to abort, e.g. because of convergence) 91 // Other TensorFlow status -> absl INVALID_ARGUMENT status with error 92 // message being message_prefix + TensorFlow status code + error message. 93 static absl::Status ToFcpStatus(tensorflow::Status s, 94 const std::string& message_prefix); 95 96 std::unique_ptr<tensorflow::Session> session_; 97 std::unique_ptr<InterruptibleRunner> interruptible_runner_; 98 absl::Mutex session_lock_; 99 bool session_closed_; 100 }; 101 102 } // namespace engine 103 } // namespace client 104 } // namespace fcp 105 106 #endif // FCP_CLIENT_ENGINE_TF_WRAPPER_H_ 107