aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.pylintrc6
-rw-r--r--README.md31
-rw-r--r--inference_example.py51
-rw-r--r--llama/__init__.py18
-rw-r--r--llama/generate.py322
-rw-r--r--llama/llm.py74
-rw-r--r--llama/model.py722
-rw-r--r--llama/tokenizer.py89
-rw-r--r--llama/utils.py153
9 files changed, 1466 insertions, 0 deletions
diff --git a/.pylintrc b/.pylintrc
new file mode 100644
index 0000000..bb13cc1
--- /dev/null
+++ b/.pylintrc
@@ -0,0 +1,6 @@
+[MASTER]
+extension-pkg-whitelist=numpy,torch,sentencepiece
+
+[TYPECHECK]
+ignored-modules=numpy,torch,sentencepiece
+ignored-classes=numpy,torch,sentencepiece
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..bc573b4
--- /dev/null
+++ b/README.md
@@ -0,0 +1,31 @@
+MyLlama2
+========
+
+This is a ground-up reimplementation of the Llama 2 family of large language models.
+It adheres to the exact same architecture, based on a decoder-only transformer model
+equipped with Group-Query Attention (GQA), key-value caching, SwiGLU feedforward layers,
+and SentencePiece embeddings. It is compatible with the original Llama 2 weights.
+Unlike Meta's model, this implementation does not incorporate parallel layers or any
+distributed processing APIs. Consequently, it can only run on a single GPU and is also
+capable of running on a CPU without the need for special tools (e.g., `torchrun`, `mpi`, etc.).
+
+This model was created for demonstration purposes, with the intent of sharing it with the
+community. During its development, I identified a few minor issues in FAIR's
+implementation, which I plan to contribute back through pull requests. I believe this
+implementation is more accessible for those new to AI, and I've included references to the papers
+where these concepts were first introduced. However, this code has not been extensively reviewed.
+For production projects, I recommend starting with [Meta's implementation](https://github.com/facebookresearch/llama).
+For high-performance CPU-only inference, consider compiling
+[llama.cpp](https://github.com/ggerganov/llama.cpp) while targeting the native architecture.
+
+Example usage:
+--------------
+
+```bash
+python inference_example.py \
+ llama/llama-2-7b \
+ ./tokenizer.model \
+ --top_p 0.8 \
+ --max_generation_length 100 \
+ --context "Four score and seven years ago our fathers brought forth on this continent, a new nation, conceived in Liberty, and dedicated to the proposition that all men are created equal."
+```
diff --git a/inference_example.py b/inference_example.py
new file mode 100644
index 0000000..403ceb9
--- /dev/null
+++ b/inference_example.py
@@ -0,0 +1,51 @@
+import argparse
+
+from llama import (
+ load_llama_from_checkpoint,
+ generate_token_sequence,
+)
+
+def main(args: argparse.Namespace) -> None:
+ llama, tokenizer = load_llama_from_checkpoint(
+ args.model_directory,
+ args.tokenizer
+ )
+
+ context: str = args.context
+
+ prompt = tokenizer.encode(context, True, False)
+
+ print(f'Prompt: {context}')
+ print(f'Generated: ', end='')
+
+ for token in generate_token_sequence(llama,
+ prompt,
+ top_p = args.top_p,
+ max_generation_length = args.max_generation_length):
+ piece = tokenizer.id_to_piece(token)
+ print(piece, end='', flush=True)
+ print()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Generate text using a Llama model.')
+
+ parser.add_argument('model_directory', type=str,
+ help='Path to the directory containing the Llama model.')
+ parser.add_argument('tokenizer', type=str,
+ help='Path to the tokenizer model file.')
+ parser.add_argument('--context', type=str, default='Hello, world!',
+ help='Initial context to seed the Llama model.')
+ parser.add_argument('--max_generation_length', type=int, default=None,
+ help='Maximum length of the generated sequence.')
+ parser.add_argument('--top_p', type=float, default=0.80,
+ help='The cumulative distribution function (CDF) to use for sampling.')
+
+ try:
+ import argcomplete
+ argcomplete.autocomplete(parser)
+ except ImportError:
+ pass
+
+ args = parser.parse_args()
+
+ main(args)
diff --git a/llama/__init__.py b/llama/__init__.py
new file mode 100644
index 0000000..2b9a9ef
--- /dev/null
+++ b/llama/__init__.py
@@ -0,0 +1,18 @@
+"""
+Llama2 model, loading infrastructure, and sampling helpers
+"""
+
+from .model import Llama
+
+from .generate import (
+ sample_sequence,
+ sample_batched_sequence,
+ generate_token_sequence,
+ generate_batched_token_sequence,
+ token_probabilities,
+ batched_token_probabilities
+)
+
+from .utils import (
+ load_llama_from_checkpoint
+)
diff --git a/llama/generate.py b/llama/generate.py
new file mode 100644
index 0000000..e735ca4
--- /dev/null
+++ b/llama/generate.py
@@ -0,0 +1,322 @@
+"""
+This module provides an interface for generating and sampling token sequences from a language model.
+It allows for the controlled generation of text by specifying parameters such as temperature, top-k,
+and top-p, which influence the randomness and quality of the generated sequences.
+"""
+
+# pylint: disable=locally-disabled, R0913, R0914
+
+from collections import OrderedDict
+from typing import List, Optional, Generator, cast
+
+import torch
+
+from .llm import LLM
+
+TokenList = OrderedDict[int, float]
+
+def _sample_internal(llm: LLM, context: torch.Tensor) -> torch.Tensor:
+ """
+ Sample a tensor of logits from the language model (LLM) based on the input context.
+ """
+
+ batch_size, seq_len = context.shape
+
+ assert seq_len <= llm.context_length
+ assert batch_size <= llm.max_batch_size
+
+ with torch.inference_mode():
+ return llm(context)
+
+def _load_context(tokens: List[List[int]], pad_id: int,
+ pad_to_length: Optional[int] = None) -> torch.Tensor:
+ """
+ Load a batch of token lists into a padded tensor suitable for input to a language model.
+ """
+ batch_size = len(tokens)
+
+ max_token_len = max((len(tok) for tok in tokens))
+
+ pad_to_length = max_token_len if pad_to_length is None else pad_to_length
+
+ context = torch.full(
+ (batch_size, pad_to_length), pad_id, dtype=torch.long
+ )
+
+ for dim, toks in enumerate(tokens):
+ context[dim, :len(toks)] = torch.tensor(toks, dtype=torch.long)
+
+ return context
+
+def batched_token_probabilities(llm: LLM,
+ tokens: List[List[int]],
+ temperature: float = 1.0) -> List[TokenList]:
+ """
+ Calculate the probabilities of the next token sequence across a batch of sequences.
+
+ Args:
+ - llm (LLM): An instance of the language model.
+ - tokens (List[List[int]]): A list of tokenized input sequences.
+ - temperature (float): A temperature parameter to scale the logits before
+ applying softmax. Default is 1.0.
+
+ Returns:
+ - List[TokenList]: A list of ordered dictionaries where each dictionary maps
+ token ids to their corresponding probabilities for each
+ sequence in the batch.
+ """
+
+ context = _load_context(tokens, llm.padding_idx)
+
+ token_logprobs = _sample_internal(llm, context)
+
+ token_probs = torch.softmax(token_logprobs / temperature, dim=-1)
+
+ samples: List[TokenList] = [OrderedDict() for _ in range(len(tokens))]
+ for i, p in enumerate(token_probs):
+ for _id in torch.argsort(p, descending=True):
+ samples[i][int(_id)] = float(p[_id])
+
+ return samples
+
+def token_probabilities(llm: LLM, tokens: List[int]) -> TokenList:
+ """
+ Calculate the probabilities of the next token sequence.
+ See batched_token_probabilities.
+ """
+
+ return batched_token_probabilities(llm, [ tokens ])[0]
+
+def sample_batched_token(
+ token_logprobs: torch.Tensor,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: float = 1.0,
+ sample_eps: float = 1e-6) -> torch.Tensor:
+ """
+ Sample a token from a batch of token logits with optional top-k and top-p filtering.
+
+
+ Args:
+ - token_logprobs (torch.Tensor): A tensor of token log probabilities.
+ - temperature (float): A scaling factor for logits before sampling. Default
+ is 1.0.
+ - top_k (Optional[int]): If set, the sampling is restricted to the top k
+ tokens. Default is None (no restriction).
+ - top_p (float): If set, the sampling is restricted to the smallest set
+ of tokens with cumulative probability exceeding top_p.
+ Default is 1.0 (no restriction).
+ - sample_eps (float): An epsilon value to avoid precision errors during
+ cumulative probability calculation. Default is 1e-6.
+
+ Returns:
+ - torch.Tensor: A tensor of sampled token ids for each item in the batch.
+
+ Implements both top-k sampling, top-p sampling, and beam search.
+
+ See:
+ - https://arxiv.org/pdf/1805.04833.pdf for top-k sampling
+ - https://arxiv.org/pdf/1904.09751.pdf for top-p sampling
+ """
+
+ batch_size = token_logprobs.shape[0]
+
+ token_probs = torch.softmax(token_logprobs / temperature, dim=-1)
+
+
+ selected_tokens = torch.zeros(batch_size, dtype=torch.long)
+
+ sorted_tokens = torch.argsort(token_probs, descending=True)
+ sorted_probs = torch.gather(token_probs, 1, sorted_tokens)
+ nucleus_mask = sorted_probs.cumsum(dim=-1) < top_p + sample_eps
+ nucleus_mask[:,0] = True
+
+ for i, (tokens, mask, probs) in enumerate(zip(sorted_tokens, nucleus_mask, sorted_probs)):
+ nucleus = tokens[mask]
+ p = probs[mask]
+
+ if top_k is not None and len(nucleus) > top_k:
+ nucleus = nucleus[:top_k]
+ p = p[:top_k]
+
+ p /= p.sum(axis=0)
+
+ token = nucleus[torch.multinomial(p, 1)]
+
+ selected_tokens[i] = token
+
+ return selected_tokens
+
+def generate_batched_token_sequence(llm: LLM,
+ prompts: List[List[int]],
+ max_generation_length: Optional[int] = None,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: float = 1.0) -> Generator[List[Optional[int]], None, None]:
+ """
+ Generate a sequence of tokens for each prompt across a sequence of batches.
+
+ Args:
+ - llm (LLM): An instance of the language model.
+ - prompts (List[List[int]]): A list of tokenized input sequences (prompts).
+ - max_generation_length (Optional[int]): The maximum number of tokens to
+ generate for each prompt. If None, generate up to the model's maximum
+ context length. Default is None.
+ - temperature (float): A scaling factor for logits before sampling. Default
+ is 1.0.
+ - top_k (Optional[int]): If set, restricts sampling to the top k most
+ likely tokens. Default is None (no restriction).
+ - top_p (float): If set, restricts sampling to a subset of tokens with a
+ cumulative probability greater than top_p. Default is 1.0
+ (no restriction).
+
+ Yields:
+ - Generator[List[Optional[int]], None, None]: A generator that yields lists
+ of token ids, where each list corresponds to one prompt in the batch.
+ Yields none if a token was not generated during an iteration of inference.
+
+ Raises:
+ - AssertionError: If batch size exceeds the maximum allowed by the LLM, or
+ if the requested generation length is too long.
+ """
+
+ batch_size = len(prompts)
+ assert batch_size <= llm.max_batch_size, \
+ "Batch size exceeded the maximum batch size of the LLM"
+
+ prompt_lens = torch.tensor([len(p) for p in prompts], dtype=torch.long)
+ max_prompt_len = max(prompt_lens)
+
+ remaining_context = llm.context_length - max_prompt_len
+ if max_generation_length is None:
+ max_generation_length = remaining_context
+ else:
+ assert max_generation_length <= remaining_context, \
+ "Cannot generate more tokens than exist in the context"
+
+ eos = torch.zeros(batch_size, dtype=torch.long)
+ last_pos = 0
+
+ end_pos = max_prompt_len + max_generation_length
+ context = _load_context(prompts, llm.padding_idx, pad_to_length=end_pos)
+
+ start_pos = max(prompt_lens)
+
+ for pos in range(start_pos, end_pos):
+ log_probs = llm(context[:, last_pos:pos], last_pos)
+
+ sampled = sample_batched_token(
+ log_probs,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p
+ )
+
+ in_prompt = pos < prompt_lens
+ should_replace_mask = (eos == 0) & (~in_prompt)
+
+ yield [int(sampled[i]) if should_replace_mask[i] else None for i in range(batch_size)]
+
+ context[should_replace_mask, pos] = sampled[should_replace_mask]
+ eos[(eos > 0) & (sampled == llm.eos_token)] = pos + 1
+ last_pos = pos
+
+ if (eos > 0).all():
+ break
+
+def generate_token_sequence(llm: LLM,
+ prompt: List[int],
+ max_generation_length: Optional[int] = None,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: float = 1.0) -> Generator[int, None, None]:
+ """
+ Generate a sequence of tokens for a single prompt.
+ See generate_batched_token_sequence.
+ """
+
+ for tokens in generate_batched_token_sequence(llm,
+ [ prompt ],
+ max_generation_length=max_generation_length,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p):
+ yield cast(int, tokens[0])
+
+def sample_batched_sequence(llm: LLM,
+ prompts: List[List[int]],
+ max_generation_length: Optional[int] = None,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: float = 1.0,
+ include_prompt: bool = False) -> List[List[int]]:
+ """
+ Generate and sample a sequence of tokens for each input sequence in a batch.
+
+ Args:
+ - llm (LLM): An instance of the language model.
+ - prompts (List[List[int]]): A list of tokenized input sequences (prompts).
+ - max_generation_length (Optional[int]): The maximum number of tokens to
+ generate for each prompt. Defaults to None, which allows the generation
+ up to the model's maximum context length.
+ - temperature (float): A scaling factor for logits before sampling,
+ affecting the randomness of the output. Default is 1.0, with lower values
+ leading to less random samples.
+ - top_k (Optional[int]): Limits the sampling pool to the top k tokens
+ according to the probability distribution. Default is None, indicating no
+ limit.
+ - top_p (float): The cumulative probability threshold for nucleus sampling;
+ allows sampling from a set of high-probability tokens whose cumulative
+ probability exceeds this threshold. Default is 1.0, indicating no limit.
+ - include_prompt (bool): If True, includes the input prompt at the beginning
+ of the generated sequence. Default is False.
+
+ Returns:
+ - List[List[int]]: A list of lists containing the sampled token sequences
+ for each input prompt. The sequences include the generated tokens and,
+ if include_prompt is True, the original input prompt tokens.
+ """
+
+ sampled_seqs: List[List[int]] = [[] for _ in range(len(prompts))]
+
+ if include_prompt:
+ for i, prompt in enumerate(prompts):
+ sampled_seqs[i].extend(prompt)
+
+ for generated_tokens in generate_batched_token_sequence(llm,
+ prompts,
+ max_generation_length,
+ temperature,
+ top_k,
+ top_p):
+ for i, token in enumerate(generated_tokens):
+ if token is not None:
+ sampled_seqs[i].append(token)
+
+ return sampled_seqs
+
+def sample_sequence(llm: LLM,
+ prompts: List[int],
+ max_generation_length: Optional[int] = None,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: float = 1.0,
+ include_prompt: bool = False) -> List[int]:
+ """
+ Generate and sample a sequence of tokens for a single input sequence.
+ See sample_batched_sequence for reference.
+ """
+
+ return sample_batched_sequence(
+ llm, [prompts], max_generation_length, temperature,
+ top_k, top_p, include_prompt
+ )[0]
+
+__all__ = [
+ 'token_probabilities',
+ 'sample_batched_token',
+ 'sample_sequence',
+ 'sample_batched_sequence',
+ 'generate_token_sequence',
+ 'generate_batched_token_sequence',
+]
diff --git a/llama/llm.py b/llama/llm.py
new file mode 100644
index 0000000..7192b31
--- /dev/null
+++ b/llama/llm.py
@@ -0,0 +1,74 @@
+"""
+LLM provides a generalized interface for autoregressive
+next-word prediction models. The class can be utilized for tasks such as text
+sampling and probability prediction over a vocabulary.
+"""
+
+import torch
+from torch import nn
+
+class LLM(nn.Module):
+ """
+ LLM provides a generalized interface for autoregressive
+ next-word prediction models. The class can be utilized for tasks such as text
+ sampling and probability prediction over a vocabulary.
+
+ Attributes:
+ context_length (int): Length of the context window for the
+ autoregressive model. Default is -1, which
+ indicates that this needs to be set.
+
+ max_batch_size (int): The maximum size of a batch that can be processed.
+ Default is -1, which indicates that this needs to
+ be set.
+
+ vocab_size (int): The size of the vocabulary used in the model.
+ Default is -1, which indicates that this needs to
+ be set.
+
+ padding_idx (int): The index used for padding in mixed-length batches.
+ Default is -1, which indicates that this needs to be
+ set.
+
+ eos_token (int): Token index that signifies the end of a sequence during
+ auto-regressive generation. Default is -1, which
+ indicates that this needs to be set.
+ """
+
+ context_length = -1
+ max_batch_size = -1
+ vocab_size = -1
+ padding_idx = -1
+ eos_token = -1
+
+ def forward(self, context: torch.Tensor, cur_pos: int = 0) -> torch.Tensor:
+ """
+ Computes the log probabilities of the next token given a sequence of
+ tokens as context.
+
+ Args:
+ context (torch.Tensor): A tensor of shape (batch_size, context_length)
+ containing token ids. These tokens serve as the
+ context for predicting the next token.
+
+ cur_pos (int, optional): The position at which to start the
+ prediction. If cur_pos is not zero,
+ the internal cache (if available) will
+ be used to speed up predictions.
+ Defaults to 0.
+
+ Returns:
+ torch.Tensor: A tensor of shape (batch_size, vocab_size) containing
+ the log probabilities of the next token given the
+ context.
+
+ Examples:
+ # Predict the next token for a sequence [1, 2, 3]
+ log_probs = llm(torch.tensor([[1, 2, 3]], dtype=torch.long), 0)
+
+ # Predict the next token for a sequence [1, 2, 3, 4, 5] using the
+ # cache starting at position 3
+ log_probs = llm(torch.tensor([[4, 5]], dtype=torch.long), 3)
+ """
+
+ raise NotImplementedError()
diff --git a/llama/model.py b/llama/model.py
new file mode 100644
index 0000000..6417c83
--- /dev/null
+++ b/llama/model.py
@@ -0,0 +1,722 @@
+"""
+The core model for the Llama family of LLMs
+"""
+
+import math
+import copy
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Dict
+
+import torch
+import torch.nn.functional as F
+
+from torch import nn
+from .llm import LLM
+
+# pylint: disable=locally-disabled, R0902, R0913
+
+def _round_up_to_multiple(n: int, m: int) -> int:
+ """
+ Round n up to an integer multiple of m
+ """
+ return math.ceil(n / m) * m
+
+@dataclass
+class LlamaArgs:
+ """
+ Arguments class for configuring a LLAMA model.
+
+ Attributes:
+ dim (int): The model dimension, typically referred to as d_model in
+ "Attention is All You Need" paper.
+ n_layers (int): The number of layers in the model.
+ n_heads (int): The number of attention heads in each layer.
+ vocab_size (int): The size of the model's vocabulary.
+ multiple_of (int): Ensures the feed-forward network dimension (d_ff)
+ is a multiple of this factor.
+ norm_eps (float): The epsilon value for RMS normalization, avoiding
+ division by zero.
+ max_ctx_len (int, optional): The maximum context length the model can
+ handle. Defaults to 2048.
+ max_batch_size (int, optional): The maximum batch size supported by the
+ model's cache. Defaults to 1.
+ n_groups (Optional[int], optional): The number of key-value groups in
+ grouped query-attention (GQA), if applicable. Defaults to None.
+ padding_idx (int): The index used for padding in embeddings. Defaults
+ to -1.
+ """
+
+ dim: int
+ n_layers: int
+ n_heads: int
+ vocab_size: int
+ multiple_of: int
+ norm_eps: float
+ max_ctx_len: int = 2048
+ max_batch_size: int = 1
+ n_groups: Optional[int] = None
+ padding_idx: int = -1
+
+class RMSNorm(nn.Module):
+ """
+ Implements an unbiased Root Mean Square (RMS) Layer Normalization.
+
+ Reference:
+ See the paper "Root Mean Square Layer Normalization" at
+ https://arxiv.org/pdf/1910.07467.pdf for more details.
+
+ Attributes:
+ eps (float): A small epsilon value added to the denominator for
+ numerical stability.
+ gain (nn.Parameter): A learnable gain parameter applied after
+ normalization.
+ """
+
+ def __init__(self, d: int, eps: float = 1e-6, dtype: torch.dtype = torch.float):
+ """
+ Initializes the RMSNorm layer.
+
+ Args:
+ d (int): The dimensionality of the input feature space.
+ eps (float, optional): A small epsilon value to add to the
+ denominator for numerical stability. Defaults to 1e-6.
+ dtype (torch.dtype, optional): The data type of the learnable gain
+ parameter. Defaults to torch.float.
+ """
+ super().__init__()
+ self.eps = eps
+ self.gain = nn.Parameter(torch.ones(d, dtype=dtype))
+
+
+ def forward(self, a: torch.Tensor) -> torch.Tensor:
+ """
+ Applies RMS normalization to the input tensor.
+
+ Args:
+ a (torch.Tensor): The input tensor to be normalized.
+
+ Returns:
+ torch.Tensor: The normalized tensor with the same shape as the input.
+ """
+
+ inverse_rms = torch.rsqrt(self.eps + torch.mean(a ** 2, dim=-1, keepdim=True))
+
+ return a * inverse_rms * self.gain
+
+
+class SwiGLU(nn.Module):
+ """
+ Implements the SwiGLU variant of the Gated Linear Unit (GLU) as part of the
+ FFN layer of a transformer. SwiGLU is a variant of the Gated Linear Unit
+ where the gating mechanism is controlled by a Swish activation function.
+
+ Reference:
+ The SwiGLU activation function is detailed in the paper "GLU Variants Improve Transformer"
+ which can be accessed at https://arxiv.org/pdf/2002.05202.pdf.
+ """
+
+ def __init__(self, dim : int, dim_ff: int, dtype: torch.dtype = torch.float):
+ """
+ Initializes the SwiGLU module.
+
+ Arguments:
+ dim (int): The dimensionality of the input and output tensors.
+ dim_ff (int): The reduced dimensionality of the hidden layer.
+ dtype (torch.dtype, optional): The data type for the weights of
+ the linear transformations. Defaults to torch.float.
+ """
+ super().__init__()
+
+ self.w = nn.Linear(dim, dim_ff, bias=False, dtype=dtype)
+ self.v = nn.Linear(dim, dim_ff, bias=False, dtype=dtype)
+ self.w2 = nn.Linear(dim_ff, dim, bias=False, dtype=dtype)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Applies the SwiGLU feed-forward layer
+
+ Arguments:
+ x (torch.Tensor): The input tensor to the SwiGLU module.
+
+ Returns:
+ torch.Tensor: The output tensor after applying the SwiGLU operation.
+ """
+ return self.w2(F.silu(self.w(x)) * self.v(x))
+
+class RotaryEmbeddings(nn.Module):
+ """
+ Implementation of rotary position embeddings.
+
+ Rotary embeddings are a mechanism for injecting positional information into
+ transformer models. These embeddings apply a rotation to the key and value
+ vectors in the attention mechanism based on their position, with different
+ "dampening" factors applied based on the relative distance between two tokens.
+
+ Args:
+ - dim (int): The dimension of the embeddings.
+ - max_ctx_len (int): The maximum length of the context for which to compute
+ the embeddings.
+ - theta (float, optional): The frequency parameter for computing the rotary
+ embeddings. Defaults to 10000.0.
+
+ Raises:
+ AssertionError: If the dimension is not even.
+
+ References:
+ - RoFormer paper: https://arxiv.org/pdf/2104.09864.pdf
+ """
+
+ embedding_cache: Dict[int, torch.Tensor] = {}
+
+ def __init__(self, dim: int, max_ctx_len: int, theta: float = 10000.0):
+ """
+ Initialize the RotaryEmbeddings module.
+
+ Args:
+ - dim (int): The dimension of the embeddings.
+ - max_ctx_len (int): The maximum length of the context for which
+ to compute the embeddings.
+ - theta (float, optional): The frequency parameter for computing
+ the rotary embeddings. Defaults to 10000.0.
+
+ Raises:
+ AssertionError: If the dimension is not even.
+ """
+ super().__init__()
+
+ assert dim % 2 == 0, "Model dimension should be a multiple of two"
+
+ self.n_coord_pairs = dim // 2
+ self.rots = RotaryEmbeddings.get_embeddings(dim, max_ctx_len, theta)
+
+ @staticmethod
+ def compute_angles(dim: int, max_ctx_len: int, theta: float) -> torch.Tensor:
+ """
+ Compute the rotation angles for the embeddings.
+
+ Arguments:
+ dim (int): The dimension of the embeddings.
+ max_ctx_len (int): The maximum context length.
+ theta (float): The frequency parameter for the embeddings.
+
+ Returns:
+ torch.Tensor: A tensor of shape (max_ctx_len, dim // 2) containing the
+ rotation angles.
+ """
+
+ freqs = theta ** (-torch.arange(0, dim, 2, dtype=torch.float) / dim)
+
+ m = torch.arange(max_ctx_len)
+
+ angles = torch.outer(m, freqs)
+
+ return torch.polar(torch.ones((max_ctx_len, dim // 2)), angles)
+
+ @staticmethod
+ def get_embeddings(dim: int, max_ctx_len: int, theta: float) -> torch.Tensor:
+ """
+ Retrieve or compute and cache the rotary embeddings.
+
+ Args:
+ - dim (int): The dimension of the embeddings.
+ - max_ctx_len (int): The maximum context length.
+ - theta (float): The frequency parameter for the embeddings.
+
+ Returns:
+ - torch.Tensor: A tensor containing the precomputed embeddings.
+ """
+
+ cache = RotaryEmbeddings.embedding_cache
+
+ if dim not in cache:
+
+ cache[dim] = \
+ RotaryEmbeddings.compute_angles(dim, max_ctx_len, theta)
+
+ return cache[dim]
+
+ def forward(self, x: torch.Tensor, cur_pos: int = 0) -> torch.Tensor:
+ """
+ Apply the rotary embeddings to the input tensor.
+
+ Arguments:
+ - x (torch.Tensor): A tensor of shape (batch_size, ctx_len, ..., dim)
+ representing input features.
+ - cur_pos (int, optional): The current position index from which to
+ apply rotations. Defaults to 0.
+
+ Returns:
+ - torch.Tensor: The rotated tensor with the same shape as the input.
+ """
+
+ _batch_size, ctx_len, *dup_dims, dim = x.shape
+
+ rotated = x.view(*x.shape[:-1], self.n_coord_pairs, 2)
+ rotated = torch.view_as_complex(rotated.float())
+
+ broad_shape = [1, ctx_len] + [1] * len(dup_dims) + [ dim // 2 ]
+
+ rotated *= self.rots[cur_pos : cur_pos + ctx_len].view(*broad_shape)
+
+ rotated = torch.view_as_real(rotated)
+ rotated = rotated.view(*x.shape[:-1], dim).type_as(x)
+
+ return rotated
+
+def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+ mask: Optional[torch.Tensor] = None) \
+ -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Compute the scaled dot product attention.
+
+ This function takes as input the query (Q), key (K), value (V) tensors,
+ and an optional mask, and returns the attention output and attention
+ weights.
+
+ Arguments:
+ - q (torch.Tensor): The query tensor of shape (..., seq_len, d_k).
+ - k (torch.Tensor): The key tensor of shape (..., seq_len, d_k).
+ - v (torch.Tensor): The value tensor of shape (..., seq_len, d_v).
+ - mask (Optional[torch.Tensor]): An optional mask tensor to apply to
+ the scores before softmax.
+
+ Returns:
+ - Tuple[torch.Tensor, torch.Tensor]: A tuple consisting of the attention
+ output tensor and the attention weights tensor.
+
+ References:
+ - "Attention Is All You Need": https://arxiv.org/pdf/1706.03762.pdf
+ """
+
+ d_k = q.size(-1)
+
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
+
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, float("-inf"))
+
+ attn = F.softmax(scores, dim=-1)
+
+ return torch.matmul(attn, v), attn
+
+class LinearCache:
+ """
+ A simple linear-cache. This is used to cache the attention
+ keys and values.
+ """
+
+ def __init__(self, max_batch_size: int, max_context_len: int,
+ tensor_dims: Tuple, dtype: torch.dtype = torch.float):
+ """Initializes the LinearCache with given dimensions and data type."""
+
+ self.max_batch_size = max_batch_size
+ self.max_context_len = max_context_len
+ self.cache = torch.zeros(
+ (max_batch_size, max_context_len, *tensor_dims),
+ dtype=dtype,
+ )
+ self.cached_batch_size = 0
+
+ def get(self, pos: int) -> torch.Tensor:
+ """Retrieves the cached values up to a given sequence position."""
+ return self.cache[:self.cached_batch_size, :pos]
+
+ def set(self, current_pos: int, seq: torch.Tensor) -> None:
+ """Updates the cache with new sequences at the specified position."""
+
+ batch_size, ctx_len, *_ = seq.shape
+
+ self.cache[:batch_size, current_pos:current_pos+ctx_len] = seq
+
+ self.cached_batch_size = batch_size
+
+class GQA(nn.Module):
+ """
+ Group-Query Attention (GQA) module for transformer architectures.
+
+ References:
+ - See "GQA: Training Generalized Multi-Query Transformer Models from
+ Multi-Head Checkpoints" at https://arxiv.org/pdf/2305.13245.pdf
+ """
+
+ def __init__(self, dim: int, n_heads: int,
+ n_groups: Optional[int] = None,
+ query_embedding: Optional[nn.Module] = None,
+ key_embedding: Optional[nn.Module] = None,
+ apply_decoder_mask: bool = False,
+ kv_caches: Optional[Tuple[LinearCache, LinearCache]] = None,
+ dtype: torch.dtype = torch.float):
+ """
+ Initializes the Group-Query Attention (GQA) module.
+
+ Parameters:
+ dim (int): The dimensionality of the input features and the last dimension of
+ the output tensor.
+ n_heads (int): The number of attention heads to use.
+ n_groups (Optional[int]): The number of groups to divide the attention heads
+ into. If not specified, defaults to the number of heads.
+ Must divide `n_heads` evenly.
+ query_embedding (Optional[nn.Module]): An optional module to embed the query
+ vectors, e.g., a positional encoding module.
+ key_embedding (Optional[nn.Module]): An optional module to embed the key vectors,
+ similar to `query_embedding`.
+ apply_decoder_mask (bool): Whether to apply a causal mask to the attention mechanism,
+ useful for decoder self-attention.
+ kv_caches (Optional[Tuple[LinearCache, LinearCache]]): Optional tuple of
+ `LinearCache` instances for
+ caching key and value projections
+ in an autoregressive setting.
+ dtype (torch.dtype): The data type of the module's parameters, e.g., `torch.float32`.
+ The cache tensors should also use this data type.
+ """
+
+ n_groups = n_groups if n_groups else n_heads
+
+ assert dim % n_heads == 0, \
+ "Model dimension should be a multiple of n_heads"
+ assert n_heads % n_groups == 0, \
+ "n_heads should be a multiple of n_groups"
+
+ super().__init__()
+
+ head_dim = dim // n_heads
+
+ self.n_heads = n_heads
+ self.n_groups = n_groups
+ self.head_dim = head_dim
+ self.apply_decoder_mask = apply_decoder_mask
+
+ self.query_embedding = query_embedding
+ self.key_embedding = key_embedding
+
+ self.wq = nn.Linear(
+ dim,
+ n_heads * head_dim,
+ bias=False,
+ dtype=dtype,
+ )
+
+ self.wk = nn.Linear(
+ dim,
+ n_groups * head_dim,
+ bias=False,
+ dtype=dtype,
+ )
+
+ self.wv = nn.Linear(
+ dim,
+ n_groups * head_dim,
+ bias=False,
+ dtype=dtype,
+ )
+
+ self.wo = nn.Linear(
+ n_heads * head_dim,
+ dim,
+ bias=False,
+ dtype=dtype,
+ )
+
+ if kv_caches is not None:
+ self.key_cache = kv_caches[0]
+ self.value_cache = kv_caches[1]
+ self.has_cache = True
+ else:
+ self.has_cache = False
+
+ def forward(self, x: torch.Tensor, cur_pos: int):
+ """
+ Processes the input tensor with Group-Query Attention.
+
+ Arguments:
+ - x (torch.Tensor): The input tensor of shape
+ (batch_size, context_length, dim).
+ - cur_pos (int): The current position in the sequence for which
+ to compute attention. This is relevant when using key-value caches,
+ as it determines the part of the cache to update and utilize.
+
+ Returns:
+ - torch.Tensor: The output tensor after applying Group-Query Attention.
+ """
+
+ batch_size, ctx_len, dim = x.shape
+
+ # Perform key, query, and value projections
+
+ # wq(x) performs all n_heads projections at once, then the result
+ # is reshaped such that the first head_dim results are part of the first
+ # head, the second head_dim results are part of the second head, and so
+ # on.
+ q = self.wq(x).view(batch_size, ctx_len, self.n_heads, self.head_dim)
+ k = self.wk(x).view(batch_size, ctx_len, self.n_groups, self.head_dim)
+ v = self.wv(x).view(batch_size, ctx_len, self.n_groups, self.head_dim)
+
+ # Apply embeddings to the key and query matrices
+ if self.query_embedding:
+ q = self.query_embedding(q, cur_pos)
+
+ if self.key_embedding:
+ k = self.key_embedding(k, cur_pos)
+
+ if self.has_cache:
+ # Add the new embeddings to the cache
+ self.key_cache.set(cur_pos, k)
+ self.value_cache.set(cur_pos, v)
+
+ # Get all the previous embedding from the cache.
+
+ # Note if cur_pos != 0, ctx_len is the length of
+ # the new sequence. In reality, the whole sequence
+ # is cur_pos + ctx_len and both cached results will
+ # be of size (batch_size, ctx_len + cur_pos, n_groups, head_dim)
+ k = self.key_cache.get(cur_pos + ctx_len)
+ v = self.value_cache.get(cur_pos + ctx_len)
+
+ # Avoid copy if multi-head attention MHA is used. This is true in the
+ # 7B and 13B models.
+ if self.n_groups != self.n_heads:
+
+ repeats = self.n_heads // self.n_groups
+
+ # Duplicate grouped attention heads:
+
+ # From: { G_0, G_1, ... G_{k - 1} }
+ # To: { G_0, G_0, ... G_0, G_1, ..., G_{k - 1}, G_{k - 1}, ..., G_{k - 1}
+ k = torch.repeat_interleave(k, dim=2, repeats=repeats)
+ v = torch.repeat_interleave(v, dim=2, repeats=repeats)
+
+ # Transpose to parallelize attention across heads during batched-matrix
+ # multiplication
+ q = q.transpose(1, 2) # (batch_size, n_heads, ctx_len, head_dim)
+ k = k.transpose(1, 2) # (batch_size, n_heads, ctx_len, head_dim)
+ v = v.transpose(1, 2) # (batch_size, n_heads, ctx_len, head_dim)
+
+ if self.apply_decoder_mask:
+ # Construct attention mask
+
+ # In the decoder architecture, the mask is a lower triangular matrix that prevents
+ # previous tokens from attending to subsequent ones. More concretely for attention
+ # scores (i, j), token i cannot attend to token j if j > i.
+
+ # When key-value caching is enabled, we are only computing the attention scores
+ # for the new sequence. Thus, the matrix of scores is of size (ctx_len, total_len)
+ # and the only masked entries are (i, j) for j > cached_len + i since row i really
+ # represents token cached_len + i.
+ mask = torch.hstack([
+ torch.ones((ctx_len, cur_pos)),
+ torch.tril(torch.ones((ctx_len, ctx_len))),
+ ])
+ else:
+ mask = None
+
+ # Perform attention
+ x, _ = attention(q, k, v, mask)
+
+ # Concatenate heads
+ x = x.transpose(1, 2) # (batch_size, ctx_len, n_heads, head_dim)
+ x = x.reshape((batch_size, ctx_len, dim))
+
+ # Final linear layer
+ x = self.wo(x)
+
+ return x
+
+class LlamaTransformerLayer(nn.Module):
+ """
+ This constitutes a single transformer block within Meta's Llama architecture.
+
+ The transformer architecture combines group-query attention (GQA) and key-value caching.
+
+ It also utilizes RMSNorm to decrease co-variance shifts during training and skip connections
+ which make training easier.
+ """
+
+ def __init__(self, dim: int, n_heads: int, n_groups: Optional[int], max_context_len: int,
+ max_batch_size: int, round_ff_to_multiple: int, eps: float = 1e-6):
+ """Initializes a layer of the Lamma transformer."""
+
+ super().__init__()
+
+ head_dim = dim // n_heads
+
+ self.query_embedding = RotaryEmbeddings(head_dim, max_context_len)
+ self.key_embedding = RotaryEmbeddings(head_dim, max_context_len)
+
+ cache_size = n_groups if n_groups else n_heads
+
+ self.key_cache = LinearCache(
+ max_batch_size, max_context_len, (cache_size, head_dim), dtype=torch.bfloat16
+ )
+ self.value_cache = LinearCache(
+ max_batch_size, max_context_len, (cache_size, head_dim), dtype=torch.bfloat16
+ )
+
+ self.gqa = GQA(
+ dim, n_heads, n_groups,
+ query_embedding=self.query_embedding,
+ key_embedding=self.key_embedding,
+ kv_caches=(self.key_cache, self.value_cache),
+ dtype=torch.bfloat16,
+ apply_decoder_mask=True,
+ )
+
+ # It might have been better to specify the inner "hidden" feed-forward
+ # dimension directly as a hyper parameter. It seems that FAIR chose
+ # this odd ratio from the [SwiGLU paper](https://arxiv.org/pdf/2002.05202.pdf)
+ # directly. This seems slightly odd as this ratio was initially used only for
+ # the purposes of enabling a fair comparison across different feed-forward
+ # configurations.
+ dim_ff = _round_up_to_multiple(4 * int(2 * dim / 3), round_ff_to_multiple)
+
+ self.feed_forward = SwiGLU(dim, dim_ff, dtype=torch.bfloat16)
+
+ self.attention_norm = RMSNorm(dim, eps, dtype=torch.bfloat16)
+ self.forward_norm = RMSNorm(dim, eps, dtype=torch.bfloat16)
+
+ def forward(self, x: torch.Tensor, cur_pos: int = 0) -> torch.Tensor:
+ """
+ Takes as an input the input embeddings or previous decoder output
+ and produces the output of this decoder
+ """
+
+ # RMS Norm
+ x_norm = self.attention_norm(x)
+ # GQA with a skip connection
+ # See ResNet at https://arxiv.org/pdf/1512.03385.pdf for skip connections
+ h = x + self.gqa(x_norm, cur_pos=cur_pos)
+ # RMS Norm
+ h_norm = self.forward_norm(h)
+ # SwiGLU feed-forward with a skip connection
+ h = h + self.feed_forward(h_norm)
+
+ return h
+
+class LlamaDecoderStack(nn.Module):
+ """
+ The decoder stack is a stack of n_layers of decoders.
+ """
+
+ def __init__(self, args: LlamaArgs):
+ """Initializes the decoder stack"""
+
+ super().__init__()
+
+ layer = LlamaTransformerLayer(
+ args.dim, args.n_heads, args.n_groups, args.max_ctx_len,
+ args.max_batch_size, args.multiple_of, args.norm_eps
+ )
+
+ self.decoders = nn.ModuleList([
+ copy.deepcopy(layer) for _ in range(args.n_layers)
+ ])
+
+ def forward(self, embedding: torch.Tensor, cur_pos: int = 0) -> torch.Tensor:
+ """Apply all encoders, obtaining the outputs of the last decoder"""
+
+ h = embedding
+
+ for decoder in self.decoders:
+ h = decoder(h, cur_pos)
+
+ return h
+
+class LlamaEmbeddings(nn.Module):
+ """
+ LlamaEmbeddings transform a tensor of token ids into embedding vectors of size dim
+ """
+
+ def __init__(self, vocab_size: int, dim: int, padding_idx: int):
+ """Initializes the LlamaEmbeddings"""
+
+ super().__init__()
+
+ self.vocab_size = vocab_size
+ self.dim = dim
+ self.padding_idx = padding_idx
+
+ self.embedding = nn.Embedding(self.vocab_size, self.dim, dtype=torch.bfloat16)
+
+ def forward(self, context: torch.Tensor) -> torch.Tensor:
+ """Retrieve the embeddings for a token sequence"""
+
+ # The original Llama implementation employs parallel embeddings. This
+ # implicitly produces zero embeddings for padding_idx = -1. This behavior
+ # is seemingly undefined and relies on implementation details within
+ # the parallel embeddings.
+
+ # Since nn.Embedding does not handle negative indices, we must manually
+ # zero out the padded parts of the context.
+ padding_mask = context == torch.tensor(self.padding_idx, dtype=torch.long)
+ context[padding_mask] = torch.tensor(0, dtype=torch.long)
+
+ embeddings = self.embedding(context)
+
+ embeddings[padding_mask] = torch.zeros((self.dim,), dtype=embeddings.dtype)
+
+ return embeddings
+
+class Llama(LLM):
+ """An class representing the Llama family of LLMs"""
+
+ def __init__(self, args : LlamaArgs):
+ """Initialize the Llama model"""
+
+ super().__init__()
+
+ self.context_length = args.max_ctx_len
+ self.max_batch_size = args.max_batch_size
+
+ self.embeddings = LlamaEmbeddings(
+ args.vocab_size, args.dim, args.padding_idx
+ )
+
+ self.decoder_stack = LlamaDecoderStack(args)
+
+ self.output_norm = RMSNorm(
+ args.dim, eps=args.norm_eps, dtype=torch.bfloat16
+ )
+
+ self.vocab_map = nn.Linear(
+ args.dim, args.vocab_size, bias=False, dtype=torch.bfloat16
+ )
+
+ def forward(self, context: torch.Tensor, cur_pos: int = 0) -> torch.Tensor:
+ """
+ Computes the log probabilities of the next token given a sequence of
+ tokens as context.
+
+ Args:
+ context (torch.Tensor): A tensor of shape (batch_size, context_length)
+ containing token ids. These tokens serve as the
+ context for predicting the next token.
+
+ cur_pos (int, optional): The position at which to start the
+ prediction. If cur_pos is not zero,
+ the internal cache (if available) will
+ be used to speed up predictions.
+ Defaults to 0.
+
+ Returns:
+ torch.Tensor: A tensor of shape (batch_size, vocab_size) containing
+ the log probabilities of the next token given the
+ context.
+
+ Examples:
+ # Predict the next token for a sequence [1, 2, 3]
+ log_probs = llm(torch.tensor([[1, 2, 3]], dtype=torch.long), 0)
+
+ # Predict the next token for a sequence [1, 2, 3, 4, 5] using the
+ # cache starting at position 3
+ log_probs = llm(torch.tensor([[4, 5]], dtype=torch.long), 3)
+ """
+
+ embeddings = self.embeddings(context) # ( n_batches, n_embeddings, dim )
+
+ h = self.decoder_stack(embeddings, cur_pos)
+
+ h = self.output_norm(h)
+
+ vocab_logits = self.vocab_map(h)
+
+ return vocab_logits[:,-1]
diff --git a/llama/tokenizer.py b/llama/tokenizer.py
new file mode 100644
index 0000000..937a0b8
--- /dev/null
+++ b/llama/tokenizer.py
@@ -0,0 +1,89 @@
+"""
+Llama Tokenizer
+===============
+This module contains the Tokenizer class that wraps the SentencePiece tokenizer.
+"""
+
+from typing import List
+from sentencepiece import SentencePieceProcessor # type: ignore
+
+class Tokenizer:
+ """
+ Llama Tokenizer Class
+ ---------------------
+ This class provides a wrapper around the SentencePiece tokenizer.
+ It adds some utility functions for easier encoding and decoding.
+
+ Attributes:
+ bos_id (int): The id representing the "beginning of sentence" token.
+ eos_id (int): The id representing the "end of sentence" token.
+ pad_id (int): The id representing the padding token.
+ vocab_size (int): The size of the vocabulary.
+ """
+
+ def __init__(self, model_path: str):
+ """
+ Initialize the Tokenizer.
+
+ Args:
+ model_path (str): The path to the SentencePiece model file.
+
+ Returns:
+ None
+ """
+ sp = SentencePieceProcessor(model_file=model_path)
+
+ self.bos_id: int = sp.bos_id()
+ self.eos_id: int = sp.eos_id()
+ self.pad_id: int = sp.pad_id()
+ self.vocab_size: int = sp.vocab_size()
+
+ self.sp = sp
+
+ def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
+ """
+ Encode a string as a sequence of token IDs.
+
+ Args:
+ s (str): The string to be encoded.
+ bos (bool, optional): Whether to add a "beginning of sentence" token. Defaults to False.
+ eos (bool, optional): Whether to add an "end of sentence" token. Defaults to False.
+
+ Returns:
+ List[int]: The list of token IDs.
+ """
+ tokens = []
+
+ if bos:
+ tokens.append(self.bos_id)
+
+ tokens.extend(self.sp.encode(s))
+
+ if eos:
+ tokens.append(self.eos_id)
+
+ return tokens
+
+ def decode(self, tokens: List[int]) -> str:
+ """
+ Decode a sequence of token IDs to a string.
+
+ Args:
+ tokens (List[int]): The list of token IDs to be decoded.
+
+ Returns:
+ str: The decoded string.
+ """
+ return self.sp.decode(tokens)
+
+ def id_to_piece(self, token: int) -> str:
+ """
+ Convert a token ID to its corresponding token string.
+
+ Args:
+ token (int): The token ID.
+
+ Returns:
+ str: The token string, with SentencePiece's '▁' character replaced by a space.
+ """
+ return self.sp.id_to_piece(token).replace('▁', ' ')
diff --git a/llama/utils.py b/llama/utils.py
new file mode 100644
index 0000000..5c2af3c
--- /dev/null
+++ b/llama/utils.py
@@ -0,0 +1,153 @@
+"""
+Utilities for loading the Llama model and tokenizer from a checkpoint directory
+and a tokenizer model file.
+"""
+
+import json
+import re
+
+from pathlib import Path
+from typing import Dict, Any, Tuple
+
+import torch
+
+from .model import LlamaArgs, Llama
+from .tokenizer import Tokenizer
+
+ModuleParams = Dict[str, Any]
+
+def _load_model_from_checkpoint(checkpoint_dir: str) \
+ -> Tuple[LlamaArgs, ModuleParams]:
+ """
+ Load the Llama model from a given checkpoint directory.
+
+ Args:
+ checkpoint_dir (str): The path to the directory containing the Llama
+ model checkpoint.
+
+ Returns:
+ Tuple[LlamaArgs, ModuleParams]: A tuple containing:
+ - LlamaArgs: Arguments used for initializing the Llama model.
+ - ModuleParams: PyTorch state dictionary for the Llama model.
+ """
+
+ checkpoint_path = Path(checkpoint_dir)
+ with open(checkpoint_path / "params.json", "r", encoding='utf-8') as f:
+ args = json.loads(f.read())
+
+ args = LlamaArgs(**args)
+
+ checkpoint_paths = list(checkpoint_path.glob('*.pth'))
+ checkpoint_paths = sorted(checkpoint_paths)
+
+ checkpoint = torch.load(checkpoint_paths[0], map_location="cpu")
+
+ return args, checkpoint
+
+# pylint: disable=locally-disabled, R0912
+def _transform_params(params: ModuleParams) -> ModuleParams:
+ """
+ Map the state dictionary keys from the official Llama model to the keys
+ used in this implementation.
+
+ Args:
+ params (ModuleParams): The state dictionary from the official Llama
+ model.
+
+ Returns:
+ ModuleParams: The modified state dictionary to match the keys used in
+ this implementation.
+ """
+
+ new_params = {}
+
+ for label, param in params.items():
+
+ if label == 'tok_embeddings.weight':
+ label = 'embeddings.embedding.weight'
+ elif label == 'norm.weight':
+ label = 'output_norm.gain'
+ elif label == 'output.weight':
+ label = 'vocab_map.weight'
+ else:
+
+ if label in { 'rope.freqs' }:
+ continue
+
+ regex = re.compile(r'layers\.(\d+)\.(.*)')
+
+ m = regex.match(label)
+
+ assert m is not None
+
+ layer_num = m.group(1)
+ sub_label = m.group(2)
+
+ label = f'decoder_stack.decoders.{layer_num}.'
+
+ if sub_label == 'attention.wq.weight':
+ label += 'gqa.wq.weight'
+ elif sub_label == 'attention.wk.weight':
+ label += 'gqa.wk.weight'
+ elif sub_label == 'attention.wv.weight':
+ label += 'gqa.wv.weight'
+ elif sub_label == 'attention.wo.weight':
+ label += 'gqa.wo.weight'
+ elif sub_label == 'feed_forward.w1.weight':
+ label += 'feed_forward.w.weight'
+ elif sub_label == 'feed_forward.w2.weight':
+ label += 'feed_forward.w2.weight'
+ elif sub_label == 'feed_forward.w3.weight':
+ label += 'feed_forward.v.weight'
+ elif sub_label == 'attention_norm.weight':
+ label += 'attention_norm.gain'
+ elif sub_label == 'ffn_norm.weight':
+ label += 'forward_norm.gain'
+ else:
+ assert False, "Key not found"
+
+ new_params[label] = param
+
+ return new_params
+
+def load_llama_from_checkpoint(checkpoint_dir: str, tokenizer_path: str,
+ max_batch_size: int = 1, max_context_len: int = 2048) \
+ -> Tuple[Llama, Tokenizer]:
+ """
+ Load the Llama model and the tokenizer from specified paths.
+
+ Args:
+ checkpoint_dir (str): Path to the directory containing the model
+ checkpoint.
+ tokenizer_path (str): Path to the tokenizer model file.
+ max_batch_size (int, optional): Maximum batch size the model can accept.
+ Affects cache size. Default is 1.
+ max_context_len (int, optional): Maximum context length the model can
+ handle. Affects cache size. Default is
+ 2048.
+
+ Returns:
+ Tuple[Llama, Tokenizer]: A tuple containing:
+ - Llama: The Llama model loaded from the checkpoint.
+ - Tokenizer: The tokenizer model loaded from the given path.
+ """
+
+ args, checkpoint = _load_model_from_checkpoint(checkpoint_dir)
+
+ tokenizer = Tokenizer(tokenizer_path)
+
+ args.vocab_size = tokenizer.vocab_size
+ args.max_batch_size = max_batch_size
+ args.max_ctx_len = max_context_len
+ args.padding_idx = tokenizer.pad_id
+
+ checkpoint = _transform_params(checkpoint)
+
+ llama_model = Llama(args)
+ llama_model.load_state_dict(checkpoint)
+
+ return llama_model, tokenizer
+
+__all__ = [
+ 'load_llama_from_checkpoint'
+]