feat: add embeddings for go-llama.cpp backend (#190)

This commit is contained in:
Ettore Di Giacinto 2023-05-05 11:20:06 +02:00 committed by GitHub
parent 714bfcd45b
commit c839b334eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 253 additions and 154 deletions

View file

@ -33,13 +33,21 @@ type OpenAIUsage struct {
TotalTokens int `json:"total_tokens"`
}
type Item struct {
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object,omitempty"`
}
type OpenAIResponse struct {
Created int `json:"created,omitempty"`
Object string `json:"object,omitempty"`
ID string `json:"id,omitempty"`
Model string `json:"model,omitempty"`
Choices []Choice `json:"choices,omitempty"`
Usage OpenAIUsage `json:"usage"`
Created int `json:"created,omitempty"`
Object string `json:"object,omitempty"`
ID string `json:"id,omitempty"`
Model string `json:"model,omitempty"`
Choices []Choice `json:"choices,omitempty"`
Data []Item `json:"data,omitempty"`
Usage OpenAIUsage `json:"usage"`
}
type Choice struct {
@ -298,6 +306,40 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
}
}
// https://platform.openai.com/docs/api-reference/completions
func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
log.Debug().Msgf("Parameter Config: %+v", config)
// get the model function to call for the result
embedFn, err := ModelEmbedding(input.Input, loader, *config)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
resp := &OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Data: []Item{{Embedding: embeddings, Index: 0, Object: "embedding"}},
Object: "list",
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}
func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {