aboutsummaryrefslogtreecommitdiff
path: root/src/gpt_chat_cli/openai_wrappers.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpt_chat_cli/openai_wrappers.py')
-rw-r--r--src/gpt_chat_cli/openai_wrappers.py44
1 files changed, 40 insertions, 4 deletions
diff --git a/src/gpt_chat_cli/openai_wrappers.py b/src/gpt_chat_cli/openai_wrappers.py
index 413ec24..6eeba4d 100644
--- a/src/gpt_chat_cli/openai_wrappers.py
+++ b/src/gpt_chat_cli/openai_wrappers.py
@@ -28,6 +28,24 @@ class Choice:
finish_reason: Optional[FinishReason]
index: int
+class Role(Enum):
+ SYSTEM = "system"
+ USER = "user"
+ ASSISTANT = "assistant"
+
+@dataclass
+class ChatMessage:
+ role: Role
+ content: str
+
+ def to_json(self : "ChatMessage"):
+ return {
+ "role": self.role.value,
+ "content": self.content
+ }
+
+ChatHistory = List[ChatMessage]
+
@dataclass
class OpenAIChatResponse:
choices: List[Choice]
@@ -61,13 +79,31 @@ class OpenAIChatResponse:
OpenAIChatResponseStream = Generator[OpenAIChatResponse, None, None]
-def create_chat_completion(*args, **kwargs) \
- -> OpenAIChatResponseStream:
+from .argparsing import CompletionArguments
+
+def create_chat_completion(hist : ChatHistory, args: CompletionArguments) \
+ -> OpenAIChatResponseStream:
+
+ messages = [ msg.to_json() for msg in hist ]
+
+ response = openai.ChatCompletion.create(
+ model=args.model,
+ messages=messages,
+ n=args.n_completions,
+ temperature=args.temperature,
+ presence_penalty=args.presence_penalty,
+ frequency_penalty=args.frequency_penalty,
+ max_tokens=args.max_tokens,
+ top_p=args.top_p,
+ stream=True
+ )
+
return (
- OpenAIChatResponse.from_json(update) \
- for update in openai.ChatCompletion.create(*args, **kwargs)
+ OpenAIChatResponse.from_json( update ) \
+ for update in response
)
+
def list_models() -> List[str]:
model_data = openai.Model.list()