aboutsummaryrefslogtreecommitdiff
path: root/llama/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'llama/model.py')
-rw-r--r--llama/model.py722
1 files changed, 722 insertions, 0 deletions
diff --git a/llama/model.py b/llama/model.py
new file mode 100644
index 0000000..6417c83
--- /dev/null
+++ b/llama/model.py
@@ -0,0 +1,722 @@
+"""
+The core model for the Llama family of LLMs
+"""
+
+import math
+import copy
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Dict
+
+import torch
+import torch.nn.functional as F
+
+from torch import nn
+from .llm import LLM
+
+# pylint: disable=locally-disabled, R0902, R0913
+
+def _round_up_to_multiple(n: int, m: int) -> int:
+ """
+ Round n up to an integer multiple of m
+ """
+ return math.ceil(n / m) * m
+
+@dataclass
+class LlamaArgs:
+ """
+ Arguments class for configuring a LLAMA model.
+
+ Attributes:
+ dim (int): The model dimension, typically referred to as d_model in
+ "Attention is All You Need" paper.
+ n_layers (int): The number of layers in the model.
+ n_heads (int): The number of attention heads in each layer.
+ vocab_size (int): The size of the model's vocabulary.
+ multiple_of (int): Ensures the feed-forward network dimension (d_ff)
+ is a multiple of this factor.
+ norm_eps (float): The epsilon value for RMS normalization, avoiding
+ division by zero.
+ max_ctx_len (int, optional): The maximum context length the model can
+ handle. Defaults to 2048.
+ max_batch_size (int, optional): The maximum batch size supported by the
+ model's cache. Defaults to 1.
+ n_groups (Optional[int], optional): The number of key-value groups in
+ grouped query-attention (GQA), if applicable. Defaults to None.
+ padding_idx (int): The index used for padding in embeddings. Defaults
+ to -1.
+ """
+
+ dim: int
+ n_layers: int
+ n_heads: int
+ vocab_size: int
+ multiple_of: int
+ norm_eps: float
+ max_ctx_len: int = 2048
+ max_batch_size: int = 1
+ n_groups: Optional[int] = None
+ padding_idx: int = -1
+
+class RMSNorm(nn.Module):
+ """
+ Implements an unbiased Root Mean Square (RMS) Layer Normalization.
+
+ Reference:
+ See the paper "Root Mean Square Layer Normalization" at
+ https://arxiv.org/pdf/1910.07467.pdf for more details.
+
+ Attributes:
+ eps (float): A small epsilon value added to the denominator for
+ numerical stability.
+ gain (nn.Parameter): A learnable gain parameter applied after
+ normalization.
+ """
+
+ def __init__(self, d: int, eps: float = 1e-6, dtype: torch.dtype = torch.float):
+ """
+ Initializes the RMSNorm layer.
+
+ Args:
+ d (int): The dimensionality of the input feature space.
+ eps (float, optional): A small epsilon value to add to the
+ denominator for numerical stability. Defaults to 1e-6.
+ dtype (torch.dtype, optional): The data type of the learnable gain
+ parameter. Defaults to torch.float.
+ """
+ super().__init__()
+ self.eps = eps
+ self.gain = nn.Parameter(torch.ones(d, dtype=dtype))
+
+
+ def forward(self, a: torch.Tensor) -> torch.Tensor:
+ """
+ Applies RMS normalization to the input tensor.
+
+ Args:
+ a (torch.Tensor): The input tensor to be normalized.
+
+ Returns:
+ torch.Tensor: The normalized tensor with the same shape as the input.
+ """
+
+ inverse_rms = torch.rsqrt(self.eps + torch.mean(a ** 2, dim=-1, keepdim=True))
+
+ return a * inverse_rms * self.gain
+
+
+class SwiGLU(nn.Module):
+ """
+ Implements the SwiGLU variant of the Gated Linear Unit (GLU) as part of the
+ FFN layer of a transformer. SwiGLU is a variant of the Gated Linear Unit
+ where the gating mechanism is controlled by a Swish activation function.
+
+ Reference:
+ The SwiGLU activation function is detailed in the paper "GLU Variants Improve Transformer"
+ which can be accessed at https://arxiv.org/pdf/2002.05202.pdf.
+ """
+
+ def __init__(self, dim : int, dim_ff: int, dtype: torch.dtype = torch.float):
+ """
+ Initializes the SwiGLU module.
+
+ Arguments:
+ dim (int): The dimensionality of the input and output tensors.
+ dim_ff (int): The reduced dimensionality of the hidden layer.
+ dtype (torch.dtype, optional): The data type for the weights of
+ the linear transformations. Defaults to torch.float.
+ """
+ super().__init__()
+
+ self.w = nn.Linear(dim, dim_ff, bias=False, dtype=dtype)
+ self.v = nn.Linear(dim, dim_ff, bias=False, dtype=dtype)
+ self.w2 = nn.Linear(dim_ff, dim, bias=False, dtype=dtype)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Applies the SwiGLU feed-forward layer
+
+ Arguments:
+ x (torch.Tensor): The input tensor to the SwiGLU module.
+
+ Returns:
+ torch.Tensor: The output tensor after applying the SwiGLU operation.
+ """
+ return self.w2(F.silu(self.w(x)) * self.v(x))
+
+class RotaryEmbeddings(nn.Module):
+ """
+ Implementation of rotary position embeddings.
+
+ Rotary embeddings are a mechanism for injecting positional information into
+ transformer models. These embeddings apply a rotation to the key and value
+ vectors in the attention mechanism based on their position, with different
+ "dampening" factors applied based on the relative distance between two tokens.
+
+ Args:
+ - dim (int): The dimension of the embeddings.
+ - max_ctx_len (int): The maximum length of the context for which to compute
+ the embeddings.
+ - theta (float, optional): The frequency parameter for computing the rotary
+ embeddings. Defaults to 10000.0.
+
+ Raises:
+ AssertionError: If the dimension is not even.
+
+ References:
+ - RoFormer paper: https://arxiv.org/pdf/2104.09864.pdf
+ """
+
+ embedding_cache: Dict[int, torch.Tensor] = {}
+
+ def __init__(self, dim: int, max_ctx_len: int, theta: float = 10000.0):
+ """
+ Initialize the RotaryEmbeddings module.
+
+ Args:
+ - dim (int): The dimension of the embeddings.
+ - max_ctx_len (int): The maximum length of the context for which
+ to compute the embeddings.
+ - theta (float, optional): The frequency parameter for computing
+ the rotary embeddings. Defaults to 10000.0.
+
+ Raises:
+ AssertionError: If the dimension is not even.
+ """
+ super().__init__()
+
+ assert dim % 2 == 0, "Model dimension should be a multiple of two"
+
+ self.n_coord_pairs = dim // 2
+ self.rots = RotaryEmbeddings.get_embeddings(dim, max_ctx_len, theta)
+
+ @staticmethod
+ def compute_angles(dim: int, max_ctx_len: int, theta: float) -> torch.Tensor:
+ """
+ Compute the rotation angles for the embeddings.
+
+ Arguments:
+ dim (int): The dimension of the embeddings.
+ max_ctx_len (int): The maximum context length.
+ theta (float): The frequency parameter for the embeddings.
+
+ Returns:
+ torch.Tensor: A tensor of shape (max_ctx_len, dim // 2) containing the
+ rotation angles.
+ """
+
+ freqs = theta ** (-torch.arange(0, dim, 2, dtype=torch.float) / dim)
+
+ m = torch.arange(max_ctx_len)
+
+ angles = torch.outer(m, freqs)
+
+ return torch.polar(torch.ones((max_ctx_len, dim // 2)), angles)
+
+ @staticmethod
+ def get_embeddings(dim: int, max_ctx_len: int, theta: float) -> torch.Tensor:
+ """
+ Retrieve or compute and cache the rotary embeddings.
+
+ Args:
+ - dim (int): The dimension of the embeddings.
+ - max_ctx_len (int): The maximum context length.
+ - theta (float): The frequency parameter for the embeddings.
+
+ Returns:
+ - torch.Tensor: A tensor containing the precomputed embeddings.
+ """
+
+ cache = RotaryEmbeddings.embedding_cache
+
+ if dim not in cache:
+
+ cache[dim] = \
+ RotaryEmbeddings.compute_angles(dim, max_ctx_len, theta)
+
+ return cache[dim]
+
+ def forward(self, x: torch.Tensor, cur_pos: int = 0) -> torch.Tensor:
+ """
+ Apply the rotary embeddings to the input tensor.
+
+ Arguments:
+ - x (torch.Tensor): A tensor of shape (batch_size, ctx_len, ..., dim)
+ representing input features.
+ - cur_pos (int, optional): The current position index from which to
+ apply rotations. Defaults to 0.
+
+ Returns:
+ - torch.Tensor: The rotated tensor with the same shape as the input.
+ """
+
+ _batch_size, ctx_len, *dup_dims, dim = x.shape
+
+ rotated = x.view(*x.shape[:-1], self.n_coord_pairs, 2)
+ rotated = torch.view_as_complex(rotated.float())
+
+ broad_shape = [1, ctx_len] + [1] * len(dup_dims) + [ dim // 2 ]
+
+ rotated *= self.rots[cur_pos : cur_pos + ctx_len].view(*broad_shape)
+
+ rotated = torch.view_as_real(rotated)
+ rotated = rotated.view(*x.shape[:-1], dim).type_as(x)
+
+ return rotated
+
+def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+ mask: Optional[torch.Tensor] = None) \
+ -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Compute the scaled dot product attention.
+
+ This function takes as input the query (Q), key (K), value (V) tensors,
+ and an optional mask, and returns the attention output and attention
+ weights.
+
+ Arguments:
+ - q (torch.Tensor): The query tensor of shape (..., seq_len, d_k).
+ - k (torch.Tensor): The key tensor of shape (..., seq_len, d_k).
+ - v (torch.Tensor): The value tensor of shape (..., seq_len, d_v).
+ - mask (Optional[torch.Tensor]): An optional mask tensor to apply to
+ the scores before softmax.
+
+ Returns:
+ - Tuple[torch.Tensor, torch.Tensor]: A tuple consisting of the attention
+ output tensor and the attention weights tensor.
+
+ References:
+ - "Attention Is All You Need": https://arxiv.org/pdf/1706.03762.pdf
+ """
+
+ d_k = q.size(-1)
+
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
+
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, float("-inf"))
+
+ attn = F.softmax(scores, dim=-1)
+
+ return torch.matmul(attn, v), attn
+
+class LinearCache:
+ """
+ A simple linear-cache. This is used to cache the attention
+ keys and values.
+ """
+
+ def __init__(self, max_batch_size: int, max_context_len: int,
+ tensor_dims: Tuple, dtype: torch.dtype = torch.float):
+ """Initializes the LinearCache with given dimensions and data type."""
+
+ self.max_batch_size = max_batch_size
+ self.max_context_len = max_context_len
+ self.cache = torch.zeros(
+ (max_batch_size, max_context_len, *tensor_dims),
+ dtype=dtype,
+ )
+ self.cached_batch_size = 0
+
+ def get(self, pos: int) -> torch.Tensor:
+ """Retrieves the cached values up to a given sequence position."""
+ return self.cache[:self.cached_batch_size, :pos]
+
+ def set(self, current_pos: int, seq: torch.Tensor) -> None:
+ """Updates the cache with new sequences at the specified position."""
+
+ batch_size, ctx_len, *_ = seq.shape
+
+ self.cache[:batch_size, current_pos:current_pos+ctx_len] = seq
+
+ self.cached_batch_size = batch_size
+
+class GQA(nn.Module):
+ """
+ Group-Query Attention (GQA) module for transformer architectures.
+
+ References:
+ - See "GQA: Training Generalized Multi-Query Transformer Models from
+ Multi-Head Checkpoints" at https://arxiv.org/pdf/2305.13245.pdf
+ """
+
+ def __init__(self, dim: int, n_heads: int,
+ n_groups: Optional[int] = None,
+ query_embedding: Optional[nn.Module] = None,
+ key_embedding: Optional[nn.Module] = None,
+ apply_decoder_mask: bool = False,
+ kv_caches: Optional[Tuple[LinearCache, LinearCache]] = None,
+ dtype: torch.dtype = torch.float):
+ """
+ Initializes the Group-Query Attention (GQA) module.
+
+ Parameters:
+ dim (int): The dimensionality of the input features and the last dimension of
+ the output tensor.
+ n_heads (int): The number of attention heads to use.
+ n_groups (Optional[int]): The number of groups to divide the attention heads
+ into. If not specified, defaults to the number of heads.
+ Must divide `n_heads` evenly.
+ query_embedding (Optional[nn.Module]): An optional module to embed the query
+ vectors, e.g., a positional encoding module.
+ key_embedding (Optional[nn.Module]): An optional module to embed the key vectors,
+ similar to `query_embedding`.
+ apply_decoder_mask (bool): Whether to apply a causal mask to the attention mechanism,
+ useful for decoder self-attention.
+ kv_caches (Optional[Tuple[LinearCache, LinearCache]]): Optional tuple of
+ `LinearCache` instances for
+ caching key and value projections
+ in an autoregressive setting.
+ dtype (torch.dtype): The data type of the module's parameters, e.g., `torch.float32`.
+ The cache tensors should also use this data type.
+ """
+
+ n_groups = n_groups if n_groups else n_heads
+
+ assert dim % n_heads == 0, \
+ "Model dimension should be a multiple of n_heads"
+ assert n_heads % n_groups == 0, \
+ "n_heads should be a multiple of n_groups"
+
+ super().__init__()
+
+ head_dim = dim // n_heads
+
+ self.n_heads = n_heads
+ self.n_groups = n_groups
+ self.head_dim = head_dim
+ self.apply_decoder_mask = apply_decoder_mask
+
+ self.query_embedding = query_embedding
+ self.key_embedding = key_embedding
+
+ self.wq = nn.Linear(
+ dim,
+ n_heads * head_dim,
+ bias=False,
+ dtype=dtype,
+ )
+
+ self.wk = nn.Linear(
+ dim,
+ n_groups * head_dim,
+ bias=False,
+ dtype=dtype,
+ )
+
+ self.wv = nn.Linear(
+ dim,
+ n_groups * head_dim,
+ bias=False,
+ dtype=dtype,
+ )
+
+ self.wo = nn.Linear(
+ n_heads * head_dim,
+ dim,
+ bias=False,
+ dtype=dtype,
+ )
+
+ if kv_caches is not None:
+ self.key_cache = kv_caches[0]
+ self.value_cache = kv_caches[1]
+ self.has_cache = True
+ else:
+ self.has_cache = False
+
+ def forward(self, x: torch.Tensor, cur_pos: int):
+ """
+ Processes the input tensor with Group-Query Attention.
+
+ Arguments:
+ - x (torch.Tensor): The input tensor of shape
+ (batch_size, context_length, dim).
+ - cur_pos (int): The current position in the sequence for which
+ to compute attention. This is relevant when using key-value caches,
+ as it determines the part of the cache to update and utilize.
+
+ Returns:
+ - torch.Tensor: The output tensor after applying Group-Query Attention.
+ """
+
+ batch_size, ctx_len, dim = x.shape
+
+ # Perform key, query, and value projections
+
+ # wq(x) performs all n_heads projections at once, then the result
+ # is reshaped such that the first head_dim results are part of the first
+ # head, the second head_dim results are part of the second head, and so
+ # on.
+ q = self.wq(x).view(batch_size, ctx_len, self.n_heads, self.head_dim)
+ k = self.wk(x).view(batch_size, ctx_len, self.n_groups, self.head_dim)
+ v = self.wv(x).view(batch_size, ctx_len, self.n_groups, self.head_dim)
+
+ # Apply embeddings to the key and query matrices
+ if self.query_embedding:
+ q = self.query_embedding(q, cur_pos)
+
+ if self.key_embedding:
+ k = self.key_embedding(k, cur_pos)
+
+ if self.has_cache:
+ # Add the new embeddings to the cache
+ self.key_cache.set(cur_pos, k)
+ self.value_cache.set(cur_pos, v)
+
+ # Get all the previous embedding from the cache.
+
+ # Note if cur_pos != 0, ctx_len is the length of
+ # the new sequence. In reality, the whole sequence
+ # is cur_pos + ctx_len and both cached results will
+ # be of size (batch_size, ctx_len + cur_pos, n_groups, head_dim)
+ k = self.key_cache.get(cur_pos + ctx_len)
+ v = self.value_cache.get(cur_pos + ctx_len)
+
+ # Avoid copy if multi-head attention MHA is used. This is true in the
+ # 7B and 13B models.
+ if self.n_groups != self.n_heads:
+
+ repeats = self.n_heads // self.n_groups
+
+ # Duplicate grouped attention heads:
+
+ # From: { G_0, G_1, ... G_{k - 1} }
+ # To: { G_0, G_0, ... G_0, G_1, ..., G_{k - 1}, G_{k - 1}, ..., G_{k - 1}
+ k = torch.repeat_interleave(k, dim=2, repeats=repeats)
+ v = torch.repeat_interleave(v, dim=2, repeats=repeats)
+
+ # Transpose to parallelize attention across heads during batched-matrix
+ # multiplication
+ q = q.transpose(1, 2) # (batch_size, n_heads, ctx_len, head_dim)
+ k = k.transpose(1, 2) # (batch_size, n_heads, ctx_len, head_dim)
+ v = v.transpose(1, 2) # (batch_size, n_heads, ctx_len, head_dim)
+
+ if self.apply_decoder_mask:
+ # Construct attention mask
+
+ # In the decoder architecture, the mask is a lower triangular matrix that prevents
+ # previous tokens from attending to subsequent ones. More concretely for attention
+ # scores (i, j), token i cannot attend to token j if j > i.
+
+ # When key-value caching is enabled, we are only computing the attention scores
+ # for the new sequence. Thus, the matrix of scores is of size (ctx_len, total_len)
+ # and the only masked entries are (i, j) for j > cached_len + i since row i really
+ # represents token cached_len + i.
+ mask = torch.hstack([
+ torch.ones((ctx_len, cur_pos)),
+ torch.tril(torch.ones((ctx_len, ctx_len))),
+ ])
+ else:
+ mask = None
+
+ # Perform attention
+ x, _ = attention(q, k, v, mask)
+
+ # Concatenate heads
+ x = x.transpose(1, 2) # (batch_size, ctx_len, n_heads, head_dim)
+ x = x.reshape((batch_size, ctx_len, dim))
+
+ # Final linear layer
+ x = self.wo(x)
+
+ return x
+
+class LlamaTransformerLayer(nn.Module):
+ """
+ This constitutes a single transformer block within Meta's Llama architecture.
+
+ The transformer architecture combines group-query attention (GQA) and key-value caching.
+
+ It also utilizes RMSNorm to decrease co-variance shifts during training and skip connections
+ which make training easier.
+ """
+
+ def __init__(self, dim: int, n_heads: int, n_groups: Optional[int], max_context_len: int,
+ max_batch_size: int, round_ff_to_multiple: int, eps: float = 1e-6):
+ """Initializes a layer of the Lamma transformer."""
+
+ super().__init__()
+
+ head_dim = dim // n_heads
+
+ self.query_embedding = RotaryEmbeddings(head_dim, max_context_len)
+ self.key_embedding = RotaryEmbeddings(head_dim, max_context_len)
+
+ cache_size = n_groups if n_groups else n_heads
+
+ self.key_cache = LinearCache(
+ max_batch_size, max_context_len, (cache_size, head_dim), dtype=torch.bfloat16
+ )
+ self.value_cache = LinearCache(
+ max_batch_size, max_context_len, (cache_size, head_dim), dtype=torch.bfloat16
+ )
+
+ self.gqa = GQA(
+ dim, n_heads, n_groups,
+ query_embedding=self.query_embedding,
+ key_embedding=self.key_embedding,
+ kv_caches=(self.key_cache, self.value_cache),
+ dtype=torch.bfloat16,
+ apply_decoder_mask=True,
+ )
+
+ # It might have been better to specify the inner "hidden" feed-forward
+ # dimension directly as a hyper parameter. It seems that FAIR chose
+ # this odd ratio from the [SwiGLU paper](https://arxiv.org/pdf/2002.05202.pdf)
+ # directly. This seems slightly odd as this ratio was initially used only for
+ # the purposes of enabling a fair comparison across different feed-forward
+ # configurations.
+ dim_ff = _round_up_to_multiple(4 * int(2 * dim / 3), round_ff_to_multiple)
+
+ self.feed_forward = SwiGLU(dim, dim_ff, dtype=torch.bfloat16)
+
+ self.attention_norm = RMSNorm(dim, eps, dtype=torch.bfloat16)
+ self.forward_norm = RMSNorm(dim, eps, dtype=torch.bfloat16)
+
+ def forward(self, x: torch.Tensor, cur_pos: int = 0) -> torch.Tensor:
+ """
+ Takes as an input the input embeddings or previous decoder output
+ and produces the output of this decoder
+ """
+
+ # RMS Norm
+ x_norm = self.attention_norm(x)
+ # GQA with a skip connection
+ # See ResNet at https://arxiv.org/pdf/1512.03385.pdf for skip connections
+ h = x + self.gqa(x_norm, cur_pos=cur_pos)
+ # RMS Norm
+ h_norm = self.forward_norm(h)
+ # SwiGLU feed-forward with a skip connection
+ h = h + self.feed_forward(h_norm)
+
+ return h
+
+class LlamaDecoderStack(nn.Module):
+ """
+ The decoder stack is a stack of n_layers of decoders.
+ """
+
+ def __init__(self, args: LlamaArgs):
+ """Initializes the decoder stack"""
+
+ super().__init__()
+
+ layer = LlamaTransformerLayer(
+ args.dim, args.n_heads, args.n_groups, args.max_ctx_len,
+ args.max_batch_size, args.multiple_of, args.norm_eps
+ )
+
+ self.decoders = nn.ModuleList([
+ copy.deepcopy(layer) for _ in range(args.n_layers)
+ ])
+
+ def forward(self, embedding: torch.Tensor, cur_pos: int = 0) -> torch.Tensor:
+ """Apply all encoders, obtaining the outputs of the last decoder"""
+
+ h = embedding
+
+ for decoder in self.decoders:
+ h = decoder(h, cur_pos)
+
+ return h
+
+class LlamaEmbeddings(nn.Module):
+ """
+ LlamaEmbeddings transform a tensor of token ids into embedding vectors of size dim
+ """
+
+ def __init__(self, vocab_size: int, dim: int, padding_idx: int):
+ """Initializes the LlamaEmbeddings"""
+
+ super().__init__()
+
+ self.vocab_size = vocab_size
+ self.dim = dim
+ self.padding_idx = padding_idx
+
+ self.embedding = nn.Embedding(self.vocab_size, self.dim, dtype=torch.bfloat16)
+
+ def forward(self, context: torch.Tensor) -> torch.Tensor:
+ """Retrieve the embeddings for a token sequence"""
+
+ # The original Llama implementation employs parallel embeddings. This
+ # implicitly produces zero embeddings for padding_idx = -1. This behavior
+ # is seemingly undefined and relies on implementation details within
+ # the parallel embeddings.
+
+ # Since nn.Embedding does not handle negative indices, we must manually
+ # zero out the padded parts of the context.
+ padding_mask = context == torch.tensor(self.padding_idx, dtype=torch.long)
+ context[padding_mask] = torch.tensor(0, dtype=torch.long)
+
+ embeddings = self.embedding(context)
+
+ embeddings[padding_mask] = torch.zeros((self.dim,), dtype=embeddings.dtype)
+
+ return embeddings
+
+class Llama(LLM):
+ """An class representing the Llama family of LLMs"""
+
+ def __init__(self, args : LlamaArgs):
+ """Initialize the Llama model"""
+
+ super().__init__()
+
+ self.context_length = args.max_ctx_len
+ self.max_batch_size = args.max_batch_size
+
+ self.embeddings = LlamaEmbeddings(
+ args.vocab_size, args.dim, args.padding_idx
+ )
+
+ self.decoder_stack = LlamaDecoderStack(args)
+
+ self.output_norm = RMSNorm(
+ args.dim, eps=args.norm_eps, dtype=torch.bfloat16
+ )
+
+ self.vocab_map = nn.Linear(
+ args.dim, args.vocab_size, bias=False, dtype=torch.bfloat16
+ )
+
+ def forward(self, context: torch.Tensor, cur_pos: int = 0) -> torch.Tensor:
+ """
+ Computes the log probabilities of the next token given a sequence of
+ tokens as context.
+
+ Args:
+ context (torch.Tensor): A tensor of shape (batch_size, context_length)
+ containing token ids. These tokens serve as the
+ context for predicting the next token.
+
+ cur_pos (int, optional): The position at which to start the
+ prediction. If cur_pos is not zero,
+ the internal cache (if available) will
+ be used to speed up predictions.
+ Defaults to 0.
+
+ Returns:
+ torch.Tensor: A tensor of shape (batch_size, vocab_size) containing
+ the log probabilities of the next token given the
+ context.
+
+ Examples:
+ # Predict the next token for a sequence [1, 2, 3]
+ log_probs = llm(torch.tensor([[1, 2, 3]], dtype=torch.long), 0)
+
+ # Predict the next token for a sequence [1, 2, 3, 4, 5] using the
+ # cache starting at position 3
+ log_probs = llm(torch.tensor([[4, 5]], dtype=torch.long), 3)
+ """
+
+ embeddings = self.embeddings(context) # ( n_batches, n_embeddings, dim )
+
+ h = self.decoder_stack(embeddings, cur_pos)
+
+ h = self.output_norm(h)
+
+ vocab_logits = self.vocab_map(h)
+
+ return vocab_logits[:,-1]