aboutsummaryrefslogtreecommitdiff
path: root/llama/generate.py
diff options
context:
space:
mode:
authorflu0r1ne <flu0r1ne@flu0r1ne.net>2023-11-01 20:46:01 -0500
committerflu0r1ne <flu0r1ne@flu0r1ne.net>2023-11-01 20:46:01 -0500
commitaf5a2996234768921b81d96ffaae00cb88229862 (patch)
tree5b2a688582652fc8080616ccc0de162198aa8ee0 /llama/generate.py
downloadmyllama2-main.tar.xz
myllama2-main.zip
Initial commitHEADmain
Diffstat (limited to 'llama/generate.py')
-rw-r--r--llama/generate.py322
1 files changed, 322 insertions, 0 deletions
diff --git a/llama/generate.py b/llama/generate.py
new file mode 100644
index 0000000..e735ca4
--- /dev/null
+++ b/llama/generate.py
@@ -0,0 +1,322 @@
+"""
+This module provides an interface for generating and sampling token sequences from a language model.
+It allows for the controlled generation of text by specifying parameters such as temperature, top-k,
+and top-p, which influence the randomness and quality of the generated sequences.
+"""
+
+# pylint: disable=locally-disabled, R0913, R0914
+
+from collections import OrderedDict
+from typing import List, Optional, Generator, cast
+
+import torch
+
+from .llm import LLM
+
+TokenList = OrderedDict[int, float]
+
+def _sample_internal(llm: LLM, context: torch.Tensor) -> torch.Tensor:
+ """
+ Sample a tensor of logits from the language model (LLM) based on the input context.
+ """
+
+ batch_size, seq_len = context.shape
+
+ assert seq_len <= llm.context_length
+ assert batch_size <= llm.max_batch_size
+
+ with torch.inference_mode():
+ return llm(context)
+
+def _load_context(tokens: List[List[int]], pad_id: int,
+ pad_to_length: Optional[int] = None) -> torch.Tensor:
+ """
+ Load a batch of token lists into a padded tensor suitable for input to a language model.
+ """
+ batch_size = len(tokens)
+
+ max_token_len = max((len(tok) for tok in tokens))
+
+ pad_to_length = max_token_len if pad_to_length is None else pad_to_length
+
+ context = torch.full(
+ (batch_size, pad_to_length), pad_id, dtype=torch.long
+ )
+
+ for dim, toks in enumerate(tokens):
+ context[dim, :len(toks)] = torch.tensor(toks, dtype=torch.long)
+
+ return context
+
+def batched_token_probabilities(llm: LLM,
+ tokens: List[List[int]],
+ temperature: float = 1.0) -> List[TokenList]:
+ """
+ Calculate the probabilities of the next token sequence across a batch of sequences.
+
+ Args:
+ - llm (LLM): An instance of the language model.
+ - tokens (List[List[int]]): A list of tokenized input sequences.
+ - temperature (float): A temperature parameter to scale the logits before
+ applying softmax. Default is 1.0.
+
+ Returns:
+ - List[TokenList]: A list of ordered dictionaries where each dictionary maps
+ token ids to their corresponding probabilities for each
+ sequence in the batch.
+ """
+
+ context = _load_context(tokens, llm.padding_idx)
+
+ token_logprobs = _sample_internal(llm, context)
+
+ token_probs = torch.softmax(token_logprobs / temperature, dim=-1)
+
+ samples: List[TokenList] = [OrderedDict() for _ in range(len(tokens))]
+ for i, p in enumerate(token_probs):
+ for _id in torch.argsort(p, descending=True):
+ samples[i][int(_id)] = float(p[_id])
+
+ return samples
+
+def token_probabilities(llm: LLM, tokens: List[int]) -> TokenList:
+ """
+ Calculate the probabilities of the next token sequence.
+ See batched_token_probabilities.
+ """
+
+ return batched_token_probabilities(llm, [ tokens ])[0]
+
+def sample_batched_token(
+ token_logprobs: torch.Tensor,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: float = 1.0,
+ sample_eps: float = 1e-6) -> torch.Tensor:
+ """
+ Sample a token from a batch of token logits with optional top-k and top-p filtering.
+
+
+ Args:
+ - token_logprobs (torch.Tensor): A tensor of token log probabilities.
+ - temperature (float): A scaling factor for logits before sampling. Default
+ is 1.0.
+ - top_k (Optional[int]): If set, the sampling is restricted to the top k
+ tokens. Default is None (no restriction).
+ - top_p (float): If set, the sampling is restricted to the smallest set
+ of tokens with cumulative probability exceeding top_p.
+ Default is 1.0 (no restriction).
+ - sample_eps (float): An epsilon value to avoid precision errors during
+ cumulative probability calculation. Default is 1e-6.
+
+ Returns:
+ - torch.Tensor: A tensor of sampled token ids for each item in the batch.
+
+ Implements both top-k sampling, top-p sampling, and beam search.
+
+ See:
+ - https://arxiv.org/pdf/1805.04833.pdf for top-k sampling
+ - https://arxiv.org/pdf/1904.09751.pdf for top-p sampling
+ """
+
+ batch_size = token_logprobs.shape[0]
+
+ token_probs = torch.softmax(token_logprobs / temperature, dim=-1)
+
+
+ selected_tokens = torch.zeros(batch_size, dtype=torch.long)
+
+ sorted_tokens = torch.argsort(token_probs, descending=True)
+ sorted_probs = torch.gather(token_probs, 1, sorted_tokens)
+ nucleus_mask = sorted_probs.cumsum(dim=-1) < top_p + sample_eps
+ nucleus_mask[:,0] = True
+
+ for i, (tokens, mask, probs) in enumerate(zip(sorted_tokens, nucleus_mask, sorted_probs)):
+ nucleus = tokens[mask]
+ p = probs[mask]
+
+ if top_k is not None and len(nucleus) > top_k:
+ nucleus = nucleus[:top_k]
+ p = p[:top_k]
+
+ p /= p.sum(axis=0)
+
+ token = nucleus[torch.multinomial(p, 1)]
+
+ selected_tokens[i] = token
+
+ return selected_tokens
+
+def generate_batched_token_sequence(llm: LLM,
+ prompts: List[List[int]],
+ max_generation_length: Optional[int] = None,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: float = 1.0) -> Generator[List[Optional[int]], None, None]:
+ """
+ Generate a sequence of tokens for each prompt across a sequence of batches.
+
+ Args:
+ - llm (LLM): An instance of the language model.
+ - prompts (List[List[int]]): A list of tokenized input sequences (prompts).
+ - max_generation_length (Optional[int]): The maximum number of tokens to
+ generate for each prompt. If None, generate up to the model's maximum
+ context length. Default is None.
+ - temperature (float): A scaling factor for logits before sampling. Default
+ is 1.0.
+ - top_k (Optional[int]): If set, restricts sampling to the top k most
+ likely tokens. Default is None (no restriction).
+ - top_p (float): If set, restricts sampling to a subset of tokens with a
+ cumulative probability greater than top_p. Default is 1.0
+ (no restriction).
+
+ Yields:
+ - Generator[List[Optional[int]], None, None]: A generator that yields lists
+ of token ids, where each list corresponds to one prompt in the batch.
+ Yields none if a token was not generated during an iteration of inference.
+
+ Raises:
+ - AssertionError: If batch size exceeds the maximum allowed by the LLM, or
+ if the requested generation length is too long.
+ """
+
+ batch_size = len(prompts)
+ assert batch_size <= llm.max_batch_size, \
+ "Batch size exceeded the maximum batch size of the LLM"
+
+ prompt_lens = torch.tensor([len(p) for p in prompts], dtype=torch.long)
+ max_prompt_len = max(prompt_lens)
+
+ remaining_context = llm.context_length - max_prompt_len
+ if max_generation_length is None:
+ max_generation_length = remaining_context
+ else:
+ assert max_generation_length <= remaining_context, \
+ "Cannot generate more tokens than exist in the context"
+
+ eos = torch.zeros(batch_size, dtype=torch.long)
+ last_pos = 0
+
+ end_pos = max_prompt_len + max_generation_length
+ context = _load_context(prompts, llm.padding_idx, pad_to_length=end_pos)
+
+ start_pos = max(prompt_lens)
+
+ for pos in range(start_pos, end_pos):
+ log_probs = llm(context[:, last_pos:pos], last_pos)
+
+ sampled = sample_batched_token(
+ log_probs,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p
+ )
+
+ in_prompt = pos < prompt_lens
+ should_replace_mask = (eos == 0) & (~in_prompt)
+
+ yield [int(sampled[i]) if should_replace_mask[i] else None for i in range(batch_size)]
+
+ context[should_replace_mask, pos] = sampled[should_replace_mask]
+ eos[(eos > 0) & (sampled == llm.eos_token)] = pos + 1
+ last_pos = pos
+
+ if (eos > 0).all():
+ break
+
+def generate_token_sequence(llm: LLM,
+ prompt: List[int],
+ max_generation_length: Optional[int] = None,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: float = 1.0) -> Generator[int, None, None]:
+ """
+ Generate a sequence of tokens for a single prompt.
+ See generate_batched_token_sequence.
+ """
+
+ for tokens in generate_batched_token_sequence(llm,
+ [ prompt ],
+ max_generation_length=max_generation_length,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p):
+ yield cast(int, tokens[0])
+
+def sample_batched_sequence(llm: LLM,
+ prompts: List[List[int]],
+ max_generation_length: Optional[int] = None,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: float = 1.0,
+ include_prompt: bool = False) -> List[List[int]]:
+ """
+ Generate and sample a sequence of tokens for each input sequence in a batch.
+
+ Args:
+ - llm (LLM): An instance of the language model.
+ - prompts (List[List[int]]): A list of tokenized input sequences (prompts).
+ - max_generation_length (Optional[int]): The maximum number of tokens to
+ generate for each prompt. Defaults to None, which allows the generation
+ up to the model's maximum context length.
+ - temperature (float): A scaling factor for logits before sampling,
+ affecting the randomness of the output. Default is 1.0, with lower values
+ leading to less random samples.
+ - top_k (Optional[int]): Limits the sampling pool to the top k tokens
+ according to the probability distribution. Default is None, indicating no
+ limit.
+ - top_p (float): The cumulative probability threshold for nucleus sampling;
+ allows sampling from a set of high-probability tokens whose cumulative
+ probability exceeds this threshold. Default is 1.0, indicating no limit.
+ - include_prompt (bool): If True, includes the input prompt at the beginning
+ of the generated sequence. Default is False.
+
+ Returns:
+ - List[List[int]]: A list of lists containing the sampled token sequences
+ for each input prompt. The sequences include the generated tokens and,
+ if include_prompt is True, the original input prompt tokens.
+ """
+
+ sampled_seqs: List[List[int]] = [[] for _ in range(len(prompts))]
+
+ if include_prompt:
+ for i, prompt in enumerate(prompts):
+ sampled_seqs[i].extend(prompt)
+
+ for generated_tokens in generate_batched_token_sequence(llm,
+ prompts,
+ max_generation_length,
+ temperature,
+ top_k,
+ top_p):
+ for i, token in enumerate(generated_tokens):
+ if token is not None:
+ sampled_seqs[i].append(token)
+
+ return sampled_seqs
+
+def sample_sequence(llm: LLM,
+ prompts: List[int],
+ max_generation_length: Optional[int] = None,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: float = 1.0,
+ include_prompt: bool = False) -> List[int]:
+ """
+ Generate and sample a sequence of tokens for a single input sequence.
+ See sample_batched_sequence for reference.
+ """
+
+ return sample_batched_sequence(
+ llm, [prompts], max_generation_length, temperature,
+ top_k, top_p, include_prompt
+ )[0]
+
+__all__ = [
+ 'token_probabilities',
+ 'sample_batched_token',
+ 'sample_sequence',
+ 'sample_batched_sequence',
+ 'generate_token_sequence',
+ 'generate_batched_token_sequence',
+]