diff options
Diffstat (limited to 'llama/llm.py')
-rw-r--r-- | llama/llm.py | 74 |
1 files changed, 74 insertions, 0 deletions
diff --git a/llama/llm.py b/llama/llm.py new file mode 100644 index 0000000..7192b31 --- /dev/null +++ b/llama/llm.py @@ -0,0 +1,74 @@ +""" +LLM provides a generalized interface for autoregressive +next-word prediction models. The class can be utilized for tasks such as text +sampling and probability prediction over a vocabulary. +""" + +import torch +from torch import nn + +class LLM(nn.Module): + """ + LLM provides a generalized interface for autoregressive + next-word prediction models. The class can be utilized for tasks such as text + sampling and probability prediction over a vocabulary. + + Attributes: + context_length (int): Length of the context window for the + autoregressive model. Default is -1, which + indicates that this needs to be set. + + max_batch_size (int): The maximum size of a batch that can be processed. + Default is -1, which indicates that this needs to + be set. + + vocab_size (int): The size of the vocabulary used in the model. + Default is -1, which indicates that this needs to + be set. + + padding_idx (int): The index used for padding in mixed-length batches. + Default is -1, which indicates that this needs to be + set. + + eos_token (int): Token index that signifies the end of a sequence during + auto-regressive generation. Default is -1, which + indicates that this needs to be set. + """ + + context_length = -1 + max_batch_size = -1 + vocab_size = -1 + padding_idx = -1 + eos_token = -1 + + def forward(self, context: torch.Tensor, cur_pos: int = 0) -> torch.Tensor: + """ + Computes the log probabilities of the next token given a sequence of + tokens as context. + + Args: + context (torch.Tensor): A tensor of shape (batch_size, context_length) + containing token ids. These tokens serve as the + context for predicting the next token. + + cur_pos (int, optional): The position at which to start the + prediction. If cur_pos is not zero, + the internal cache (if available) will + be used to speed up predictions. + Defaults to 0. + + Returns: + torch.Tensor: A tensor of shape (batch_size, vocab_size) containing + the log probabilities of the next token given the + context. + + Examples: + # Predict the next token for a sequence [1, 2, 3] + log_probs = llm(torch.tensor([[1, 2, 3]], dtype=torch.long), 0) + + # Predict the next token for a sequence [1, 2, 3, 4, 5] using the + # cache starting at position 3 + log_probs = llm(torch.tensor([[4, 5]], dtype=torch.long), 3) + """ + + raise NotImplementedError() |