diff options
Diffstat (limited to 'src/gpt_chat_cli/openai_wrappers.py')
-rw-r--r-- | src/gpt_chat_cli/openai_wrappers.py | 44 |
1 files changed, 40 insertions, 4 deletions
diff --git a/src/gpt_chat_cli/openai_wrappers.py b/src/gpt_chat_cli/openai_wrappers.py index 413ec24..6eeba4d 100644 --- a/src/gpt_chat_cli/openai_wrappers.py +++ b/src/gpt_chat_cli/openai_wrappers.py @@ -28,6 +28,24 @@ class Choice: finish_reason: Optional[FinishReason] index: int +class Role(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + +@dataclass +class ChatMessage: + role: Role + content: str + + def to_json(self : "ChatMessage"): + return { + "role": self.role.value, + "content": self.content + } + +ChatHistory = List[ChatMessage] + @dataclass class OpenAIChatResponse: choices: List[Choice] @@ -61,13 +79,31 @@ class OpenAIChatResponse: OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None] -def create_chat_completion(*args, **kwargs) \ - -> OpenAIChatResponseStream: +from .argparsing import CompletionArguments + +def create_chat_completion(hist : ChatHistory, args: CompletionArguments) \ + -> OpenAIChatResponseStream: + + messages = [ msg.to_json() for msg in hist ] + + response = openai.ChatCompletion.create( + model=args.model, + messages=messages, + n=args.n_completions, + temperature=args.temperature, + presence_penalty=args.presence_penalty, + frequency_penalty=args.frequency_penalty, + max_tokens=args.max_tokens, + top_p=args.top_p, + stream=True + ) + return ( - OpenAIChatResponse.from_json(update) \ - for update in openai.ChatCompletion.create(*args, **kwargs) + OpenAIChatResponse.from_json( update ) \ + for update in response ) + def list_models() -> List[str]: model_data = openai.Model.list() |