Llama2 LlamaRMSNorm Explained – LLM Tutorial

By | February 27, 2024

Llama2 has used LlamaRMSNorm module to save training time, it is defined as:

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

        # convert into half-precision if necessary
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

        return self.weight * hidden_states

Paper Root Mean Square Layer Normalization has introduced it.

RMSNorm can be viewed as:

RMSNorm

From paper, we can find: RMSNorm achieves comparable performance against LayerNorm but reduces the running time by 7%∼64% on different models.

To understand LayerNorm, you can read this tutorial:

An Explain to Layer Normalization in Neural Networks – Machine Learning Tutorial