diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/gpt_chat_cli/argparsing.py | 73 | ||||
-rw-r--r-- | src/gpt_chat_cli/gcli.py | 65 |
2 files changed, 108 insertions, 30 deletions
diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py index a7d3218..7d1d305 100644 --- a/src/gpt_chat_cli/argparsing.py +++ b/src/gpt_chat_cli/argparsing.py @@ -4,12 +4,17 @@ import logging import openai import sys from enum import Enum +from dataclasses import dataclass +from typing import Tuple, Optional class AutoDetectedOption(Enum): ON = 'on' OFF = 'off' AUTO = 'auto' + def __str__(self : "AutoDetectedOption"): + return self.value + def die_validation_err(err : str): print(err, file=sys.stderr) sys.exit(1) @@ -33,7 +38,62 @@ def validate_args(args: argparse.Namespace) -> None: if args.n_completions < 1: die_validation_err("Number of completions must be greater than or equal to 1.") -def parse_args(): +@dataclass +class CompletionArguments: + model: str + n_completions: int + temperature: float + presence_penalty: float + frequency_penalty: float + max_tokens: int + top_p: float + message: str + +@dataclass +class DisplayArguments: + adornments: bool + color: bool + + +@dataclass +class DebugArguments: + save_response_to_file: Optional[str] + load_response_from_file: Optional[str] + +@dataclass +class Arguments: + completion_args: CompletionArguments + display_args: DisplayArguments + debug_args: Optional[DebugArguments] = None + +def split_arguments(args: argparse.Namespace, debug=False) -> Arguments: + completion_args = CompletionArguments( + model=args.model, + n_completions=args.n_completions, + temperature=args.temperature, + presence_penalty=args.presence_penalty, + frequency_penalty=args.frequency_penalty, + max_tokens=args.max_tokens, + top_p=args.top_p, + message=args.message + ) + + display_args = DisplayArguments( + adornments=(args.adornments == AutoDetectedOption.ON), + color=(args.color == AutoDetectedOption.ON), + ) + + if debug: + debug_args = DebugArguments( + save_response_to_file=args.save_response_to_file, + load_response_from_file=args.load_response_from_file, + ) + else: + debug_args = None + + return Arguments( completion_args, display_args, debug_args ) + +def parse_args() -> Arguments: GCLI_ENV_PREFIX = "GCLI_" @@ -178,11 +238,16 @@ def parse_args(): if args.color == AutoDetectedOption.AUTO: if os.getenv("NO_COLOR"): args.color = AutoDetectedOption.OFF - else: + elif sys.stdout.isatty(): args.color = AutoDetectedOption.ON + else: + args.color = AutoDetectedOption.OFF if args.adornments == AutoDetectedOption.AUTO: - args.adornments = AutoDetectedOption.ON + if sys.stdout.isatty(): + args.adornments = AutoDetectedOption.ON + else: + args.adornments = AutoDetectedOption.OFF if not debug: args.load_response_from_file = None @@ -190,4 +255,4 @@ def parse_args(): validate_args(args) - return args + return split_arguments(args, debug=debug) diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/gcli.py index dd7dbbb..dcd7c8b 100644 --- a/src/gpt_chat_cli/gcli.py +++ b/src/gpt_chat_cli/gcli.py @@ -1,6 +1,5 @@ #!/bin/env python3 -import argparse import sys import openai import pickle @@ -18,7 +17,9 @@ from .openai_wrappers import ( from .argparsing import ( parse_args, - AutoDetectedOption, + Arguments, + DisplayArguments, + CompletionArguments, ) from .color import get_color_codes @@ -27,7 +28,7 @@ from .color import get_color_codes #### SAVE / REPLAY #### ########################### -def create_chat_completion_from_args(args : argparse.Namespace) \ +def create_chat_completion_from_args(args : CompletionArguments) \ -> OpenAIChatResponseStream: return create_chat_completion( model=args.model, @@ -41,19 +42,19 @@ def create_chat_completion_from_args(args : argparse.Namespace) \ stream=True ) -def save_response_and_arguments(args : argparse.Namespace) -> None: - completion = create_chat_completion_from_args(args) +def save_response_and_arguments(args : Arguments) -> None: + completion = create_chat_completion_from_args(args.completion_args) completion = list(completion) - filename = args.save_response_to_file + filename = args.debug_args.save_response_to_file with open(filename, 'wb') as f: - pickle.dump((args, completion,), f) + pickle.dump((args.completion_args, completion,), f) -def load_response_and_arguments(args : argparse.Namespace) \ - -> Tuple[argparse.Namespace, OpenAIChatResponseStream]: +def load_response_and_arguments(args : Arguments) \ + -> Tuple[CompletionArguments, OpenAIChatResponseStream]: - filename = args.load_response_from_file + filename = args.debug_args.load_response_from_file with open(filename, 'rb') as f: args, completion = pickle.load(f) @@ -74,7 +75,11 @@ class CumulativeResponse: self.content = "" return chunk -def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatResponseStream): +def print_streamed_response( + display_args : DisplayArguments, + completion : OpenAIChatResponseStream, + n_completions : int + ) -> None: """ Print the response in real time by printing the deltas as they occur. If multiple responses are requested, print the first in real-time, accumulating the others in the background. One the @@ -82,9 +87,8 @@ def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatRe on until all responses have been printed. """ - COLOR_CODE = get_color_codes(no_color = args.color == AutoDetectedOption.OFF) - ADORNMENTS = args.adornments == AutoDetectedOption.ON - N_COMPLETIONS = args.n_completions + COLOR_CODE = get_color_codes(no_color = not display_args.color) + adornments = display_args.adornments cumu_responses = defaultdict(CumulativeResponse) display_idx = 0 @@ -103,9 +107,9 @@ def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatRe display_response = cumu_responses[display_idx] - if not prompt_printed and ADORNMENTS: - res_indicator = '' if N_COMPLETIONS == 1 else \ - f' {display_idx + 1}/{N_COMPLETIONS}' + if not prompt_printed and adornments: + res_indicator = '' if n_completions == 1 else \ + f' {display_idx + 1}/{n_completions}' PROMPT = f'[{COLOR_CODE.GREEN}{update.model}{COLOR_CODE.RESET}{COLOR_CODE.RED}{res_indicator}{COLOR_CODE.RESET}]' prompt_printed = True print(PROMPT, end=' ', flush=True) @@ -116,11 +120,11 @@ def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatRe sep='', end='', flush=True) if display_response.finish_reason is not FinishReason.NONE: - if display_idx < N_COMPLETIONS: + if display_idx < n_completions: display_idx += 1 prompt_printed = False - if ADORNMENTS: + if adornments: print(end='\n\n', flush=True) else: print(end='\n', flush=True) @@ -128,15 +132,24 @@ def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatRe def main(): args = parse_args() - if args.save_response_to_file: - save_response_and_arguments(args) - return - elif args.load_response_from_file: - args, completion = load_response_and_arguments(args) + completion_args = args.completion_args + + if args.debug_args: + debug_args : DebugArguments = args.debug_args + + if debug_args.save_response_to_file: + save_response_and_arguments(args) + return + elif debug_args.load_response_from_file: + completion_args, completion = load_response_and_arguments(args) else: - completion = create_chat_completion_from_args(args) + completion = create_chat_completion_from_args(completion_args) - print_streamed_response(args, completion) + print_streamed_response( + args.display_args, + completion, + completion_args.n_completions + ) if __name__ == "__main__": main() |