xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/common.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2021 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker #ifndef FCP_CLIENT_ENGINE_COMMON_H_
17*14675a02SAndroid Build Coastguard Worker #define FCP_CLIENT_ENGINE_COMMON_H_
18*14675a02SAndroid Build Coastguard Worker 
19*14675a02SAndroid Build Coastguard Worker #include <string>
20*14675a02SAndroid Build Coastguard Worker #include <utility>
21*14675a02SAndroid Build Coastguard Worker #include <vector>
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker #include "absl/container/flat_hash_set.h"
24*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
25*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/engine.pb.h"
26*14675a02SAndroid Build Coastguard Worker #include "fcp/client/stats.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/plan.pb.h"
28*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor.h"
29*14675a02SAndroid Build Coastguard Worker 
30*14675a02SAndroid Build Coastguard Worker namespace fcp {
31*14675a02SAndroid Build Coastguard Worker namespace client {
32*14675a02SAndroid Build Coastguard Worker namespace engine {
33*14675a02SAndroid Build Coastguard Worker 
34*14675a02SAndroid Build Coastguard Worker enum class PlanOutcome {
35*14675a02SAndroid Build Coastguard Worker   kSuccess,
36*14675a02SAndroid Build Coastguard Worker   // A TensorFlow error occurred.
37*14675a02SAndroid Build Coastguard Worker   kTensorflowError,
38*14675a02SAndroid Build Coastguard Worker   // Computation was interrupted.
39*14675a02SAndroid Build Coastguard Worker   kInterrupted,
40*14675a02SAndroid Build Coastguard Worker   // The input parameters are invalid.
41*14675a02SAndroid Build Coastguard Worker   kInvalidArgument,
42*14675a02SAndroid Build Coastguard Worker   // An example iterator error occurred.
43*14675a02SAndroid Build Coastguard Worker   kExampleIteratorError,
44*14675a02SAndroid Build Coastguard Worker };
45*14675a02SAndroid Build Coastguard Worker 
46*14675a02SAndroid Build Coastguard Worker // The result of a call to `SimplePlanEngine::RunPlan` or
47*14675a02SAndroid Build Coastguard Worker // `TfLitePlanEngine::RunPlan`.
48*14675a02SAndroid Build Coastguard Worker struct PlanResult {
49*14675a02SAndroid Build Coastguard Worker   explicit PlanResult(PlanOutcome outcome, absl::Status status);
50*14675a02SAndroid Build Coastguard Worker 
51*14675a02SAndroid Build Coastguard Worker   // The outcome of the plan execution.
52*14675a02SAndroid Build Coastguard Worker   PlanOutcome outcome;
53*14675a02SAndroid Build Coastguard Worker   // Only set if `outcome` is `kSuccess`, otherwise this is empty.
54*14675a02SAndroid Build Coastguard Worker   std::vector<tensorflow::Tensor> output_tensors;
55*14675a02SAndroid Build Coastguard Worker   // Only set if `outcome` is `kSuccess`, otherwise this is empty.
56*14675a02SAndroid Build Coastguard Worker   std::vector<std::string> output_names;
57*14675a02SAndroid Build Coastguard Worker   // When the outcome is `kSuccess`, the status is ok. Otherwise, this status
58*14675a02SAndroid Build Coastguard Worker   // contain the original error status which leads to the PlanOutcome.
59*14675a02SAndroid Build Coastguard Worker   absl::Status original_status;
60*14675a02SAndroid Build Coastguard Worker   ::fcp::client::ExampleStats example_stats;
61*14675a02SAndroid Build Coastguard Worker 
62*14675a02SAndroid Build Coastguard Worker   PlanResult(PlanResult&&) = default;
63*14675a02SAndroid Build Coastguard Worker   PlanResult& operator=(PlanResult&&) = default;
64*14675a02SAndroid Build Coastguard Worker 
65*14675a02SAndroid Build Coastguard Worker   // Disallow copy and assign.
66*14675a02SAndroid Build Coastguard Worker   PlanResult(const PlanResult&) = delete;
67*14675a02SAndroid Build Coastguard Worker   PlanResult& operator=(const PlanResult&) = delete;
68*14675a02SAndroid Build Coastguard Worker };
69*14675a02SAndroid Build Coastguard Worker 
70*14675a02SAndroid Build Coastguard Worker // Validates that the input tensors match what's inside the TensorflowSpec.
71*14675a02SAndroid Build Coastguard Worker absl::Status ValidateTensorflowSpec(
72*14675a02SAndroid Build Coastguard Worker     const google::internal::federated::plan::TensorflowSpec& tensorflow_spec,
73*14675a02SAndroid Build Coastguard Worker     const absl::flat_hash_set<std::string>& expected_input_tensor_names_set,
74*14675a02SAndroid Build Coastguard Worker     const std::vector<std::string>& output_names);
75*14675a02SAndroid Build Coastguard Worker 
76*14675a02SAndroid Build Coastguard Worker PhaseOutcome ConvertPlanOutcomeToPhaseOutcome(PlanOutcome plan_outcome);
77*14675a02SAndroid Build Coastguard Worker 
78*14675a02SAndroid Build Coastguard Worker absl::Status ConvertPlanOutcomeToStatus(engine::PlanOutcome outcome);
79*14675a02SAndroid Build Coastguard Worker 
80*14675a02SAndroid Build Coastguard Worker }  // namespace engine
81*14675a02SAndroid Build Coastguard Worker }  // namespace client
82*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
83*14675a02SAndroid Build Coastguard Worker 
84*14675a02SAndroid Build Coastguard Worker #endif  // FCP_CLIENT_ENGINE_COMMON_H_
85