package openai import ( "bufio" "bytes" "encoding/json" "fmt" "time" "github.com/go-skynet/LocalAI/core/backend" "github.com/go-skynet/LocalAI/core/services" "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/schema" "github.com/gofiber/fiber/v2" "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" ) // https://platform.openai.com/docs/api-reference/completions func CompletionEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *schema.StartupOptions) func(c *fiber.Ctx) error { id := uuid.New().String() created := int(time.Now().Unix()) return func(c *fiber.Ctx) error { modelName, input, err := readInput(c, so, ml, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } log.Debug().Msgf("`input`: %+v", input) if input.Stream { log.Debug().Msgf("Stream request received") c.Context().SetContentType("text/event-stream") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) //c.Set("Content-Type", "text/event-stream") c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") responses, err := backend.StreamingCompletionGenerationOpenAIRequest(modelName, input, cl, ml, so) if err != nil { return fmt.Errorf("failed establishing streaming completion request :%w", err) } c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { for ev := range responses { var buf bytes.Buffer enc := json.NewEncoder(&buf) enc.Encode(ev) log.Debug().Msgf("Sending chunk: %s", buf.String()) fmt.Fprintf(w, "data: %v\n", buf.String()) w.Flush() } resp := &schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { Index: 0, FinishReason: "stop", }, }, Object: "text_completion", } respData, _ := json.Marshal(resp) w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) w.WriteString("data: [DONE]\n\n") w.Flush() })) return nil } /////////// resp, err := backend.CompletionGenerationOpenAIRequest(modelName, input, cl, ml, so) if err != nil { return fmt.Errorf("error generating completion request: +%w", err) } jsonResult, _ := json.Marshal(resp) log.Debug().Msgf("Response: %s", jsonResult) // Return the prediction in the response body return c.JSON(resp) } }