aboutsummaryrefslogtreecommitdiff
path: root/inference_example.py
diff options
context:
space:
mode:
authorflu0r1ne <flu0r1ne@flu0r1ne.net>2023-11-01 20:46:01 -0500
committerflu0r1ne <flu0r1ne@flu0r1ne.net>2023-11-01 20:46:01 -0500
commitaf5a2996234768921b81d96ffaae00cb88229862 (patch)
tree5b2a688582652fc8080616ccc0de162198aa8ee0 /inference_example.py
downloadmyllama2-main.tar.xz
myllama2-main.zip
Initial commitHEADmain
Diffstat (limited to 'inference_example.py')
-rw-r--r--inference_example.py51
1 files changed, 51 insertions, 0 deletions
diff --git a/inference_example.py b/inference_example.py
new file mode 100644
index 0000000..403ceb9
--- /dev/null
+++ b/inference_example.py
@@ -0,0 +1,51 @@
+import argparse
+
+from llama import (
+ load_llama_from_checkpoint,
+ generate_token_sequence,
+)
+
+def main(args: argparse.Namespace) -> None:
+ llama, tokenizer = load_llama_from_checkpoint(
+ args.model_directory,
+ args.tokenizer
+ )
+
+ context: str = args.context
+
+ prompt = tokenizer.encode(context, True, False)
+
+ print(f'Prompt: {context}')
+ print(f'Generated: ', end='')
+
+ for token in generate_token_sequence(llama,
+ prompt,
+ top_p = args.top_p,
+ max_generation_length = args.max_generation_length):
+ piece = tokenizer.id_to_piece(token)
+ print(piece, end='', flush=True)
+ print()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Generate text using a Llama model.')
+
+ parser.add_argument('model_directory', type=str,
+ help='Path to the directory containing the Llama model.')
+ parser.add_argument('tokenizer', type=str,
+ help='Path to the tokenizer model file.')
+ parser.add_argument('--context', type=str, default='Hello, world!',
+ help='Initial context to seed the Llama model.')
+ parser.add_argument('--max_generation_length', type=int, default=None,
+ help='Maximum length of the generated sequence.')
+ parser.add_argument('--top_p', type=float, default=0.80,
+ help='The cumulative distribution function (CDF) to use for sampling.')
+
+ try:
+ import argcomplete
+ argcomplete.autocomplete(parser)
+ except ImportError:
+ pass
+
+ args = parser.parse_args()
+
+ main(args)