diff options
Diffstat (limited to 'src/gpt_chat_cli/argparsing.py')
-rw-r--r-- | src/gpt_chat_cli/argparsing.py | 73 |
1 files changed, 69 insertions, 4 deletions
diff --git a/src/gpt_chat_cli/argparsing.py b/src/gpt_chat_cli/argparsing.py index a7d3218..7d1d305 100644 --- a/src/gpt_chat_cli/argparsing.py +++ b/src/gpt_chat_cli/argparsing.py @@ -4,12 +4,17 @@ import logging import openai import sys from enum import Enum +from dataclasses import dataclass +from typing import Tuple, Optional class AutoDetectedOption(Enum): ON = 'on' OFF = 'off' AUTO = 'auto' + def __str__(self : "AutoDetectedOption"): + return self.value + def die_validation_err(err : str): print(err, file=sys.stderr) sys.exit(1) @@ -33,7 +38,62 @@ def validate_args(args: argparse.Namespace) -> None: if args.n_completions < 1: die_validation_err("Number of completions must be greater than or equal to 1.") -def parse_args(): +@dataclass +class CompletionArguments: + model: str + n_completions: int + temperature: float + presence_penalty: float + frequency_penalty: float + max_tokens: int + top_p: float + message: str + +@dataclass +class DisplayArguments: + adornments: bool + color: bool + + +@dataclass +class DebugArguments: + save_response_to_file: Optional[str] + load_response_from_file: Optional[str] + +@dataclass +class Arguments: + completion_args: CompletionArguments + display_args: DisplayArguments + debug_args: Optional[DebugArguments] = None + +def split_arguments(args: argparse.Namespace, debug=False) -> Arguments: + completion_args = CompletionArguments( + model=args.model, + n_completions=args.n_completions, + temperature=args.temperature, + presence_penalty=args.presence_penalty, + frequency_penalty=args.frequency_penalty, + max_tokens=args.max_tokens, + top_p=args.top_p, + message=args.message + ) + + display_args = DisplayArguments( + adornments=(args.adornments == AutoDetectedOption.ON), + color=(args.color == AutoDetectedOption.ON), + ) + + if debug: + debug_args = DebugArguments( + save_response_to_file=args.save_response_to_file, + load_response_from_file=args.load_response_from_file, + ) + else: + debug_args = None + + return Arguments( completion_args, display_args, debug_args ) + +def parse_args() -> Arguments: GCLI_ENV_PREFIX = "GCLI_" @@ -178,11 +238,16 @@ def parse_args(): if args.color == AutoDetectedOption.AUTO: if os.getenv("NO_COLOR"): args.color = AutoDetectedOption.OFF - else: + elif sys.stdout.isatty(): args.color = AutoDetectedOption.ON + else: + args.color = AutoDetectedOption.OFF if args.adornments == AutoDetectedOption.AUTO: - args.adornments = AutoDetectedOption.ON + if sys.stdout.isatty(): + args.adornments = AutoDetectedOption.ON + else: + args.adornments = AutoDetectedOption.OFF if not debug: args.load_response_from_file = None @@ -190,4 +255,4 @@ def parse_args(): validate_args(args) - return args + return split_arguments(args, debug=debug) |