1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert CHANNEL_TILE % 8 == 0 7$assert CHANNEL_TILE >= 8 8$assert ROW_TILE >= 1 9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10#include <assert.h> 11 12#include <arm_neon.h> 13 14#include <xnnpack/math.h> 15#include <xnnpack/prelu.h> 16 17 18void xnn_f16_prelu_ukernel__neonfp16arith_${ROW_TILE}x${CHANNEL_TILE}( 19 size_t rows, 20 size_t channels, 21 const void* restrict input, 22 size_t input_stride, 23 const void* restrict weights, 24 void* restrict output, 25 size_t output_stride) XNN_OOB_READS 26{ 27 assert(rows != 0); 28 assert(channels != 0); 29 assert(channels % sizeof(__fp16) == 0); 30 31 const __fp16* i0 = (const __fp16*) input; 32 __fp16* o0 = (__fp16*) output; 33 $for M in range(1, ROW_TILE): 34 const __fp16* i${M} = (const __fp16*) ((uintptr_t) i${M-1} + input_stride); 35 __fp16* o${M} = (__fp16*) ((uintptr_t) o${M-1} + output_stride); 36 37 const size_t input_increment = input_stride * ${ROW_TILE} - channels; 38 const size_t output_increment = output_stride * ${ROW_TILE} - channels; 39 40 do { 41 $for M in range(1, ROW_TILE): 42 $if M % 2 == 0: 43 if XNN_UNPREDICTABLE(rows <= ${M}) { 44 i${M} = i${M-1}; 45 o${M} = o${M-1}; 46 } 47 $else: 48 if XNN_UNPREDICTABLE(rows < ${M+1}) { 49 i${M} = i${M-1}; 50 o${M} = o${M-1}; 51 } 52 53 const __fp16* w = (const __fp16*) weights; 54 size_t c = channels; 55 $if CHANNEL_TILE > 8: 56 for (; c >= ${CHANNEL_TILE} * sizeof(__fp16); c -= ${CHANNEL_TILE} * sizeof(__fp16)) { 57 $for C in range(0, CHANNEL_TILE, 8): 58 const float16x8_t vw${ABC[C:C+8]} = vld1q_f16(w); w += 8; 59 60 $for M in range(ROW_TILE): 61 $for C in range(0, CHANNEL_TILE, 8): 62 const float16x8_t vi${M}x0${ABC[C:C+8]} = vld1q_f16(i${M}); i${M} += 8; 63 64 $for M in range(ROW_TILE): 65 $for C in range(0, CHANNEL_TILE, 8): 66 float16x8_t vacc${M}x0${ABC[C:C+8]} = vmulq_f16(vi${M}x0${ABC[C:C+8]}, vw${ABC[C:C+8]}); 67 const uint16x8_t vm${M}x0${ABC[C:C+8]} = vcltq_s16(vreinterpretq_s16_f16(vi${M}x0${ABC[C:C+8]}), vmovq_n_s16(0)); 68 69 $for M in range(ROW_TILE): 70 $for C in range(0, CHANNEL_TILE, 8): 71 vacc${M}x0${ABC[C:C+8]} = vbslq_f16(vm${M}x0${ABC[C:C+8]}, vacc${M}x0${ABC[C:C+8]}, vi${M}x0${ABC[C:C+8]}); 72 73 $for M in range(ROW_TILE): 74 $for C in range(0, CHANNEL_TILE, 8): 75 vst1q_f16(o${M}, vacc${M}x0${ABC[C:C+8]}); o${M} += 8; 76 } 77 for (; c >= 8 * sizeof(__fp16); c -= 8 * sizeof(__fp16)) { 78 const float16x8_t vw01234567 = vld1q_f16(w); w += 8; 79 80 $for M in range(ROW_TILE): 81 const float16x8_t vi${M}x01234567 = vld1q_f16(i${M}); 82 i${M} += 8; 83 84 $for M in range(ROW_TILE): 85 float16x8_t vacc${M}x01234567 = vmulq_f16(vi${M}x01234567, vw01234567); 86 const uint16x8_t vm${M}x01234567 = vcltq_s16(vreinterpretq_s16_f16(vi${M}x01234567), vmovq_n_s16(0)); 87 88 $for M in range(ROW_TILE): 89 vacc${M}x01234567 = vbslq_f16(vm${M}x01234567, vacc${M}x01234567, vi${M}x01234567); 90 91 $for M in range(ROW_TILE): 92 vst1q_f16(o${M}, vacc${M}x01234567); o${M} += 8; 93 } 94 if XNN_UNLIKELY(c != 0) { 95 const float16x8_t vw01234567 = vld1q_f16(w); 96 97 $for M in range(ROW_TILE): 98 const float16x8_t vi${M}x01234567 = vld1q_f16(i${M}); 99 i${M} = (const __fp16*) ((uintptr_t) i${M} + c); 100 101 $for M in range(ROW_TILE): 102 float16x8_t vacc${M}x01234567 = vmulq_f16(vi${M}x01234567, vw01234567); 103 const uint16x8_t vm${M}x01234567 = vcltq_s16(vreinterpretq_s16_f16(vi${M}x01234567), vmovq_n_s16(0)); 104 105 $for M in range(ROW_TILE): 106 vacc${M}x01234567 = vbslq_f16(vm${M}x01234567, vacc${M}x01234567, vi${M}x01234567); 107 108 $for M in range(ROW_TILE): 109 float16x4_t vacc${M}x0123 = vget_low_f16(vacc${M}x01234567); 110 if (c & (4 * sizeof(__fp16))) { 111 $for M in range(ROW_TILE): 112 vst1_f16(o${M}, vacc${M}x0123); o${M} += 4; 113 114 $for M in range(ROW_TILE): 115 vacc${M}x0123 = vget_high_f16(vacc${M}x01234567); 116 } 117 if (c & (2 * sizeof(__fp16))) { 118 $for M in range(ROW_TILE): 119 vst1_lane_u32((void*) o${M}, vreinterpret_u32_f16(vacc${M}x0123), 0); o${M} += 2; 120 vacc${M}x0123 = vext_f16(vacc${M}x0123, vacc${M}x0123, 2); 121 } 122 if (c & (1 * sizeof(__fp16))) { 123 $for M in range(ROW_TILE): 124 vst1_lane_f16(o${M}, vacc${M}x0123, 0); o${M} += 1; 125 } 126 } 127 $for M in range(ROW_TILE): 128 i${M} = (const __fp16*) ((uintptr_t) i${M} + input_increment); 129 o${M} = (__fp16*) ((uintptr_t) o${M} + output_increment); 130 rows = doz(rows, ${ROW_TILE}); 131 } while (rows != 0); 132} 133