diff options
Diffstat (limited to 'src/gpt_chat_cli')
| -rw-r--r-- | src/gpt_chat_cli/argparsing.py | 200 | ||||
| -rw-r--r-- | src/gpt_chat_cli/argvalidation.py | 198 | ||||
| -rw-r--r-- | src/gpt_chat_cli/cmd.py (renamed from src/gpt_chat_cli/gcli.py) | 37 | ||||
| -rw-r--r-- | src/gpt_chat_cli/main.py | 38 | ||||
| -rw-r--r-- | src/gpt_chat_cli/openai_wrappers.py | 5 | ||||
| -rw-r--r-- | src/gpt_chat_cli/version.py | 2 | 
6 files changed, 284 insertions, 196 deletions
| diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py index b026af8..a96e81c 100644 --- a/src/gpt_chat_cli/argparsing.py +++ b/src/gpt_chat_cli/argparsing.py @@ -1,48 +1,10 @@  import argparse +import argcomplete  import os -import logging -import openai -import sys -from enum import Enum +  from dataclasses import dataclass  from typing import Tuple, Optional - -def die_validation_err(err : str): -    print(err, file=sys.stderr) -    sys.exit(1) - -def validate_args(args: argparse.Namespace, debug : bool = False) -> None: - -    if not 0 <= args.temperature <= 2: -        die_validation_err("Temperature must be between 0 and 2.") - -    if not -2 <= args.frequency_penalty <= 2: -        die_validation_err("Frequency penalty must be between -2.0 and 2.0.") - -    if not -2 <= args.presence_penalty <= 2: -        die_validation_err("Presence penalty must be between -2.0 and 2.0.") - -    if args.max_tokens < 1: -        die_validation_err("Max tokens must be greater than or equal to 1.") - -    if not 0 <= args.top_p <= 1: -        die_validation_err("Top_p must be between 0 and 1.") - -    if args.n_completions < 1: -        die_validation_err("Number of completions must be greater than or equal to 1.") - -    if args.interactive and args.n_completions != 1: -        die_validation_err("Only a single completion can be used in interactive mode") - -    if (args.prompt_from_fd or args.prompt_from_file) and args.message: -        die_validation_err("Cannot specify an initial message alongside --prompt_from_fd or --prompt_from_file") - -    if debug and args.interactive: - -        if args.interactive and ( -            args.save_response_to_file or args.load_response_from_file -        ): -            die_validation_err("Save and load operations cannot be used in interactive mode") +from enum import Enum  class AutoDetectedOption(Enum):      ON = 'on' @@ -52,96 +14,49 @@ class AutoDetectedOption(Enum):      def __str__(self : "AutoDetectedOption"):          return self.value -@dataclass -class CompletionArguments: -    model: str -    n_completions: int -    temperature: float -    presence_penalty: float -    frequency_penalty: float -    max_tokens: int -    top_p: float +###################### +## PUBLIC INTERFACE ## +######################  @dataclass -class DisplayArguments: -    adornments: bool -    color: bool +class RawArguments: +    args : argparse.Namespace +    debug : bool +    openai_key : Optional[str] = None -@dataclass -class DebugArguments: -    save_response_to_file: Optional[str] -    load_response_from_file: Optional[str] +def parse_raw_args_or_complete() -> RawArguments: -@dataclass -class MessageSource: -    message: Optional[str] = None -    prompt_from_fd: Optional[str] = None -    prompt_from_file: Optional[str] = None +    parser, debug = _construct_parser() -@dataclass -class Arguments: -    completion_args: CompletionArguments -    display_args: DisplayArguments -    version: bool -    list_models: bool -    interactive: bool -    initial_message: MessageSource -    system_message: Optional[str] = None -    debug_args: Optional[DebugArguments] = None - -def split_arguments(args: argparse.Namespace) -> Arguments: -    completion_args = CompletionArguments( -        model=args.model, -        n_completions=args.n_completions, -        temperature=args.temperature, -        presence_penalty=args.presence_penalty, -        frequency_penalty=args.frequency_penalty, -        max_tokens=args.max_tokens, -        top_p=args.top_p, -    ) - -    msg_src = MessageSource( -        message = args.message, -        prompt_from_fd = args.prompt_from_fd, -        prompt_from_file = args.prompt_from_file, -    ) +    argcomplete.autocomplete( parser ) -    display_args = DisplayArguments( -        adornments=(args.adornments == AutoDetectedOption.ON), -        color=(args.color == AutoDetectedOption.ON), -    ) +    args = parser.parse_args() -    debug_args = DebugArguments( -        save_response_to_file=args.save_response_to_file, -        load_response_from_file=args.load_response_from_file, -    ) +    openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY")) -    return Arguments( -           initial_message=msg_src, -           completion_args=completion_args, -           display_args=display_args, -           debug_args=debug_args, -           version=args.version, -           list_models=args.list_models, -           interactive=args.interactive, -           system_message=args.system_message +    return RawArguments( +        args = args, +        debug = debug, +        openai_key = openai_key      ) -def parse_args() -> Arguments: +##################### +##  PRIVATE LOGIC  ## +##################### -    GPT_CLI_ENV_PREFIX = "GPT_CLI_" +_GPT_CLI_ENV_PREFIX = "GPT_CLI_" -    debug = os.getenv(f'{GPT_CLI_ENV_PREFIX}DEBUG') is not None +def _construct_parser() \ +    -> Tuple[argparse.ArgumentParser, bool]: -    if debug: -        logging.warning("Debugging mode and unstable features have been enabled.") +    debug = os.getenv(f'{_GPT_CLI_ENV_PREFIX}DEBUG') is not None      parser = argparse.ArgumentParser()      parser.add_argument(          "-m",          "--model", -        default=os.getenv(f'{GPT_CLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"), +        default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"),          help="ID of the model to use",      ) @@ -149,7 +64,7 @@ def parse_args() -> Arguments:          "-t",          "--temperature",          type=float, -        default=os.getenv(f'{GPT_CLI_ENV_PREFIX}TEMPERATURE', 0.5), +        default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}TEMPERATURE', 0.5),          help=(              "What sampling temperature to use, between 0 and 2. Higher values "              "like 0.8 will make the output more random, while lower values " @@ -161,7 +76,7 @@ def parse_args() -> Arguments:          "-f",          "--frequency-penalty",          type=float, -        default=os.getenv(f'{GPT_CLI_ENV_PREFIX}FREQUENCY_PENALTY', 0), +        default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}FREQUENCY_PENALTY', 0),          help=(              "Number between -2.0 and 2.0. Positive values penalize new tokens based "              "on their existing frequency in the text so far, decreasing the model's " @@ -173,7 +88,7 @@ def parse_args() -> Arguments:          "-p",          "--presence-penalty",          type=float, -        default=os.getenv(f'{GPT_CLI_ENV_PREFIX}PRESENCE_PENALTY', 0), +        default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}PRESENCE_PENALTY', 0),          help=(              "Number between -2.0 and 2.0. Positive values penalize new tokens based "              "on whether they appear in the text so far, increasing the model's " @@ -185,7 +100,7 @@ def parse_args() -> Arguments:          "-k",          "--max-tokens",          type=int, -        default=os.getenv(f'{GPT_CLI_ENV_PREFIX}MAX_TOKENS', 2048), +        default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}MAX_TOKENS', 2048),          help=(              "The maximum number of tokens to generate in the chat completion. "              "Defaults to 2048." @@ -196,7 +111,7 @@ def parse_args() -> Arguments:          "-s",          "--top-p",          type=float, -        default=os.getenv(f'{GPT_CLI_ENV_PREFIX}TOP_P', 1), +        default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}TOP_P', 1),          help=(              "An alternative to sampling with temperature, called nucleus sampling, "              "where the model considers the results of the tokens with top_p " @@ -209,14 +124,14 @@ def parse_args() -> Arguments:          "-n",          "--n-completions",          type=int, -        default=os.getenv('f{GPT_CLI_ENV_PREFIX}N_COMPLETIONS', 1), +        default=os.getenv('f{_GPT_CLI_ENV_PREFIX}N_COMPLETIONS', 1),          help="How many chat completion choices to generate for each input message.",      )      parser.add_argument(          "--system-message",          type=str, -        default=os.getenv('f{GPT_CLI_ENV_PREFIX}SYSTEM_MESSAGE'), +        default=os.getenv('f{_GPT_CLI_ENV_PREFIX}SYSTEM_MESSAGE'),          help="Specify an alternative system message.",      ) @@ -298,49 +213,4 @@ def parse_args() -> Arguments:              help="UNSTABLE: load a response from a file. This can reply a response for debugging purposes",          ) -    openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY")) -    if not openai_key: -        print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr) -        print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr) -        sys.exit(1) - -    openai.api_key = openai_key - -    args = parser.parse_args() - -    if debug and args.load_response_from_file: -        logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {args.load_response_from_file} was generated') - -    if args.color == AutoDetectedOption.AUTO: -        if os.getenv("NO_COLOR"): -            args.color = AutoDetectedOption.OFF -        elif sys.stdout.isatty(): -            args.color = AutoDetectedOption.ON -        else: -            args.color = AutoDetectedOption.OFF - -    if args.adornments == AutoDetectedOption.AUTO: -        if sys.stdout.isatty(): -            args.adornments = AutoDetectedOption.ON -        else: -            args.adornments = AutoDetectedOption.OFF - -    initial_message_specified = ( -        args.message or -        args.prompt_from_fd or -        args.prompt_from_file -    ) - -    if not initial_message_specified: -        if debug and args.load_response_from_file: -            args.interactive = False -        elif sys.stdin.isatty(): -            args.interactive = True - -    if not debug: -        args.load_response_from_file = None -        args.save_response_to_file = None - -    validate_args(args, debug=debug) - -    return split_arguments(args) +    return parser, debug diff --git a/src/gpt_chat_cli/argvalidation.py b/src/gpt_chat_cli/argvalidation.py new file mode 100644 index 0000000..16987a9 --- /dev/null +++ b/src/gpt_chat_cli/argvalidation.py @@ -0,0 +1,198 @@ +import logging + +from enum import Enum +from dataclasses import dataclass +from typing import Tuple, Optional + +import sys +import os + +from .argparsing import ( +    RawArguments, +    AutoDetectedOption, +) + +###################### +## PUBLIC INTERFACE ## +###################### + +@dataclass +class CompletionArguments: +    model: str +    n_completions: int +    temperature: float +    presence_penalty: float +    frequency_penalty: float +    max_tokens: int +    top_p: float + +@dataclass +class DisplayArguments: +    adornments: bool +    color: bool + +@dataclass +class DebugArguments: +    save_response_to_file: Optional[str] +    load_response_from_file: Optional[str] + +@dataclass +class MessageSource: +    message: Optional[str] = None +    prompt_from_fd: Optional[str] = None +    prompt_from_file: Optional[str] = None + +@dataclass +class Arguments: +    completion_args: CompletionArguments +    display_args: DisplayArguments +    version: bool +    list_models: bool +    interactive: bool +    initial_message: MessageSource +    openai_key: str +    system_message: Optional[str] = None +    debug_args: Optional[DebugArguments] = None + +def post_process_raw_args(raw_args : RawArguments) -> Arguments: +    _populate_defaults(raw_args) +    _issue_warnings(raw_args) +    _validate_args(raw_args) +    return _restructure_arguments(raw_args) + +##################### +##  PRIVATE LOGIC  ## +##################### + +def _restructure_arguments(raw_args : RawArguments) -> Arguments: + +    args = raw_args.args +    openai_key = raw_args.openai_key + +    completion_args = CompletionArguments( +        model=args.model, +        n_completions=args.n_completions, +        temperature=args.temperature, +        presence_penalty=args.presence_penalty, +        frequency_penalty=args.frequency_penalty, +        max_tokens=args.max_tokens, +        top_p=args.top_p, +    ) + +    msg_src = MessageSource( +        message = args.message, +        prompt_from_fd = args.prompt_from_fd, +        prompt_from_file = args.prompt_from_file, +    ) + +    display_args = DisplayArguments( +        adornments=(args.adornments == AutoDetectedOption.ON), +        color=(args.color == AutoDetectedOption.ON), +    ) + +    debug_args = DebugArguments( +        save_response_to_file=args.save_response_to_file, +        load_response_from_file=args.load_response_from_file, +    ) + +    return Arguments( +           initial_message=msg_src, +           openai_key=openai_key, +           completion_args=completion_args, +           display_args=display_args, +           debug_args=debug_args, +           version=args.version, +           list_models=args.list_models, +           interactive=args.interactive, +           system_message=args.system_message +    ) + +def _die_validation_err(err : str): +    print(err, file=sys.stderr) +    sys.exit(1) + +def _validate_args(raw_args : RawArguments) -> None: + +    args = raw_args.args +    debug = raw_args.debug + +    if not 0 <= args.temperature <= 2: +        _die_validation_err("Temperature must be between 0 and 2.") + +    if not -2 <= args.frequency_penalty <= 2: +        _die_validation_err("Frequency penalty must be between -2.0 and 2.0.") + +    if not -2 <= args.presence_penalty <= 2: +        _die_validation_err("Presence penalty must be between -2.0 and 2.0.") + +    if args.max_tokens < 1: +        _die_validation_err("Max tokens must be greater than or equal to 1.") + +    if not 0 <= args.top_p <= 1: +        _die_validation_err("Top_p must be between 0 and 1.") + +    if args.n_completions < 1: +        _die_validation_err("Number of completions must be greater than or equal to 1.") + +    if args.interactive and args.n_completions != 1: +        _die_validation_err("Only a single completion can be used in interactive mode") + +    if (args.prompt_from_fd or args.prompt_from_file) and args.message: +        _die_validation_err("Cannot specify an initial message alongside --prompt_from_fd or --prompt_from_file") + +    if debug and args.interactive: + +        if args.interactive and ( +            args.save_response_to_file or args.load_response_from_file +        ): +            _die_validation_err("Save and load operations cannot be used in interactive mode") + +def _issue_warnings(raw_args : RawArguments): + +    if not raw_args.openai_key: +        print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr) +        print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr) + +        sys.exit(1) + +    if raw_args.debug: +        logging.warning("Debugging mode and unstable features have been enabled.") + +    if raw_args.debug and raw_args.args.load_response_from_file: +        logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {raw_args.args.load_response_from_file} was generated') + +def _populate_defaults(raw_args : RawArguments): + +    args = raw_args.args +    debug = raw_args.debug + +    if args.color == AutoDetectedOption.AUTO: +        if os.getenv("NO_COLOR"): +            args.color = AutoDetectedOption.OFF +        elif sys.stdout.isatty(): +            args.color = AutoDetectedOption.ON +        else: +            args.color = AutoDetectedOption.OFF + +    if args.adornments == AutoDetectedOption.AUTO: +        if sys.stdout.isatty(): +            args.adornments = AutoDetectedOption.ON +        else: +            args.adornments = AutoDetectedOption.OFF + +    initial_message_specified = ( +        args.message or +        args.prompt_from_fd or +        args.prompt_from_file +    ) + +    if not initial_message_specified: +        if debug and args.load_response_from_file: +            args.interactive = False +        elif sys.stdin.isatty(): +            args.interactive = True + +    if not debug: +        args.load_response_from_file = None +        args.save_response_to_file = None + diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/cmd.py index 916542b..d7aed6c 100644 --- a/src/gpt_chat_cli/gcli.py +++ b/src/gpt_chat_cli/cmd.py @@ -20,8 +20,7 @@ from .openai_wrappers import (      ChatMessage  ) -from .argparsing import ( -    parse_args, +from .argvalidation import (      Arguments,      DisplayArguments,      CompletionArguments, @@ -212,13 +211,6 @@ def print_streamed_response(  ####    COMMANDS     ####  ######################### -def cmd_version(): -    print(f'version {VERSION}') - -def cmd_list_models(): -    for model in list_models(): -        print(model) -  def surround_ansi_escapes(prompt, start = "\x01", end = "\x02"):          '''          Fixes issue on Linux with the readline module @@ -239,7 +231,14 @@ def surround_ansi_escapes(prompt, start = "\x01", end = "\x02"):          return result -def cmd_interactive(args : Arguments): +def version(): +    print(f'version {VERSION}') + +def list_models(): +    for model in list_models(): +        print(model) + +def interactive(args : Arguments):      enable_emacs_editing() @@ -294,7 +293,7 @@ def cmd_interactive(args : Arguments):          if not prompt_message():              return -def cmd_singleton(args: Arguments): +def singleton(args: Arguments):      completion_args = args.completion_args      debug_args : DebugArguments = args.debug_args @@ -325,19 +324,3 @@ def cmd_singleton(args: Arguments):          completion,          completion_args.n_completions      ) - - -def main(): -    args = parse_args() - -    if args.version: -        cmd_version() -    elif args.list_models: -        cmd_list_models() -    elif args.interactive: -        cmd_interactive(args) -    else: -        cmd_singleton(args) - -if __name__ == "__main__": -    main() diff --git a/src/gpt_chat_cli/main.py b/src/gpt_chat_cli/main.py new file mode 100644 index 0000000..77d2708 --- /dev/null +++ b/src/gpt_chat_cli/main.py @@ -0,0 +1,38 @@ +def main(): +    # defer other imports until autocomplete has finished +    from .argparsing import parse_raw_args_or_complete + +    raw_args = parse_raw_args_or_complete() + +    # post process and validate +    from .argvalidation import ( +        Arguments, +        post_process_raw_args +    ) + +    args = post_process_raw_args( raw_args ) + +    # populate key +    import openai + +    openai.api_key = args.openai_key + +    # execute relevant command +    from .cmd import ( +        version, +        list_models, +        interactive, +        singleton, +    ) + +    if args.version: +        version() +    elif args.list_models: +        list_models() +    elif args.interactive: +        interactive(args) +    else: +        singleton(args) + +if __name__ == "__main__": +    main() diff --git a/src/gpt_chat_cli/openai_wrappers.py b/src/gpt_chat_cli/openai_wrappers.py index 3e1ec06..d478531 100644 --- a/src/gpt_chat_cli/openai_wrappers.py +++ b/src/gpt_chat_cli/openai_wrappers.py @@ -5,6 +5,8 @@ from typing import Any, List, Optional, Generator  from dataclasses import dataclass  from enum import Enum, auto +from .argvalidation import CompletionArguments +  @dataclass  class Delta:      content: Optional[str] = None @@ -21,7 +23,6 @@ class FinishReason(Enum):          if finish_reason_str is None:              return FinishReason.NONE          return FinishReason[finish_reason_str.upper()] -  @dataclass  class Choice:      delta: Delta @@ -79,8 +80,6 @@ class OpenAIChatResponse:  OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None] -from .argparsing import CompletionArguments -  def create_chat_completion(hist : ChatHistory, args: CompletionArguments) \      -> OpenAIChatResponseStream: diff --git a/src/gpt_chat_cli/version.py b/src/gpt_chat_cli/version.py index 856ce1d..b5fd259 100644 --- a/src/gpt_chat_cli/version.py +++ b/src/gpt_chat_cli/version.py @@ -1 +1 @@ -VERSION = '0.1.2' +VERSION = '0.2.2-alpha.1' | 
