xref: /aosp_15_r20/external/armnn/src/armnn/layers/QuantizeLayer.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "QuantizeLayer.hpp"
7 
8 #include "LayerCloneBase.hpp"
9 
10 namespace armnn
11 {
12 
QuantizeLayer(const char * name)13 QuantizeLayer::QuantizeLayer(const char* name)
14 : Layer(1, 1, LayerType::Quantize, name)
15 {}
16 
CreateWorkload(const IWorkloadFactory & factory) const17 std::unique_ptr<IWorkload> QuantizeLayer::CreateWorkload(const IWorkloadFactory& factory) const
18 {
19     QuantizeQueueDescriptor descriptor;
20     SetAdditionalInfo(descriptor);
21 
22     WorkloadInfo info = PrepInfoAndDesc(descriptor);
23 
24     return factory.CreateWorkload(LayerType::Quantize, descriptor, info);
25 }
26 
Clone(Graph & graph) const27 Layer* QuantizeLayer::Clone(Graph& graph) const
28 {
29     QuantizeLayer* clone = CloneBase<QuantizeLayer>(graph, GetName());
30     return clone;
31 }
32 
ValidateTensorShapesFromInputs()33 void QuantizeLayer::ValidateTensorShapesFromInputs()
34 {
35     VerifyLayerConnections(1, CHECK_LOCATION());
36 
37     const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
38 
39     VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
40 
41     auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
42 
43     ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "QuantizeLayer");
44 }
45 
ExecuteStrategy(IStrategy & strategy) const46 void QuantizeLayer::ExecuteStrategy(IStrategy& strategy) const
47 {
48     strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
49 }
50 
51 } //namespace armnn