refactor: rewrite OpenAI integration

- use `OpenAI` and `Azure.AI.OpenAI`
- use `developer` role instead of `system` for OpenAI's `o1` series models
- use streaming response
- re-design `AIAssistant`

Signed-off-by: leo <longshuang@msn.cn>
This commit is contained in:
leo 2024-12-24 15:51:27 +08:00
parent cf90e51887
commit a915708db3
No known key found for this signature in database
12 changed files with 270 additions and 212 deletions

View file

@ -1,81 +1,13 @@
using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.ClientModel;
using System.Threading;
using Azure.AI.OpenAI;
using CommunityToolkit.Mvvm.ComponentModel;
using OpenAI;
using OpenAI.Chat;
namespace SourceGit.Models
{
public class OpenAIChatMessage
{
[JsonPropertyName("role")]
public string Role
{
get;
set;
}
[JsonPropertyName("content")]
public string Content
{
get;
set;
}
}
public class OpenAIChatChoice
{
[JsonPropertyName("index")]
public int Index
{
get;
set;
}
[JsonPropertyName("message")]
public OpenAIChatMessage Message
{
get;
set;
}
}
public class OpenAIChatResponse
{
[JsonPropertyName("choices")]
public List<OpenAIChatChoice> Choices
{
get;
set;
} = [];
}
public class OpenAIChatRequest
{
[JsonPropertyName("model")]
public string Model
{
get;
set;
}
[JsonPropertyName("messages")]
public List<OpenAIChatMessage> Messages
{
get;
set;
} = [];
public void AddMessage(string role, string content)
{
Messages.Add(new OpenAIChatMessage { Role = role, Content = content });
}
}
public class OpenAIService : ObservableObject
{
public string Name
@ -147,48 +79,48 @@ namespace SourceGit.Models
""";
}
public OpenAIChatResponse Chat(string prompt, string question, CancellationToken cancellation)
public void Chat(string prompt, string question, CancellationToken cancellation, Action<string> onUpdate)
{
var chat = new OpenAIChatRequest() { Model = Model };
chat.AddMessage("user", prompt);
chat.AddMessage("user", question);
var client = new HttpClient() { Timeout = TimeSpan.FromSeconds(60) };
if (!string.IsNullOrEmpty(ApiKey))
Uri server = new(Server);
ApiKeyCredential key = new(ApiKey);
ChatClient client = null;
if (Server.Contains("openai.azure.com/", StringComparison.Ordinal))
{
if (Server.Contains("openai.azure.com/", StringComparison.Ordinal))
client.DefaultRequestHeaders.Add("api-key", ApiKey);
else
client.DefaultRequestHeaders.Add("Authorization", $"Bearer {ApiKey}");
var azure = new AzureOpenAIClient(server, key);
client = azure.GetChatClient(Model);
}
else
{
var openai = new OpenAIClient(key, new() { Endpoint = server });
client = openai.GetChatClient(Model);
}
var req = new StringContent(JsonSerializer.Serialize(chat, JsonCodeGen.Default.OpenAIChatRequest), Encoding.UTF8, "application/json");
try
{
var task = client.PostAsync(Server, req, cancellation);
task.Wait(cancellation);
var updates = client.CompleteChatStreaming([
ShouldUseDeveloperPrompt() ? new DeveloperChatMessage(prompt) : new SystemChatMessage(prompt),
new UserChatMessage(question),
], null, cancellation);
var rsp = task.Result;
var reader = rsp.Content.ReadAsStringAsync(cancellation);
reader.Wait(cancellation);
var body = reader.Result;
if (!rsp.IsSuccessStatusCode)
foreach (var update in updates)
{
throw new Exception($"AI service returns error code {rsp.StatusCode}. Body: {body ?? string.Empty}");
if (update.ContentUpdate.Count > 0)
onUpdate.Invoke(update.ContentUpdate[0].Text);
}
return JsonSerializer.Deserialize(reader.Result, JsonCodeGen.Default.OpenAIChatResponse);
}
catch
{
if (cancellation.IsCancellationRequested)
return null;
throw;
if (!cancellation.IsCancellationRequested)
throw;
}
}
private bool ShouldUseDeveloperPrompt()
{
return _model.Equals("o1", StringComparison.Ordinal) ||
_model.Equals("o1-mini", StringComparison.Ordinal);
}
private string _name;
private string _server;
private string _apiKey;