diff options
Diffstat (limited to 'src/gpt_chat_cli')
| -rw-r--r-- | src/gpt_chat_cli/cmd.py | 7 | ||||
| -rw-r--r-- | src/gpt_chat_cli/prompt.py | 170 | 
2 files changed, 174 insertions, 3 deletions
diff --git a/src/gpt_chat_cli/cmd.py b/src/gpt_chat_cli/cmd.py index 899c705..83cd298 100644 --- a/src/gpt_chat_cli/cmd.py +++ b/src/gpt_chat_cli/cmd.py @@ -297,10 +297,10 @@ def interactive(args : Arguments):      initial_message = resolve_initial_message(args.initial_message, interactive=True) -    if initial_message: -        print( PROMPT, initial_message, sep='', flush=True ) -      with prompter as prompt: +        if initial_message: +            print( prompt.prompt, initial_message, sep='', flush=True ) +          while True:              try:                  if initial_message: @@ -339,6 +339,7 @@ def interactive(args : Arguments):                  hist.append( ChatMessage(Role.ASSISTANT, response) )              except KeyboardInterrupt: # Skip to next prompt +                print()                  continue              except EOFError: # Exit on Control-D                  print() diff --git a/src/gpt_chat_cli/prompt.py b/src/gpt_chat_cli/prompt.py new file mode 100644 index 0000000..ec39b4b --- /dev/null +++ b/src/gpt_chat_cli/prompt.py @@ -0,0 +1,170 @@ +from typing import List, Dict, Tuple +from dataclasses import dataclass + +import os +import sys +import textwrap + +from .color import ( +    get_color_codes, +    surround_ansi_escapes, +    ColorCode +) + +@dataclass +class Command( object ): +    name : str +    description : str + +    # eventually could contain arguments + +PromptResponse = Tuple[Command, None] | Tuple[None, str] + +class PromptUsageException(Exception): +    def __init__( self : "PromptUsageException", message : str): +        self.message = message + +class Prompt( object ): + +    COMMAND_INDICATOR = "/" +    SEP = ' ' + +    def __init__(self : "Prompt", prompt : str, cmds : List[Command]): +        self.prompt = prompt +        self._cmds = { cmd.name : cmd for cmd in cmds } + +    def print_help( self : "Prompt", file=sys.stdout ): +        cmds = list(self._cmds.values()) +        cmds.sort(key=lambda cmd: cmd.name) + +        COMMAND_HEADER = "Command" +        DESCRIPTION_HEADER = "Description" + +        MAX_COLUMN_LEN = 80 +        SPACING = 2 +        SEP = SPACING * ' ' + +        # Calculate column widths +        name_width = max(len(cmd.name) for cmd in cmds) +        name_width = max(len(COMMAND_HEADER), name_width) + +        desc_width = max(len(cmd.description) for cmd in cmds) +        desc_width = max(len(DESCRIPTION_HEADER), desc_width) +        desc_width = min(desc_width, MAX_COLUMN_LEN - SPACING - name_width) + +        # Print headers +        print(f"{'Command':<{name_width}}{SEP}{'Description':<{desc_width}}", file=file) +        print("-" * (name_width + desc_width + SPACING), file=file) + +        # Print rows +        for cmd in cmds: +            name = cmd.name.ljust(name_width) + +            desc_lines = textwrap.wrap(cmd.description, desc_width) +            desc = '' +            for i, desc_line in enumerate(desc_lines): +                if i == 0: +                    desc += desc_line + '\n' +                else: +                    desc += ' ' * (SPACING + name_width) +                    desc += desc_line + '\n' + +            print(f"{name}{SEP}{desc}", file=file, end='') + +        print() + +    def _parse_response( self : "Prompt", response : str) \ +            -> PromptResponse: + +        if not response.startswith(Prompt.COMMAND_INDICATOR): +            return (None, response) + +        cmd_parts = response[1:].split(sep=Prompt.SEP) +        cmd_name = cmd_parts[0].strip() + +        if cmd_name not in self._cmds: +            raise PromptUsageException(f'command not found: {cmd_name}') + +        cmd = self._cmds[cmd_name] + +        return (cmd, None) + +    def input( self : "Prompt" ): +        while True: +            try: +                raw_input = input(self.prompt) +                response = self._parse_response( raw_input ) +                return response +            except PromptUsageException as e: +                print(f'error: {e.message}', file=sys.stderr) + +class Prompter( object ): + +    _cmds : Dict[str, Command] + +    def __init__( self : "Prompter", no_color=False ): +        self._cmds = [] +        self._COLOR = get_color_codes( no_color=no_color ) +        self._PROMPT = surround_ansi_escapes( +            f'[{self._COLOR.WHITE}#{self._COLOR.RESET}] ' +        ) +        self._prev_completer_delims = None + +    def add_command( self : "Prompter", name : str, description : str ): +        cmd = Command(name, description) +        self._cmds.append(cmd) +        return cmd + +    def _enable_prompting( self : "Prompter" ): + +        try: +            import readline as rl + +            def completer( text, state ): + +                # stop completions when slash commands +                # are not specified +                buffer = rl.get_line_buffer() + +                if len(buffer) == 0 or buffer[0] != '/': +                    return None + +                completion_options = [ cmd.name for cmd in self._cmds ] + +                options = [ i for i in completion_options if i.startswith(text) ] + +                if state < len(options): +                    return options[state] +                else: +                    return None + +            rl.set_completer(completer) +            self._prev_completer_delims = rl.get_completer_delims() +            rl.set_completer_delims('/') +            rl.parse_and_bind('tab: complete') + +        except ImportError: +            pass + +    def _disable_prompting(self : "Prompter"): + +        # try our best to clean up +        try: +            import readline as rl + +            rl.set_completer(None) +            rl.parse_and_bind('tab: insert-tab') +            rl.set_completer_delims(self._prev_completer_delims) + +        except ImportError: +            pass + +    def __enter__(self : "Prompter"): +        self._enable_prompting() +        return Prompt(self._PROMPT, self._cmds) + +    def __exit__(self : "Prompter", exec_type, exec_val, exc_tb): +        self._disable_prompting() + + +  | 
