aboutsummaryrefslogtreecommitdiff
path: root/llama/generate.py
blob: e735ca404881ebf7d0da18bbc0791498a99ed671 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
"""
This module provides an interface for generating and sampling token sequences from a language model.
It allows for the controlled generation of text by specifying parameters such as temperature, top-k,
and top-p, which influence the randomness and quality of the generated sequences.
"""

# pylint: disable=locally-disabled, R0913, R0914

from collections import OrderedDict
from typing import List, Optional, Generator, cast

import torch

from .llm import LLM

TokenList = OrderedDict[int, float]

def _sample_internal(llm: LLM, context: torch.Tensor) -> torch.Tensor:
    """
    Sample a tensor of logits from the language model (LLM) based on the input context.
    """

    batch_size, seq_len = context.shape

    assert seq_len <= llm.context_length
    assert batch_size <= llm.max_batch_size

    with torch.inference_mode():
        return llm(context)

def _load_context(tokens: List[List[int]], pad_id: int,
                  pad_to_length: Optional[int] = None) -> torch.Tensor:
    """
    Load a batch of token lists into a padded tensor suitable for input to a language model.
    """
    batch_size = len(tokens)

    max_token_len = max((len(tok) for tok in tokens))

    pad_to_length = max_token_len if pad_to_length is None else pad_to_length

    context = torch.full(
        (batch_size, pad_to_length), pad_id, dtype=torch.long
    )

    for dim, toks in enumerate(tokens):
        context[dim, :len(toks)] = torch.tensor(toks, dtype=torch.long)

    return context

def batched_token_probabilities(llm: LLM,
                                tokens: List[List[int]],
                                temperature: float = 1.0) -> List[TokenList]:
    """
    Calculate the probabilities of the next token sequence across a batch of sequences.

    Args:
    - llm (LLM): An instance of the language model.
    - tokens (List[List[int]]): A list of tokenized input sequences.
    - temperature (float): A temperature parameter to scale the logits before
                           applying softmax. Default is 1.0.

    Returns:
    - List[TokenList]: A list of ordered dictionaries where each dictionary maps
                       token ids to their corresponding probabilities for each
                       sequence in the batch.
    """

    context = _load_context(tokens, llm.padding_idx)

    token_logprobs = _sample_internal(llm, context)

    token_probs = torch.softmax(token_logprobs / temperature, dim=-1)

    samples: List[TokenList] = [OrderedDict() for _ in range(len(tokens))]
    for i, p in enumerate(token_probs):
        for _id in torch.argsort(p, descending=True):
            samples[i][int(_id)] = float(p[_id])

    return samples

def token_probabilities(llm: LLM, tokens: List[int]) -> TokenList:
    """
    Calculate the probabilities of the next token sequence.
    See batched_token_probabilities.
    """

    return batched_token_probabilities(llm, [ tokens ])[0]

def sample_batched_token(
    token_logprobs: torch.Tensor,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: float = 1.0,
    sample_eps: float = 1e-6) -> torch.Tensor:
    """
    Sample a token from a batch of token logits with optional top-k and top-p filtering.


    Args:
    - token_logprobs (torch.Tensor): A tensor of token log probabilities.
    - temperature (float): A scaling factor for logits before sampling. Default
                           is 1.0.
    - top_k (Optional[int]): If set, the sampling is restricted to the top k
                             tokens. Default is None (no restriction).
    - top_p (float): If set, the sampling is restricted to the smallest set
                     of tokens with cumulative probability exceeding top_p.
                     Default is 1.0 (no restriction).
    - sample_eps (float): An epsilon value to avoid precision errors during
                          cumulative probability calculation. Default is 1e-6.

    Returns:
    - torch.Tensor: A tensor of sampled token ids for each item in the batch.

    Implements both top-k sampling, top-p sampling, and beam search.

    See:
      - https://arxiv.org/pdf/1805.04833.pdf for top-k sampling
      - https://arxiv.org/pdf/1904.09751.pdf for top-p sampling
    """

    batch_size = token_logprobs.shape[0]

    token_probs = torch.softmax(token_logprobs / temperature, dim=-1)


    selected_tokens = torch.zeros(batch_size, dtype=torch.long)

    sorted_tokens = torch.argsort(token_probs, descending=True)
    sorted_probs = torch.gather(token_probs, 1, sorted_tokens)
    nucleus_mask = sorted_probs.cumsum(dim=-1) < top_p + sample_eps
    nucleus_mask[:,0] = True

    for i, (tokens, mask, probs) in enumerate(zip(sorted_tokens, nucleus_mask, sorted_probs)):
        nucleus = tokens[mask]
        p = probs[mask]

        if top_k is not None and len(nucleus) > top_k:
            nucleus = nucleus[:top_k]
            p = p[:top_k]

        p /= p.sum(axis=0)

        token = nucleus[torch.multinomial(p, 1)]

        selected_tokens[i] = token

    return selected_tokens

