From f3e5d6c6615f1ee623fcda83da671a54fafd128c Mon Sep 17 00:00:00 2001 From: flu0r1ne Date: Fri, 5 May 2023 14:48:37 -0500 Subject: Add interactive mode --- src/gpt_chat_cli/gcli.py | 131 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 101 insertions(+), 30 deletions(-) (limited to 'src/gpt_chat_cli/gcli.py') diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/gcli.py index 2d40cf2..9df4459 100644 --- a/src/gpt_chat_cli/gcli.py +++ b/src/gpt_chat_cli/gcli.py @@ -14,6 +14,8 @@ from .openai_wrappers import ( OpenAIChatResponse, OpenAIChatResponseStream, FinishReason, + Role, + ChatMessage ) from .argparsing import ( @@ -31,28 +33,43 @@ from .color import get_color_codes #### SAVE / REPLAY #### ########################### -def create_chat_completion_from_args(args : CompletionArguments) \ - -> OpenAIChatResponseStream: - return create_chat_completion( - model=args.model, - messages=[{ "role": "user", "content": args.message }], - n=args.n_completions, - temperature=args.temperature, - presence_penalty=args.presence_penalty, - frequency_penalty=args.frequency_penalty, - max_tokens=args.max_tokens, - top_p=args.top_p, - stream=True - ) + +# def create_chat_completion_from_args(args : CompletionArguments) \ +# -> OpenAIChatResponseStream: +# return create_chat_completion( +# model=args.model, +# messages=[{ "role": "user", "content": args.message }], +# n=args.n_completions, +# temperature=args.temperature, +# presence_penalty=args.presence_penalty, +# frequency_penalty=args.frequency_penalty, +# max_tokens=args.max_tokens, +# top_p=args.top_p, +# stream=True +# ) + +def create_singleton_chat_completion( + message : str, + completion_args : CompletionArguments + ): + + hist = [ ChatMessage( Role.USER, message ) ] + + completion = create_chat_completion(hist, completion_args) + + return completion def save_response_and_arguments(args : Arguments) -> None: - completion = create_chat_completion_from_args(args.completion_args) + + message = args.initial_message + + completion = create_singleton_chat_completion(message, args.completion_args) completion = list(completion) filename = args.debug_args.save_response_to_file with open(filename, 'wb') as f: - pickle.dump((args.completion_args, completion,), f) + pickle.dump((message, args.completion_args, completion,), f) def load_response_and_arguments(args : Arguments) \ -> Tuple[CompletionArguments, OpenAIChatResponseStream]: @@ -60,9 +77,9 @@ def load_response_and_arguments(args : Arguments) \ filename = args.debug_args.load_response_from_file with open(filename, 'rb') as f: - args, completion = pickle.load(f) + message, args, completion = pickle.load(f) - return (args, completion) + return (message, args, completion) ######################### #### PRETTY PRINTING #### @@ -70,18 +87,24 @@ def load_response_and_arguments(args : Arguments) \ @dataclass class CumulativeResponse: - content: str = "" + delta_content: str = "" finish_reason: FinishReason = FinishReason.NONE + content: str = "" - def take_content(self : "CumulativeResponse"): - chunk = self.content - self.content = "" + def take_delta(self : "CumulativeResponse"): + chunk = self.delta_content + self.delta_content = "" return chunk + def add_content(self : "CumulativeResponse", new_chunk : str): + self.content += new_chunk + self.delta_content += new_chunk + def print_streamed_response( display_args : DisplayArguments, completion : OpenAIChatResponseStream, - n_completions : int + n_completions : int, + return_responses : bool = False ) -> None: """ Print the response in real time by printing the deltas as they occur. If multiple responses @@ -103,7 +126,7 @@ def print_streamed_response( delta = choice.delta if delta.content: - cumu_responses[choice.index].content += delta.content + cumu_responses[choice.index].add_content(delta.content) if choice.finish_reason is not FinishReason.NONE: cumu_responses[choice.index].finish_reason = choice.finish_reason @@ -118,7 +141,7 @@ def print_streamed_response( print(PROMPT, end=' ', flush=True) - content = display_response.take_content() + content = display_response.take_delta() print(f'{COLOR_CODE.WHITE}{content}{COLOR_CODE.RESET}', sep='', end='', flush=True) @@ -132,6 +155,9 @@ def print_streamed_response( else: print(end='\n', flush=True) + if return_responses: + return [ cumu_responses[i].content for i in range(n_completions) ] + def cmd_version(): print(f'version {VERSION}') @@ -142,24 +168,69 @@ def cmd_list_models(): def cmd_interactive(args : Arguments): COLOR_CODE = get_color_codes(no_color = not args.display_args.color) - print(f'GPT Chat CLI {VERSION}') - print(f'[{COLOR_CODE.WHITE}#{COLOR_CODE.RESET}]', end=' ', flush=True) + completion_args = args.completion_args + display_args = args.display_args + + hist = [] + + def print_prompt(): + + print(f'[{COLOR_CODE.WHITE}#{COLOR_CODE.RESET}]', end=' ', flush=True) + + def prompt_message() -> bool: + print_prompt() + + # Control-D closes the input stream + try: + message = input() + except EOFError: + print() + return False + + hist.append( ChatMessage( Role.USER, message ) ) + + return True + + print(f'GPT Chat CLI version {VERSION}') + print(f'Press Control-D to exit') + + if args.initial_message: + print_prompt() + print( args.initial_message ) + hist.append( ChatMessage( Role.USER, args.initial_message ) ) + else: + prompt_message() + + while True: + + completion = create_chat_completion(hist, completion_args) + + response = print_streamed_response( + display_args, completion, 1, return_responses=True, + )[0] + + hist.append( ChatMessage(Role.ASSISTANT, response) ) + + if not prompt_message(): + break def cmd_singleton(args: Arguments): completion_args = args.completion_args debug_args : DebugArguments = args.debug_args + message = args.initial_message if debug_args.save_response_to_file: save_response_and_arguments(args) return elif debug_args.load_response_from_file: - completion_args, completion = load_response_and_arguments(args) + message, completion_args, completion = load_response_and_arguments(args) else: - if completion_args.message is None: - completion_args.message = sys.stdin.read() + # message is only None is a TTY is not attached + if message is None: + message = sys.stdin.read() - completion = create_chat_completion_from_args(completion_args) + completion = create_singleton_chat_completion(message, completion_args) print_streamed_response( args.display_args, -- cgit v1.2.3