diff options
author | flu0r1ne <flu0r1ne@flu0r1ne.net> | 2023-11-01 20:46:01 -0500 |
---|---|---|
committer | flu0r1ne <flu0r1ne@flu0r1ne.net> | 2023-11-01 20:46:01 -0500 |
commit | af5a2996234768921b81d96ffaae00cb88229862 (patch) | |
tree | 5b2a688582652fc8080616ccc0de162198aa8ee0 /llama/generate.py | |
download | myllama2-main.tar.xz myllama2-main.zip |
Diffstat (limited to 'llama/generate.py')
-rw-r--r-- | llama/generate.py | 322 |
1 files changed, 322 insertions, 0 deletions
diff --git a/llama/generate.py b/llama/generate.py new file mode 100644 index 0000000..e735ca4 --- /dev/null +++ b/llama/generate.py @@ -0,0 +1,322 @@ +""" +This module provides an interface for generating and sampling token sequences from a language model. +It allows for the controlled generation of text by specifying parameters such as temperature, top-k, +and top-p, which influence the randomness and quality of the generated sequences. +""" + +# pylint: disable=locally-disabled, R0913, R0914 + +from collections import OrderedDict +from typing import List, Optional, Generator, cast + +import torch + +from .llm import LLM + +TokenList = OrderedDict[int, float] + +def _sample_internal(llm: LLM, context: torch.Tensor) -> torch.Tensor: + """ + Sample a tensor of logits from the language model (LLM) based on the input context. + """ + + batch_size, seq_len = context.shape + + assert seq_len <= llm.context_length + assert batch_size <= llm.max_batch_size + + with torch.inference_mode(): + return llm(context) + +def _load_context(tokens: List[List[int]], pad_id: int, + pad_to_length: Optional[int] = None) -> torch.Tensor: + """ + Load a batch of token lists into a padded tensor suitable for input to a language model. + """ + batch_size = len(tokens) + + max_token_len = max((len(tok) for tok in tokens)) + + pad_to_length = max_token_len if pad_to_length is None else pad_to_length + + context = torch.full( + (batch_size, pad_to_length), pad_id, dtype=torch.long + ) + + for dim, toks in enumerate(tokens): + context[dim, :len(toks)] = torch.tensor(toks, dtype=torch.long) + + return context + +def batched_token_probabilities(llm: LLM, + tokens: List[List[int]], + temperature: float = 1.0) -> List[TokenList]: + """ + Calculate the probabilities of the next token sequence across a batch of sequences. + + Args: + - llm (LLM): An instance of the language model. + - tokens (List[List[int]]): A list of tokenized input sequences. + - temperature (float): A temperature parameter to scale the logits before + applying softmax. Default is 1.0. + + Returns: + - List[TokenList]: A list of ordered dictionaries where each dictionary maps + token ids to their corresponding probabilities for each + sequence in the batch. + """ + + context = _load_context(tokens, llm.padding_idx) + + token_logprobs = _sample_internal(llm, context) + + token_probs = torch.softmax(token_logprobs / temperature, dim=-1) + + samples: List[TokenList] = [OrderedDict() for _ in range(len(tokens))] + for i, p in enumerate(token_probs): + for _id in torch.argsort(p, descending=True): + samples[i][int(_id)] = float(p[_id]) + + return samples + +def token_probabilities(llm: LLM, tokens: List[int]) -> TokenList: + """ + Calculate the probabilities of the next token sequence. + See batched_token_probabilities. + """ + + return batched_token_probabilities(llm, [ tokens ])[0] + +def sample_batched_token( + token_logprobs: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: float = 1.0, + sample_eps: float = 1e-6) -> torch.Tensor: + """ + Sample a token from a batch of token logits with optional top-k and top-p filtering. + + + Args: + - token_logprobs (torch.Tensor): A tensor of token log probabilities. + - temperature (float): A scaling factor for logits before sampling. Default + is 1.0. + - top_k (Optional[int]): If set, the sampling is restricted to the top k + tokens. Default is None (no restriction). + - top_p (float): If set, the sampling is restricted to the smallest set + of tokens with cumulative probability exceeding top_p. + Default is 1.0 (no restriction). + - sample_eps (float): An epsilon value to avoid precision errors during + cumulative probability calculation. Default is 1e-6. + + Returns: + - torch.Tensor: A tensor of sampled token ids for each item in the batch. + + Implements both top-k sampling, top-p sampling, and beam search. + + See: + - https://arxiv.org/pdf/1805.04833.pdf for top-k sampling + - https://arxiv.org/pdf/1904.09751.pdf for top-p sampling + """ + + batch_size = token_logprobs.shape[0] + + token_probs = torch.softmax(token_logprobs / temperature, dim=-1) + + + selected_tokens = torch.zeros(batch_size, dtype=torch.long) + + sorted_tokens = torch.argsort(token_probs, descending=True) + sorted_probs = torch.gather(token_probs, 1, sorted_tokens) + nucleus_mask = sorted_probs.cumsum(dim=-1) < top_p + sample_eps + nucleus_mask[:,0] = True + + for i, (tokens, mask, probs) in enumerate(zip(sorted_tokens, nucleus_mask, sorted_probs)): + nucleus = tokens[mask] + p = probs[mask] + + if top_k is not None and len(nucleus) > top_k: + nucleus = nucleus[:top_k] + p = p[:top_k] + + p /= p.sum(axis=0) + + token = nucleus[torch.multinomial(p, 1)] + + selected_tokens[i] = token + + return selected_tokens + +def generate_batched_token_sequence(llm: LLM, + prompts: List[List[int]], + max_generation_length: Optional[int] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: float = 1.0) -> Generator[List[Optional[int]], None, None]: + """ + Generate a sequence of tokens for each prompt across a sequence of batches. + + Args: + - llm (LLM): An instance of the language model. + - prompts (List[List[int]]): A list of tokenized input sequences (prompts). + - max_generation_length (Optional[int]): The maximum number of tokens to + generate for each prompt. If None, generate up to the model's maximum + context length. Default is None. + - temperature (float): A scaling factor for logits before sampling. Default + is 1.0. + - top_k (Optional[int]): If set, restricts sampling to the top k most + likely tokens. Default is None (no restriction). + - top_p (float): If set, restricts sampling to a subset of tokens with a + cumulative probability greater than top_p. Default is 1.0 + (no restriction). + + Yields: + - Generator[List[Optional[int]], None, None]: A generator that yields lists + of token ids, where each list corresponds to one prompt in the batch. + Yields none if a token was not generated during an iteration of inference. + + Raises: + - AssertionError: If batch size exceeds the maximum allowed by the LLM, or + if the requested generation length is too long. + """ + + batch_size = len(prompts) + assert batch_size <= llm.max_batch_size, \ + "Batch size exceeded the maximum batch size of the LLM" + + prompt_lens = torch.tensor([len(p) for p in prompts], dtype=torch.long) + max_prompt_len = max(prompt_lens) + + remaining_context = llm.context_length - max_prompt_len + if max_generation_length is None: + max_generation_length = remaining_context + else: + assert max_generation_length <= remaining_context, \ + "Cannot generate more tokens than exist in the context" + + eos = torch.zeros(batch_size, dtype=torch.long) + last_pos = 0 + + end_pos = max_prompt_len + max_generation_length + context = _load_context(prompts, llm.padding_idx, pad_to_length=end_pos) + + start_pos = max(prompt_lens) + + for pos in range(start_pos, end_pos): + log_probs = llm(context[:, last_pos:pos], last_pos) + + sampled = sample_batched_token( + log_probs, + temperature=temperature, + top_k=top_k, + top_p=top_p + ) + + in_prompt = pos < prompt_lens + should_replace_mask = (eos == 0) & (~in_prompt) + + yield [int(sampled[i]) if should_replace_mask[i] else None for i in range(batch_size)] + + context[should_replace_mask, pos] = sampled[should_replace_mask] + eos[(eos > 0) & (sampled == llm.eos_token)] = pos + 1 + last_pos = pos + + if (eos > 0).all(): + break + +def generate_token_sequence(llm: LLM, + prompt: List[int], + max_generation_length: Optional[int] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: float = 1.0) -> Generator[int, None, None]: + """ + Generate a sequence of tokens for a single prompt. + See generate_batched_token_sequence. + """ + + for tokens in generate_batched_token_sequence(llm, + [ prompt ], + max_generation_length=max_generation_length, + temperature=temperature, + top_k=top_k, + top_p=top_p): + yield cast(int, tokens[0]) + +def sample_batched_sequence(llm: LLM, + prompts: List[List[int]], + max_generation_length: Optional[int] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: float = 1.0, + include_prompt: bool = False) -> List[List[int]]: + """ + Generate and sample a sequence of tokens for each input sequence in a batch. + + Args: + - llm (LLM): An instance of the language model. + - prompts (List[List[int]]): A list of tokenized input sequences (prompts). + - max_generation_length (Optional[int]): The maximum number of tokens to + generate for each prompt. Defaults to None, which allows the generation + up to the model's maximum context length. + - temperature (float): A scaling factor for logits before sampling, + affecting the randomness of the output. Default is 1.0, with lower values + leading to less random samples. + - top_k (Optional[int]): Limits the sampling pool to the top k tokens + according to the probability distribution. Default is None, indicating no + limit. + - top_p (float): The cumulative probability threshold for nucleus sampling; + allows sampling from a set of high-probability tokens whose cumulative + probability exceeds this threshold. Default is 1.0, indicating no limit. + - include_prompt (bool): If True, includes the input prompt at the beginning + of the generated sequence. Default is False. + + Returns: + - List[List[int]]: A list of lists containing the sampled token sequences + for each input prompt. The sequences include the generated tokens and, + if include_prompt is True, the original input prompt tokens. + """ + + sampled_seqs: List[List[int]] = [[] for _ in range(len(prompts))] + + if include_prompt: + for i, prompt in enumerate(prompts): + sampled_seqs[i].extend(prompt) + + for generated_tokens in generate_batched_token_sequence(llm, + prompts, + max_generation_length, + temperature, + top_k, + top_p): + for i, token in enumerate(generated_tokens): + if token is not None: + sampled_seqs[i].append(token) + + return sampled_seqs + +def sample_sequence(llm: LLM, + prompts: List[int], + max_generation_length: Optional[int] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: float = 1.0, + include_prompt: bool = False) -> List[int]: + """ + Generate and sample a sequence of tokens for a single input sequence. + See sample_batched_sequence for reference. + """ + + return sample_batched_sequence( + llm, [prompts], max_generation_length, temperature, + top_k, top_p, include_prompt + )[0] + +__all__ = [ + 'token_probabilities', + 'sample_batched_token', + 'sample_sequence', + 'sample_batched_sequence', + 'generate_token_sequence', + 'generate_batched_token_sequence', +] |