aboutsummaryrefslogtreecommitdiff
path: root/src/gpt_chat_cli/argparsing.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpt_chat_cli/argparsing.py')
-rw-r--r--src/gpt_chat_cli/argparsing.py73
1 files changed, 69 insertions, 4 deletions
diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py
index a7d3218..7d1d305 100644
--- a/src/gpt_chat_cli/argparsing.py
+++ b/src/gpt_chat_cli/argparsing.py
@@ -4,12 +4,17 @@ import logging
import openai
import sys
from enum import Enum
+from dataclasses import dataclass
+from typing import Tuple, Optional
class AutoDetectedOption(Enum):
ON = 'on'
OFF = 'off'
AUTO = 'auto'
+ def __str__(self : "AutoDetectedOption"):
+ return self.value
+
def die_validation_err(err : str):
print(err, file=sys.stderr)
sys.exit(1)
@@ -33,7 +38,62 @@ def validate_args(args: argparse.Namespace) -> None:
if args.n_completions < 1:
die_validation_err("Number of completions must be greater than or equal to 1.")
-def parse_args():
+@dataclass
+class CompletionArguments:
+ model: str
+ n_completions: int
+ temperature: float
+ presence_penalty: float
+ frequency_penalty: float
+ max_tokens: int
+ top_p: float
+ message: str
+
+@dataclass
+class DisplayArguments:
+ adornments: bool
+ color: bool
+
+
+@dataclass
+class DebugArguments:
+ save_response_to_file: Optional[str]
+ load_response_from_file: Optional[str]
+
+@dataclass
+class Arguments:
+ completion_args: CompletionArguments
+ display_args: DisplayArguments
+ debug_args: Optional[DebugArguments] = None
+
+def split_arguments(args: argparse.Namespace, debug=False) -> Arguments:
+ completion_args = CompletionArguments(
+ model=args.model,
+ n_completions=args.n_completions,
+ temperature=args.temperature,
+ presence_penalty=args.presence_penalty,
+ frequency_penalty=args.frequency_penalty,
+ max_tokens=args.max_tokens,
+ top_p=args.top_p,
+ message=args.message
+ )
+
+ display_args = DisplayArguments(
+ adornments=(args.adornments == AutoDetectedOption.ON),
+ color=(args.color == AutoDetectedOption.ON),
+ )
+
+ if debug:
+ debug_args = DebugArguments(
+ save_response_to_file=args.save_response_to_file,
+ load_response_from_file=args.load_response_from_file,
+ )
+ else:
+ debug_args = None
+
+ return Arguments( completion_args, display_args, debug_args )
+
+def parse_args() -> Arguments:
GCLI_ENV_PREFIX = "GCLI_"
@@ -178,11 +238,16 @@ def parse_args():
if args.color == AutoDetectedOption.AUTO:
if os.getenv("NO_COLOR"):
args.color = AutoDetectedOption.OFF
- else:
+ elif sys.stdout.isatty():
args.color = AutoDetectedOption.ON
+ else:
+ args.color = AutoDetectedOption.OFF
if args.adornments == AutoDetectedOption.AUTO:
- args.adornments = AutoDetectedOption.ON
+ if sys.stdout.isatty():
+ args.adornments = AutoDetectedOption.ON
+ else:
+ args.adornments = AutoDetectedOption.OFF
if not debug:
args.load_response_from_file = None
@@ -190,4 +255,4 @@ def parse_args():
validate_args(args)
- return args
+ return split_arguments(args, debug=debug)