1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "SoftmaxLayer.hpp"
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include "LayerCloneBase.hpp"
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TypesUtils.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadData.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadFactory.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker namespace armnn
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker
SoftmaxLayer(const SoftmaxDescriptor & param,const char * name)16*89c4ff92SAndroid Build Coastguard Worker SoftmaxLayer::SoftmaxLayer(const SoftmaxDescriptor ¶m, const char* name)
17*89c4ff92SAndroid Build Coastguard Worker : LayerWithParameters(1, 1, LayerType::Softmax, param, name)
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker }
20*89c4ff92SAndroid Build Coastguard Worker
CreateWorkload(const IWorkloadFactory & factory) const21*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IWorkload> SoftmaxLayer::CreateWorkload(const IWorkloadFactory& factory) const
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker SoftmaxQueueDescriptor descriptor;
24*89c4ff92SAndroid Build Coastguard Worker SetAdditionalInfo(descriptor);
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker return factory.CreateWorkload(LayerType::Softmax, descriptor, PrepInfoAndDesc(descriptor));
27*89c4ff92SAndroid Build Coastguard Worker }
28*89c4ff92SAndroid Build Coastguard Worker
Clone(Graph & graph) const29*89c4ff92SAndroid Build Coastguard Worker SoftmaxLayer* SoftmaxLayer::Clone(Graph& graph) const
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker return CloneBase<SoftmaxLayer>(graph, m_Param, GetName());
32*89c4ff92SAndroid Build Coastguard Worker }
33*89c4ff92SAndroid Build Coastguard Worker
ValidateTensorShapesFromInputs()34*89c4ff92SAndroid Build Coastguard Worker void SoftmaxLayer::ValidateTensorShapesFromInputs()
35*89c4ff92SAndroid Build Coastguard Worker {
36*89c4ff92SAndroid Build Coastguard Worker VerifyLayerConnections(1, CHECK_LOCATION());
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
39*89c4ff92SAndroid Build Coastguard Worker
40*89c4ff92SAndroid Build Coastguard Worker VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(inferredShapes.size() == 1);
45*89c4ff92SAndroid Build Coastguard Worker
46*89c4ff92SAndroid Build Coastguard Worker ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "SoftmaxLayer");
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker
ExecuteStrategy(IStrategy & strategy) const49*89c4ff92SAndroid Build Coastguard Worker void SoftmaxLayer::ExecuteStrategy(IStrategy& strategy) const
50*89c4ff92SAndroid Build Coastguard Worker {
51*89c4ff92SAndroid Build Coastguard Worker strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
52*89c4ff92SAndroid Build Coastguard Worker }
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
55