From 89c77af2d93c8ed3f3f452fff709d91f9228f9dc Mon Sep 17 00:00:00 2001 From: flu0r1ne Date: Thu, 11 May 2023 02:29:52 -0500 Subject: Fix printing when provided initial messages --- src/gpt_chat_cli/cmd.py | 7 +- src/gpt_chat_cli/prompt.py | 170 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 src/gpt_chat_cli/prompt.py (limited to 'src') diff --git a/src/gpt_chat_cli/cmd.py b/src/gpt_chat_cli/cmd.py index 899c705..83cd298 100644 --- a/src/gpt_chat_cli/cmd.py +++ b/src/gpt_chat_cli/cmd.py @@ -297,10 +297,10 @@ def interactive(args : Arguments): initial_message = resolve_initial_message(args.initial_message, interactive=True) - if initial_message: - print( PROMPT, initial_message, sep='', flush=True ) - with prompter as prompt: + if initial_message: + print( prompt.prompt, initial_message, sep='', flush=True ) + while True: try: if initial_message: @@ -339,6 +339,7 @@ def interactive(args : Arguments): hist.append( ChatMessage(Role.ASSISTANT, response) ) except KeyboardInterrupt: # Skip to next prompt + print() continue except EOFError: # Exit on Control-D print() diff --git a/src/gpt_chat_cli/prompt.py b/src/gpt_chat_cli/prompt.py new file mode 100644 index 0000000..ec39b4b --- /dev/null +++ b/src/gpt_chat_cli/prompt.py @@ -0,0 +1,170 @@ +from typing import List, Dict, Tuple +from dataclasses import dataclass + +import os +import sys +import textwrap + +from .color import ( + get_color_codes, + surround_ansi_escapes, + ColorCode +) + +@dataclass +class Command( object ): + name : str + description : str + + # eventually could contain arguments + +PromptResponse = Tuple[Command, None] | Tuple[None, str] + +class PromptUsageException(Exception): + def __init__( self : "PromptUsageException", message : str): + self.message = message + +class Prompt( object ): + + COMMAND_INDICATOR = "/" + SEP = ' ' + + def __init__(self : "Prompt", prompt : str, cmds : List[Command]): + self.prompt = prompt + self._cmds = { cmd.name : cmd for cmd in cmds } + + def print_help( self : "Prompt", file=sys.stdout ): + cmds = list(self._cmds.values()) + cmds.sort(key=lambda cmd: cmd.name) + + COMMAND_HEADER = "Command" + DESCRIPTION_HEADER = "Description" + + MAX_COLUMN_LEN = 80 + SPACING = 2 + SEP = SPACING * ' ' + + # Calculate column widths + name_width = max(len(cmd.name) for cmd in cmds) + name_width = max(len(COMMAND_HEADER), name_width) + + desc_width = max(len(cmd.description) for cmd in cmds) + desc_width = max(len(DESCRIPTION_HEADER), desc_width) + desc_width = min(desc_width, MAX_COLUMN_LEN - SPACING - name_width) + + # Print headers + print(f"{'Command':<{name_width}}{SEP}{'Description':<{desc_width}}", file=file) + print("-" * (name_width + desc_width + SPACING), file=file) + + # Print rows + for cmd in cmds: + name = cmd.name.ljust(name_width) + + desc_lines = textwrap.wrap(cmd.description, desc_width) + desc = '' + for i, desc_line in enumerate(desc_lines): + if i == 0: + desc += desc_line + '\n' + else: + desc += ' ' * (SPACING + name_width) + desc += desc_line + '\n' + + print(f"{name}{SEP}{desc}", file=file, end='') + + print() + + def _parse_response( self : "Prompt", response : str) \ + -> PromptResponse: + + if not response.startswith(Prompt.COMMAND_INDICATOR): + return (None, response) + + cmd_parts = response[1:].split(sep=Prompt.SEP) + cmd_name = cmd_parts[0].strip() + + if cmd_name not in self._cmds: + raise PromptUsageException(f'command not found: {cmd_name}') + + cmd = self._cmds[cmd_name] + + return (cmd, None) + + def input( self : "Prompt" ): + while True: + try: + raw_input = input(self.prompt) + response = self._parse_response( raw_input ) + return response + except PromptUsageException as e: + print(f'error: {e.message}', file=sys.stderr) + +class Prompter( object ): + + _cmds : Dict[str, Command] + + def __init__( self : "Prompter", no_color=False ): + self._cmds = [] + self._COLOR = get_color_codes( no_color=no_color ) + self._PROMPT = surround_ansi_escapes( + f'[{self._COLOR.WHITE}#{self._COLOR.RESET}] ' + ) + self._prev_completer_delims = None + + def add_command( self : "Prompter", name : str, description : str ): + cmd = Command(name, description) + self._cmds.append(cmd) + return cmd + + def _enable_prompting( self : "Prompter" ): + + try: + import readline as rl + + def completer( text, state ): + + # stop completions when slash commands + # are not specified + buffer = rl.get_line_buffer() + + if len(buffer) == 0 or buffer[0] != '/': + return None + + completion_options = [ cmd.name for cmd in self._cmds ] + + options = [ i for i in completion_options if i.startswith(text) ] + + if state < len(options): + return options[state] + else: + return None + + rl.set_completer(completer) + self._prev_completer_delims = rl.get_completer_delims() + rl.set_completer_delims('/') + rl.parse_and_bind('tab: complete') + + except ImportError: + pass + + def _disable_prompting(self : "Prompter"): + + # try our best to clean up + try: + import readline as rl + + rl.set_completer(None) + rl.parse_and_bind('tab: insert-tab') + rl.set_completer_delims(self._prev_completer_delims) + + except ImportError: + pass + + def __enter__(self : "Prompter"): + self._enable_prompting() + return Prompt(self._PROMPT, self._cmds) + + def __exit__(self : "Prompter", exec_type, exec_val, exc_tb): + self._disable_prompting() + + + -- cgit v1.2.3