mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-22 11:35:00 +00:00
wip
This commit is contained in:
parent
f0e265a96d
commit
78ef045bb3
21 changed files with 485 additions and 336 deletions
|
@ -81,6 +81,10 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
||||
}
|
||||
|
||||
if input.ResponseFormat == "json_object" {
|
||||
input.Grammar = grammar.JSONBNF
|
||||
}
|
||||
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
if len(input.Functions) > 0 && config.ShouldUseFunctions() {
|
||||
log.Debug().Msgf("Response needs to process functions")
|
||||
|
@ -140,14 +144,14 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||
}
|
||||
}
|
||||
r := config.Roles[role]
|
||||
contentExists := i.Content != nil && *i.Content != ""
|
||||
contentExists := i.Content != nil && i.StringContent != ""
|
||||
// First attempt to populate content via a chat message specific template
|
||||
if config.TemplateConfig.ChatMessage != "" {
|
||||
chatMessageData := model.ChatMessageTemplateData{
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Role: r,
|
||||
RoleName: role,
|
||||
Content: *i.Content,
|
||||
Content: i.StringContent,
|
||||
MessageIndex: messageIndex,
|
||||
}
|
||||
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
|
||||
|
@ -166,7 +170,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||
if content == "" {
|
||||
if r != "" {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(r, " ", *i.Content)
|
||||
content = fmt.Sprint(r, " ", i.StringContent)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
|
@ -180,7 +184,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||
}
|
||||
} else {
|
||||
if contentExists {
|
||||
content = fmt.Sprint(*i.Content)
|
||||
content = fmt.Sprint(i.StringContent)
|
||||
}
|
||||
if i.FunctionCall != nil {
|
||||
j, err := json.Marshal(i.FunctionCall)
|
||||
|
@ -334,7 +338,11 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
|
|||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||
// Note: This costs (in term of CPU) another computation
|
||||
config.Grammar = ""
|
||||
predFunc, err := backend.ModelInference(input.Context, predInput, o.Loader, *config, o, nil)
|
||||
images := []string{}
|
||||
for _, m := range input.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
predFunc, err := backend.ModelInference(input.Context, predInput, images, o.Loader, *config, o, nil)
|
||||
if err != nil {
|
||||
log.Error().Msgf("inference error: %s", err.Error())
|
||||
return
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
config "github.com/go-skynet/LocalAI/api/config"
|
||||
"github.com/go-skynet/LocalAI/api/options"
|
||||
"github.com/go-skynet/LocalAI/api/schema"
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar"
|
||||
model "github.com/go-skynet/LocalAI/pkg/model"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
|
@ -64,6 +65,10 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
|
|||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||
}
|
||||
|
||||
if input.ResponseFormat == "json_object" {
|
||||
input.Grammar = grammar.JSONBNF
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
|
||||
if input.Stream {
|
||||
|
|
|
@ -23,8 +23,13 @@ func ComputeChoices(
|
|||
n = 1
|
||||
}
|
||||
|
||||
images := []string{}
|
||||
for _, m := range req.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := backend.ModelInference(req.Context, predInput, loader, *config, o, tokenCallback)
|
||||
predFunc, err := backend.ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback)
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, err
|
||||
}
|
||||
|
|
|
@ -2,8 +2,11 @@ package openai
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
@ -61,6 +64,37 @@ func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *sche
|
|||
return modelFile, input, nil
|
||||
}
|
||||
|
||||
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
||||
// encodes it in base64 and returns the base64 string
|
||||
func getBase64Image(s string) (string, error) {
|
||||
if strings.HasPrefix(s, "http") {
|
||||
// download the image
|
||||
resp, err := http.Get(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// read the image data into memory
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// encode the image data in base64
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
// return the base64 string
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
|
||||
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
|
||||
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
|
||||
}
|
||||
return "", fmt.Errorf("not valid string")
|
||||
}
|
||||
|
||||
func updateConfig(config *config.Config, input *schema.OpenAIRequest) {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
|
@ -129,6 +163,35 @@ func updateConfig(config *config.Config, input *schema.OpenAIRequest) {
|
|||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
index := 0
|
||||
for _, m := range input.Messages {
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
m.StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []schema.Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
for _, pp := range c {
|
||||
if pp.Type == "text" {
|
||||
m.StringContent = pp.Text
|
||||
} else if pp.Type == "image_url" {
|
||||
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
|
||||
base64, err := getBase64Image(pp.ImageURL)
|
||||
if err == nil {
|
||||
m.StringImages = append(m.StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
// set a placeholder for each image
|
||||
m.StringContent = m.StringContent + fmt.Sprintf("[img-%d]", index)
|
||||
index++
|
||||
} else {
|
||||
fmt.Print("Failed encoding image", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue