aboutsummaryrefslogtreecommitdiff
path: root/llama/llm.py
diff options
context:
space:
mode:
Diffstat (limited to 'llama/llm.py')
-rw-r--r--llama/llm.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/llama/llm.py b/llama/llm.py
new file mode 100644
index 0000000..7192b31
--- /dev/null
+++ b/llama/llm.py
@@ -0,0 +1,74 @@
+"""
+LLM provides a generalized interface for autoregressive
+next-word prediction models. The class can be utilized for tasks such as text
+sampling and probability prediction over a vocabulary.
+"""
+
+import torch
+from torch import nn
+
+class LLM(nn.Module):
+ """
+ LLM provides a generalized interface for autoregressive
+ next-word prediction models. The class can be utilized for tasks such as text
+ sampling and probability prediction over a vocabulary.
+
+ Attributes:
+ context_length (int): Length of the context window for the
+ autoregressive model. Default is -1, which
+ indicates that this needs to be set.
+
+ max_batch_size (int): The maximum size of a batch that can be processed.
+ Default is -1, which indicates that this needs to
+ be set.
+
+ vocab_size (int): The size of the vocabulary used in the model.
+ Default is -1, which indicates that this needs to
+ be set.
+
+ padding_idx (int): The index used for padding in mixed-length batches.
+ Default is -1, which indicates that this needs to be
+ set.
+
+ eos_token (int): Token index that signifies the end of a sequence during
+ auto-regressive generation. Default is -1, which
+ indicates that this needs to be set.
+ """
+
+ context_length = -1
+ max_batch_size = -1
+ vocab_size = -1
+ padding_idx = -1
+ eos_token = -1
+
+ def forward(self, context: torch.Tensor, cur_pos: int = 0) -> torch.Tensor:
+ """
+ Computes the log probabilities of the next token given a sequence of
+ tokens as context.
+
+ Args:
+ context (torch.Tensor): A tensor of shape (batch_size, context_length)
+ containing token ids. These tokens serve as the
+ context for predicting the next token.
+
+ cur_pos (int, optional): The position at which to start the
+ prediction. If cur_pos is not zero,
+ the internal cache (if available) will
+ be used to speed up predictions.
+ Defaults to 0.
+
+ Returns:
+ torch.Tensor: A tensor of shape (batch_size, vocab_size) containing
+ the log probabilities of the next token given the
+ context.
+
+ Examples:
+ # Predict the next token for a sequence [1, 2, 3]
+ log_probs = llm(torch.tensor([[1, 2, 3]], dtype=torch.long), 0)
+
+ # Predict the next token for a sequence [1, 2, 3, 4, 5] using the
+ # cache starting at position 3
+ log_probs = llm(torch.tensor([[4, 5]], dtype=torch.long), 3)
+ """
+
+ raise NotImplementedError()