mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-28 06:25:00 +00:00
feat(assistant): Assistant and AssistantFiles api (#1803)
* Initial implementation of assistants api * Move load/save configs to utils * Save assistant and assistantfiles config to disk. * Add tsets for assistant api * Fix models path spelling mistake. * Remove personal go.mod information --------- Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
parent
b7ffe66219
commit
2d7913b3be
8 changed files with 1108 additions and 61 deletions
456
core/http/endpoints/openai/assistant_test.go
Normal file
456
core/http/endpoints/openai/assistant_test.go
Normal file
|
@ -0,0 +1,456 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/go-skynet/LocalAI/core/config"
|
||||
"github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var configsDir string = "/tmp/localai/configs"
|
||||
|
||||
type MockLoader struct {
|
||||
models []string
|
||||
}
|
||||
|
||||
func tearDown() func() {
|
||||
return func() {
|
||||
UploadedFiles = []File{}
|
||||
Assistants = []Assistant{}
|
||||
AssistantFiles = []AssistantFile{}
|
||||
_ = os.Remove(filepath.Join(configsDir, AssistantsConfigFile))
|
||||
_ = os.Remove(filepath.Join(configsDir, AssistantsFileConfigFile))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssistantEndpoints(t *testing.T) {
|
||||
// Preparing the mocked objects
|
||||
cl := &config.BackendConfigLoader{}
|
||||
//configsDir := "/tmp/localai/configs"
|
||||
modelPath := "/tmp/localai/model"
|
||||
var ml = model.NewModelLoader(modelPath)
|
||||
|
||||
appConfig := &config.ApplicationConfig{
|
||||
ConfigsDir: configsDir,
|
||||
UploadLimitMB: 10,
|
||||
UploadDir: "test_dir",
|
||||
ModelPath: modelPath,
|
||||
}
|
||||
|
||||
_ = os.RemoveAll(appConfig.ConfigsDir)
|
||||
_ = os.MkdirAll(appConfig.ConfigsDir, 0755)
|
||||
_ = os.MkdirAll(modelPath, 0755)
|
||||
os.Create(filepath.Join(modelPath, "ggml-gpt4all-j"))
|
||||
|
||||
app := fiber.New(fiber.Config{
|
||||
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||
})
|
||||
|
||||
// Create a Test Server
|
||||
app.Get("/assistants", ListAssistantsEndpoint(cl, ml, appConfig))
|
||||
app.Post("/assistants", CreateAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Delete("/assistants/:assistant_id", DeleteAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Get("/assistants/:assistant_id", GetAssistantEndpoint(cl, ml, appConfig))
|
||||
app.Post("/assistants/:assistant_id", ModifyAssistantEndpoint(cl, ml, appConfig))
|
||||
|
||||
app.Post("/files", UploadFilesEndpoint(cl, appConfig))
|
||||
app.Get("/assistants/:assistant_id/files", ListAssistantFilesEndpoint(cl, ml, appConfig))
|
||||
app.Post("/assistants/:assistant_id/files", CreateAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Delete("/assistants/:assistant_id/files/:file_id", DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
||||
app.Get("/assistants/:assistant_id/files/:file_id", GetAssistantFileEndpoint(cl, ml, appConfig))
|
||||
|
||||
t.Run("CreateAssistantEndpoint", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
ar := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: "3.5-turbo",
|
||||
Description: "Test Assistant",
|
||||
Instructions: "You are computer science teacher answering student questions",
|
||||
Tools: []Tool{{Type: Function}},
|
||||
FileIDs: nil,
|
||||
Metadata: nil,
|
||||
}
|
||||
|
||||
resultAssistant, resp, err := createAssistant(app, *ar)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
assert.Equal(t, 1, len(Assistants))
|
||||
//t.Cleanup(cleanupAllAssistants(t, app, []string{resultAssistant.ID}))
|
||||
|
||||
assert.Equal(t, ar.Name, resultAssistant.Name)
|
||||
assert.Equal(t, ar.Model, resultAssistant.Model)
|
||||
assert.Equal(t, ar.Tools, resultAssistant.Tools)
|
||||
assert.Equal(t, ar.Description, resultAssistant.Description)
|
||||
assert.Equal(t, ar.Instructions, resultAssistant.Instructions)
|
||||
assert.Equal(t, ar.FileIDs, resultAssistant.FileIDs)
|
||||
assert.Equal(t, ar.Metadata, resultAssistant.Metadata)
|
||||
})
|
||||
|
||||
t.Run("ListAssistantsEndpoint", func(t *testing.T) {
|
||||
var ids []string
|
||||
var resultAssistant []Assistant
|
||||
for i := 0; i < 4; i++ {
|
||||
ar := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: fmt.Sprintf("3.5-turbo-%d", i),
|
||||
Description: fmt.Sprintf("Test Assistant - %d", i),
|
||||
Instructions: fmt.Sprintf("You are computer science teacher answering student questions - %d", i),
|
||||
Tools: []Tool{{Type: Function}},
|
||||
FileIDs: []string{"fid-1234"},
|
||||
Metadata: map[string]string{"meta": "data"},
|
||||
}
|
||||
|
||||
//var err error
|
||||
ra, _, err := createAssistant(app, *ar)
|
||||
// Because we create the assistants so fast all end up with the same created time.
|
||||
time.Sleep(time.Second)
|
||||
resultAssistant = append(resultAssistant, ra)
|
||||
assert.NoError(t, err)
|
||||
ids = append(ids, resultAssistant[i].ID)
|
||||
}
|
||||
|
||||
t.Cleanup(cleanupAllAssistants(t, app, ids))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
reqURL string
|
||||
expectedStatus int
|
||||
expectedResult []Assistant
|
||||
expectedStringResult string
|
||||
}{
|
||||
{
|
||||
name: "Valid Usage - limit only",
|
||||
reqURL: "/assistants?limit=2",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: Assistants[:2], // Expecting the first two assistants
|
||||
},
|
||||
{
|
||||
name: "Valid Usage - order asc",
|
||||
reqURL: "/assistants?order=asc",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: Assistants, // Expecting all assistants in ascending order
|
||||
},
|
||||
{
|
||||
name: "Valid Usage - order desc",
|
||||
reqURL: "/assistants?order=desc",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: []Assistant{Assistants[3], Assistants[2], Assistants[1], Assistants[0]}, // Expecting all assistants in descending order
|
||||
},
|
||||
{
|
||||
name: "Valid Usage - after specific ID",
|
||||
reqURL: "/assistants?after=2",
|
||||
expectedStatus: http.StatusOK,
|
||||
// Note this is correct because it's put in descending order already
|
||||
expectedResult: Assistants[:3], // Expecting assistants after (excluding) ID 2
|
||||
},
|
||||
{
|
||||
name: "Valid Usage - before specific ID",
|
||||
reqURL: "/assistants?before=4",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: Assistants[2:], // Expecting assistants before (excluding) ID 3.
|
||||
},
|
||||
{
|
||||
name: "Invalid Usage - non-integer limit",
|
||||
reqURL: "/assistants?limit=two",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStringResult: "Invalid limit query value: two",
|
||||
},
|
||||
{
|
||||
name: "Invalid Usage - non-existing id in after",
|
||||
reqURL: "/assistants?after=100",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: []Assistant(nil), // Expecting empty list as there are no IDs above 100
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodGet, tt.reqURL, nil)
|
||||
response, err := app.Test(request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedStatus, response.StatusCode)
|
||||
if tt.expectedStatus != fiber.StatusOK {
|
||||
all, _ := ioutil.ReadAll(response.Body)
|
||||
assert.Equal(t, tt.expectedStringResult, string(all))
|
||||
} else {
|
||||
var result []Assistant
|
||||
err = json.NewDecoder(response.Body).Decode(&result)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.expectedResult, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DeleteAssistantEndpoint", func(t *testing.T) {
|
||||
ar := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: "3.5-turbo",
|
||||
Description: "Test Assistant",
|
||||
Instructions: "You are computer science teacher answering student questions",
|
||||
Tools: []Tool{{Type: Function}},
|
||||
FileIDs: nil,
|
||||
Metadata: nil,
|
||||
}
|
||||
|
||||
resultAssistant, _, err := createAssistant(app, *ar)
|
||||
assert.NoError(t, err)
|
||||
|
||||
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID)
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||
_, err = app.Test(deleteReq)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(Assistants))
|
||||
})
|
||||
|
||||
t.Run("GetAssistantEndpoint", func(t *testing.T) {
|
||||
ar := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: "3.5-turbo",
|
||||
Description: "Test Assistant",
|
||||
Instructions: "You are computer science teacher answering student questions",
|
||||
Tools: []Tool{{Type: Function}},
|
||||
FileIDs: nil,
|
||||
Metadata: nil,
|
||||
}
|
||||
|
||||
resultAssistant, _, err := createAssistant(app, *ar)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(cleanupAllAssistants(t, app, []string{resultAssistant.ID}))
|
||||
|
||||
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID)
|
||||
request := httptest.NewRequest(http.MethodGet, target, nil)
|
||||
response, err := app.Test(request)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var getAssistant Assistant
|
||||
err = json.NewDecoder(response.Body).Decode(&getAssistant)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, resultAssistant.ID, getAssistant.ID)
|
||||
})
|
||||
|
||||
t.Run("ModifyAssistantEndpoint", func(t *testing.T) {
|
||||
ar := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: "3.5-turbo",
|
||||
Description: "Test Assistant",
|
||||
Instructions: "You are computer science teacher answering student questions",
|
||||
Tools: []Tool{{Type: Function}},
|
||||
FileIDs: nil,
|
||||
Metadata: nil,
|
||||
}
|
||||
|
||||
resultAssistant, _, err := createAssistant(app, *ar)
|
||||
assert.NoError(t, err)
|
||||
|
||||
modifiedAr := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: "4.0-turbo",
|
||||
Description: "Modified Test Assistant",
|
||||
Instructions: "You are math teacher answering student questions",
|
||||
Tools: []Tool{{Type: CodeInterpreter}},
|
||||
FileIDs: nil,
|
||||
Metadata: nil,
|
||||
}
|
||||
|
||||
modifiedArJson, err := json.Marshal(modifiedAr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID)
|
||||
request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(modifiedArJson)))
|
||||
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||
|
||||
modifyResponse, err := app.Test(request)
|
||||
assert.NoError(t, err)
|
||||
var getAssistant Assistant
|
||||
err = json.NewDecoder(modifyResponse.Body).Decode(&getAssistant)
|
||||
|
||||
t.Cleanup(cleanupAllAssistants(t, app, []string{getAssistant.ID}))
|
||||
|
||||
assert.Equal(t, resultAssistant.ID, getAssistant.ID) // IDs should match even if contents change
|
||||
assert.Equal(t, modifiedAr.Tools, getAssistant.Tools)
|
||||
assert.Equal(t, modifiedAr.Name, getAssistant.Name)
|
||||
assert.Equal(t, modifiedAr.Instructions, getAssistant.Instructions)
|
||||
assert.Equal(t, modifiedAr.Description, getAssistant.Description)
|
||||
})
|
||||
|
||||
t.Run("CreateAssistantFileEndpoint", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
afr := AssistantFileRequest{FileID: file.ID}
|
||||
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, assistant.ID, af.AssistantID)
|
||||
})
|
||||
t.Run("ListAssistantFilesEndpoint", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
afr := AssistantFileRequest{FileID: file.ID}
|
||||
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, assistant.ID, af.AssistantID)
|
||||
})
|
||||
t.Run("GetAssistantFileEndpoint", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
afr := AssistantFileRequest{FileID: file.ID}
|
||||
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(cleanupAssistantFile(t, app, af.ID, af.AssistantID))
|
||||
|
||||
target := fmt.Sprintf("/assistants/%s/files/%s", assistant.ID, file.ID)
|
||||
request := httptest.NewRequest(http.MethodGet, target, nil)
|
||||
response, err := app.Test(request)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var assistantFile AssistantFile
|
||||
err = json.NewDecoder(response.Body).Decode(&assistantFile)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, af.ID, assistantFile.ID)
|
||||
assert.Equal(t, af.AssistantID, assistantFile.AssistantID)
|
||||
})
|
||||
t.Run("DeleteAssistantFileEndpoint", func(t *testing.T) {
|
||||
t.Cleanup(tearDown())
|
||||
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
afr := AssistantFileRequest{FileID: file.ID}
|
||||
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleanupAssistantFile(t, app, af.ID, af.AssistantID)()
|
||||
|
||||
assert.Empty(t, AssistantFiles)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func createFileAndAssistant(t *testing.T, app *fiber.App, o *config.ApplicationConfig) (File, Assistant, error) {
|
||||
ar := &AssistantRequest{
|
||||
Model: "ggml-gpt4all-j",
|
||||
Name: "3.5-turbo",
|
||||
Description: "Test Assistant",
|
||||
Instructions: "You are computer science teacher answering student questions",
|
||||
Tools: []Tool{{Type: Function}},
|
||||
FileIDs: nil,
|
||||
Metadata: nil,
|
||||
}
|
||||
|
||||
assistant, _, err := createAssistant(app, *ar)
|
||||
if err != nil {
|
||||
return File{}, Assistant{}, err
|
||||
}
|
||||
t.Cleanup(cleanupAllAssistants(t, app, []string{assistant.ID}))
|
||||
|
||||
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, o)
|
||||
t.Cleanup(func() {
|
||||
_, err := CallFilesDeleteEndpoint(t, app, file.ID)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
return file, assistant, nil
|
||||
}
|
||||
|
||||
func createAssistantFile(app *fiber.App, afr AssistantFileRequest, assistantId string) (AssistantFile, *http.Response, error) {
|
||||
afrJson, err := json.Marshal(afr)
|
||||
if err != nil {
|
||||
return AssistantFile{}, nil, err
|
||||
}
|
||||
|
||||
target := fmt.Sprintf("/assistants/%s/files", assistantId)
|
||||
request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(afrJson)))
|
||||
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||
request.Header.Set("OpenAi-Beta", "assistants=v1")
|
||||
|
||||
resp, err := app.Test(request)
|
||||
if err != nil {
|
||||
return AssistantFile{}, resp, err
|
||||
}
|
||||
|
||||
var assistantFile AssistantFile
|
||||
all, err := ioutil.ReadAll(resp.Body)
|
||||
err = json.NewDecoder(strings.NewReader(string(all))).Decode(&assistantFile)
|
||||
if err != nil {
|
||||
return AssistantFile{}, resp, err
|
||||
}
|
||||
|
||||
return assistantFile, resp, nil
|
||||
}
|
||||
|
||||
func createAssistant(app *fiber.App, ar AssistantRequest) (Assistant, *http.Response, error) {
|
||||
assistant, err := json.Marshal(ar)
|
||||
if err != nil {
|
||||
return Assistant{}, nil, err
|
||||
}
|
||||
|
||||
request := httptest.NewRequest(http.MethodPost, "/assistants", strings.NewReader(string(assistant)))
|
||||
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||
request.Header.Set("OpenAi-Beta", "assistants=v1")
|
||||
|
||||
resp, err := app.Test(request)
|
||||
if err != nil {
|
||||
return Assistant{}, resp, err
|
||||
}
|
||||
|
||||
bodyString, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return Assistant{}, resp, err
|
||||
}
|
||||
|
||||
var resultAssistant Assistant
|
||||
err = json.NewDecoder(strings.NewReader(string(bodyString))).Decode(&resultAssistant)
|
||||
|
||||
return resultAssistant, resp, nil
|
||||
}
|
||||
|
||||
func cleanupAllAssistants(t *testing.T, app *fiber.App, ids []string) func() {
|
||||
return func() {
|
||||
for _, assistant := range ids {
|
||||
target := fmt.Sprintf("/assistants/%s", assistant)
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||
_, err := app.Test(deleteReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete assistant %s: %v", assistant, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupAssistantFile(t *testing.T, app *fiber.App, fileId, assistantId string) func() {
|
||||
return func() {
|
||||
target := fmt.Sprintf("/assistants/%s/files/%s", assistantId, fileId)
|
||||
request := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||
request.Header.Set("OpenAi-Beta", "assistants=v1")
|
||||
|
||||
resp, err := app.Test(request)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var dafr DeleteAssistantFileResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&dafr)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, dafr.Deleted)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue