1*14675a02SAndroid Build Coastguard Worker /* 2*14675a02SAndroid Build Coastguard Worker * Copyright 2022 Google LLC 3*14675a02SAndroid Build Coastguard Worker * 4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*14675a02SAndroid Build Coastguard Worker * 8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*14675a02SAndroid Build Coastguard Worker * 10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*14675a02SAndroid Build Coastguard Worker * limitations under the License. 15*14675a02SAndroid Build Coastguard Worker */ 16*14675a02SAndroid Build Coastguard Worker 17*14675a02SAndroid Build Coastguard Worker #ifndef FCP_AGGREGATION_TENSORFLOW_CONVERTERS_H_ 18*14675a02SAndroid Build Coastguard Worker #define FCP_AGGREGATION_TENSORFLOW_CONVERTERS_H_ 19*14675a02SAndroid Build Coastguard Worker 20*14675a02SAndroid Build Coastguard Worker #include <memory> 21*14675a02SAndroid Build Coastguard Worker 22*14675a02SAndroid Build Coastguard Worker #include "fcp/aggregation/core/datatype.h" 23*14675a02SAndroid Build Coastguard Worker #include "fcp/aggregation/core/tensor.h" 24*14675a02SAndroid Build Coastguard Worker #include "fcp/aggregation/core/tensor_shape.h" 25*14675a02SAndroid Build Coastguard Worker #include "fcp/aggregation/core/tensor_spec.h" 26*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h" 27*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor.h" 28*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor_shape.h" 29*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/types.pb.h" 30*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/protobuf/struct.pb.h" 31*14675a02SAndroid Build Coastguard Worker 32*14675a02SAndroid Build Coastguard Worker namespace fcp::aggregation::tensorflow { 33*14675a02SAndroid Build Coastguard Worker 34*14675a02SAndroid Build Coastguard Worker // Converts Tensorflow DataType to Aggregation DataType. 35*14675a02SAndroid Build Coastguard Worker // Returns an error status if the input data type isn't supported by 36*14675a02SAndroid Build Coastguard Worker // the Aggregation Core. 37*14675a02SAndroid Build Coastguard Worker StatusOr<DataType> ConvertDataType(::tensorflow::DataType dtype); 38*14675a02SAndroid Build Coastguard Worker 39*14675a02SAndroid Build Coastguard Worker // Converts Tensorflow TensorShape to Aggregation TensorShape. 40*14675a02SAndroid Build Coastguard Worker // Note that the Tensorflow shape is expected to be valid (it seems impossible 41*14675a02SAndroid Build Coastguard Worker // to create an invalid shape). 42*14675a02SAndroid Build Coastguard Worker TensorShape ConvertShape(const ::tensorflow::TensorShape& shape); 43*14675a02SAndroid Build Coastguard Worker 44*14675a02SAndroid Build Coastguard Worker // Converts Tensorflow TensorSpecProto to Aggregation TensorSpec. 45*14675a02SAndroid Build Coastguard Worker // Returns an error status if supplied TensorSpecProto data type or shape isn't 46*14675a02SAndroid Build Coastguard Worker // supported by the Aggregation Core. 47*14675a02SAndroid Build Coastguard Worker StatusOr<TensorSpec> ConvertTensorSpec( 48*14675a02SAndroid Build Coastguard Worker const ::tensorflow::TensorSpecProto& spec); 49*14675a02SAndroid Build Coastguard Worker 50*14675a02SAndroid Build Coastguard Worker // Converts Tensorflow Tensor to Aggregation Tensor. 51*14675a02SAndroid Build Coastguard Worker // Returns an error status if supplied Tensor data type or shape isn't 52*14675a02SAndroid Build Coastguard Worker // supported by the Aggregation Core. 53*14675a02SAndroid Build Coastguard Worker // Note that this function consumes the Tensorflow tensor. 54*14675a02SAndroid Build Coastguard Worker StatusOr<Tensor> ConvertTensor(std::unique_ptr<::tensorflow::Tensor> tensor); 55*14675a02SAndroid Build Coastguard Worker 56*14675a02SAndroid Build Coastguard Worker } // namespace fcp::aggregation::tensorflow 57*14675a02SAndroid Build Coastguard Worker 58*14675a02SAndroid Build Coastguard Worker #endif // FCP_AGGREGATION_TENSORFLOW_CONVERTERS_H_ 59