From f4118779b31e21b8cfb20ddf64e6fd355d4d02fc Mon Sep 17 00:00:00 2001 From: flu0r1ne Date: Sat, 6 May 2023 14:56:08 -0500 Subject: Add alternative file sources, specify system message --- src/gpt_chat_cli/gcli.py | 134 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 92 insertions(+), 42 deletions(-) (limited to 'src/gpt_chat_cli/gcli.py') diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/gcli.py index 1c5555c..5f2a478 100644 --- a/src/gpt_chat_cli/gcli.py +++ b/src/gpt_chat_cli/gcli.py @@ -3,10 +3,12 @@ import sys import openai import pickle +import os +import datetime from collections import defaultdict from dataclasses import dataclass -from typing import Tuple +from typing import Tuple, Optional from .openai_wrappers import ( create_chat_completion, @@ -24,50 +26,102 @@ from .argparsing import ( DisplayArguments, CompletionArguments, DebugArguments, + MessageSource ) from .version import VERSION from .color import get_color_codes +from .chat_colorizer import ChatColorizer -import datetime +########################### +#### UTILS #### +########################### + +def resolve_initial_message(src: MessageSource, interactive=False) -> str: + msg = None + + if src.message: + msg = src.message + elif src.prompt_from_fd: + with os.fdopen(src.prompt_from_fd, "r") as f: + msg = f.read() + elif src.prompt_from_file: + with open(src.prompt_from_file, "r") as f: + msg = f.read() + elif not interactive: + msg = sys.stdin.read() + + return msg + +def get_system_message(system_message : Optional[str]): + + if not system_message: + + current_date_time = datetime.datetime.now() + system_message = f'The current date is {current_date_time}. When emitting code or producing markdown, ensure to label fenced code blocks with the language in use.' + + return ChatMessage(Role.SYSTEM, system_message) + +def enable_emacs_editing(): + try: + import readline + except ImportError: + pass ########################### #### SAVE / REPLAY #### ########################### -def create_singleton_chat_completion( - message : str, - completion_args : CompletionArguments - ): +@dataclass +class CompletionContext: + message: str + completion_args: CompletionArguments + system_message: Optional[str] = None - hist = [ get_system_message(), ChatMessage( Role.USER, message ) ] +def create_singleton_chat_completion(ctx : CompletionContext): - completion = create_chat_completion(hist, completion_args) + hist = [ + get_system_message(ctx.system_message), + ChatMessage(Role.USER, ctx.message) + ] + + completion = create_chat_completion(hist, ctx.completion_args) return completion def save_response_and_arguments(args : Arguments) -> None: - message = args.initial_message + message = resolve_initial_message(args.initial_message) + + ctx = CompletionContext( + message=message, + completion_args=args.completion_args, + system_message=args.system_message + ) + + completion = create_singleton_chat_completion( + message, + args.completion_args, + args.system_message, + ) - completion = create_singleton_chat_completion(message, args.completion_args) completion = list(completion) filename = args.debug_args.save_response_to_file with open(filename, 'wb') as f: - pickle.dump((message, args.completion_args, completion,), f) + pickle.dump((ctx, completion,), f) def load_response_and_arguments(args : Arguments) \ - -> Tuple[CompletionArguments, OpenAIChatResponseStream]: + -> Tuple[CompletionContext, OpenAIChatResponseStream]: filename = args.debug_args.load_response_from_file with open(filename, 'rb') as f: - message, args, completion = pickle.load(f) + ctx, completion = pickle.load(f) - return (message, args, completion) + return (ctx, completion) ######################### #### PRETTY PRINTING #### @@ -88,8 +142,6 @@ class CumulativeResponse: self.content += new_chunk self.delta_content += new_chunk -from .chat_colorizer import ChatColorizer - def print_streamed_response( display_args : DisplayArguments, completion : OpenAIChatResponseStream, @@ -156,12 +208,9 @@ def print_streamed_response( if return_responses: return [ cumu_responses[i].content for i in range(n_completions) ] -def get_system_message(): - current_date_time = datetime.datetime.now() - - msg = f'The current date is {current_date_time}. When emitting code or producing markdown, ensure to label fenced code blocks with the language in use.' - - return ChatMessage( Role.SYSTEM, msg) +######################### +#### COMMANDS #### +######################### def cmd_version(): print(f'version {VERSION}') @@ -170,15 +219,6 @@ def cmd_list_models(): for model in list_models(): print(model) -def enable_emacs_editing(): - try: - import readline - # self.old_completer = readline.get_completer() - # readline.set_completer(self.complete) - # readline.parse_and_bind(self.completekey+": complete") - except ImportError: - pass - def cmd_interactive(args : Arguments): enable_emacs_editing() @@ -188,7 +228,7 @@ def cmd_interactive(args : Arguments): completion_args = args.completion_args display_args = args.display_args - hist = [ get_system_message() ] + hist = [ get_system_message( args.system_message ) ] PROMPT = f'[{COLOR_CODE.WHITE}#{COLOR_CODE.RESET}] ' @@ -208,9 +248,11 @@ def cmd_interactive(args : Arguments): print(f'GPT Chat CLI version {VERSION}') print(f'Press Control-D to exit') - if args.initial_message: - print( PROMPT, args.initial_message, sep='' ) - hist.append( ChatMessage( Role.USER, args.initial_message ) ) + initial_message = resolve_initial_message(args.initial_message, interactive=True) + + if initial_message: + print( PROMPT, initial_message, sep='', flush=True ) + hist.append( ChatMessage( Role.USER, initial_message ) ) else: if not prompt_message(): return @@ -225,11 +267,11 @@ def cmd_interactive(args : Arguments): )[0] hist.append( ChatMessage(Role.ASSISTANT, response) ) - except: - pass + except KeyboardInterrupt: + print() if not prompt_message(): - break + return def cmd_singleton(args: Arguments): completion_args = args.completion_args @@ -241,13 +283,21 @@ def cmd_singleton(args: Arguments): save_response_and_arguments(args) return elif debug_args.load_response_from_file: - message, completion_args, completion = load_response_and_arguments(args) + ctx, completion = load_response_and_arguments(args) + + message = ctx.message + completion_args = ctx.completion_args else: # message is only None is a TTY is not attached - if message is None: - message = sys.stdin.read() + message = resolve_initial_message(args.initial_message) + + ctx = CompletionContext( + message=message, + completion_args=completion_args, + system_message=args.system_message + ) - completion = create_singleton_chat_completion(message, completion_args) + completion = create_singleton_chat_completion(ctx) print_streamed_response( args.display_args, -- cgit v1.2.3