diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/gpt_chat_cli/argparsing.py | 200 | ||||
-rw-r--r-- | src/gpt_chat_cli/argvalidation.py | 198 | ||||
-rw-r--r-- | src/gpt_chat_cli/cmd.py (renamed from src/gpt_chat_cli/gcli.py) | 37 | ||||
-rw-r--r-- | src/gpt_chat_cli/main.py | 38 | ||||
-rw-r--r-- | src/gpt_chat_cli/openai_wrappers.py | 5 | ||||
-rw-r--r-- | src/gpt_chat_cli/version.py | 2 |
6 files changed, 284 insertions, 196 deletions
diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py index b026af8..a96e81c 100644 --- a/src/gpt_chat_cli/argparsing.py +++ b/src/gpt_chat_cli/argparsing.py @@ -1,48 +1,10 @@ import argparse +import argcomplete import os -import logging -import openai -import sys -from enum import Enum + from dataclasses import dataclass from typing import Tuple, Optional - -def die_validation_err(err : str): - print(err, file=sys.stderr) - sys.exit(1) - -def validate_args(args: argparse.Namespace, debug : bool = False) -> None: - - if not 0 <= args.temperature <= 2: - die_validation_err("Temperature must be between 0 and 2.") - - if not -2 <= args.frequency_penalty <= 2: - die_validation_err("Frequency penalty must be between -2.0 and 2.0.") - - if not -2 <= args.presence_penalty <= 2: - die_validation_err("Presence penalty must be between -2.0 and 2.0.") - - if args.max_tokens < 1: - die_validation_err("Max tokens must be greater than or equal to 1.") - - if not 0 <= args.top_p <= 1: - die_validation_err("Top_p must be between 0 and 1.") - - if args.n_completions < 1: - die_validation_err("Number of completions must be greater than or equal to 1.") - - if args.interactive and args.n_completions != 1: - die_validation_err("Only a single completion can be used in interactive mode") - - if (args.prompt_from_fd or args.prompt_from_file) and args.message: - die_validation_err("Cannot specify an initial message alongside --prompt_from_fd or --prompt_from_file") - - if debug and args.interactive: - - if args.interactive and ( - args.save_response_to_file or args.load_response_from_file - ): - die_validation_err("Save and load operations cannot be used in interactive mode") +from enum import Enum class AutoDetectedOption(Enum): ON = 'on' @@ -52,96 +14,49 @@ class AutoDetectedOption(Enum): def __str__(self : "AutoDetectedOption"): return self.value -@dataclass -class CompletionArguments: - model: str - n_completions: int - temperature: float - presence_penalty: float - frequency_penalty: float - max_tokens: int - top_p: float +###################### +## PUBLIC INTERFACE ## +###################### @dataclass -class DisplayArguments: - adornments: bool - color: bool +class RawArguments: + args : argparse.Namespace + debug : bool + openai_key : Optional[str] = None -@dataclass -class DebugArguments: - save_response_to_file: Optional[str] - load_response_from_file: Optional[str] +def parse_raw_args_or_complete() -> RawArguments: -@dataclass -class MessageSource: - message: Optional[str] = None - prompt_from_fd: Optional[str] = None - prompt_from_file: Optional[str] = None + parser, debug = _construct_parser() -@dataclass -class Arguments: - completion_args: CompletionArguments - display_args: DisplayArguments - version: bool - list_models: bool - interactive: bool - initial_message: MessageSource - system_message: Optional[str] = None - debug_args: Optional[DebugArguments] = None - -def split_arguments(args: argparse.Namespace) -> Arguments: - completion_args = CompletionArguments( - model=args.model, - n_completions=args.n_completions, - temperature=args.temperature, - presence_penalty=args.presence_penalty, - frequency_penalty=args.frequency_penalty, - max_tokens=args.max_tokens, - top_p=args.top_p, - ) - - msg_src = MessageSource( - message = args.message, - prompt_from_fd = args.prompt_from_fd, - prompt_from_file = args.prompt_from_file, - ) + argcomplete.autocomplete( parser ) - display_args = DisplayArguments( - adornments=(args.adornments == AutoDetectedOption.ON), - color=(args.color == AutoDetectedOption.ON), - ) + args = parser.parse_args() - debug_args = DebugArguments( - save_response_to_file=args.save_response_to_file, - load_response_from_file=args.load_response_from_file, - ) + openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY")) - return Arguments( - initial_message=msg_src, - completion_args=completion_args, - display_args=display_args, - debug_args=debug_args, - version=args.version, - list_models=args.list_models, - interactive=args.interactive, - system_message=args.system_message + return RawArguments( + args = args, + debug = debug, + openai_key = openai_key ) -def parse_args() -> Arguments: +##################### +## PRIVATE LOGIC ## +##################### - GPT_CLI_ENV_PREFIX = "GPT_CLI_" +_GPT_CLI_ENV_PREFIX = "GPT_CLI_" - debug = os.getenv(f'{GPT_CLI_ENV_PREFIX}DEBUG') is not None +def _construct_parser() \ + -> Tuple[argparse.ArgumentParser, bool]: - if debug: - logging.warning("Debugging mode and unstable features have been enabled.") + debug = os.getenv(f'{_GPT_CLI_ENV_PREFIX}DEBUG') is not None parser = argparse.ArgumentParser() parser.add_argument( "-m", "--model", - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"), help="ID of the model to use", ) @@ -149,7 +64,7 @@ def parse_args() -> Arguments: "-t", "--temperature", type=float, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}TEMPERATURE', 0.5), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}TEMPERATURE', 0.5), help=( "What sampling temperature to use, between 0 and 2. Higher values " "like 0.8 will make the output more random, while lower values " @@ -161,7 +76,7 @@ def parse_args() -> Arguments: "-f", "--frequency-penalty", type=float, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}FREQUENCY_PENALTY', 0), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}FREQUENCY_PENALTY', 0), help=( "Number between -2.0 and 2.0. Positive values penalize new tokens based " "on their existing frequency in the text so far, decreasing the model's " @@ -173,7 +88,7 @@ def parse_args() -> Arguments: "-p", "--presence-penalty", type=float, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}PRESENCE_PENALTY', 0), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}PRESENCE_PENALTY', 0), help=( "Number between -2.0 and 2.0. Positive values penalize new tokens based " "on whether they appear in the text so far, increasing the model's " @@ -185,7 +100,7 @@ def parse_args() -> Arguments: "-k", "--max-tokens", type=int, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}MAX_TOKENS', 2048), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}MAX_TOKENS', 2048), help=( "The maximum number of tokens to generate in the chat completion. " "Defaults to 2048." @@ -196,7 +111,7 @@ def parse_args() -> Arguments: "-s", "--top-p", type=float, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}TOP_P', 1), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}TOP_P', 1), help=( "An alternative to sampling with temperature, called nucleus sampling, " "where the model considers the results of the tokens with top_p " @@ -209,14 +124,14 @@ def parse_args() -> Arguments: "-n", "--n-completions", type=int, - default=os.getenv('f{GPT_CLI_ENV_PREFIX}N_COMPLETIONS', 1), + default=os.getenv('f{_GPT_CLI_ENV_PREFIX}N_COMPLETIONS', 1), help="How many chat completion choices to generate for each input message.", ) parser.add_argument( "--system-message", type=str, - default=os.getenv('f{GPT_CLI_ENV_PREFIX}SYSTEM_MESSAGE'), + default=os.getenv('f{_GPT_CLI_ENV_PREFIX}SYSTEM_MESSAGE'), help="Specify an alternative system message.", ) @@ -298,49 +213,4 @@ def parse_args() -> Arguments: help="UNSTABLE: load a response from a file. This can reply a response for debugging purposes", ) - openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY")) - if not openai_key: - print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr) - print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr) - sys.exit(1) - - openai.api_key = openai_key - - args = parser.parse_args() - - if debug and args.load_response_from_file: - logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {args.load_response_from_file} was generated') - - if args.color == AutoDetectedOption.AUTO: - if os.getenv("NO_COLOR"): - args.color = AutoDetectedOption.OFF - elif sys.stdout.isatty(): - args.color = AutoDetectedOption.ON - else: - args.color = AutoDetectedOption.OFF - - if args.adornments == AutoDetectedOption.AUTO: - if sys.stdout.isatty(): - args.adornments = AutoDetectedOption.ON - else: - args.adornments = AutoDetectedOption.OFF - - initial_message_specified = ( - args.message or - args.prompt_from_fd or - args.prompt_from_file - ) - - if not initial_message_specified: - if debug and args.load_response_from_file: - args.interactive = False - elif sys.stdin.isatty(): - args.interactive = True - - if not debug: - args.load_response_from_file = None - args.save_response_to_file = None - - validate_args(args, debug=debug) - - return split_arguments(args) + return parser, debug diff --git a/src/gpt_chat_cli/argvalidation.py b/src/gpt_chat_cli/argvalidation.py new file mode 100644 index 0000000..16987a9 --- /dev/null +++ b/src/gpt_chat_cli/argvalidation.py @@ -0,0 +1,198 @@ +import logging + +from enum import Enum +from dataclasses import dataclass +from typing import Tuple, Optional + +import sys +import os + +from .argparsing import ( + RawArguments, + AutoDetectedOption, +) + +###################### +## PUBLIC INTERFACE ## +###################### + +@dataclass +class CompletionArguments: + model: str + n_completions: int + temperature: float + presence_penalty: float + frequency_penalty: float + max_tokens: int + top_p: float + +@dataclass +class DisplayArguments: + adornments: bool + color: bool + +@dataclass +class DebugArguments: + save_response_to_file: Optional[str] + load_response_from_file: Optional[str] + +@dataclass +class MessageSource: + message: Optional[str] = None + prompt_from_fd: Optional[str] = None + prompt_from_file: Optional[str] = None + +@dataclass +class Arguments: + completion_args: CompletionArguments + display_args: DisplayArguments + version: bool + list_models: bool + interactive: bool + initial_message: MessageSource + openai_key: str + system_message: Optional[str] = None + debug_args: Optional[DebugArguments] = None + +def post_process_raw_args(raw_args : RawArguments) -> Arguments: + _populate_defaults(raw_args) + _issue_warnings(raw_args) + _validate_args(raw_args) + return _restructure_arguments(raw_args) + +##################### +## PRIVATE LOGIC ## +##################### + +def _restructure_arguments(raw_args : RawArguments) -> Arguments: + + args = raw_args.args + openai_key = raw_args.openai_key + + completion_args = CompletionArguments( + model=args.model, + n_completions=args.n_completions, + temperature=args.temperature, + presence_penalty=args.presence_penalty, + frequency_penalty=args.frequency_penalty, + max_tokens=args.max_tokens, + top_p=args.top_p, + ) + + msg_src = MessageSource( + message = args.message, + prompt_from_fd = args.prompt_from_fd, + prompt_from_file = args.prompt_from_file, + ) + + display_args = DisplayArguments( + adornments=(args.adornments == AutoDetectedOption.ON), + color=(args.color == AutoDetectedOption.ON), + ) + + debug_args = DebugArguments( + save_response_to_file=args.save_response_to_file, + load_response_from_file=args.load_response_from_file, + ) + + return Arguments( + initial_message=msg_src, + openai_key=openai_key, + completion_args=completion_args, + display_args=display_args, + debug_args=debug_args, + version=args.version, + list_models=args.list_models, + interactive=args.interactive, + system_message=args.system_message + ) + +def _die_validation_err(err : str): + print(err, file=sys.stderr) + sys.exit(1) + +def _validate_args(raw_args : RawArguments) -> None: + + args = raw_args.args + debug = raw_args.debug + + if not 0 <= args.temperature <= 2: + _die_validation_err("Temperature must be between 0 and 2.") + + if not -2 <= args.frequency_penalty <= 2: + _die_validation_err("Frequency penalty must be between -2.0 and 2.0.") + + if not -2 <= args.presence_penalty <= 2: + _die_validation_err("Presence penalty must be between -2.0 and 2.0.") + + if args.max_tokens < 1: + _die_validation_err("Max tokens must be greater than or equal to 1.") + + if not 0 <= args.top_p <= 1: + _die_validation_err("Top_p must be between 0 and 1.") + + if args.n_completions < 1: + _die_validation_err("Number of completions must be greater than or equal to 1.") + + if args.interactive and args.n_completions != 1: + _die_validation_err("Only a single completion can be used in interactive mode") + + if (args.prompt_from_fd or args.prompt_from_file) and args.message: + _die_validation_err("Cannot specify an initial message alongside --prompt_from_fd or --prompt_from_file") + + if debug and args.interactive: + + if args.interactive and ( + args.save_response_to_file or args.load_response_from_file + ): + _die_validation_err("Save and load operations cannot be used in interactive mode") + +def _issue_warnings(raw_args : RawArguments): + + if not raw_args.openai_key: + print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr) + print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr) + + sys.exit(1) + + if raw_args.debug: + logging.warning("Debugging mode and unstable features have been enabled.") + + if raw_args.debug and raw_args.args.load_response_from_file: + logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {raw_args.args.load_response_from_file} was generated') + +def _populate_defaults(raw_args : RawArguments): + + args = raw_args.args + debug = raw_args.debug + + if args.color == AutoDetectedOption.AUTO: + if os.getenv("NO_COLOR"): + args.color = AutoDetectedOption.OFF + elif sys.stdout.isatty(): + args.color = AutoDetectedOption.ON + else: + args.color = AutoDetectedOption.OFF + + if args.adornments == AutoDetectedOption.AUTO: + if sys.stdout.isatty(): + args.adornments = AutoDetectedOption.ON + else: + args.adornments = AutoDetectedOption.OFF + + initial_message_specified = ( + args.message or + args.prompt_from_fd or + args.prompt_from_file + ) + + if not initial_message_specified: + if debug and args.load_response_from_file: + args.interactive = False + elif sys.stdin.isatty(): + args.interactive = True + + if not debug: + args.load_response_from_file = None + args.save_response_to_file = None + diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/cmd.py index 916542b..d7aed6c 100644 --- a/src/gpt_chat_cli/gcli.py +++ b/src/gpt_chat_cli/cmd.py @@ -20,8 +20,7 @@ from .openai_wrappers import ( ChatMessage ) -from .argparsing import ( - parse_args, +from .argvalidation import ( Arguments, DisplayArguments, CompletionArguments, @@ -212,13 +211,6 @@ def print_streamed_response( #### COMMANDS #### ######################### -def cmd_version(): - print(f'version {VERSION}') - -def cmd_list_models(): - for model in list_models(): - print(model) - def surround_ansi_escapes(prompt, start = "\x01", end = "\x02"): ''' Fixes issue on Linux with the readline module @@ -239,7 +231,14 @@ def surround_ansi_escapes(prompt, start = "\x01", end = "\x02"): return result -def cmd_interactive(args : Arguments): +def version(): + print(f'version {VERSION}') + +def list_models(): + for model in list_models(): + print(model) + +def interactive(args : Arguments): enable_emacs_editing() @@ -294,7 +293,7 @@ def cmd_interactive(args : Arguments): if not prompt_message(): return -def cmd_singleton(args: Arguments): +def singleton(args: Arguments): completion_args = args.completion_args debug_args : DebugArguments = args.debug_args @@ -325,19 +324,3 @@ def cmd_singleton(args: Arguments): completion, completion_args.n_completions ) - - -def main(): - args = parse_args() - - if args.version: - cmd_version() - elif args.list_models: - cmd_list_models() - elif args.interactive: - cmd_interactive(args) - else: - cmd_singleton(args) - -if __name__ == "__main__": - main() diff --git a/src/gpt_chat_cli/main.py b/src/gpt_chat_cli/main.py new file mode 100644 index 0000000..77d2708 --- /dev/null +++ b/src/gpt_chat_cli/main.py @@ -0,0 +1,38 @@ +def main(): + # defer other imports until autocomplete has finished + from .argparsing import parse_raw_args_or_complete + + raw_args = parse_raw_args_or_complete() + + # post process and validate + from .argvalidation import ( + Arguments, + post_process_raw_args + ) + + args = post_process_raw_args( raw_args ) + + # populate key + import openai + + openai.api_key = args.openai_key + + # execute relevant command + from .cmd import ( + version, + list_models, + interactive, + singleton, + ) + + if args.version: + version() + elif args.list_models: + list_models() + elif args.interactive: + interactive(args) + else: + singleton(args) + +if __name__ == "__main__": + main() diff --git a/src/gpt_chat_cli/openai_wrappers.py b/src/gpt_chat_cli/openai_wrappers.py index 3e1ec06..d478531 100644 --- a/src/gpt_chat_cli/openai_wrappers.py +++ b/src/gpt_chat_cli/openai_wrappers.py @@ -5,6 +5,8 @@ from typing import Any, List, Optional, Generator from dataclasses import dataclass from enum import Enum, auto +from .argvalidation import CompletionArguments + @dataclass class Delta: content: Optional[str] = None @@ -21,7 +23,6 @@ class FinishReason(Enum): if finish_reason_str is None: return FinishReason.NONE return FinishReason[finish_reason_str.upper()] - @dataclass class Choice: delta: Delta @@ -79,8 +80,6 @@ class OpenAIChatResponse: OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None] -from .argparsing import CompletionArguments - def create_chat_completion(hist : ChatHistory, args: CompletionArguments) \ -> OpenAIChatResponseStream: diff --git a/src/gpt_chat_cli/version.py b/src/gpt_chat_cli/version.py index 856ce1d..b5fd259 100644 --- a/src/gpt_chat_cli/version.py +++ b/src/gpt_chat_cli/version.py @@ -1 +1 @@ -VERSION = '0.1.2' +VERSION = '0.2.2-alpha.1' |