aboutsummaryrefslogtreecommitdiff
path: root/src/gpt_chat_cli/argvalidation.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpt_chat_cli/argvalidation.py')
-rw-r--r--src/gpt_chat_cli/argvalidation.py198
1 files changed, 198 insertions, 0 deletions
diff --git a/src/gpt_chat_cli/argvalidation.py b/src/gpt_chat_cli/argvalidation.py
new file mode 100644
index 0000000..16987a9
--- /dev/null
+++ b/src/gpt_chat_cli/argvalidation.py
@@ -0,0 +1,198 @@
+import logging
+
+from enum import Enum
+from dataclasses import dataclass
+from typing import Tuple, Optional
+
+import sys
+import os
+
+from .argparsing import (
+ RawArguments,
+ AutoDetectedOption,
+)
+
+######################
+## PUBLIC INTERFACE ##
+######################
+
+@dataclass
+class CompletionArguments:
+ model: str
+ n_completions: int
+ temperature: float
+ presence_penalty: float
+ frequency_penalty: float
+ max_tokens: int
+ top_p: float
+
+@dataclass
+class DisplayArguments:
+ adornments: bool
+ color: bool
+
+@dataclass
+class DebugArguments:
+ save_response_to_file: Optional[str]
+ load_response_from_file: Optional[str]
+
+@dataclass
+class MessageSource:
+ message: Optional[str] = None
+ prompt_from_fd: Optional[str] = None
+ prompt_from_file: Optional[str] = None
+
+@dataclass
+class Arguments:
+ completion_args: CompletionArguments
+ display_args: DisplayArguments
+ version: bool
+ list_models: bool
+ interactive: bool
+ initial_message: MessageSource
+ openai_key: str
+ system_message: Optional[str] = None
+ debug_args: Optional[DebugArguments] = None
+
+def post_process_raw_args(raw_args : RawArguments) -> Arguments:
+ _populate_defaults(raw_args)
+ _issue_warnings(raw_args)
+ _validate_args(raw_args)
+ return _restructure_arguments(raw_args)
+
+#####################
+## PRIVATE LOGIC ##
+#####################
+
+def _restructure_arguments(raw_args : RawArguments) -> Arguments:
+
+ args = raw_args.args
+ openai_key = raw_args.openai_key
+
+ completion_args = CompletionArguments(
+ model=args.model,
+ n_completions=args.n_completions,
+ temperature=args.temperature,
+ presence_penalty=args.presence_penalty,
+ frequency_penalty=args.frequency_penalty,
+ max_tokens=args.max_tokens,
+ top_p=args.top_p,
+ )
+
+ msg_src = MessageSource(
+ message = args.message,
+ prompt_from_fd = args.prompt_from_fd,
+ prompt_from_file = args.prompt_from_file,
+ )
+
+ display_args = DisplayArguments(
+ adornments=(args.adornments == AutoDetectedOption.ON),
+ color=(args.color == AutoDetectedOption.ON),
+ )
+
+ debug_args = DebugArguments(
+ save_response_to_file=args.save_response_to_file,
+ load_response_from_file=args.load_response_from_file,
+ )
+
+ return Arguments(
+ initial_message=msg_src,
+ openai_key=openai_key,
+ completion_args=completion_args,
+ display_args=display_args,
+ debug_args=debug_args,
+ version=args.version,
+ list_models=args.list_models,
+ interactive=args.interactive,
+ system_message=args.system_message
+ )
+
+def _die_validation_err(err : str):
+ print(err, file=sys.stderr)
+ sys.exit(1)
+
+def _validate_args(raw_args : RawArguments) -> None:
+
+ args = raw_args.args
+ debug = raw_args.debug
+
+ if not 0 <= args.temperature <= 2:
+ _die_validation_err("Temperature must be between 0 and 2.")
+
+ if not -2 <= args.frequency_penalty <= 2:
+ _die_validation_err("Frequency penalty must be between -2.0 and 2.0.")
+
+ if not -2 <= args.presence_penalty <= 2:
+ _die_validation_err("Presence penalty must be between -2.0 and 2.0.")
+
+ if args.max_tokens < 1:
+ _die_validation_err("Max tokens must be greater than or equal to 1.")
+
+ if not 0 <= args.top_p <= 1:
+ _die_validation_err("Top_p must be between 0 and 1.")
+
+ if args.n_completions < 1:
+ _die_validation_err("Number of completions must be greater than or equal to 1.")
+
+ if args.interactive and args.n_completions != 1:
+ _die_validation_err("Only a single completion can be used in interactive mode")
+
+ if (args.prompt_from_fd or args.prompt_from_file) and args.message:
+ _die_validation_err("Cannot specify an initial message alongside --prompt_from_fd or --prompt_from_file")
+
+ if debug and args.interactive:
+
+ if args.interactive and (
+ args.save_response_to_file or args.load_response_from_file
+ ):
+ _die_validation_err("Save and load operations cannot be used in interactive mode")
+
+def _issue_warnings(raw_args : RawArguments):
+
+ if not raw_args.openai_key:
+ print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr)
+ print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr)
+
+ sys.exit(1)
+
+ if raw_args.debug:
+ logging.warning("Debugging mode and unstable features have been enabled.")
+
+ if raw_args.debug and raw_args.args.load_response_from_file:
+ logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {raw_args.args.load_response_from_file} was generated')
+
+def _populate_defaults(raw_args : RawArguments):
+
+ args = raw_args.args
+ debug = raw_args.debug
+
+ if args.color == AutoDetectedOption.AUTO:
+ if os.getenv("NO_COLOR"):
+ args.color = AutoDetectedOption.OFF
+ elif sys.stdout.isatty():
+ args.color = AutoDetectedOption.ON
+ else:
+ args.color = AutoDetectedOption.OFF
+
+ if args.adornments == AutoDetectedOption.AUTO:
+ if sys.stdout.isatty():
+ args.adornments = AutoDetectedOption.ON
+ else:
+ args.adornments = AutoDetectedOption.OFF
+
+ initial_message_specified = (
+ args.message or
+ args.prompt_from_fd or
+ args.prompt_from_file
+ )
+
+ if not initial_message_specified:
+ if debug and args.load_response_from_file:
+ args.interactive = False
+ elif sys.stdin.isatty():
+ args.interactive = True
+
+ if not debug:
+ args.load_response_from_file = None
+ args.save_response_to_file = None
+