From f9133b5a61123aa7159a41cf373df3fde6030041 Mon Sep 17 00:00:00 2001 From: Dave Lee Date: Thu, 1 Jun 2023 20:44:54 -0400 Subject: [PATCH] big progress checkin. Still quite broken, but now it shows the new direction. Time to start hooking things up again. --- apiv2/config.go | 23 ++ apiv2/localai.go | 232 +++++++----------- apiv2/localai.go.old | 196 --------------- apiv2/localai_nethttp.go | 4 +- apiv2/util.go | 66 +++++ go.mod | 4 +- go.sum | 8 +- openai-openapi/chi-interface.tmpl | 17 ++ openai-openapi/config.yaml | 2 + openai-openapi/endpoint-body-mapping.tmpl | 25 +- openai-openapi/localai_model_patches.yaml | 3 + .../remove_depreciated_openapi.yaml | 12 - 12 files changed, 228 insertions(+), 364 deletions(-) delete mode 100644 apiv2/localai.go.old create mode 100644 apiv2/util.go create mode 100644 openai-openapi/chi-interface.tmpl delete mode 100644 openai-openapi/remove_depreciated_openapi.yaml diff --git a/apiv2/config.go b/apiv2/config.go index fe437015..7b024193 100644 --- a/apiv2/config.go +++ b/apiv2/config.go @@ -43,6 +43,18 @@ type Config interface { GetRegistration() ConfigRegistration } +func (cs ConfigStub) GetRequestDefaults() interface{} { + return nil +} + +func (cs ConfigStub) GetLocalPaths() ConfigLocalPaths { + return cs.LocalPaths +} + +func (cs ConfigStub) GetRegistration() ConfigRegistration { + return cs.Registration +} + func (sc SpecificConfig[RequestModel]) GetRequestDefaults() interface{} { return sc.RequestDefaults } @@ -158,6 +170,17 @@ func (cm *ConfigManager) GetConfig(r ConfigRegistration) (Config, bool) { return v, exists } +// This is a convience function for endpoint functions to use. +// The advantage is it avoids errors in the endpoint string +// Not a clue what the performance cost of this is. +func (cm *ConfigManager) GetConfigForThisEndpoint(m string) (Config, bool) { + endpoint := printCurrentFunctionName(2) + return cm.GetConfig(ConfigRegistration{ + Model: m, + Endpoint: endpoint, + }) +} + func (cm *ConfigManager) listConfigs() []ConfigRegistration { var res []ConfigRegistration for k := range cm.configs { diff --git a/apiv2/localai.go b/apiv2/localai.go index 8d237a67..1d9d5feb 100644 --- a/apiv2/localai.go +++ b/apiv2/localai.go @@ -1,11 +1,8 @@ package apiv2 import ( - "encoding/json" + "context" "fmt" - "io" - "net/http" - "runtime" "strings" "github.com/mitchellh/mapstructure" @@ -15,222 +12,183 @@ type LocalAIServer struct { configManager *ConfigManager } -type Error struct { - Code int `json:"code"` - Message string `json:"message"` -} - -type ModelOnlyRequest struct { - Model string `json:"model" yaml:"model"` -} - -// This function grabs the name of the function that calls it, skipping up the callstack `skip` levels. -// This is probably a go war crime, but NJ method and all. It's an awesome way to index EndpointConfigMap -func printCurrentFunctionName(skip int) string { - pc, _, _, _ := runtime.Caller(skip) - funcName := runtime.FuncForPC(pc).Name() - fmt.Println("Current function:", funcName) - return funcName -} - -func sendError(w http.ResponseWriter, code int, message string) { - localAiError := Error{ - Code: code, - Message: message, - } - w.WriteHeader(code) - json.NewEncoder(w).Encode(localAiError) -} - -// TODO: Is it a good idea to return "" in cases where the model isn't provided? -// Or is that actually an error condition? -// NO is a decent guess as any to start with? -// r *http.Request -func (server *LocalAIServer) getRequestModelName(body []byte) string { - var modelOnlyRequest = ModelOnlyRequest{} - if err := json.Unmarshal(body, &modelOnlyRequest); err != nil { - fmt.Printf("ERR in getRequestModelName, %+v", err) - return "" - } - return modelOnlyRequest.Model -} - -func (server *LocalAIServer) combineRequestAndConfig(endpointName string, body []byte) (interface{}, error) { - model := server.getRequestModelName(body) - - lookup := ConfigRegistration{Model: model, Endpoint: endpointName} - - config, exists := server.configManager.GetConfig(lookup) - - if !exists { - return nil, fmt.Errorf("Config not found for %+v", lookup) - } - - // fmt.Printf("Model: %s\nConfig: %+v\n", model, config) - - request := config.GetRequestDefaults() - // fmt.Printf("BEFORE rD: %T\n%+v\n\n", request, request) - tmpUnmarshal := map[string]interface{}{} - if err := json.Unmarshal(body, &tmpUnmarshal); err != nil { - return nil, fmt.Errorf("error unmarshalling json to temp map\n%w", err) - } - // fmt.Printf("$$$ tmpUnmarshal: %+v\n", tmpUnmarshal) - mapstructure.Decode(tmpUnmarshal, &request) - fmt.Printf("AFTER rD: %T\n%+v\n\n", request, request) - return request, nil -} - -func (server *LocalAIServer) getRequest(w http.ResponseWriter, r *http.Request) (interface{}, error) { - body, err := io.ReadAll(r.Body) - if err != nil { - sendError(w, http.StatusBadRequest, "Failed to read body") - } +func combineRequestAndConfig[RequestType any](configManager *ConfigManager, model string, requestFromInput *RequestType) (*RequestType, error) { splitFnName := strings.Split(printCurrentFunctionName(2), ".") endpointName := splitFnName[len(splitFnName)-1] - return server.combineRequestAndConfig(endpointName, body) + lookup := ConfigRegistration{Model: model, Endpoint: endpointName} + + config, exists := configManager.GetConfig(lookup) + + if !exists { + return nil, fmt.Errorf("Config not found for %+v", lookup) + } + + // fmt.Printf("Model: %s\nConfig: %+v\nrequestFromInput: %+v\n", model, config, requestFromInput) + + request, ok := config.GetRequestDefaults().(RequestType) + + if !ok { + return nil, fmt.Errorf("Config failed casting for %+v", lookup) + } + + // configMergingConfig := GetConfigMergingDecoderConfig(&request) + // configMergingDecoder, err := mapstructure.NewDecoder(&configMergingConfig) + + // if err != nil { + // return nil, err + // } + + // configMergingDecoder.Decode(requestFromInput) + + // TODO try decoding hooks again later. For testing, do a stupid copy + decodeErr := mapstructure.Decode(structToStrippedMap(*requestFromInput), &request) + + if decodeErr != nil { + return nil, decodeErr + } + + fmt.Printf("AFTER rD: %T\n%+v\n\n", request, request) + + return &request, nil } -// CancelFineTune implements ServerInterface -func (*LocalAIServer) CancelFineTune(w http.ResponseWriter, r *http.Request, fineTuneId string) { +// CancelFineTune implements StrictServerInterface +func (*LocalAIServer) CancelFineTune(ctx context.Context, request CancelFineTuneRequestObject) (CancelFineTuneResponseObject, error) { panic("unimplemented") } -// CreateChatCompletion implements ServerInterface -func (server *LocalAIServer) CreateChatCompletion(w http.ResponseWriter, r *http.Request) { - fmt.Println("HIT APIv2 CreateChatCompletion!") +// CreateChatCompletion implements StrictServerInterface +func (las *LocalAIServer) CreateChatCompletion(ctx context.Context, request CreateChatCompletionRequestObject) (CreateChatCompletionResponseObject, error) { - request, err := server.getRequest(w, r) + chatRequest, err := combineRequestAndConfig(las.configManager, request.Body.Model, request.Body) if err != nil { - sendError(w, http.StatusBadRequest, err.Error()) + fmt.Printf("CreateChatCompletion ERROR combining config and input!\n%s\n", err.Error()) + return nil, err } - // fmt.Printf("\n!!! Survived to attempt cast. BEFORE:\n\tType: %T\n\t%+v", request, request) + fmt.Printf("\n===CreateChatCompletion===\n%+v\n", chatRequest) - chatRequest, castSuccess := request.(CreateChatCompletionRequest) - - if !castSuccess { - sendError(w, http.StatusInternalServerError, "Cast Fail???") - return - } - - fmt.Printf("\n\n!! AFTER !!\ntemperature %f\n top_p %f \n %d\n", *chatRequest.Temperature, *chatRequest.TopP, *chatRequest.XLocalaiExtensions.TopK) + fmt.Printf("\n\n!! TYPED CreateChatCompletion !!\ntemperature %f\n top_p %f \n %d\n", *chatRequest.Temperature, *chatRequest.TopP, *chatRequest.XLocalaiExtensions.TopK) fmt.Printf("chatRequest: %+v\nlen(messages): %d", chatRequest, len(chatRequest.Messages)) for i, m := range chatRequest.Messages { fmt.Printf("message #%d: %+v", i, m) } + + return CreateChatCompletion200JSONResponse{}, nil + + // panic("unimplemented") } -// switch chatRequest := requestDefault.(type) { -// case CreateChatCompletionRequest: - -// CreateCompletion implements ServerInterface -func (*LocalAIServer) CreateCompletion(w http.ResponseWriter, r *http.Request) { +// CreateCompletion implements StrictServerInterface +func (*LocalAIServer) CreateCompletion(ctx context.Context, request CreateCompletionRequestObject) (CreateCompletionResponseObject, error) { panic("unimplemented") } -// CreateEdit implements ServerInterface -func (*LocalAIServer) CreateEdit(w http.ResponseWriter, r *http.Request) { +// CreateEdit implements StrictServerInterface +func (*LocalAIServer) CreateEdit(ctx context.Context, request CreateEditRequestObject) (CreateEditResponseObject, error) { panic("unimplemented") } -// CreateEmbedding implements ServerInterface -func (*LocalAIServer) CreateEmbedding(w http.ResponseWriter, r *http.Request) { +// CreateEmbedding implements StrictServerInterface +func (*LocalAIServer) CreateEmbedding(ctx context.Context, request CreateEmbeddingRequestObject) (CreateEmbeddingResponseObject, error) { panic("unimplemented") } -// CreateFile implements ServerInterface -func (*LocalAIServer) CreateFile(w http.ResponseWriter, r *http.Request) { +// CreateFile implements StrictServerInterface +func (*LocalAIServer) CreateFile(ctx context.Context, request CreateFileRequestObject) (CreateFileResponseObject, error) { panic("unimplemented") } -// CreateFineTune implements ServerInterface -func (*LocalAIServer) CreateFineTune(w http.ResponseWriter, r *http.Request) { +// CreateFineTune implements StrictServerInterface +func (*LocalAIServer) CreateFineTune(ctx context.Context, request CreateFineTuneRequestObject) (CreateFineTuneResponseObject, error) { panic("unimplemented") } -// CreateImage implements ServerInterface -func (*LocalAIServer) CreateImage(w http.ResponseWriter, r *http.Request) { +// CreateImage implements StrictServerInterface +func (*LocalAIServer) CreateImage(ctx context.Context, request CreateImageRequestObject) (CreateImageResponseObject, error) { panic("unimplemented") } -// CreateImageEdit implements ServerInterface -func (*LocalAIServer) CreateImageEdit(w http.ResponseWriter, r *http.Request) { +// CreateImageEdit implements StrictServerInterface +func (*LocalAIServer) CreateImageEdit(ctx context.Context, request CreateImageEditRequestObject) (CreateImageEditResponseObject, error) { panic("unimplemented") } -// CreateImageVariation implements ServerInterface -func (*LocalAIServer) CreateImageVariation(w http.ResponseWriter, r *http.Request) { +// CreateImageVariation implements StrictServerInterface +func (*LocalAIServer) CreateImageVariation(ctx context.Context, request CreateImageVariationRequestObject) (CreateImageVariationResponseObject, error) { panic("unimplemented") } -// CreateModeration implements ServerInterface -func (*LocalAIServer) CreateModeration(w http.ResponseWriter, r *http.Request) { +// CreateModeration implements StrictServerInterface +func (*LocalAIServer) CreateModeration(ctx context.Context, request CreateModerationRequestObject) (CreateModerationResponseObject, error) { panic("unimplemented") } -// CreateTranscription implements ServerInterface -func (*LocalAIServer) CreateTranscription(w http.ResponseWriter, r *http.Request) { +// CreateTranscription implements StrictServerInterface +func (*LocalAIServer) CreateTranscription(ctx context.Context, request CreateTranscriptionRequestObject) (CreateTranscriptionResponseObject, error) { panic("unimplemented") } -// CreateTranslation implements ServerInterface -func (*LocalAIServer) CreateTranslation(w http.ResponseWriter, r *http.Request) { +// CreateTranslation implements StrictServerInterface +func (*LocalAIServer) CreateTranslation(ctx context.Context, request CreateTranslationRequestObject) (CreateTranslationResponseObject, error) { panic("unimplemented") } -// DeleteFile implements ServerInterface -func (*LocalAIServer) DeleteFile(w http.ResponseWriter, r *http.Request, fileId string) { +// DeleteFile implements StrictServerInterface +func (*LocalAIServer) DeleteFile(ctx context.Context, request DeleteFileRequestObject) (DeleteFileResponseObject, error) { panic("unimplemented") } -// DeleteModel implements ServerInterface -func (*LocalAIServer) DeleteModel(w http.ResponseWriter, r *http.Request, model string) { +// DeleteModel implements StrictServerInterface +func (*LocalAIServer) DeleteModel(ctx context.Context, request DeleteModelRequestObject) (DeleteModelResponseObject, error) { panic("unimplemented") } -// DownloadFile implements ServerInterface -func (*LocalAIServer) DownloadFile(w http.ResponseWriter, r *http.Request, fileId string) { +// DownloadFile implements StrictServerInterface +func (*LocalAIServer) DownloadFile(ctx context.Context, request DownloadFileRequestObject) (DownloadFileResponseObject, error) { panic("unimplemented") } -// ListFiles implements ServerInterface -func (*LocalAIServer) ListFiles(w http.ResponseWriter, r *http.Request) { +// ListFiles implements StrictServerInterface +func (*LocalAIServer) ListFiles(ctx context.Context, request ListFilesRequestObject) (ListFilesResponseObject, error) { panic("unimplemented") } -// ListFineTuneEvents implements ServerInterface -func (*LocalAIServer) ListFineTuneEvents(w http.ResponseWriter, r *http.Request, fineTuneId string, params ListFineTuneEventsParams) { +// ListFineTuneEvents implements StrictServerInterface +func (*LocalAIServer) ListFineTuneEvents(ctx context.Context, request ListFineTuneEventsRequestObject) (ListFineTuneEventsResponseObject, error) { panic("unimplemented") } -// ListFineTunes implements ServerInterface -func (*LocalAIServer) ListFineTunes(w http.ResponseWriter, r *http.Request) { +// ListFineTunes implements StrictServerInterface +func (*LocalAIServer) ListFineTunes(ctx context.Context, request ListFineTunesRequestObject) (ListFineTunesResponseObject, error) { panic("unimplemented") } -// ListModels implements ServerInterface -func (*LocalAIServer) ListModels(w http.ResponseWriter, r *http.Request) { +// ListModels implements StrictServerInterface +func (*LocalAIServer) ListModels(ctx context.Context, request ListModelsRequestObject) (ListModelsResponseObject, error) { panic("unimplemented") } -// RetrieveFile implements ServerInterface -func (*LocalAIServer) RetrieveFile(w http.ResponseWriter, r *http.Request, fileId string) { +// RetrieveFile implements StrictServerInterface +func (*LocalAIServer) RetrieveFile(ctx context.Context, request RetrieveFileRequestObject) (RetrieveFileResponseObject, error) { panic("unimplemented") } -// RetrieveFineTune implements ServerInterface -func (*LocalAIServer) RetrieveFineTune(w http.ResponseWriter, r *http.Request, fineTuneId string) { +// RetrieveFineTune implements StrictServerInterface +func (*LocalAIServer) RetrieveFineTune(ctx context.Context, request RetrieveFineTuneRequestObject) (RetrieveFineTuneResponseObject, error) { panic("unimplemented") } -// RetrieveModel implements ServerInterface -func (*LocalAIServer) RetrieveModel(w http.ResponseWriter, r *http.Request, model string) { +// RetrieveModel implements StrictServerInterface +func (*LocalAIServer) RetrieveModel(ctx context.Context, request RetrieveModelRequestObject) (RetrieveModelResponseObject, error) { panic("unimplemented") } -var _ ServerInterface = (*LocalAIServer)(nil) +var _ StrictServerInterface = (*LocalAIServer)(nil) + +// var _ ServerInterface = NewStrictHandler((*LocalAIServer)(nil), nil) diff --git a/apiv2/localai.go.old b/apiv2/localai.go.old deleted file mode 100644 index 7a76288f..00000000 --- a/apiv2/localai.go.old +++ /dev/null @@ -1,196 +0,0 @@ -package apiv2 - -import ( - "encoding/json" - "fmt" - "net/http" -) - -type LocalAIServer struct { - configMerger *ConfigMerger -} - -var _ ServerInterface = (*LocalAIServer)(nil) - -type Error struct { - Code int `json:"code"` - Message string `json:"message"` -} - -func sendError(w http.ResponseWriter, code int, message string) { - localAiError := Error{ - Code: code, - Message: message, - } - w.WriteHeader(code) - json.NewEncoder(w).Encode(localAiError) -} - -// It won't work, but it's worth a try. -const nyiErrorMessageFormatString = "%s is not yet implemented by LocalAI\nThere is no need to contact support about this error and retrying will not help.\nExpect an update at https://github.com/go-skynet/LocalAI if this changes!" - -// Do we want or need an additional "wontfix" template that is even stronger than this? -const nyiDepreciatedErrorMessageFormatString = "%s is a depreciated portion of the OpenAI API, and is not yet implemented by LocalAI\nThere is no need to contact support about this error and retrying will not help." - -// CancelFineTune implements ServerInterface -func (*LocalAIServer) CancelFineTune(w http.ResponseWriter, r *http.Request, fineTuneId string) { - sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "Fine Tune")) - return -} - -// CreateAnswer implements ServerInterface -func (*LocalAIServer) CreateAnswer(w http.ResponseWriter, r *http.Request) { - sendError(w, 501, fmt.Sprintf(nyiDepreciatedErrorMessageFormatString, "CreateAnswer")) - return -} - -// CreateChatCompletion implements ServerInterface -func (*LocalAIServer) CreateChatCompletion(w http.ResponseWriter, r *http.Request) { - var chatRequest CreateChatCompletionRequest - if err := json.NewDecoder(r.Body).Decode(&chatRequest); err != nil { - sendError(w, http.StatusBadRequest, "Invalid CreateChatCompletionRequest") - return - } - configMerger.GetConfig(chatRequest.Model) -} - -// CreateClassification implements ServerInterface -func (*LocalAIServer) CreateClassification(w http.ResponseWriter, r *http.Request) { - sendError(w, 501, fmt.Sprintf(nyiDepreciatedErrorMessageFormatString, "CreateClassification")) - return -} - -// CreateCompletion implements ServerInterface -func (*LocalAIServer) CreateCompletion(w http.ResponseWriter, r *http.Request) { - panic("unimplemented") -} - -// CreateEdit implements ServerInterface -func (*LocalAIServer) CreateEdit(w http.ResponseWriter, r *http.Request) { - sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "CreateEdit")) - return -} - -// CreateEmbedding implements ServerInterface -func (*LocalAIServer) CreateEmbedding(w http.ResponseWriter, r *http.Request) { - panic("unimplemented") -} - -// CreateFile implements ServerInterface -func (*LocalAIServer) CreateFile(w http.ResponseWriter, r *http.Request) { - sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "Create File")) - return -} - -// CreateFineTune implements ServerInterface -func (*LocalAIServer) CreateFineTune(w http.ResponseWriter, r *http.Request) { - sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "Create Fine Tune")) - return -} - -// CreateImage implements ServerInterface -func (*LocalAIServer) CreateImage(w http.ResponseWriter, r *http.Request) { - panic("unimplemented") -} - -// CreateImageEdit implements ServerInterface -func (*LocalAIServer) CreateImageEdit(w http.ResponseWriter, r *http.Request) { - panic("unimplemented") -} - -// CreateImageVariation implements ServerInterface -func (*LocalAIServer) CreateImageVariation(w http.ResponseWriter, r *http.Request) { - panic("unimplemented") -} - -// CreateModeration implements ServerInterface -func (*LocalAIServer) CreateModeration(w http.ResponseWriter, r *http.Request) { - sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "CreateModeration")) - return -} - -// CreateSearch implements ServerInterface -func (*LocalAIServer) CreateSearch(w http.ResponseWriter, r *http.Request, engineId string) { - sendError(w, 501, fmt.Sprintf(nyiDepreciatedErrorMessageFormatString, "CreateSearch")) - return -} - -// CreateTranscription implements ServerInterface -func (*LocalAIServer) CreateTranscription(w http.ResponseWriter, r *http.Request) { - panic("unimplemented") -} - -// CreateTranslation implements ServerInterface -func (*LocalAIServer) CreateTranslation(w http.ResponseWriter, r *http.Request) { - panic("unimplemented") -} - -// DeleteFile implements ServerInterface -func (*LocalAIServer) DeleteFile(w http.ResponseWriter, r *http.Request, fileId string) { - sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "DeleteFile")) - return -} - -// DeleteModel implements ServerInterface -func (*LocalAIServer) DeleteModel(w http.ResponseWriter, r *http.Request, model string) { - sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "DeleteModel")) - return -} - -// DownloadFile implements ServerInterface -func (*LocalAIServer) DownloadFile(w http.ResponseWriter, r *http.Request, fileId string) { - sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "DownloadFile")) - return -} - -// ListEngines implements ServerInterface -func (*LocalAIServer) ListEngines(w http.ResponseWriter, r *http.Request) { - sendError(w, 501, fmt.Sprintf(nyiDepreciatedErrorMessageFormatString, "List Engines")) - return -} - -// ListFiles implements ServerInterface -func (*LocalAIServer) ListFiles(w http.ResponseWriter, r *http.Request) { - sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "ListFiles")) - return -} - -// ListFineTuneEvents implements ServerInterface -func (*LocalAIServer) ListFineTuneEvents(w http.ResponseWriter, r *http.Request, fineTuneId string, params ListFineTuneEventsParams) { - sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "List Fine Tune Events")) - return -} - -// ListFineTunes implements ServerInterface -func (*LocalAIServer) ListFineTunes(w http.ResponseWriter, r *http.Request) { - sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "List Fine Tunes")) - return -} - -// ListModels implements ServerInterface -func (*LocalAIServer) ListModels(w http.ResponseWriter, r *http.Request) { - panic("unimplemented") -} - -// RetrieveEngine implements ServerInterface -func (*LocalAIServer) RetrieveEngine(w http.ResponseWriter, r *http.Request, engineId string) { - sendError(w, 501, fmt.Sprintf(nyiDepreciatedErrorMessageFormatString, "RetrieveEngine")) - return -} - -// RetrieveFile implements ServerInterface -func (*LocalAIServer) RetrieveFile(w http.ResponseWriter, r *http.Request, fileId string) { - sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "RetrieveFile")) - return -} - -// RetrieveFineTune implements ServerInterface -func (*LocalAIServer) RetrieveFineTune(w http.ResponseWriter, r *http.Request, fineTuneId string) { - sendError(w, 501, fmt.Sprintf(nyiErrorMessageFormatString, "Retrieve Fine Tune")) - return -} - -// RetrieveModel implements ServerInterface -func (*LocalAIServer) RetrieveModel(w http.ResponseWriter, r *http.Request, model string) { - panic("unimplemented") -} diff --git a/apiv2/localai_nethttp.go b/apiv2/localai_nethttp.go index ab1527b9..d6e8c037 100644 --- a/apiv2/localai_nethttp.go +++ b/apiv2/localai_nethttp.go @@ -7,7 +7,9 @@ func NewLocalAINetHTTPServer(configManager *ConfigManager) *LocalAIServer { configManager: configManager, } - http.Handle("/", Handler(&localAI)) + var middlewares []StrictMiddlewareFunc + + http.Handle("/", Handler(NewStrictHandler(&localAI, middlewares))) http.ListenAndServe(":8085", nil) return &localAI diff --git a/apiv2/util.go b/apiv2/util.go new file mode 100644 index 00000000..353690a8 --- /dev/null +++ b/apiv2/util.go @@ -0,0 +1,66 @@ +package apiv2 + +import ( + "fmt" + "reflect" + "runtime" + // "github.com/mitchellh/mapstructure" +) + +// Not sure if there's a better place for this, so stuff it in a utils.go for now + +// This function grabs the name of the function that calls it, skipping up the callstack `skip` levels. +// This is probably a go war crime, but NJ method and all. It's an awesome way to index EndpointConfigMap +func printCurrentFunctionName(skip int) string { + pc, _, _, _ := runtime.Caller(skip) + funcName := runtime.FuncForPC(pc).Name() + fmt.Println("Current function:", funcName) + return funcName +} + +// This is another dubious one - Decode hooks are hard.... so this function just takes a perfectly good struct and copies all the nonzero fields out to a map[string]interface{} which mapstructure handles correctly already :) +func structToStrippedMap(s interface{}) map[string]interface{} { + m := make(map[string]interface{}) + + // Get the reflect.Value of the struct + v := reflect.ValueOf(s) + + // Get the reflect.Type of the struct + t := v.Type() + + // Iterate over each field of the struct + for i := 0; i < v.NumField(); i++ { + // Get the field's reflect.Value + fieldValue := v.Field(i) + + // Get the field's reflect.StructField + field := t.Field(i) + + // Skip unexported fields and zero values + if !fieldValue.CanInterface() || fieldValue.IsZero() { + continue + } + + // Add the field name and value to the map + m[field.Name] = fieldValue.Interface() + } + + return m +} + +// func NilToEmptyHook(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) { +// fmt.Printf("* to.Kind %+v\ndata: %+v", to.Kind(), data) +// if to.Kind() == reflect.Ptr && reflect.ValueOf(data).IsNil() { +// fmt.Println("!!!!!HIT") +// return reflect.Zero(to).Interface(), nil +// } +// return data, nil +// } + +// func GetConfigMergingDecoderConfig(result interface{}) mapstructure.DecoderConfig { + +// return mapstructure.DecoderConfig{ +// Result: result, +// DecodeHook: mapstructure.ComposeDecodeHookFunc(NilToEmptyHook), +// } +// } diff --git a/go.mod b/go.mod index 1eaed2e4..d45a8358 100644 --- a/go.mod +++ b/go.mod @@ -98,7 +98,7 @@ require ( golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.8.0 // indirect golang.org/x/text v0.9.0 // indirect - golang.org/x/tools v0.9.1 // indirect + golang.org/x/tools v0.9.2 // indirect google.golang.org/protobuf v1.30.0 // indirect ) @@ -118,4 +118,4 @@ replace github.com/go-skynet/bloomz.cpp => /workspace/bloomz replace github.com/mudler/go-stable-diffusion => /workspace/go-stable-diffusion -replace github.com/deepmap/oapi-codegen v1.12.4 => github.com/dave-gray101/oapi-codegen v0.0.0-20230601032358-055c3446c85e +replace github.com/deepmap/oapi-codegen v1.12.4 => github.com/dave-gray101/oapi-codegen v0.0.0-20230601175843-6acf0cf32d63 diff --git a/go.sum b/go.sum index e9ce1795..be0f0fec 100644 --- a/go.sum +++ b/go.sum @@ -23,8 +23,8 @@ github.com/cppforlife/go-cli-ui v0.0.0-20220622150351-995494831c6c/go.mod h1:ci7 github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/dave-gray101/oapi-codegen v0.0.0-20230601032358-055c3446c85e h1:yTYuNvvxVelDSrUDt9b96CRL0Iyo7IjgjnkbjzBymi4= -github.com/dave-gray101/oapi-codegen v0.0.0-20230601032358-055c3446c85e/go.mod h1:rey/E8Zmlg0o3jo02vrDZMSv6YeWY/I8j3FTeR+78EU= +github.com/dave-gray101/oapi-codegen v0.0.0-20230601175843-6acf0cf32d63 h1:17JdfrnUg7bInlq0HSv0gKxI0iGK5LcoOaezXkBXbx4= +github.com/dave-gray101/oapi-codegen v0.0.0-20230601175843-6acf0cf32d63/go.mod h1:nr56bxUaGXFVFkQHtsrX2OUSN2yjMoEyYRFoA4/Cq2Y= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -281,8 +281,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20201022035929-9cf592e881e9/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= -golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= -golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= +golang.org/x/tools v0.9.2 h1:UXbndbirwCAx6TULftIfie/ygDNCwxEie+IiNP1IcNc= +golang.org/x/tools v0.9.2/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/openai-openapi/chi-interface.tmpl b/openai-openapi/chi-interface.tmpl new file mode 100644 index 00000000..be17195f --- /dev/null +++ b/openai-openapi/chi-interface.tmpl @@ -0,0 +1,17 @@ +// ServerInterface represents all server handlers. +type ServerInterface interface { +{{range .}}{{.SummaryAsComment }} +// ({{.Method}} {{.Path}}) +{{.OperationId}}(w http.ResponseWriter, r *http.Request{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params {{.OperationId}}Params{{end}}) +{{end}} +} + +// TypedServerInterface is used to give each endpoint a fully typed method signature for cases where we're able to route automatically +type TypedServerInterface interface { +{{range .}}{{.SummaryAsComment }} +// ({{.Method}} {{.Path}}) +{{$reqBody := genDefaultRequestBodyType . -}} +{{- if ne $reqBody "" }}{{$reqBody = printf ", body %s" $reqBody}}{{end -}} +{{.OperationId}}(w http.ResponseWriter{{genParamArgs .PathParams}}{{if .RequiresParamObject}}, params {{.OperationId}}Params{{end}}{{$reqBody}}) +{{end}} +} \ No newline at end of file diff --git a/openai-openapi/config.yaml b/openai-openapi/config.yaml index 7f61899e..5f6fc434 100644 --- a/openai-openapi/config.yaml +++ b/openai-openapi/config.yaml @@ -2,6 +2,7 @@ package: apiv2 generate: models: true chi-server: true + strict-server: true output: apiv2/localai.gen.go output-options: exclude-depreciated: true @@ -10,6 +11,7 @@ output-options: - mapstructure user-templates: endpoint-body-mapping.tmpl: ./openai-openapi/endpoint-body-mapping.tmpl + # chi/chi-interface.tmpl: ./openai-openapi/chi-interface.tmpl # union.tmpl: "// SKIP" # union-and-additional-properties.tmpl: "// SKIP" # additional-properties.tmpl: "// SKIP" \ No newline at end of file diff --git a/openai-openapi/endpoint-body-mapping.tmpl b/openai-openapi/endpoint-body-mapping.tmpl index 0b656add..d3dc8db6 100644 --- a/openai-openapi/endpoint-body-mapping.tmpl +++ b/openai-openapi/endpoint-body-mapping.tmpl @@ -1,19 +1,20 @@ // TEMP: Consider revising this in oapi-codegen to make cleaner or at least renaming.... -//type EndpointSpecificConfig interface { -// GetRequestDefaults() interface{} -//} + + +var EndpointToRequestBodyMap = map[string]interface{}{ +{{range .}}{{$opid := .OperationId -}} + {{ $reqBody := genDefaultRequestBodyType . -}}{{if ne $reqBody "" -}} + "{{$opid}}":{{genDefaultRequestBodyType .}}{}, + {{end -}} +{{end -}} +} + var EndpointConfigMap = map[string]Config{ {{range .}}{{$opid := .OperationId -}} - {{if eq (len .Bodies) 1 -}} - {{with index .Bodies 0}}{{ $typeDef := .TypeDef $opid -}} - "{{$opid}}":SpecificConfig[{{$typeDef.TypeName}}]{}, - {{end -}} + {{ $reqBody := genDefaultRequestBodyType . -}}{{if ne $reqBody "" -}} + "{{$opid}}":SpecificConfig[{{genDefaultRequestBodyType .}}]{}, {{else -}} - {{range .Bodies -}} - {{if and .IsSupported .Default -}} - "{{$opid}}":SpecificConfig[{{.TypeName}}]{},{{break -}} - {{end -}} - {{end -}} + "{{$opid}}":ConfigStub{}, {{end -}} {{end -}} } \ No newline at end of file diff --git a/openai-openapi/localai_model_patches.yaml b/openai-openapi/localai_model_patches.yaml index 106ecaf8..72046790 100644 --- a/openai-openapi/localai_model_patches.yaml +++ b/openai-openapi/localai_model_patches.yaml @@ -58,6 +58,9 @@ components: - type: object nullable: true properties: + model: + type: string + nullable: false mode: type: integer nullable: true diff --git a/openai-openapi/remove_depreciated_openapi.yaml b/openai-openapi/remove_depreciated_openapi.yaml deleted file mode 100644 index 2f2b079c..00000000 --- a/openai-openapi/remove_depreciated_openapi.yaml +++ /dev/null @@ -1,12 +0,0 @@ -#! This doesn't actually work or get used yet. It's a scratch space for an idea of mine - we might want to just strip out and ignore depreciated portions of the spec. Not reliable enough to use yet. - -#@ load("@ytt:overlay", "overlay") - -#@overlay/match by=overlay.all, expects="1+" ---- -paths: - #!overlay/match by=overlay.subset({"deprecated": True}) - #@overlay/remove - *: - *: - depreciated: True \ No newline at end of file