mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat(assistant): Assistant and AssistantFiles api (#1803)
* Initial implementation of assistants api * Move load/save configs to utils * Save assistant and assistantfiles config to disk. * Add tsets for assistant api * Fix models path spelling mistake. * Remove personal go.mod information --------- Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
parent
b7ffe66219
commit
2d7913b3be
8 changed files with 1108 additions and 61 deletions
|
@ -20,6 +20,7 @@ type ApplicationConfig struct {
|
||||||
ImageDir string
|
ImageDir string
|
||||||
AudioDir string
|
AudioDir string
|
||||||
UploadDir string
|
UploadDir string
|
||||||
|
ConfigsDir string
|
||||||
CORS bool
|
CORS bool
|
||||||
PreloadJSONModels string
|
PreloadJSONModels string
|
||||||
PreloadModelsFromPath string
|
PreloadModelsFromPath string
|
||||||
|
@ -252,6 +253,12 @@ func WithUploadDir(uploadDir string) AppOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithConfigsDir(configsDir string) AppOption {
|
||||||
|
return func(o *ApplicationConfig) {
|
||||||
|
o.ConfigsDir = configsDir
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithApiKeys(apiKeys []string) AppOption {
|
func WithApiKeys(apiKeys []string) AppOption {
|
||||||
return func(o *ApplicationConfig) {
|
return func(o *ApplicationConfig) {
|
||||||
o.ApiKeys = apiKeys
|
o.ApiKeys = apiKeys
|
||||||
|
|
|
@ -3,6 +3,7 @@ package http
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -155,8 +156,17 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
||||||
}{Version: internal.PrintableVersion()})
|
}{Version: internal.PrintableVersion()})
|
||||||
})
|
})
|
||||||
|
|
||||||
// Load upload json
|
// Make sure directories exists
|
||||||
openai.LoadUploadConfig(appConfig.UploadDir)
|
os.MkdirAll(appConfig.ImageDir, 0755)
|
||||||
|
os.MkdirAll(appConfig.AudioDir, 0755)
|
||||||
|
os.MkdirAll(appConfig.UploadDir, 0755)
|
||||||
|
os.MkdirAll(appConfig.ConfigsDir, 0755)
|
||||||
|
os.MkdirAll(appConfig.ModelPath, 0755)
|
||||||
|
|
||||||
|
// Load config jsons
|
||||||
|
utils.LoadConfig(appConfig.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles)
|
||||||
|
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
|
||||||
|
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)
|
||||||
|
|
||||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
|
||||||
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||||
|
@ -189,6 +199,26 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
|
||||||
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
||||||
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
// assistant
|
||||||
|
app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
||||||
|
app.Get("/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/v1/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/assistants", openai.CreateAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
||||||
|
app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
|
app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
|
app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
|
app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
|
app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// files
|
// files
|
||||||
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
||||||
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
|
||||||
|
|
515
core/http/endpoints/openai/assistant.go
Normal file
515
core/http/endpoints/openai/assistant.go
Normal file
|
@ -0,0 +1,515 @@
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToolType defines a type for tool options
|
||||||
|
type ToolType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
CodeInterpreter ToolType = "code_interpreter"
|
||||||
|
Retrieval ToolType = "retrieval"
|
||||||
|
Function ToolType = "function"
|
||||||
|
|
||||||
|
MaxCharacterInstructions = 32768
|
||||||
|
MaxCharacterDescription = 512
|
||||||
|
MaxCharacterName = 256
|
||||||
|
MaxToolsSize = 128
|
||||||
|
MaxFileIdSize = 20
|
||||||
|
MaxCharacterMetadataKey = 64
|
||||||
|
MaxCharacterMetadataValue = 512
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type ToolType `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assistant represents the structure of an assistant object from the OpenAI API.
|
||||||
|
type Assistant struct {
|
||||||
|
ID string `json:"id"` // The unique identifier of the assistant.
|
||||||
|
Object string `json:"object"` // Object type, which is "assistant".
|
||||||
|
Created int64 `json:"created"` // The time at which the assistant was created.
|
||||||
|
Model string `json:"model"` // The model ID used by the assistant.
|
||||||
|
Name string `json:"name,omitempty"` // The name of the assistant.
|
||||||
|
Description string `json:"description,omitempty"` // The description of the assistant.
|
||||||
|
Instructions string `json:"instructions,omitempty"` // The system instructions that the assistant uses.
|
||||||
|
Tools []Tool `json:"tools,omitempty"` // A list of tools enabled on the assistant.
|
||||||
|
FileIDs []string `json:"file_ids,omitempty"` // A list of file IDs attached to this assistant.
|
||||||
|
Metadata map[string]string `json:"metadata,omitempty"` // Set of key-value pairs attached to the assistant.
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
Assistants = []Assistant{} // better to return empty array instead of "null"
|
||||||
|
AssistantsConfigFile = "assistants.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AssistantRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Instructions string `json:"instructions,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
FileIDs []string `json:"file_ids,omitempty"`
|
||||||
|
Metadata map[string]string `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
request := new(AssistantRequest)
|
||||||
|
if err := c.BodyParser(request); err != nil {
|
||||||
|
log.Warn().AnErr("Unable to parse AssistantRequest", err)
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||||
|
}
|
||||||
|
|
||||||
|
if !modelExists(ml, request.Model) {
|
||||||
|
log.Warn().Msgf("Model: %s was not found in list of models.", request.Model)
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("Model " + request.Model + " not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Tools == nil {
|
||||||
|
request.Tools = []Tool{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.FileIDs == nil {
|
||||||
|
request.FileIDs = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Metadata == nil {
|
||||||
|
request.Metadata = make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
id := "asst_" + strconv.FormatInt(generateRandomID(), 10)
|
||||||
|
|
||||||
|
assistant := Assistant{
|
||||||
|
ID: id,
|
||||||
|
Object: "assistant",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Model: request.Model,
|
||||||
|
Name: request.Name,
|
||||||
|
Description: request.Description,
|
||||||
|
Instructions: request.Instructions,
|
||||||
|
Tools: request.Tools,
|
||||||
|
FileIDs: request.FileIDs,
|
||||||
|
Metadata: request.Metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
Assistants = append(Assistants, assistant)
|
||||||
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(assistant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentId int64 = 0
|
||||||
|
|
||||||
|
func generateRandomID() int64 {
|
||||||
|
atomic.AddInt64(¤tId, 1)
|
||||||
|
return currentId
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListAssistantsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
// Because we're altering the existing assistants list we should just duplicate it for now.
|
||||||
|
returnAssistants := Assistants
|
||||||
|
// Parse query parameters
|
||||||
|
limitQuery := c.Query("limit", "20")
|
||||||
|
orderQuery := c.Query("order", "desc")
|
||||||
|
afterQuery := c.Query("after")
|
||||||
|
beforeQuery := c.Query("before")
|
||||||
|
|
||||||
|
// Convert string limit to integer
|
||||||
|
limit, err := strconv.Atoi(limitQuery)
|
||||||
|
if err != nil {
|
||||||
|
return c.Status(http.StatusBadRequest).SendString(fmt.Sprintf("Invalid limit query value: %s", limitQuery))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort assistants
|
||||||
|
sort.SliceStable(returnAssistants, func(i, j int) bool {
|
||||||
|
if orderQuery == "asc" {
|
||||||
|
return returnAssistants[i].Created < returnAssistants[j].Created
|
||||||
|
}
|
||||||
|
return returnAssistants[i].Created > returnAssistants[j].Created
|
||||||
|
})
|
||||||
|
|
||||||
|
// After and before cursors
|
||||||
|
if afterQuery != "" {
|
||||||
|
returnAssistants = filterAssistantsAfterID(returnAssistants, afterQuery)
|
||||||
|
}
|
||||||
|
if beforeQuery != "" {
|
||||||
|
returnAssistants = filterAssistantsBeforeID(returnAssistants, beforeQuery)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply limit
|
||||||
|
if limit < len(returnAssistants) {
|
||||||
|
returnAssistants = returnAssistants[:limit]
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(returnAssistants)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterAssistantsBeforeID filters out those assistants whose ID comes before the given ID
|
||||||
|
// We assume that the assistants are already sorted
|
||||||
|
func filterAssistantsBeforeID(assistants []Assistant, id string) []Assistant {
|
||||||
|
idInt, err := strconv.Atoi(id)
|
||||||
|
if err != nil {
|
||||||
|
return assistants // Return original slice if invalid id format is provided
|
||||||
|
}
|
||||||
|
|
||||||
|
var filteredAssistants []Assistant
|
||||||
|
|
||||||
|
for _, assistant := range assistants {
|
||||||
|
aid, err := strconv.Atoi(strings.TrimPrefix(assistant.ID, "asst_"))
|
||||||
|
if err != nil {
|
||||||
|
continue // Skip if invalid id in assistant
|
||||||
|
}
|
||||||
|
|
||||||
|
if aid < idInt {
|
||||||
|
filteredAssistants = append(filteredAssistants, assistant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredAssistants
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterAssistantsAfterID filters out those assistants whose ID comes after the given ID
|
||||||
|
// We assume that the assistants are already sorted
|
||||||
|
func filterAssistantsAfterID(assistants []Assistant, id string) []Assistant {
|
||||||
|
idInt, err := strconv.Atoi(id)
|
||||||
|
if err != nil {
|
||||||
|
return assistants // Return original slice if invalid id format is provided
|
||||||
|
}
|
||||||
|
|
||||||
|
var filteredAssistants []Assistant
|
||||||
|
|
||||||
|
for _, assistant := range assistants {
|
||||||
|
aid, err := strconv.Atoi(strings.TrimPrefix(assistant.ID, "asst_"))
|
||||||
|
if err != nil {
|
||||||
|
continue // Skip if invalid id in assistant
|
||||||
|
}
|
||||||
|
|
||||||
|
if aid > idInt {
|
||||||
|
filteredAssistants = append(filteredAssistants, assistant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredAssistants
|
||||||
|
}
|
||||||
|
|
||||||
|
func modelExists(ml *model.ModelLoader, modelName string) (found bool) {
|
||||||
|
found = false
|
||||||
|
models, err := ml.ListModels()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
if model == modelName {
|
||||||
|
found = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
type DeleteAssistantResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Deleted bool `json:"deleted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, assistant := range Assistants {
|
||||||
|
if assistant.ID == assistantID {
|
||||||
|
Assistants = append(Assistants[:i], Assistants[i+1:]...)
|
||||||
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(DeleteAssistantResponse{
|
||||||
|
ID: assistantID,
|
||||||
|
Object: "assistant.deleted",
|
||||||
|
Deleted: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warn().Msgf("Unable to find assistant %s for deletion", assistantID)
|
||||||
|
return c.Status(fiber.StatusNotFound).JSON(DeleteAssistantResponse{
|
||||||
|
ID: assistantID,
|
||||||
|
Object: "assistant.deleted",
|
||||||
|
Deleted: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, assistant := range Assistants {
|
||||||
|
if assistant.ID == assistantID {
|
||||||
|
return c.Status(fiber.StatusOK).JSON(assistant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant with id: %s", assistantID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type AssistantFile struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
AssistantID string `json:"assistant_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
AssistantFiles []AssistantFile
|
||||||
|
AssistantsFileConfigFile = "assistantsFile.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AssistantFileRequest struct {
|
||||||
|
FileID string `json:"file_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeleteAssistantFileResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Deleted bool `json:"deleted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
request := new(AssistantFileRequest)
|
||||||
|
if err := c.BodyParser(request); err != nil {
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||||
|
}
|
||||||
|
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, assistant := range Assistants {
|
||||||
|
if assistant.ID == assistantID {
|
||||||
|
if len(assistant.FileIDs) > MaxFileIdSize {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Max files %d for assistant %s reached.", MaxFileIdSize, assistant.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range UploadedFiles {
|
||||||
|
if file.ID == request.FileID {
|
||||||
|
assistant.FileIDs = append(assistant.FileIDs, request.FileID)
|
||||||
|
assistantFile := AssistantFile{
|
||||||
|
ID: file.ID,
|
||||||
|
Object: "assistant.file",
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
AssistantID: assistant.ID,
|
||||||
|
}
|
||||||
|
AssistantFiles = append(AssistantFiles, assistantFile)
|
||||||
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(assistantFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find file_id: %s", request.FileID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListAssistantFilesEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
type ListAssistantFiles struct {
|
||||||
|
Data []File
|
||||||
|
Object string
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
limitQuery := c.Query("limit", "20")
|
||||||
|
order := c.Query("order", "desc")
|
||||||
|
limit, err := strconv.Atoi(limitQuery)
|
||||||
|
if err != nil || limit < 1 || limit > 100 {
|
||||||
|
limit = 20 // Default to 20 if there's an error or the limit is out of bounds
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort files by CreatedAt depending on the order query parameter
|
||||||
|
if order == "asc" {
|
||||||
|
sort.Slice(AssistantFiles, func(i, j int) bool {
|
||||||
|
return AssistantFiles[i].CreatedAt < AssistantFiles[j].CreatedAt
|
||||||
|
})
|
||||||
|
} else { // default to "desc"
|
||||||
|
sort.Slice(AssistantFiles, func(i, j int) bool {
|
||||||
|
return AssistantFiles[i].CreatedAt > AssistantFiles[j].CreatedAt
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit the number of files returned
|
||||||
|
var limitedFiles []AssistantFile
|
||||||
|
hasMore := false
|
||||||
|
if len(AssistantFiles) > limit {
|
||||||
|
hasMore = true
|
||||||
|
limitedFiles = AssistantFiles[:limit]
|
||||||
|
} else {
|
||||||
|
limitedFiles = AssistantFiles
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"object": "list",
|
||||||
|
"data": limitedFiles,
|
||||||
|
"first_id": func() string {
|
||||||
|
if len(limitedFiles) > 0 {
|
||||||
|
return limitedFiles[0].ID
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}(),
|
||||||
|
"last_id": func() string {
|
||||||
|
if len(limitedFiles) > 0 {
|
||||||
|
return limitedFiles[len(limitedFiles)-1].ID
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}(),
|
||||||
|
"has_more": hasMore,
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusOK).JSON(response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ModifyAssistantEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
request := new(AssistantRequest)
|
||||||
|
if err := c.BodyParser(request); err != nil {
|
||||||
|
log.Warn().AnErr("Unable to parse AssistantRequest", err)
|
||||||
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||||
|
}
|
||||||
|
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, assistant := range Assistants {
|
||||||
|
if assistant.ID == assistantID {
|
||||||
|
newAssistant := Assistant{
|
||||||
|
ID: assistantID,
|
||||||
|
Object: assistant.Object,
|
||||||
|
Created: assistant.Created,
|
||||||
|
Model: request.Model,
|
||||||
|
Name: request.Name,
|
||||||
|
Description: request.Description,
|
||||||
|
Instructions: request.Instructions,
|
||||||
|
Tools: request.Tools,
|
||||||
|
FileIDs: request.FileIDs, // todo: should probably verify fileids exist
|
||||||
|
Metadata: request.Metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove old one and replace with new one
|
||||||
|
Assistants = append(Assistants[:i], Assistants[i+1:]...)
|
||||||
|
Assistants = append(Assistants, newAssistant)
|
||||||
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsConfigFile, Assistants)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(newAssistant)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant with id: %s", assistantID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
fileId := c.Params("file_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required")
|
||||||
|
}
|
||||||
|
// First remove file from assistant
|
||||||
|
for i, assistant := range Assistants {
|
||||||
|
if assistant.ID == assistantID {
|
||||||
|
for j, fileId := range assistant.FileIDs {
|
||||||
|
if fileId == fileId {
|
||||||
|
Assistants[i].FileIDs = append(Assistants[i].FileIDs[:j], Assistants[i].FileIDs[j+1:]...)
|
||||||
|
|
||||||
|
// Check if the file exists in the assistantFiles slice
|
||||||
|
for i, assistantFile := range AssistantFiles {
|
||||||
|
if assistantFile.ID == fileId {
|
||||||
|
// Remove the file from the assistantFiles slice
|
||||||
|
AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...)
|
||||||
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles)
|
||||||
|
return c.Status(fiber.StatusOK).JSON(DeleteAssistantFileResponse{
|
||||||
|
ID: fileId,
|
||||||
|
Object: "assistant.file.deleted",
|
||||||
|
Deleted: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warn().Msgf("Unable to locate file_id: %s in assistants: %s. Continuing to delete assistant file.", fileId, assistantID)
|
||||||
|
for i, assistantFile := range AssistantFiles {
|
||||||
|
if assistantFile.AssistantID == assistantID {
|
||||||
|
|
||||||
|
AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...)
|
||||||
|
utils.SaveConfig(appConfig.ConfigsDir, AssistantsFileConfigFile, AssistantFiles)
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusNotFound).JSON(DeleteAssistantFileResponse{
|
||||||
|
ID: fileId,
|
||||||
|
Object: "assistant.file.deleted",
|
||||||
|
Deleted: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Warn().Msgf("Unable to find assistant: %s", assistantID)
|
||||||
|
|
||||||
|
return c.Status(fiber.StatusNotFound).JSON(DeleteAssistantFileResponse{
|
||||||
|
ID: fileId,
|
||||||
|
Object: "assistant.file.deleted",
|
||||||
|
Deleted: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAssistantFileEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
assistantID := c.Params("assistant_id")
|
||||||
|
fileId := c.Params("file_id")
|
||||||
|
if assistantID == "" {
|
||||||
|
return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, assistantFile := range AssistantFiles {
|
||||||
|
if assistantFile.AssistantID == assistantID {
|
||||||
|
if assistantFile.ID == fileId {
|
||||||
|
return c.Status(fiber.StatusOK).JSON(assistantFile)
|
||||||
|
}
|
||||||
|
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant file with file_id: %s", fileId))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant file with assistant_id: %s", assistantID))
|
||||||
|
}
|
||||||
|
}
|
456
core/http/endpoints/openai/assistant_test.go
Normal file
456
core/http/endpoints/openai/assistant_test.go
Normal file
|
@ -0,0 +1,456 @@
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/model"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var configsDir string = "/tmp/localai/configs"
|
||||||
|
|
||||||
|
type MockLoader struct {
|
||||||
|
models []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func tearDown() func() {
|
||||||
|
return func() {
|
||||||
|
UploadedFiles = []File{}
|
||||||
|
Assistants = []Assistant{}
|
||||||
|
AssistantFiles = []AssistantFile{}
|
||||||
|
_ = os.Remove(filepath.Join(configsDir, AssistantsConfigFile))
|
||||||
|
_ = os.Remove(filepath.Join(configsDir, AssistantsFileConfigFile))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssistantEndpoints(t *testing.T) {
|
||||||
|
// Preparing the mocked objects
|
||||||
|
cl := &config.BackendConfigLoader{}
|
||||||
|
//configsDir := "/tmp/localai/configs"
|
||||||
|
modelPath := "/tmp/localai/model"
|
||||||
|
var ml = model.NewModelLoader(modelPath)
|
||||||
|
|
||||||
|
appConfig := &config.ApplicationConfig{
|
||||||
|
ConfigsDir: configsDir,
|
||||||
|
UploadLimitMB: 10,
|
||||||
|
UploadDir: "test_dir",
|
||||||
|
ModelPath: modelPath,
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = os.RemoveAll(appConfig.ConfigsDir)
|
||||||
|
_ = os.MkdirAll(appConfig.ConfigsDir, 0755)
|
||||||
|
_ = os.MkdirAll(modelPath, 0755)
|
||||||
|
os.Create(filepath.Join(modelPath, "ggml-gpt4all-j"))
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{
|
||||||
|
BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB.
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a Test Server
|
||||||
|
app.Get("/assistants", ListAssistantsEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/assistants", CreateAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Delete("/assistants/:assistant_id", DeleteAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Get("/assistants/:assistant_id", GetAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/assistants/:assistant_id", ModifyAssistantEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
app.Post("/files", UploadFilesEndpoint(cl, appConfig))
|
||||||
|
app.Get("/assistants/:assistant_id/files", ListAssistantFilesEndpoint(cl, ml, appConfig))
|
||||||
|
app.Post("/assistants/:assistant_id/files", CreateAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
|
app.Delete("/assistants/:assistant_id/files/:file_id", DeleteAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
|
app.Get("/assistants/:assistant_id/files/:file_id", GetAssistantFileEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
|
t.Run("CreateAssistantEndpoint", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
ar := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: "3.5-turbo",
|
||||||
|
Description: "Test Assistant",
|
||||||
|
Instructions: "You are computer science teacher answering student questions",
|
||||||
|
Tools: []Tool{{Type: Function}},
|
||||||
|
FileIDs: nil,
|
||||||
|
Metadata: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultAssistant, resp, err := createAssistant(app, *ar)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, len(Assistants))
|
||||||
|
//t.Cleanup(cleanupAllAssistants(t, app, []string{resultAssistant.ID}))
|
||||||
|
|
||||||
|
assert.Equal(t, ar.Name, resultAssistant.Name)
|
||||||
|
assert.Equal(t, ar.Model, resultAssistant.Model)
|
||||||
|
assert.Equal(t, ar.Tools, resultAssistant.Tools)
|
||||||
|
assert.Equal(t, ar.Description, resultAssistant.Description)
|
||||||
|
assert.Equal(t, ar.Instructions, resultAssistant.Instructions)
|
||||||
|
assert.Equal(t, ar.FileIDs, resultAssistant.FileIDs)
|
||||||
|
assert.Equal(t, ar.Metadata, resultAssistant.Metadata)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ListAssistantsEndpoint", func(t *testing.T) {
|
||||||
|
var ids []string
|
||||||
|
var resultAssistant []Assistant
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
ar := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: fmt.Sprintf("3.5-turbo-%d", i),
|
||||||
|
Description: fmt.Sprintf("Test Assistant - %d", i),
|
||||||
|
Instructions: fmt.Sprintf("You are computer science teacher answering student questions - %d", i),
|
||||||
|
Tools: []Tool{{Type: Function}},
|
||||||
|
FileIDs: []string{"fid-1234"},
|
||||||
|
Metadata: map[string]string{"meta": "data"},
|
||||||
|
}
|
||||||
|
|
||||||
|
//var err error
|
||||||
|
ra, _, err := createAssistant(app, *ar)
|
||||||
|
// Because we create the assistants so fast all end up with the same created time.
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
resultAssistant = append(resultAssistant, ra)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
ids = append(ids, resultAssistant[i].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(cleanupAllAssistants(t, app, ids))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
reqURL string
|
||||||
|
expectedStatus int
|
||||||
|
expectedResult []Assistant
|
||||||
|
expectedStringResult string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid Usage - limit only",
|
||||||
|
reqURL: "/assistants?limit=2",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedResult: Assistants[:2], // Expecting the first two assistants
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid Usage - order asc",
|
||||||
|
reqURL: "/assistants?order=asc",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedResult: Assistants, // Expecting all assistants in ascending order
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid Usage - order desc",
|
||||||
|
reqURL: "/assistants?order=desc",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedResult: []Assistant{Assistants[3], Assistants[2], Assistants[1], Assistants[0]}, // Expecting all assistants in descending order
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid Usage - after specific ID",
|
||||||
|
reqURL: "/assistants?after=2",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
// Note this is correct because it's put in descending order already
|
||||||
|
expectedResult: Assistants[:3], // Expecting assistants after (excluding) ID 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid Usage - before specific ID",
|
||||||
|
reqURL: "/assistants?before=4",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedResult: Assistants[2:], // Expecting assistants before (excluding) ID 3.
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Usage - non-integer limit",
|
||||||
|
reqURL: "/assistants?limit=two",
|
||||||
|
expectedStatus: http.StatusBadRequest,
|
||||||
|
expectedStringResult: "Invalid limit query value: two",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Usage - non-existing id in after",
|
||||||
|
reqURL: "/assistants?after=100",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedResult: []Assistant(nil), // Expecting empty list as there are no IDs above 100
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
request := httptest.NewRequest(http.MethodGet, tt.reqURL, nil)
|
||||||
|
response, err := app.Test(request)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expectedStatus, response.StatusCode)
|
||||||
|
if tt.expectedStatus != fiber.StatusOK {
|
||||||
|
all, _ := ioutil.ReadAll(response.Body)
|
||||||
|
assert.Equal(t, tt.expectedStringResult, string(all))
|
||||||
|
} else {
|
||||||
|
var result []Assistant
|
||||||
|
err = json.NewDecoder(response.Body).Decode(&result)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedResult, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DeleteAssistantEndpoint", func(t *testing.T) {
|
||||||
|
ar := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: "3.5-turbo",
|
||||||
|
Description: "Test Assistant",
|
||||||
|
Instructions: "You are computer science teacher answering student questions",
|
||||||
|
Tools: []Tool{{Type: Function}},
|
||||||
|
FileIDs: nil,
|
||||||
|
Metadata: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultAssistant, _, err := createAssistant(app, *ar)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID)
|
||||||
|
deleteReq := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||||
|
_, err = app.Test(deleteReq)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, len(Assistants))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetAssistantEndpoint", func(t *testing.T) {
|
||||||
|
ar := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: "3.5-turbo",
|
||||||
|
Description: "Test Assistant",
|
||||||
|
Instructions: "You are computer science teacher answering student questions",
|
||||||
|
Tools: []Tool{{Type: Function}},
|
||||||
|
FileIDs: nil,
|
||||||
|
Metadata: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultAssistant, _, err := createAssistant(app, *ar)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
t.Cleanup(cleanupAllAssistants(t, app, []string{resultAssistant.ID}))
|
||||||
|
|
||||||
|
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID)
|
||||||
|
request := httptest.NewRequest(http.MethodGet, target, nil)
|
||||||
|
response, err := app.Test(request)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var getAssistant Assistant
|
||||||
|
err = json.NewDecoder(response.Body).Decode(&getAssistant)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, resultAssistant.ID, getAssistant.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ModifyAssistantEndpoint", func(t *testing.T) {
|
||||||
|
ar := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: "3.5-turbo",
|
||||||
|
Description: "Test Assistant",
|
||||||
|
Instructions: "You are computer science teacher answering student questions",
|
||||||
|
Tools: []Tool{{Type: Function}},
|
||||||
|
FileIDs: nil,
|
||||||
|
Metadata: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultAssistant, _, err := createAssistant(app, *ar)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
modifiedAr := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: "4.0-turbo",
|
||||||
|
Description: "Modified Test Assistant",
|
||||||
|
Instructions: "You are math teacher answering student questions",
|
||||||
|
Tools: []Tool{{Type: CodeInterpreter}},
|
||||||
|
FileIDs: nil,
|
||||||
|
Metadata: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
modifiedArJson, err := json.Marshal(modifiedAr)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
target := fmt.Sprintf("/assistants/%s", resultAssistant.ID)
|
||||||
|
request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(modifiedArJson)))
|
||||||
|
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||||
|
|
||||||
|
modifyResponse, err := app.Test(request)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
var getAssistant Assistant
|
||||||
|
err = json.NewDecoder(modifyResponse.Body).Decode(&getAssistant)
|
||||||
|
|
||||||
|
t.Cleanup(cleanupAllAssistants(t, app, []string{getAssistant.ID}))
|
||||||
|
|
||||||
|
assert.Equal(t, resultAssistant.ID, getAssistant.ID) // IDs should match even if contents change
|
||||||
|
assert.Equal(t, modifiedAr.Tools, getAssistant.Tools)
|
||||||
|
assert.Equal(t, modifiedAr.Name, getAssistant.Name)
|
||||||
|
assert.Equal(t, modifiedAr.Instructions, getAssistant.Instructions)
|
||||||
|
assert.Equal(t, modifiedAr.Description, getAssistant.Description)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CreateAssistantFileEndpoint", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
afr := AssistantFileRequest{FileID: file.ID}
|
||||||
|
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, assistant.ID, af.AssistantID)
|
||||||
|
})
|
||||||
|
t.Run("ListAssistantFilesEndpoint", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
afr := AssistantFileRequest{FileID: file.ID}
|
||||||
|
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, assistant.ID, af.AssistantID)
|
||||||
|
})
|
||||||
|
t.Run("GetAssistantFileEndpoint", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
afr := AssistantFileRequest{FileID: file.ID}
|
||||||
|
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
t.Cleanup(cleanupAssistantFile(t, app, af.ID, af.AssistantID))
|
||||||
|
|
||||||
|
target := fmt.Sprintf("/assistants/%s/files/%s", assistant.ID, file.ID)
|
||||||
|
request := httptest.NewRequest(http.MethodGet, target, nil)
|
||||||
|
response, err := app.Test(request)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var assistantFile AssistantFile
|
||||||
|
err = json.NewDecoder(response.Body).Decode(&assistantFile)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, af.ID, assistantFile.ID)
|
||||||
|
assert.Equal(t, af.AssistantID, assistantFile.AssistantID)
|
||||||
|
})
|
||||||
|
t.Run("DeleteAssistantFileEndpoint", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
|
file, assistant, err := createFileAndAssistant(t, app, appConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
afr := AssistantFileRequest{FileID: file.ID}
|
||||||
|
af, _, err := createAssistantFile(app, afr, assistant.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
cleanupAssistantFile(t, app, af.ID, af.AssistantID)()
|
||||||
|
|
||||||
|
assert.Empty(t, AssistantFiles)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func createFileAndAssistant(t *testing.T, app *fiber.App, o *config.ApplicationConfig) (File, Assistant, error) {
|
||||||
|
ar := &AssistantRequest{
|
||||||
|
Model: "ggml-gpt4all-j",
|
||||||
|
Name: "3.5-turbo",
|
||||||
|
Description: "Test Assistant",
|
||||||
|
Instructions: "You are computer science teacher answering student questions",
|
||||||
|
Tools: []Tool{{Type: Function}},
|
||||||
|
FileIDs: nil,
|
||||||
|
Metadata: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
assistant, _, err := createAssistant(app, *ar)
|
||||||
|
if err != nil {
|
||||||
|
return File{}, Assistant{}, err
|
||||||
|
}
|
||||||
|
t.Cleanup(cleanupAllAssistants(t, app, []string{assistant.ID}))
|
||||||
|
|
||||||
|
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, o)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_, err := CallFilesDeleteEndpoint(t, app, file.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
return file, assistant, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAssistantFile(app *fiber.App, afr AssistantFileRequest, assistantId string) (AssistantFile, *http.Response, error) {
|
||||||
|
afrJson, err := json.Marshal(afr)
|
||||||
|
if err != nil {
|
||||||
|
return AssistantFile{}, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
target := fmt.Sprintf("/assistants/%s/files", assistantId)
|
||||||
|
request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(afrJson)))
|
||||||
|
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||||
|
request.Header.Set("OpenAi-Beta", "assistants=v1")
|
||||||
|
|
||||||
|
resp, err := app.Test(request)
|
||||||
|
if err != nil {
|
||||||
|
return AssistantFile{}, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var assistantFile AssistantFile
|
||||||
|
all, err := ioutil.ReadAll(resp.Body)
|
||||||
|
err = json.NewDecoder(strings.NewReader(string(all))).Decode(&assistantFile)
|
||||||
|
if err != nil {
|
||||||
|
return AssistantFile{}, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return assistantFile, resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAssistant(app *fiber.App, ar AssistantRequest) (Assistant, *http.Response, error) {
|
||||||
|
assistant, err := json.Marshal(ar)
|
||||||
|
if err != nil {
|
||||||
|
return Assistant{}, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
request := httptest.NewRequest(http.MethodPost, "/assistants", strings.NewReader(string(assistant)))
|
||||||
|
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||||
|
request.Header.Set("OpenAi-Beta", "assistants=v1")
|
||||||
|
|
||||||
|
resp, err := app.Test(request)
|
||||||
|
if err != nil {
|
||||||
|
return Assistant{}, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyString, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return Assistant{}, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var resultAssistant Assistant
|
||||||
|
err = json.NewDecoder(strings.NewReader(string(bodyString))).Decode(&resultAssistant)
|
||||||
|
|
||||||
|
return resultAssistant, resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupAllAssistants(t *testing.T, app *fiber.App, ids []string) func() {
|
||||||
|
return func() {
|
||||||
|
for _, assistant := range ids {
|
||||||
|
target := fmt.Sprintf("/assistants/%s", assistant)
|
||||||
|
deleteReq := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||||
|
_, err := app.Test(deleteReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to delete assistant %s: %v", assistant, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupAssistantFile(t *testing.T, app *fiber.App, fileId, assistantId string) func() {
|
||||||
|
return func() {
|
||||||
|
target := fmt.Sprintf("/assistants/%s/files/%s", assistantId, fileId)
|
||||||
|
request := httptest.NewRequest(http.MethodDelete, target, nil)
|
||||||
|
request.Header.Set(fiber.HeaderContentType, "application/json")
|
||||||
|
request.Header.Set("OpenAi-Beta", "assistants=v1")
|
||||||
|
|
||||||
|
resp, err := app.Test(request)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var dafr DeleteAssistantFileResponse
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&dafr)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, dafr.Deleted)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,23 +1,22 @@
|
||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/core/config"
|
"github.com/go-skynet/LocalAI/core/config"
|
||||||
|
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils"
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var uploadedFiles []File
|
var UploadedFiles []File
|
||||||
|
|
||||||
const uploadedFilesFile = "uploadedFiles.json"
|
const UploadedFilesFile = "uploadedFiles.json"
|
||||||
|
|
||||||
// File represents the structure of a file object from the OpenAI API.
|
// File represents the structure of a file object from the OpenAI API.
|
||||||
type File struct {
|
type File struct {
|
||||||
|
@ -29,38 +28,6 @@ type File struct {
|
||||||
Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.)
|
Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.)
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveUploadConfig(uploadDir string) {
|
|
||||||
file, err := json.MarshalIndent(uploadedFiles, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = os.WriteFile(filepath.Join(uploadDir, uploadedFilesFile), file, 0644)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("Failed to save uploadedFiles to file: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadUploadConfig(uploadPath string) {
|
|
||||||
uploadFilePath := filepath.Join(uploadPath, uploadedFilesFile)
|
|
||||||
|
|
||||||
_, err := os.Stat(uploadFilePath)
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
log.Debug().Msgf("No uploadedFiles file found at %s", uploadFilePath)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
file, err := os.ReadFile(uploadFilePath)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("Failed to read file: %s", err)
|
|
||||||
} else {
|
|
||||||
err = json.Unmarshal(file, &uploadedFiles)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Msgf("Failed to JSON unmarshal the file into uploadedFiles: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
|
// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create
|
||||||
func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
|
@ -95,7 +62,7 @@ func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Appli
|
||||||
}
|
}
|
||||||
|
|
||||||
f := File{
|
f := File{
|
||||||
ID: fmt.Sprintf("file-%d", time.Now().Unix()),
|
ID: fmt.Sprintf("file-%d", getNextFileId()),
|
||||||
Object: "file",
|
Object: "file",
|
||||||
Bytes: int(file.Size),
|
Bytes: int(file.Size),
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
|
@ -103,12 +70,19 @@ func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Appli
|
||||||
Purpose: purpose,
|
Purpose: purpose,
|
||||||
}
|
}
|
||||||
|
|
||||||
uploadedFiles = append(uploadedFiles, f)
|
UploadedFiles = append(UploadedFiles, f)
|
||||||
saveUploadConfig(appConfig.UploadDir)
|
utils.SaveConfig(appConfig.UploadDir, UploadedFilesFile, UploadedFiles)
|
||||||
return c.Status(fiber.StatusOK).JSON(f)
|
return c.Status(fiber.StatusOK).JSON(f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var currentFileId int64 = 0
|
||||||
|
|
||||||
|
func getNextFileId() int64 {
|
||||||
|
atomic.AddInt64(¤tId, 1)
|
||||||
|
return currentId
|
||||||
|
}
|
||||||
|
|
||||||
// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
|
// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list
|
||||||
func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
type ListFiles struct {
|
type ListFiles struct {
|
||||||
|
@ -121,9 +95,9 @@ func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Applica
|
||||||
|
|
||||||
purpose := c.Query("purpose")
|
purpose := c.Query("purpose")
|
||||||
if purpose == "" {
|
if purpose == "" {
|
||||||
listFiles.Data = uploadedFiles
|
listFiles.Data = UploadedFiles
|
||||||
} else {
|
} else {
|
||||||
for _, f := range uploadedFiles {
|
for _, f := range UploadedFiles {
|
||||||
if purpose == f.Purpose {
|
if purpose == f.Purpose {
|
||||||
listFiles.Data = append(listFiles.Data, f)
|
listFiles.Data = append(listFiles.Data, f)
|
||||||
}
|
}
|
||||||
|
@ -140,7 +114,7 @@ func getFileFromRequest(c *fiber.Ctx) (*File, error) {
|
||||||
return nil, fmt.Errorf("file_id parameter is required")
|
return nil, fmt.Errorf("file_id parameter is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range uploadedFiles {
|
for _, f := range UploadedFiles {
|
||||||
if id == f.ID {
|
if id == f.ID {
|
||||||
return &f, nil
|
return &f, nil
|
||||||
}
|
}
|
||||||
|
@ -184,14 +158,14 @@ func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.Appli
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove upload from list
|
// Remove upload from list
|
||||||
for i, f := range uploadedFiles {
|
for i, f := range UploadedFiles {
|
||||||
if f.ID == file.ID {
|
if f.ID == file.ID {
|
||||||
uploadedFiles = append(uploadedFiles[:i], uploadedFiles[i+1:]...)
|
UploadedFiles = append(UploadedFiles[:i], UploadedFiles[i+1:]...)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
saveUploadConfig(appConfig.UploadDir)
|
utils.SaveConfig(appConfig.UploadDir, UploadedFilesFile, UploadedFiles)
|
||||||
return c.JSON(DeleteStatus{
|
return c.JSON(DeleteStatus{
|
||||||
Id: file.ID,
|
Id: file.ID,
|
||||||
Object: "file",
|
Object: "file",
|
||||||
|
|
|
@ -3,6 +3,7 @@ package openai
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -73,6 +74,7 @@ func TestUploadFileExceedSizeLimit(t *testing.T) {
|
||||||
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
|
app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option))
|
||||||
|
|
||||||
t.Run("UploadFilesEndpoint file size exceeds limit", func(t *testing.T) {
|
t.Run("UploadFilesEndpoint file size exceeds limit", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 11, option)
|
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 11, option)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
@ -80,46 +82,54 @@ func TestUploadFileExceedSizeLimit(t *testing.T) {
|
||||||
assert.Contains(t, bodyToString(resp, t), "exceeds upload limit")
|
assert.Contains(t, bodyToString(resp, t), "exceeds upload limit")
|
||||||
})
|
})
|
||||||
t.Run("UploadFilesEndpoint purpose not defined", func(t *testing.T) {
|
t.Run("UploadFilesEndpoint purpose not defined", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
resp, _ := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "", 5, option)
|
resp, _ := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "", 5, option)
|
||||||
|
|
||||||
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||||
assert.Contains(t, bodyToString(resp, t), "Purpose is not defined")
|
assert.Contains(t, bodyToString(resp, t), "Purpose is not defined")
|
||||||
})
|
})
|
||||||
t.Run("UploadFilesEndpoint file already exists", func(t *testing.T) {
|
t.Run("UploadFilesEndpoint file already exists", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
f1 := CallFilesUploadEndpointWithCleanup(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
f1 := CallFilesUploadEndpointWithCleanup(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
||||||
|
|
||||||
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 5, option)
|
||||||
fmt.Println(f1)
|
fmt.Println(f1)
|
||||||
fmt.Printf("ERror: %v", err)
|
fmt.Printf("ERror: %v\n", err)
|
||||||
|
fmt.Printf("resp: %+v\n", resp)
|
||||||
|
|
||||||
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
|
||||||
assert.Contains(t, bodyToString(resp, t), "File already exists")
|
assert.Contains(t, bodyToString(resp, t), "File already exists")
|
||||||
})
|
})
|
||||||
t.Run("UploadFilesEndpoint file uploaded successfully", func(t *testing.T) {
|
t.Run("UploadFilesEndpoint file uploaded successfully", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
||||||
|
|
||||||
// Check if file exists in the disk
|
// Check if file exists in the disk
|
||||||
filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName("test.txt"))
|
testName := strings.Split(t.Name(), "/")[1]
|
||||||
|
fileName := testName + "-test.txt"
|
||||||
|
filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName(fileName))
|
||||||
_, err := os.Stat(filePath)
|
_, err := os.Stat(filePath)
|
||||||
|
|
||||||
assert.False(t, os.IsNotExist(err))
|
assert.False(t, os.IsNotExist(err))
|
||||||
assert.Equal(t, file.Bytes, 5242880)
|
assert.Equal(t, file.Bytes, 5242880)
|
||||||
assert.NotEmpty(t, file.CreatedAt)
|
assert.NotEmpty(t, file.CreatedAt)
|
||||||
assert.Equal(t, file.Filename, "test.txt")
|
assert.Equal(t, file.Filename, fileName)
|
||||||
assert.Equal(t, file.Purpose, "fine-tune")
|
assert.Equal(t, file.Purpose, "fine-tune")
|
||||||
})
|
})
|
||||||
t.Run("ListFilesEndpoint without purpose parameter", func(t *testing.T) {
|
t.Run("ListFilesEndpoint without purpose parameter", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
resp, err := CallListFilesEndpoint(t, app, "")
|
resp, err := CallListFilesEndpoint(t, app, "")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, 200, resp.StatusCode)
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
listFiles := responseToListFile(t, resp)
|
listFiles := responseToListFile(t, resp)
|
||||||
if len(listFiles.Data) != len(uploadedFiles) {
|
if len(listFiles.Data) != len(UploadedFiles) {
|
||||||
t.Errorf("Expected %v files, got %v files", len(uploadedFiles), len(listFiles.Data))
|
t.Errorf("Expected %v files, got %v files", len(UploadedFiles), len(listFiles.Data))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) {
|
t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
_ = CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
_ = CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option)
|
||||||
|
|
||||||
resp, err := CallListFilesEndpoint(t, app, "fine-tune")
|
resp, err := CallListFilesEndpoint(t, app, "fine-tune")
|
||||||
|
@ -131,6 +141,7 @@ func TestUploadFileExceedSizeLimit(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("ListFilesEndpoint with invalid query parameter", func(t *testing.T) {
|
t.Run("ListFilesEndpoint with invalid query parameter", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
resp, err := CallListFilesEndpoint(t, app, "not-so-fine-tune")
|
resp, err := CallListFilesEndpoint(t, app, "not-so-fine-tune")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 200, resp.StatusCode)
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
@ -142,6 +153,7 @@ func TestUploadFileExceedSizeLimit(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("GetFilesContentsEndpoint get file content", func(t *testing.T) {
|
t.Run("GetFilesContentsEndpoint get file content", func(t *testing.T) {
|
||||||
|
t.Cleanup(tearDown())
|
||||||
req := httptest.NewRequest("GET", "/files", nil)
|
req := httptest.NewRequest("GET", "/files", nil)
|
||||||
resp, _ := app.Test(req)
|
resp, _ := app.Test(req)
|
||||||
assert.Equal(t, 200, resp.StatusCode)
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
@ -175,8 +187,10 @@ func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*htt
|
||||||
}
|
}
|
||||||
|
|
||||||
func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) (*http.Response, error) {
|
func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) (*http.Response, error) {
|
||||||
|
testName := strings.Split(t.Name(), "/")[1]
|
||||||
|
|
||||||
// Create a file that exceeds the limit
|
// Create a file that exceeds the limit
|
||||||
file := createTestFile(t, fileName, fileSize, appConfig)
|
file := createTestFile(t, testName+"-"+fileName, fileSize, appConfig)
|
||||||
|
|
||||||
// Creating a new HTTP Request
|
// Creating a new HTTP Request
|
||||||
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||||
|
@ -188,7 +202,8 @@ func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpos
|
||||||
|
|
||||||
func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) File {
|
func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) File {
|
||||||
// Create a file that exceeds the limit
|
// Create a file that exceeds the limit
|
||||||
file := createTestFile(t, fileName, fileSize, appConfig)
|
testName := strings.Split(t.Name(), "/")[1]
|
||||||
|
file := createTestFile(t, testName+"-"+fileName, fileSize, appConfig)
|
||||||
|
|
||||||
// Creating a new HTTP Request
|
// Creating a new HTTP Request
|
||||||
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
body, writer := newMultipartFile(file.Name(), tag, purpose)
|
||||||
|
@ -199,11 +214,12 @@ func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName,
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
f := responseToFile(t, resp)
|
f := responseToFile(t, resp)
|
||||||
|
|
||||||
id := f.ID
|
//id := f.ID
|
||||||
t.Cleanup(func() {
|
//t.Cleanup(func() {
|
||||||
_, err := CallFilesDeleteEndpoint(t, app, id)
|
// _, err := CallFilesDeleteEndpoint(t, app, id)
|
||||||
assert.NoError(t, err)
|
// assert.NoError(t, err)
|
||||||
})
|
// assert.Empty(t, UploadedFiles)
|
||||||
|
//})
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
@ -240,7 +256,8 @@ func createTestFile(t *testing.T, name string, sizeMB int, option *config.Applic
|
||||||
t.Fatalf("Error MKDIR: %v", err)
|
t.Fatalf("Error MKDIR: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
file, _ := os.Create(name)
|
file, err := os.Create(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
file.WriteString(strings.Repeat("a", sizeMB*1024*1024)) // sizeMB MB File
|
file.WriteString(strings.Repeat("a", sizeMB*1024*1024)) // sizeMB MB File
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
|
@ -280,7 +297,7 @@ func responseToListFile(t *testing.T, resp *http.Response) ListFiles {
|
||||||
|
|
||||||
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&listFiles)
|
err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&listFiles)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Failed to decode response: %s", err)
|
log.Error().Msgf("Failed to decode response: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return listFiles
|
return listFiles
|
||||||
|
|
7
main.go
7
main.go
|
@ -149,6 +149,12 @@ func main() {
|
||||||
EnvVars: []string{"UPLOAD_PATH"},
|
EnvVars: []string{"UPLOAD_PATH"},
|
||||||
Value: "/tmp/localai/upload",
|
Value: "/tmp/localai/upload",
|
||||||
},
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "config-path",
|
||||||
|
Usage: "Path to store uploads from files api",
|
||||||
|
EnvVars: []string{"CONFIG_PATH"},
|
||||||
|
Value: "/tmp/localai/config",
|
||||||
|
},
|
||||||
&cli.StringFlag{
|
&cli.StringFlag{
|
||||||
Name: "backend-assets-path",
|
Name: "backend-assets-path",
|
||||||
Usage: "Path used to extract libraries that are required by some of the backends in runtime.",
|
Usage: "Path used to extract libraries that are required by some of the backends in runtime.",
|
||||||
|
@ -241,6 +247,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
|
||||||
config.WithImageDir(ctx.String("image-path")),
|
config.WithImageDir(ctx.String("image-path")),
|
||||||
config.WithAudioDir(ctx.String("audio-path")),
|
config.WithAudioDir(ctx.String("audio-path")),
|
||||||
config.WithUploadDir(ctx.String("upload-path")),
|
config.WithUploadDir(ctx.String("upload-path")),
|
||||||
|
config.WithConfigsDir(ctx.String("config-path")),
|
||||||
config.WithF16(ctx.Bool("f16")),
|
config.WithF16(ctx.Bool("f16")),
|
||||||
config.WithStringGalleries(ctx.String("galleries")),
|
config.WithStringGalleries(ctx.String("galleries")),
|
||||||
config.WithModelLibraryURL(ctx.String("remote-library")),
|
config.WithModelLibraryURL(ctx.String("remote-library")),
|
||||||
|
|
41
pkg/utils/config.go
Normal file
41
pkg/utils/config.go
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SaveConfig(filePath, fileName string, obj any) {
|
||||||
|
file, err := json.MarshalIndent(obj, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
absolutePath := filepath.Join(filePath, fileName)
|
||||||
|
err = os.WriteFile(absolutePath, file, 0644)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("Failed to save configuration file to %s: %s", absolutePath, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfig(filePath, fileName string, obj interface{}) {
|
||||||
|
uploadFilePath := filepath.Join(filePath, fileName)
|
||||||
|
|
||||||
|
_, err := os.Stat(uploadFilePath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
log.Debug().Msgf("No configuration file found at %s", uploadFilePath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.ReadFile(uploadFilePath)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("Failed to read file: %s", err)
|
||||||
|
} else {
|
||||||
|
err = json.Unmarshal(file, &obj)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Msgf("Failed to JSON unmarshal the file %s: %v", uploadFilePath, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue