aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/gpt_chat_cli/__init__.py0
-rw-r--r--src/gpt_chat_cli/argparsing.py193
-rw-r--r--src/gpt_chat_cli/color.py92
-rw-r--r--src/gpt_chat_cli/gcli.py142
-rw-r--r--src/gpt_chat_cli/openai_wrappers.py69
5 files changed, 496 insertions, 0 deletions
diff --git a/src/gpt_chat_cli/__init__.py b/src/gpt_chat_cli/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/gpt_chat_cli/__init__.py
diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py
new file mode 100644
index 0000000..a7d3218
--- /dev/null
+++ b/src/gpt_chat_cli/argparsing.py
@@ -0,0 +1,193 @@
+import argparse
+import os
+import logging
+import openai
+import sys
+from enum import Enum
+
+class AutoDetectedOption(Enum):
+ ON = 'on'
+ OFF = 'off'
+ AUTO = 'auto'
+
+def die_validation_err(err : str):
+ print(err, file=sys.stderr)
+ sys.exit(1)
+
+def validate_args(args: argparse.Namespace) -> None:
+ if not 0 <= args.temperature <= 2:
+ die_validation_err("Temperature must be between 0 and 2.")
+
+ if not -2 <= args.frequency_penalty <= 2:
+ die_validation_err("Frequency penalty must be between -2.0 and 2.0.")
+
+ if not -2 <= args.presence_penalty <= 2:
+ die_validation_err("Presence penalty must be between -2.0 and 2.0.")
+
+ if args.max_tokens < 1:
+ die_validation_err("Max tokens must be greater than or equal to 1.")
+
+ if not 0 <= args.top_p <= 1:
+ die_validation_err("Top_p must be between 0 and 1.")
+
+ if args.n_completions < 1:
+ die_validation_err("Number of completions must be greater than or equal to 1.")
+
+def parse_args():
+
+ GCLI_ENV_PREFIX = "GCLI_"
+
+ debug = os.getenv(f'{GCLI_ENV_PREFIX}DEBUG') is not None
+
+ if debug:
+ logging.warning("Debugging mode and unstable features have been enabled.")
+
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "-m",
+ "--model",
+ default=os.getenv(f'{GCLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"),
+ help="ID of the model to use",
+ )
+
+ parser.add_argument(
+ "-t",
+ "--temperature",
+ type=float,
+ default=os.getenv(f'{GCLI_ENV_PREFIX}TEMPERATURE', 0.5),
+ help=(
+ "What sampling temperature to use, between 0 and 2. Higher values "
+ "like 0.8 will make the output more random, while lower values "
+ "like 0.2 will make it more focused and deterministic."
+ ),
+ )
+
+ parser.add_argument(
+ "-f",
+ "--frequency-penalty",
+ type=float,
+ default=os.getenv(f'{GCLI_ENV_PREFIX}FREQUENCY_PENALTY', 0),
+ help=(
+ "Number between -2.0 and 2.0. Positive values penalize new tokens based "
+ "on their existing frequency in the text so far, decreasing the model's "
+ "likelihood to repeat the same line verbatim."
+ ),
+ )
+
+ parser.add_argument(
+ "-p",
+ "--presence-penalty",
+ type=float,
+ default=os.getenv(f'{GCLI_ENV_PREFIX}PRESENCE_PENALTY', 0),
+ help=(
+ "Number between -2.0 and 2.0. Positive values penalize new tokens based "
+ "on whether they appear in the text so far, increasing the model's "
+ "likelihood to talk about new topics."
+ ),
+ )
+
+ parser.add_argument(
+ "-k",
+ "--max-tokens",
+ type=int,
+ default=os.getenv(f'{GCLI_ENV_PREFIX}MAX_TOKENS', 2048),
+ help=(
+ "The maximum number of tokens to generate in the chat completion. "
+ "Defaults to 2048."
+ ),
+ )
+
+ parser.add_argument(
+ "-s",
+ "--top-p",
+ type=float,
+ default=os.getenv(f'{GCLI_ENV_PREFIX}TOP_P', 1),
+ help=(
+ "An alternative to sampling with temperature, called nucleus sampling, "
+ "where the model considers the results of the tokens with top_p "
+ "probability mass. So 0.1 means only the tokens comprising the top 10%% "
+ "probability mass are considered."
+ ),
+ )
+
+ parser.add_argument(
+ "-n",
+ "--n-completions",
+ type=int,
+ default=os.getenv('f{GCLI_ENV_PREFIX}N_COMPLETIONS', 1),
+ help="How many chat completion choices to generate for each input message.",
+ )
+
+ parser.add_argument(
+ "--adornments",
+ type=AutoDetectedOption,
+ choices=list(AutoDetectedOption),
+ default=AutoDetectedOption.AUTO,
+ help=(
+ "Show adornments to indicate the model and response."
+ " Can be set to 'on', 'off', or 'auto'."
+ )
+ )
+
+ parser.add_argument(
+ "--color",
+ type=AutoDetectedOption,
+ choices=list(AutoDetectedOption),
+ default=AutoDetectedOption.AUTO,
+ help="Set color to 'on', 'off', or 'auto'.",
+ )
+
+ parser.add_argument(
+ "message",
+ type=str,
+ help=(
+ "The contents of the message. When used in chat mode, this is the initial "
+ "message if provided."
+ ),
+ )
+
+ if debug:
+ group = parser.add_mutually_exclusive_group()
+
+ group.add_argument(
+ '--save-response-to-file',
+ type=str,
+ help="UNSTABLE: save the response to a file. This can reply a response for debugging purposes",
+ )
+
+ group.add_argument(
+ '--load-response-from-file',
+ type=str,
+ help="UNSTABLE: load a response from a file. This can reply a response for debugging purposes",
+ )
+
+ openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY"))
+ if not openai_key:
+ print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr)
+ print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr)
+ sys.exit(1)
+
+ openai.api_key = openai_key
+
+ args = parser.parse_args()
+
+ if debug and args.load_response_from_file:
+ logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {args.load_response_from_file} was generated')
+
+ if args.color == AutoDetectedOption.AUTO:
+ if os.getenv("NO_COLOR"):
+ args.color = AutoDetectedOption.OFF
+ else:
+ args.color = AutoDetectedOption.ON
+
+ if args.adornments == AutoDetectedOption.AUTO:
+ args.adornments = AutoDetectedOption.ON
+
+ if not debug:
+ args.load_response_from_file = None
+ args.save_response_to_file = None
+
+ validate_args(args)
+
+ return args
diff --git a/src/gpt_chat_cli/color.py b/src/gpt_chat_cli/color.py
new file mode 100644
index 0000000..ce1b182
--- /dev/null
+++ b/src/gpt_chat_cli/color.py
@@ -0,0 +1,92 @@
+from typing import Literal
+
+class ColorCode:
+ """A superclass to signal that color codes are strings"""
+
+ BLACK: Literal[str]
+ RED: Literal[str]
+ GREEN: Literal[str]
+ YELLOW: Literal[str]
+ BLUE: Literal[str]
+ MAGENTA: Literal[str]
+ CYAN: Literal[str]
+ WHITE: Literal[str]
+ RESET: Literal[str]
+
+ BLACK_BG: Literal[str]
+ RED_BG: Literal[str]
+ GREEN_BG: Literal[str]
+ YELLOW_BG: Literal[str]
+ BLUE_BG: Literal[str]
+ MAGENTA_BG: Literal[str]
+ CYAN_BG: Literal[str]
+ WHITE_BG: Literal[str]
+
+ BOLD: Literal[str]
+ UNDERLINE: Literal[str]
+ BLINK: Literal[str]
+
+
+class VT100ColorCode(ColorCode):
+ """A class containing VT100 color codes"""
+
+ # Define the color codes
+ BLACK = '\033[30m'
+ RED = '\033[31m'
+ GREEN = '\033[32m'
+ YELLOW = '\033[33m'
+ BLUE = '\033[34m'
+ MAGENTA = '\033[35m'
+ CYAN = '\033[36m'
+ WHITE = '\033[37m'
+ RESET = '\033[0m'
+
+ # Define the background color codes
+ BLACK_BG = '\033[40m'
+ RED_BG = '\033[41m'
+ GREEN_BG = '\033[42m'
+ YELLOW_BG = '\033[43m'
+ BLUE_BG = '\033[44m'
+ MAGENTA_BG = '\033[45m'
+ CYAN_BG = '\033[46m'
+ WHITE_BG = '\033[47m'
+
+ # Define the bold, underline and blink codes
+ BOLD = '\033[1m'
+ UNDERLINE = '\033[4m'
+ BLINK = '\033[5m'
+
+class NoColorColorCode(ColorCode):
+ """A class nullifying color codes to disable color"""
+
+ # Define the color codes
+ BLACK = ''
+ RED = ''
+ GREEN = ''
+ YELLOW = ''
+ BLUE = ''
+ MAGENTA = ''
+ CYAN = ''
+ WHITE = ''
+ RESET = ''
+
+ # Define the background color codes
+ BLACK_BG = ''
+ RED_BG = ''
+ GREEN_BG = ''
+ YELLOW_BG = ''
+ BLUE_BG = ''
+ MAGENTA_BG = ''
+ CYAN_BG = ''
+ WHITE_BG = ''
+
+ # Define the bold, underline and blink codes
+ BOLD = ''
+ UNDERLINE = ''
+ BLINK = ''
+
+def get_color_codes(no_color=False) -> ColorCode:
+ if no_color:
+ return NoColorColorCode
+ else:
+ return VT100ColorCode
diff --git a/src/gpt_chat_cli/gcli.py b/src/gpt_chat_cli/gcli.py
new file mode 100644
index 0000000..ded6d6c
--- /dev/null
+++ b/src/gpt_chat_cli/gcli.py
@@ -0,0 +1,142 @@
+#!/bin/env python3
+
+import argparse
+import sys
+import openai
+import pickle
+
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Tuple
+
+from .openai_wrappers import (
+ create_chat_completion,
+ OpenAIChatResponse,
+ OpenAIChatResponseStream,
+ FinishReason,
+)
+
+from .argparsing import (
+ parse_args,
+ AutoDetectedOption,
+)
+
+from .color import get_color_codes
+
+###########################
+#### SAVE / REPLAY ####
+###########################
+
+def create_chat_completion_from_args(args : argparse.Namespace) \
+ -> OpenAIChatResponseStream:
+ return create_chat_completion(
+ model=args.model,
+ messages=[{ "role": "user", "content": args.message }],
+ n=args.n_completions,
+ temperature=args.temperature,
+ presence_penalty=args.presence_penalty,
+ frequency_penalty=args.frequency_penalty,
+ max_tokens=args.max_tokens,
+ top_p=args.top_p,
+ stream=True
+ )
+
+def save_response_and_arguments(args : argparse.Namespace) -> None:
+ completion = create_chat_completion_from_args(args)
+ completion = list(completion)
+
+ filename = args.save_response_to_file
+
+ with open(filename, 'wb') as f:
+ pickle.dump((args, completion,), f)
+
+def load_response_and_arguments(args : argparse.Namespace) \
+ -> Tuple[argparse.Namespace, OpenAIChatResponseStream]:
+
+ filename = args.load_response_from_file
+
+ with open(filename, 'rb') as f:
+ args, completion = pickle.load(f)
+
+ return (args, completion)
+
+#########################
+#### PRETTY PRINTING ####
+#########################
+
+@dataclass
+class CumulativeResponse:
+ content: str = ""
+ finish_reason: FinishReason = FinishReason.NONE
+
+ def take_content(self : "CumulativeResponse"):
+ chunk = self.content
+ self.content = ""
+ return chunk
+
+def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatResponseStream):
+ """
+ Print the response in real time by printing the deltas as they occur. If multiple responses
+ are requested, print the first in real-time, accumulating the others in the background. One the
+ first response completes, move on to the second response printing the deltas in real time. Continue
+ on until all responses have been printed.
+ """
+
+ COLOR_CODE = get_color_codes(no_color = args.color == AutoDetectedOption.OFF)
+ ADORNMENTS = args.adornments == AutoDetectedOption.ON
+ N_COMPLETIONS = args.n_completions
+
+ cumu_responses = defaultdict(CumulativeResponse)
+ display_idx = 0
+ prompt_printed = False
+
+ for update in completion:
+
+ for choice in update.choices:
+ delta = choice.delta
+
+ if delta.content:
+ cumu_responses[choice.index].content += delta.content
+
+ if choice.finish_reason is not FinishReason.NONE:
+ cumu_responses[choice.index].finish_reason = choice.finish_reason
+
+ display_response = cumu_responses[display_idx]
+
+ if not prompt_printed and ADORNMENTS:
+ res_indicator = '' if N_COMPLETIONS == 1 else \
+ f' {display_idx + 1}/{n_completions}'
+ PROMPT = f'[{COLOR_CODE.GREEN}{update.model}{COLOR_CODE.RESET}{COLOR_CODE.RED}{res_indicator}{COLOR_CODE.RESET}]'
+ prompt_printed = True
+ print(PROMPT, end=' ', flush=True)
+
+
+ content = display_response.take_content()
+ print(f'{COLOR_CODE.WHITE}{content}{COLOR_CODE.RESET}',
+ sep='', end='', flush=True)
+
+ if display_response.finish_reason is not FinishReason.NONE:
+ if display_idx < N_COMPLETIONS:
+ display_idx += 1
+ prompt_printed = False
+
+ if ADORNMENTS:
+ print(end='\n\n', flush=True)
+ else:
+ print(end='\n', flush=True)
+
+def main():
+ args = parse_args()
+
+ if args.save_response_to_file:
+ save_response_and_arguments(args)
+ return
+ elif args.load_response_from_file:
+ args, completion = load_response_and_arguments(args)
+ else:
+ completion = create_chat_completion_from_args(args)
+
+ print_streamed_response(args, completion)
+
+if __name__ == "__main__":
+ main()
diff --git a/src/gpt_chat_cli/openai_wrappers.py b/src/gpt_chat_cli/openai_wrappers.py
new file mode 100644
index 0000000..784a9ce
--- /dev/null
+++ b/src/gpt_chat_cli/openai_wrappers.py
@@ -0,0 +1,69 @@
+import json
+import openai
+
+from typing import Any, List, Optional, Generator
+from dataclasses import dataclass
+from enum import Enum, auto
+
+@dataclass
+class Delta:
+ content: Optional[str] = None
+ role: Optional[str] = None
+
+class FinishReason(Enum):
+ STOP = auto()
+ MAX_TOKENS = auto()
+ TEMPERATURE = auto()
+ NONE = auto()
+
+ @staticmethod
+ def from_str(finish_reason_str : Optional[str]) -> "FinishReason":
+ if finish_reason_str is None:
+ return FinishReason.NONE
+ return FinishReason[finish_reason_str.upper()]
+
+@dataclass
+class Choice:
+ delta: Delta
+ finish_reason: Optional[FinishReason]
+ index: int
+
+@dataclass
+class OpenAIChatResponse:
+ choices: List[Choice]
+ created: int
+ id: str
+ model: str
+ object: str
+
+ def from_json(data: Any) -> "OpenAIChatResponse":
+ choices = []
+
+ for choice in data["choices"]:
+ delta = Delta(
+ content=choice["delta"].get("content"),
+ role=choice["delta"].get("role")
+ )
+
+ choices.append(Choice(
+ delta=delta,
+ finish_reason=FinishReason.from_str(choice["finish_reason"]),
+ index=choice["index"],
+ ))
+
+ return OpenAIChatResponse(
+ choices,
+ created=data["created"],
+ id=data["id"],
+ model=data["model"],
+ object=data["object"],
+ )
+
+OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None]
+
+def create_chat_completion(*args, **kwargs) \
+ -> OpenAIChatResponseStream:
+ return (
+ OpenAIChatResponse.from_json(update) \
+ for update in openai.ChatCompletion.create(*args, **kwargs)
+ )