1# Copyright (c) Facebook, Inc. and its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import torch 7 8 9# pointwise operators can go through a faster pathway 10 11tensor_magic_methods = ["add", ""] 12pointwise_magic_methods_with_reverse = ( 13 "add", 14 "sub", 15 "mul", 16 "floordiv", 17 "div", 18 "truediv", 19 "mod", 20 "pow", 21 "lshift", 22 "rshift", 23 "and", 24 "or", 25 "xor", 26) 27pointwise_magic_methods = ( 28 *(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)), 29 "eq", 30 "gt", 31 "le", 32 "lt", 33 "ge", 34 "gt", 35 "ne", 36 "neg", 37 "pos", 38 "abs", 39 "invert", 40 "iadd", 41 "isub", 42 "imul", 43 "ifloordiv", 44 "idiv", 45 "itruediv", 46 "imod", 47 "ipow", 48 "ilshift", 49 "irshift", 50 "iand", 51 "ior", 52 "ixor", 53 "int", 54 "long", 55 "float", 56 "complex", 57) 58 59pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),) 60 61pointwise = ( 62 *(getattr(torch.Tensor, m) for m in pointwise_methods), 63 torch.nn.functional.dropout, 64 torch.where, 65 torch.Tensor.abs, 66 torch.abs, 67 torch.Tensor.acos, 68 torch.acos, 69 torch.Tensor.acosh, 70 torch.acosh, 71 torch.Tensor.add, 72 torch.add, 73 torch.Tensor.addcdiv, 74 torch.addcdiv, 75 torch.Tensor.addcmul, 76 torch.addcmul, 77 torch.Tensor.addr, 78 torch.addr, 79 torch.Tensor.angle, 80 torch.angle, 81 torch.Tensor.asin, 82 torch.asin, 83 torch.Tensor.asinh, 84 torch.asinh, 85 torch.Tensor.atan, 86 torch.atan, 87 torch.Tensor.atan2, 88 torch.atan2, 89 torch.Tensor.atanh, 90 torch.atanh, 91 torch.Tensor.bitwise_and, 92 torch.bitwise_and, 93 torch.Tensor.bitwise_left_shift, 94 torch.bitwise_left_shift, 95 torch.Tensor.bitwise_not, 96 torch.bitwise_not, 97 torch.Tensor.bitwise_or, 98 torch.bitwise_or, 99 torch.Tensor.bitwise_right_shift, 100 torch.bitwise_right_shift, 101 torch.Tensor.bitwise_xor, 102 torch.bitwise_xor, 103 torch.Tensor.ceil, 104 torch.ceil, 105 torch.celu, 106 torch.nn.functional.celu, 107 torch.Tensor.clamp, 108 torch.clamp, 109 torch.Tensor.clamp_max, 110 torch.clamp_max, 111 torch.Tensor.clamp_min, 112 torch.clamp_min, 113 torch.Tensor.copysign, 114 torch.copysign, 115 torch.Tensor.cos, 116 torch.cos, 117 torch.Tensor.cosh, 118 torch.cosh, 119 torch.Tensor.deg2rad, 120 torch.deg2rad, 121 torch.Tensor.digamma, 122 torch.digamma, 123 torch.Tensor.div, 124 torch.div, 125 torch.dropout, 126 torch.nn.functional.dropout, 127 torch.nn.functional.elu, 128 torch.Tensor.eq, 129 torch.eq, 130 torch.Tensor.erf, 131 torch.erf, 132 torch.Tensor.erfc, 133 torch.erfc, 134 torch.Tensor.erfinv, 135 torch.erfinv, 136 torch.Tensor.exp, 137 torch.exp, 138 torch.Tensor.exp2, 139 torch.exp2, 140 torch.Tensor.expm1, 141 torch.expm1, 142 torch.feature_dropout, 143 torch.Tensor.float_power, 144 torch.float_power, 145 torch.Tensor.floor, 146 torch.floor, 147 torch.Tensor.floor_divide, 148 torch.floor_divide, 149 torch.Tensor.fmod, 150 torch.fmod, 151 torch.Tensor.frac, 152 torch.frac, 153 torch.Tensor.frexp, 154 torch.frexp, 155 torch.Tensor.gcd, 156 torch.gcd, 157 torch.Tensor.ge, 158 torch.ge, 159 torch.nn.functional.gelu, 160 torch.nn.functional.glu, 161 torch.Tensor.gt, 162 torch.gt, 163 torch.Tensor.hardshrink, 164 torch.hardshrink, 165 torch.nn.functional.hardshrink, 166 torch.nn.functional.hardsigmoid, 167 torch.nn.functional.hardswish, 168 torch.nn.functional.hardtanh, 169 torch.Tensor.heaviside, 170 torch.heaviside, 171 torch.Tensor.hypot, 172 torch.hypot, 173 torch.Tensor.i0, 174 torch.i0, 175 torch.Tensor.igamma, 176 torch.igamma, 177 torch.Tensor.igammac, 178 torch.igammac, 179 torch.Tensor.isclose, 180 torch.isclose, 181 torch.Tensor.isfinite, 182 torch.isfinite, 183 torch.Tensor.isinf, 184 torch.isinf, 185 torch.Tensor.isnan, 186 torch.isnan, 187 torch.Tensor.isneginf, 188 torch.isneginf, 189 torch.Tensor.isposinf, 190 torch.isposinf, 191 torch.Tensor.isreal, 192 torch.isreal, 193 torch.Tensor.kron, 194 torch.kron, 195 torch.Tensor.lcm, 196 torch.lcm, 197 torch.Tensor.ldexp, 198 torch.ldexp, 199 torch.Tensor.le, 200 torch.le, 201 torch.nn.functional.leaky_relu, 202 torch.Tensor.lerp, 203 torch.lerp, 204 torch.Tensor.lgamma, 205 torch.lgamma, 206 torch.Tensor.log, 207 torch.log, 208 torch.Tensor.log10, 209 torch.log10, 210 torch.Tensor.log1p, 211 torch.log1p, 212 torch.Tensor.log2, 213 torch.log2, 214 torch.nn.functional.logsigmoid, 215 torch.Tensor.logical_and, 216 torch.logical_and, 217 torch.Tensor.logical_not, 218 torch.logical_not, 219 torch.Tensor.logical_or, 220 torch.logical_or, 221 torch.Tensor.logical_xor, 222 torch.logical_xor, 223 torch.Tensor.logit, 224 torch.logit, 225 torch.Tensor.lt, 226 torch.lt, 227 torch.Tensor.maximum, 228 torch.maximum, 229 torch.Tensor.minimum, 230 torch.minimum, 231 torch.nn.functional.mish, 232 torch.Tensor.mvlgamma, 233 torch.mvlgamma, 234 torch.Tensor.nan_to_num, 235 torch.nan_to_num, 236 torch.Tensor.ne, 237 torch.ne, 238 torch.Tensor.neg, 239 torch.neg, 240 torch.Tensor.nextafter, 241 torch.nextafter, 242 torch.Tensor.outer, 243 torch.outer, 244 torch.polar, 245 torch.Tensor.polygamma, 246 torch.polygamma, 247 torch.Tensor.positive, 248 torch.positive, 249 torch.Tensor.pow, 250 torch.pow, 251 torch.Tensor.prelu, 252 torch.prelu, 253 torch.nn.functional.prelu, 254 torch.Tensor.rad2deg, 255 torch.rad2deg, 256 torch.Tensor.reciprocal, 257 torch.reciprocal, 258 torch.Tensor.relu, 259 torch.relu, 260 torch.nn.functional.relu, 261 torch.nn.functional.relu6, 262 torch.Tensor.remainder, 263 torch.remainder, 264 torch.Tensor.round, 265 torch.round, 266 torch.rrelu, 267 torch.nn.functional.rrelu, 268 torch.Tensor.rsqrt, 269 torch.rsqrt, 270 torch.rsub, 271 torch.selu, 272 torch.nn.functional.selu, 273 torch.Tensor.sgn, 274 torch.sgn, 275 torch.Tensor.sigmoid, 276 torch.sigmoid, 277 torch.nn.functional.sigmoid, 278 torch.Tensor.sign, 279 torch.sign, 280 torch.Tensor.signbit, 281 torch.signbit, 282 torch.nn.functional.silu, 283 torch.Tensor.sin, 284 torch.sin, 285 torch.Tensor.sinc, 286 torch.sinc, 287 torch.Tensor.sinh, 288 torch.sinh, 289 torch.nn.functional.softplus, 290 torch.nn.functional.softshrink, 291 torch.Tensor.sqrt, 292 torch.sqrt, 293 torch.Tensor.square, 294 torch.square, 295 torch.Tensor.sub, 296 torch.sub, 297 torch.Tensor.tan, 298 torch.tan, 299 torch.Tensor.tanh, 300 torch.tanh, 301 torch.nn.functional.tanh, 302 torch.threshold, 303 torch.nn.functional.threshold, 304 torch.trapz, 305 torch.Tensor.true_divide, 306 torch.true_divide, 307 torch.Tensor.trunc, 308 torch.trunc, 309 torch.Tensor.xlogy, 310 torch.xlogy, 311 torch.rand_like, 312) 313