From 69654841fd272a66403031cd6e9311423f9f117b Mon Sep 17 00:00:00 2001 From: flu0r1ne Date: Fri, 5 May 2023 01:14:50 -0500 Subject: Disable color when redirection occurs. Only save query arguments with save/load API. --- src/gpt_chat_cli/gcli.py | 65 +++++++++++++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 26 deletions(-) (limited to 'src/gpt_chat_cli/gcli.py') diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/gcli.py index dd7dbbb..dcd7c8b 100644 --- a/src/gpt_chat_cli/gcli.py +++ b/src/gpt_chat_cli/gcli.py @@ -1,6 +1,5 @@ #!/bin/env python3 -import argparse import sys import openai import pickle @@ -18,7 +17,9 @@ from .openai_wrappers import ( from .argparsing import ( parse_args, - AutoDetectedOption, + Arguments, + DisplayArguments, + CompletionArguments, ) from .color import get_color_codes @@ -27,7 +28,7 @@ from .color import get_color_codes #### SAVE / REPLAY #### ########################### -def create_chat_completion_from_args(args : argparse.Namespace) \ +def create_chat_completion_from_args(args : CompletionArguments) \ -> OpenAIChatResponseStream: return create_chat_completion( model=args.model, @@ -41,19 +42,19 @@ def create_chat_completion_from_args(args : argparse.Namespace) \ stream=True ) -def save_response_and_arguments(args : argparse.Namespace) -> None: - completion = create_chat_completion_from_args(args) +def save_response_and_arguments(args : Arguments) -> None: + completion = create_chat_completion_from_args(args.completion_args) completion = list(completion) - filename = args.save_response_to_file + filename = args.debug_args.save_response_to_file with open(filename, 'wb') as f: - pickle.dump((args, completion,), f) + pickle.dump((args.completion_args, completion,), f) -def load_response_and_arguments(args : argparse.Namespace) \ - -> Tuple[argparse.Namespace, OpenAIChatResponseStream]: +def load_response_and_arguments(args : Arguments) \ + -> Tuple[CompletionArguments, OpenAIChatResponseStream]: - filename = args.load_response_from_file + filename = args.debug_args.load_response_from_file with open(filename, 'rb') as f: args, completion = pickle.load(f) @@ -74,7 +75,11 @@ class CumulativeResponse: self.content = "" return chunk -def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatResponseStream): +def print_streamed_response( + display_args : DisplayArguments, + completion : OpenAIChatResponseStream, + n_completions : int + ) -> None: """ Print the response in real time by printing the deltas as they occur. If multiple responses are requested, print the first in real-time, accumulating the others in the background. One the @@ -82,9 +87,8 @@ def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatRe on until all responses have been printed. """ - COLOR_CODE = get_color_codes(no_color = args.color == AutoDetectedOption.OFF) - ADORNMENTS = args.adornments == AutoDetectedOption.ON - N_COMPLETIONS = args.n_completions + COLOR_CODE = get_color_codes(no_color = not display_args.color) + adornments = display_args.adornments cumu_responses = defaultdict(CumulativeResponse) display_idx = 0 @@ -103,9 +107,9 @@ def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatRe display_response = cumu_responses[display_idx] - if not prompt_printed and ADORNMENTS: - res_indicator = '' if N_COMPLETIONS == 1 else \ - f' {display_idx + 1}/{N_COMPLETIONS}' + if not prompt_printed and adornments: + res_indicator = '' if n_completions == 1 else \ + f' {display_idx + 1}/{n_completions}' PROMPT = f'[{COLOR_CODE.GREEN}{update.model}{COLOR_CODE.RESET}{COLOR_CODE.RED}{res_indicator}{COLOR_CODE.RESET}]' prompt_printed = True print(PROMPT, end=' ', flush=True) @@ -116,11 +120,11 @@ def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatRe sep='', end='', flush=True) if display_response.finish_reason is not FinishReason.NONE: - if display_idx < N_COMPLETIONS: + if display_idx < n_completions: display_idx += 1 prompt_printed = False - if ADORNMENTS: + if adornments: print(end='\n\n', flush=True) else: print(end='\n', flush=True) @@ -128,15 +132,24 @@ def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatRe def main(): args = parse_args() - if args.save_response_to_file: - save_response_and_arguments(args) - return - elif args.load_response_from_file: - args, completion = load_response_and_arguments(args) + completion_args = args.completion_args + + if args.debug_args: + debug_args : DebugArguments = args.debug_args + + if debug_args.save_response_to_file: + save_response_and_arguments(args) + return + elif debug_args.load_response_from_file: + completion_args, completion = load_response_and_arguments(args) else: - completion = create_chat_completion_from_args(args) + completion = create_chat_completion_from_args(completion_args) - print_streamed_response(args, completion) + print_streamed_response( + args.display_args, + completion, + completion_args.n_completions + ) if __name__ == "__main__": main() -- cgit v1.2.3