manual merge, build testing starting from here

This commit is contained in:
Dave Lee 2023-06-05 21:27:19 -04:00
commit 8e6e7c456b
No known key found for this signature in database
67 changed files with 2095 additions and 1617 deletions

51
pkg/assets/extract.go Normal file
View file

@ -0,0 +1,51 @@
package assets
import (
"embed"
"fmt"
"io/fs"
"os"
"path/filepath"
)
func ExtractFiles(content embed.FS, extractDir string) error {
// Create the target directory if it doesn't exist
err := os.MkdirAll(extractDir, 0755)
if err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
// Walk through the embedded FS and extract files
err = fs.WalkDir(content, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
// Reconstruct the directory structure in the target directory
targetFile := filepath.Join(extractDir, path)
if d.IsDir() {
// Create the directory in the target directory
err := os.MkdirAll(targetFile, 0755)
if err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
return nil
}
// Read the file from the embedded FS
fileData, err := content.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read file: %v", err)
}
// Create the file in the target directory
err = os.WriteFile(targetFile, fileData, 0644)
if err != nil {
return fmt.Errorf("failed to write file: %v", err)
}
return nil
})
return err
}

View file

@ -0,0 +1,47 @@
package langchain
import (
"context"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/huggingface"
)
type HuggingFace struct {
modelPath string
}
func NewHuggingFace(repoId string) (*HuggingFace, error) {
return &HuggingFace{
modelPath: repoId,
}, nil
}
func (s *HuggingFace) PredictHuggingFace(text string, opts ...PredictOption) (*Predict, error) {
po := NewPredictOptions(opts...)
// Init client
llm, err := huggingface.New()
if err != nil {
return nil, err
}
// Convert from LocalAI to LangChainGo format of options
co := []llms.CallOption{
llms.WithModel(po.Model),
llms.WithMaxTokens(po.MaxTokens),
llms.WithTemperature(po.Temperature),
llms.WithStopWords(po.StopWords),
}
// Call Inference API
ctx := context.Background()
completion, err := llm.Call(ctx, text, co...)
if err != nil {
return nil, err
}
return &Predict{
Completion: completion,
}, nil
}

View file

@ -0,0 +1,57 @@
package langchain
type PredictOptions struct {
Model string `json:"model"`
// MaxTokens is the maximum number of tokens to generate.
MaxTokens int `json:"max_tokens"`
// Temperature is the temperature for sampling, between 0 and 1.
Temperature float64 `json:"temperature"`
// StopWords is a list of words to stop on.
StopWords []string `json:"stop_words"`
}
type PredictOption func(p *PredictOptions)
var DefaultOptions = PredictOptions{
Model: "gpt2",
MaxTokens: 200,
Temperature: 0.96,
StopWords: nil,
}
type Predict struct {
Completion string
}
func SetModel(model string) PredictOption {
return func(o *PredictOptions) {
o.Model = model
}
}
func SetTemperature(temperature float64) PredictOption {
return func(o *PredictOptions) {
o.Temperature = temperature
}
}
func SetMaxTokens(maxTokens int) PredictOption {
return func(o *PredictOptions) {
o.MaxTokens = maxTokens
}
}
func SetStopWords(stopWords []string) PredictOption {
return func(o *PredictOptions) {
o.StopWords = stopWords
}
}
// NewPredictOptions Create a new PredictOptions object with the given options.
func NewPredictOptions(opts ...PredictOption) PredictOptions {
p := DefaultOptions
for _, opt := range opts {
opt(&p)
}
return p
}

View file

