diff options
Diffstat (limited to 'src/gpt_chat_cli/argvalidation.py')
-rw-r--r-- | src/gpt_chat_cli/argvalidation.py | 198 |
1 files changed, 198 insertions, 0 deletions
diff --git a/src/gpt_chat_cli/argvalidation.py b/src/gpt_chat_cli/argvalidation.py new file mode 100644 index 0000000..16987a9 --- /dev/null +++ b/src/gpt_chat_cli/argvalidation.py @@ -0,0 +1,198 @@ +import logging + +from enum import Enum +from dataclasses import dataclass +from typing import Tuple, Optional + +import sys +import os + +from .argparsing import ( + RawArguments, + AutoDetectedOption, +) + +###################### +## PUBLIC INTERFACE ## +###################### + +@dataclass +class CompletionArguments: + model: str + n_completions: int + temperature: float + presence_penalty: float + frequency_penalty: float + max_tokens: int + top_p: float + +@dataclass +class DisplayArguments: + adornments: bool + color: bool + +@dataclass +class DebugArguments: + save_response_to_file: Optional[str] + load_response_from_file: Optional[str] + +@dataclass +class MessageSource: + message: Optional[str] = None + prompt_from_fd: Optional[str] = None + prompt_from_file: Optional[str] = None + +@dataclass +class Arguments: + completion_args: CompletionArguments + display_args: DisplayArguments + version: bool + list_models: bool + interactive: bool + initial_message: MessageSource + openai_key: str + system_message: Optional[str] = None + debug_args: Optional[DebugArguments] = None + +def post_process_raw_args(raw_args : RawArguments) -> Arguments: + _populate_defaults(raw_args) + _issue_warnings(raw_args) + _validate_args(raw_args) + return _restructure_arguments(raw_args) + +##################### +## PRIVATE LOGIC ## +##################### + +def _restructure_arguments(raw_args : RawArguments) -> Arguments: + + args = raw_args.args + openai_key = raw_args.openai_key + + completion_args = CompletionArguments( + model=args.model, + n_completions=args.n_completions, + temperature=args.temperature, + presence_penalty=args.presence_penalty, + frequency_penalty=args.frequency_penalty, + max_tokens=args.max_tokens, + top_p=args.top_p, + ) + + msg_src = MessageSource( + message = args.message, + prompt_from_fd = args.prompt_from_fd, + prompt_from_file = args.prompt_from_file, + ) + + display_args = DisplayArguments( + adornments=(args.adornments == AutoDetectedOption.ON), + color=(args.color == AutoDetectedOption.ON), + ) + + debug_args = DebugArguments( + save_response_to_file=args.save_response_to_file, + load_response_from_file=args.load_response_from_file, + ) + + return Arguments( + initial_message=msg_src, + openai_key=openai_key, + completion_args=completion_args, + display_args=display_args, + debug_args=debug_args, + version=args.version, + list_models=args.list_models, + interactive=args.interactive, + system_message=args.system_message + ) + +def _die_validation_err(err : str): + print(err, file=sys.stderr) + sys.exit(1) + +def _validate_args(raw_args : RawArguments) -> None: + + args = raw_args.args + debug = raw_args.debug + + if not 0 <= args.temperature <= 2: + _die_validation_err("Temperature must be between 0 and 2.") + + if not -2 <= args.frequency_penalty <= 2: + _die_validation_err("Frequency penalty must be between -2.0 and 2.0.") + + if not -2 <= args.presence_penalty <= 2: + _die_validation_err("Presence penalty must be between -2.0 and 2.0.") + + if args.max_tokens < 1: + _die_validation_err("Max tokens must be greater than or equal to 1.") + + if not 0 <= args.top_p <= 1: + _die_validation_err("Top_p must be between 0 and 1.") + + if args.n_completions < 1: + _die_validation_err("Number of completions must be greater than or equal to 1.") + + if args.interactive and args.n_completions != 1: + _die_validation_err("Only a single completion can be used in interactive mode") + + if (args.prompt_from_fd or args.prompt_from_file) and args.message: + _die_validation_err("Cannot specify an initial message alongside --prompt_from_fd or --prompt_from_file") + + if debug and args.interactive: + + if args.interactive and ( + args.save_response_to_file or args.load_response_from_file + ): + _die_validation_err("Save and load operations cannot be used in interactive mode") + +def _issue_warnings(raw_args : RawArguments): + + if not raw_args.openai_key: + print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr) + print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr) + + sys.exit(1) + + if raw_args.debug: + logging.warning("Debugging mode and unstable features have been enabled.") + + if raw_args.debug and raw_args.args.load_response_from_file: + logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {raw_args.args.load_response_from_file} was generated') + +def _populate_defaults(raw_args : RawArguments): + + args = raw_args.args + debug = raw_args.debug + + if args.color == AutoDetectedOption.AUTO: + if os.getenv("NO_COLOR"): + args.color = AutoDetectedOption.OFF + elif sys.stdout.isatty(): + args.color = AutoDetectedOption.ON + else: + args.color = AutoDetectedOption.OFF + + if args.adornments == AutoDetectedOption.AUTO: + if sys.stdout.isatty(): + args.adornments = AutoDetectedOption.ON + else: + args.adornments = AutoDetectedOption.OFF + + initial_message_specified = ( + args.message or + args.prompt_from_fd or + args.prompt_from_file + ) + + if not initial_message_specified: + if debug and args.load_response_from_file: + args.interactive = False + elif sys.stdin.isatty(): + args.interactive = True + + if not debug: + args.load_response_from_file = None + args.save_response_to_file = None + |