aboutsummaryrefslogtreecommitdiff
path: root/src/gpt_chat_cli/gcli.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpt_chat_cli/gcli.py')
-rw-r--r--src/gpt_chat_cli/gcli.py134
1 files changed, 92 insertions, 42 deletions
diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/gcli.py
index 1c5555c..5f2a478 100644
--- a/src/gpt_chat_cli/gcli.py
+++ b/src/gpt_chat_cli/gcli.py
@@ -3,10 +3,12 @@
import sys
import openai
import pickle
+import os
+import datetime
from collections import defaultdict
from dataclasses import dataclass
-from typing import Tuple
+from typing import Tuple, Optional
from .openai_wrappers import (
create_chat_completion,
@@ -24,50 +26,102 @@ from .argparsing import (
DisplayArguments,
CompletionArguments,
DebugArguments,
+ MessageSource
)
from .version import VERSION
from .color import get_color_codes
+from .chat_colorizer import ChatColorizer
-import datetime
+###########################
+#### UTILS ####
+###########################
+
+def resolve_initial_message(src: MessageSource, interactive=False) -> str:
+ msg = None
+
+ if src.message:
+ msg = src.message
+ elif src.prompt_from_fd:
+ with os.fdopen(src.prompt_from_fd, "r") as f:
+ msg = f.read()
+ elif src.prompt_from_file:
+ with open(src.prompt_from_file, "r") as f:
+ msg = f.read()
+ elif not interactive:
+ msg = sys.stdin.read()
+
+ return msg
+
+def get_system_message(system_message : Optional[str]):
+
+ if not system_message:
+
+ current_date_time = datetime.datetime.now()
+ system_message = f'The current date is {current_date_time}. When emitting code or producing markdown, ensure to label fenced code blocks with the language in use.'
+
+ return ChatMessage(Role.SYSTEM, system_message)
+
+def enable_emacs_editing():
+ try:
+ import readline
+ except ImportError:
+ pass
###########################
#### SAVE / REPLAY ####
###########################
-def create_singleton_chat_completion(
- message : str,
- completion_args : CompletionArguments
- ):
+@dataclass
+class CompletionContext:
+ message: str
+ completion_args: CompletionArguments
+ system_message: Optional[str] = None
- hist = [ get_system_message(), ChatMessage( Role.USER, message ) ]
+def create_singleton_chat_completion(ctx : CompletionContext):
- completion = create_chat_completion(hist, completion_args)
+ hist = [
+ get_system_message(ctx.system_message),
+ ChatMessage(Role.USER, ctx.message)
+ ]
+
+ completion = create_chat_completion(hist, ctx.completion_args)
return completion
def save_response_and_arguments(args : Arguments) -> None:
- message = args.initial_message
+ message = resolve_initial_message(args.initial_message)
+
+ ctx = CompletionContext(
+ message=message,
+ completion_args=args.completion_args,
+ system_message=args.system_message
+ )
+
+ completion = create_singleton_chat_completion(
+ message,
+ args.completion_args,
+ args.system_message,
+ )
- completion = create_singleton_chat_completion(message, args.completion_args)
completion = list(completion)
filename = args.debug_args.save_response_to_file
with open(filename, 'wb') as f:
- pickle.dump((message, args.completion_args, completion,), f)
+ pickle.dump((ctx, completion,), f)
def load_response_and_arguments(args : Arguments) \
- -> Tuple[CompletionArguments, OpenAIChatResponseStream]:
+ -> Tuple[CompletionContext, OpenAIChatResponseStream]:
filename = args.debug_args.load_response_from_file
with open(filename, 'rb') as f:
- message, args, completion = pickle.load(f)
+ ctx, completion = pickle.load(f)
- return (message, args, completion)
+ return (ctx, completion)
#########################
#### PRETTY PRINTING ####
@@ -88,8 +142,6 @@ class CumulativeResponse:
self.content += new_chunk
self.delta_content += new_chunk
-from .chat_colorizer import ChatColorizer
-
def print_streamed_response(
display_args : DisplayArguments,
completion : OpenAIChatResponseStream,
@@ -156,12 +208,9 @@ def print_streamed_response(
if return_responses:
return [ cumu_responses[i].content for i in range(n_completions) ]
-def get_system_message():
- current_date_time = datetime.datetime.now()
-
- msg = f'The current date is {current_date_time}. When emitting code or producing markdown, ensure to label fenced code blocks with the language in use.'
-
- return ChatMessage( Role.SYSTEM, msg)
+#########################
+#### COMMANDS ####
+#########################
def cmd_version():
print(f'version {VERSION}')
@@ -170,15 +219,6 @@ def cmd_list_models():
for model in list_models():
print(model)
-def enable_emacs_editing():
- try:
- import readline
- # self.old_completer = readline.get_completer()
- # readline.set_completer(self.complete)
- # readline.parse_and_bind(self.completekey+": complete")
- except ImportError:
- pass
-
def cmd_interactive(args : Arguments):
enable_emacs_editing()
@@ -188,7 +228,7 @@ def cmd_interactive(args : Arguments):
completion_args = args.completion_args
display_args = args.display_args
- hist = [ get_system_message() ]
+ hist = [ get_system_message( args.system_message ) ]
PROMPT = f'[{COLOR_CODE.WHITE}#{COLOR_CODE.RESET}] '
@@ -208,9 +248,11 @@ def cmd_interactive(args : Arguments):
print(f'GPT Chat CLI version {VERSION}')
print(f'Press Control-D to exit')
- if args.initial_message:
- print( PROMPT, args.initial_message, sep='' )
- hist.append( ChatMessage( Role.USER, args.initial_message ) )
+ initial_message = resolve_initial_message(args.initial_message, interactive=True)
+
+ if initial_message:
+ print( PROMPT, initial_message, sep='', flush=True )
+ hist.append( ChatMessage( Role.USER, initial_message ) )
else:
if not prompt_message():
return
@@ -225,11 +267,11 @@ def cmd_interactive(args : Arguments):
)[0]
hist.append( ChatMessage(Role.ASSISTANT, response) )
- except:
- pass
+ except KeyboardInterrupt:
+ print()
if not prompt_message():
- break
+ return
def cmd_singleton(args: Arguments):
completion_args = args.completion_args
@@ -241,13 +283,21 @@ def cmd_singleton(args: Arguments):
save_response_and_arguments(args)
return
elif debug_args.load_response_from_file:
- message, completion_args, completion = load_response_and_arguments(args)
+ ctx, completion = load_response_and_arguments(args)
+
+ message = ctx.message
+ completion_args = ctx.completion_args
else:
# message is only None is a TTY is not attached
- if message is None:
- message = sys.stdin.read()
+ message = resolve_initial_message(args.initial_message)
+
+ ctx = CompletionContext(
+ message=message,
+ completion_args=completion_args,
+ system_message=args.system_message
+ )
- completion = create_singleton_chat_completion(message, completion_args)
+ completion = create_singleton_chat_completion(ctx)
print_streamed_response(
args.display_args,