From af5a2996234768921b81d96ffaae00cb88229862 Mon Sep 17 00:00:00 2001 From: flu0r1ne Date: Wed, 1 Nov 2023 20:46:01 -0500 Subject: Initial commit --- inference_example.py | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 inference_example.py (limited to 'inference_example.py') diff --git a/inference_example.py b/inference_example.py new file mode 100644 index 0000000..403ceb9 --- /dev/null +++ b/inference_example.py @@ -0,0 +1,51 @@ +import argparse + +from llama import ( + load_llama_from_checkpoint, + generate_token_sequence, +) + +def main(args: argparse.Namespace) -> None: + llama, tokenizer = load_llama_from_checkpoint( + args.model_directory, + args.tokenizer + ) + + context: str = args.context + + prompt = tokenizer.encode(context, True, False) + + print(f'Prompt: {context}') + print(f'Generated: ', end='') + + for token in generate_token_sequence(llama, + prompt, + top_p = args.top_p, + max_generation_length = args.max_generation_length): + piece = tokenizer.id_to_piece(token) + print(piece, end='', flush=True) + print() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Generate text using a Llama model.') + + parser.add_argument('model_directory', type=str, + help='Path to the directory containing the Llama model.') + parser.add_argument('tokenizer', type=str, + help='Path to the tokenizer model file.') + parser.add_argument('--context', type=str, default='Hello, world!', + help='Initial context to seed the Llama model.') + parser.add_argument('--max_generation_length', type=int, default=None, + help='Maximum length of the generated sequence.') + parser.add_argument('--top_p', type=float, default=0.80, + help='The cumulative distribution function (CDF) to use for sampling.') + + try: + import argcomplete + argcomplete.autocomplete(parser) + except ImportError: + pass + + args = parser.parse_args() + + main(args) -- cgit v1.2.3