1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/runtime/kernel/kernel_includes.h> 12 13 namespace torch { 14 namespace executor { 15 16 namespace native { 17 18 Tensor& sdpa_with_kv_cache_out( 19 KernelRuntimeContext& ctx, 20 const Tensor& q_projected, 21 const Tensor& k_projected, 22 const Tensor& v_projected, 23 Tensor& key_cache, 24 Tensor& value_cache, 25 const int64_t start_pos, 26 const int64_t seq_len, 27 const optional<Tensor>& attn_mask, 28 const double dropout_p, 29 const bool is_causal, 30 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy 31 const optional<double> scale, 32 Tensor& output); 33 34 Tensor& custom_sdpa_out( 35 RuntimeContext& ctx, 36 const Tensor& q, 37 const Tensor& k, 38 const Tensor& v, 39 const int64_t start_pos, 40 const optional<Tensor>& attn_mask, 41 const double dropout_p, 42 const bool is_causal, 43 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy 44 const optional<double> scale, 45 Tensor& output); 46 47 Tensor& flash_attention_kernel_out( 48 KernelRuntimeContext& ctx, 49 const Tensor& query, 50 const Tensor& key, 51 const Tensor& value, 52 const optional<Tensor>& attn_mask, 53 const double dropout_p, 54 const bool is_causal, 55 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy 56 const optional<double> scale, 57 Tensor& output); 58 59 } // namespace native 60 } // namespace executor 61 } // namespace torch 62