aboutsummaryrefslogtreecommitdiff
path: root/inference_example.py
blob: 403ceb993e68e14644f3878cdf1756b0bf48dd3a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse

from llama import (
    load_llama_from_checkpoint,
    generate_token_sequence,
)

def main(args: argparse.Namespace) -> None:
    llama, tokenizer = load_llama_from_checkpoint(
        args.model_directory,
        args.tokenizer
    )

    context: str = args.context

    prompt = tokenizer.encode(context, True, False)

    print(f'Prompt: {context}')
    print(f'Generated: ', end='')

    for token in generate_token_sequence(llama,
                                         prompt,
                                         top_p = args.top_p,
                                         max_generation_length = args.max_generation_length):
      piece = tokenizer.id_to_piece(token)
      print(piece, end='', flush=True)
    print()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Generate text using a Llama model.')

    parser.add_argument('model_directory', type=str,
                        help='Path to the directory containing the Llama model.')
    parser.add_argument('tokenizer', type=str,
                        help='Path to the tokenizer model file.')
    parser.add_argument('--context', type=str, default='Hello, world!',
                        help='Initial context to seed the Llama model.')
    parser.add_argument('--max_generation_length', type=int, default=None,
                        help='Maximum length of the generated sequence.')
    parser.add_argument('--top_p', type=float, default=0.80,
                        help='The cumulative distribution function (CDF) to use for sampling.')

    try:
        import argcomplete
        argcomplete.autocomplete(parser)
    except ImportError:
        pass

    args = parser.parse_args()

    main(args)