1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/ops.h"
17
18 #include "tensorflow/c/tf_status_helper.h"
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/op.h"
21 #include "tensorflow/core/framework/op_def_builder.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23
24 using ::tensorflow::DataType;
25 using ::tensorflow::OpDef;
26 using ::tensorflow::OpDefBuilder;
27 using ::tensorflow::OpDeprecation;
28 using ::tensorflow::OpShapeInferenceFn;
29 using ::tensorflow::Set_TF_Status_from_Status;
30 using ::tensorflow::Status;
31 using ::tensorflow::shape_inference::DimensionHandle;
32 using ::tensorflow::shape_inference::InferenceContext;
33 using ::tensorflow::shape_inference::ShapeHandle;
34
TF_NewOpDefinitionBuilder(const char * op_name)35 TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name) {
36 auto* result = new OpDefBuilder(op_name);
37 return reinterpret_cast<TF_OpDefinitionBuilder*>(result);
38 }
39
TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder * builder)40 void TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder* builder) {
41 delete reinterpret_cast<OpDefBuilder*>(builder);
42 }
43
TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder * builder,const char * input_spec)44 void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder,
45 const char* input_spec) {
46 reinterpret_cast<OpDefBuilder*>(builder)->Input(input_spec);
47 }
48
TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder * builder,const char * output_spec)49 void TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder* builder,
50 const char* output_spec) {
51 reinterpret_cast<OpDefBuilder*>(builder)->Output(output_spec);
52 }
53
54 #define DEFINE_BUILDER_BOOL_SETTER(func_name) \
55 void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \
56 bool arg_name) { \
57 reinterpret_cast<OpDefBuilder*>(builder)->func_name(); \
58 }
59
60 DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative)
DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate)61 DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate)
62 DEFINE_BUILDER_BOOL_SETTER(SetIsStateful)
63 DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput)
64
65 void TF_OpDefinitionBuilderAddAttr(TF_OpDefinitionBuilder* builder,
66 const char* attr_spec) {
67 reinterpret_cast<OpDefBuilder*>(builder)->Attr(attr_spec);
68 }
69
TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder * builder,int version,const char * explanation)70 void TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder* builder,
71 int version, const char* explanation) {
72 reinterpret_cast<OpDefBuilder*>(builder)->Deprecated(version, explanation);
73 }
74
TF_RegisterOpDefinition(TF_OpDefinitionBuilder * builder,TF_Status * status)75 void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
76 TF_Status* status) {
77 auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
78 TF_SetStatus(status, TF_OK, "");
79 ::tensorflow::OpRegistry::Global()->Register(
80 [cc_builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
81 Status result = cc_builder->Finalize(op_reg_data);
82 delete cc_builder;
83 return result;
84 });
85 }
86
TF_OpDefinitionBuilderSetShapeInferenceFunction(TF_OpDefinitionBuilder * builder,void (* shape_inference_func)(TF_ShapeInferenceContext * ctx,TF_Status * status))87 void TF_OpDefinitionBuilderSetShapeInferenceFunction(
88 TF_OpDefinitionBuilder* builder,
89 void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
90 TF_Status* status)) {
91 auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
92 cc_builder->SetShapeFn(
93 [shape_inference_func](InferenceContext* ctx) -> tensorflow::Status {
94 TF_Status* c_status = TF_NewStatus();
95 auto c_ctx = reinterpret_cast<TF_ShapeInferenceContext*>(ctx);
96 shape_inference_func(c_ctx, c_status);
97 tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status);
98 TF_DeleteStatus(c_status);
99 return result;
100 });
101 }
102
TF_NewShapeHandle()103 TF_ShapeHandle* TF_NewShapeHandle() {
104 return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle);
105 }
106
TF_ShapeInferenceContextScalar(TF_ShapeInferenceContext * ctx)107 TF_ShapeHandle* TF_ShapeInferenceContextScalar(TF_ShapeInferenceContext* ctx) {
108 auto* handle = new ShapeHandle;
109 *handle = reinterpret_cast<InferenceContext*>(ctx)->Scalar();
110 return reinterpret_cast<TF_ShapeHandle*>(handle);
111 }
112
TF_ShapeInferenceContextVectorFromSize(TF_ShapeInferenceContext * ctx,size_t size)113 TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
114 TF_ShapeInferenceContext* ctx, size_t size) {
115 auto* handle = new ShapeHandle;
116 *handle = reinterpret_cast<InferenceContext*>(ctx)->Vector(size);
117 return reinterpret_cast<TF_ShapeHandle*>(handle);
118 }
119
TF_ShapeInferenceContextConcatenateShapes(TF_ShapeInferenceContext * ctx,TF_ShapeHandle * first,TF_ShapeHandle * second,TF_ShapeHandle * result,TF_Status * status)120 void TF_ShapeInferenceContextConcatenateShapes(TF_ShapeInferenceContext* ctx,
121 TF_ShapeHandle* first,
122 TF_ShapeHandle* second,
123 TF_ShapeHandle* result,
124 TF_Status* status) {
125 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
126 Status s = cc_ctx->Concatenate(*reinterpret_cast<ShapeHandle*>(first),
127 *reinterpret_cast<ShapeHandle*>(second),
128 reinterpret_cast<ShapeHandle*>(result));
129 Set_TF_Status_from_Status(status, s);
130 }
131
TF_NewDimensionHandle()132 TF_DimensionHandle* TF_NewDimensionHandle() {
133 return reinterpret_cast<TF_DimensionHandle*>(new DimensionHandle);
134 }
135
TF_ShapeInferenceContextNumInputs(TF_ShapeInferenceContext * ctx)136 int64_t TF_ShapeInferenceContextNumInputs(TF_ShapeInferenceContext* ctx) {
137 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
138 return cc_ctx->num_inputs();
139 }
140
TF_ShapeInferenceContextGetInput(TF_ShapeInferenceContext * ctx,int i,TF_ShapeHandle * handle,TF_Status * status)141 void TF_ShapeInferenceContextGetInput(TF_ShapeInferenceContext* ctx, int i,
142 TF_ShapeHandle* handle,
143 TF_Status* status) {
144 TF_SetStatus(status, TF_OK, "");
145 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
146 if (i < 0 || i >= cc_ctx->num_inputs()) {
147 TF_SetStatus(status, TF_INVALID_ARGUMENT, "input index out of range");
148 }
149 if (TF_GetCode(status) == TF_OK) {
150 auto* cc_result = reinterpret_cast<ShapeHandle*>(handle);
151 *cc_result = cc_ctx->input(i);
152 }
153 }
154
TF_ShapeInferenceContextRankKnown(TF_ShapeInferenceContext * ctx,TF_ShapeHandle * handle)155 int TF_ShapeInferenceContextRankKnown(TF_ShapeInferenceContext* ctx,
156 TF_ShapeHandle* handle) {
157 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
158 return cc_ctx->RankKnown(*reinterpret_cast<ShapeHandle*>(handle));
159 }
160
TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext * ctx,int i,TF_ShapeHandle * handle,TF_Status * status)161 void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx, int i,
162 TF_ShapeHandle* handle,
163 TF_Status* status) {
164 TF_SetStatus(status, TF_OK, "");
165 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
166 if (i < 0 || i >= cc_ctx->num_outputs()) {
167 TF_SetStatus(status, TF_INVALID_ARGUMENT, "output index out of range");
168 }
169 if (TF_GetCode(status) == TF_OK) {
170 cc_ctx->set_output(i, *(reinterpret_cast<ShapeHandle*>(handle)));
171 }
172 }
173
TF_DeleteShapeHandle(TF_ShapeHandle * handle)174 void TF_DeleteShapeHandle(TF_ShapeHandle* handle) {
175 if (handle == nullptr) {
176 return;
177 }
178
179 delete reinterpret_cast<ShapeHandle*>(handle);
180 }
181
TF_DeleteDimensionHandle(TF_DimensionHandle * handle)182 void TF_DeleteDimensionHandle(TF_DimensionHandle* handle) {
183 if (handle == nullptr) {
184 return;
185 }
186
187 delete reinterpret_cast<DimensionHandle*>(handle);
188 }
189
190 #define DEFINE_TF_GETATTR(func, c_type, cc_type) \
191 void TF_ShapeInferenceContext_GetAttr##func( \
192 TF_ShapeInferenceContext* ctx, const char* attr_name, c_type* val, \
193 TF_Status* status) { \
194 TF_SetStatus(status, TF_OK, ""); \
195 cc_type v; \
196 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
197 Status s = cc_ctx->GetAttr(attr_name, &v); \
198 Set_TF_Status_from_Status(status, s); \
199 if (s.ok()) { \
200 *val = static_cast<c_type>(v); \
201 } \
202 }
203
DEFINE_TF_GETATTR(Type,TF_DataType,tensorflow::DataType)204 DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
205
206 #define DEFINE_RANK_FUNC(func_name) \
207 void TF_ShapeInferenceContext##func_name( \
208 TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, \
209 TF_ShapeHandle* result, TF_Status* status) { \
210 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
211 auto* cc_handle = reinterpret_cast<ShapeHandle*>(handle); \
212 auto* cc_result = reinterpret_cast<ShapeHandle*>(result); \
213 Status s = cc_ctx->func_name(*cc_handle, rank, cc_result); \
214 Set_TF_Status_from_Status(status, s); \
215 }
216
217 DEFINE_RANK_FUNC(WithRank)
218 DEFINE_RANK_FUNC(WithRankAtLeast)
219 DEFINE_RANK_FUNC(WithRankAtMost)
220
221 int64_t TF_ShapeInferenceContextRank(TF_ShapeInferenceContext* ctx,
222 TF_ShapeHandle* handle) {
223 return reinterpret_cast<InferenceContext*>(ctx)->Rank(
224 *reinterpret_cast<ShapeHandle*>(handle));
225 }
226
TF_ShapeInferenceContextDim(TF_ShapeInferenceContext * ctx,TF_ShapeHandle * shape_handle,int64_t i,TF_DimensionHandle * result)227 void TF_ShapeInferenceContextDim(TF_ShapeInferenceContext* ctx,
228 TF_ShapeHandle* shape_handle, int64_t i,
229 TF_DimensionHandle* result) {
230 int64_t rank = TF_ShapeInferenceContextRank(ctx, shape_handle);
231 auto* cc_result = reinterpret_cast<DimensionHandle*>(result);
232
233 if (i < -rank || i >= rank) {
234 *cc_result = DimensionHandle();
235 return;
236 }
237
238 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
239 auto* cc_shape_handle = reinterpret_cast<ShapeHandle*>(shape_handle);
240 *cc_result = cc_ctx->Dim(*cc_shape_handle, i);
241 }
242
TF_DimensionHandleValueKnown(TF_DimensionHandle * dim_handle)243 int TF_DimensionHandleValueKnown(TF_DimensionHandle* dim_handle) {
244 return InferenceContext::ValueKnown(
245 *reinterpret_cast<DimensionHandle*>(dim_handle));
246 }
247
TF_ShapeInferenceContextSetUnknownShape(TF_ShapeInferenceContext * ctx,TF_Status * status)248 void TF_ShapeInferenceContextSetUnknownShape(TF_ShapeInferenceContext* ctx,
249 TF_Status* status) {
250 Status s = ::tensorflow::shape_inference::UnknownShape(
251 reinterpret_cast<InferenceContext*>(ctx));
252 Set_TF_Status_from_Status(status, s);
253 }
254
TF_ShapeInferenceContextSubshape(TF_ShapeInferenceContext * ctx,TF_ShapeHandle * shape_handle,int64_t start,int64_t end,TF_ShapeHandle * result,TF_Status * status)255 void TF_ShapeInferenceContextSubshape(TF_ShapeInferenceContext* ctx,
256 TF_ShapeHandle* shape_handle,
257 int64_t start, int64_t end,
258 TF_ShapeHandle* result,
259 TF_Status* status) {
260 TF_SetStatus(status, TF_OK, "");
261 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
262 auto* cc_result = reinterpret_cast<ShapeHandle*>(result);
263 Status s = cc_ctx->Subshape(*reinterpret_cast<ShapeHandle*>(shape_handle),
264 start, end, cc_result);
265 Set_TF_Status_from_Status(status, s);
266 }
267
TF_DimensionHandleValue(TF_DimensionHandle * dim_handle)268 int64_t TF_DimensionHandleValue(TF_DimensionHandle* dim_handle) {
269 return InferenceContext::Value(
270 *reinterpret_cast<DimensionHandle*>(dim_handle));
271 }
272