diff options
-rw-r--r-- | src/gpt_chat_cli/argparsing.py | 5 | ||||
-rw-r--r-- | src/gpt_chat_cli/gcli.py | 131 | ||||
-rw-r--r-- | src/gpt_chat_cli/openai_wrappers.py | 44 |
3 files changed, 143 insertions, 37 deletions
diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py index e9183a9..04d3645 100644 --- a/src/gpt_chat_cli/argparsing.py +++ b/src/gpt_chat_cli/argparsing.py @@ -57,14 +57,12 @@ class CompletionArguments: frequency_penalty: float max_tokens: int top_p: float - message: str @dataclass class DisplayArguments: adornments: bool color: bool - @dataclass class DebugArguments: save_response_to_file: Optional[str] @@ -77,6 +75,7 @@ class Arguments: version: bool list_models: bool interactive: bool + initial_message: Optional[str] = None debug_args: Optional[DebugArguments] = None def split_arguments(args: argparse.Namespace) -> Arguments: @@ -88,7 +87,6 @@ def split_arguments(args: argparse.Namespace) -> Arguments: frequency_penalty=args.frequency_penalty, max_tokens=args.max_tokens, top_p=args.top_p, - message=args.message ) display_args = DisplayArguments( @@ -102,6 +100,7 @@ def split_arguments(args: argparse.Namespace) -> Arguments: ) return Arguments( + initial_message=args.message, completion_args=completion_args, display_args=display_args, debug_args=debug_args, diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/gcli.py index 2d40cf2..9df4459 100644 --- a/src/gpt_chat_cli/gcli.py +++ b/src/gpt_chat_cli/gcli.py @@ -14,6 +14,8 @@ from .openai_wrappers import ( OpenAIChatResponse, OpenAIChatResponseStream, FinishReason, + Role, + ChatMessage ) from .argparsing import ( @@ -31,28 +33,43 @@ from .color import get_color_codes #### SAVE / REPLAY #### ########################### -def create_chat_completion_from_args(args : CompletionArguments) \ - -> OpenAIChatResponseStream: - return create_chat_completion( - model=args.model, - messages=[{ "role": "user", "content": args.message }], - n=args.n_completions, - temperature=args.temperature, - presence_penalty=args.presence_penalty, - frequency_penalty=args.frequency_penalty, - max_tokens=args.max_tokens, - top_p=args.top_p, - stream=True - ) + +# def create_chat_completion_from_args(args : CompletionArguments) \ +# -> OpenAIChatResponseStream: +# return create_chat_completion( +# model=args.model, +# messages=[{ "role": "user", "content": args.message }], +# n=args.n_completions, +# temperature=args.temperature, +# presence_penalty=args.presence_penalty, +# frequency_penalty=args.frequency_penalty, +# max_tokens=args.max_tokens, +# top_p=args.top_p, +# stream=True +# ) + +def create_singleton_chat_completion( + message : str, + completion_args : CompletionArguments + ): + + hist = [ ChatMessage( Role.USER, message ) ] + + completion = create_chat_completion(hist, completion_args) + + return completion def save_response_and_arguments(args : Arguments) -> None: - completion = create_chat_completion_from_args(args.completion_args) + + message = args.initial_message + + completion = create_singleton_chat_completion(message, args.completion_args) completion = list(completion) filename = args.debug_args.save_response_to_file with open(filename, 'wb') as f: - pickle.dump((args.completion_args, completion,), f) + pickle.dump((message, args.completion_args, completion,), f) def load_response_and_arguments(args : Arguments) \ -> Tuple[CompletionArguments, OpenAIChatResponseStream]: @@ -60,9 +77,9 @@ def load_response_and_arguments(args : Arguments) \ filename = args.debug_args.load_response_from_file with open(filename, 'rb') as f: - args, completion = pickle.load(f) + message, args, completion = pickle.load(f) - return (args, completion) + return (message, args, completion) ######################### #### PRETTY PRINTING #### @@ -70,18 +87,24 @@ def load_response_and_arguments(args : Arguments) \ @dataclass class CumulativeResponse: - content: str = "" + delta_content: str = "" finish_reason: FinishReason = FinishReason.NONE + content: str = "" - def take_content(self : "CumulativeResponse"): - chunk = self.content - self.content = "" + def take_delta(self : "CumulativeResponse"): + chunk = self.delta_content + self.delta_content = "" return chunk + def add_content(self : "CumulativeResponse", new_chunk : str): + self.content += new_chunk + self.delta_content += new_chunk + def print_streamed_response( display_args : DisplayArguments, completion : OpenAIChatResponseStream, - n_completions : int + n_completions : int, + return_responses : bool = False ) -> None: """ Print the response in real time by printing the deltas as they occur. If multiple responses @@ -103,7 +126,7 @@ def print_streamed_response( delta = choice.delta if delta.content: - cumu_responses[choice.index].content += delta.content + cumu_responses[choice.index].add_content(delta.content) if choice.finish_reason is not FinishReason.NONE: cumu_responses[choice.index].finish_reason = choice.finish_reason @@ -118,7 +141,7 @@ def print_streamed_response( print(PROMPT, end=' ', flush=True) - content = display_response.take_content() + content = display_response.take_delta() print(f'{COLOR_CODE.WHITE}{content}{COLOR_CODE.RESET}', sep='', end='', flush=True) @@ -132,6 +155,9 @@ def print_streamed_response( else: print(end='\n', flush=True) + if return_responses: + return [ cumu_responses[i].content for i in range(n_completions) ] + def cmd_version(): print(f'version {VERSION}') @@ -142,24 +168,69 @@ def cmd_list_models(): def cmd_interactive(args : Arguments): COLOR_CODE = get_color_codes(no_color = not args.display_args.color) - print(f'GPT Chat CLI {VERSION}') - print(f'[{COLOR_CODE.WHITE}#{COLOR_CODE.RESET}]', end=' ', flush=True) + completion_args = args.completion_args + display_args = args.display_args + + hist = [] + + def print_prompt(): + + print(f'[{COLOR_CODE.WHITE}#{COLOR_CODE.RESET}]', end=' ', flush=True) + + def prompt_message() -> bool: + print_prompt() + + # Control-D closes the input stream + try: + message = input() + except EOFError: + print() + return False + + hist.append( ChatMessage( Role.USER, message ) ) + + return True + + print(f'GPT Chat CLI version {VERSION}') + print(f'Press Control-D to exit') + + if args.initial_message: + print_prompt() + print( args.initial_message ) + hist.append( ChatMessage( Role.USER, args.initial_message ) ) + else: + prompt_message() + + while True: + + completion = create_chat_completion(hist, completion_args) + + response = print_streamed_response( + display_args, completion, 1, return_responses=True, + )[0] + + hist.append( ChatMessage(Role.ASSISTANT, response) ) + + if not prompt_message(): + break def cmd_singleton(args: Arguments): completion_args = args.completion_args debug_args : DebugArguments = args.debug_args + message = args.initial_message if debug_args.save_response_to_file: save_response_and_arguments(args) return elif debug_args.load_response_from_file: - completion_args, completion = load_response_and_arguments(args) + message, completion_args, completion = load_response_and_arguments(args) else: - if completion_args.message is None: - completion_args.message = sys.stdin.read() + # message is only None is a TTY is not attached + if message is None: + message = sys.stdin.read() - completion = create_chat_completion_from_args(completion_args) + completion = create_singleton_chat_completion(message, completion_args) print_streamed_response( args.display_args, diff --git a/src/gpt_chat_cli/openai_wrappers.py b/src/gpt_chat_cli/openai_wrappers.py index 413ec24..6eeba4d 100644 --- a/src/gpt_chat_cli/openai_wrappers.py +++ b/src/gpt_chat_cli/openai_wrappers.py @@ -28,6 +28,24 @@ class Choice: finish_reason: Optional[FinishReason] index: int +class Role(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + +@dataclass +class ChatMessage: + role: Role + content: str + + def to_json(self : "ChatMessage"): + return { + "role": self.role.value, + "content": self.content + } + +ChatHistory = List[ChatMessage] + @dataclass class OpenAIChatResponse: choices: List[Choice] @@ -61,13 +79,31 @@ class OpenAIChatResponse: OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None] -def create_chat_completion(*args, **kwargs) \ - -> OpenAIChatResponseStream: +from .argparsing import CompletionArguments + +def create_chat_completion(hist : ChatHistory, args: CompletionArguments) \ + -> OpenAIChatResponseStream: + + messages = [ msg.to_json() for msg in hist ] + + response = openai.ChatCompletion.create( + model=args.model, + messages=messages, + n=args.n_completions, + temperature=args.temperature, + presence_penalty=args.presence_penalty, + frequency_penalty=args.frequency_penalty, + max_tokens=args.max_tokens, + top_p=args.top_p, + stream=True + ) + return ( - OpenAIChatResponse.from_json(update) \ - for update in openai.ChatCompletion.create(*args, **kwargs) + OpenAIChatResponse.from_json( update ) \ + for update in response ) + def list_models() -> List[str]: model_data = openai.Model.list() |