1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8from executorch.examples.models.llama.llama_transformer import RMSNorm 9 10 11def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module): 12 for name, child in module.named_children(): 13 if isinstance(child, RMSNorm): 14 rms_norm = torch.nn.RMSNorm(child.dim, eps=child.eps) 15 rms_norm.weight = child.weight 16 setattr( 17 module, 18 name, 19 rms_norm, 20 ) 21 else: 22 replace_rms_norm_with_native_rms_norm(child) 23 return module 24