From c3eca0d7fdea2c28084db3732875becac6874c28 Mon Sep 17 00:00:00 2001 From: leo Date: Wed, 19 Feb 2025 18:01:16 +0800 Subject: [PATCH] refactor: OpenAI integration (#996) - Add `OpenAIResponse` to trim the `...` block - Add an `Enable Streaming` option to fix the issue that some services do not support streaming output Signed-off-by: leo --- src/Commands/GenerateCommitMessage.cs | 81 +++------------ src/Models/OpenAI.cs | 136 +++++++++++++++++++++++--- src/Resources/Locales/en_US.axaml | 1 + src/Resources/Locales/zh_CN.axaml | 1 + src/Resources/Locales/zh_TW.axaml | 7 +- src/Views/Preferences.axaml | 4 + 6 files changed, 145 insertions(+), 85 deletions(-) diff --git a/src/Commands/GenerateCommitMessage.cs b/src/Commands/GenerateCommitMessage.cs index 4b18a561..df61fdd2 100644 --- a/src/Commands/GenerateCommitMessage.cs +++ b/src/Commands/GenerateCommitMessage.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Text; -using System.Text.RegularExpressions; using System.Threading; using Avalonia.Threading; @@ -36,6 +35,8 @@ namespace SourceGit.Commands { try { + _onResponse?.Invoke("Waiting for pre-file analyzing to completed...\n\n"); + var responseBuilder = new StringBuilder(); var summaryBuilder = new StringBuilder(); foreach (var change in _changes) @@ -49,18 +50,17 @@ namespace SourceGit.Commands var rs = new GetDiffContent(_repo, new Models.DiffOption(change, false)).ReadToEnd(); if (rs.IsSuccess) { - var hasFirstValidChar = false; - var thinkingBuffer = new StringBuilder(); _service.Chat( _service.AnalyzeDiffPrompt, $"Here is the `git diff` output: {rs.StdOut}", _cancelToken, update => - ProcessChatResponse(update, ref hasFirstValidChar, thinkingBuffer, - (responseBuilder, text => - _onResponse?.Invoke( - $"Waiting for pre-file analyzing to completed...\n\n{text}")), - (summaryBuilder, null))); + { + responseBuilder.Append(update); + summaryBuilder.Append(update); + + _onResponse?.Invoke($"Waiting for pre-file analyzing to completed...\n\n{responseBuilder}"); + }); } responseBuilder.Append("\n"); @@ -74,15 +74,15 @@ namespace SourceGit.Commands var responseBody = responseBuilder.ToString(); var subjectBuilder = new StringBuilder(); - var hasSubjectFirstValidChar = false; - var subjectThinkingBuffer = new StringBuilder(); _service.Chat( _service.GenerateSubjectPrompt, $"Here are the summaries changes:\n{summaryBuilder}", _cancelToken, update => - ProcessChatResponse(update, ref hasSubjectFirstValidChar, subjectThinkingBuffer, - (subjectBuilder, text => _onResponse?.Invoke($"{text}\n\n{responseBody}")))); + { + subjectBuilder.Append(update); + _onResponse?.Invoke($"{subjectBuilder}\n\n{responseBody}"); + }); } catch (Exception e) { @@ -90,67 +90,10 @@ namespace SourceGit.Commands } } - private void ProcessChatResponse( - string update, - ref bool hasFirstValidChar, - StringBuilder thinkingBuffer, - params (StringBuilder builder, Action callback)[] outputs) - { - if (!hasFirstValidChar) - { - update = update.TrimStart(); - if (string.IsNullOrEmpty(update)) - return; - if (update.StartsWith("<", StringComparison.Ordinal)) - thinkingBuffer.Append(update); - hasFirstValidChar = true; - } - - if (thinkingBuffer.Length > 0) - thinkingBuffer.Append(update); - - if (thinkingBuffer.Length > 15) - { - var match = REG_COT.Match(thinkingBuffer.ToString()); - if (match.Success) - { - update = REG_COT.Replace(thinkingBuffer.ToString(), "").TrimStart(); - if (update.Length > 0) - { - foreach (var output in outputs) - output.builder.Append(update); - thinkingBuffer.Clear(); - } - return; - } - - match = REG_THINK_START.Match(thinkingBuffer.ToString()); - if (!match.Success) - { - foreach (var output in outputs) - output.builder.Append(thinkingBuffer); - thinkingBuffer.Clear(); - return; - } - } - - if (thinkingBuffer.Length == 0) - { - foreach (var output in outputs) - { - output.builder.Append(update); - output.callback?.Invoke(output.builder.ToString()); - } - } - } - private Models.OpenAIService _service; private string _repo; private List _changes; private CancellationToken _cancelToken; private Action _onResponse; - - private static readonly Regex REG_COT = new(@"^<(think|thought|thinking|thought_chain)>(.*?)", RegexOptions.Singleline); - private static readonly Regex REG_THINK_START = new(@"^<(think|thought|thinking|thought_chain)>", RegexOptions.Singleline); } } diff --git a/src/Models/OpenAI.cs b/src/Models/OpenAI.cs index a6648c11..264230c6 100644 --- a/src/Models/OpenAI.cs +++ b/src/Models/OpenAI.cs @@ -1,5 +1,8 @@ using System; using System.ClientModel; +using System.Collections.Generic; +using System.Text; +using System.Text.RegularExpressions; using System.Threading; using Azure.AI.OpenAI; using CommunityToolkit.Mvvm.ComponentModel; @@ -8,6 +11,91 @@ using OpenAI.Chat; namespace SourceGit.Models { + public partial class OpenAIResponse + { + public OpenAIResponse(Action onUpdate) + { + _onUpdate = onUpdate; + } + + public void Append(string text) + { + var buffer = text; + + if (_thinkTail.Length > 0) + { + _thinkTail.Append(buffer); + buffer = _thinkTail.ToString(); + _thinkTail.Clear(); + } + + buffer = REG_COT().Replace(buffer, ""); + + var startIdx = buffer.IndexOf('<', StringComparison.Ordinal); + if (startIdx >= 0) + { + if (startIdx > 0) + OnReceive(buffer.Substring(0, startIdx)); + + var endIdx = buffer.IndexOf(">", startIdx + 1, StringComparison.Ordinal); + if (endIdx <= startIdx) + { + if (buffer.Length - startIdx <= 15) + _thinkTail.Append(buffer.Substring(startIdx)); + else + OnReceive(buffer.Substring(startIdx)); + } + else if (endIdx < startIdx + 15) + { + var tag = buffer.Substring(startIdx + 1, endIdx - startIdx - 1); + if (_thinkTags.Contains(tag)) + _thinkTail.Append(buffer.Substring(startIdx)); + else + OnReceive(buffer.Substring(startIdx)); + } + else + { + OnReceive(buffer.Substring(startIdx)); + } + } + else + { + OnReceive(buffer); + } + } + + public void End() + { + if (_thinkTail.Length > 0) + { + OnReceive(_thinkTail.ToString()); + _thinkTail.Clear(); + } + } + + private void OnReceive(string text) + { + if (!_hasTrimmedStart) + { + text = text.TrimStart(); + if (string.IsNullOrEmpty(text)) + return; + + _hasTrimmedStart = true; + } + + _onUpdate.Invoke(text); + } + + [GeneratedRegex(@"<(think|thought|thinking|thought_chain)>.*?", RegexOptions.Singleline)] + private static partial Regex REG_COT(); + + private Action _onUpdate = null; + private StringBuilder _thinkTail = new StringBuilder(); + private HashSet _thinkTags = ["think", "thought", "thinking", "thought_chain"]; + private bool _hasTrimmedStart = false; + } + public class OpenAIService : ObservableObject { public string Name @@ -42,6 +130,12 @@ namespace SourceGit.Models set => SetProperty(ref _model, value); } + public bool Streaming + { + get => _streaming; + set => SetProperty(ref _streaming, value); + } + public string AnalyzeDiffPrompt { get => _analyzeDiffPrompt; @@ -89,32 +183,47 @@ namespace SourceGit.Models public void Chat(string prompt, string question, CancellationToken cancellation, Action onUpdate) { - Uri server = new(Server); - ApiKeyCredential key = new(ApiKey); - ChatClient client = null; - if (Server.Contains("openai.azure.com/", StringComparison.Ordinal)) + var server = new Uri(_server); + var key = new ApiKeyCredential(_apiKey); + var client = null as ChatClient; + if (_server.Contains("openai.azure.com/", StringComparison.Ordinal)) { var azure = new AzureOpenAIClient(server, key); - client = azure.GetChatClient(Model); + client = azure.GetChatClient(_model); } else { var openai = new OpenAIClient(key, new() { Endpoint = server }); - client = openai.GetChatClient(Model); + client = openai.GetChatClient(_model); } + var messages = new List(); + messages.Add(_model.Equals("o1-mini", StringComparison.Ordinal) ? new UserChatMessage(prompt) : new SystemChatMessage(prompt)); + messages.Add(new UserChatMessage(question)); + try { - var updates = client.CompleteChatStreaming([ - _model.Equals("o1-mini", StringComparison.Ordinal) ? new UserChatMessage(prompt) : new SystemChatMessage(prompt), - new UserChatMessage(question), - ], null, cancellation); + var rsp = new OpenAIResponse(onUpdate); - foreach (var update in updates) + if (_streaming) { - if (update.ContentUpdate.Count > 0) - onUpdate.Invoke(update.ContentUpdate[0].Text); + var updates = client.CompleteChatStreaming(messages, null, cancellation); + + foreach (var update in updates) + { + if (update.ContentUpdate.Count > 0) + rsp.Append(update.ContentUpdate[0].Text); + } } + else + { + var completion = client.CompleteChat(messages, null, cancellation); + + if (completion.Value.Content.Count > 0) + rsp.Append(completion.Value.Content[0].Text); + } + + rsp.End(); } catch { @@ -127,6 +236,7 @@ namespace SourceGit.Models private string _server; private string _apiKey; private string _model; + private bool _streaming = true; private string _analyzeDiffPrompt; private string _generateSubjectPrompt; } diff --git a/src/Resources/Locales/en_US.axaml b/src/Resources/Locales/en_US.axaml index cd266666..0da8b5ed 100644 --- a/src/Resources/Locales/en_US.axaml +++ b/src/Resources/Locales/en_US.axaml @@ -459,6 +459,7 @@ Model Name Server + Enable Streaming APPEARANCE Default Font Font Size diff --git a/src/Resources/Locales/zh_CN.axaml b/src/Resources/Locales/zh_CN.axaml index 4db909c9..983775e7 100644 --- a/src/Resources/Locales/zh_CN.axaml +++ b/src/Resources/Locales/zh_CN.axaml @@ -462,6 +462,7 @@ 模型 配置名称 服务地址 + 启用流式输出 外观配置 缺省字体 字体大小 diff --git a/src/Resources/Locales/zh_TW.axaml b/src/Resources/Locales/zh_TW.axaml index a3fe5c19..2a5ced45 100644 --- a/src/Resources/Locales/zh_TW.axaml +++ b/src/Resources/Locales/zh_TW.axaml @@ -456,12 +456,13 @@ {0} 年前 偏好設定 AI - 伺服器 + 分析變更差異提示詞 API 金鑰 + 產生提交訊息提示詞 模型 名稱 - 分析變更差異提示詞 - 產生提交訊息提示詞 + 伺服器 + 啟用串流輸出 外觀設定 預設字型 字型大小 diff --git a/src/Views/Preferences.axaml b/src/Views/Preferences.axaml index 1d282ad9..3bdd150a 100644 --- a/src/Views/Preferences.axaml +++ b/src/Views/Preferences.axaml @@ -616,6 +616,10 @@ Text="{Binding GenerateSubjectPrompt, Mode=TwoWay}" AcceptsReturn="true" TextWrapping="Wrap"/> + +