feat: add image generation with ncnn-stablediffusion (#272)

This commit is contained in:
Ettore Di Giacinto 2023-05-16 19:32:53 +02:00 committed by GitHub
parent acd03d15f2
commit 9d051c5d4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 582 additions and 58 deletions

View file

@ -8,11 +8,12 @@ import (
"github.com/donomii/go-rwkv.cpp"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
"github.com/go-skynet/bloomz.cpp"
bert "github.com/go-skynet/go-bert.cpp"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
llama "github.com/go-skynet/go-llama.cpp"
gpt4all "github.com/nomic/gpt4all/gpt4all-bindings/golang"
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
)
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
@ -38,6 +39,45 @@ func defaultLLamaOpts(c Config) []llama.ModelOption {
return llamaOpts
}
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config) (func() error, error) {
if c.Backend != model.StableDiffusionBackend {
return nil, fmt.Errorf("endpoint only working with stablediffusion models")
}
inferenceModel, err := loader.BackendLoader(c.Backend, c.ImageGenerationAssets, []llama.ModelOption{}, uint32(c.Threads))
if err != nil {
return nil, err
}
var fn func() error
switch model := inferenceModel.(type) {
case *stablediffusion.StableDiffusion:
fn = func() error {
return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst)
}
default:
fn = func() error {
return fmt.Errorf("creation of images not supported by the backend")
}
}
return func() error {
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
mutexMap.Lock()
l, ok := mutexes[c.Backend]
if !ok {
m := &sync.Mutex{}
mutexes[c.Backend] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock()
return fn()
}, nil
}
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) {
if !c.Embeddings {
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")