aboutsummaryrefslogtreecommitdiff
path: root/src/gpt_chat_cli/openai_wrappers.py
blob: d478531c29301827f9c03d424288147d0d2f8796 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import json
import openai

from typing import Any, List, Optional, Generator
from dataclasses import dataclass
from enum import Enum, auto

from .argvalidation import CompletionArguments

@dataclass
class Delta:
    content: Optional[str] = None
    role: Optional[str] = None

class FinishReason(Enum):
    STOP = auto()
    MAX_TOKENS = auto()
    TEMPERATURE = auto()
    NONE = auto()

    @staticmethod
    def from_str(finish_reason_str : Optional[str]) -> "FinishReason":
        if finish_reason_str is None:
            return FinishReason.NONE
        return FinishReason[finish_reason_str.upper()]
@dataclass
class Choice:
    delta: Delta
    finish_reason: Optional[FinishReason]
    index: int

class Role(Enum):
    SYSTEM = "system"
    USER = "user"
    ASSISTANT = "assistant"

@dataclass
class ChatMessage:
    role: Role
    content: str

    def to_json(self : "ChatMessage"):
        return {
            "role": self.role.value,
            "content": self.content
        }

ChatHistory = List[ChatMessage]

@dataclass
class OpenAIChatResponse:
    choices: List[Choice]
    created: int
    id: str
    model: str
    object: str

    def from_json(data: Any) -> "OpenAIChatResponse":
        choices = []

        for choice in data["choices"]:
            delta = Delta(
                content=choice["delta"].get("content"),
                role=choice["delta"].get("role")
            )

            choices.append(Choice(
                delta=delta,
                finish_reason=FinishReason.from_str(choice["finish_reason"]),
                index=choice["index"],
            ))

        return OpenAIChatResponse(
            choices,
            created=data["created"],
            id=data["id"],
            model=data["model"],
            object=data["object"],
        )

OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None]

def create_chat_completion(hist : ChatHistory, args: CompletionArguments) \
    -> OpenAIChatResponseStream:

    messages = [ msg.to_json() for msg in hist ]

    response = openai.ChatCompletion.create(
        model=args.model,
        messages=messages,
        n=args.n_completions,
        temperature=args.temperature,
        presence_penalty=args.presence_penalty,
        frequency_penalty=args.frequency_penalty,
        max_tokens=args.max_tokens,
        top_p=args.top_p,
        stream=True
    )

    return (
        OpenAIChatResponse.from_json( update ) \
        for update in response
    )

def is_compatible_model(_id : str):
    ''' FIXME: There seems no better way to do this currently ... '''
    return 'gpt' in _id

def list_models() -> List[str]:

    model_data = openai.Model.list()

    models = []

    for model in model_data["data"]:
        if is_compatible_model(model["id"]):
            models.append(model["id"])

    models.sort()

    return models