diff options
Diffstat (limited to 'llama')
-rw-r--r-- | llama/__init__.py | 18 | ||||
-rw-r--r-- | llama/generate.py | 322 | ||||
-rw-r--r-- | llama/llm.py | 74 | ||||
-rw-r--r-- | llama/model.py | 722 | ||||
-rw-r--r-- | llama/tokenizer.py | 89 | ||||
-rw-r--r-- | llama/utils.py | 153 |
6 files changed, 1378 insertions, 0 deletions
diff --git a/llama/__init__.py b/llama/__init__.py new file mode 100644 index 0000000..2b9a9ef --- /dev/null +++ b/llama/__init__.py @@ -0,0 +1,18 @@ +""" +Llama2 model, loading infrastructure, and sampling helpers +""" + +from .model import Llama + +from .generate import ( + sample_sequence, + sample_batched_sequence, + generate_token_sequence, + generate_batched_token_sequence, + token_probabilities, + batched_token_probabilities +) + +from .utils import ( + load_llama_from_checkpoint +) diff --git a/llama/generate.py b/llama/generate.py new file mode 100644 index 0000000..e735ca4 --- /dev/null +++ b/llama/generate.py @@ -0,0 +1,322 @@ +""" +This module provides an interface for generating and sampling token sequences from a language model. +It allows for the controlled generation of text by specifying parameters such as temperature, top-k, +and top-p, which influence the randomness and quality of the generated sequences. +""" + +# pylint: disable=locally-disabled, R0913, R0914 + +from collections import OrderedDict +from typing import List, Optional, Generator, cast + +import torch + +from .llm import LLM + +TokenList = OrderedDict[int, float] + +def _sample_internal(llm: LLM, context: torch.Tensor) -> torch.Tensor: + """ + Sample a tensor of logits from the language model (LLM) based on the input context. + """ + + batch_size, seq_len = context.shape + + assert seq_len <= llm.context_length + assert batch_size <= llm.max_batch_size + + with torch.inference_mode(): + return llm(context) + +def _load_context(tokens: List[List[int]], pad_id: int, + pad_to_length: Optional[int] = None) -> torch.Tensor: + """ + Load a batch of token lists into a padded tensor suitable for input to a language model. + """ + batch_size = len(tokens) + + max_token_len = max((len(tok) for tok in tokens)) + + pad_to_length = max_token_len if pad_to_length is None else pad_to_length + + context = torch.full( + (batch_size, pad_to_length), pad_id, dtype=torch.long + ) + + for dim, toks in enumerate(tokens): + context[dim, :len(toks)] = torch.tensor(toks, dtype=torch.long) + + return context + +def batched_token_probabilities(llm: LLM, + tokens: List[List[int]], + temperature: float = 1.0) -> List[TokenList]: + """ + Calculate the probabilities of the next token sequence across a batch of sequences. + + Args: + - llm (LLM): An instance of the language model. + - tokens (List[List[int]]): A list of tokenized input sequences. + - temperature (float): A temperature parameter to scale the logits before + applying softmax. Default is 1.0. + + Returns: + - List[TokenList]: A list of ordered dictionaries where each dictionary maps + token ids to their corresponding probabilities for each + sequence in the batch. + """ + + context = _load_context(tokens, llm.padding_idx) + + token_logprobs = _sample_internal(llm, context) + + token_probs = torch.softmax(token_logprobs / temperature, dim=-1) + + samples: List[TokenList] = [OrderedDict() for _ in range(len(tokens))] + for i, p in enumerate(token_probs): + for _id in torch.argsort(p, descending=True): + samples[i][int(_id)] = float(p[_id]) + + return samples + +def token_probabilities(llm: LLM, tokens: List[int]) -> TokenList: + """ + Calculate the probabilities of the next token sequence. + See batched_token_probabilities. + """ + + return batched_token_probabilities(llm, [ tokens ])[0] + +def sample_batched_token( + token_logprobs: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: float = 1.0, + sample_eps: float = 1e-6) -> torch.Tensor: + """ + Sample a token from a batch of token logits with optional top-k and top-p filtering. + + + Args: + - token_logprobs (torch.Tensor): A tensor of token log probabilities. + - temperature (float): A scaling factor for logits before sampling. Default + is 1.0. + - top_k (Optional[int]): If set, the sampling is restricted to the top k + tokens. Default is None (no restriction). + - top_p (float): If set, the sampling is restricted to the smallest set + of tokens with cumulative probability exceeding top_p. + Default is 1.0 (no restriction). + - sample_eps (float): An epsilon value to avoid precision errors during + cumulative probability calculation. Default is 1e-6. + + Returns: + - torch.Tensor: A tensor of sampled token ids for each item in the batch. + + Implements both top-k sampling, top-p sampling, and beam search. + + See: + - https://arxiv.org/pdf/1805.04833.pdf for top-k sampling + - https://arxiv.org/pdf/1904.09751.pdf for top-p sampling + """ + + batch_size = token_logprobs.shape[0] + + token_probs = torch.softmax(token_logprobs / temperature, dim=-1) + + + selected_tokens = torch.zeros(batch_size, dtype=torch.long) + + sorted_tokens = torch.argsort(token_probs, descending=True) + sorted_probs = torch.gather(token_probs, 1, sorted_tokens) + nucleus_mask = sorted_probs.cumsum(dim=-1) < top_p + sample_eps + nucleus_mask[:,0] = True + + for i, (tokens, mask, probs) in enumerate(zip(sorted_tokens, nucleus_mask, sorted_probs)): + nucleus = tokens[mask] + p = probs[mask] + + if top_k is not None and len(nucleus) > top_k: + nucleus = nucleus[:top_k] + p = p[:top_k] + + p /= p.sum(axis=0) + + token = nucleus[torch.multinomial(p, 1)] + + selected_tokens[i] = token + + return selected_tokens + +def generate_batched_token_sequence(llm: LLM, + prompts: List[List[int]], + max_generation_length: Optional[int] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: float = 1.0) -> Generator[List[Optional[int]], None, None]: + """ + Generate a sequence of tokens for each prompt across a sequence of batches. + + Args: + - llm (LLM): An instance of the language model. + - prompts (List[List[int]]): A list of tokenized input sequences (prompts). + - max_generation_length (Optional[int]): The maximum number of tokens to + generate for each prompt. If None, generate up to the model's maximum + context length. Default is None. + - temperature (float): A scaling factor for logits before sampling. Default + is 1.0. + - top_k (Optional[int]): If set, restricts sampling to the top k most + likely tokens. Default is None (no restriction). + - top_p (float): If set, restricts sampling to a subset of tokens with a + cumulative probability greater than top_p. Default is 1.0 + (no restriction). + + Yields: + - Generator[List[Optional[int]], None, None]: A generator that yields lists + of token ids, where each list corresponds to one prompt in the batch. + Yields none if a token was not generated during an iteration of inference. + + Raises: + - AssertionError: If batch size exceeds the maximum allowed by the LLM, or + if the requested generation length is too long. + """ + + batch_size = len(prompts) + assert batch_size <= llm.max_batch_size, \ + "Batch size exceeded the maximum batch size of the LLM" + + prompt_lens = torch.tensor([len(p) for p in prompts], dtype=torch.long) + max_prompt_len = max(prompt_lens) + + remaining_context = llm.context_length - max_prompt_len + if max_generation_length is None: + max_generation_length = remaining_context + else: + assert max_generation_length <= remaining_context, \ + "Cannot generate more tokens than exist in the context" + + eos = torch.zeros(batch_size, dtype=torch.long) + last_pos = 0 + + end_pos = max_prompt_len + max_generation_length + context = _load_context(prompts, llm.padding_idx, pad_to_length=end_pos) + + start_pos = max(prompt_lens) + + for pos in range(start_pos, end_pos): + log_probs = llm(context[:, last_pos:pos], last_pos) + + sampled = sample_batched_token( + log_probs, + temperature=temperature, + top_k=top_k, + top_p=top_p + ) + + in_prompt = pos < prompt_lens + should_replace_mask = (eos == 0) & (~in_prompt) + + yield [int(sampled[i]) if should_replace_mask[i] else None for i in range(batch_size)] + + context[should_replace_mask, pos] = sampled[should_replace_mask] + eos[(eos > 0) & (sampled == llm.eos_token)] = pos + 1 + last_pos = pos + + if (eos > 0).all(): + break + +def generate_token_sequence(llm: LLM, + prompt: List[int], + max_generation_length: Optional[int] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: float = 1.0) -> Generator[int, None, None]: + """ + Generate a sequence of tokens for a single prompt. + See generate_batched_token_sequence. + """ + + for tokens in generate_batched_token_sequence(llm, + [ prompt ], + max_generation_length=max_generation_length, + temperature=temperature, + top_k=top_k, + top_p=top_p): + yield cast(int, tokens[0]) + +def sample_batched_sequence(llm: LLM, + prompts: List[List[int]], + max_generation_length: Optional[int] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: float = 1.0, + include_prompt: bool = False) -> List[List[int]]: + """ + Generate and sample a sequence of tokens for each input sequence in a batch. + + Args: + - llm (LLM): An instance of the language model. + - prompts (List[List[int]]): A list of tokenized input sequences (prompts). + - max_generation_length (Optional[int]): The maximum number of tokens to + generate for each prompt. Defaults to None, which allows the generation + up to the model's maximum context length. + - temperature (float): A scaling factor for logits before sampling, + affecting the randomness of the output. Default is 1.0, with lower values + leading to less random samples. + - top_k (Optional[int]): Limits the sampling pool to the top k tokens + according to the probability distribution. Default is None, indicating no + limit. + - top_p (float): The cumulative probability threshold for nucleus sampling; + allows sampling from a set of high-probability tokens whose cumulative + probability exceeds this threshold. Default is 1.0, indicating no limit. + - include_prompt (bool): If True, includes the input prompt at the beginning + of the generated sequence. Default is False. + + Returns: + - List[List[int]]: A list of lists containing the sampled token sequences + for each input prompt. The sequences include the generated tokens and, + if include_prompt is True, the original input prompt tokens. + """ + + sampled_seqs: List[List[int]] = [[] for _ in range(len(prompts))] + + if include_prompt: + for i, prompt in enumerate(prompts): + sampled_seqs[i].extend(prompt) + + for generated_tokens in generate_batched_token_sequence(llm, + prompts, + max_generation_length, + temperature, + top_k, + top_p): + for i, token in enumerate(generated_tokens): + if token is not None: + sampled_seqs[i].append(token) + + return sampled_seqs + +def sample_sequence(llm: LLM, + prompts: List[int], + max_generation_length: Optional[int] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: float = 1.0, + include_prompt: bool = False) -> List[int]: + """ + Generate and sample a sequence of tokens for a single input sequence. + See sample_batched_sequence for reference. + """ + + return sample_batched_sequence( + llm, [prompts], max_generation_length, temperature, + top_k, top_p, include_prompt + )[0] + +__all__ = [ + 'token_probabilities', + 'sample_batched_token', + 'sample_sequence', + 'sample_batched_sequence', + 'generate_token_sequence', + 'generate_batched_token_sequence', +] diff --git a/llama/llm.py b/llama/llm.py new file mode 100644 index 0000000..7192b31 --- /dev/null +++ b/llama/llm.py @@ -0,0 +1,74 @@ +""" +LLM provides a generalized interface for autoregressive +next-word prediction models. The class can be utilized for tasks such as text +sampling and probability prediction over a vocabulary. +""" + +import torch +from torch import nn + +class LLM(nn.Module): + """ + LLM provides a generalized interface for autoregressive + next-word prediction models. The class can be utilized for tasks such as text + sampling and probability prediction over a vocabulary. + + Attributes: + context_length (int): Length of the context window for the + autoregressive model. Default is -1, which + indicates that this needs to be set. + + max_batch_size (int): The maximum size of a batch that can be processed. + Default is -1, which indicates that this needs to + be set. + + vocab_size (int): The size of the vocabulary used in the model. + Default is -1, which indicates that this needs to + be set. + + padding_idx (int): The index used for padding in mixed-length batches. + Default is -1, which indicates that this needs to be + set. + + eos_token (int): Token index that signifies the end of a sequence during + auto-regressive generation. Default is -1, which + indicates that this needs to be set. + """ + + context_length = -1 + max_batch_size = -1 + vocab_size = -1 + padding_idx = -1 + eos_token = -1 + + def forward(self, context: torch.Tensor, cur_pos: int = 0) -> torch.Tensor: + """ + Computes the log probabilities of the next token given a sequence of + tokens as context. + + Args: + context (torch.Tensor): A tensor of shape (batch_size, context_length) + containing token ids. These tokens serve as the + context for predicting the next token. + + cur_pos (int, optional): The position at which to start the + prediction. If cur_pos is not zero, + the internal cache (if available) will + be used to speed up predictions. + Defaults to 0. + + Returns: + torch.Tensor: A tensor of shape (batch_size, vocab_size) containing + the log probabilities of the next token given the + context. + + Examples: + # Predict the next token for a sequence [1, 2, 3] + log_probs = llm(torch.tensor([[1, 2, 3]], dtype=torch.long), 0) + + # Predict the next token for a sequence [1, 2, 3, 4, 5] using the + # cache starting at position 3 + log_probs = llm(torch.tensor([[4, 5]], dtype=torch.long), 3) + """ + + raise NotImplementedError() diff --git a/llama/model.py b/llama/model.py new file mode 100644 index 0000000..6417c83 --- /dev/null +++ b/llama/model.py @@ -0,0 +1,722 @@ +""" +The core model for the Llama family of LLMs +""" + +import math +import copy + +from dataclasses import dataclass +from typing import Optional, Tuple, Dict + +import torch +import torch.nn.functional as F + +from torch import nn +from .llm import LLM + +# pylint: disable=locally-disabled, R0902, R0913 + +def _round_up_to_multiple(n: int, m: int) -> int: + """ + Round n up to an integer multiple of m + """ + return math.ceil(n / m) * m + +@dataclass +class LlamaArgs: + """ + Arguments class for configuring a LLAMA model. + + Attributes: + dim (int): The model dimension, typically referred to as d_model in + "Attention is All You Need" paper. + n_layers (int): The number of layers in the model. + n_heads (int): The number of attention heads in each layer. + vocab_size (int): The size of the model's vocabulary. + multiple_of (int): Ensures the feed-forward network dimension (d_ff) + is a multiple of this factor. + norm_eps (float): The epsilon value for RMS normalization, avoiding + division by zero. + max_ctx_len (int, optional): The maximum context length the model can + handle. Defaults to 2048. + max_batch_size (int, optional): The maximum batch size supported by the + model's cache. Defaults to 1. + n_groups (Optional[int], optional): The number of key-value groups in + grouped query-attention (GQA), if applicable. Defaults to None. + padding_idx (int): The index used for padding in embeddings. Defaults + to -1. + """ + + dim: int + n_layers: int + n_heads: int + vocab_size: int + multiple_of: int + norm_eps: float + max_ctx_len: int = 2048 + max_batch_size: int = 1 + n_groups: Optional[int] = None + padding_idx: int = -1 + +class RMSNorm(nn.Module): + """ + Implements an unbiased Root Mean Square (RMS) Layer Normalization. + + Reference: + See the paper "Root Mean Square Layer Normalization" at + https://arxiv.org/pdf/1910.07467.pdf for more details. + + Attributes: + eps (float): A small epsilon value added to the denominator for + numerical stability. + gain (nn.Parameter): A learnable gain parameter applied after + normalization. + """ + + def __init__(self, d: int, eps: float = 1e-6, dtype: torch.dtype = torch.float): + """ + Initializes the RMSNorm layer. + + Args: + d (int): The dimensionality of the input feature space. + eps (float, optional): A small epsilon value to add to the + denominator for numerical stability. Defaults to 1e-6. + dtype (torch.dtype, optional): The data type of the learnable gain + parameter. Defaults to torch.float. + """ + super().__init__() + self.eps = eps + self.gain = nn.Parameter(torch.ones(d, dtype=dtype)) + + + def forward(self, a: torch.Tensor) -> torch.Tensor: + """ + Applies RMS normalization to the input tensor. + + Args: + a (torch.Tensor): The input tensor to be normalized. + + Returns: + torch.Tensor: The normalized tensor with the same shape as the input. + """ + + inverse_rms = torch.rsqrt(self.eps + torch.mean(a ** 2, dim=-1, keepdim=True)) + + return a * inverse_rms * self.gain + + +class SwiGLU(nn.Module): + """ + Implements the SwiGLU variant of the Gated Linear Unit (GLU) as part of the + FFN layer of a transformer. SwiGLU is a variant of the Gated Linear Unit + where the gating mechanism is controlled by a Swish activation function. + + Reference: + The SwiGLU activation function is detailed in the paper "GLU Variants Improve Transformer" + which can be accessed at https://arxiv.org/pdf/2002.05202.pdf. + """ + + def __init__(self, dim : int, dim_ff: int, dtype: torch.dtype = torch.float): + """ + Initializes the SwiGLU module. + + Arguments: + dim (int): The dimensionality of the input and output tensors. + dim_ff (int): The reduced dimensionality of the hidden layer. + dtype (torch.dtype, optional): The data type for the weights of + the linear transformations. Defaults to torch.float. + """ + super().__init__() + + self.w = nn.Linear(dim, dim_ff, bias=False, dtype=dtype) + self.v = nn.Linear(dim, dim_ff, bias=False, dtype=dtype) + self.w2 = nn.Linear(dim_ff, dim, bias=False, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Applies the SwiGLU feed-forward layer + + Arguments: + x (torch.Tensor): The input tensor to the SwiGLU module. + + Returns: + torch.Tensor: The output tensor after applying the SwiGLU operation. + """ + return self.w2(F.silu(self.w(x)) * self.v(x)) + +class RotaryEmbeddings(nn.Module): + """ + Implementation of rotary position embeddings. + + Rotary embeddings are a mechanism for injecting positional information into + transformer models. These embeddings apply a rotation to the key and value + vectors in the attention mechanism based on their position, with different + "dampening" factors applied based on the relative distance between two tokens. + + Args: + - dim (int): The dimension of the embeddings. + - max_ctx_len (int): The maximum length of the context for which to compute + the embeddings. + - theta (float, optional): The frequency parameter for computing the rotary + embeddings. Defaults to 10000.0. + + Raises: + AssertionError: If the dimension is not even. + + References: + - RoFormer paper: https://arxiv.org/pdf/2104.09864.pdf + """ + + embedding_cache: Dict[int, torch.Tensor] = {} + + def __init__(self, dim: int, max_ctx_len: int, theta: float = 10000.0): + """ + Initialize the RotaryEmbeddings module. + + Args: + - dim (int): The dimension of the embeddings. + - max_ctx_len (int): The maximum length of the context for which + to compute the embeddings. + - theta (float, optional): The frequency parameter for computing + the rotary embeddings. Defaults to 10000.0. + + Raises: + AssertionError: If the dimension is not even. + """ + super().__init__() + + assert dim % 2 == 0, "Model dimension should be a multiple of two" + + self.n_coord_pairs = dim // 2 + self.rots = RotaryEmbeddings.get_embeddings(dim, max_ctx_len, theta) + + @staticmethod + def compute_angles(dim: int, max_ctx_len: int, theta: float) -> torch.Tensor: + """ + Compute the rotation angles for the embeddings. + + Arguments: + dim (int): The dimension of the embeddings. + max_ctx_len (int): The maximum context length. + theta (float): The frequency parameter for the embeddings. + + Returns: + torch.Tensor: A tensor of shape (max_ctx_len, dim // 2) containing the + rotation angles. + """ + + freqs = theta ** (-torch.arange(0, dim, 2, dtype=torch.float) / dim) + + m = torch.arange(max_ctx_len) + + angles = torch.outer(m, freqs) + + return torch.polar(torch.ones((max_ctx_len, dim // 2)), angles) + + @staticmethod + def get_embeddings(dim: int, max_ctx_len: int, theta: float) -> torch.Tensor: + """ + Retrieve or compute and cache the rotary embeddings. + + Args: + - dim (int): The dimension of the embeddings. + - max_ctx_len (int): The maximum context length. + - theta (float): The frequency parameter for the embeddings. + + Returns: + - torch.Tensor: A tensor containing the precomputed embeddings. + """ + + cache = RotaryEmbeddings.embedding_cache + + if dim not in cache: + + cache[dim] = \ + RotaryEmbeddings.compute_angles(dim, max_ctx_len, theta) + + return cache[dim] + + def forward(self, x: torch.Tensor, cur_pos: int = 0) -> torch.Tensor: + """ + Apply the rotary embeddings to the input tensor. + + Arguments: + - x (torch.Tensor): A tensor of shape (batch_size, ctx_len, ..., dim) + representing input features. + - cur_pos (int, optional): The current position index from which to + apply rotations. Defaults to 0. + + Returns: + - torch.Tensor: The rotated tensor with the same shape as the input. + """ + + _batch_size, ctx_len, *dup_dims, dim = x.shape + + rotated = x.view(*x.shape[:-1], self.n_coord_pairs, 2) + rotated = torch.view_as_complex(rotated.float()) + + broad_shape = [1, ctx_len] + [1] * len(dup_dims) + [ dim // 2 ] + + rotated *= self.rots[cur_pos : cur_pos + ctx_len].view(*broad_shape) + + rotated = torch.view_as_real(rotated) + rotated = rotated.view(*x.shape[:-1], dim).type_as(x) + + return rotated + +def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + mask: Optional[torch.Tensor] = None) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the scaled dot product attention. + + This function takes as input the query (Q), key (K), value (V) tensors, + and an optional mask, and returns the attention output and attention + weights. + + Arguments: + - q (torch.Tensor): The query tensor of shape (..., seq_len, d_k). + - k (torch.Tensor): The key tensor of shape (..., seq_len, d_k). + - v (torch.Tensor): The value tensor of shape (..., seq_len, d_v). + - mask (Optional[torch.Tensor]): An optional mask tensor to apply to + the scores before softmax. + + Returns: + - Tuple[torch.Tensor, torch.Tensor]: A tuple consisting of the attention + output tensor and the attention weights tensor. + + References: + - "Attention Is All You Need": https://arxiv.org/pdf/1706.03762.pdf + """ + + d_k = q.size(-1) + + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) + + if mask is not None: + scores = scores.masked_fill(mask == 0, float("-inf")) + + attn = F.softmax(scores, dim=-1) + + return torch.matmul(attn, v), attn + +class LinearCache: + """ + A simple linear-cache. This is used to cache the attention + keys and values. + """ + + def __init__(self, max_batch_size: int, max_context_len: int, + tensor_dims: Tuple, dtype: torch.dtype = torch.float): + """Initializes the LinearCache with given dimensions and data type.""" + + self.max_batch_size = max_batch_size + self.max_context_len = max_context_len + self.cache = torch.zeros( + (max_batch_size, max_context_len, *tensor_dims), + dtype=dtype, + ) + self.cached_batch_size = 0 + + def get(self, pos: int) -> torch.Tensor: + """Retrieves the cached values up to a given sequence position.""" + return self.cache[:self.cached_batch_size, :pos] + + def set(self, current_pos: int, seq: torch.Tensor) -> None: + """Updates the cache with new sequences at the specified position.""" + + batch_size, ctx_len, *_ = seq.shape + + self.cache[:batch_size, current_pos:current_pos+ctx_len] = seq + + self.cached_batch_size = batch_size + +class GQA(nn.Module): + """ + Group-Query Attention (GQA) module for transformer architectures. + + References: + - See "GQA: Training Generalized Multi-Query Transformer Models from + Multi-Head Checkpoints" at https://arxiv.org/pdf/2305.13245.pdf + """ + + def __init__(self, dim: int, n_heads: int, + n_groups: Optional[int] = None, + query_embedding: Optional[nn.Module] = None, + key_embedding: Optional[nn.Module] = None, + apply_decoder_mask: bool = False, + kv_caches: Optional[Tuple[LinearCache, LinearCache]] = None, + dtype: torch.dtype = torch.float): + """ + Initializes the Group-Query Attention (GQA) module. + + Parameters: + dim (int): The dimensionality of the input features and the last dimension of + the output tensor. + n_heads (int): The number of attention heads to use. + n_groups (Optional[int]): The number of groups to divide the attention heads + into. If not specified, defaults to the number of heads. + Must divide `n_heads` evenly. + query_embedding (Optional[nn.Module]): An optional module to embed the query + vectors, e.g., a positional encoding module. + key_embedding (Optional[nn.Module]): An optional module to embed the key vectors, + similar to `query_embedding`. + apply_decoder_mask (bool): Whether to apply a causal mask to the attention mechanism, + useful for decoder self-attention. + kv_caches (Optional[Tuple[LinearCache, LinearCache]]): Optional tuple of + `LinearCache` instances for + caching key and value projections + in an autoregressive setting. + dtype (torch.dtype): The data type of the module's parameters, e.g., `torch.float32`. + The cache tensors should also use this data type. + """ + + n_groups = n_groups if n_groups else n_heads + + assert dim % n_heads == 0, \ + "Model dimension should be a multiple of n_heads" + assert n_heads % n_groups == 0, \ + "n_heads should be a multiple of n_groups" + + super().__init__() + + head_dim = dim // n_heads + + self.n_heads = n_heads + self.n_groups = n_groups + self.head_dim = head_dim + self.apply_decoder_mask = apply_decoder_mask + + self.query_embedding = query_embedding + self.key_embedding = key_embedding + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + dtype=dtype, + ) + + self.wk = nn.Linear( + dim, + n_groups * head_dim, + bias=False, + dtype=dtype, + ) + + self.wv = nn.Linear( + dim, + n_groups * head_dim, + bias=False, + dtype=dtype, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + dtype=dtype, + ) + + if kv_caches is not None: + self.key_cache = kv_caches[0] + self.value_cache = kv_caches[1] + self.has_cache = True + else: + self.has_cache = False + + def forward(self, x: torch.Tensor, cur_pos: int): + """ + Processes the input tensor with Group-Query Attention. + + Arguments: + - x (torch.Tensor): The input tensor of shape + (batch_size, context_length, dim). + - cur_pos (int): The current position in the sequence for which + to compute attention. This is relevant when using key-value caches, + as it determines the part of the cache to update and utilize. + + Returns: + - torch.Tensor: The output tensor after applying Group-Query Attention. + """ + + batch_size, ctx_len, dim = x.shape + + # Perform key, query, and value projections + + # wq(x) performs all n_heads projections at once, then the result + # is reshaped such that the first head_dim results are part of the first + # head, the second head_dim results are part of the second head, and so + # on. + q = self.wq(x).view(batch_size, ctx_len, self.n_heads, self.head_dim) + k = self.wk(x).view(batch_size, ctx_len, self.n_groups, self.head_dim) + v = self.wv(x).view(batch_size, ctx_len, self.n_groups, self.head_dim) + + # Apply embeddings to the key and query matrices + if self.query_embedding: + q = self.query_embedding(q, cur_pos) + + if self.key_embedding: + k = self.key_embedding(k, cur_pos) + + if self.has_cache: + # Add the new embeddings to the cache + self.key_cache.set(cur_pos, k) + self.value_cache.set(cur_pos, v) + + # Get all the previous embedding from the cache. + + # Note if cur_pos != 0, ctx_len is the length of + # the new sequence. In reality, the whole sequence + # is cur_pos + ctx_len and both cached results will + # be of size (batch_size, ctx_len + cur_pos, n_groups, head_dim) + k = self.key_cache.get(cur_pos + ctx_len) + v = self.value_cache.get(cur_pos + ctx_len) + + # Avoid copy if multi-head attention MHA is used. This is true in the + # 7B and 13B models. + if self.n_groups != self.n_heads: + + repeats = self.n_heads // self.n_groups + + # Duplicate grouped attention heads: + + # From: { G_0, G_1, ... G_{k - 1} } + # To: { G_0, G_0, ... G_0, G_1, ..., G_{k - 1}, G_{k - 1}, ..., G_{k - 1} + k = torch.repeat_interleave(k, dim=2, repeats=repeats) + v = torch.repeat_interleave(v, dim=2, repeats=repeats) + + # Transpose to parallelize attention across heads during batched-matrix + # multiplication + q = q.transpose(1, 2) # (batch_size, n_heads, ctx_len, head_dim) + k = k.transpose(1, 2) # (batch_size, n_heads, ctx_len, head_dim) + v = v.transpose(1, 2) # (batch_size, n_heads, ctx_len, head_dim) + + if self.apply_decoder_mask: + # Construct attention mask + + # In the decoder architecture, the mask is a lower triangular matrix that prevents + # previous tokens from attending to subsequent ones. More concretely for attention + # scores (i, j), token i cannot attend to token j if j > i. + + # When key-value caching is enabled, we are only computing the attention scores + # for the new sequence. Thus, the matrix of scores is of size (ctx_len, total_len) + # and the only masked entries are (i, j) for j > cached_len + i since row i really + # represents token cached_len + i. + mask = torch.hstack([ + torch.ones((ctx_len, cur_pos)), + torch.tril(torch.ones((ctx_len, ctx_len))), + ]) + else: + mask = None + + # Perform attention + x, _ = attention(q, k, v, mask) + + # Concatenate heads + x = x.transpose(1, 2) # (batch_size, ctx_len, n_heads, head_dim) + x = x.reshape((batch_size, ctx_len, dim)) + + # Final linear layer + x = self.wo(x) + + return x + +class LlamaTransformerLayer(nn.Module): + """ + This constitutes a single transformer block within Meta's Llama architecture. + + The transformer architecture combines group-query attention (GQA) and key-value caching. + + It also utilizes RMSNorm to decrease co-variance shifts during training and skip connections + which make training easier. + """ + + def __init__(self, dim: int, n_heads: int, n_groups: Optional[int], max_context_len: int, + max_batch_size: int, round_ff_to_multiple: int, eps: float = 1e-6): + """Initializes a layer of the Lamma transformer.""" + + super().__init__() + + head_dim = dim // n_heads + + self.query_embedding = RotaryEmbeddings(head_dim, max_context_len) + self.key_embedding = RotaryEmbeddings(head_dim, max_context_len) + + cache_size = n_groups if n_groups else n_heads + + self.key_cache = LinearCache( + max_batch_size, max_context_len, (cache_size, head_dim), dtype=torch.bfloat16 + ) + self.value_cache = LinearCache( + max_batch_size, max_context_len, (cache_size, head_dim), dtype=torch.bfloat16 + ) + + self.gqa = GQA( + dim, n_heads, n_groups, + query_embedding=self.query_embedding, + key_embedding=self.key_embedding, + kv_caches=(self.key_cache, self.value_cache), + dtype=torch.bfloat16, + apply_decoder_mask=True, + ) + + # It might have been better to specify the inner "hidden" feed-forward + # dimension directly as a hyper parameter. It seems that FAIR chose + # this odd ratio from the [SwiGLU paper](https://arxiv.org/pdf/2002.05202.pdf) + # directly. This seems slightly odd as this ratio was initially used only for + # the purposes of enabling a fair comparison across different feed-forward + # configurations. + dim_ff = _round_up_to_multiple(4 * int(2 * dim / 3), round_ff_to_multiple) + + self.feed_forward = SwiGLU(dim, dim_ff, dtype=torch.bfloat16) + + self.attention_norm = RMSNorm(dim, eps, dtype=torch.bfloat16) + self.forward_norm = RMSNorm(dim, eps, dtype=torch.bfloat16) + + def forward(self, x: torch.Tensor, cur_pos: int = 0) -> torch.Tensor: + """ + Takes as an input the input embeddings or previous decoder output + and produces the output of this decoder + """ + + # RMS Norm + x_norm = self.attention_norm(x) + # GQA with a skip connection + # See ResNet at https://arxiv.org/pdf/1512.03385.pdf for skip connections + h = x + self.gqa(x_norm, cur_pos=cur_pos) + # RMS Norm + h_norm = self.forward_norm(h) + # SwiGLU feed-forward with a skip connection + h = h + self.feed_forward(h_norm) + + return h + +class LlamaDecoderStack(nn.Module): + """ + The decoder stack is a stack of n_layers of decoders. + """ + + def __init__(self, args: LlamaArgs): + """Initializes the decoder stack""" + + super().__init__() + + layer = LlamaTransformerLayer( + args.dim, args.n_heads, args.n_groups, args.max_ctx_len, + args.max_batch_size, args.multiple_of, args.norm_eps + ) + + self.decoders = nn.ModuleList([ + copy.deepcopy(layer) for _ in range(args.n_layers) + ]) + + def forward(self, embedding: torch.Tensor, cur_pos: int = 0) -> torch.Tensor: + """Apply all encoders, obtaining the outputs of the last decoder""" + + h = embedding + + for decoder in self.decoders: + h = decoder(h, cur_pos) + + return h + +class LlamaEmbeddings(nn.Module): + """ + LlamaEmbeddings transform a tensor of token ids into embedding vectors of size dim + """ + + def __init__(self, vocab_size: int, dim: int, padding_idx: int): + """Initializes the LlamaEmbeddings""" + + super().__init__() + + self.vocab_size = vocab_size + self.dim = dim + self.padding_idx = padding_idx + + self.embedding = nn.Embedding(self.vocab_size, self.dim, dtype=torch.bfloat16) + + def forward(self, context: torch.Tensor) -> torch.Tensor: + """Retrieve the embeddings for a token sequence""" + + # The original Llama implementation employs parallel embeddings. This + # implicitly produces zero embeddings for padding_idx = -1. This behavior + # is seemingly undefined and relies on implementation details within + # the parallel embeddings. + + # Since nn.Embedding does not handle negative indices, we must manually + # zero out the padded parts of the context. + padding_mask = context == torch.tensor(self.padding_idx, dtype=torch.long) + context[padding_mask] = torch.tensor(0, dtype=torch.long) + + embeddings = self.embedding(context) + + embeddings[padding_mask] = torch.zeros((self.dim,), dtype=embeddings.dtype) + + return embeddings + +class Llama(LLM): + """An class representing the Llama family of LLMs""" + + def __init__(self, args : LlamaArgs): + """Initialize the Llama model""" + + super().__init__() + + self.context_length = args.max_ctx_len + self.max_batch_size = args.max_batch_size + + self.embeddings = LlamaEmbeddings( + args.vocab_size, args.dim, args.padding_idx + ) + + self.decoder_stack = LlamaDecoderStack(args) + + self.output_norm = RMSNorm( + args.dim, eps=args.norm_eps, dtype=torch.bfloat16 + ) + + self.vocab_map = nn.Linear( + args.dim, args.vocab_size, bias=False, dtype=torch.bfloat16 + ) + + def forward(self, context: torch.Tensor, cur_pos: int = 0) -> torch.Tensor: + """ + Computes the log probabilities of the next token given a sequence of + tokens as context. + + Args: + context (torch.Tensor): A tensor of shape (batch_size, context_length) + containing token ids. These tokens serve as the + context for predicting the next token. + + cur_pos (int, optional): The position at which to start the + prediction. If cur_pos is not zero, + the internal cache (if available) will + be used to speed up predictions. + Defaults to 0. + + Returns: + torch.Tensor: A tensor of shape (batch_size, vocab_size) containing + the log probabilities of the next token given the + context. + + Examples: + # Predict the next token for a sequence [1, 2, 3] + log_probs = llm(torch.tensor([[1, 2, 3]], dtype=torch.long), 0) + + # Predict the next token for a sequence [1, 2, 3, 4, 5] using the + # cache starting at position 3 + log_probs = llm(torch.tensor([[4, 5]], dtype=torch.long), 3) + """ + + embeddings = self.embeddings(context) # ( n_batches, n_embeddings, dim ) + + h = self.decoder_stack(embeddings, cur_pos) + + h = self.output_norm(h) + + vocab_logits = self.vocab_map(h) + + return vocab_logits[:,-1] diff --git a/llama/tokenizer.py b/llama/tokenizer.py new file mode 100644 index 0000000..937a0b8 --- /dev/null +++ b/llama/tokenizer.py @@ -0,0 +1,89 @@ +""" +Llama Tokenizer +=============== +This module contains the Tokenizer class that wraps the SentencePiece tokenizer. +""" + +from typing import List +from sentencepiece import SentencePieceProcessor # type: ignore + +class Tokenizer: + """ + Llama Tokenizer Class + --------------------- + This class provides a wrapper around the SentencePiece tokenizer. + It adds some utility functions for easier encoding and decoding. + + Attributes: + bos_id (int): The id representing the "beginning of sentence" token. + eos_id (int): The id representing the "end of sentence" token. + pad_id (int): The id representing the padding token. + vocab_size (int): The size of the vocabulary. + """ + + def __init__(self, model_path: str): + """ + Initialize the Tokenizer. + + Args: + model_path (str): The path to the SentencePiece model file. + + Returns: + None + """ + sp = SentencePieceProcessor(model_file=model_path) + + self.bos_id: int = sp.bos_id() + self.eos_id: int = sp.eos_id() + self.pad_id: int = sp.pad_id() + self.vocab_size: int = sp.vocab_size() + + self.sp = sp + + def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]: + """ + Encode a string as a sequence of token IDs. + + Args: + s (str): The string to be encoded. + bos (bool, optional): Whether to add a "beginning of sentence" token. Defaults to False. + eos (bool, optional): Whether to add an "end of sentence" token. Defaults to False. + + Returns: + List[int]: The list of token IDs. + """ + tokens = [] + + if bos: + tokens.append(self.bos_id) + + tokens.extend(self.sp.encode(s)) + + if eos: + tokens.append(self.eos_id) + + return tokens + + def decode(self, tokens: List[int]) -> str: + """ + Decode a sequence of token IDs to a string. + + Args: + tokens (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + return self.sp.decode(tokens) + + def id_to_piece(self, token: int) -> str: + """ + Convert a token ID to its corresponding token string. + + Args: + token (int): The token ID. + + Returns: + str: The token string, with SentencePiece's '▁' character replaced by a space. + """ + return self.sp.id_to_piece(token).replace('▁', ' ') diff --git a/llama/utils.py b/llama/utils.py new file mode 100644 index 0000000..5c2af3c --- /dev/null +++ b/llama/utils.py @@ -0,0 +1,153 @@ +""" +Utilities for loading the Llama model and tokenizer from a checkpoint directory +and a tokenizer model file. +""" + +import json +import re + +from pathlib import Path +from typing import Dict, Any, Tuple + +import torch + +from .model import LlamaArgs, Llama +from .tokenizer import Tokenizer + +ModuleParams = Dict[str, Any] + +def _load_model_from_checkpoint(checkpoint_dir: str) \ + -> Tuple[LlamaArgs, ModuleParams]: + """ + Load the Llama model from a given checkpoint directory. + + Args: + checkpoint_dir (str): The path to the directory containing the Llama + model checkpoint. + + Returns: + Tuple[LlamaArgs, ModuleParams]: A tuple containing: + - LlamaArgs: Arguments used for initializing the Llama model. + - ModuleParams: PyTorch state dictionary for the Llama model. + """ + + checkpoint_path = Path(checkpoint_dir) + with open(checkpoint_path / "params.json", "r", encoding='utf-8') as f: + args = json.loads(f.read()) + + args = LlamaArgs(**args) + + checkpoint_paths = list(checkpoint_path.glob('*.pth')) + checkpoint_paths = sorted(checkpoint_paths) + + checkpoint = torch.load(checkpoint_paths[0], map_location="cpu") + + return args, checkpoint + +# pylint: disable=locally-disabled, R0912 +def _transform_params(params: ModuleParams) -> ModuleParams: + """ + Map the state dictionary keys from the official Llama model to the keys + used in this implementation. + + Args: + params (ModuleParams): The state dictionary from the official Llama + model. + + Returns: + ModuleParams: The modified state dictionary to match the keys used in + this implementation. + """ + + new_params = {} + + for label, param in params.items(): + + if label == 'tok_embeddings.weight': + label = 'embeddings.embedding.weight' + elif label == 'norm.weight': + label = 'output_norm.gain' + elif label == 'output.weight': + label = 'vocab_map.weight' + else: + + if label in { 'rope.freqs' }: + continue + + regex = re.compile(r'layers\.(\d+)\.(.*)') + + m = regex.match(label) + + assert m is not None + + layer_num = m.group(1) + sub_label = m.group(2) + + label = f'decoder_stack.decoders.{layer_num}.' + + if sub_label == 'attention.wq.weight': + label += 'gqa.wq.weight' + elif sub_label == 'attention.wk.weight': + label += 'gqa.wk.weight' + elif sub_label == 'attention.wv.weight': + label += 'gqa.wv.weight' + elif sub_label == 'attention.wo.weight': + label += 'gqa.wo.weight' + elif sub_label == 'feed_forward.w1.weight': + label += 'feed_forward.w.weight' + elif sub_label == 'feed_forward.w2.weight': + label += 'feed_forward.w2.weight' + elif sub_label == 'feed_forward.w3.weight': + label += 'feed_forward.v.weight' + elif sub_label == 'attention_norm.weight': + label += 'attention_norm.gain' + elif sub_label == 'ffn_norm.weight': + label += 'forward_norm.gain' + else: + assert False, "Key not found" + + new_params[label] = param + + return new_params + +def load_llama_from_checkpoint(checkpoint_dir: str, tokenizer_path: str, + max_batch_size: int = 1, max_context_len: int = 2048) \ + -> Tuple[Llama, Tokenizer]: + """ + Load the Llama model and the tokenizer from specified paths. + + Args: + checkpoint_dir (str): Path to the directory containing the model + checkpoint. + tokenizer_path (str): Path to the tokenizer model file. + max_batch_size (int, optional): Maximum batch size the model can accept. + Affects cache size. Default is 1. + max_context_len (int, optional): Maximum context length the model can + handle. Affects cache size. Default is + 2048. + + Returns: + Tuple[Llama, Tokenizer]: A tuple containing: + - Llama: The Llama model loaded from the checkpoint. + - Tokenizer: The tokenizer model loaded from the given path. + """ + + args, checkpoint = _load_model_from_checkpoint(checkpoint_dir) + + tokenizer = Tokenizer(tokenizer_path) + + args.vocab_size = tokenizer.vocab_size + args.max_batch_size = max_batch_size + args.max_ctx_len = max_context_len + args.padding_idx = tokenizer.pad_id + + checkpoint = _transform_params(checkpoint) + + llama_model = Llama(args) + llama_model.load_state_dict(checkpoint) + + return llama_model, tokenizer + +__all__ = [ + 'load_llama_from_checkpoint' +] |