plumbing for cli context and loader, stash before picking up initializers

This commit is contained in:
Dave Lee 2023-06-01 23:43:34 -04:00
parent 45285bb5d8
commit 1ef6ba2b52
No known key found for this signature in database
6 changed files with 139 additions and 34 deletions

View file

@ -5,14 +5,16 @@ import (
"fmt"
"strings"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/mitchellh/mapstructure"
)
type LocalAIServer struct {
configManager *ConfigManager
loader *model.ModelLoader
}
func combineRequestAndConfig[RequestType any](configManager *ConfigManager, model string, requestFromInput *RequestType) (*RequestType, error) {
func combineRequestAndConfig[RequestType any](configManager *ConfigManager, model string, requestFromInput *RequestType) (*SpecificConfig[RequestType], error) {
splitFnName := strings.Split(printCurrentFunctionName(2), ".")
@ -52,7 +54,17 @@ func combineRequestAndConfig[RequestType any](configManager *ConfigManager, mode
fmt.Printf("AFTER rD: %T\n%+v\n\n", request, request)
return &request, nil
return &SpecificConfig[RequestType]{
ConfigStub: ConfigStub{
Registration: config.GetRegistration(),
LocalPaths: config.GetLocalPaths(),
},
RequestDefaults: request,
}, nil
}
func (las *LocalAIServer) loadModel(configStub ConfigStub) {
}
// CancelFineTune implements StrictServerInterface
@ -63,13 +75,15 @@ func (*LocalAIServer) CancelFineTune(ctx context.Context, request CancelFineTune
// CreateChatCompletion implements StrictServerInterface
func (las *LocalAIServer) CreateChatCompletion(ctx context.Context, request CreateChatCompletionRequestObject) (CreateChatCompletionResponseObject, error) {
chatRequest, err := combineRequestAndConfig(las.configManager, request.Body.Model, request.Body)
chatRequestConfig, err := combineRequestAndConfig(las.configManager, request.Body.Model, request.Body)
if err != nil {
fmt.Printf("CreateChatCompletion ERROR combining config and input!\n%s\n", err.Error())
return nil, err
}
chatRequest := chatRequestConfig.RequestDefaults
fmt.Printf("\n===CreateChatCompletion===\n%+v\n", chatRequest)
fmt.Printf("\n\n!! TYPED CreateChatCompletion !!\ntemperature %f\n top_p %f \n %d\n", *chatRequest.Temperature, *chatRequest.TopP, *chatRequest.XLocalaiExtensions.TopK)
@ -85,8 +99,48 @@ func (las *LocalAIServer) CreateChatCompletion(ctx context.Context, request Crea
}
// CreateCompletion implements StrictServerInterface
func (*LocalAIServer) CreateCompletion(ctx context.Context, request CreateCompletionRequestObject) (CreateCompletionResponseObject, error) {
panic("unimplemented")
func (las *LocalAIServer) CreateCompletion(ctx context.Context, request CreateCompletionRequestObject) (CreateCompletionResponseObject, error) {
modelName := request.Body.Model
config, err := combineRequestAndConfig(las.configManager, modelName, request.Body)
if err != nil {
fmt.Printf("CreateCompletion ERROR combining config and input!\n%s\n", err.Error())
return nil, err
}
req := config.GetRequest()
fmt.Printf("\n===CreateCompletion===\n%+v\n", req)
var choices []CreateCompletionResponseChoice
prompts, err := req.Prompt.AsCreateCompletionRequestPrompt1()
if err != nil {
tokenPrompt, err := req.Prompt.AsCreateCompletionRequestPrompt2()
if err == nil {
fmt.Printf("Scary token array length %d\n", len(tokenPrompt))
panic("Token array is scary and phase 2")
}
singlePrompt, err := req.Prompt.AsCreateCompletionRequestPrompt0()
if err != nil {
return nil, err
}
prompts = []string{singlePrompt}
}
// model := las.loader.LoadModel(modelName, )
for _, v := range prompts {
fmt.Printf("[prompt] %s\n", v)
}
return CreateCompletion200JSONResponse{
Model: modelName,
Choices: choices,
}, nil
}
// CreateEdit implements StrictServerInterface