From d02f9ded2a503683b57bcacece6a2cff5484fb81 Mon Sep 17 00:00:00 2001 From: flu0r1ne Date: Thu, 4 May 2023 19:46:32 -0500 Subject: Add packaging info --- src/gpt_chat_cli/openai_wrappers.py | 69 +++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 src/gpt_chat_cli/openai_wrappers.py (limited to 'src/gpt_chat_cli/openai_wrappers.py') diff --git a/src/gpt_chat_cli/openai_wrappers.py b/src/gpt_chat_cli/openai_wrappers.py new file mode 100644 index 0000000..784a9ce --- /dev/null +++ b/src/gpt_chat_cli/openai_wrappers.py @@ -0,0 +1,69 @@ +import json +import openai + +from typing import Any, List, Optional, Generator +from dataclasses import dataclass +from enum import Enum, auto + +@dataclass +class Delta: + content: Optional[str] = None + role: Optional[str] = None + +class FinishReason(Enum): + STOP = auto() + MAX_TOKENS = auto() + TEMPERATURE = auto() + NONE = auto() + + @staticmethod + def from_str(finish_reason_str : Optional[str]) -> "FinishReason": + if finish_reason_str is None: + return FinishReason.NONE + return FinishReason[finish_reason_str.upper()] + +@dataclass +class Choice: + delta: Delta + finish_reason: Optional[FinishReason] + index: int + +@dataclass +class OpenAIChatResponse: + choices: List[Choice] + created: int + id: str + model: str + object: str + + def from_json(data: Any) -> "OpenAIChatResponse": + choices = [] + + for choice in data["choices"]: + delta = Delta( + content=choice["delta"].get("content"), + role=choice["delta"].get("role") + ) + + choices.append(Choice( + delta=delta, + finish_reason=FinishReason.from_str(choice["finish_reason"]), + index=choice["index"], + )) + + return OpenAIChatResponse( + choices, + created=data["created"], + id=data["id"], + model=data["model"], + object=data["object"], + ) + +OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None] + +def create_chat_completion(*args, **kwargs) \ + -> OpenAIChatResponseStream: + return ( + OpenAIChatResponse.from_json(update) \ + for update in openai.ChatCompletion.create(*args, **kwargs) + ) -- cgit v1.2.3