From 4f684f7073149c48ea9dd962b64a8bf7a169a1cf Mon Sep 17 00:00:00 2001 From: flu0r1ne Date: Tue, 9 May 2023 15:57:05 -0500 Subject: Rewrote the argparsing functionality to enable autocompletion via the "kislyuk/argcomplete" package. This essentially means: - `argparsing.py` does a minimal amount of work to initialize the arg parser. Then it attempts dynamic completion. - `argvalidation.py` processes the raw arguments parsed in `argparsing.py`, validates them, issues warnings (if required), and splits them into logical groupings - Commands in `gcli.py` have been moved to `cmd.py` - `main.py` provides an initial control path to call these functions in succession --- pyproject.toml | 2 +- setup.py | 7 +- src/gpt_chat_cli/argparsing.py | 200 ++++----------------- src/gpt_chat_cli/argvalidation.py | 198 +++++++++++++++++++++ src/gpt_chat_cli/cmd.py | 326 ++++++++++++++++++++++++++++++++++ src/gpt_chat_cli/gcli.py | 343 ------------------------------------ src/gpt_chat_cli/main.py | 38 ++++ src/gpt_chat_cli/openai_wrappers.py | 5 +- src/gpt_chat_cli/version.py | 2 +- 9 files changed, 605 insertions(+), 516 deletions(-) create mode 100644 src/gpt_chat_cli/argvalidation.py create mode 100644 src/gpt_chat_cli/cmd.py delete mode 100644 src/gpt_chat_cli/gcli.py create mode 100644 src/gpt_chat_cli/main.py diff --git a/pyproject.toml b/pyproject.toml index a5ee66f..ecf72f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "gpt-chat-cli" -version = "0.1.2" +version = "0.2.2-alpha.1" authors = [ { name="Flu0r1ne", email="flu0r1ne@flu0r1ne.net" }, ] diff --git a/setup.py b/setup.py index 269dc73..07740e6 100644 --- a/setup.py +++ b/setup.py @@ -2,9 +2,9 @@ import setuptools setuptools.setup( name='gpt-chat-cli', - version='0.1.2', + version='0.2.2-alpha', entry_points = { - 'console_scripts': ['gpt-chat-cli=gpt_chat_cli.gcli:main'], + 'console_scripts': ['gpt-chat-cli=gpt_chat_cli.main:main'], }, author='Flu0r1ne', description='A simple ChatGPT CLI', @@ -13,7 +13,8 @@ setuptools.setup( install_requires=[ 'setuptools', 'openai >= 0.27.6', - 'pygments >= 0.15.0' + 'pygments >= 0.15.0', + 'argcomplete >= 3.0.8', ], python_requires='>=3.7' ) diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py index b026af8..a96e81c 100644 --- a/src/gpt_chat_cli/argparsing.py +++ b/src/gpt_chat_cli/argparsing.py @@ -1,48 +1,10 @@ import argparse +import argcomplete import os -import logging -import openai -import sys -from enum import Enum + from dataclasses import dataclass from typing import Tuple, Optional - -def die_validation_err(err : str): - print(err, file=sys.stderr) - sys.exit(1) - -def validate_args(args: argparse.Namespace, debug : bool = False) -> None: - - if not 0 <= args.temperature <= 2: - die_validation_err("Temperature must be between 0 and 2.") - - if not -2 <= args.frequency_penalty <= 2: - die_validation_err("Frequency penalty must be between -2.0 and 2.0.") - - if not -2 <= args.presence_penalty <= 2: - die_validation_err("Presence penalty must be between -2.0 and 2.0.") - - if args.max_tokens < 1: - die_validation_err("Max tokens must be greater than or equal to 1.") - - if not 0 <= args.top_p <= 1: - die_validation_err("Top_p must be between 0 and 1.") - - if args.n_completions < 1: - die_validation_err("Number of completions must be greater than or equal to 1.") - - if args.interactive and args.n_completions != 1: - die_validation_err("Only a single completion can be used in interactive mode") - - if (args.prompt_from_fd or args.prompt_from_file) and args.message: - die_validation_err("Cannot specify an initial message alongside --prompt_from_fd or --prompt_from_file") - - if debug and args.interactive: - - if args.interactive and ( - args.save_response_to_file or args.load_response_from_file - ): - die_validation_err("Save and load operations cannot be used in interactive mode") +from enum import Enum class AutoDetectedOption(Enum): ON = 'on' @@ -52,96 +14,49 @@ class AutoDetectedOption(Enum): def __str__(self : "AutoDetectedOption"): return self.value -@dataclass -class CompletionArguments: - model: str - n_completions: int - temperature: float - presence_penalty: float - frequency_penalty: float - max_tokens: int - top_p: float +###################### +## PUBLIC INTERFACE ## +###################### @dataclass -class DisplayArguments: - adornments: bool - color: bool +class RawArguments: + args : argparse.Namespace + debug : bool + openai_key : Optional[str] = None -@dataclass -class DebugArguments: - save_response_to_file: Optional[str] - load_response_from_file: Optional[str] +def parse_raw_args_or_complete() -> RawArguments: -@dataclass -class MessageSource: - message: Optional[str] = None - prompt_from_fd: Optional[str] = None - prompt_from_file: Optional[str] = None + parser, debug = _construct_parser() -@dataclass -class Arguments: - completion_args: CompletionArguments - display_args: DisplayArguments - version: bool - list_models: bool - interactive: bool - initial_message: MessageSource - system_message: Optional[str] = None - debug_args: Optional[DebugArguments] = None - -def split_arguments(args: argparse.Namespace) -> Arguments: - completion_args = CompletionArguments( - model=args.model, - n_completions=args.n_completions, - temperature=args.temperature, - presence_penalty=args.presence_penalty, - frequency_penalty=args.frequency_penalty, - max_tokens=args.max_tokens, - top_p=args.top_p, - ) - - msg_src = MessageSource( - message = args.message, - prompt_from_fd = args.prompt_from_fd, - prompt_from_file = args.prompt_from_file, - ) + argcomplete.autocomplete( parser ) - display_args = DisplayArguments( - adornments=(args.adornments == AutoDetectedOption.ON), - color=(args.color == AutoDetectedOption.ON), - ) + args = parser.parse_args() - debug_args = DebugArguments( - save_response_to_file=args.save_response_to_file, - load_response_from_file=args.load_response_from_file, - ) + openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY")) - return Arguments( - initial_message=msg_src, - completion_args=completion_args, - display_args=display_args, - debug_args=debug_args, - version=args.version, - list_models=args.list_models, - interactive=args.interactive, - system_message=args.system_message + return RawArguments( + args = args, + debug = debug, + openai_key = openai_key ) -def parse_args() -> Arguments: +##################### +## PRIVATE LOGIC ## +##################### - GPT_CLI_ENV_PREFIX = "GPT_CLI_" +_GPT_CLI_ENV_PREFIX = "GPT_CLI_" - debug = os.getenv(f'{GPT_CLI_ENV_PREFIX}DEBUG') is not None +def _construct_parser() \ + -> Tuple[argparse.ArgumentParser, bool]: - if debug: - logging.warning("Debugging mode and unstable features have been enabled.") + debug = os.getenv(f'{_GPT_CLI_ENV_PREFIX}DEBUG') is not None parser = argparse.ArgumentParser() parser.add_argument( "-m", "--model", - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"), help="ID of the model to use", ) @@ -149,7 +64,7 @@ def parse_args() -> Arguments: "-t", "--temperature", type=float, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}TEMPERATURE', 0.5), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}TEMPERATURE', 0.5), help=( "What sampling temperature to use, between 0 and 2. Higher values " "like 0.8 will make the output more random, while lower values " @@ -161,7 +76,7 @@ def parse_args() -> Arguments: "-f", "--frequency-penalty", type=float, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}FREQUENCY_PENALTY', 0), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}FREQUENCY_PENALTY', 0), help=( "Number between -2.0 and 2.0. Positive values penalize new tokens based " "on their existing frequency in the text so far, decreasing the model's " @@ -173,7 +88,7 @@ def parse_args() -> Arguments: "-p", "--presence-penalty", type=float, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}PRESENCE_PENALTY', 0), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}PRESENCE_PENALTY', 0), help=( "Number between -2.0 and 2.0. Positive values penalize new tokens based " "on whether they appear in the text so far, increasing the model's " @@ -185,7 +100,7 @@ def parse_args() -> Arguments: "-k", "--max-tokens", type=int, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}MAX_TOKENS', 2048), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}MAX_TOKENS', 2048), help=( "The maximum number of tokens to generate in the chat completion. " "Defaults to 2048." @@ -196,7 +111,7 @@ def parse_args() -> Arguments: "-s", "--top-p", type=float, - default=os.getenv(f'{GPT_CLI_ENV_PREFIX}TOP_P', 1), + default=os.getenv(f'{_GPT_CLI_ENV_PREFIX}TOP_P', 1), help=( "An alternative to sampling with temperature, called nucleus sampling, " "where the model considers the results of the tokens with top_p " @@ -209,14 +124,14 @@ def parse_args() -> Arguments: "-n", "--n-completions", type=int, - default=os.getenv('f{GPT_CLI_ENV_PREFIX}N_COMPLETIONS', 1), + default=os.getenv('f{_GPT_CLI_ENV_PREFIX}N_COMPLETIONS', 1), help="How many chat completion choices to generate for each input message.", ) parser.add_argument( "--system-message", type=str, - default=os.getenv('f{GPT_CLI_ENV_PREFIX}SYSTEM_MESSAGE'), + default=os.getenv('f{_GPT_CLI_ENV_PREFIX}SYSTEM_MESSAGE'), help="Specify an alternative system message.", ) @@ -298,49 +213,4 @@ def parse_args() -> Arguments: help="UNSTABLE: load a response from a file. This can reply a response for debugging purposes", ) - openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY")) - if not openai_key: - print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr) - print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr) - sys.exit(1) - - openai.api_key = openai_key - - args = parser.parse_args() - - if debug and args.load_response_from_file: - logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {args.load_response_from_file} was generated') - - if args.color == AutoDetectedOption.AUTO: - if os.getenv("NO_COLOR"): - args.color = AutoDetectedOption.OFF - elif sys.stdout.isatty(): - args.color = AutoDetectedOption.ON - else: - args.color = AutoDetectedOption.OFF - - if args.adornments == AutoDetectedOption.AUTO: - if sys.stdout.isatty(): - args.adornments = AutoDetectedOption.ON - else: - args.adornments = AutoDetectedOption.OFF - - initial_message_specified = ( - args.message or - args.prompt_from_fd or - args.prompt_from_file - ) - - if not initial_message_specified: - if debug and args.load_response_from_file: - args.interactive = False - elif sys.stdin.isatty(): - args.interactive = True - - if not debug: - args.load_response_from_file = None - args.save_response_to_file = None - - validate_args(args, debug=debug) - - return split_arguments(args) + return parser, debug diff --git a/src/gpt_chat_cli/argvalidation.py b/src/gpt_chat_cli/argvalidation.py new file mode 100644 index 0000000..16987a9 --- /dev/null +++ b/src/gpt_chat_cli/argvalidation.py @@ -0,0 +1,198 @@ +import logging + +from enum import Enum +from dataclasses import dataclass +from typing import Tuple, Optional + +import sys +import os + +from .argparsing import ( + RawArguments, + AutoDetectedOption, +) + +###################### +## PUBLIC INTERFACE ## +###################### + +@dataclass +class CompletionArguments: + model: str + n_completions: int + temperature: float + presence_penalty: float + frequency_penalty: float + max_tokens: int + top_p: float + +@dataclass +class DisplayArguments: + adornments: bool + color: bool + +@dataclass +class DebugArguments: + save_response_to_file: Optional[str] + load_response_from_file: Optional[str] + +@dataclass +class MessageSource: + message: Optional[str] = None + prompt_from_fd: Optional[str] = None + prompt_from_file: Optional[str] = None + +@dataclass +class Arguments: + completion_args: CompletionArguments + display_args: DisplayArguments + version: bool + list_models: bool + interactive: bool + initial_message: MessageSource + openai_key: str + system_message: Optional[str] = None + debug_args: Optional[DebugArguments] = None + +def post_process_raw_args(raw_args : RawArguments) -> Arguments: + _populate_defaults(raw_args) + _issue_warnings(raw_args) + _validate_args(raw_args) + return _restructure_arguments(raw_args) + +##################### +## PRIVATE LOGIC ## +##################### + +def _restructure_arguments(raw_args : RawArguments) -> Arguments: + + args = raw_args.args + openai_key = raw_args.openai_key + + completion_args = CompletionArguments( + model=args.model, + n_completions=args.n_completions, + temperature=args.temperature, + presence_penalty=args.presence_penalty, + frequency_penalty=args.frequency_penalty, + max_tokens=args.max_tokens, + top_p=args.top_p, + ) + + msg_src = MessageSource( + message = args.message, + prompt_from_fd = args.prompt_from_fd, + prompt_from_file = args.prompt_from_file, + ) + + display_args = DisplayArguments( + adornments=(args.adornments == AutoDetectedOption.ON), + color=(args.color == AutoDetectedOption.ON), + ) + + debug_args = DebugArguments( + save_response_to_file=args.save_response_to_file, + load_response_from_file=args.load_response_from_file, + ) + + return Arguments( + initial_message=msg_src, + openai_key=openai_key, + completion_args=completion_args, + display_args=display_args, + debug_args=debug_args, + version=args.version, + list_models=args.list_models, + interactive=args.interactive, + system_message=args.system_message + ) + +def _die_validation_err(err : str): + print(err, file=sys.stderr) + sys.exit(1) + +def _validate_args(raw_args : RawArguments) -> None: + + args = raw_args.args + debug = raw_args.debug + + if not 0 <= args.temperature <= 2: + _die_validation_err("Temperature must be between 0 and 2.") + + if not -2 <= args.frequency_penalty <= 2: + _die_validation_err("Frequency penalty must be between -2.0 and 2.0.") + + if not -2 <= args.presence_penalty <= 2: + _die_validation_err("Presence penalty must be between -2.0 and 2.0.") + + if args.max_tokens < 1: + _die_validation_err("Max tokens must be greater than or equal to 1.") + + if not 0 <= args.top_p <= 1: + _die_validation_err("Top_p must be between 0 and 1.") + + if args.n_completions < 1: + _die_validation_err("Number of completions must be greater than or equal to 1.") + + if args.interactive and args.n_completions != 1: + _die_validation_err("Only a single completion can be used in interactive mode") + + if (args.prompt_from_fd or args.prompt_from_file) and args.message: + _die_validation_err("Cannot specify an initial message alongside --prompt_from_fd or --prompt_from_file") + + if debug and args.interactive: + + if args.interactive and ( + args.save_response_to_file or args.load_response_from_file + ): + _die_validation_err("Save and load operations cannot be used in interactive mode") + +def _issue_warnings(raw_args : RawArguments): + + if not raw_args.openai_key: + print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr) + print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr) + + sys.exit(1) + + if raw_args.debug: + logging.warning("Debugging mode and unstable features have been enabled.") + + if raw_args.debug and raw_args.args.load_response_from_file: + logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {raw_args.args.load_response_from_file} was generated') + +def _populate_defaults(raw_args : RawArguments): + + args = raw_args.args + debug = raw_args.debug + + if args.color == AutoDetectedOption.AUTO: + if os.getenv("NO_COLOR"): + args.color = AutoDetectedOption.OFF + elif sys.stdout.isatty(): + args.color = AutoDetectedOption.ON + else: + args.color = AutoDetectedOption.OFF + + if args.adornments == AutoDetectedOption.AUTO: + if sys.stdout.isatty(): + args.adornments = AutoDetectedOption.ON + else: + args.adornments = AutoDetectedOption.OFF + + initial_message_specified = ( + args.message or + args.prompt_from_fd or + args.prompt_from_file + ) + + if not initial_message_specified: + if debug and args.load_response_from_file: + args.interactive = False + elif sys.stdin.isatty(): + args.interactive = True + + if not debug: + args.load_response_from_file = None + args.save_response_to_file = None + diff --git a/src/gpt_chat_cli/cmd.py b/src/gpt_chat_cli/cmd.py new file mode 100644 index 0000000..d7aed6c --- /dev/null +++ b/src/gpt_chat_cli/cmd.py @@ -0,0 +1,326 @@ +#!/bin/env python3 + +import sys +import openai +import pickle +import os +import datetime + +from collections import defaultdict +from dataclasses import dataclass +from typing import Tuple, Optional + +from .openai_wrappers import ( + create_chat_completion, + list_models, + OpenAIChatResponse, + OpenAIChatResponseStream, + FinishReason, + Role, + ChatMessage +) + +from .argvalidation import ( + Arguments, + DisplayArguments, + CompletionArguments, + DebugArguments, + MessageSource +) + +from .version import VERSION +from .color import get_color_codes +from .chat_colorizer import ChatColorizer + +########################### +#### UTILS #### +########################### + +def resolve_initial_message(src: MessageSource, interactive=False) -> str: + msg = None + + if src.message: + msg = src.message + elif src.prompt_from_fd: + with os.fdopen(src.prompt_from_fd, "r") as f: + msg = f.read() + elif src.prompt_from_file: + with open(src.prompt_from_file, "r") as f: + msg = f.read() + elif not interactive: + msg = sys.stdin.read() + + return msg + +def get_system_message(system_message : Optional[str]): + + if not system_message: + + current_date_time = datetime.datetime.now() + + system_message = f'The current date is {current_date_time}. When emitting code or producing markdown, ensure to label fenced code blocks with the language in use.' + + return ChatMessage(Role.SYSTEM, system_message) + +def enable_emacs_editing(): + try: + import readline + except ImportError: + pass + +########################### +#### SAVE / REPLAY #### +########################### + +@dataclass +class CompletionContext: + message: str + completion_args: CompletionArguments + system_message: Optional[str] = None + +def create_singleton_chat_completion(ctx : CompletionContext): + + hist = [ + get_system_message(ctx.system_message), + ChatMessage(Role.USER, ctx.message) + ] + + completion = create_chat_completion(hist, ctx.completion_args) + + return completion + +def save_response_and_arguments(args : Arguments) -> None: + + message = resolve_initial_message(args.initial_message) + + ctx = CompletionContext( + message=message, + completion_args=args.completion_args, + system_message=args.system_message + ) + + completion = create_singleton_chat_completion( + message, + args.completion_args, + args.system_message, + ) + + completion = list(completion) + + filename = args.debug_args.save_response_to_file + + with open(filename, 'wb') as f: + pickle.dump((ctx, completion,), f) + +def load_response_and_arguments(args : Arguments) \ + -> Tuple[CompletionContext, OpenAIChatResponseStream]: + + filename = args.debug_args.load_response_from_file + + with open(filename, 'rb') as f: + ctx, completion = pickle.load(f) + + return (ctx, completion) + +######################### +#### PRETTY PRINTING #### +######################### + +@dataclass +class CumulativeResponse: + delta_content: str = "" + finish_reason: FinishReason = FinishReason.NONE + content: str = "" + + def take_delta(self : "CumulativeResponse"): + chunk = self.delta_content + self.delta_content = "" + return chunk + + def add_content(self : "CumulativeResponse", new_chunk : str): + self.content += new_chunk + self.delta_content += new_chunk + +def print_streamed_response( + display_args : DisplayArguments, + completion : OpenAIChatResponseStream, + n_completions : int, + return_responses : bool = False + ) -> None: + """ + Print the response in real time by printing the deltas as they occur. If multiple responses + are requested, print the first in real-time, accumulating the others in the background. One the + first response completes, move on to the second response printing the deltas in real time. Continue + on until all responses have been printed. + """ + + no_color = not display_args.color + + COLOR_CODE = get_color_codes(no_color = no_color) + adornments = display_args.adornments + + cumu_responses = defaultdict(CumulativeResponse) + display_idx = 0 + prompt_printed = False + + chat_colorizer = ChatColorizer(no_color = no_color) + + for update in completion: + + for choice in update.choices: + delta = choice.delta + + if delta.content: + cumu_responses[choice.index].add_content(delta.content) + + if choice.finish_reason is not FinishReason.NONE: + cumu_responses[choice.index].finish_reason = choice.finish_reason + + display_response = cumu_responses[display_idx] + + if not prompt_printed and adornments: + res_indicator = '' if n_completions == 1 else \ + f' {display_idx + 1}/{n_completions}' + PROMPT = f'[{COLOR_CODE.GREEN}{update.model}{COLOR_CODE.RESET}{COLOR_CODE.RED}{res_indicator}{COLOR_CODE.RESET}]' + prompt_printed = True + print(PROMPT, end=' ', flush=True) + + content = display_response.take_delta() + chat_colorizer.add_chunk( content ) + + chat_colorizer.print() + + if display_response.finish_reason is not FinishReason.NONE: + chat_colorizer.finish() + chat_colorizer.print() + chat_colorizer = ChatColorizer( no_color=no_color ) + + if display_idx < n_completions: + display_idx += 1 + prompt_printed = False + + if adornments: + print(end='\n\n', flush=True) + else: + print(end='\n', flush=True) + + if return_responses: + return [ cumu_responses[i].content for i in range(n_completions) ] + +######################### +#### COMMANDS #### +######################### + +def surround_ansi_escapes(prompt, start = "\x01", end = "\x02"): + ''' + Fixes issue on Linux with the readline module + See: https://github.com/python/cpython/issues/61539 + ''' + escaped = False + result = "" + + for c in prompt: + if c == "\x1b" and not escaped: + result += start + c + escaped = True + elif c.isalpha() and escaped: + result += c + end + escaped = False + else: + result += c + + return result + +def version(): + print(f'version {VERSION}') + +def list_models(): + for model in list_models(): + print(model) + +def interactive(args : Arguments): + + enable_emacs_editing() + + COLOR_CODE = get_color_codes(no_color = not args.display_args.color) + + completion_args = args.completion_args + display_args = args.display_args + + hist = [ get_system_message( args.system_message ) ] + + PROMPT = f'[{COLOR_CODE.WHITE}#{COLOR_CODE.RESET}] ' + PROMPT = surround_ansi_escapes(PROMPT) + + def prompt_message() -> bool: + + # Control-D closes the input stream + try: + message = input( PROMPT ) + except (EOFError, KeyboardInterrupt): + print() + return False + + hist.append( ChatMessage( Role.USER, message ) ) + + return True + + print(f'GPT Chat CLI version {VERSION}') + print(f'Press Control-D to exit') + + initial_message = resolve_initial_message(args.initial_message, interactive=True) + + if initial_message: + print( PROMPT, initial_message, sep='', flush=True ) + hist.append( ChatMessage( Role.USER, initial_message ) ) + else: + if not prompt_message(): + return + + while True: + + completion = create_chat_completion(hist, completion_args) + + try: + response = print_streamed_response( + display_args, completion, 1, return_responses=True, + )[0] + + hist.append( ChatMessage(Role.ASSISTANT, response) ) + except KeyboardInterrupt: + print() + + if not prompt_message(): + return + +def singleton(args: Arguments): + completion_args = args.completion_args + + debug_args : DebugArguments = args.debug_args + message = args.initial_message + + if debug_args.save_response_to_file: + save_response_and_arguments(args) + return + elif debug_args.load_response_from_file: + ctx, completion = load_response_and_arguments(args) + + message = ctx.message + completion_args = ctx.completion_args + else: + # message is only None is a TTY is not attached + message = resolve_initial_message(args.initial_message) + + ctx = CompletionContext( + message=message, + completion_args=completion_args, + system_message=args.system_message + ) + + completion = create_singleton_chat_completion(ctx) + + print_streamed_response( + args.display_args, + completion, + completion_args.n_completions + ) diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/gcli.py deleted file mode 100644 index 916542b..0000000 --- a/src/gpt_chat_cli/gcli.py +++ /dev/null @@ -1,343 +0,0 @@ -#!/bin/env python3 - -import sys -import openai -import pickle -import os -import datetime - -from collections import defaultdict -from dataclasses import dataclass -from typing import Tuple, Optional - -from .openai_wrappers import ( - create_chat_completion, - list_models, - OpenAIChatResponse, - OpenAIChatResponseStream, - FinishReason, - Role, - ChatMessage -) - -from .argparsing import ( - parse_args, - Arguments, - DisplayArguments, - CompletionArguments, - DebugArguments, - MessageSource -) - -from .version import VERSION -from .color import get_color_codes -from .chat_colorizer import ChatColorizer - -########################### -#### UTILS #### -########################### - -def resolve_initial_message(src: MessageSource, interactive=False) -> str: - msg = None - - if src.message: - msg = src.message - elif src.prompt_from_fd: - with os.fdopen(src.prompt_from_fd, "r") as f: - msg = f.read() - elif src.prompt_from_file: - with open(src.prompt_from_file, "r") as f: - msg = f.read() - elif not interactive: - msg = sys.stdin.read() - - return msg - -def get_system_message(system_message : Optional[str]): - - if not system_message: - - current_date_time = datetime.datetime.now() - - system_message = f'The current date is {current_date_time}. When emitting code or producing markdown, ensure to label fenced code blocks with the language in use.' - - return ChatMessage(Role.SYSTEM, system_message) - -def enable_emacs_editing(): - try: - import readline - except ImportError: - pass - -########################### -#### SAVE / REPLAY #### -########################### - -@dataclass -class CompletionContext: - message: str - completion_args: CompletionArguments - system_message: Optional[str] = None - -def create_singleton_chat_completion(ctx : CompletionContext): - - hist = [ - get_system_message(ctx.system_message), - ChatMessage(Role.USER, ctx.message) - ] - - completion = create_chat_completion(hist, ctx.completion_args) - - return completion - -def save_response_and_arguments(args : Arguments) -> None: - - message = resolve_initial_message(args.initial_message) - - ctx = CompletionContext( - message=message, - completion_args=args.completion_args, - system_message=args.system_message - ) - - completion = create_singleton_chat_completion( - message, - args.completion_args, - args.system_message, - ) - - completion = list(completion) - - filename = args.debug_args.save_response_to_file - - with open(filename, 'wb') as f: - pickle.dump((ctx, completion,), f) - -def load_response_and_arguments(args : Arguments) \ - -> Tuple[CompletionContext, OpenAIChatResponseStream]: - - filename = args.debug_args.load_response_from_file - - with open(filename, 'rb') as f: - ctx, completion = pickle.load(f) - - return (ctx, completion) - -######################### -#### PRETTY PRINTING #### -######################### - -@dataclass -class CumulativeResponse: - delta_content: str = "" - finish_reason: FinishReason = FinishReason.NONE - content: str = "" - - def take_delta(self : "CumulativeResponse"): - chunk = self.delta_content - self.delta_content = "" - return chunk - - def add_content(self : "CumulativeResponse", new_chunk : str): - self.content += new_chunk - self.delta_content += new_chunk - -def print_streamed_response( - display_args : DisplayArguments, - completion : OpenAIChatResponseStream, - n_completions : int, - return_responses : bool = False - ) -> None: - """ - Print the response in real time by printing the deltas as they occur. If multiple responses - are requested, print the first in real-time, accumulating the others in the background. One the - first response completes, move on to the second response printing the deltas in real time. Continue - on until all responses have been printed. - """ - - no_color = not display_args.color - - COLOR_CODE = get_color_codes(no_color = no_color) - adornments = display_args.adornments - - cumu_responses = defaultdict(CumulativeResponse) - display_idx = 0 - prompt_printed = False - - chat_colorizer = ChatColorizer(no_color = no_color) - - for update in completion: - - for choice in update.choices: - delta = choice.delta - - if delta.content: - cumu_responses[choice.index].add_content(delta.content) - - if choice.finish_reason is not FinishReason.NONE: - cumu_responses[choice.index].finish_reason = choice.finish_reason - - display_response = cumu_responses[display_idx] - - if not prompt_printed and adornments: - res_indicator = '' if n_completions == 1 else \ - f' {display_idx + 1}/{n_completions}' - PROMPT = f'[{COLOR_CODE.GREEN}{update.model}{COLOR_CODE.RESET}{COLOR_CODE.RED}{res_indicator}{COLOR_CODE.RESET}]' - prompt_printed = True - print(PROMPT, end=' ', flush=True) - - content = display_response.take_delta() - chat_colorizer.add_chunk( content ) - - chat_colorizer.print() - - if display_response.finish_reason is not FinishReason.NONE: - chat_colorizer.finish() - chat_colorizer.print() - chat_colorizer = ChatColorizer( no_color=no_color ) - - if display_idx < n_completions: - display_idx += 1 - prompt_printed = False - - if adornments: - print(end='\n\n', flush=True) - else: - print(end='\n', flush=True) - - if return_responses: - return [ cumu_responses[i].content for i in range(n_completions) ] - -######################### -#### COMMANDS #### -######################### - -def cmd_version(): - print(f'version {VERSION}') - -def cmd_list_models(): - for model in list_models(): - print(model) - -def surround_ansi_escapes(prompt, start = "\x01", end = "\x02"): - ''' - Fixes issue on Linux with the readline module - See: https://github.com/python/cpython/issues/61539 - ''' - escaped = False - result = "" - - for c in prompt: - if c == "\x1b" and not escaped: - result += start + c - escaped = True - elif c.isalpha() and escaped: - result += c + end - escaped = False - else: - result += c - - return result - -def cmd_interactive(args : Arguments): - - enable_emacs_editing() - - COLOR_CODE = get_color_codes(no_color = not args.display_args.color) - - completion_args = args.completion_args - display_args = args.display_args - - hist = [ get_system_message( args.system_message ) ] - - PROMPT = f'[{COLOR_CODE.WHITE}#{COLOR_CODE.RESET}] ' - PROMPT = surround_ansi_escapes(PROMPT) - - def prompt_message() -> bool: - - # Control-D closes the input stream - try: - message = input( PROMPT ) - except (EOFError, KeyboardInterrupt): - print() - return False - - hist.append( ChatMessage( Role.USER, message ) ) - - return True - - print(f'GPT Chat CLI version {VERSION}') - print(f'Press Control-D to exit') - - initial_message = resolve_initial_message(args.initial_message, interactive=True) - - if initial_message: - print( PROMPT, initial_message, sep='', flush=True ) - hist.append( ChatMessage( Role.USER, initial_message ) ) - else: - if not prompt_message(): - return - - while True: - - completion = create_chat_completion(hist, completion_args) - - try: - response = print_streamed_response( - display_args, completion, 1, return_responses=True, - )[0] - - hist.append( ChatMessage(Role.ASSISTANT, response) ) - except KeyboardInterrupt: - print() - - if not prompt_message(): - return - -def cmd_singleton(args: Arguments): - completion_args = args.completion_args - - debug_args : DebugArguments = args.debug_args - message = args.initial_message - - if debug_args.save_response_to_file: - save_response_and_arguments(args) - return - elif debug_args.load_response_from_file: - ctx, completion = load_response_and_arguments(args) - - message = ctx.message - completion_args = ctx.completion_args - else: - # message is only None is a TTY is not attached - message = resolve_initial_message(args.initial_message) - - ctx = CompletionContext( - message=message, - completion_args=completion_args, - system_message=args.system_message - ) - - completion = create_singleton_chat_completion(ctx) - - print_streamed_response( - args.display_args, - completion, - completion_args.n_completions - ) - - -def main(): - args = parse_args() - - if args.version: - cmd_version() - elif args.list_models: - cmd_list_models() - elif args.interactive: - cmd_interactive(args) - else: - cmd_singleton(args) - -if __name__ == "__main__": - main() diff --git a/src/gpt_chat_cli/main.py b/src/gpt_chat_cli/main.py new file mode 100644 index 0000000..77d2708 --- /dev/null +++ b/src/gpt_chat_cli/main.py @@ -0,0 +1,38 @@ +def main(): + # defer other imports until autocomplete has finished + from .argparsing import parse_raw_args_or_complete + + raw_args = parse_raw_args_or_complete() + + # post process and validate + from .argvalidation import ( + Arguments, + post_process_raw_args + ) + + args = post_process_raw_args( raw_args ) + + # populate key + import openai + + openai.api_key = args.openai_key + + # execute relevant command + from .cmd import ( + version, + list_models, + interactive, + singleton, + ) + + if args.version: + version() + elif args.list_models: + list_models() + elif args.interactive: + interactive(args) + else: + singleton(args) + +if __name__ == "__main__": + main() diff --git a/src/gpt_chat_cli/openai_wrappers.py b/src/gpt_chat_cli/openai_wrappers.py index 3e1ec06..d478531 100644 --- a/src/gpt_chat_cli/openai_wrappers.py +++ b/src/gpt_chat_cli/openai_wrappers.py @@ -5,6 +5,8 @@ from typing import Any, List, Optional, Generator from dataclasses import dataclass from enum import Enum, auto +from .argvalidation import CompletionArguments + @dataclass class Delta: content: Optional[str] = None @@ -21,7 +23,6 @@ class FinishReason(Enum): if finish_reason_str is None: return FinishReason.NONE return FinishReason[finish_reason_str.upper()] - @dataclass class Choice: delta: Delta @@ -79,8 +80,6 @@ class OpenAIChatResponse: OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None] -from .argparsing import CompletionArguments - def create_chat_completion(hist : ChatHistory, args: CompletionArguments) \ -> OpenAIChatResponseStream: diff --git a/src/gpt_chat_cli/version.py b/src/gpt_chat_cli/version.py index 856ce1d..b5fd259 100644 --- a/src/gpt_chat_cli/version.py +++ b/src/gpt_chat_cli/version.py @@ -1 +1 @@ -VERSION = '0.1.2' +VERSION = '0.2.2-alpha.1' -- cgit v1.2.3