xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/UnaryOps.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8
9namespace executorch {
10namespace backends {
11namespace mps {
12namespace delegate {
13
14MPSGraphTensor*
15unaryOpTensor(
16  MPSGraphTensor* inputTensor,
17  MPSGraph* mpsGraph,
18  std::function<MPSGraphTensor*(MPSGraphTensor*)> unaryOpFunction) {
19    return unaryOpFunction(inputTensor);
20}
21
22Error
23MPSGraphBuilder::mpsBitwiseNotOp(NodePtr nodePtr) {
24  auto graphNode = nodePtr->mpsnode_union_as_MPSBitwiseNot();
25  ET_LOG(
26    Debug, "%s: %d -> %d",
27    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
28  );
29
30  MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
31  MPSDataType mpsInputDataType = [inputTensor dataType];
32  if (getScalarType(mpsInputDataType) == executorch::aten::ScalarType::Bool) {
33    _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph notWithTensor:inputTensor name:nil];
34  } else {
35    ET_CHECK_OR_RETURN_ERROR(
36      is_macos_13_or_newer(), NotSupported,
37      "mpsBitwiseNotOp supported by MPS on MacOS13.0+/iOS16.1+");
38    _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph bitwiseNOTWithTensor:inputTensor name:nil];
39  }
40
41  return Error::Ok;
42}
43
44#define REGISTER_UNARY_OP(aot_name, graph_op)                                                   \
45Error                                                                                           \
46MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr)  {                                          \
47  auto graphNode = static_cast<const mpsgraph::_MPSNode1x1 *>(nodePtr->mpsnode_union());        \
48  ET_LOG(                                                                                       \
49    Debug, "%s: %d -> %d",                                                                      \
50    __FUNCTION__,                                                                               \
51    graphNode->input1_id(),                                                                     \
52    graphNode->output_id()                                                                      \
53  );                                                                                            \
54  _idToMPSGraphTensor[graphNode->output_id()] = unaryOpTensor(                                  \
55    getMPSGraphTensor(graphNode->input1_id()),                                                  \
56    _mpsGraph,                                                                                  \
57    [&](MPSGraphTensor* inputTensor) -> MPSGraphTensor* {                                       \
58      return [_mpsGraph graph_op##WithTensor:inputTensor                                        \
59                                        name:nil];                                              \
60    }                                                                                           \
61  );                                                                                            \
62  return Error::Ok;                                                                             \
63}
64
65REGISTER_UNARY_OP(Exp, exponent)
66REGISTER_UNARY_OP(Exp2, exponentBase2)
67REGISTER_UNARY_OP(Reciprocal, reciprocal)
68REGISTER_UNARY_OP(Sqrt, squareRoot)
69REGISTER_UNARY_OP(Neg, negative)
70REGISTER_UNARY_OP(Log, logarithm)
71REGISTER_UNARY_OP(Log10, logarithmBase10)
72REGISTER_UNARY_OP(Log2, logarithmBase2)
73REGISTER_UNARY_OP(Erf, erf)
74REGISTER_UNARY_OP(Floor, floor)
75REGISTER_UNARY_OP(Ceil, ceil)
76REGISTER_UNARY_OP(Rsqrt, reverseSquareRoot)
77REGISTER_UNARY_OP(Sigmoid, sigmoid)
78REGISTER_UNARY_OP(Sin, sin)
79REGISTER_UNARY_OP(Sign, sign)
80REGISTER_UNARY_OP(Cos, cos)
81REGISTER_UNARY_OP(Tan, tan)
82REGISTER_UNARY_OP(Abs,  absolute)
83REGISTER_UNARY_OP(Asin, asin)
84REGISTER_UNARY_OP(Acos, acos)
85REGISTER_UNARY_OP(Atan, atan)
86REGISTER_UNARY_OP(Sinh, sinh)
87REGISTER_UNARY_OP(Cosh, cosh)
88REGISTER_UNARY_OP(Tanh, tanh)
89REGISTER_UNARY_OP(Asinh, asinh)
90REGISTER_UNARY_OP(Acosh, acosh)
91REGISTER_UNARY_OP(Atanh, atanh)
92REGISTER_UNARY_OP(Isnan, isNaN)
93REGISTER_UNARY_OP(Isinf, isInfinite)
94REGISTER_UNARY_OP(Round, round)
95REGISTER_UNARY_OP(LogicalNot, not)
96
97
98Error
99MPSGraphBuilder::mpsNormCdfOp(NodePtr nodePtr)  {
100  auto graphNode = static_cast<const mpsgraph::_MPSNode1x1 *>(nodePtr->mpsnode_union());
101  ET_LOG(
102    Debug, "%s: %d -> %d",
103    __FUNCTION__,
104    graphNode->input1_id(),
105    graphNode->output_id()
106  );
107  MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
108  auto dataType = [inputTensor dataType];
109  const float SQRT1_2 = 0.707106781186547524400844362104849039f;
110  MPSGraphTensor *sqrt1_2 = [_mpsGraph constantWithScalar:SQRT1_2
111                                                    shape:@[@1]
112                                                 dataType:dataType];
113  MPSGraphTensor *onef = [_mpsGraph constantWithScalar:1.0f
114                                                shape:@[@1]
115                                             dataType:dataType];
116  MPSGraphTensor *halff = [_mpsGraph constantWithScalar:0.5f
117                                                 shape:@[@1]
118                                              dataType:dataType];
119
120  MPSGraphTensor *erfTensor = [_mpsGraph multiplicationWithPrimaryTensor:inputTensor
121                                                         secondaryTensor:sqrt1_2
122                                                                    name:nil];
123  erfTensor = [_mpsGraph erfWithTensor:erfTensor name:nil];
124  erfTensor = [_mpsGraph additionWithPrimaryTensor:erfTensor
125                                   secondaryTensor:onef
126                                              name:nil];
127  _idToMPSGraphTensor[graphNode->output_id()] =
128    [_mpsGraph multiplicationWithPrimaryTensor:erfTensor
129                               secondaryTensor:halff
130                                          name:nil];
131
132  return Error::Ok;
133}
134
135} // namespace delegate
136} // namespace mps
137} // namespace backends
138} // namespace executorch
139