From 1ef6ba2b5285e401535af03007213d074f42bf8f Mon Sep 17 00:00:00 2001 From: Dave Lee Date: Thu, 1 Jun 2023 23:43:34 -0400 Subject: [PATCH] plumbing for cli context and loader, stash before picking up initializers --- apiv2/config.go | 5 +- apiv2/localai.go | 64 +++++++++++++++++++-- apiv2/localai_nethttp.go | 11 +++- config/gpt-3.5-turbo-completion.yaml | 11 ++++ main.go | 69 +++++++++++++++-------- openai-openapi/localai_model_patches.yaml | 13 ++++- 6 files changed, 139 insertions(+), 34 deletions(-) create mode 100644 config/gpt-3.5-turbo-completion.yaml diff --git a/apiv2/config.go b/apiv2/config.go index 7b024193..ff6aa026 100644 --- a/apiv2/config.go +++ b/apiv2/config.go @@ -59,6 +59,10 @@ func (sc SpecificConfig[RequestModel]) GetRequestDefaults() interface{} { return sc.RequestDefaults } +func (sc SpecificConfig[RequestModel]) GetRequest() RequestModel { + return sc.RequestDefaults +} + func (sc SpecificConfig[RequestModel]) GetLocalPaths() ConfigLocalPaths { return sc.LocalPaths } @@ -90,7 +94,6 @@ func (cm *ConfigManager) loadConfigFile(path string) (*Config, error) { return nil, fmt.Errorf("cannot unmarshal config file: %w", err) } fmt.Printf("RAW STUB: %+v\n", stub) - // fmt.Printf("DUMB SHIT: %+v\n%T\n", EndpointToRequestBodyMap[rawConfig.Registration.Endpoint], EndpointToRequestBodyMap[rawConfig.Registration.Endpoint]) endpoint := stub.Registration.Endpoint diff --git a/apiv2/localai.go b/apiv2/localai.go index 1d9d5feb..4bc945a8 100644 --- a/apiv2/localai.go +++ b/apiv2/localai.go @@ -5,14 +5,16 @@ import ( "fmt" "strings" + model "github.com/go-skynet/LocalAI/pkg/model" "github.com/mitchellh/mapstructure" ) type LocalAIServer struct { configManager *ConfigManager + loader *model.ModelLoader } -func combineRequestAndConfig[RequestType any](configManager *ConfigManager, model string, requestFromInput *RequestType) (*RequestType, error) { +func combineRequestAndConfig[RequestType any](configManager *ConfigManager, model string, requestFromInput *RequestType) (*SpecificConfig[RequestType], error) { splitFnName := strings.Split(printCurrentFunctionName(2), ".") @@ -52,7 +54,17 @@ func combineRequestAndConfig[RequestType any](configManager *ConfigManager, mode fmt.Printf("AFTER rD: %T\n%+v\n\n", request, request) - return &request, nil + return &SpecificConfig[RequestType]{ + ConfigStub: ConfigStub{ + Registration: config.GetRegistration(), + LocalPaths: config.GetLocalPaths(), + }, + RequestDefaults: request, + }, nil +} + +func (las *LocalAIServer) loadModel(configStub ConfigStub) { + } // CancelFineTune implements StrictServerInterface @@ -63,13 +75,15 @@ func (*LocalAIServer) CancelFineTune(ctx context.Context, request CancelFineTune // CreateChatCompletion implements StrictServerInterface func (las *LocalAIServer) CreateChatCompletion(ctx context.Context, request CreateChatCompletionRequestObject) (CreateChatCompletionResponseObject, error) { - chatRequest, err := combineRequestAndConfig(las.configManager, request.Body.Model, request.Body) + chatRequestConfig, err := combineRequestAndConfig(las.configManager, request.Body.Model, request.Body) if err != nil { fmt.Printf("CreateChatCompletion ERROR combining config and input!\n%s\n", err.Error()) return nil, err } + chatRequest := chatRequestConfig.RequestDefaults + fmt.Printf("\n===CreateChatCompletion===\n%+v\n", chatRequest) fmt.Printf("\n\n!! TYPED CreateChatCompletion !!\ntemperature %f\n top_p %f \n %d\n", *chatRequest.Temperature, *chatRequest.TopP, *chatRequest.XLocalaiExtensions.TopK) @@ -85,8 +99,48 @@ func (las *LocalAIServer) CreateChatCompletion(ctx context.Context, request Crea } // CreateCompletion implements StrictServerInterface -func (*LocalAIServer) CreateCompletion(ctx context.Context, request CreateCompletionRequestObject) (CreateCompletionResponseObject, error) { - panic("unimplemented") +func (las *LocalAIServer) CreateCompletion(ctx context.Context, request CreateCompletionRequestObject) (CreateCompletionResponseObject, error) { + + modelName := request.Body.Model + + config, err := combineRequestAndConfig(las.configManager, modelName, request.Body) + + if err != nil { + fmt.Printf("CreateCompletion ERROR combining config and input!\n%s\n", err.Error()) + return nil, err + } + + req := config.GetRequest() + + fmt.Printf("\n===CreateCompletion===\n%+v\n", req) + + var choices []CreateCompletionResponseChoice + + prompts, err := req.Prompt.AsCreateCompletionRequestPrompt1() + + if err != nil { + tokenPrompt, err := req.Prompt.AsCreateCompletionRequestPrompt2() + if err == nil { + fmt.Printf("Scary token array length %d\n", len(tokenPrompt)) + panic("Token array is scary and phase 2") + } + singlePrompt, err := req.Prompt.AsCreateCompletionRequestPrompt0() + if err != nil { + return nil, err + } + prompts = []string{singlePrompt} + } + + // model := las.loader.LoadModel(modelName, ) + + for _, v := range prompts { + fmt.Printf("[prompt] %s\n", v) + } + + return CreateCompletion200JSONResponse{ + Model: modelName, + Choices: choices, + }, nil } // CreateEdit implements StrictServerInterface diff --git a/apiv2/localai_nethttp.go b/apiv2/localai_nethttp.go index d6e8c037..ca571f50 100644 --- a/apiv2/localai_nethttp.go +++ b/apiv2/localai_nethttp.go @@ -1,16 +1,21 @@ package apiv2 -import "net/http" +import ( + "net/http" -func NewLocalAINetHTTPServer(configManager *ConfigManager) *LocalAIServer { + "github.com/go-skynet/LocalAI/pkg/model" +) + +func NewLocalAINetHTTPServer(configManager *ConfigManager, loader *model.ModelLoader, address string) *LocalAIServer { localAI := LocalAIServer{ configManager: configManager, + loader: loader, } var middlewares []StrictMiddlewareFunc http.Handle("/", Handler(NewStrictHandler(&localAI, middlewares))) - http.ListenAndServe(":8085", nil) + http.ListenAndServe(address, nil) return &localAI } diff --git a/config/gpt-3.5-turbo-completion.yaml b/config/gpt-3.5-turbo-completion.yaml new file mode 100644 index 00000000..f890d09e --- /dev/null +++ b/config/gpt-3.5-turbo-completion.yaml @@ -0,0 +1,11 @@ +registration: + model: gpt-3.5-turbo + endpoint: CreateCompletion +local_paths: + model: ggml-gpt4all-j + template: chat-gpt4all +request_defaults: + top_p: 0.7 + temperature: 0.2 + x-localai-extensions: + top_k: 80 \ No newline at end of file diff --git a/main.go b/main.go index 3483d175..037e496f 100644 --- a/main.go +++ b/main.go @@ -25,29 +25,6 @@ func main() { log.Log().Msgf("STARTING!") - // TODO REALLY SHIT TEST MUST BE FIXED BEFORE MERGE - v2ConfigManager := apiv2.NewConfigManager() - log.Log().Msgf("v2ConfigManager init %+v", v2ConfigManager) - registered, cfgErr := v2ConfigManager.LoadConfigDirectory("/workspace/config") - - fmt.Printf("!=!") - - log.Log().Msgf("NEW v2 test cfgErr: %w \nREGISTRATIONS:", cfgErr) - - for i, reg := range registered { - log.Log().Msgf("%d: %+v", i, reg) - - testField, exists := v2ConfigManager.GetConfig(reg) - if exists { - log.Log().Msgf("!! %s: %s", testField.GetRegistration().Endpoint, testField.GetLocalPaths().Model) - } - - } - - v2Server := apiv2.NewLocalAINetHTTPServer(v2ConfigManager) - - log.Log().Msgf("NEW v2 test: %+v", v2Server) - app := &cli.App{ Name: "LocalAI", Usage: "OpenAI compatible API for running LLaMA/GPT models locally on CPU with consumer grade hardware.", @@ -72,6 +49,18 @@ func main() { EnvVars: []string{"MODELS_PATH"}, Value: filepath.Join(path, "models"), }, + &cli.StringFlag{ + Name: "template-path", + DefaultText: "Path containing templates used for inferencing", + EnvVars: []string{"TEMPLATES_PATH"}, + Value: filepath.Join(path, "templates"), + }, + &cli.StringFlag{ + Name: "config-path", + DefaultText: "Path containing model/endpoint configurations", + EnvVars: []string{"CONFIG_PATH"}, + Value: filepath.Join(path, "config"), + }, &cli.StringFlag{ Name: "config-file", DefaultText: "Config file", @@ -83,6 +72,12 @@ func main() { EnvVars: []string{"ADDRESS"}, Value: ":8080", }, + &cli.StringFlag{ + Name: "addressv2", + DefaultText: "Bind address for the API server (DEBUG v2 TEST)", + EnvVars: []string{"ADDRESS_V2"}, + Value: ":8085", + }, &cli.StringFlag{ Name: "image-path", DefaultText: "Image directory", @@ -120,7 +115,33 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. Copyright: "go-skynet authors", Action: func(ctx *cli.Context) error { fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path")) - return api.App(context.Background(), ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path"), ctx.String("templates-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-path")).Listen(ctx.String("address")) + + loader := model.NewModelLoader(ctx.String("models-path"), ctx.String("templates-path")) + + if av2 := ctx.String("addressv2"); av2 != "" { + + v2ConfigManager := apiv2.NewConfigManager() + registered, cfgErr := v2ConfigManager.LoadConfigDirectory(ctx.String("config-path")) + + if cfgErr != nil { + panic("failed to load config directory todo better handler here") + } + + for i, reg := range registered { + log.Log().Msgf("%d: %+v", i, reg) + + testField, exists := v2ConfigManager.GetConfig(reg) + if exists { + log.Log().Msgf("!! %s: %s", testField.GetRegistration().Endpoint, testField.GetLocalPaths().Model) + } + + } + + v2Server := apiv2.NewLocalAINetHTTPServer(v2ConfigManager, loader, ctx.String("addressv2")) + + log.Log().Msgf("NEW v2 test: %+v", v2Server) + } + return api.App(context.Background(), ctx.String("config-file"), loader, ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-path")).Listen(ctx.String("address")) }, } diff --git a/openai-openapi/localai_model_patches.yaml b/openai-openapi/localai_model_patches.yaml index 72046790..bb5defd9 100644 --- a/openai-openapi/localai_model_patches.yaml +++ b/openai-openapi/localai_model_patches.yaml @@ -77,4 +77,15 @@ components: #@overlay/match missing_ok=True x-localai-extensions: $ref: "#/components/schemas/LocalAIImageRequestExtension" - \ No newline at end of file + CreateChatCompletionResponse: + properties: + choices: + items: + #@overlay/match missing_ok=True + x-go-type-name: "CreateChatCompletionResponseChoice" + CreateCompletionResponse: + properties: + choices: + items: + #@overlay/match missing_ok=True + x-go-type-name: "CreateCompletionResponseChoice"