diff options
| author | flu0r1ne <flu0r1ne@flu0r1ne.net> | 2023-05-04 18:10:32 -0500 | 
|---|---|---|
| committer | flu0r1ne <flu0r1ne@flu0r1ne.net> | 2023-05-04 18:10:32 -0500 | 
| commit | 44b15a775e3f5370c81507534df103f63611b956 (patch) | |
| tree | 6cefb2252ce10f922e21438d1cc9747c0ce67ec2 /argparsing.py | |
| download | gpt-chat-cli-44b15a775e3f5370c81507534df103f63611b956.tar.xz gpt-chat-cli-44b15a775e3f5370c81507534df103f63611b956.zip | |
Add version 0.0.1 which lacks interactive features but streams colorized output
Diffstat (limited to 'argparsing.py')
| -rw-r--r-- | argparsing.py | 193 | 
1 files changed, 193 insertions, 0 deletions
| diff --git a/argparsing.py b/argparsing.py new file mode 100644 index 0000000..a7d3218 --- /dev/null +++ b/argparsing.py @@ -0,0 +1,193 @@ +import argparse +import os +import logging +import openai +import sys +from enum import Enum + +class AutoDetectedOption(Enum): +    ON = 'on' +    OFF = 'off' +    AUTO = 'auto' + +def die_validation_err(err : str): +    print(err, file=sys.stderr) +    sys.exit(1) + +def validate_args(args: argparse.Namespace) -> None: +    if not 0 <= args.temperature <= 2: +        die_validation_err("Temperature must be between 0 and 2.") + +    if not -2 <= args.frequency_penalty <= 2: +        die_validation_err("Frequency penalty must be between -2.0 and 2.0.") + +    if not -2 <= args.presence_penalty <= 2: +        die_validation_err("Presence penalty must be between -2.0 and 2.0.") + +    if args.max_tokens < 1: +        die_validation_err("Max tokens must be greater than or equal to 1.") + +    if not 0 <= args.top_p <= 1: +        die_validation_err("Top_p must be between 0 and 1.") + +    if args.n_completions < 1: +        die_validation_err("Number of completions must be greater than or equal to 1.") + +def parse_args(): + +    GCLI_ENV_PREFIX = "GCLI_" + +    debug = os.getenv(f'{GCLI_ENV_PREFIX}DEBUG') is not None + +    if debug: +        logging.warning("Debugging mode and unstable features have been enabled.") + +    parser = argparse.ArgumentParser() + +    parser.add_argument( +        "-m", +        "--model", +        default=os.getenv(f'{GCLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"), +        help="ID of the model to use", +    ) + +    parser.add_argument( +        "-t", +        "--temperature", +        type=float, +        default=os.getenv(f'{GCLI_ENV_PREFIX}TEMPERATURE', 0.5), +        help=( +            "What sampling temperature to use, between 0 and 2. Higher values " +            "like 0.8 will make the output more random, while lower values " +            "like 0.2 will make it more focused and deterministic." +        ), +    ) + +    parser.add_argument( +        "-f", +        "--frequency-penalty", +        type=float, +        default=os.getenv(f'{GCLI_ENV_PREFIX}FREQUENCY_PENALTY', 0), +        help=( +            "Number between -2.0 and 2.0. Positive values penalize new tokens based " +            "on their existing frequency in the text so far, decreasing the model's " +            "likelihood to repeat the same line verbatim." +        ), +    ) + +    parser.add_argument( +        "-p", +        "--presence-penalty", +        type=float, +        default=os.getenv(f'{GCLI_ENV_PREFIX}PRESENCE_PENALTY', 0), +        help=( +            "Number between -2.0 and 2.0. Positive values penalize new tokens based " +            "on whether they appear in the text so far, increasing the model's " +            "likelihood to talk about new topics." +        ), +    ) + +    parser.add_argument( +        "-k", +        "--max-tokens", +        type=int, +        default=os.getenv(f'{GCLI_ENV_PREFIX}MAX_TOKENS', 2048), +        help=( +            "The maximum number of tokens to generate in the chat completion. " +            "Defaults to 2048." +        ), +    ) + +    parser.add_argument( +        "-s", +        "--top-p", +        type=float, +        default=os.getenv(f'{GCLI_ENV_PREFIX}TOP_P', 1), +        help=( +            "An alternative to sampling with temperature, called nucleus sampling, " +            "where the model considers the results of the tokens with top_p " +            "probability mass. So 0.1 means only the tokens comprising the top 10%% " +            "probability mass are considered." +        ), +    ) + +    parser.add_argument( +        "-n", +        "--n-completions", +        type=int, +        default=os.getenv('f{GCLI_ENV_PREFIX}N_COMPLETIONS', 1), +        help="How many chat completion choices to generate for each input message.", +    ) + +    parser.add_argument( +        "--adornments", +        type=AutoDetectedOption, +        choices=list(AutoDetectedOption), +        default=AutoDetectedOption.AUTO, +        help=( +            "Show adornments to indicate the model and response." +            " Can be set to 'on', 'off', or 'auto'." +        ) +    ) + +    parser.add_argument( +        "--color", +        type=AutoDetectedOption, +        choices=list(AutoDetectedOption), +        default=AutoDetectedOption.AUTO, +        help="Set color to 'on', 'off', or 'auto'.", +    ) + +    parser.add_argument( +        "message", +        type=str, +        help=( +            "The contents of the message. When used in chat mode, this is the initial " +            "message if provided." +        ), +    ) + +    if debug: +        group = parser.add_mutually_exclusive_group() + +        group.add_argument( +            '--save-response-to-file', +            type=str, +            help="UNSTABLE: save the response to a file. This can reply a response for debugging purposes", +        ) + +        group.add_argument( +            '--load-response-from-file', +            type=str, +            help="UNSTABLE: load a response from a file. This can reply a response for debugging purposes", +        ) + +    openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY")) +    if not openai_key: +        print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr) +        print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr) +        sys.exit(1) + +    openai.api_key = openai_key + +    args = parser.parse_args() + +    if debug and args.load_response_from_file: +        logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {args.load_response_from_file} was generated') + +    if args.color == AutoDetectedOption.AUTO: +        if os.getenv("NO_COLOR"): +            args.color = AutoDetectedOption.OFF +        else: +            args.color = AutoDetectedOption.ON + +    if args.adornments == AutoDetectedOption.AUTO: +        args.adornments = AutoDetectedOption.ON + +    if not debug: +        args.load_response_from_file = None +        args.save_response_to_file = None + +    validate_args(args) + +    return args | 
