diff options
author | flu0r1ne <flu0r1ne@flu0r1ne.net> | 2023-05-09 15:57:05 -0500 |
---|---|---|
committer | flu0r1ne <flu0r1ne@flu0r1ne.net> | 2023-05-09 15:57:05 -0500 |
commit | 4f684f7073149c48ea9dd962b64a8bf7a169a1cf (patch) | |
tree | 3d642519b77bf9a6e451af62e7a50d7ee99f5a1d /src/gpt_chat_cli/argparsing.py | |
parent | 61a9d6426cdb0e32436b704561b7a3ff47b60b04 (diff) | |
download | gpt-chat-cli-4f684f7073149c48ea9dd962b64a8bf7a169a1cf.tar.xz gpt-chat-cli-4f684f7073149c48ea9dd962b64a8bf7a169a1cf.zip |
Rewrote the argparsing functionality to enable autocompletion via the "kislyuk/argcomplete" package.
This essentially means:
- `argparsing.py` does a minimal amount of work to initialize the arg parser.
Then it attempts dynamic completion.
- `argvalidation.py` processes the raw arguments parsed in
`argparsing.py`, validates them, issues warnings (if required), and
splits them into logical groupings
- Commands in `gcli.py` have been moved to `cmd.py`
- `main.py` provides an initial control path to call these functions in
succession
Diffstat (limited to 'src/gpt_chat_cli/argparsing.py')
-rw-r--r-- | src/gpt_chat_cli/argparsing.py | 200 |
1 files changed, 35 insertions, 165 deletions
diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py index b026af8..a96e81c 100644 --- a/src/gpt_chat_cli/argparsing.py +++ b/src/gpt_chat_cli/argparsing.py @@ -1,48 +1,10 @@ import argparse +import argcomplete import os -import logging -import openai -import sys -from enum import Enum + from dataclasses import dataclass from typing import Tuple, Optional - -def die_validation_err(err : str): - print(err, file=sys.stderr) - sys.exit(1) - -def validate_args(args: argparse.Namespace, debug : bool = False) -> None: - - if not 0 <= args.temperature <= 2: - die_validation_err("Temperature must be between 0 and 2.") - - if not -2 <= args.frequency_penalty <= 2: - die_validation_err("Frequency penalty must be between -2.0 and 2.0.") - - if not -2 <= args.presence_penalty <= 2: - die_validation_err("Presence penalty must be between -2.0 and 2.0.") - - if args.max_tokens < 1: - die_validation_err("Max tokens must be greater than or equal to 1.") - - if not 0 <= args.top_p <= 1: - die_validation_err("Top_p must be between 0 and 1.") - - if args.n_completions < 1: - die_validation_err("Number of completions must be greater than or equal to 1.") - - if args.interactive and args.n_completions != 1: - die_validation_err("Only a single completion can be used in interactive mode") - - if (args.prompt_from_fd or args.prompt_from_file) and args.message: - die_validation_err("Cannot specify an initial message alongside --prompt_from_fd or --prompt_from_file") - - if debug and args.interactive: - - if args.interactive and ( - args.save_response_to_file or args.load_response_from_file - ): - die_validation_err("Save and load operations cannot be used in interactive mode") +from enum import Enum class AutoDetectedOption(Enum): ON = 'on' @@ -52,96 +14,49 @@ class AutoDetectedOption(Enum): def __str__(self : "AutoDetectedOption"): return self.value -@dataclass -class CompletionArguments: - model: str - n_completions: int - temperature: float - presence_penalty: float - frequency_penalty: float - max_tokens: int - top_p: float +###################### +## PUBLIC INTERFACE ## +###################### @dataclass -class DisplayArguments: - adornments: bool - color: bool +class RawArguments: + args : argparse.Namespace + debug : bool + openai_key : Optional[str] = None -@dataclass -class DebugArguments: - save_response_to_file: Optional[str] - load_response_from_file: Optional[str] +def parse_raw_args_or_complete() -> RawArguments: -@dataclass -class MessageSource: - message: Optional[str] = None - prompt_from_fd: Optional[str] = None - prompt_from_file: Optional[str] = None + parser, debug = _construct_parser() -@dataclass -class Arguments: - completion_args: CompletionArguments - display_args: DisplayArguments - version: bool - list_models: bool - interactive: bool - initial_message: MessageSource - system_message: Optional[str] = None - debug_args: Optional[DebugArguments] = None - -def split_arguments(args: argparse.Namespace) -> Arguments: - completion_args = CompletionArguments( - model=args.model, - n_completions=args.n_completions, - temperature=args.temperature, - presence_penalty=args.presence_penalty, - frequency_penalty=args.frequency_penalty, - max_tokens=args.max_tokens, - top_p=args.top_p, - ) - - msg_src = MessageSource( - message = args.message, - prompt_from_fd = args.prompt_from_fd, - prompt_from_file = args.prompt_from_file, - ) + argcomplete.autocomplete( parser ) - display_args = DisplayArguments( - adornments=(args.adornments == AutoDetectedOption.ON), - color=(args.color == AutoDetectedOption.ON), - ) + args = parser.parse_args() - debug_args = DebugArguments( - save_response_to_file=args.save_response_to_file, - load_response_from_file=args.load_response_from_file, - ) + openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY")) - return Arguments( - initial_message=msg_src, - completion_args=completion_args, - display_args=display_args, - debug_args=debug_args, - version=args.version, - list_models=args.list_models, - interactive=args.interactive, - system_message=args.system_message + return RawArguments( + args = args, + debug = debug, + openai_key = openai_key ) -def parse_args() -> Arguments: +##################### +## PRIVATE LOGIC ## +##################### - GPT_CLI_ENV_PREFIX = "GPT_CLI_" +_GPT_CLI_ENV_PREFIX = "GPT_CLI_" - debug = os.getenv(f'{GPT_CLI_ENV_PREFIX}DEBUG') is not None +def _construct_parser() \ + -> Tuple[argparse.ArgumentParser, bool]: - if debug: - logging.warning("Debugging mode and unstable features have been enabled.") + debug = os.getenv(f'{_GPT_CLI_ENV_PREFIX}DEBUG') is not None parser = argparse.ArgumentParser() parser.add_argument( "-m", "--model", - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"), help="ID of the model to use", ) @@ -149,7 +64,7 @@ def parse_args() -> Arguments: "-t", "--temperature", type=float, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}TEMPERATURE', 0.5), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}TEMPERATURE', 0.5), help=( "What sampling temperature to use, between 0 and 2. Higher values " "like 0.8 will make the output more random, while lower values " @@ -161,7 +76,7 @@ def parse_args() -> Arguments: "-f", "--frequency-penalty", type=float, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}FREQUENCY_PENALTY', 0), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}FREQUENCY_PENALTY', 0), help=( "Number between -2.0 and 2.0. Positive values penalize new tokens based " "on their existing frequency in the text so far, decreasing the model's " @@ -173,7 +88,7 @@ def parse_args() -> Arguments: "-p", "--presence-penalty", type=float, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}PRESENCE_PENALTY', 0), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}PRESENCE_PENALTY', 0), help=( "Number between -2.0 and 2.0. Positive values penalize new tokens based " "on whether they appear in the text so far, increasing the model's " @@ -185,7 +100,7 @@ def parse_args() -> Arguments: "-k", "--max-tokens", type=int, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}MAX_TOKENS', 2048), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}MAX_TOKENS', 2048), help=( "The maximum number of tokens to generate in the chat completion. " "Defaults to 2048." @@ -196,7 +111,7 @@ def parse_args() -> Arguments: "-s", "--top-p", type=float, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}TOP_P', 1), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}TOP_P', 1), help=( "An alternative to sampling with temperature, called nucleus sampling, " "where the model considers the results of the tokens with top_p " @@ -209,14 +124,14 @@ def parse_args() -> Arguments: "-n", "--n-completions", type=int, - default=os.getenv('f{GPT_CLI_ENV_PREFIX}N_COMPLETIONS', 1), + default=os.getenv('f{_GPT_CLI_ENV_PREFIX}N_COMPLETIONS', 1), help="How many chat completion choices to generate for each input message.", ) parser.add_argument( "--system-message", type=str, - default=os.getenv('f{GPT_CLI_ENV_PREFIX}SYSTEM_MESSAGE'), + default=os.getenv('f{_GPT_CLI_ENV_PREFIX}SYSTEM_MESSAGE'), help="Specify an alternative system message.", ) @@ -298,49 +213,4 @@ def parse_args() -> Arguments: help="UNSTABLE: load a response from a file. This can reply a response for debugging purposes", ) - openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY")) - if not openai_key: - print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr) - print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr) - sys.exit(1) - - openai.api_key = openai_key - - args = parser.parse_args() - - if debug and args.load_response_from_file: - logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {args.load_response_from_file} was generated') - - if args.color == AutoDetectedOption.AUTO: - if os.getenv("NO_COLOR"): - args.color = AutoDetectedOption.OFF - elif sys.stdout.isatty(): - args.color = AutoDetectedOption.ON - else: - args.color = AutoDetectedOption.OFF - - if args.adornments == AutoDetectedOption.AUTO: - if sys.stdout.isatty(): - args.adornments = AutoDetectedOption.ON - else: - args.adornments = AutoDetectedOption.OFF - - initial_message_specified = ( - args.message or - args.prompt_from_fd or - args.prompt_from_file - ) - - if not initial_message_specified: - if debug and args.load_response_from_file: - args.interactive = False - elif sys.stdin.isatty(): - args.interactive = True - - if not debug: - args.load_response_from_file = None - args.save_response_to_file = None - - validate_args(args, debug=debug) - - return split_arguments(args) + return parser, debug |