feat: add tiny dream stable diffusion support (#1283)

Signed-off-by: Gianluca Boiano <morf3089@gmail.com>
This commit is contained in:
Gianluca Boiano 2023-12-24 20:27:24 +01:00 committed by GitHub
parent f7621b2c6c
commit cae7b197ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 161 additions and 30 deletions

View file

@ -40,6 +40,7 @@ const (
RwkvBackend = "rwkv"
WhisperBackend = "whisper"
StableDiffusionBackend = "stablediffusion"
TinyDreamBackend = "tinydream"
PiperBackend = "piper"
LCHuggingFaceBackend = "langchain-huggingface"
@ -64,6 +65,7 @@ var AutoLoadBackends []string = []string{
RwkvBackend,
WhisperBackend,
StableDiffusionBackend,
TinyDreamBackend,
PiperBackend,
}

36
pkg/tinydream/generate.go Normal file
View file

@ -0,0 +1,36 @@
//go:build tinydream
// +build tinydream
package tinydream
import (
"fmt"
"path/filepath"
tinyDream "github.com/M0Rf30/go-tiny-dream"
)
func GenerateImage(height, width, step, seed int, positive_prompt, negative_prompt, dst, asset_dir string) error {
fmt.Println(dst)
if height > 512 || width > 512 {
return tinyDream.GenerateImage(
1,
step,
seed,
positive_prompt,
negative_prompt,
filepath.Dir(dst),
asset_dir,
)
}
return tinyDream.GenerateImage(
0,
step,
seed,
positive_prompt,
negative_prompt,
filepath.Dir(dst),
asset_dir,
)
}

View file

@ -0,0 +1,10 @@
//go:build !tinydream
// +build !tinydream
package tinydream
import "fmt"
func GenerateImage(height, width, step, seed int, positive_prompt, negative_prompt, dst, asset_dir string) error {
return fmt.Errorf("This version of LocalAI was built without the tinytts tag")
}

View file

@ -0,0 +1,20 @@
package tinydream
import "os"
type TinyDream struct {
assetDir string
}
func New(assetDir string) (*TinyDream, error) {
if _, err := os.Stat(assetDir); err != nil {
return nil, err
}
return &TinyDream{
assetDir: assetDir,
}, nil
}
func (td *TinyDream) GenerateImage(height, width, step, seed int, positive_prompt, negative_prompt, dst string) error {
return GenerateImage(height, width, step, seed, positive_prompt, negative_prompt, dst, td.assetDir)
}