1*14675a02SAndroid Build Coastguard Worker /* 2*14675a02SAndroid Build Coastguard Worker * Copyright 2019 Google LLC 3*14675a02SAndroid Build Coastguard Worker * 4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*14675a02SAndroid Build Coastguard Worker * 8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*14675a02SAndroid Build Coastguard Worker * 10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*14675a02SAndroid Build Coastguard Worker * limitations under the License. 15*14675a02SAndroid Build Coastguard Worker */ 16*14675a02SAndroid Build Coastguard Worker #ifndef FCP_CLIENT_ENGINE_TF_WRAPPER_H_ 17*14675a02SAndroid Build Coastguard Worker #define FCP_CLIENT_ENGINE_TF_WRAPPER_H_ 18*14675a02SAndroid Build Coastguard Worker 19*14675a02SAndroid Build Coastguard Worker #include <functional> 20*14675a02SAndroid Build Coastguard Worker #include <string> 21*14675a02SAndroid Build Coastguard Worker #include <utility> 22*14675a02SAndroid Build Coastguard Worker 23*14675a02SAndroid Build Coastguard Worker #include "google/protobuf/any.pb.h" 24*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h" 25*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h" 26*14675a02SAndroid Build Coastguard Worker #include "absl/time/time.h" 27*14675a02SAndroid Build Coastguard Worker #include "fcp/base/future.h" 28*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h" 29*14675a02SAndroid Build Coastguard Worker #include "fcp/base/scheduler.h" 30*14675a02SAndroid Build Coastguard Worker #include "fcp/client/interruptible_runner.h" 31*14675a02SAndroid Build Coastguard Worker #include "fcp/client/log_manager.h" 32*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/public/session.h" 33*14675a02SAndroid Build Coastguard Worker 34*14675a02SAndroid Build Coastguard Worker namespace fcp { 35*14675a02SAndroid Build Coastguard Worker namespace client { 36*14675a02SAndroid Build Coastguard Worker namespace engine { 37*14675a02SAndroid Build Coastguard Worker 38*14675a02SAndroid Build Coastguard Worker // A class to call into TensorFlow. 39*14675a02SAndroid Build Coastguard Worker // All functions in this interface indicate errors as follows: 40*14675a02SAndroid Build Coastguard Worker // - CANCELLED: interrupted execution 41*14675a02SAndroid Build Coastguard Worker // - INVALID_ARGUMENT: TensorFlow error. The TensorFlow error code and message 42*14675a02SAndroid Build Coastguard Worker // are included in the Status message. 43*14675a02SAndroid Build Coastguard Worker // - OUT_OF_RANGE: internal abortion, i.e. TensorFlow reporting the model 44*14675a02SAndroid Build Coastguard Worker // aborted execution. 45*14675a02SAndroid Build Coastguard Worker // This class supports aborting ongoing calls, by polling the provided 46*14675a02SAndroid Build Coastguard Worker // should_abort function. 47*14675a02SAndroid Build Coastguard Worker class TensorFlowWrapper { 48*14675a02SAndroid Build Coastguard Worker public: 49*14675a02SAndroid Build Coastguard Worker static absl::StatusOr<std::unique_ptr<TensorFlowWrapper>> Create( 50*14675a02SAndroid Build Coastguard Worker const std::string& graph, const ::google::protobuf::Any& config_proto, 51*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort, 52*14675a02SAndroid Build Coastguard Worker const InterruptibleRunner::TimingConfig& timing_config, 53*14675a02SAndroid Build Coastguard Worker LogManager* log_manager); 54*14675a02SAndroid Build Coastguard Worker 55*14675a02SAndroid Build Coastguard Worker // Utility method for creating a ConfigProto from an optionally 56*14675a02SAndroid Build Coastguard Worker // externally provided value, or from hardcoded defaults. This is a separate 57*14675a02SAndroid Build Coastguard Worker // method to aid with testing. 58*14675a02SAndroid Build Coastguard Worker static absl::StatusOr<::tensorflow::ConfigProto> InitializeConfigProto( 59*14675a02SAndroid Build Coastguard Worker const ::google::protobuf::Any& external_config_proto); 60*14675a02SAndroid Build Coastguard Worker 61*14675a02SAndroid Build Coastguard Worker ~TensorFlowWrapper(); 62*14675a02SAndroid Build Coastguard Worker 63*14675a02SAndroid Build Coastguard Worker // Wrapper around TensorFlow's Session::Run method with full support for 64*14675a02SAndroid Build Coastguard Worker // feeds, fetches and target node names. 65*14675a02SAndroid Build Coastguard Worker // Returns OK, OUT_OF_RANGE, INVALID_ARGUMENT, or CANCELLED. 66*14675a02SAndroid Build Coastguard Worker absl::Status Run( 67*14675a02SAndroid Build Coastguard Worker const std::vector<std::pair<std::string, tensorflow::Tensor>>& inputs, 68*14675a02SAndroid Build Coastguard Worker const std::vector<std::string>& output_tensor_names, 69*14675a02SAndroid Build Coastguard Worker const std::vector<std::string>& target_node_names, 70*14675a02SAndroid Build Coastguard Worker std::vector<tensorflow::Tensor>* outputs); 71*14675a02SAndroid Build Coastguard Worker 72*14675a02SAndroid Build Coastguard Worker // Closes and releases the TensorFlow session. After this is called, no 73*14675a02SAndroid Build Coastguard Worker // further calls on this TensorFlowWrapper should be made. Subsequent calls to 74*14675a02SAndroid Build Coastguard Worker // CloseAndRelease() will have no effect. 75*14675a02SAndroid Build Coastguard Worker absl::Status CloseAndRelease(); 76*14675a02SAndroid Build Coastguard Worker 77*14675a02SAndroid Build Coastguard Worker private: TensorFlowWrapper(std::unique_ptr<tensorflow::Session> session,std::unique_ptr<InterruptibleRunner> interruptible_runner,LogManager * log_manager)78*14675a02SAndroid Build Coastguard Worker TensorFlowWrapper(std::unique_ptr<tensorflow::Session> session, 79*14675a02SAndroid Build Coastguard Worker std::unique_ptr<InterruptibleRunner> interruptible_runner, 80*14675a02SAndroid Build Coastguard Worker LogManager* log_manager) 81*14675a02SAndroid Build Coastguard Worker : session_(std::move(session)), 82*14675a02SAndroid Build Coastguard Worker interruptible_runner_(std::move(interruptible_runner)), 83*14675a02SAndroid Build Coastguard Worker session_closed_(false) {} 84*14675a02SAndroid Build Coastguard Worker 85*14675a02SAndroid Build Coastguard Worker // Converts a TensorFlow status to an absl::Status. 86*14675a02SAndroid Build Coastguard Worker // 87*14675a02SAndroid Build Coastguard Worker // Rule: 88*14675a02SAndroid Build Coastguard Worker // TensorFlow OK status -> absl OK status 89*14675a02SAndroid Build Coastguard Worker // TensorFlow OUT_OF_RANGE -> absl OUT_OF_RANGE status (this is TF indicating 90*14675a02SAndroid Build Coastguard Worker // that the plan decided to abort, e.g. because of convergence) 91*14675a02SAndroid Build Coastguard Worker // Other TensorFlow status -> absl INVALID_ARGUMENT status with error 92*14675a02SAndroid Build Coastguard Worker // message being message_prefix + TensorFlow status code + error message. 93*14675a02SAndroid Build Coastguard Worker static absl::Status ToFcpStatus(tensorflow::Status s, 94*14675a02SAndroid Build Coastguard Worker const std::string& message_prefix); 95*14675a02SAndroid Build Coastguard Worker 96*14675a02SAndroid Build Coastguard Worker std::unique_ptr<tensorflow::Session> session_; 97*14675a02SAndroid Build Coastguard Worker std::unique_ptr<InterruptibleRunner> interruptible_runner_; 98*14675a02SAndroid Build Coastguard Worker absl::Mutex session_lock_; 99*14675a02SAndroid Build Coastguard Worker bool session_closed_; 100*14675a02SAndroid Build Coastguard Worker }; 101*14675a02SAndroid Build Coastguard Worker 102*14675a02SAndroid Build Coastguard Worker } // namespace engine 103*14675a02SAndroid Build Coastguard Worker } // namespace client 104*14675a02SAndroid Build Coastguard Worker } // namespace fcp 105*14675a02SAndroid Build Coastguard Worker 106*14675a02SAndroid Build Coastguard Worker #endif // FCP_CLIENT_ENGINE_TF_WRAPPER_H_ 107