refactor: OpenAI integration (#996)

- Add `OpenAIResponse` to trim the `<think>...</think>` block
- Add an `Enable Streaming` option to fix the issue that some services do not support streaming output

Signed-off-by: leo <longshuang@msn.cn>
This commit is contained in:
leo 2025-02-19 18:01:16 +08:00
parent 69d107430a
commit c3eca0d7fd
No known key found for this signature in database
6 changed files with 145 additions and 85 deletions

View file

@ -1,5 +1,8 @@
using System;
using System.ClientModel;
using System.Collections.Generic;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
using Azure.AI.OpenAI;
using CommunityToolkit.Mvvm.ComponentModel;
@ -8,6 +11,91 @@ using OpenAI.Chat;
namespace SourceGit.Models
{
public partial class OpenAIResponse
{
public OpenAIResponse(Action<string> onUpdate)
{
_onUpdate = onUpdate;
}
public void Append(string text)
{
var buffer = text;
if (_thinkTail.Length > 0)
{
_thinkTail.Append(buffer);
buffer = _thinkTail.ToString();
_thinkTail.Clear();
}
buffer = REG_COT().Replace(buffer, "");
var startIdx = buffer.IndexOf('<', StringComparison.Ordinal);
if (startIdx >= 0)
{
if (startIdx > 0)
OnReceive(buffer.Substring(0, startIdx));
var endIdx = buffer.IndexOf(">", startIdx + 1, StringComparison.Ordinal);
if (endIdx <= startIdx)
{
if (buffer.Length - startIdx <= 15)
_thinkTail.Append(buffer.Substring(startIdx));
else
OnReceive(buffer.Substring(startIdx));
}
else if (endIdx < startIdx + 15)
{
var tag = buffer.Substring(startIdx + 1, endIdx - startIdx - 1);
if (_thinkTags.Contains(tag))
_thinkTail.Append(buffer.Substring(startIdx));
else
OnReceive(buffer.Substring(startIdx));
}
else
{
OnReceive(buffer.Substring(startIdx));
}
}
else
{
OnReceive(buffer);
}
}
public void End()
{
if (_thinkTail.Length > 0)
{
OnReceive(_thinkTail.ToString());
_thinkTail.Clear();
}
}
private void OnReceive(string text)
{
if (!_hasTrimmedStart)
{
text = text.TrimStart();
if (string.IsNullOrEmpty(text))
return;
_hasTrimmedStart = true;
}
_onUpdate.Invoke(text);
}
[GeneratedRegex(@"<(think|thought|thinking|thought_chain)>.*?</\1>", RegexOptions.Singleline)]
private static partial Regex REG_COT();
private Action<string> _onUpdate = null;
private StringBuilder _thinkTail = new StringBuilder();
private HashSet<string> _thinkTags = ["think", "thought", "thinking", "thought_chain"];
private bool _hasTrimmedStart = false;
}
public class OpenAIService : ObservableObject
{
public string Name
@ -42,6 +130,12 @@ namespace SourceGit.Models
set => SetProperty(ref _model, value);
}
public bool Streaming
{
get => _streaming;
set => SetProperty(ref _streaming, value);
}
public string AnalyzeDiffPrompt
{
get => _analyzeDiffPrompt;
@ -89,32 +183,47 @@ namespace SourceGit.Models
public void Chat(string prompt, string question, CancellationToken cancellation, Action<string> onUpdate)
{
Uri server = new(Server);
ApiKeyCredential key = new(ApiKey);
ChatClient client = null;
if (Server.Contains("openai.azure.com/", StringComparison.Ordinal))
var server = new Uri(_server);
var key = new ApiKeyCredential(_apiKey);
var client = null as ChatClient;
if (_server.Contains("openai.azure.com/", StringComparison.Ordinal))
{
var azure = new AzureOpenAIClient(server, key);
client = azure.GetChatClient(Model);
client = azure.GetChatClient(_model);
}
else
{
var openai = new OpenAIClient(key, new() { Endpoint = server });
client = openai.GetChatClient(Model);
client = openai.GetChatClient(_model);
}
var messages = new List<ChatMessage>();
messages.Add(_model.Equals("o1-mini", StringComparison.Ordinal) ? new UserChatMessage(prompt) : new SystemChatMessage(prompt));
messages.Add(new UserChatMessage(question));
try
{
var updates = client.CompleteChatStreaming([
_model.Equals("o1-mini", StringComparison.Ordinal) ? new UserChatMessage(prompt) : new SystemChatMessage(prompt),
new UserChatMessage(question),
], null, cancellation);
var rsp = new OpenAIResponse(onUpdate);
foreach (var update in updates)
if (_streaming)
{
if (update.ContentUpdate.Count > 0)
onUpdate.Invoke(update.ContentUpdate[0].Text);
var updates = client.CompleteChatStreaming(messages, null, cancellation);
foreach (var update in updates)
{
if (update.ContentUpdate.Count > 0)
rsp.Append(update.ContentUpdate[0].Text);
}
}
else
{
var completion = client.CompleteChat(messages, null, cancellation);
if (completion.Value.Content.Count > 0)
rsp.Append(completion.Value.Content[0].Text);
}
rsp.End();
}
catch
{
@ -127,6 +236,7 @@ namespace SourceGit.Models
private string _server;
private string _apiKey;
private string _model;
private bool _streaming = true;
private string _analyzeDiffPrompt;
private string _generateSubjectPrompt;
}