aboutsummaryrefslogtreecommitdiff
path: root/src/gpt_chat_cli/gcli.py
diff options
context:
space:
mode:
authorflu0r1ne <flu0r1ne@flu0r1ne.net>2023-05-04 19:46:32 -0500
committerflu0r1ne <flu0r1ne@flu0r1ne.net>2023-05-04 19:46:32 -0500
commitd02f9ded2a503683b57bcacece6a2cff5484fb81 (patch)
tree896521aa42194e5055b04cc7d5c6e93b56719f73 /src/gpt_chat_cli/gcli.py
parenta74933b2d83efb5da4e0f1851d65ad575f04a65d (diff)
downloadgpt-chat-cli-d02f9ded2a503683b57bcacece6a2cff5484fb81.tar.xz
gpt-chat-cli-d02f9ded2a503683b57bcacece6a2cff5484fb81.zip
Add packaging info
Diffstat (limited to 'src/gpt_chat_cli/gcli.py')
-rw-r--r--src/gpt_chat_cli/gcli.py142
1 files changed, 142 insertions, 0 deletions
diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/gcli.py
new file mode 100644
index 0000000..ded6d6c
--- /dev/null
+++ b/src/gpt_chat_cli/gcli.py
@@ -0,0 +1,142 @@
+#!/bin/env python3
+
+import argparse
+import sys
+import openai
+import pickle
+
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Tuple
+
+from .openai_wrappers import (
+ create_chat_completion,
+ OpenAIChatResponse,
+ OpenAIChatResponseStream,
+ FinishReason,
+)
+
+from .argparsing import (
+ parse_args,
+ AutoDetectedOption,
+)
+
+from .color import get_color_codes
+
+###########################
+#### SAVE / REPLAY ####
+###########################
+
+def create_chat_completion_from_args(args : argparse.Namespace) \
+ -> OpenAIChatResponseStream:
+ return create_chat_completion(
+ model=args.model,
+ messages=[{ "role": "user", "content": args.message }],
+ n=args.n_completions,
+ temperature=args.temperature,
+ presence_penalty=args.presence_penalty,
+ frequency_penalty=args.frequency_penalty,
+ max_tokens=args.max_tokens,
+ top_p=args.top_p,
+ stream=True
+ )
+
+def save_response_and_arguments(args : argparse.Namespace) -> None:
+ completion = create_chat_completion_from_args(args)
+ completion = list(completion)
+
+ filename = args.save_response_to_file
+
+ with open(filename, 'wb') as f:
+ pickle.dump((args, completion,), f)
+
+def load_response_and_arguments(args : argparse.Namespace) \
+ -> Tuple[argparse.Namespace, OpenAIChatResponseStream]:
+
+ filename = args.load_response_from_file
+
+ with open(filename, 'rb') as f:
+ args, completion = pickle.load(f)
+
+ return (args, completion)
+
+#########################
+#### PRETTY PRINTING ####
+#########################
+
+@dataclass
+class CumulativeResponse:
+ content: str = ""
+ finish_reason: FinishReason = FinishReason.NONE
+
+ def take_content(self : "CumulativeResponse"):
+ chunk = self.content
+ self.content = ""
+ return chunk
+
+def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatResponseStream):
+ """
+ Print the response in real time by printing the deltas as they occur. If multiple responses
+ are requested, print the first in real-time, accumulating the others in the background. One the
+ first response completes, move on to the second response printing the deltas in real time. Continue
+ on until all responses have been printed.
+ """
+
+ COLOR_CODE = get_color_codes(no_color = args.color == AutoDetectedOption.OFF)
+ ADORNMENTS = args.adornments == AutoDetectedOption.ON
+ N_COMPLETIONS = args.n_completions
+
+ cumu_responses = defaultdict(CumulativeResponse)
+ display_idx = 0
+ prompt_printed = False
+
+ for update in completion:
+
+ for choice in update.choices:
+ delta = choice.delta
+
+ if delta.content:
+ cumu_responses[choice.index].content += delta.content
+
+ if choice.finish_reason is not FinishReason.NONE:
+ cumu_responses[choice.index].finish_reason = choice.finish_reason
+
+ display_response = cumu_responses[display_idx]
+
+ if not prompt_printed and ADORNMENTS:
+ res_indicator = '' if N_COMPLETIONS == 1 else \
+ f' {display_idx + 1}/{n_completions}'
+ PROMPT = f'[{COLOR_CODE.GREEN}{update.model}{COLOR_CODE.RESET}{COLOR_CODE.RED}{res_indicator}{COLOR_CODE.RESET}]'
+ prompt_printed = True
+ print(PROMPT, end=' ', flush=True)
+
+
+ content = display_response.take_content()
+ print(f'{COLOR_CODE.WHITE}{content}{COLOR_CODE.RESET}',
+ sep='', end='', flush=True)
+
+ if display_response.finish_reason is not FinishReason.NONE:
+ if display_idx < N_COMPLETIONS:
+ display_idx += 1
+ prompt_printed = False
+
+ if ADORNMENTS:
+ print(end='\n\n', flush=True)
+ else:
+ print(end='\n', flush=True)
+
+def main():
+ args = parse_args()
+
+ if args.save_response_to_file:
+ save_response_and_arguments(args)
+ return
+ elif args.load_response_from_file:
+ args, completion = load_response_and_arguments(args)
+ else:
+ completion = create_chat_completion_from_args(args)
+
+ print_streamed_response(args, completion)
+
+if __name__ == "__main__":
+ main()