aboutsummaryrefslogtreecommitdiff
path: root/openai_wrappers.py
diff options
context:
space:
mode:
Diffstat (limited to 'openai_wrappers.py')
-rw-r--r--openai_wrappers.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/openai_wrappers.py b/openai_wrappers.py
new file mode 100644
index 0000000..cad024a
--- /dev/null
+++ b/openai_wrappers.py
@@ -0,0 +1,68 @@
+import json
+from typing import Any, List, Optional, Generator
+from dataclasses import dataclass
+from enum import Enum, auto
+import openai
+
+@dataclass
+class Delta:
+ content: Optional[str] = None
+ role: Optional[str] = None
+
+class FinishReason(Enum):
+ STOP = auto()
+ MAX_TOKENS = auto()
+ TEMPERATURE = auto()
+ NONE = auto()
+
+ @staticmethod
+ def from_str(finish_reason_str : Optional[str]) -> "FinishReason":
+ if finish_reason_str is None:
+ return FinishReason.NONE
+ return FinishReason[finish_reason_str.upper()]
+
+@dataclass
+class Choice:
+ delta: Delta
+ finish_reason: Optional[FinishReason]
+ index: int
+
+@dataclass
+class OpenAIChatResponse:
+ choices: List[Choice]
+ created: int
+ id: str
+ model: str
+ object: str
+
+ def from_json(data: Any) -> "OpenAIChatResponse":
+ choices = []
+
+ for choice in data["choices"]:
+ delta = Delta(
+ content=choice["delta"].get("content"),
+ role=choice["delta"].get("role")
+ )
+
+ choices.append(Choice(
+ delta=delta,
+ finish_reason=FinishReason.from_str(choice["finish_reason"]),
+ index=choice["index"],
+ ))
+
+ return OpenAIChatResponse(
+ choices,
+ created=data["created"],
+ id=data["id"],
+ model=data["model"],
+ object=data["object"],
+ )
+
+OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None]
+
+def create_chat_completion(*args, **kwargs) \
+ -> OpenAIChatResponseStream:
+ return (
+ OpenAIChatResponse.from_json(update) \
+ for update in openai.ChatCompletion.create(*args, **kwargs)
+ )