aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/gpt_chat_cli/argparsing.py5
-rw-r--r--src/gpt_chat_cli/gcli.py131
-rw-r--r--src/gpt_chat_cli/openai_wrappers.py44
3 files changed, 143 insertions, 37 deletions
diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py
index e9183a9..04d3645 100644
--- a/src/gpt_chat_cli/argparsing.py
+++ b/src/gpt_chat_cli/argparsing.py
@@ -57,14 +57,12 @@ class CompletionArguments:
frequency_penalty: float
max_tokens: int
top_p: float
- message: str
@dataclass
class DisplayArguments:
adornments: bool
color: bool
-
@dataclass
class DebugArguments:
save_response_to_file: Optional[str]
@@ -77,6 +75,7 @@ class Arguments:
version: bool
list_models: bool
interactive: bool
+ initial_message: Optional[str] = None
debug_args: Optional[DebugArguments] = None
def split_arguments(args: argparse.Namespace) -> Arguments:
@@ -88,7 +87,6 @@ def split_arguments(args: argparse.Namespace) -> Arguments:
frequency_penalty=args.frequency_penalty,
max_tokens=args.max_tokens,
top_p=args.top_p,
- message=args.message
)
display_args = DisplayArguments(
@@ -102,6 +100,7 @@ def split_arguments(args: argparse.Namespace) -> Arguments:
)
return Arguments(
+ initial_message=args.message,
completion_args=completion_args,
display_args=display_args,
debug_args=debug_args,
diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/gcli.py
index 2d40cf2..9df4459 100644
--- a/src/gpt_chat_cli/gcli.py
+++ b/src/gpt_chat_cli/gcli.py
@@ -14,6 +14,8 @@ from .openai_wrappers import (
OpenAIChatResponse,
OpenAIChatResponseStream,
FinishReason,
+ Role,
+ ChatMessage
)
from .argparsing import (
@@ -31,28 +33,43 @@ from .color import get_color_codes
#### SAVE / REPLAY ####
###########################
-def create_chat_completion_from_args(args : CompletionArguments) \
- -> OpenAIChatResponseStream:
- return create_chat_completion(
- model=args.model,
- messages=[{ "role": "user", "content": args.message }],
- n=args.n_completions,
- temperature=args.temperature,
- presence_penalty=args.presence_penalty,
- frequency_penalty=args.frequency_penalty,
- max_tokens=args.max_tokens,
- top_p=args.top_p,
- stream=True
- )
+
+# def create_chat_completion_from_args(args : CompletionArguments) \
+# -> OpenAIChatResponseStream:
+# return create_chat_completion(
+# model=args.model,
+# messages=[{ "role": "user", "content": args.message }],
+# n=args.n_completions,
+# temperature=args.temperature,
+# presence_penalty=args.presence_penalty,
+# frequency_penalty=args.frequency_penalty,
+# max_tokens=args.max_tokens,
+# top_p=args.top_p,
+# stream=True
+# )
+
+def create_singleton_chat_completion(
+ message : str,
+ completion_args : CompletionArguments
+ ):
+
+ hist = [ ChatMessage( Role.USER, message ) ]
+
+ completion = create_chat_completion(hist, completion_args)
+
+ return completion
def save_response_and_arguments(args : Arguments) -> None:
- completion = create_chat_completion_from_args(args.completion_args)
+
+ message = args.initial_message
+
+ completion = create_singleton_chat_completion(message, args.completion_args)
completion = list(completion)
filename = args.debug_args.save_response_to_file
with open(filename, 'wb') as f:
- pickle.dump((args.completion_args, completion,), f)
+ pickle.dump((message, args.completion_args, completion,), f)
def load_response_and_arguments(args : Arguments) \
-> Tuple[CompletionArguments, OpenAIChatResponseStream]:
@@ -60,9 +77,9 @@ def load_response_and_arguments(args : Arguments) \
filename = args.debug_args.load_response_from_file
with open(filename, 'rb') as f:
- args, completion = pickle.load(f)
+ message, args, completion = pickle.load(f)
- return (args, completion)
+ return (message, args, completion)
#########################
#### PRETTY PRINTING ####
@@ -70,18 +87,24 @@ def load_response_and_arguments(args : Arguments) \
@dataclass
class CumulativeResponse:
- content: str = ""
+ delta_content: str = ""
finish_reason: FinishReason = FinishReason.NONE
+ content: str = ""
- def take_content(self : "CumulativeResponse"):
- chunk = self.content
- self.content = ""
+ def take_delta(self : "CumulativeResponse"):
+ chunk = self.delta_content
+ self.delta_content = ""
return chunk
+ def add_content(self : "CumulativeResponse", new_chunk : str):
+ self.content += new_chunk
+ self.delta_content += new_chunk
+
def print_streamed_response(
display_args : DisplayArguments,
completion : OpenAIChatResponseStream,
- n_completions : int
+ n_completions : int,
+ return_responses : bool = False
) -> None:
"""
Print the response in real time by printing the deltas as they occur. If multiple responses
@@ -103,7 +126,7 @@ def print_streamed_response(
delta = choice.delta
if delta.content:
- cumu_responses[choice.index].content += delta.content
+ cumu_responses[choice.index].add_content(delta.content)
if choice.finish_reason is not FinishReason.NONE:
cumu_responses[choice.index].finish_reason = choice.finish_reason
@@ -118,7 +141,7 @@ def print_streamed_response(
print(PROMPT, end=' ', flush=True)
- content = display_response.take_content()
+ content = display_response.take_delta()
print(f'{COLOR_CODE.WHITE}{content}{COLOR_CODE.RESET}',
sep='', end='', flush=True)
@@ -132,6 +155,9 @@ def print_streamed_response(
else:
print(end='\n', flush=True)
+ if return_responses:
+ return [ cumu_responses[i].content for i in range(n_completions) ]
+
def cmd_version():
print(f'version {VERSION}')
@@ -142,24 +168,69 @@ def cmd_list_models():
def cmd_interactive(args : Arguments):
COLOR_CODE = get_color_codes(no_color = not args.display_args.color)
- print(f'GPT Chat CLI {VERSION}')
- print(f'[{COLOR_CODE.WHITE}#{COLOR_CODE.RESET}]', end=' ', flush=True)
+ completion_args = args.completion_args
+ display_args = args.display_args
+
+ hist = []
+
+ def print_prompt():
+
+ print(f'[{COLOR_CODE.WHITE}#{COLOR_CODE.RESET}]', end=' ', flush=True)
+
+ def prompt_message() -> bool:
+ print_prompt()
+
+ # Control-D closes the input stream
+ try:
+ message = input()
+ except EOFError:
+ print()
+ return False
+
+ hist.append( ChatMessage( Role.USER, message ) )
+
+ return True
+
+ print(f'GPT Chat CLI version {VERSION}')
+ print(f'Press Control-D to exit')
+
+ if args.initial_message:
+ print_prompt()
+ print( args.initial_message )
+ hist.append( ChatMessage( Role.USER, args.initial_message ) )
+ else:
+ prompt_message()
+
+ while True:
+
+ completion = create_chat_completion(hist, completion_args)
+
+ response = print_streamed_response(
+ display_args, completion, 1, return_responses=True,
+ )[0]
+
+ hist.append( ChatMessage(Role.ASSISTANT, response) )
+
+ if not prompt_message():
+ break
def cmd_singleton(args: Arguments):
completion_args = args.completion_args
debug_args : DebugArguments = args.debug_args
+ message = args.initial_message
if debug_args.save_response_to_file:
save_response_and_arguments(args)
return
elif debug_args.load_response_from_file:
- completion_args, completion = load_response_and_arguments(args)
+ message, completion_args, completion = load_response_and_arguments(args)
else:
- if completion_args.message is None:
- completion_args.message = sys.stdin.read()
+ # message is only None is a TTY is not attached
+ if message is None:
+ message = sys.stdin.read()
- completion = create_chat_completion_from_args(completion_args)
+ completion = create_singleton_chat_completion(message, completion_args)
print_streamed_response(
args.display_args,
diff --git a/src/gpt_chat_cli/openai_wrappers.py b/src/gpt_chat_cli/openai_wrappers.py
index 413ec24..6eeba4d 100644
--- a/src/gpt_chat_cli/openai_wrappers.py
+++ b/src/gpt_chat_cli/openai_wrappers.py
@@ -28,6 +28,24 @@ class Choice:
finish_reason: Optional[FinishReason]
index: int
+class Role(Enum):
+ SYSTEM = "system"
+ USER = "user"
+ ASSISTANT = "assistant"
+
+@dataclass
+class ChatMessage:
+ role: Role
+ content: str
+
+ def to_json(self : "ChatMessage"):
+ return {
+ "role": self.role.value,
+ "content": self.content
+ }
+
+ChatHistory = List[ChatMessage]
+
@dataclass
class OpenAIChatResponse:
choices: List[Choice]
@@ -61,13 +79,31 @@ class OpenAIChatResponse:
OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None]
-def create_chat_completion(*args, **kwargs) \
- -> OpenAIChatResponseStream:
+from .argparsing import CompletionArguments
+
+def create_chat_completion(hist : ChatHistory, args: CompletionArguments) \
+ -> OpenAIChatResponseStream:
+
+ messages = [ msg.to_json() for msg in hist ]
+
+ response = openai.ChatCompletion.create(
+ model=args.model,
+ messages=messages,
+ n=args.n_completions,
+ temperature=args.temperature,
+ presence_penalty=args.presence_penalty,
+ frequency_penalty=args.frequency_penalty,
+ max_tokens=args.max_tokens,
+ top_p=args.top_p,
+ stream=True
+ )
+
return (
- OpenAIChatResponse.from_json(update) \
- for update in openai.ChatCompletion.create(*args, **kwargs)
+ OpenAIChatResponse.from_json( update ) \
+ for update in response
)
+
def list_models() -> List[str]:
model_data = openai.Model.list()