1 2// 3// Copyright (c) 2023 Apple Inc. All rights reserved. 4// Provided subject to the LICENSE file in the top level directory. 5// 6 7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 8 9namespace executorch { 10namespace backends { 11namespace mps { 12namespace delegate { 13 14MPSGraphTensor* 15unaryOpTensor( 16 MPSGraphTensor* inputTensor, 17 MPSGraph* mpsGraph, 18 std::function<MPSGraphTensor*(MPSGraphTensor*)> unaryOpFunction) { 19 return unaryOpFunction(inputTensor); 20} 21 22Error 23MPSGraphBuilder::mpsBitwiseNotOp(NodePtr nodePtr) { 24 auto graphNode = nodePtr->mpsnode_union_as_MPSBitwiseNot(); 25 ET_LOG( 26 Debug, "%s: %d -> %d", 27 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 28 ); 29 30 MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); 31 MPSDataType mpsInputDataType = [inputTensor dataType]; 32 if (getScalarType(mpsInputDataType) == executorch::aten::ScalarType::Bool) { 33 _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph notWithTensor:inputTensor name:nil]; 34 } else { 35 ET_CHECK_OR_RETURN_ERROR( 36 is_macos_13_or_newer(), NotSupported, 37 "mpsBitwiseNotOp supported by MPS on MacOS13.0+/iOS16.1+"); 38 _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph bitwiseNOTWithTensor:inputTensor name:nil]; 39 } 40 41 return Error::Ok; 42} 43 44#define REGISTER_UNARY_OP(aot_name, graph_op) \ 45Error \ 46MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr) { \ 47 auto graphNode = static_cast<const mpsgraph::_MPSNode1x1 *>(nodePtr->mpsnode_union()); \ 48 ET_LOG( \ 49 Debug, "%s: %d -> %d", \ 50 __FUNCTION__, \ 51 graphNode->input1_id(), \ 52 graphNode->output_id() \ 53 ); \ 54 _idToMPSGraphTensor[graphNode->output_id()] = unaryOpTensor( \ 55 getMPSGraphTensor(graphNode->input1_id()), \ 56 _mpsGraph, \ 57 [&](MPSGraphTensor* inputTensor) -> MPSGraphTensor* { \ 58 return [_mpsGraph graph_op##WithTensor:inputTensor \ 59 name:nil]; \ 60 } \ 61 ); \ 62 return Error::Ok; \ 63} 64 65REGISTER_UNARY_OP(Exp, exponent) 66REGISTER_UNARY_OP(Exp2, exponentBase2) 67REGISTER_UNARY_OP(Reciprocal, reciprocal) 68REGISTER_UNARY_OP(Sqrt, squareRoot) 69REGISTER_UNARY_OP(Neg, negative) 70REGISTER_UNARY_OP(Log, logarithm) 71REGISTER_UNARY_OP(Log10, logarithmBase10) 72REGISTER_UNARY_OP(Log2, logarithmBase2) 73REGISTER_UNARY_OP(Erf, erf) 74REGISTER_UNARY_OP(Floor, floor) 75REGISTER_UNARY_OP(Ceil, ceil) 76REGISTER_UNARY_OP(Rsqrt, reverseSquareRoot) 77REGISTER_UNARY_OP(Sigmoid, sigmoid) 78REGISTER_UNARY_OP(Sin, sin) 79REGISTER_UNARY_OP(Sign, sign) 80REGISTER_UNARY_OP(Cos, cos) 81REGISTER_UNARY_OP(Tan, tan) 82REGISTER_UNARY_OP(Abs, absolute) 83REGISTER_UNARY_OP(Asin, asin) 84REGISTER_UNARY_OP(Acos, acos) 85REGISTER_UNARY_OP(Atan, atan) 86REGISTER_UNARY_OP(Sinh, sinh) 87REGISTER_UNARY_OP(Cosh, cosh) 88REGISTER_UNARY_OP(Tanh, tanh) 89REGISTER_UNARY_OP(Asinh, asinh) 90REGISTER_UNARY_OP(Acosh, acosh) 91REGISTER_UNARY_OP(Atanh, atanh) 92REGISTER_UNARY_OP(Isnan, isNaN) 93REGISTER_UNARY_OP(Isinf, isInfinite) 94REGISTER_UNARY_OP(Round, round) 95REGISTER_UNARY_OP(LogicalNot, not) 96 97 98Error 99MPSGraphBuilder::mpsNormCdfOp(NodePtr nodePtr) { 100 auto graphNode = static_cast<const mpsgraph::_MPSNode1x1 *>(nodePtr->mpsnode_union()); 101 ET_LOG( 102 Debug, "%s: %d -> %d", 103 __FUNCTION__, 104 graphNode->input1_id(), 105 graphNode->output_id() 106 ); 107 MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); 108 auto dataType = [inputTensor dataType]; 109 const float SQRT1_2 = 0.707106781186547524400844362104849039f; 110 MPSGraphTensor *sqrt1_2 = [_mpsGraph constantWithScalar:SQRT1_2 111 shape:@[@1] 112 dataType:dataType]; 113 MPSGraphTensor *onef = [_mpsGraph constantWithScalar:1.0f 114 shape:@[@1] 115 dataType:dataType]; 116 MPSGraphTensor *halff = [_mpsGraph constantWithScalar:0.5f 117 shape:@[@1] 118 dataType:dataType]; 119 120 MPSGraphTensor *erfTensor = [_mpsGraph multiplicationWithPrimaryTensor:inputTensor 121 secondaryTensor:sqrt1_2 122 name:nil]; 123 erfTensor = [_mpsGraph erfWithTensor:erfTensor name:nil]; 124 erfTensor = [_mpsGraph additionWithPrimaryTensor:erfTensor 125 secondaryTensor:onef 126 name:nil]; 127 _idToMPSGraphTensor[graphNode->output_id()] = 128 [_mpsGraph multiplicationWithPrimaryTensor:erfTensor 129 secondaryTensor:halff 130 name:nil]; 131 132 return Error::Ok; 133} 134 135} // namespace delegate 136} // namespace mps 137} // namespace backends 138} // namespace executorch 139