def generate_batched_token_sequence(llm: LLM,
                    prompts: List[List[int]],
                    max_generation_length: Optional[int] = None,
                    temperature: float = 1.0,
                    top_k: Optional[int] = None,
                    top_p: float = 1.0) -> Generator[List[Optional[int]], None, None]:
    """
    Generate a sequence of tokens for each prompt across a sequence of batches.

    Args:
    - llm (LLM): An instance of the language model.
    - prompts (List[List[int]]): A list of tokenized input sequences (prompts).
    - max_generation_length (Optional[int]): The maximum number of tokens to
      generate for each prompt. If None, generate up to the model's maximum
      context length. Default is None.
    - temperature (float): A scaling factor for logits before sampling. Default
                           is 1.0.
    - top_k (Optional[int]): If set, restricts sampling to the top k most
                             likely tokens. Default is None (no restriction).
    - top_p (float): If set, restricts sampling to a subset of tokens with a
      cumulative probability greater than top_p. Default is 1.0
      (no restriction).

    Yields:
    - Generator[List[Optional[int]], None, None]: A generator that yields lists
      of token ids, where each list corresponds to one prompt in the batch.
      Yields none if a token was not generated during an iteration of inference.

    Raises:
    - AssertionError: If batch size exceeds the maximum allowed by the LLM, or
      if the requested generation length is too long.
    """

    batch_size = len(prompts)
    assert batch_size <= llm.max_batch_size, \
        "Batch size exceeded the maximum batch size of the LLM"

    prompt_lens = torch.tensor([len(p) for p in prompts], dtype=torch.long)
    max_prompt_len = max(prompt_lens)

    remaining_context = llm.context_length - max_prompt_len
    if max_generation_length is None:
        max_generation_length = remaining_context
    else:
        assert max_generation_length <= remaining_context, \
            "Cannot generate more tokens than exist in the context"

    eos = torch.zeros(batch_size, dtype=torch.long)
    last_pos = 0

    end_pos = max_prompt_len + max_generation_length
    context = _load_context(prompts, llm.padding_idx, pad_to_length=end_pos)

    start_pos = max(prompt_lens)

    for pos in range(start_pos, end_pos):
        log_probs = llm(context[:, last_pos:pos], last_pos)

        sampled = sample_batched_token(
            log_probs,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p
        )

        in_prompt = pos < prompt_lens
        should_replace_mask = (eos == 0) & (~in_prompt)

        yield [int(sampled[i]) if should_replace_mask[i] else None for i in range(batch_size)]

        context[should_replace_mask, pos] = sampled[should_replace_mask]
        eos[(eos > 0) & (sampled == llm.eos_token)] = pos + 1
        last_pos = pos

        if (eos > 0).all():
            break

def generate_token_sequence(llm: LLM,
                            prompt: List[int],
                            max_generation_length: Optional[int] = None,
                            temperature: float = 1.0,
                            top_k: Optional[int] = None,
                            top_p: float = 1.0) -> Generator[int, None, None]:
    """
    Generate a sequence of tokens for a single prompt.
    See generate_batched_token_sequence.
    """

    for tokens in generate_batched_token_sequence(llm,
                                                  [ prompt ],
                                                  max_generation_length=max_generation_length,
                                                  temperature=temperature,
                                                  top_k=top_k,
                                                  top_p=top_p):
        yield cast(int, tokens[0])

def sample_batched_sequence(llm: LLM,
                            prompts: List[List[int]],
                            max_generation_length: Optional[int] = None,
                            temperature: float = 1.0,
                            top_k: Optional[int] = None,
                            top_p: float = 1.0,
                            include_prompt: bool = False) -> List[List[int]]:
    """
    Generate and sample a sequence of tokens for each input sequence in a batch.

    Args:
    - llm (LLM): An instance of the language model.
    - prompts (List[List[int]]): A list of tokenized input sequences (prompts).
    - max_generation_length (Optional[int]): The maximum number of tokens to
      generate for each prompt. Defaults to None, which allows the generation
      up to the model's maximum context length.
    - temperature (float): A scaling factor for logits before sampling,
      affecting the randomness of the output. Default is 1.0, with lower values
      leading to less random samples.
    - top_k (Optional[int]): Limits the sampling pool to the top k tokens
      according to the probability distribution. Default is None, indicating no
      limit.
    - top_p (float): The cumulative probability threshold for nucleus sampling;
      allows sampling from a set of high-probability tokens whose cumulative
      probability exceeds this threshold. Default is 1.0, indicating no limit.
    - include_prompt (bool): If True, includes the input prompt at the beginning
      of the generated sequence. Default is False.

    Returns:
    - List[List[int]]: A list of lists containing the sampled token sequences
      for each input prompt. The sequences include the generated tokens and,
      if include_prompt is True, the original input prompt tokens.
    """

    sampled_seqs: List[List[int]] = [[] for _ in range(len(prompts))]

    if include_prompt:
        for i, prompt in enumerate(prompts):
            sampled_seqs[i].extend(prompt)

    for generated_tokens in generate_batched_token_sequence(llm,
                                            prompts,
                                            max_generation_length,
                                            temperature,
                                            top_k,
                                            top_p):
        for i, token in enumerate(generated_tokens):
            if token is not None:
                sampled_seqs[i].append(token)

    return sampled_seqs

def sample_sequence(llm: LLM,
                    prompts: List[int],
                    max_generation_length: Optional[int] = None,
                    temperature: float = 1.0,
                    top_k: Optional[int] = None,
                    top_p: float = 1.0,
                    include_prompt: bool = False) -> List[int]:
    """
    Generate and sample a sequence of tokens for a single input sequence.
    See sample_batched_sequence for reference.
    """

    return sample_batched_sequence(
        llm, [prompts], max_generation_length, temperature,
        top_k, top_p, include_prompt
    )[0]

__all__ = [
    'token_probabilities',
    'sample_batched_token',
    'sample_sequence',
    'sample_batched_sequence',
    'generate_token_sequence',
    'generate_batched_token_sequence',
]