aboutsummaryrefslogtreecommitdiff
path: root/src/gpt_chat_cli/argparsing.py
diff options
context:
space:
mode:
authorflu0r1ne <flu0r1ne@flu0r1ne.net>2023-05-04 19:46:32 -0500
committerflu0r1ne <flu0r1ne@flu0r1ne.net>2023-05-04 19:46:32 -0500
commitd02f9ded2a503683b57bcacece6a2cff5484fb81 (patch)
tree896521aa42194e5055b04cc7d5c6e93b56719f73 /src/gpt_chat_cli/argparsing.py
parenta74933b2d83efb5da4e0f1851d65ad575f04a65d (diff)
downloadgpt-chat-cli-d02f9ded2a503683b57bcacece6a2cff5484fb81.tar.xz
gpt-chat-cli-d02f9ded2a503683b57bcacece6a2cff5484fb81.zip
Add packaging info
Diffstat (limited to 'src/gpt_chat_cli/argparsing.py')
-rw-r--r--src/gpt_chat_cli/argparsing.py193
1 files changed, 193 insertions, 0 deletions
diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py
new file mode 100644
index 0000000..a7d3218
--- /dev/null
+++ b/src/gpt_chat_cli/argparsing.py
@@ -0,0 +1,193 @@
+import argparse
+import os
+import logging
+import openai
+import sys
+from enum import Enum
+
+class AutoDetectedOption(Enum):
+ ON = 'on'
+ OFF = 'off'
+ AUTO = 'auto'
+
+def die_validation_err(err : str):
+ print(err, file=sys.stderr)
+ sys.exit(1)
+
+def validate_args(args: argparse.Namespace) -> None:
+ if not 0 <= args.temperature <= 2:
+ die_validation_err("Temperature must be between 0 and 2.")
+
+ if not -2 <= args.frequency_penalty <= 2:
+ die_validation_err("Frequency penalty must be between -2.0 and 2.0.")
+
+ if not -2 <= args.presence_penalty <= 2:
+ die_validation_err("Presence penalty must be between -2.0 and 2.0.")
+
+ if args.max_tokens < 1:
+ die_validation_err("Max tokens must be greater than or equal to 1.")
+
+ if not 0 <= args.top_p <= 1:
+ die_validation_err("Top_p must be between 0 and 1.")
+
+ if args.n_completions < 1:
+ die_validation_err("Number of completions must be greater than or equal to 1.")
+
+def parse_args():
+
+ GCLI_ENV_PREFIX = "GCLI_"
+
+ debug = os.getenv(f'{GCLI_ENV_PREFIX}DEBUG') is not None
+
+ if debug:
+ logging.warning("Debugging mode and unstable features have been enabled.")
+
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "-m",
+ "--model",
+ default=os.getenv(f'{GCLI_ENV_PREFIX}MODEL', "gpt-3.5-turbo"),
+ help="ID of the model to use",
+ )
+
+ parser.add_argument(
+ "-t",
+ "--temperature",
+ type=float,
+ default=os.getenv(f'{GCLI_ENV_PREFIX}TEMPERATURE', 0.5),
+ help=(
+ "What sampling temperature to use, between 0 and 2. Higher values "
+ "like 0.8 will make the output more random, while lower values "
+ "like 0.2 will make it more focused and deterministic."
+ ),
+ )
+
+ parser.add_argument(
+ "-f",
+ "--frequency-penalty",
+ type=float,
+ default=os.getenv(f'{GCLI_ENV_PREFIX}FREQUENCY_PENALTY', 0),
+ help=(
+ "Number between -2.0 and 2.0. Positive values penalize new tokens based "
+ "on their existing frequency in the text so far, decreasing the model's "
+ "likelihood to repeat the same line verbatim."
+ ),
+ )
+
+ parser.add_argument(
+ "-p",
+ "--presence-penalty",
+ type=float,
+ default=os.getenv(f'{GCLI_ENV_PREFIX}PRESENCE_PENALTY', 0),
+ help=(
+ "Number between -2.0 and 2.0. Positive values penalize new tokens based "
+ "on whether they appear in the text so far, increasing the model's "
+ "likelihood to talk about new topics."
+ ),
+ )
+
+ parser.add_argument(
+ "-k",
+ "--max-tokens",
+ type=int,
+ default=os.getenv(f'{GCLI_ENV_PREFIX}MAX_TOKENS', 2048),
+ help=(
+ "The maximum number of tokens to generate in the chat completion. "
+ "Defaults to 2048."
+ ),
+ )
+
+ parser.add_argument(
+ "-s",
+ "--top-p",
+ type=float,
+ default=os.getenv(f'{GCLI_ENV_PREFIX}TOP_P', 1),
+ help=(
+ "An alternative to sampling with temperature, called nucleus sampling, "
+ "where the model considers the results of the tokens with top_p "
+ "probability mass. So 0.1 means only the tokens comprising the top 10%% "
+ "probability mass are considered."
+ ),
+ )
+
+ parser.add_argument(
+ "-n",
+ "--n-completions",
+ type=int,
+ default=os.getenv('f{GCLI_ENV_PREFIX}N_COMPLETIONS', 1),
+ help="How many chat completion choices to generate for each input message.",
+ )
+
+ parser.add_argument(
+ "--adornments",
+ type=AutoDetectedOption,
+ choices=list(AutoDetectedOption),
+ default=AutoDetectedOption.AUTO,
+ help=(
+ "Show adornments to indicate the model and response."
+ " Can be set to 'on', 'off', or 'auto'."
+ )
+ )
+
+ parser.add_argument(
+ "--color",
+ type=AutoDetectedOption,
+ choices=list(AutoDetectedOption),
+ default=AutoDetectedOption.AUTO,
+ help="Set color to 'on', 'off', or 'auto'.",
+ )
+
+ parser.add_argument(
+ "message",
+ type=str,
+ help=(
+ "The contents of the message. When used in chat mode, this is the initial "
+ "message if provided."
+ ),
+ )
+
+ if debug:
+ group = parser.add_mutually_exclusive_group()
+
+ group.add_argument(
+ '--save-response-to-file',
+ type=str,
+ help="UNSTABLE: save the response to a file. This can reply a response for debugging purposes",
+ )
+
+ group.add_argument(
+ '--load-response-from-file',
+ type=str,
+ help="UNSTABLE: load a response from a file. This can reply a response for debugging purposes",
+ )
+
+ openai_key = os.getenv("OPENAI_KEY", os.getenv("OPENAI_API_KEY"))
+ if not openai_key:
+ print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr)
+ print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr)
+ sys.exit(1)
+
+ openai.api_key = openai_key
+
+ args = parser.parse_args()
+
+ if debug and args.load_response_from_file:
+ logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {args.load_response_from_file} was generated')
+
+ if args.color == AutoDetectedOption.AUTO:
+ if os.getenv("NO_COLOR"):
+ args.color = AutoDetectedOption.OFF
+ else:
+ args.color = AutoDetectedOption.ON
+
+ if args.adornments == AutoDetectedOption.AUTO:
+ args.adornments = AutoDetectedOption.ON
+
+ if not debug:
+ args.load_response_from_file = None
+ args.save_response_to_file = None
+
+ validate_args(args)
+
+ return args