@ -7,10 +7,11 @@ import (
rwkv "github.com/donomii/go-rwkv.cpp"
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-skynet/LocalAI/pkg/langchain"
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
bloomz "github.com/go-skynet/bloomz.cpp"
bert "github.com/go-skynet/go-bert.cpp"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/hashicorp/go-multierror"
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
@ -23,61 +24,61 @@ const (
LlamaBackend = "llama"
BloomzBackend = "bloomz"
StarcoderBackend = "starcoder"
StableLMBackend = "stablelm"
GPTJBackend = "gptj"
DollyBackend = "dolly"
RedPajamaBackend = "redpajama"
MPTBackend = "mpt"
GPTNeoXBackend = "gptneox"
ReplitBackend = "replit"
Gpt2Backend = "gpt2"
Gpt4AllLlamaBackend = "gpt4all-llama"
Gpt4AllMptBackend = "gpt4all-mpt"
Gpt4AllJBackend = "gpt4all-j"
Gpt4All = "gpt4all"
BertEmbeddingsBackend = "bert-embeddings"
RwkvBackend = "rwkv"
WhisperBackend = "whisper"
StableDiffusionBackend = "stablediffusion"
LCHuggingFaceBackend = "langchain-huggingface"
)
var backends []string = []string{
LlamaBackend,
Gpt4AllLlamaBackend,
Gpt4AllMptBackend,
Gpt4AllJBackend,
Gpt2Backend,
WhisperBackend,
Gpt4All,
RwkvBackend,
BloomzBackend,
StableLMBackend,
DollyBackend,
RedPajamaBackend,
ReplitBackend,
GPTNeoXBackend,
WhisperBackend,
BertEmbeddingsBackend,
GPTJBackend,
Gpt2Backend,
DollyBackend,
MPTBackend,
ReplitBackend,
StarcoderBackend,
BloomzBackend,
}
var starCoder = func(modelFile string) (interface{}, error) {
return gpt2.NewStarcoder(modelFile)
return transformers.NewStarcoder(modelFile)
}
var redPajama = func(modelFile string) (interface{}, error) {
return gpt2.NewRedPajama(modelFile)
var mpt = func(modelFile string) (interface{}, error) {
return transformers.NewMPT(modelFile)
}
var dolly = func(modelFile string) (interface{}, error) {
return gpt2.NewDolly(modelFile)
return transformers.NewDolly(modelFile)
}
var gptNeoX = func(modelFile string) (interface{}, error) {
return gpt2.NewGPTNeoX(modelFile)
return transformers.NewGPTNeoX(modelFile)
}
var replit = func(modelFile string) (interface{}, error) {
return gpt2.NewReplit(modelFile)
return transformers.NewReplit(modelFile)
}
var stableLM = func(modelFile string) (interface{}, error) {
return gpt2.NewStableLM(modelFile)
var gptJ = func(modelFile string) (interface{}, error) {
return transformers.NewGPTJ(modelFile)
}
var bertEmbeddings = func(modelFile string) (interface{}, error) {
@ -87,8 +88,9 @@ var bertEmbeddings = func(modelFile string) (interface{}, error) {
var bloomzLM = func(modelFile string) (interface{}, error) {
return bloomz.New(modelFile)
}
var gpt2LM = func(modelFile string) (interface{}, error) {
return gpt2.New(modelFile)
var transformersLM = func(modelFile string) (interface{}, error) {
return transformers.New(modelFile)
}
var stableDiffusion = func(assetDir string) (interface{}, error) {
@ -99,6 +101,10 @@ var whisperModel = func(modelFile string) (interface{}, error) {
return whisper.New(modelFile)
}
var lcHuggingFace = func(repoId string) (interface{}, error) {
return langchain.NewHuggingFace(repoId)
}
func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) {
return func(s string) (interface{}, error) {
return llama.New(s, opts...)
@ -130,14 +136,14 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
return ml.LoadModel(modelFile, llamaLM(llamaOpts...))
case BloomzBackend:
return ml.LoadModel(modelFile, bloomzLM)
case StableLMBackend:
return ml.LoadModel(modelFile, stableLM)
case GPTJBackend:
return ml.LoadModel(modelFile, gptJ)
case DollyBackend:
return ml.LoadModel(modelFile, dolly)
case RedPajamaBackend:
return ml.LoadModel(modelFile, redPajama)
case MPTBackend:
return ml.LoadModel(modelFile, mpt)
case Gpt2Backend:
return ml.LoadModel(modelFile, gpt2LM)
return ml.LoadModel(modelFile, transformersLM)
case GPTNeoXBackend:
return ml.LoadModel(modelFile, gptNeoX)
case ReplitBackend:
@ -146,18 +152,16 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
return ml.LoadModel(modelFile, stableDiffusion)
case StarcoderBackend:
return ml.LoadModel(modelFile, starCoder)
case Gpt4AllLlamaBackend:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.LLaMAType)))
case Gpt4AllMptBackend:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.MPTType)))
case Gpt4AllJBackend:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.GPTJType)))
case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads))))
case BertEmbeddingsBackend:
return ml.LoadModel(modelFile, bertEmbeddings)
case RwkvBackend:
return ml.LoadModel(modelFile, rwkvLM(filepath.Join(ml.ModelPath, modelFile+tokenizerSuffix), threads))
case WhisperBackend:
return ml.LoadModel(modelFile, whisperModel)
case LCHuggingFaceBackend:
return ml.LoadModel(modelFile, lcHuggingFace)
default:
return nil, fmt.Errorf("backend unsupported: %s", backendString)
}

View file

@ -8,6 +8,18 @@ import (
)
func GenerateImage(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst, asset_dir string) error {
if height > 512 || width > 512 {
return stableDiffusion.GenerateImageUpscaled(
height,
width,
step,
seed,
positive_prompt,
negative_prompt,
dst,
asset_dir,
)
}
return stableDiffusion.GenerateImage(
height,
width,