aboutsummaryrefslogtreecommitdiff
path: root/src/gpt_chat_cli/argparsing.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpt_chat_cli/argparsing.py')
-rw-r--r--src/gpt_chat_cli/argparsing.py200
1 files changed, 35 insertions, 165 deletions
diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py
index b026af8..a96e81c 100644
--- a/src/gpt_chat_cli/argparsing.py
+++ b/src/gpt_chat_cli/argparsing.py
@@ -1,48 +1,10 @@
import argparse
+import argcomplete
import os
-import logging
-import openai
-import sys
-from enum import Enum
+
from dataclasses import dataclass
from typing import Tuple, Optional
-
-def die_validation_err(err : str):
- print(err, file=sys.stderr)
- sys.exit(1)
-
-def validate_args(args: argparse.Namespace, debug : bool = False) -> None:
-
- if not 0 <= args.temperature <= 2:
- die_validation_err("Temperature must be between 0 and 2.")
-
- if not -2 <= args.frequency_penalty <= 2:
- die_validation_err("Frequency penalty must be between -2.0 and 2.0.")
-
- if not -2 <= args.presence_penalty <= 2:
- die_validation_err("Presence penalty must be between -2.0 and 2.0.")
-
- if args.max_tokens < 1:
- die_validation_err("Max tokens must be greater than or equal to 1.")
-
- if not 0 <= args.top_p <= 1:
- die_validation_err("Top_p must be between 0 and 1.")
-
- if args.n_completions < 1:
- die_validation_err("Number of completions must be greater than or equal to 1.")
-
- if args.interactive and args.n_completions != 1:
- die_validation_err("Only a single completion can be used in interactive mode")
-
- if (args.prompt_from_fd or args.prompt_from_file) and args.message:
- die_validation_err("Cannot specify an initial message alongside --prompt_from_fd or --prompt_from_file")
-
- if debug and args.interactive:
-
- if args.interactive and (
- args.save_response_to_file or args.load_response_from_file
- ):
- die_validation_err("Save and load operations cannot be used in interactive mode")
+from enum import Enum
class AutoDetectedOption(Enum):
ON = 'on'
@@ -52,96 +14,49 @@ class AutoDetectedOption(Enum):
def __str__(self : "AutoDetectedOption"):
return self.value
-@dataclass
-class CompletionArguments:
- model: str
- n_completions: int
- temperature: float
- presence_penalty: float
- frequency_penalty: float
- max_tokens: int
- top_p: float
+######################
+## PUBLIC INTERFACE ##
+######################
@dataclass
-class DisplayArguments:
- adornments: bool
- color: bool
+class RawArguments:
+ args : argparse.Namespace
+ debug : bool
+ openai_key : Optional[str] = None
-@dataclass
-class DebugArguments:
- save_response_to_file: Optional[str]
- load_response_from_file: Optional[str]
+def parse_raw_args_or_complete() -> RawArguments:
-@dataclass
-class MessageSource:
- message: Optional[str] = None
- prompt_from_fd: Optional[str] = None
- prompt_from_file: Optional[str] = None
+ parser, debug = _construct_parser()
-@dataclass
-class Arguments:
- completion_args: CompletionArguments
- display_args: DisplayArguments
- version: bool
- list_models: bool
- interactive: bool
- initial_message: MessageSource
- system_message: Optional[str] = None
- debug_args: Optional[DebugArguments] = None
-
-def split_arguments(args: argparse.Namespace) -> Arguments:
- completion_args = CompletionArguments(
- model=args.model,
- n_completions=args.n_completions,
- temperature=args.temperature,
- presence_penalty=args.presence_penalty,
- frequency_penalty=args.frequency_penalty,
- max_tokens=args.max_tokens,
- top_p=args.top_p,
- )
-
- msg_src = MessageSource(
- message = args.message,
- prompt_from_fd = args.prompt_from_fd,
- prompt_from_file = args.prompt_from_file,
- )
+ argcomplete.autocomplete( parser )
- display_args = DisplayArguments(
- adornments=(args.adornments == AutoDetectedOption.ON),
- color=(args.color == AutoDetectedOption.ON),
- )
+ args = parser.parse_args()
- debug_args = DebugArguments(
- save_response_to_file=args.save_response_to_file,
- load_response_from_file=args.load_response_from_file,
- )
+ openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY"))
- return Arguments(
- initial_message=msg_src,
- completion_args=completion_args,
- display_args=display_args,
- debug_args=debug_args,
- version=args.version,
- list_models=args.list_models,
- interactive=args.interactive,
- system_message=args.system_message
+ return RawArguments(
+ args = args,
+ debug = debug,
+ openai_key = openai_key
)
-def parse_args() -> Arguments:
+#####################
+## PRIVATE LOGIC ##
+#####################
- GPT_CLI_ENV_PREFIX = "GPT_CLI_"
+_GPT_CLI_ENV_PREFIX = "GPT_CLI_"
- debug = os.getenv(f'{GPT_CLI_ENV_PREFIX}DEBUG') is not None
+def _construct_parser() \
+ -> Tuple[argparse.ArgumentParser, bool]:
- if debug:
- logging.warning("Debugging mode and unstable features have been enabled.")
+ debug = os.getenv(f'{_GPT_CLI_ENV_PREFIX}DEBUG') is not None
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model",
- default=os.getenv(f'{GPT_CLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"),
+ default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"),
help="ID of the model to use",
)
@@ -149,7 +64,7 @@ def parse_args() -> Arguments:
"-t",
"--temperature",
type=float,
- default=os.getenv(f'{GPT_CLI_ENV_PREFIX}TEMPERATURE', 0.5),
+ default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}TEMPERATURE', 0.5),
help=(
"What sampling temperature to use, between 0 and 2. Higher values "
"like 0.8 will make the output more random, while lower values "
@@ -161,7 +76,7 @@ def parse_args() -> Arguments:
"-f",
"--frequency-penalty",
type=float,
- default=os.getenv(f'{GPT_CLI_ENV_PREFIX}FREQUENCY_PENALTY', 0),
+ default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}FREQUENCY_PENALTY', 0),
help=(
"Number between -2.0 and 2.0. Positive values penalize new tokens based "
"on their existing frequency in the text so far, decreasing the model's "
@@ -173,7 +88,7 @@ def parse_args() -> Arguments:
"-p",
"--presence-penalty",
type=float,
- default=os.getenv(f'{GPT_CLI_ENV_PREFIX}PRESENCE_PENALTY', 0),
+ default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}PRESENCE_PENALTY', 0),
help=(
"Number between -2.0 and 2.0. Positive values penalize new tokens based "
"on whether they appear in the text so far, increasing the model's "
@@ -185,7 +100,7 @@ def parse_args() -> Arguments:
"-k",
"--max-tokens",
type=int,
- default=os.getenv(f'{GPT_CLI_ENV_PREFIX}MAX_TOKENS', 2048),
+ default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}MAX_TOKENS', 2048),
help=(
"The maximum number of tokens to generate in the chat completion. "
"Defaults to 2048."
@@ -196,7 +111,7 @@ def parse_args() -> Arguments:
"-s",
"--top-p",
type=float,
- default=os.getenv(f'{GPT_CLI_ENV_PREFIX}TOP_P', 1),
+ default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}TOP_P', 1),
help=(
"An alternative to sampling with temperature, called nucleus sampling, "
"where the model considers the results of the tokens with top_p "
@@ -209,14 +124,14 @@ def parse_args() -> Arguments:
"-n",
"--n-completions",
type=int,
- default=os.getenv('f{GPT_CLI_ENV_PREFIX}N_COMPLETIONS', 1),
+ default=os.getenv('f{_GPT_CLI_ENV_PREFIX}N_COMPLETIONS', 1),
help="How many chat completion choices to generate for each input message.",
)
parser.add_argument(
"--system-message",
type=str,
- default=os.getenv('f{GPT_CLI_ENV_PREFIX}SYSTEM_MESSAGE'),
+ default=os.getenv('f{_GPT_CLI_ENV_PREFIX}SYSTEM_MESSAGE'),
help="Specify an alternative system message.",
)
@@ -298,49 +213,4 @@ def parse_args() -> Arguments:
help="UNSTABLE: load a response from a file. This can reply a response for debugging purposes",
)
- openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY"))
- if not openai_key:
- print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr)
- print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr)
- sys.exit(1)
-
- openai.api_key = openai_key
-
- args = parser.parse_args()
-
- if debug and args.load_response_from_file:
- logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {args.load_response_from_file} was generated')
-
- if args.color == AutoDetectedOption.AUTO:
- if os.getenv("NO_COLOR"):
- args.color = AutoDetectedOption.OFF
- elif sys.stdout.isatty():
- args.color = AutoDetectedOption.ON
- else:
- args.color = AutoDetectedOption.OFF
-
- if args.adornments == AutoDetectedOption.AUTO:
- if sys.stdout.isatty():
- args.adornments = AutoDetectedOption.ON
- else:
- args.adornments = AutoDetectedOption.OFF
-
- initial_message_specified = (
- args.message or
- args.prompt_from_fd or
- args.prompt_from_file
- )
-
- if not initial_message_specified:
- if debug and args.load_response_from_file:
- args.interactive = False
- elif sys.stdin.isatty():
- args.interactive = True
-
- if not debug:
- args.load_response_from_file = None
- args.save_response_to_file = None
-
- validate_args(args, debug=debug)
-
- return split_arguments(args)
+ return parser, debug