This commit is contained in:
Ettore Di Giacinto 2023-11-08 20:05:54 +01:00
parent f0e265a96d
commit 78ef045bb3
21 changed files with 485 additions and 336 deletions

View file

@ -2,8 +2,11 @@ package openai
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"strings"
@ -61,6 +64,37 @@ func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *sche
return modelFile, input, nil
}
// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
func getBase64Image(s string) (string, error) {
if strings.HasPrefix(s, "http") {
// download the image
resp, err := http.Get(s)
if err != nil {
return "", err
}
defer resp.Body.Close()
// read the image data into memory
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
// encode the image data in base64
encoded := base64.StdEncoding.EncodeToString(data)
// return the base64 string
return encoded, nil
}
// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
}
return "", fmt.Errorf("not valid string")
}
func updateConfig(config *config.Config, input *schema.OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
@ -129,6 +163,35 @@ func updateConfig(config *config.Config, input *schema.OpenAIRequest) {
}
}
// Decode each request's message content
index := 0
for _, m := range input.Messages {
switch content := m.Content.(type) {
case string:
m.StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
for _, pp := range c {
if pp.Type == "text" {
m.StringContent = pp.Text
} else if pp.Type == "image_url" {
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
base64, err := getBase64Image(pp.ImageURL)
if err == nil {
m.StringImages = append(m.StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
m.StringContent = m.StringContent + fmt.Sprintf("[img-%d]", index)
index++
} else {
fmt.Print("Failed encoding image", err)
}
}
}
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}