aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/gpt_chat_cli/argparsing.py200
-rw-r--r--src/gpt_chat_cli/argvalidation.py198
-rw-r--r--src/gpt_chat_cli/cmd.py (renamed from src/gpt_chat_cli/gcli.py)37
-rw-r--r--src/gpt_chat_cli/main.py38
-rw-r--r--src/gpt_chat_cli/openai_wrappers.py5
-rw-r--r--src/gpt_chat_cli/version.py2
6 files changed, 284 insertions, 196 deletions
diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py
index b026af8..a96e81c 100644
--- a/src/gpt_chat_cli/argparsing.py
+++ b/src/gpt_chat_cli/argparsing.py
@@ -1,48 +1,10 @@
import argparse
+import argcomplete
import os
-import logging
-import openai
-import sys
-from enum import Enum
+
from dataclasses import dataclass
from typing import Tuple, Optional
-
-def die_validation_err(err : str):
- print(err, file=sys.stderr)
- sys.exit(1)
-
-def validate_args(args: argparse.Namespace, debug : bool = False) -> None:
-
- if not 0 <= args.temperature <= 2:
- die_validation_err("Temperature must be between 0 and 2.")
-
- if not -2 <= args.frequency_penalty <= 2:
- die_validation_err("Frequency penalty must be between -2.0 and 2.0.")
-
- if not -2 <= args.presence_penalty <= 2:
- die_validation_err("Presence penalty must be between -2.0 and 2.0.")
-
- if args.max_tokens < 1:
- die_validation_err("Max tokens must be greater than or equal to 1.")
-
- if not 0 <= args.top_p <= 1:
- die_validation_err("Top_p must be between 0 and 1.")
-
- if args.n_completions < 1:
- die_validation_err("Number of completions must be greater than or equal to 1.")
-
- if args.interactive and args.n_completions != 1:
- die_validation_err("Only a single completion can be used in interactive mode")
-
- if (args.prompt_from_fd or args.prompt_from_file) and args.message:
- die_validation_err("Cannot specify an initial message alongside --prompt_from_fd or --prompt_from_file")
-
- if debug and args.interactive:
-
- if args.interactive and (
- args.save_response_to_file or args.load_response_from_file
- ):
- die_validation_err("Save and load operations cannot be used in interactive mode")
+from enum import Enum
class AutoDetectedOption(Enum):
ON = 'on'
@@ -52,96 +14,49 @@ class AutoDetectedOption(Enum):
def __str__(self : "AutoDetectedOption"):
return self.value
-@dataclass
-class CompletionArguments:
- model: str
- n_completions: int
- temperature: float
- presence_penalty: float
- frequency_penalty: float
- max_tokens: int
- top_p: float
+######################
+## PUBLIC INTERFACE ##
+######################
@dataclass
-class DisplayArguments:
- adornments: bool
- color: bool
+class RawArguments:
+ args : argparse.Namespace
+ debug : bool
+ openai_key : Optional[str] = None
-@dataclass
-class DebugArguments:
- save_response_to_file: Optional[str]
- load_response_from_file: Optional[str]
+def parse_raw_args_or_complete() -> RawArguments:
-@dataclass
-class MessageSource:
- message: Optional[str] = None
- prompt_from_fd: Optional[str] = None
- prompt_from_file: Optional[str] = None
+ parser, debug = _construct_parser()
-@dataclass
-class Arguments:
- completion_args: CompletionArguments
- display_args: DisplayArguments
- version: bool
- list_models: bool
- interactive: bool
- initial_message: MessageSource
- system_message: Optional[str] = None
- debug_args: Optional[DebugArguments] = None
-
-def split_arguments(args: argparse.Namespace) -> Arguments:
- completion_args = CompletionArguments(
- model=args.model,
- n_completions=args.n_completions,
- temperature=args.temperature,
- presence_penalty=args.presence_penalty,
- frequency_penalty=args.frequency_penalty,
- max_tokens=args.max_tokens,
- top_p=args.top_p,
- )
-
- msg_src = MessageSource(
- message = args.message,
- prompt_from_fd = args.prompt_from_fd,
- prompt_from_file = args.prompt_from_file,
- )
+ argcomplete.autocomplete( parser )
- display_args = DisplayArguments(
- adornments=(args.adornments == AutoDetectedOption.ON),
- color=(args.color == AutoDetectedOption.ON),
- )
+ args = parser.parse_args()
- debug_args = DebugArguments(
- save_response_to_file=args.save_response_to_file,
- load_response_from_file=args.load_response_from_file,
- )
+ openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY"))
- return Arguments(
- initial_message=msg_src,
- completion_args=completion_args,
- display_args=display_args,
- debug_args=debug_args,
- version=args.version,
- list_models=args.list_models,
- interactive=args.interactive,
- system_message=args.system_message
+ return RawArguments(
+ args = args,
+ debug = debug,
+ openai_key = openai_key
)
-def parse_args() -> Arguments:
+#####################
+## PRIVATE LOGIC ##
+#####################
- GPT_CLI_ENV_PREFIX = "GPT_CLI_"
+_GPT_CLI_ENV_PREFIX = "GPT_CLI_"
- debug = os.getenv(f'{GPT_CLI_ENV_PREFIX}DEBUG') is not None
+def _construct_parser() \
+ -> Tuple[argparse.ArgumentParser, bool]:
- if debug:
- logging.warning("Debugging mode and unstable features have been enabled.")
+ debug = os.getenv(f'{_GPT_CLI_ENV_PREFIX}DEBUG') is not None
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model",
- default=os.getenv(f'{GPT_CLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"),
+ default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"),
help="ID of the model to use",
)
@@ -149,7 +64,7 @@ def parse_args() -> Arguments:
"-t",
"--temperature",
type=float,
- default=os.getenv(f'{GPT_CLI_ENV_PREFIX}TEMPERATURE', 0.5),
+ default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}TEMPERATURE', 0.5),
help=(
"What sampling temperature to use, between 0 and 2. Higher values "
"like 0.8 will make the output more random, while lower values "
@@ -161,7 +76,7 @@ def parse_args() -> Arguments:
"-f",
"--frequency-penalty",
type=float,
- default=os.getenv(f'{GPT_CLI_ENV_PREFIX}FREQUENCY_PENALTY', 0),
+ default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}FREQUENCY_PENALTY', 0),
help=(
"Number between -2.0 and 2.0. Positive values penalize new tokens based "
"on their existing frequency in the text so far, decreasing the model's "
@@ -173,7 +88,7 @@ def parse_args() -> Arguments:
"-p",
"--presence-penalty",
type=float,
- default=os.getenv(f'{GPT_CLI_ENV_PREFIX}PRESENCE_PENALTY', 0),
+ default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}PRESENCE_PENALTY', 0),
help=(
"Number between -2.0 and 2.0. Positive values penalize new tokens based "
"on whether they appear in the text so far, increasing the model's "
@@ -185,7 +100,7 @@ def parse_args() -> Arguments:
"-k",
"--max-tokens",
type=int,
- default=os.getenv(f'{GPT_CLI_ENV_PREFIX}MAX_TOKENS', 2048),
+ default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}MAX_TOKENS', 2048),
help=(
"The maximum number of tokens to generate in the chat completion. "
"Defaults to 2048."
@@ -196,7 +111,7 @@ def parse_args() -> Arguments:
"-s",
"--top-p",
type=float,
- default=os.getenv(f'{GPT_CLI_ENV_PREFIX}TOP_P', 1),
+ default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}TOP_P', 1),
help=(
"An alternative to sampling with temperature, called nucleus sampling, "
"where the model considers the results of the tokens with top_p "
@@ -209,14 +124,14 @@ def parse_args() -> Arguments:
"-n",
"--n-completions",
type=int,
- default=os.getenv('f{GPT_CLI_ENV_PREFIX}N_COMPLETIONS', 1),
+ default=os.getenv('f{_GPT_CLI_ENV_PREFIX}N_COMPLETIONS', 1),
help="How many chat completion choices to generate for each input message.",
)
parser.add_argument(
"--system-message",
type=str,
- default=os.getenv('f{GPT_CLI_ENV_PREFIX}SYSTEM_MESSAGE'),
+ default=os.getenv('f{_GPT_CLI_ENV_PREFIX}SYSTEM_MESSAGE'),
help="Specify an alternative system message.",
)
@@ -298,49 +213,4 @@ def parse_args() -> Arguments:
help="UNSTABLE: load a response from a file. This can reply a response for debugging purposes",
)
- openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY"))
- if not openai_key:
- print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr)
- print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr)
- sys.exit(1)
-
- openai.api_key = openai_key
-
- args = parser.parse_args()
-
- if debug and args.load_response_from_file:
- logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {args.load_response_from_file} was generated')
-
- if args.color == AutoDetectedOption.AUTO:
- if os.getenv("NO_COLOR"):
- args.color = AutoDetectedOption.OFF
- elif sys.stdout.isatty():
- args.color = AutoDetectedOption.ON
- else:
- args.color = AutoDetectedOption.OFF
-
- if args.adornments == AutoDetectedOption.AUTO:
- if sys.stdout.isatty():
- args.adornments = AutoDetectedOption.ON
- else:
- args.adornments = AutoDetectedOption.OFF
-
- initial_message_specified = (
- args.message or
- args.prompt_from_fd or
- args.prompt_from_file
- )
-
- if not initial_message_specified:
- if debug and args.load_response_from_file:
- args.interactive = False
- elif sys.stdin.isatty():
- args.interactive = True
-
- if not debug:
- args.load_response_from_file = None
- args.save_response_to_file = None
-
- validate_args(args, debug=debug)
-
- return split_arguments(args)
+ return parser, debug
diff --git a/src/gpt_chat_cli/argvalidation.py b/src/gpt_chat_cli/argvalidation.py
new file mode 100644
index 0000000..16987a9
--- /dev/null
+++ b/src/gpt_chat_cli/argvalidation.py
@@ -0,0 +1,198 @@
+import logging
+
+from enum import Enum
+from dataclasses import dataclass
+from typing import Tuple, Optional
+
+import sys
+import os
+
+from .argparsing import (
+ RawArguments,
+ AutoDetectedOption,
+)
+
+######################
+## PUBLIC INTERFACE ##
+######################
+
+@dataclass
+class CompletionArguments:
+ model: str
+ n_completions: int
+ temperature: float
+ presence_penalty: float
+ frequency_penalty: float
+ max_tokens: int
+ top_p: float
+
+@dataclass
+class DisplayArguments:
+ adornments: bool
+ color: bool
+
+@dataclass
+class DebugArguments:
+ save_response_to_file: Optional[str]
+ load_response_from_file: Optional[str]
+
+@dataclass
+class MessageSource:
+ message: Optional[str] = None
+ prompt_from_fd: Optional[str] = None
+ prompt_from_file: Optional[str] = None
+
+@dataclass
+class Arguments:
+ completion_args: CompletionArguments
+ display_args: DisplayArguments
+ version: bool
+ list_models: bool
+ interactive: bool
+ initial_message: MessageSource
+ openai_key: str
+ system_message: Optional[str] = None
+ debug_args: Optional[DebugArguments] = None
+
+def post_process_raw_args(raw_args : RawArguments) -> Arguments:
+ _populate_defaults(raw_args)
+ _issue_warnings(raw_args)
+ _validate_args(raw_args)
+ return _restructure_arguments(raw_args)
+
+#####################
+## PRIVATE LOGIC ##
+#####################
+
+def _restructure_arguments(raw_args : RawArguments) -> Arguments:
+
+ args = raw_args.args
+ openai_key = raw_args.openai_key
+
+ completion_args = CompletionArguments(
+ model=args.model,
+ n_completions=args.n_completions,
+ temperature=args.temperature,
+ presence_penalty=args.presence_penalty,
+ frequency_penalty=args.frequency_penalty,
+ max_tokens=args.max_tokens,
+ top_p=args.top_p,
+ )
+
+ msg_src = MessageSource(
+ message = args.message,
+ prompt_from_fd = args.prompt_from_fd,
+ prompt_from_file = args.prompt_from_file,
+ )
+
+ display_args = DisplayArguments(
+ adornments=(args.adornments == AutoDetectedOption.ON),
+ color=(args.color == AutoDetectedOption.ON),
+ )
+
+ debug_args = DebugArguments(
+ save_response_to_file=args.save_response_to_file,
+ load_response_from_file=args.load_response_from_file,
+ )
+
+ return Arguments(
+ initial_message=msg_src,
+ openai_key=openai_key,
+ completion_args=completion_args,
+ display_args=display_args,
+ debug_args=debug_args,
+ version=args.version,
+ list_models=args.list_models,
+ interactive=args.interactive,
+ system_message=args.system_message
+ )
+
+def _die_validation_err(err : str):
+ print(err, file=sys.stderr)
+ sys.exit(1)
+
+def _validate_args(raw_args : RawArguments) -> None:
+
+ args = raw_args.args
+ debug = raw_args.debug
+
+ if not 0 <= args.temperature <= 2:
+ _die_validation_err("Temperature must be between 0 and 2.")
+
+ if not -2 <= args.frequency_penalty <= 2:
+ _die_validation_err("Frequency penalty must be between -2.0 and 2.0.")
+
+ if not -2 <= args.presence_penalty <= 2:
+ _die_validation_err("Presence penalty must be between -2.0 and 2.0.")
+
+ if args.max_tokens < 1:
+ _die_validation_err("Max tokens must be greater than or equal to 1.")
+
+ if not 0 <= args.top_p <= 1:
+ _die_validation_err("Top_p must be between 0 and 1.")
+
+ if args.n_completions < 1:
+ _die_validation_err("Number of completions must be greater than or equal to 1.")
+
+ if args.interactive and args.n_completions != 1:
+ _die_validation_err("Only a single completion can be used in interactive mode")
+
+ if (args.prompt_from_fd or args.prompt_from_file) and args.message:
+ _die_validation_err("Cannot specify an initial message alongside --prompt_from_fd or --prompt_from_file")
+
+ if debug and args.interactive:
+
+ if args.interactive and (
+ args.save_response_to_file or args.load_response_from_file
+ ):
+ _die_validation_err("Save and load operations cannot be used in interactive mode")
+
+def _issue_warnings(raw_args : RawArguments):
+
+ if not raw_args.openai_key:
+ print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr)
+ print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr)
+
+ sys.exit(1)
+
+ if raw_args.debug:
+ logging.warning("Debugging mode and unstable features have been enabled.")
+
+ if raw_args.debug and raw_args.args.load_response_from_file:
+ logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {raw_args.args.load_response_from_file} was generated')
+
+def _populate_defaults(raw_args : RawArguments):
+
+ args = raw_args.args
+ debug = raw_args.debug
+
+ if args.color == AutoDetectedOption.AUTO:
+ if os.getenv("NO_COLOR"):
+ args.color = AutoDetectedOption.OFF
+ elif sys.stdout.isatty():
+ args.color = AutoDetectedOption.ON
+ else:
+ args.color = AutoDetectedOption.OFF
+
+ if args.adornments == AutoDetectedOption.AUTO:
+ if sys.stdout.isatty():
+ args.adornments = AutoDetectedOption.ON
+ else:
+ args.adornments = AutoDetectedOption.OFF
+
+ initial_message_specified = (
+ args.message or
+ args.prompt_from_fd or
+ args.prompt_from_file
+ )
+
+ if not initial_message_specified:
+ if debug and args.load_response_from_file:
+ args.interactive = False
+ elif sys.stdin.isatty():
+ args.interactive = True
+
+ if not debug:
+ args.load_response_from_file = None
+ args.save_response_to_file = None
+
diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/cmd.py
index 916542b..d7aed6c 100644
--- a/src/gpt_chat_cli/gcli.py
+++ b/src/gpt_chat_cli/cmd.py
@@ -20,8 +20,7 @@ from .openai_wrappers import (
ChatMessage
)
-from .argparsing import (
- parse_args,
+from .argvalidation import (
Arguments,
DisplayArguments,
CompletionArguments,
@@ -212,13 +211,6 @@ def print_streamed_response(
#### COMMANDS ####
#########################
-def cmd_version():
- print(f'version {VERSION}')
-
-def cmd_list_models():
- for model in list_models():
- print(model)
-
def surround_ansi_escapes(prompt, start = "\x01", end = "\x02"):
'''
Fixes issue on Linux with the readline module
@@ -239,7 +231,14 @@ def surround_ansi_escapes(prompt, start = "\x01", end = "\x02"):
return result
-def cmd_interactive(args : Arguments):
+def version():
+ print(f'version {VERSION}')
+
+def list_models():
+ for model in list_models():
+ print(model)
+
+def interactive(args : Arguments):
enable_emacs_editing()
@@ -294,7 +293,7 @@ def cmd_interactive(args : Arguments):
if not prompt_message():
return
-def cmd_singleton(args: Arguments):
+def singleton(args: Arguments):
completion_args = args.completion_args
debug_args : DebugArguments = args.debug_args
@@ -325,19 +324,3 @@ def cmd_singleton(args: Arguments):
completion,
completion_args.n_completions
)
-
-
-def main():
- args = parse_args()
-
- if args.version:
- cmd_version()
- elif args.list_models:
- cmd_list_models()
- elif args.interactive:
- cmd_interactive(args)
- else:
- cmd_singleton(args)
-
-if __name__ == "__main__":
- main()
diff --git a/src/gpt_chat_cli/main.py b/src/gpt_chat_cli/main.py
new file mode 100644
index 0000000..77d2708
--- /dev/null
+++ b/src/gpt_chat_cli/main.py
@@ -0,0 +1,38 @@
+def main():
+ # defer other imports until autocomplete has finished
+ from .argparsing import parse_raw_args_or_complete
+
+ raw_args = parse_raw_args_or_complete()
+
+ # post process and validate
+ from .argvalidation import (
+ Arguments,
+ post_process_raw_args
+ )
+
+ args = post_process_raw_args( raw_args )
+
+ # populate key
+ import openai
+
+ openai.api_key = args.openai_key
+
+ # execute relevant command
+ from .cmd import (
+ version,
+ list_models,
+ interactive,
+ singleton,
+ )
+
+ if args.version:
+ version()
+ elif args.list_models:
+ list_models()
+ elif args.interactive:
+ interactive(args)
+ else:
+ singleton(args)
+
+if __name__ == "__main__":
+ main()
diff --git a/src/gpt_chat_cli/openai_wrappers.py b/src/gpt_chat_cli/openai_wrappers.py
index 3e1ec06..d478531 100644
--- a/src/gpt_chat_cli/openai_wrappers.py
+++ b/src/gpt_chat_cli/openai_wrappers.py
@@ -5,6 +5,8 @@ from typing import Any, List, Optional, Generator
from dataclasses import dataclass
from enum import Enum, auto
+from .argvalidation import CompletionArguments
+
@dataclass
class Delta:
content: Optional[str] = None
@@ -21,7 +23,6 @@ class FinishReason(Enum):
if finish_reason_str is None:
return FinishReason.NONE
return FinishReason[finish_reason_str.upper()]
-
@dataclass
class Choice:
delta: Delta
@@ -79,8 +80,6 @@ class OpenAIChatResponse:
OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None]
-from .argparsing import CompletionArguments
-
def create_chat_completion(hist : ChatHistory, args: CompletionArguments) \
-> OpenAIChatResponseStream:
diff --git a/src/gpt_chat_cli/version.py b/src/gpt_chat_cli/version.py
index 856ce1d..b5fd259 100644
--- a/src/gpt_chat_cli/version.py
+++ b/src/gpt_chat_cli/version.py
@@ -1 +1 @@
-VERSION = '0.1.2'
+VERSION = '0.2.2-alpha.1'