1*14675a02SAndroid Build Coastguard Worker /* 2*14675a02SAndroid Build Coastguard Worker * Copyright 2021 Google LLC 3*14675a02SAndroid Build Coastguard Worker * 4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*14675a02SAndroid Build Coastguard Worker * 8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*14675a02SAndroid Build Coastguard Worker * 10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*14675a02SAndroid Build Coastguard Worker * limitations under the License. 15*14675a02SAndroid Build Coastguard Worker */ 16*14675a02SAndroid Build Coastguard Worker #ifndef FCP_CLIENT_ENGINE_COMMON_H_ 17*14675a02SAndroid Build Coastguard Worker #define FCP_CLIENT_ENGINE_COMMON_H_ 18*14675a02SAndroid Build Coastguard Worker 19*14675a02SAndroid Build Coastguard Worker #include <string> 20*14675a02SAndroid Build Coastguard Worker #include <utility> 21*14675a02SAndroid Build Coastguard Worker #include <vector> 22*14675a02SAndroid Build Coastguard Worker 23*14675a02SAndroid Build Coastguard Worker #include "absl/container/flat_hash_set.h" 24*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h" 25*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/engine.pb.h" 26*14675a02SAndroid Build Coastguard Worker #include "fcp/client/stats.h" 27*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/plan.pb.h" 28*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor.h" 29*14675a02SAndroid Build Coastguard Worker 30*14675a02SAndroid Build Coastguard Worker namespace fcp { 31*14675a02SAndroid Build Coastguard Worker namespace client { 32*14675a02SAndroid Build Coastguard Worker namespace engine { 33*14675a02SAndroid Build Coastguard Worker 34*14675a02SAndroid Build Coastguard Worker enum class PlanOutcome { 35*14675a02SAndroid Build Coastguard Worker kSuccess, 36*14675a02SAndroid Build Coastguard Worker // A TensorFlow error occurred. 37*14675a02SAndroid Build Coastguard Worker kTensorflowError, 38*14675a02SAndroid Build Coastguard Worker // Computation was interrupted. 39*14675a02SAndroid Build Coastguard Worker kInterrupted, 40*14675a02SAndroid Build Coastguard Worker // The input parameters are invalid. 41*14675a02SAndroid Build Coastguard Worker kInvalidArgument, 42*14675a02SAndroid Build Coastguard Worker // An example iterator error occurred. 43*14675a02SAndroid Build Coastguard Worker kExampleIteratorError, 44*14675a02SAndroid Build Coastguard Worker }; 45*14675a02SAndroid Build Coastguard Worker 46*14675a02SAndroid Build Coastguard Worker // The result of a call to `SimplePlanEngine::RunPlan` or 47*14675a02SAndroid Build Coastguard Worker // `TfLitePlanEngine::RunPlan`. 48*14675a02SAndroid Build Coastguard Worker struct PlanResult { 49*14675a02SAndroid Build Coastguard Worker explicit PlanResult(PlanOutcome outcome, absl::Status status); 50*14675a02SAndroid Build Coastguard Worker 51*14675a02SAndroid Build Coastguard Worker // The outcome of the plan execution. 52*14675a02SAndroid Build Coastguard Worker PlanOutcome outcome; 53*14675a02SAndroid Build Coastguard Worker // Only set if `outcome` is `kSuccess`, otherwise this is empty. 54*14675a02SAndroid Build Coastguard Worker std::vector<tensorflow::Tensor> output_tensors; 55*14675a02SAndroid Build Coastguard Worker // Only set if `outcome` is `kSuccess`, otherwise this is empty. 56*14675a02SAndroid Build Coastguard Worker std::vector<std::string> output_names; 57*14675a02SAndroid Build Coastguard Worker // When the outcome is `kSuccess`, the status is ok. Otherwise, this status 58*14675a02SAndroid Build Coastguard Worker // contain the original error status which leads to the PlanOutcome. 59*14675a02SAndroid Build Coastguard Worker absl::Status original_status; 60*14675a02SAndroid Build Coastguard Worker ::fcp::client::ExampleStats example_stats; 61*14675a02SAndroid Build Coastguard Worker 62*14675a02SAndroid Build Coastguard Worker PlanResult(PlanResult&&) = default; 63*14675a02SAndroid Build Coastguard Worker PlanResult& operator=(PlanResult&&) = default; 64*14675a02SAndroid Build Coastguard Worker 65*14675a02SAndroid Build Coastguard Worker // Disallow copy and assign. 66*14675a02SAndroid Build Coastguard Worker PlanResult(const PlanResult&) = delete; 67*14675a02SAndroid Build Coastguard Worker PlanResult& operator=(const PlanResult&) = delete; 68*14675a02SAndroid Build Coastguard Worker }; 69*14675a02SAndroid Build Coastguard Worker 70*14675a02SAndroid Build Coastguard Worker // Validates that the input tensors match what's inside the TensorflowSpec. 71*14675a02SAndroid Build Coastguard Worker absl::Status ValidateTensorflowSpec( 72*14675a02SAndroid Build Coastguard Worker const google::internal::federated::plan::TensorflowSpec& tensorflow_spec, 73*14675a02SAndroid Build Coastguard Worker const absl::flat_hash_set<std::string>& expected_input_tensor_names_set, 74*14675a02SAndroid Build Coastguard Worker const std::vector<std::string>& output_names); 75*14675a02SAndroid Build Coastguard Worker 76*14675a02SAndroid Build Coastguard Worker PhaseOutcome ConvertPlanOutcomeToPhaseOutcome(PlanOutcome plan_outcome); 77*14675a02SAndroid Build Coastguard Worker 78*14675a02SAndroid Build Coastguard Worker absl::Status ConvertPlanOutcomeToStatus(engine::PlanOutcome outcome); 79*14675a02SAndroid Build Coastguard Worker 80*14675a02SAndroid Build Coastguard Worker } // namespace engine 81*14675a02SAndroid Build Coastguard Worker } // namespace client 82*14675a02SAndroid Build Coastguard Worker } // namespace fcp 83*14675a02SAndroid Build Coastguard Worker 84*14675a02SAndroid Build Coastguard Worker #endif // FCP_CLIENT_ENGINE_COMMON_H_ 85