xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/coreml/builders/reshape_op_builder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/reshape_op_builder.h"
16 
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/delegates/coreml/builders/op_builder.h"
24 #include "tensorflow/lite/delegates/coreml/builders/op_factory.h"
25 #include "tensorflow/lite/delegates/coreml/builders/op_validator.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 
29 namespace tflite {
30 namespace delegates {
31 namespace coreml {
32 
DebugName()33 const std::string& ReshapeOpBuilder::DebugName() {
34   if (debug_name_.empty()) {
35     SetDebugName("ReshapeOpBuilder", node_id_);
36   }
37   return debug_name_;
38 }
39 
Build()40 CoreML::Specification::NeuralNetworkLayer* ReshapeOpBuilder::Build() {
41   if (layer_ == nullptr) {
42     layer_ = std::make_unique<CoreML::Specification::NeuralNetworkLayer>();
43   }
44   layer_->set_name(DebugName());
45   for (int dim : shape_) {
46     layer_->mutable_reshape()->add_targetshape(dim);
47   }
48   if (need_transpose_)
49     layer_->mutable_reshape()->set_mode(
50         CoreML::Specification::ReshapeLayerParams::CHANNEL_LAST);
51   return layer_.release();
52 }
53 
SetShapeFromTensor(const TfLiteTensor * output_shape,const TfLiteIntArray * input_shape)54 void ReshapeOpBuilder::SetShapeFromTensor(const TfLiteTensor* output_shape,
55                                           const TfLiteIntArray* input_shape) {
56   TfLiteIntArray* shape = TfLiteIntArrayCreate(output_shape->dims->data[0]);
57   std::memcpy(shape->data, GetTensorData<int>(output_shape),
58               shape->size * sizeof(int));
59 
60   SetShapeFromIntArray(shape, input_shape);
61   TfLiteIntArrayFree(shape);
62 }
63 
SetShapeFromIntArray(const TfLiteIntArray * output_shape,const TfLiteIntArray * input_shape)64 void ReshapeOpBuilder::SetShapeFromIntArray(const TfLiteIntArray* output_shape,
65                                             const TfLiteIntArray* input_shape) {
66   // ignore first dimension (batch)
67   std::copy(output_shape->data + 1, output_shape->data + output_shape->size,
68             std::back_inserter(shape_));
69 
70   int64_t reshape_size = 1;
71   int negative_index = -1;
72   for (int i = 0; i < shape_.size(); ++i) {
73     if (shape_[i] == -1) {
74       negative_index = i;
75     } else {
76       reshape_size *= shape_[i];
77     }
78   }
79   if (negative_index >= 0) {
80     int64_t input_size = NumElements(input_shape);
81     shape_[negative_index] = input_size / reshape_size;
82   }
83 
84   if (shape_.size() == 2) {
85     shape_ = {shape_[1], 1, shape_[0]};
86   } else if (shape_.size() == 3) {
87     shape_ = {shape_[2], shape_[0], shape_[1]};
88   }
89   // When channel dimension is changed, reshape should be done with HWC layout.
90   if (shape_[0] != input_shape->data[input_shape->size - 1]) {
91     need_transpose_ = true;
92   }
93 }
94 
RegisterInputs(const TfLiteIntArray * inputs,TfLiteContext * context)95 TfLiteStatus ReshapeOpBuilder::RegisterInputs(const TfLiteIntArray* inputs,
96                                               TfLiteContext* context) {
97   AddInput(inputs->data[0]);
98 
99   if (inputs->size == 2) {
100     SetShapeFromTensor(&context->tensors[inputs->data[1]],
101                        context->tensors[inputs->data[0]].dims);
102   } else {
103     const auto* params = reinterpret_cast<TfLiteReshapeParams*>(builtin_data_);
104     TfLiteIntArray* output_shape = TfLiteIntArrayCreate(params->num_dimensions);
105     std::memcpy(output_shape->data, params->shape,
106                 params->num_dimensions * sizeof(int));
107 
108     SetShapeFromIntArray(output_shape, context->tensors[inputs->data[0]].dims);
109     TfLiteIntArrayFree(output_shape);
110   }
111   return kTfLiteOk;
112 }
113 
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)114 TfLiteStatus ReshapeOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
115                                                TfLiteContext* context) {
116   graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context));
117   return kTfLiteOk;
118 }
119 
IsReshapeOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context,int coreml_version)120 bool IsReshapeOpSupported(const TfLiteRegistration* registration,
121                           const TfLiteNode* node, TfLiteContext* context,
122                           int coreml_version) {
123   if (coreml_version >= 3) {
124     return false;
125   }
126   if (node->inputs->size == 1) {
127     const auto* params =
128         reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
129     return params->num_dimensions == 3 || params->num_dimensions == 4;
130   }
131 
132   const int kShapeTensor = 1;
133   const TfLiteTensor* shape;
134   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShapeTensor, &shape));
135   if (shape->allocation_type != kTfLiteMmapRo) {
136     TF_LITE_KERNEL_LOG(context, "Reshape has non-const shape.");
137     return false;
138   }
139   const bool is_shape_tensor =
140       shape->dims->size == 1 && shape->type == kTfLiteInt32;
141   return is_shape_tensor &&
142          (shape->dims->data[0] == 3 || shape->dims->data[0] == 4);
143 }
144 
CreateReshapeOpBuilder(GraphBuilder * graph_builder)145 OpBuilder* CreateReshapeOpBuilder(GraphBuilder* graph_builder) {
146   return new ReshapeOpBuilder(graph_builder);
147 }
148 
149 }  // namespace coreml
150 }  // namespace delegates
151 }  // namespace tflite
152