aboutsummaryrefslogtreecommitdiff
path: root/llama/utils.py
blob: 5c2af3c161c772d10b010bd3dd2c1879f56dc422 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Utilities for loading the Llama model and tokenizer from a checkpoint directory
and a tokenizer model file.
"""

import json
import re

from pathlib import Path
from typing import Dict, Any, Tuple

import torch

from .model import LlamaArgs, Llama
from .tokenizer import Tokenizer

ModuleParams = Dict[str, Any]

def _load_model_from_checkpoint(checkpoint_dir: str) \
        -> Tuple[LlamaArgs, ModuleParams]:
    """
    Load the Llama model from a given checkpoint directory.

    Args:
        checkpoint_dir (str): The path to the directory containing the Llama
        model checkpoint.

    Returns:
        Tuple[LlamaArgs, ModuleParams]: A tuple containing:
            - LlamaArgs: Arguments used for initializing the Llama model.
            - ModuleParams: PyTorch state dictionary for the Llama model.
    """

    checkpoint_path = Path(checkpoint_dir)
    with open(checkpoint_path / "params.json", "r", encoding='utf-8') as f:
        args = json.loads(f.read())

    args = LlamaArgs(**args)

    checkpoint_paths = list(checkpoint_path.glob('*.pth'))
    checkpoint_paths = sorted(checkpoint_paths)

    checkpoint = torch.load(checkpoint_paths[0], map_location="cpu")

    return args, checkpoint

# pylint: disable=locally-disabled, R0912
def _transform_params(params: ModuleParams) -> ModuleParams:
    """
    Map the state dictionary keys from the official Llama model to the keys
    used in this implementation.

    Args:
        params (ModuleParams): The state dictionary from the official Llama
        model.

    Returns:
        ModuleParams: The modified state dictionary to match the keys used in
        this implementation.
    """

    new_params = {}

    for label, param in params.items():

        if label == 'tok_embeddings.weight':
            label = 'embeddings.embedding.weight'
        elif label == 'norm.weight':
            label = 'output_norm.gain'
        elif label == 'output.weight':
            label = 'vocab_map.weight'
        else:

            if label in { 'rope.freqs' }:
                continue

            regex = re.compile(r'layers\.(\d+)\.(.*)')

            m = regex.match(label)

            assert m is not None

            layer_num = m.group(1)
            sub_label = m.group(2)

            label = f'decoder_stack.decoders.{layer_num}.'

            if sub_label == 'attention.wq.weight':
                label += 'gqa.wq.weight'
            elif sub_label == 'attention.wk.weight':
                label += 'gqa.wk.weight'
            elif sub_label == 'attention.wv.weight':
                label += 'gqa.wv.weight'
            elif sub_label == 'attention.wo.weight':
                label += 'gqa.wo.weight'
            elif sub_label == 'feed_forward.w1.weight':
                label += 'feed_forward.w.weight'
            elif sub_label == 'feed_forward.w2.weight':
                label += 'feed_forward.w2.weight'
            elif sub_label == 'feed_forward.w3.weight':
                label += 'feed_forward.v.weight'
            elif sub_label == 'attention_norm.weight':
                label += 'attention_norm.gain'
            elif sub_label == 'ffn_norm.weight':
                label += 'forward_norm.gain'
            else:
                assert False, "Key not found"

        new_params[label] = param

    return new_params

def load_llama_from_checkpoint(checkpoint_dir: str, tokenizer_path: str,
                               max_batch_size: int = 1, max_context_len: int = 2048) \
        -> Tuple[Llama, Tokenizer]:
    """
    Load the Llama model and the tokenizer from specified paths.

    Args:
        checkpoint_dir (str): Path to the directory containing the model
                              checkpoint.
        tokenizer_path (str): Path to the tokenizer model file.
        max_batch_size (int, optional): Maximum batch size the model can accept.
                                        Affects cache size. Default is 1.
        max_context_len (int, optional): Maximum context length the model can
                                        handle. Affects cache size. Default is
                                        2048.

    Returns:
        Tuple[Llama, Tokenizer]: A tuple containing:
            - Llama: The Llama model loaded from the checkpoint.
            - Tokenizer: The tokenizer model loaded from the given path.
    """

    args, checkpoint = _load_model_from_checkpoint(checkpoint_dir)

    tokenizer = Tokenizer(tokenizer_path)

    args.vocab_size = tokenizer.vocab_size
    args.max_batch_size = max_batch_size
    args.max_ctx_len = max_context_len
    args.padding_idx = tokenizer.pad_id

    checkpoint = _transform_params(checkpoint)

    llama_model = Llama(args)
    llama_model.load_state_dict(checkpoint)

    return llama_model, tokenizer

__all__ = [
    'load_llama_from_checkpoint'
]