import logging from enum import Enum from dataclasses import dataclass from typing import Tuple, Optional import sys import os from .argparsing import ( RawArguments, AutoDetectedOption, ) ###################### ## PUBLIC INTERFACE ## ###################### @dataclass class CompletionArguments: model: str n_completions: int temperature: float presence_penalty: float frequency_penalty: float max_tokens: int top_p: float @dataclass class DisplayArguments: adornments: bool color: bool @dataclass class DebugArguments: save_response_to_file: Optional[str] load_response_from_file: Optional[str] @dataclass class MessageSource: message: Optional[str] = None prompt_from_fd: Optional[str] = None prompt_from_file: Optional[str] = None @dataclass class Arguments: completion_args: CompletionArguments display_args: DisplayArguments version: bool list_models: bool interactive: bool initial_message: MessageSource openai_key: str system_message: Optional[str] = None debug_args: Optional[DebugArguments] = None def post_process_raw_args(raw_args : RawArguments) -> Arguments: _populate_defaults(raw_args) _issue_warnings(raw_args) _validate_args(raw_args) return _restructure_arguments(raw_args) ##################### ## PRIVATE LOGIC ## ##################### def _restructure_arguments(raw_args : RawArguments) -> Arguments: args = raw_args.args openai_key = raw_args.openai_key completion_args = CompletionArguments( model=args.model, n_completions=args.n_completions, temperature=args.temperature, presence_penalty=args.presence_penalty, frequency_penalty=args.frequency_penalty, max_tokens=args.max_tokens, top_p=args.top_p, ) msg_src = MessageSource( message = args.message, prompt_from_fd = args.prompt_from_fd, prompt_from_file = args.prompt_from_file, ) display_args = DisplayArguments( adornments=(args.adornments == AutoDetectedOption.ON), color=(args.color == AutoDetectedOption.ON), ) debug_args = DebugArguments( save_response_to_file=args.save_response_to_file, load_response_from_file=args.load_response_from_file, ) return Arguments( initial_message=msg_src, openai_key=openai_key, completion_args=completion_args, display_args=display_args, debug_args=debug_args, version=args.version, list_models=args.list_models, interactive=args.interactive, system_message=args.system_message ) def _die_validation_err(err : str): print(err, file=sys.stderr) sys.exit(1) def _validate_args(raw_args : RawArguments) -> None: args = raw_args.args debug = raw_args.debug if not 0 <= args.temperature <= 2: _die_validation_err("Temperature must be between 0 and 2.") if not -2 <= args.frequency_penalty <= 2: _die_validation_err("Frequency penalty must be between -2.0 and 2.0.") if not -2 <= args.presence_penalty <= 2: _die_validation_err("Presence penalty must be between -2.0 and 2.0.") if args.max_tokens < 1: _die_validation_err("Max tokens must be greater than or equal to 1.") if not 0 <= args.top_p <= 1: _die_validation_err("Top_p must be between 0 and 1.") if args.n_completions < 1: _die_validation_err("Number of completions must be greater than or equal to 1.") if args.interactive and args.n_completions != 1: _die_validation_err("Only a single completion can be used in interactive mode") if (args.prompt_from_fd or args.prompt_from_file) and args.message: _die_validation_err("Cannot specify an initial message alongside --prompt_from_fd or --prompt_from_file") if debug and args.interactive: if args.interactive and ( args.save_response_to_file or args.load_response_from_file ): _die_validation_err("Save and load operations cannot be used in interactive mode") def _issue_warnings(raw_args : RawArguments): if not raw_args.openai_key: print("The OPENAI_API_KEY or OPENAI_KEY environment variable must be defined.", file=sys.stderr) print("The OpenAI API uses API keys for authentication. Visit your (API Keys page)[https://platform.openai.com/account/api-keys] to retrieve the API key you'll use in your requests.", file=sys.stderr) sys.exit(1) if raw_args.debug: logging.warning("Debugging mode and unstable features have been enabled.") if raw_args.debug and raw_args.args.load_response_from_file: logging.warning(f'Ignoring the provided arguments in favor of those provided when the response in {raw_args.args.load_response_from_file} was generated') def _populate_defaults(raw_args : RawArguments): args = raw_args.args debug = raw_args.debug if args.color == AutoDetectedOption.AUTO: if os.getenv("NO_COLOR"): args.color = AutoDetectedOption.OFF elif sys.stdout.isatty(): args.color = AutoDetectedOption.ON else: args.color = AutoDetectedOption.OFF if args.adornments == AutoDetectedOption.AUTO: if sys.stdout.isatty(): args.adornments = AutoDetectedOption.ON else: args.adornments = AutoDetectedOption.OFF initial_message_specified = ( args.message or args.prompt_from_fd or args.prompt_from_file ) if not initial_message_specified: if debug and args.load_response_from_file: args.interactive = False elif sys.stdin.isatty(): args.interactive = True if not debug: args.load_response_from_file = None args.save_response_to_file = None