aboutsummaryrefslogtreecommitdiff
path: root/llama/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'llama/utils.py')
-rw-r--r--llama/utils.py153
1 files changed, 153 insertions, 0 deletions
diff --git a/llama/utils.py b/llama/utils.py
new file mode 100644
index 0000000..5c2af3c
--- /dev/null
+++ b/llama/utils.py
@@ -0,0 +1,153 @@
+"""
+Utilities for loading the Llama model and tokenizer from a checkpoint directory
+and a tokenizer model file.
+"""
+
+import json
+import re
+
+from pathlib import Path
+from typing import Dict, Any, Tuple
+
+import torch
+
+from .model import LlamaArgs, Llama
+from .tokenizer import Tokenizer
+
+ModuleParams = Dict[str, Any]
+
+def _load_model_from_checkpoint(checkpoint_dir: str) \
+ -> Tuple[LlamaArgs, ModuleParams]:
+ """
+ Load the Llama model from a given checkpoint directory.
+
+ Args:
+ checkpoint_dir (str): The path to the directory containing the Llama
+ model checkpoint.
+
+ Returns:
+ Tuple[LlamaArgs, ModuleParams]: A tuple containing:
+ - LlamaArgs: Arguments used for initializing the Llama model.
+ - ModuleParams: PyTorch state dictionary for the Llama model.
+ """
+
+ checkpoint_path = Path(checkpoint_dir)
+ with open(checkpoint_path / "params.json", "r", encoding='utf-8') as f:
+ args = json.loads(f.read())
+
+ args = LlamaArgs(**args)
+
+ checkpoint_paths = list(checkpoint_path.glob('*.pth'))
+ checkpoint_paths = sorted(checkpoint_paths)
+
+ checkpoint = torch.load(checkpoint_paths[0], map_location="cpu")
+
+ return args, checkpoint
+
+# pylint: disable=locally-disabled, R0912
+def _transform_params(params: ModuleParams) -> ModuleParams:
+ """
+ Map the state dictionary keys from the official Llama model to the keys
+ used in this implementation.
+
+ Args:
+ params (ModuleParams): The state dictionary from the official Llama
+ model.
+
+ Returns:
+ ModuleParams: The modified state dictionary to match the keys used in
+ this implementation.
+ """
+
+ new_params = {}
+
+ for label, param in params.items():
+
+ if label == 'tok_embeddings.weight':
+ label = 'embeddings.embedding.weight'
+ elif label == 'norm.weight':
+ label = 'output_norm.gain'
+ elif label == 'output.weight':
+ label = 'vocab_map.weight'
+ else:
+
+ if label in { 'rope.freqs' }:
+ continue
+
+ regex = re.compile(r'layers\.(\d+)\.(.*)')
+
+ m = regex.match(label)
+
+ assert m is not None
+
+ layer_num = m.group(1)
+ sub_label = m.group(2)
+
+ label = f'decoder_stack.decoders.{layer_num}.'
+
+ if sub_label == 'attention.wq.weight':
+ label += 'gqa.wq.weight'
+ elif sub_label == 'attention.wk.weight':
+ label += 'gqa.wk.weight'
+ elif sub_label == 'attention.wv.weight':
+ label += 'gqa.wv.weight'
+ elif sub_label == 'attention.wo.weight':
+ label += 'gqa.wo.weight'
+ elif sub_label == 'feed_forward.w1.weight':
+ label += 'feed_forward.w.weight'
+ elif sub_label == 'feed_forward.w2.weight':
+ label += 'feed_forward.w2.weight'
+ elif sub_label == 'feed_forward.w3.weight':
+ label += 'feed_forward.v.weight'
+ elif sub_label == 'attention_norm.weight':
+ label += 'attention_norm.gain'
+ elif sub_label == 'ffn_norm.weight':
+ label += 'forward_norm.gain'
+ else:
+ assert False, "Key not found"
+
+ new_params[label] = param
+
+ return new_params
+
+def load_llama_from_checkpoint(checkpoint_dir: str, tokenizer_path: str,
+ max_batch_size: int = 1, max_context_len: int = 2048) \
+ -> Tuple[Llama, Tokenizer]:
+ """
+ Load the Llama model and the tokenizer from specified paths.
+
+ Args:
+ checkpoint_dir (str): Path to the directory containing the model
+ checkpoint.
+ tokenizer_path (str): Path to the tokenizer model file.
+ max_batch_size (int, optional): Maximum batch size the model can accept.
+ Affects cache size. Default is 1.
+ max_context_len (int, optional): Maximum context length the model can
+ handle. Affects cache size. Default is
+ 2048.
+
+ Returns:
+ Tuple[Llama, Tokenizer]: A tuple containing:
+ - Llama: The Llama model loaded from the checkpoint.
+ - Tokenizer: The tokenizer model loaded from the given path.
+ """
+
+ args, checkpoint = _load_model_from_checkpoint(checkpoint_dir)
+
+ tokenizer = Tokenizer(tokenizer_path)
+
+ args.vocab_size = tokenizer.vocab_size
+ args.max_batch_size = max_batch_size
+ args.max_ctx_len = max_context_len
+ args.padding_idx = tokenizer.pad_id
+
+ checkpoint = _transform_params(checkpoint)
+
+ llama_model = Llama(args)
+ llama_model.load_state_dict(checkpoint)
+
+ return llama_model, tokenizer
+
+__all__ = [
+ 'load_llama_from_checkpoint'
+]