1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# Import custom op defined in op_sdpa_aot.cpp. Those ops are using PyTorch 8# C++ APIs for registration so here we need to import the shared library. 9# This is only needed for OSS. 10 11# pyre-unsafe 12 13import logging 14from pathlib import Path 15 16import torch 17 18from torch.library import impl 19 20# TODO rename this file to custom_ops_meta_registration.py 21try: 22 op = torch.ops.llama.sdpa_with_kv_cache.default 23 assert op is not None 24 op2 = torch.ops.llama.fast_hadamard_transform.default 25 assert op2 is not None 26except: 27 libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*")) 28 assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" 29 logging.info(f"Loading custom ops library: {libs[0]}") 30 torch.ops.load_library(libs[0]) 31 op = torch.ops.llama.sdpa_with_kv_cache.default 32 assert op is not None 33 op2 = torch.ops.llama.fast_hadamard_transform.default 34 assert op2 is not None 35 36custom_ops_lib = torch.library.Library("llama", "IMPL") 37 38 39def _validate_params( 40 query, 41 key, 42 value, 43 key_cache, 44 value_cache, 45 start_pos, 46 seq_len, 47 attn_mask, 48 drpout_p, 49 is_causal, 50 scale, 51): 52 assert ( 53 query.dim() == 4 54 ), f"Expected query to be 4 dimensional but got {query.dim()} dimensions." 55 assert ( 56 key.dim() == 4 57 ), f"Expected key to be 4 dimensional but got {key.dim()} dimensions." 58 assert ( 59 value.dim() == 4 60 ), f"Expected value to be 4 dimensional but got {value.dim()} dimensions." 61 62 assert ( 63 query.dtype == torch.float32 64 ), f"Expected query to be float32 but got {query.dtype}" 65 assert key.dtype == torch.float32, f"Expected key to be float32 but got {key.dtype}" 66 assert ( 67 value.dtype == torch.float32 68 ), f"Expected value to be float32 but got {value.dtype}" 69 70 assert ( 71 key_cache.dim() == 4 72 ), f"Expected key_cache to be 4 dimensional but got {key_cache.dim()}" 73 assert ( 74 value_cache.dim() == 4 75 ), f"Expected value_cache to be 4 dimensional but got {value_cache.dim()}" 76 77 assert ( 78 key_cache.dtype == torch.float32 79 ), f"Expected key_cache to be float32 but got {key_cache.dtype}" 80 assert ( 81 value_cache.dtype == torch.float32 82 ), f"Expected value_cache to be float32 but got {value_cache.dtype}" 83 84 assert ( 85 key_cache.size() == value_cache.size() 86 ), f"Key cache and value cache must have same size but got {key_cache.size()} and {value_cache.size()}" 87 88 # These asserts are real but they require me to add constrain_as_size/value calls to the model and I dont want to do that right now 89 # assert start_pos < key_cache.size( 90 # 1 91 # ), f"Start position {start_pos} must be less than sequence length {key_cache.size(2)}" 92 # assert (start_pos + seq_len) < key_cache.size( 93 # 1 94 # ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {key_cache.size(2)}" 95 96 if attn_mask is not None: 97 assert ( 98 attn_mask.dim() == 2 99 ), f"Expected attn_mask to be 2 dimensional but got {attn_mask.dim()} dimensions." 100 assert (attn_mask.dtype == torch.float32) or ( 101 attn_mask.dtype == torch.float16 102 ), f"Expected attn_mask to be float but got {attn_mask.dtype}" 103 104 105@impl(custom_ops_lib, "sdpa_with_kv_cache", "Meta") 106def sdpa_with_kv_cache_meta( 107 query, 108 key, 109 value, 110 key_cache, 111 value_cache, 112 start_pos, 113 seq_len, 114 attn_mask=None, 115 drpout_p=0.0, 116 is_causal=False, 117 scale=None, 118): 119 _validate_params( 120 query, 121 key, 122 value, 123 key_cache, 124 value_cache, 125 start_pos, 126 seq_len, 127 attn_mask, 128 drpout_p, 129 is_causal, 130 scale, 131 ) 132 133 return torch.empty_like(query) 134 135 136@impl(custom_ops_lib, "fast_hadamard_transform", "Meta") 137def fast_hadamard_transform_meta(mat): 138 # assert(mat.strides[-1] == 1, "input matrix must be contiguous in the last dimension!") 139 # assert(mat.shape[-1] == 128 or mat.shape[-1] == 14336, "unexpected input size for llama3 demo!") 140 # assert(mat.is_contiguous(), "input matrix must be contiguous currently!") 141 return torch.empty_like(mat) 142 143 144@impl(custom_ops_lib, "custom_sdpa", "Meta") 145def custom_sdpa( 146 query, 147 key_cache, 148 value_cache, 149 start_pos, 150 attn_mask=None, 151 drpout_p=0.0, 152 is_causal=False, 153 scale=None, 154): 155 seq_len = query.size(1) 156 _validate_params( 157 query, 158 key_cache, 159 value_cache, 160 key_cache, 161 value_cache, 162 start_pos, 163 seq_len, 164 attn_mask, 165 drpout_p, 166 is_causal, 167 scale, 168 ) 169 170 return torch.empty_like(query) 171 172 173def _validate_update_cache_params( 174 value, 175 cache, 176 start_pos, 177): 178 seq_len = value.size(1) 179 assert ( 180 value.dim() == 4 181 ), f"Expected value to be 4 dimensional but got {value.dim()} dimensions." 182 183 assert ( 184 value.dtype == cache.dtype 185 ), f"Expected value and cache to be of the same type but got value type {value.dtype} and cache type {cache.dtype}" 186 187 for i in [0, 2, 3]: 188 assert value.size(i) == cache.size( 189 i 190 ), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}" 191 192 torch._check_is_size(start_pos) 193 # Setting to arbitrary limit of 256 for now since there is no way 194 # to plumb this information from model config 195 torch._check(start_pos < cache.size(1)) 196 assert start_pos < cache.size( 197 1 198 ), f"Start position {start_pos} must be less than sequence length {cache.size(1)}" 199 200 torch._check((start_pos + seq_len) < cache.size(1)) 201 assert (start_pos + seq_len) < cache.size( 202 1 203 ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}" 204 205 206@impl(custom_ops_lib, "update_quantized_cache", "Meta") 207def update_quantized_cache_meta( 208 value, 209 cache, 210 start_pos, 211): 212 _validate_update_cache_params( 213 value, 214 cache, 215 start_pos, 216 ) 217 218 # Update cache doesnt really return anything but I dont know a better 219 # workaround. Should we just return cache instead? But I am afraid that 220 # will result in extra memory allocation 221 return torch.empty((1,), dtype=value.dtype, device="meta") 222