Llama2 has used LlamaRMSNorm module to save training time, it is defined as:
class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states
Paper Root Mean Square Layer Normalization has introduced it.
RMSNorm can be viewed as:
From paper, we can find: RMSNorm achieves comparable performance against LayerNorm but reduces the running time by 7%∼64% on different models.
To understand LayerNorm, you can read this tutorial:
An Explain to Layer Normalization in Neural Networks – Machine Learning Tutorial