diff options
author | flu0r1ne <flu0r1ne@flu0r1ne.net> | 2023-05-04 18:10:32 -0500 |
---|---|---|
committer | flu0r1ne <flu0r1ne@flu0r1ne.net> | 2023-05-04 18:10:32 -0500 |
commit | 44b15a775e3f5370c81507534df103f63611b956 (patch) | |
tree | 6cefb2252ce10f922e21438d1cc9747c0ce67ec2 | |
download | gpt-chat-cli-44b15a775e3f5370c81507534df103f63611b956.tar.xz gpt-chat-cli-44b15a775e3f5370c81507534df103f63611b956.zip |
Add version 0.0.1 which lacks interactive features but streams colorized output
-rw-r--r-- | argparsing.py | 193 | ||||
-rw-r--r-- | color.py | 93 | ||||
-rw-r--r-- | gcli.py | 139 | ||||
-rw-r--r-- | openai_wrappers.py | 68 |
4 files changed, 493 insertions, 0 deletions
diff --git a/argparsing.py b/argparsing.py new file mode 100644 index 0000000..a7d3218 --- /dev/null +++ b/argparsing.py @@ -0,0 +1,193 @@ +import argparse +import os +import logging +import openai +import sys +from enum import Enum + +class AutoDetectedOption(Enum): + ON = 'on' + OFF = 'off' + AUTO = 'auto' + +def die_validation_err(err : str): + print(err, file=sys.stderr) + sys.exit(1) + +def validate_args(args: argparse.Namespace) -> None: + if not 0 <= args.temperature <= 2: + die_validation_err("Temperature must be between 0 and 2.") + + if not -2 <= args.frequency_penalty <= 2: + die_validation_err("Frequency penalty must be between -2.0 and 2.0.") + + if not -2 <= args.presence_penalty <= 2: + die_validation_err("Presence penalty must be between -2.0 and 2.0.") + + if args.max_tokens < 1: + die_validation_err("Max tokens must be greater than or equal to 1.") + + if not 0 <= args.top_p <= 1: + die_validation_err("Top_p must be between 0 and 1.") + + if args.n_completions < 1: + die_validation_err("Number of completions must be greater than or equal to 1.") + +def parse_args(): + + GCLI_ENV_PREFIX = "GCLI_" + + debug = os.getenv(f'{GCLI_ENV_PREFIX}DEBUG') is not None + + if debug: + logging.warning("Debugging mode and unstable features have been enabled.") + + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--model", + default=os.getenv(f'{GCLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"), + help="ID of the model to use", + ) + + parser.add_argument( + "-t", + "--temperature", + type=float, + default=os.getenv(f'{GCLI_ENV_PREFIX}TEMPERATURE', 0.5), + help=( + "What sampling temperature to use, between 0 and 2. Higher values " + "like 0.8 will make the output more random, while lower values " + "like 0.2 will make it more focused and deterministic." + ), + ) + + parser.add_argument( + "-f", + "--frequency-penalty", + type=float, + default=os.getenv(f'{GCLI_ENV_PREFIX}FREQUENCY_PENALTY', 0), + help=( + "Number between -2.0 and 2.0. Positive values penalize new tokens based " + "on their existing frequency in the text so far, decreasing the model's " + "likelihood to repeat the same line verbatim." + ), + ) + + parser.add_argument( + "-p", + "--presence-penalty", + type=float, + default=os.getenv(f'{GCLI_ENV_PREFIX}PRESENCE_PENALTY', 0), + help=( + "Number between -2.0 and 2.0. Positive values penalize new tokens based " + "on whether they appear in the text so far, increasing the model's " + "likelihood to talk about new topics." + ), + ) + + parser.add_argument( + "-k", + "--max-tokens", + type=int, + default=os.getenv(f'{GCLI_ENV_PREFIX}MAX_TOKENS', 2048), + help=( + "The maximum number of tokens to generate in the chat completion. " + "Defaults to 2048." + ), + ) + + parser.add_argument( + "-s", + "--top-p", + type=float, + default=os.getenv(f'{GCLI_ENV_PREFIX}TOP_P', 1), + help=( + "An alternative to sampling with temperature, called nucleus sampling, " + "where the model considers the results of the tokens with top_p " + "probability mass. So 0.1 means only the tokens comprising the top 10%% " + "probability mass are considered." + ), + ) + + parser.add_argument( + "-n", + "--n-completions", + type=int, + default=os.getenv('f{GCLI_ENV_PREFIX}N_COMPLETIONS', 1), + help="How many chat completion choices to generate for each input message.", + ) + + parser.add_argument( + "--adornments", + type=AutoDetectedOption, + choices=list(AutoDetectedOption), + default=AutoDetectedOption.AUTO, + help=( + "Show adornments to indicate the model and response." + " Can be set to 'on', 'off', or 'auto'." + ) + ) + + parser.add_argument( + "--color", + type=AutoDetectedOption, + choices=list(AutoDetectedOption), + default=AutoDetectedOption.AUTO, + help="Set color to 'on', 'off', or 'auto'.", + ) + + parser.add_argument( + "message", + type=str, + help=( + "The contents of the message. When used in chat mode, this is the initial " + "message if provided." + ), + ) + + if debug: + group = parser.add_mutually_exclusive_group() + + group.add_argument( + '--save-response-to-file', + type=str, + help="UNSTABLE: save the response to a file. This can reply a response for debugging purposes", + ) + + group.add_argument( + '--load-response-from-file', + type=str, + help="UNSTABLE: load a response from a file. This can reply a response for debugging purposes", + ) + + openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY")) + if not openai_key: + print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr) + print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr) + sys.exit(1) + + openai.api_key = openai_key + + args = parser.parse_args() + + if debug and args.load_response_from_file: + logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {args.load_response_from_file} was generated') + + if args.color == AutoDetectedOption.AUTO: + if os.getenv("NO_COLOR"): + args.color = AutoDetectedOption.OFF + else: + args.color = AutoDetectedOption.ON + + if args.adornments == AutoDetectedOption.AUTO: + args.adornments = AutoDetectedOption.ON + + if not debug: + args.load_response_from_file = None + args.save_response_to_file = None + + validate_args(args) + + return args diff --git a/color.py b/color.py new file mode 100644 index 0000000..b868290 --- /dev/null +++ b/color.py @@ -0,0 +1,93 @@ +from typing import Literal + + +class ColorCode: + """A superclass to signal that color codes are strings""" + + BLACK: Literal[str] + RED: Literal[str] + GREEN: Literal[str] + YELLOW: Literal[str] + BLUE: Literal[str] + MAGENTA: Literal[str] + CYAN: Literal[str] + WHITE: Literal[str] + RESET: Literal[str] + + BLACK_BG: Literal[str] + RED_BG: Literal[str] + GREEN_BG: Literal[str] + YELLOW_BG: Literal[str] + BLUE_BG: Literal[str] + MAGENTA_BG: Literal[str] + CYAN_BG: Literal[str] + WHITE_BG: Literal[str] + + BOLD: Literal[str] + UNDERLINE: Literal[str] + BLINK: Literal[str] + + +class VT100ColorCode(ColorCode): + """A class containing VT100 color codes""" + + # Define the color codes + BLACK = '\033[30m' + RED = '\033[31m' + GREEN = '\033[32m' + YELLOW = '\033[33m' + BLUE = '\033[34m' + MAGENTA = '\033[35m' + CYAN = '\033[36m' + WHITE = '\033[37m' + RESET = '\033[0m' + + # Define the background color codes + BLACK_BG = '\033[40m' + RED_BG = '\033[41m' + GREEN_BG = '\033[42m' + YELLOW_BG = '\033[43m' + BLUE_BG = '\033[44m' + MAGENTA_BG = '\033[45m' + CYAN_BG = '\033[46m' + WHITE_BG = '\033[47m' + + # Define the bold, underline and blink codes + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + BLINK = '\033[5m' + +class NoColorColorCode(ColorCode): + """A class nullifying color codes to disable color""" + + # Define the color codes + BLACK = '' + RED = '' + GREEN = '' + YELLOW = '' + BLUE = '' + MAGENTA = '' + CYAN = '' + WHITE = '' + RESET = '' + + # Define the background color codes + BLACK_BG = '' + RED_BG = '' + GREEN_BG = '' + YELLOW_BG = '' + BLUE_BG = '' + MAGENTA_BG = '' + CYAN_BG = '' + WHITE_BG = '' + + # Define the bold, underline and blink codes + BOLD = '' + UNDERLINE = '' + BLINK = '' + +def get_color_codes(no_color=False) -> ColorCode: + if no_color: + return NoColorColorCode + else: + return VT100ColorCode @@ -0,0 +1,139 @@ +import argparse +from collections import defaultdict +import sys +import openai +import pickle +from dataclasses import dataclass +from typing import Tuple + +from openai_wrappers import ( + create_chat_completion, + OpenAIChatResponse, + OpenAIChatResponseStream, + FinishReason, +) + +from argparsing import ( + parse_args, + AutoDetectedOption, +) + +from color import get_color_codes + +########################### +#### SAVE / REPLAY #### +########################### + +def create_chat_completion_from_args(args : argparse.Namespace) \ + -> OpenAIChatResponseStream: + return create_chat_completion( + model=args.model, + messages=[{ "role": "user", "content": args.message }], + n=args.n_completions, + temperature=args.temperature, + presence_penalty=args.presence_penalty, + frequency_penalty=args.frequency_penalty, + max_tokens=args.max_tokens, + top_p=args.top_p, + stream=True + ) + +def save_response_and_arguments(args : argparse.Namespace) -> None: + completion = create_chat_completion_from_args(args) + completion = list(completion) + + filename = args.save_response_to_file + + with open(filename, 'wb') as f: + pickle.dump((args, completion,), f) + +def load_response_and_arguments(args : argparse.Namespace) \ + -> Tuple[argparse.Namespace, OpenAIChatResponseStream]: + + filename = args.load_response_from_file + + with open(filename, 'rb') as f: + args, completion = pickle.load(f) + + return (args, completion) + +######################### +#### PRETTY PRINTING #### +######################### + +@dataclass +class CumulativeResponse: + content: str = "" + finish_reason: FinishReason = FinishReason.NONE + + def take_content(self : "CumulativeResponse"): + chunk = self.content + self.content = "" + return chunk + +def print_streamed_response(args : argparse.Namespace, completion : OpenAIChatResponseStream): + """ + Print the response in real time by printing the deltas as they occur. If multiple responses + are requested, print the first in real-time, accumulating the others in the background. One the + first response completes, move on to the second response printing the deltas in real time. Continue + on until all responses have been printed. + """ + + COLOR_CODE = get_color_codes(no_color = args.color == AutoDetectedOption.OFF) + ADORNMENTS = args.adornments == AutoDetectedOption.ON + N_COMPLETIONS = args.n_completions + + cumu_responses = defaultdict(CumulativeResponse) + display_idx = 0 + prompt_printed = False + + for update in completion: + + for choice in update.choices: + delta = choice.delta + + if delta.content: + cumu_responses[choice.index].content += delta.content + + if choice.finish_reason is not FinishReason.NONE: + cumu_responses[choice.index].finish_reason = choice.finish_reason + + display_response = cumu_responses[display_idx] + + if not prompt_printed and ADORNMENTS: + res_indicator = '' if N_COMPLETIONS == 1 else \ + f' {display_idx + 1}/{n_completions}' + PROMPT = f'[{COLOR_CODE.GREEN}{update.model}{COLOR_CODE.RESET}{COLOR_CODE.RED}{res_indicator}{COLOR_CODE.RESET}]' + prompt_printed = True + print(PROMPT, end=' ', flush=True) + + + content = display_response.take_content() + print(f'{COLOR_CODE.WHITE}{content}{COLOR_CODE.RESET}', + sep='', end='', flush=True) + + if display_response.finish_reason is not FinishReason.NONE: + if display_idx < N_COMPLETIONS: + display_idx += 1 + prompt_printed = False + + if ADORNMENTS: + print(end='\n\n', flush=True) + else: + print(end='\n', flush=True) + +def main(): + args = parse_args() + + if args.save_response_to_file: + save_response_and_arguments(args) + return + elif args.load_response_from_file: + args, completion = load_response_and_arguments(args) + else: + completion = create_chat_completion_from_args(args) + + print_streamed_response(args, completion) + +if __name__ == "__main__": + main() diff --git a/openai_wrappers.py b/openai_wrappers.py new file mode 100644 index 0000000..cad024a --- /dev/null +++ b/openai_wrappers.py @@ -0,0 +1,68 @@ +import json +from typing import Any, List, Optional, Generator +from dataclasses import dataclass +from enum import Enum, auto +import openai + +@dataclass +class Delta: + content: Optional[str] = None + role: Optional[str] = None + +class FinishReason(Enum): + STOP = auto() + MAX_TOKENS = auto() + TEMPERATURE = auto() + NONE = auto() + + @staticmethod + def from_str(finish_reason_str : Optional[str]) -> "FinishReason": + if finish_reason_str is None: + return FinishReason.NONE + return FinishReason[finish_reason_str.upper()] + +@dataclass +class Choice: + delta: Delta + finish_reason: Optional[FinishReason] + index: int + +@dataclass +class OpenAIChatResponse: + choices: List[Choice] + created: int + id: str + model: str + object: str + + def from_json(data: Any) -> "OpenAIChatResponse": + choices = [] + + for choice in data["choices"]: + delta = Delta( + content=choice["delta"].get("content"), + role=choice["delta"].get("role") + ) + + choices.append(Choice( + delta=delta, + finish_reason=FinishReason.from_str(choice["finish_reason"]), + index=choice["index"], + )) + + return OpenAIChatResponse( + choices, + created=data["created"], + id=data["id"], + model=data["model"], + object=data["object"], + ) + +OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None] + +def create_chat_completion(*args, **kwargs) \ + -> OpenAIChatResponseStream: + return ( + OpenAIChatResponse.from_json(update) \ + for update in openai.ChatCompletion.create(*args, **kwargs) + ) |