aboutsummaryrefslogtreecommitdiff
path: root/openai_wrappers.py
diff options
context:
space:
mode:
Diffstat (limited to 'openai_wrappers.py')
-rw-r--r--openai_wrappers.py69
1 files changed, 0 insertions, 69 deletions
diff --git a/openai_wrappers.py b/openai_wrappers.py
deleted file mode 100644
index 784a9ce..0000000
--- a/openai_wrappers.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import json
-import openai
-
-from typing import Any, List, Optional, Generator
-from dataclasses import dataclass
-from enum import Enum, auto
-
-@dataclass
-class Delta:
- content: Optional[str] = None
- role: Optional[str] = None
-
-class FinishReason(Enum):
- STOP = auto()
- MAX_TOKENS = auto()
- TEMPERATURE = auto()
- NONE = auto()
-
- @staticmethod
- def from_str(finish_reason_str : Optional[str]) -> "FinishReason":
- if finish_reason_str is None:
- return FinishReason.NONE
- return FinishReason[finish_reason_str.upper()]
-
-@dataclass
-class Choice:
- delta: Delta
- finish_reason: Optional[FinishReason]
- index: int
-
-@dataclass
-class OpenAIChatResponse:
- choices: List[Choice]
- created: int
- id: str
- model: str
- object: str
-
- def from_json(data: Any) -> "OpenAIChatResponse":
- choices = []
-
- for choice in data["choices"]:
- delta = Delta(
- content=choice["delta"].get("content"),
- role=choice["delta"].get("role")
- )
-
- choices.append(Choice(
- delta=delta,
- finish_reason=FinishReason.from_str(choice["finish_reason"]),
- index=choice["index"],
- ))
-
- return OpenAIChatResponse(
- choices,
- created=data["created"],
- id=data["id"],
- model=data["model"],
- object=data["object"],
- )
-
-OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None]
-
-def create_chat_completion(*args, **kwargs) \
- -> OpenAIChatResponseStream:
- return (
- OpenAIChatResponse.from_json(update) \
- for update in openai.ChatCompletion.create(*args, **kwargs)
- )