feat: add image generation with ncnn-stablediffusion (#272)

This commit is contained in:
Ettore Di Giacinto 2023-05-16 19:32:53 +02:00 committed by GitHub
parent acd03d15f2
commit 9d051c5d4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 582 additions and 58 deletions

View file

@ -32,6 +32,7 @@ type Config struct {
MirostatTAU float64 `yaml:"mirostat_tau"`
Mirostat int `yaml:"mirostat"`
NGPULayers int `yaml:"gpu_layers"`
ImageGenerationAssets string `yaml:"asset_dir"`
PromptStrings, InputStrings []string
InputToken [][]int
}
@ -211,12 +212,11 @@ func updateConfig(config *Config, input *OpenAIRequest) {
}
}
}
func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) {
input := new(OpenAIRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return nil, nil, err
return "", nil, err
}
modelFile := input.Model
@ -234,14 +234,14 @@ func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelFile == "" && !bearerExists {
if modelFile == "" && !bearerExists && randomModel {
models, _ := loader.ListModels()
if len(models) > 0 {
modelFile = models[0]
log.Debug().Msgf("No model specified, using: %s", modelFile)
} else {
log.Debug().Msgf("No model specified, returning error")
return nil, nil, fmt.Errorf("no model specified")
return "", nil, fmt.Errorf("no model specified")
}
}
@ -250,7 +250,10 @@ func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug
log.Debug().Msgf("Using model from bearer token: %s", bearer)
modelFile = bearer
}
return modelFile, input, nil
}
func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
// Load a config file if present after the model name
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
if _, err := os.Stat(modelConfig); err == nil {