diff options
author | flu0r1ne <flu0r1ne@flu0r1ne.net> | 2023-11-01 20:46:01 -0500 |
---|---|---|
committer | flu0r1ne <flu0r1ne@flu0r1ne.net> | 2023-11-01 20:46:01 -0500 |
commit | af5a2996234768921b81d96ffaae00cb88229862 (patch) | |
tree | 5b2a688582652fc8080616ccc0de162198aa8ee0 /llama/utils.py | |
download | myllama2-main.tar.xz myllama2-main.zip |
Diffstat (limited to 'llama/utils.py')
-rw-r--r-- | llama/utils.py | 153 |
1 files changed, 153 insertions, 0 deletions
diff --git a/llama/utils.py b/llama/utils.py new file mode 100644 index 0000000..5c2af3c --- /dev/null +++ b/llama/utils.py @@ -0,0 +1,153 @@ +""" +Utilities for loading the Llama model and tokenizer from a checkpoint directory +and a tokenizer model file. +""" + +import json +import re + +from pathlib import Path +from typing import Dict, Any, Tuple + +import torch + +from .model import LlamaArgs, Llama +from .tokenizer import Tokenizer + +ModuleParams = Dict[str, Any] + +def _load_model_from_checkpoint(checkpoint_dir: str) \ + -> Tuple[LlamaArgs, ModuleParams]: + """ + Load the Llama model from a given checkpoint directory. + + Args: + checkpoint_dir (str): The path to the directory containing the Llama + model checkpoint. + + Returns: + Tuple[LlamaArgs, ModuleParams]: A tuple containing: + - LlamaArgs: Arguments used for initializing the Llama model. + - ModuleParams: PyTorch state dictionary for the Llama model. + """ + + checkpoint_path = Path(checkpoint_dir) + with open(checkpoint_path / "params.json", "r", encoding='utf-8') as f: + args = json.loads(f.read()) + + args = LlamaArgs(**args) + + checkpoint_paths = list(checkpoint_path.glob('*.pth')) + checkpoint_paths = sorted(checkpoint_paths) + + checkpoint = torch.load(checkpoint_paths[0], map_location="cpu") + + return args, checkpoint + +# pylint: disable=locally-disabled, R0912 +def _transform_params(params: ModuleParams) -> ModuleParams: + """ + Map the state dictionary keys from the official Llama model to the keys + used in this implementation. + + Args: + params (ModuleParams): The state dictionary from the official Llama + model. + + Returns: + ModuleParams: The modified state dictionary to match the keys used in + this implementation. + """ + + new_params = {} + + for label, param in params.items(): + + if label == 'tok_embeddings.weight': + label = 'embeddings.embedding.weight' + elif label == 'norm.weight': + label = 'output_norm.gain' + elif label == 'output.weight': + label = 'vocab_map.weight' + else: + + if label in { 'rope.freqs' }: + continue + + regex = re.compile(r'layers\.(\d+)\.(.*)') + + m = regex.match(label) + + assert m is not None + + layer_num = m.group(1) + sub_label = m.group(2) + + label = f'decoder_stack.decoders.{layer_num}.' + + if sub_label == 'attention.wq.weight': + label += 'gqa.wq.weight' + elif sub_label == 'attention.wk.weight': + label += 'gqa.wk.weight' + elif sub_label == 'attention.wv.weight': + label += 'gqa.wv.weight' + elif sub_label == 'attention.wo.weight': + label += 'gqa.wo.weight' + elif sub_label == 'feed_forward.w1.weight': + label += 'feed_forward.w.weight' + elif sub_label == 'feed_forward.w2.weight': + label += 'feed_forward.w2.weight' + elif sub_label == 'feed_forward.w3.weight': + label += 'feed_forward.v.weight' + elif sub_label == 'attention_norm.weight': + label += 'attention_norm.gain' + elif sub_label == 'ffn_norm.weight': + label += 'forward_norm.gain' + else: + assert False, "Key not found" + + new_params[label] = param + + return new_params + +def load_llama_from_checkpoint(checkpoint_dir: str, tokenizer_path: str, + max_batch_size: int = 1, max_context_len: int = 2048) \ + -> Tuple[Llama, Tokenizer]: + """ + Load the Llama model and the tokenizer from specified paths. + + Args: + checkpoint_dir (str): Path to the directory containing the model + checkpoint. + tokenizer_path (str): Path to the tokenizer model file. + max_batch_size (int, optional): Maximum batch size the model can accept. + Affects cache size. Default is 1. + max_context_len (int, optional): Maximum context length the model can + handle. Affects cache size. Default is + 2048. + + Returns: + Tuple[Llama, Tokenizer]: A tuple containing: + - Llama: The Llama model loaded from the checkpoint. + - Tokenizer: The tokenizer model loaded from the given path. + """ + + args, checkpoint = _load_model_from_checkpoint(checkpoint_dir) + + tokenizer = Tokenizer(tokenizer_path) + + args.vocab_size = tokenizer.vocab_size + args.max_batch_size = max_batch_size + args.max_ctx_len = max_context_len + args.padding_idx = tokenizer.pad_id + + checkpoint = _transform_params(checkpoint) + + llama_model = Llama(args) + llama_model.load_state_dict(checkpoint) + + return llama_model, tokenizer + +__all__ = [ + 'load_llama_from_checkpoint' +] |