aboutsummaryrefslogtreecommitdiff
path: root/src/gpt_chat_cli/gcli.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpt_chat_cli/gcli.py')
-rw-r--r--src/gpt_chat_cli/gcli.py131
1 files changed, 101 insertions, 30 deletions
diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/gcli.py
index 2d40cf2..9df4459 100644
--- a/src/gpt_chat_cli/gcli.py
+++ b/src/gpt_chat_cli/gcli.py
@@ -14,6 +14,8 @@ from .openai_wrappers import (
OpenAIChatResponse,
OpenAIChatResponseStream,
FinishReason,
+ Role,
+ ChatMessage
)
from .argparsing import (
@@ -31,28 +33,43 @@ from .color import get_color_codes
#### SAVE / REPLAY ####
###########################
-def create_chat_completion_from_args(args : CompletionArguments) \
- -> OpenAIChatResponseStream:
- return create_chat_completion(
- model=args.model,
- messages=[{ "role": "user", "content": args.message }],
- n=args.n_completions,
- temperature=args.temperature,
- presence_penalty=args.presence_penalty,
- frequency_penalty=args.frequency_penalty,
- max_tokens=args.max_tokens,
- top_p=args.top_p,
- stream=True
- )
+
+# def create_chat_completion_from_args(args : CompletionArguments) \
+# -> OpenAIChatResponseStream:
+# return create_chat_completion(
+# model=args.model,
+# messages=[{ "role": "user", "content": args.message }],
+# n=args.n_completions,
+# temperature=args.temperature,
+# presence_penalty=args.presence_penalty,
+# frequency_penalty=args.frequency_penalty,
+# max_tokens=args.max_tokens,
+# top_p=args.top_p,
+# stream=True
+# )
+
+def create_singleton_chat_completion(
+ message : str,
+ completion_args : CompletionArguments
+ ):
+
+ hist = [ ChatMessage( Role.USER, message ) ]
+
+ completion = create_chat_completion(hist, completion_args)
+
+ return completion
def save_response_and_arguments(args : Arguments) -> None:
- completion = create_chat_completion_from_args(args.completion_args)
+
+ message = args.initial_message
+
+ completion = create_singleton_chat_completion(message, args.completion_args)
completion = list(completion)
filename = args.debug_args.save_response_to_file
with open(filename, 'wb') as f:
- pickle.dump((args.completion_args, completion,), f)
+ pickle.dump((message, args.completion_args, completion,), f)
def load_response_and_arguments(args : Arguments) \
-> Tuple[CompletionArguments, OpenAIChatResponseStream]:
@@ -60,9 +77,9 @@ def load_response_and_arguments(args : Arguments) \
filename = args.debug_args.load_response_from_file
with open(filename, 'rb') as f:
- args, completion = pickle.load(f)
+ message, args, completion = pickle.load(f)
- return (args, completion)
+ return (message, args, completion)
#########################
#### PRETTY PRINTING ####
@@ -70,18 +87,24 @@ def load_response_and_arguments(args : Arguments) \
@dataclass
class CumulativeResponse:
- content: str = ""
+ delta_content: str = ""
finish_reason: FinishReason = FinishReason.NONE
+ content: str = ""
- def take_content(self : "CumulativeResponse"):
- chunk = self.content
- self.content = ""
+ def take_delta(self : "CumulativeResponse"):
+ chunk = self.delta_content
+ self.delta_content = ""
return chunk
+ def add_content(self : "CumulativeResponse", new_chunk : str):
+ self.content += new_chunk
+ self.delta_content += new_chunk
+
def print_streamed_response(
display_args : DisplayArguments,
completion : OpenAIChatResponseStream,
- n_completions : int
+ n_completions : int,
+ return_responses : bool = False
) -> None:
"""
Print the response in real time by printing the deltas as they occur. If multiple responses
@@ -103,7 +126,7 @@ def print_streamed_response(
delta = choice.delta
if delta.content:
- cumu_responses[choice.index].content += delta.content
+ cumu_responses[choice.index].add_content(delta.content)
if choice.finish_reason is not FinishReason.NONE:
cumu_responses[choice.index].finish_reason = choice.finish_reason
@@ -118,7 +141,7 @@ def print_streamed_response(
print(PROMPT, end=' ', flush=True)
- content = display_response.take_content()
+ content = display_response.take_delta()
print(f'{COLOR_CODE.WHITE}{content}{COLOR_CODE.RESET}',
sep='', end='', flush=True)
@@ -132,6 +155,9 @@ def print_streamed_response(
else:
print(end='\n', flush=True)
+ if return_responses:
+ return [ cumu_responses[i].content for i in range(n_completions) ]
+
def cmd_version():
print(f'version {VERSION}')
@@ -142,24 +168,69 @@ def cmd_list_models():
def cmd_interactive(args : Arguments):
COLOR_CODE = get_color_codes(no_color = not args.display_args.color)
- print(f'GPT Chat CLI {VERSION}')
- print(f'[{COLOR_CODE.WHITE}#{COLOR_CODE.RESET}]', end=' ', flush=True)
+ completion_args = args.completion_args
+ display_args = args.display_args
+
+ hist = []
+
+ def print_prompt():
+
+ print(f'[{COLOR_CODE.WHITE}#{COLOR_CODE.RESET}]', end=' ', flush=True)
+
+ def prompt_message() -> bool:
+ print_prompt()
+
+ # Control-D closes the input stream
+ try:
+ message = input()
+ except EOFError:
+ print()
+ return False
+
+ hist.append( ChatMessage( Role.USER, message ) )
+
+ return True
+
+ print(f'GPT Chat CLI version {VERSION}')
+ print(f'Press Control-D to exit')
+
+ if args.initial_message:
+ print_prompt()
+ print( args.initial_message )
+ hist.append( ChatMessage( Role.USER, args.initial_message ) )
+ else:
+ prompt_message()
+
+ while True:
+
+ completion = create_chat_completion(hist, completion_args)
+
+ response = print_streamed_response(
+ display_args, completion, 1, return_responses=True,
+ )[0]
+
+ hist.append( ChatMessage(Role.ASSISTANT, response) )
+
+ if not prompt_message():
+ break
def cmd_singleton(args: Arguments):
completion_args = args.completion_args
debug_args : DebugArguments = args.debug_args
+ message = args.initial_message
if debug_args.save_response_to_file:
save_response_and_arguments(args)
return
elif debug_args.load_response_from_file:
- completion_args, completion = load_response_and_arguments(args)
+ message, completion_args, completion = load_response_and_arguments(args)
else:
- if completion_args.message is None:
- completion_args.message = sys.stdin.read()
+ # message is only None is a TTY is not attached
+ if message is None:
+ message = sys.stdin.read()
- completion = create_chat_completion_from_args(completion_args)
+ completion = create_singleton_chat_completion(message, completion_args)
print_streamed_response(
args.display_args,