aboutsummaryrefslogtreecommitdiff
path: root/openai_wrappers.py
blob: 784a9cee8ae8e38893af566474bb5c850e915a24 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import json
import openai

from typing import Any, List, Optional, Generator
from dataclasses import dataclass
from enum import Enum, auto

@dataclass
class Delta:
    content: Optional[str] = None
    role: Optional[str] = None

class FinishReason(Enum):
    STOP = auto()
    MAX_TOKENS = auto()
    TEMPERATURE = auto()
    NONE = auto()

    @staticmethod
    def from_str(finish_reason_str : Optional[str]) -> "FinishReason":
        if finish_reason_str is None:
            return FinishReason.NONE
        return FinishReason[finish_reason_str.upper()]

@dataclass
class Choice:
    delta: Delta
    finish_reason: Optional[FinishReason]
    index: int

@dataclass
class OpenAIChatResponse:
    choices: List[Choice]
    created: int
    id: str
    model: str
    object: str

    def from_json(data: Any) -> "OpenAIChatResponse":
        choices = []

        for choice in data["choices"]:
            delta = Delta(
                content=choice["delta"].get("content"),
                role=choice["delta"].get("role")
            )

            choices.append(Choice(
                delta=delta,
                finish_reason=FinishReason.from_str(choice["finish_reason"]),
                index=choice["index"],
            ))

        return OpenAIChatResponse(
            choices,
            created=data["created"],
            id=data["id"],
            model=data["model"],
            object=data["object"],
        )

OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None]

def create_chat_completion(*args, **kwargs) \
        -> OpenAIChatResponseStream:
    return (
        OpenAIChatResponse.from_json(update) \
        for update in  openai.ChatCompletion.create(*args, **kwargs)
    )