aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/gpt_chat_cli/argparsing.py73
-rw-r--r--src/gpt_chat_cli/gcli.py65
2 files changed, 108 insertions, 30 deletions
diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py
index a7d3218..7d1d305 100644
--- a/src/gpt_chat_cli/argparsing.py
+++ b/src/gpt_chat_cli/argparsing.py
@@ -4,12 +4,17 @@ import logging
import openai
import sys
from enum import Enum
+from dataclasses import dataclass
+from typing import Tuple, Optional
class AutoDetectedOption(Enum):
ON = 'on'
OFF = 'off'
AUTO = 'auto'
+ def __str__(self : "AutoDetectedOption"):
+ return self.value
+
def die_validation_err(err : str):
print(err, file=sys.stderr)
sys.exit(1)
@@ -33,7 +38,62 @@ def validate_args(args: argparse.Namespace) -> None:
if args.n_completions < 1:
die_validation_err("Number of completions must be greater than or equal to 1.")
-def parse_args():
+@dataclass
+class CompletionArguments:
+ model: str
+ n_completions: int
+ temperature: float
+ presence_penalty: float
+ frequency_penalty: float
+ max_tokens: int
+ top_p: float
+ message: str
+
+@dataclass
+class DisplayArguments:
+ adornments: bool
+ color: bool
+
+
+@dataclass
+class DebugArguments:
+ save_response_to_file: Optional[str]
+ load_response_from_file: Optional[str]
+
+@dataclass
+class Arguments:
+ completion_args: CompletionArguments
+ display_args: DisplayArguments
+ debug_args: Optional[DebugArguments] = None
+
+def split_arguments(args: argparse.Namespace, debug=False) -> Arguments:
+ completion_args = CompletionArguments(
+ model=args.model,
+ n_completions=args.n_completions,
+ temperature=args.temperature,
+ presence_penalty=args.presence_penalty,
+ frequency_penalty=args.frequency_penalty,
+ max_tokens=args.max_tokens,
+ top_p=args.top_p,
+ message=args.message
+ )
+
+ display_args = DisplayArguments(
+ adornments=(args.adornments == AutoDetectedOption.ON),
+ color=(args.color == AutoDetectedOption.ON),
+ )
+
+ if debug:
+ debug_args = DebugArguments(
+ save_response_to_file=args.save_response_to_file,
+ load_response_from_file=args.load_response_from_file,
+ )
+ else:
+ debug_args = None
+
+ return Arguments( completion_args, display_args, debug_args )
+
+def parse_args() -> Arguments:
GCLI_ENV_PREFIX = "GCLI_"
@@ -178,11 +238,16 @@ def parse_args():
if args.color == AutoDetectedOption.AUTO:
if os.getenv("NO_COLOR"):
args.color = AutoDetectedOption.OFF
- else:
+ elif sys.stdout.isatty():
args.color = AutoDetectedOption.ON
+ else:
+ args.color = AutoDetectedOption.OFF
if args.adornments == AutoDetectedOption.AUTO:
- args.adornments = AutoDetectedOption.ON
+ if sys.stdout.isatty():
+ args.adornments = AutoDetectedOption.ON
+ else:
+ args.adornments = AutoDetectedOption.OFF
if not debug:
args.load_response_from_file = None
@@ -190,4 +255,4 @@ def parse_args():
validate_args(args)
- return args
+ return split_arguments(args, debug=debug)
diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/gcli.py
index dd7dbbb..dcd7c8b 100644
--- a/src/gpt_chat_cli/gcli.py
+++ b/src/gpt_chat_cli/gcli.py
@@ -1,6 +1,5 @@
#!/bin/env python3
-import argparse
import sys
import openai
import pickle
@@ -18,7 +17,9 @@ from .openai_wrappers import (
from .argparsing import (
parse_args,
- AutoDetectedOption,
+ Arguments,
+ DisplayArguments,
+ CompletionArguments,
)
from .color import get_color_codes
@@ -27,7 +28,7 @@ from .color import get_color_codes
#### SAVE / REPLAY ####
###########################
-def create_chat_completion_from_args(args : argparse.Namespace) \
+def create_chat_completion_from_args(args : CompletionArguments) \
-> OpenAIChatResponseStream:
return create_chat_completion(
model=args.model,
@@ -41,19 +42,19 @@ def create_chat_completion_from_args(args : argparse.Namespace) \
stream=True
)
-def save_response_and_arguments(args : argparse.Namespace) -> None:
- completion = create_chat_completion_from_args(args)
+def save_response_and_arguments(args : Arguments) -> None:
+ completion = create_chat_completion_from_args(args.completion_args)
completion = list(completion)
- filename = args.save_response_to_file
+ filename = args.debug_args.save_response_to_file
with open(filename, 'wb') as f:
- pickle.dump((args, completion,), f)
+ pickle.dump((args.completion_args, completion,), f)
-def load_response_and_arguments(args : argparse.Namespace) \
- -> Tuple[argparse.Namespace, OpenAIChatResponseStream]:
+def load_response_and_arguments(args : Arguments) \
+ -> Tuple[CompletionArguments, OpenAIChatResponseStream]:
- filename = args.load_response_from_file
+ filename = args.debug_args.load_response_from_file
with open(filename, 'rb') as f:
args, completion = pickle.load(f)
@@ -74,7 +75,11 @@ class CumulativeResponse:
self.content = ""
return chunk
-def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatResponseStream):
+def print_streamed_response(
+ display_args : DisplayArguments,
+ completion : OpenAIChatResponseStream,
+ n_completions : int
+ ) -> None:
"""
Print the response in real time by printing the deltas as they occur. If multiple responses
are requested, print the first in real-time, accumulating the others in the background. One the
@@ -82,9 +87,8 @@ def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatRe
on until all responses have been printed.
"""
- COLOR_CODE = get_color_codes(no_color = args.color == AutoDetectedOption.OFF)
- ADORNMENTS = args.adornments == AutoDetectedOption.ON
- N_COMPLETIONS = args.n_completions
+ COLOR_CODE = get_color_codes(no_color = not display_args.color)
+ adornments = display_args.adornments
cumu_responses = defaultdict(CumulativeResponse)
display_idx = 0
@@ -103,9 +107,9 @@ def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatRe
display_response = cumu_responses[display_idx]
- if not prompt_printed and ADORNMENTS:
- res_indicator = '' if N_COMPLETIONS == 1 else \
- f' {display_idx + 1}/{N_COMPLETIONS}'
+ if not prompt_printed and adornments:
+ res_indicator = '' if n_completions == 1 else \
+ f' {display_idx + 1}/{n_completions}'
PROMPT = f'[{COLOR_CODE.GREEN}{update.model}{COLOR_CODE.RESET}{COLOR_CODE.RED}{res_indicator}{COLOR_CODE.RESET}]'
prompt_printed = True
print(PROMPT, end=' ', flush=True)
@@ -116,11 +120,11 @@ def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatRe
sep='', end='', flush=True)
if display_response.finish_reason is not FinishReason.NONE:
- if display_idx < N_COMPLETIONS:
+ if display_idx < n_completions:
display_idx += 1
prompt_printed = False
- if ADORNMENTS:
+ if adornments:
print(end='\n\n', flush=True)
else:
print(end='\n', flush=True)
@@ -128,15 +132,24 @@ def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatRe
def main():
args = parse_args()
- if args.save_response_to_file:
- save_response_and_arguments(args)
- return
- elif args.load_response_from_file:
- args, completion = load_response_and_arguments(args)
+ completion_args = args.completion_args
+
+ if args.debug_args:
+ debug_args : DebugArguments = args.debug_args
+
+ if debug_args.save_response_to_file:
+ save_response_and_arguments(args)
+ return
+ elif debug_args.load_response_from_file:
+ completion_args, completion = load_response_and_arguments(args)
else:
- completion = create_chat_completion_from_args(args)
+ completion = create_chat_completion_from_args(completion_args)
- print_streamed_response(args, completion)
+ print_streamed_response(
+ args.display_args,
+ completion,
+ completion_args.n_completions
+ )
if __name__ == "__main__":
main()