mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat(video-gen): add endpoint for video generation (#5247)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
a67d22f5f2
commit
2c9279a542
23 changed files with 401 additions and 69 deletions
|
@ -14,6 +14,7 @@ service Backend {
|
||||||
rpc PredictStream(PredictOptions) returns (stream Reply) {}
|
rpc PredictStream(PredictOptions) returns (stream Reply) {}
|
||||||
rpc Embedding(PredictOptions) returns (EmbeddingResult) {}
|
rpc Embedding(PredictOptions) returns (EmbeddingResult) {}
|
||||||
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
|
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
|
||||||
|
rpc GenerateVideo(GenerateVideoRequest) returns (Result) {}
|
||||||
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
|
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
|
||||||
rpc TTS(TTSRequest) returns (Result) {}
|
rpc TTS(TTSRequest) returns (Result) {}
|
||||||
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
|
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
|
||||||
|
@ -301,6 +302,19 @@ message GenerateImageRequest {
|
||||||
int32 CLIPSkip = 11;
|
int32 CLIPSkip = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message GenerateVideoRequest {
|
||||||
|
string prompt = 1;
|
||||||
|
string start_image = 2; // Path or base64 encoded image for the start frame
|
||||||
|
string end_image = 3; // Path or base64 encoded image for the end frame
|
||||||
|
int32 width = 4;
|
||||||
|
int32 height = 5;
|
||||||
|
int32 num_frames = 6; // Number of frames to generate
|
||||||
|
int32 fps = 7; // Frames per second
|
||||||
|
int32 seed = 8;
|
||||||
|
float cfg_scale = 9; // Classifier-free guidance scale
|
||||||
|
string dst = 10; // Output path for the generated video
|
||||||
|
}
|
||||||
|
|
||||||
message TTSRequest {
|
message TTSRequest {
|
||||||
string text = 1;
|
string text = 1;
|
||||||
string model = 2;
|
string model = 2;
|
||||||
|
|
|
@ -43,18 +43,12 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
return nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
||||||
}
|
}
|
||||||
if options.ImageDir != "" {
|
if options.GeneratedContentDir != "" {
|
||||||
err := os.MkdirAll(options.ImageDir, 0750)
|
err := os.MkdirAll(options.GeneratedContentDir, 0750)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to create ImageDir: %q", err)
|
return nil, fmt.Errorf("unable to create ImageDir: %q", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if options.AudioDir != "" {
|
|
||||||
err := os.MkdirAll(options.AudioDir, 0750)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to create AudioDir: %q", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if options.UploadDir != "" {
|
if options.UploadDir != "" {
|
||||||
err := os.MkdirAll(options.UploadDir, 0750)
|
err := os.MkdirAll(options.UploadDir, 0750)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -35,12 +35,17 @@ func SoundGeneration(
|
||||||
return "", nil, fmt.Errorf("could not load sound generation model")
|
return "", nil, fmt.Errorf("could not load sound generation model")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
|
if err := os.MkdirAll(appConfig.GeneratedContentDir, 0750); err != nil {
|
||||||
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "sound_generation", ".wav")
|
audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio")
|
||||||
filePath := filepath.Join(appConfig.AudioDir, fileName)
|
if err := os.MkdirAll(audioDir, 0750); err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := utils.GenerateUniqueFileName(audioDir, "sound_generation", ".wav")
|
||||||
|
filePath := filepath.Join(audioDir, fileName)
|
||||||
|
|
||||||
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
|
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
|
||||||
Text: text,
|
Text: text,
|
||||||
|
|
|
@ -32,12 +32,13 @@ func ModelTTS(
|
||||||
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
|
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
|
audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio")
|
||||||
|
if err := os.MkdirAll(audioDir, 0750); err != nil {
|
||||||
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
|
fileName := utils.GenerateUniqueFileName(audioDir, "tts", ".wav")
|
||||||
filePath := filepath.Join(appConfig.AudioDir, fileName)
|
filePath := filepath.Join(audioDir, fileName)
|
||||||
|
|
||||||
// We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
|
// We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
|
||||||
// This should be addressed in a follow up PR soon.
|
// This should be addressed in a follow up PR soon.
|
||||||
|
|
36
core/backend/video.go
Normal file
36
core/backend/video.go
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
|
||||||
|
|
||||||
|
opts := ModelOptions(backendConfig, appConfig)
|
||||||
|
inferenceModel, err := loader.Load(
|
||||||
|
opts...,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer loader.Close()
|
||||||
|
|
||||||
|
fn := func() error {
|
||||||
|
_, err := inferenceModel.GenerateVideo(
|
||||||
|
appConfig.Context,
|
||||||
|
&proto.GenerateVideoRequest{
|
||||||
|
Height: height,
|
||||||
|
Width: width,
|
||||||
|
Prompt: prompt,
|
||||||
|
StartImage: startImage,
|
||||||
|
EndImage: endImage,
|
||||||
|
Dst: dst,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn, nil
|
||||||
|
}
|
|
@ -21,8 +21,7 @@ type RunCMD struct {
|
||||||
|
|
||||||
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
||||||
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
|
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
|
||||||
ImagePath string `env:"LOCALAI_IMAGE_PATH,IMAGE_PATH" type:"path" default:"/tmp/generated/images" help:"Location for images generated by backends (e.g. stablediffusion)" group:"storage"`
|
GeneratedContentPath string `env:"LOCALAI_GENERATED_CONTENT_PATH,GENERATED_CONTENT_PATH" type:"path" default:"/tmp/generated/content" help:"Location for generated content (e.g. images, audio, videos)" group:"storage"`
|
||||||
AudioPath string `env:"LOCALAI_AUDIO_PATH,AUDIO_PATH" type:"path" default:"/tmp/generated/audio" help:"Location for audio generated by backends (e.g. piper)" group:"storage"`
|
|
||||||
UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"`
|
UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"`
|
||||||
ConfigPath string `env:"LOCALAI_CONFIG_PATH,CONFIG_PATH" default:"/tmp/localai/config" group:"storage"`
|
ConfigPath string `env:"LOCALAI_CONFIG_PATH,CONFIG_PATH" default:"/tmp/localai/config" group:"storage"`
|
||||||
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
|
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
|
||||||
|
@ -81,8 +80,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||||
config.WithModelPath(r.ModelsPath),
|
config.WithModelPath(r.ModelsPath),
|
||||||
config.WithContextSize(r.ContextSize),
|
config.WithContextSize(r.ContextSize),
|
||||||
config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel),
|
config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel),
|
||||||
config.WithImageDir(r.ImagePath),
|
config.WithGeneratedContentDir(r.GeneratedContentPath),
|
||||||
config.WithAudioDir(r.AudioPath),
|
|
||||||
config.WithUploadDir(r.UploadPath),
|
config.WithUploadDir(r.UploadPath),
|
||||||
config.WithConfigsDir(r.ConfigPath),
|
config.WithConfigsDir(r.ConfigPath),
|
||||||
config.WithDynamicConfigDir(r.LocalaiConfigDir),
|
config.WithDynamicConfigDir(r.LocalaiConfigDir),
|
||||||
|
|
|
@ -70,7 +70,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
|
||||||
opts := &config.ApplicationConfig{
|
opts := &config.ApplicationConfig{
|
||||||
ModelPath: t.ModelsPath,
|
ModelPath: t.ModelsPath,
|
||||||
Context: context.Background(),
|
Context: context.Background(),
|
||||||
AudioDir: outputDir,
|
GeneratedContentDir: outputDir,
|
||||||
AssetsDestination: t.BackendAssetsPath,
|
AssetsDestination: t.BackendAssetsPath,
|
||||||
ExternalGRPCBackends: externalBackends,
|
ExternalGRPCBackends: externalBackends,
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
|
||||||
opts := &config.ApplicationConfig{
|
opts := &config.ApplicationConfig{
|
||||||
ModelPath: t.ModelsPath,
|
ModelPath: t.ModelsPath,
|
||||||
Context: context.Background(),
|
Context: context.Background(),
|
||||||
AudioDir: outputDir,
|
GeneratedContentDir: outputDir,
|
||||||
AssetsDestination: t.BackendAssetsPath,
|
AssetsDestination: t.BackendAssetsPath,
|
||||||
}
|
}
|
||||||
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
|
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
|
||||||
|
|
|
@ -19,10 +19,11 @@ type ApplicationConfig struct {
|
||||||
UploadLimitMB, Threads, ContextSize int
|
UploadLimitMB, Threads, ContextSize int
|
||||||
F16 bool
|
F16 bool
|
||||||
Debug bool
|
Debug bool
|
||||||
ImageDir string
|
GeneratedContentDir string
|
||||||
AudioDir string
|
|
||||||
UploadDir string
|
|
||||||
ConfigsDir string
|
ConfigsDir string
|
||||||
|
UploadDir string
|
||||||
|
|
||||||
DynamicConfigsDir string
|
DynamicConfigsDir string
|
||||||
DynamicConfigsDirPollInterval time.Duration
|
DynamicConfigsDirPollInterval time.Duration
|
||||||
CORS bool
|
CORS bool
|
||||||
|
@ -279,15 +280,9 @@ func WithDebug(debug bool) AppOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithAudioDir(audioDir string) AppOption {
|
func WithGeneratedContentDir(generatedContentDir string) AppOption {
|
||||||
return func(o *ApplicationConfig) {
|
return func(o *ApplicationConfig) {
|
||||||
o.AudioDir = audioDir
|
o.GeneratedContentDir = generatedContentDir
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithImageDir(imageDir string) AppOption {
|
|
||||||
return func(o *ApplicationConfig) {
|
|
||||||
o.ImageDir = imageDir
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -436,18 +436,19 @@ func (c *BackendConfig) HasTemplate() bool {
|
||||||
type BackendConfigUsecases int
|
type BackendConfigUsecases int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
FLAG_ANY BackendConfigUsecases = 0b00000000000
|
FLAG_ANY BackendConfigUsecases = 0b000000000000
|
||||||
FLAG_CHAT BackendConfigUsecases = 0b00000000001
|
FLAG_CHAT BackendConfigUsecases = 0b000000000001
|
||||||
FLAG_COMPLETION BackendConfigUsecases = 0b00000000010
|
FLAG_COMPLETION BackendConfigUsecases = 0b000000000010
|
||||||
FLAG_EDIT BackendConfigUsecases = 0b00000000100
|
FLAG_EDIT BackendConfigUsecases = 0b000000000100
|
||||||
FLAG_EMBEDDINGS BackendConfigUsecases = 0b00000001000
|
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000000001000
|
||||||
FLAG_RERANK BackendConfigUsecases = 0b00000010000
|
FLAG_RERANK BackendConfigUsecases = 0b000000010000
|
||||||
FLAG_IMAGE BackendConfigUsecases = 0b00000100000
|
FLAG_IMAGE BackendConfigUsecases = 0b000000100000
|
||||||
FLAG_TRANSCRIPT BackendConfigUsecases = 0b00001000000
|
FLAG_TRANSCRIPT BackendConfigUsecases = 0b000001000000
|
||||||
FLAG_TTS BackendConfigUsecases = 0b00010000000
|
FLAG_TTS BackendConfigUsecases = 0b000010000000
|
||||||
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b00100000000
|
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b000100000000
|
||||||
FLAG_TOKENIZE BackendConfigUsecases = 0b01000000000
|
FLAG_TOKENIZE BackendConfigUsecases = 0b001000000000
|
||||||
FLAG_VAD BackendConfigUsecases = 0b10000000000
|
FLAG_VAD BackendConfigUsecases = 0b010000000000
|
||||||
|
FLAG_VIDEO BackendConfigUsecases = 0b100000000000
|
||||||
|
|
||||||
// Common Subsets
|
// Common Subsets
|
||||||
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
||||||
|
@ -468,6 +469,7 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
|
||||||
"FLAG_TOKENIZE": FLAG_TOKENIZE,
|
"FLAG_TOKENIZE": FLAG_TOKENIZE,
|
||||||
"FLAG_VAD": FLAG_VAD,
|
"FLAG_VAD": FLAG_VAD,
|
||||||
"FLAG_LLM": FLAG_LLM,
|
"FLAG_LLM": FLAG_LLM,
|
||||||
|
"FLAG_VIDEO": FLAG_VIDEO,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -532,6 +534,17 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
if (u & FLAG_VIDEO) == FLAG_VIDEO {
|
||||||
|
videoBackends := []string{"diffusers", "stablediffusion"}
|
||||||
|
if !slices.Contains(videoBackends, c.Backend) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
if (u & FLAG_RERANK) == FLAG_RERANK {
|
if (u & FLAG_RERANK) == FLAG_RERANK {
|
||||||
if c.Backend != "rerankers" {
|
if c.Backend != "rerankers" {
|
||||||
|
|
|
@ -5,6 +5,8 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/dave-gray101/v2keyauth"
|
"github.com/dave-gray101/v2keyauth"
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
@ -153,12 +155,19 @@ func API(application *application.Application) (*fiber.App, error) {
|
||||||
Browse: true,
|
Browse: true,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
if application.ApplicationConfig().ImageDir != "" {
|
if application.ApplicationConfig().GeneratedContentDir != "" {
|
||||||
router.Static("/generated-images", application.ApplicationConfig().ImageDir)
|
os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750)
|
||||||
}
|
audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio")
|
||||||
|
imagePath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "images")
|
||||||
|
videoPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "videos")
|
||||||
|
|
||||||
if application.ApplicationConfig().AudioDir != "" {
|
os.MkdirAll(audioPath, 0750)
|
||||||
router.Static("/generated-audio", application.ApplicationConfig().AudioDir)
|
os.MkdirAll(imagePath, 0750)
|
||||||
|
os.MkdirAll(videoPath, 0750)
|
||||||
|
|
||||||
|
router.Static("/generated-audio", audioPath)
|
||||||
|
router.Static("/generated-images", imagePath)
|
||||||
|
router.Static("/generated-videos", videoPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
|
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
|
||||||
|
|
|
@ -629,8 +629,7 @@ var _ = Describe("API test", func() {
|
||||||
application, err := application.New(
|
application, err := application.New(
|
||||||
append(commonOpts,
|
append(commonOpts,
|
||||||
config.WithContext(c),
|
config.WithContext(c),
|
||||||
config.WithAudioDir(tmpdir),
|
config.WithGeneratedContentDir(tmpdir),
|
||||||
config.WithImageDir(tmpdir),
|
|
||||||
config.WithGalleries(galleries),
|
config.WithGalleries(galleries),
|
||||||
config.WithModelPath(modelDir),
|
config.WithModelPath(modelDir),
|
||||||
config.WithBackendAssets(backendAssets),
|
config.WithBackendAssets(backendAssets),
|
||||||
|
|
205
core/http/endpoints/localai/video.go
Normal file
205
core/http/endpoints/localai/video.go
Normal file
|
@ -0,0 +1,205 @@
|
||||||
|
package localai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/http/middleware"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func downloadFile(url string) (string, error) {
|
||||||
|
// Get the data
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Create the file
|
||||||
|
out, err := os.CreateTemp("", "video")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer out.Close()
|
||||||
|
|
||||||
|
// Write the body to file
|
||||||
|
_, err = io.Copy(out, resp.Body)
|
||||||
|
return out.Name(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
|
||||||
|
curl http://localhost:8080/v1/images/generations \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"prompt": "A cute baby sea otter",
|
||||||
|
"n": 1,
|
||||||
|
"size": "512x512"
|
||||||
|
}'
|
||||||
|
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
// VideoEndpoint
|
||||||
|
// @Summary Creates a video given a prompt.
|
||||||
|
// @Param request body schema.OpenAIRequest true "query params"
|
||||||
|
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||||
|
// @Router /video [post]
|
||||||
|
func VideoEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||||
|
return func(c *fiber.Ctx) error {
|
||||||
|
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
|
||||||
|
if !ok || input.Model == "" {
|
||||||
|
log.Error().Msg("Video Endpoint - Invalid Input")
|
||||||
|
return fiber.ErrBadRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||||
|
if !ok || config == nil {
|
||||||
|
log.Error().Msg("Video Endpoint - Invalid Config")
|
||||||
|
return fiber.ErrBadRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
src := ""
|
||||||
|
if input.StartImage != "" {
|
||||||
|
|
||||||
|
var fileData []byte
|
||||||
|
var err error
|
||||||
|
// check if input.File is an URL, if so download it and save it
|
||||||
|
// to a temporary file
|
||||||
|
if strings.HasPrefix(input.StartImage, "http://") || strings.HasPrefix(input.StartImage, "https://") {
|
||||||
|
out, err := downloadFile(input.StartImage)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed downloading file:%w", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(out)
|
||||||
|
|
||||||
|
fileData, err = os.ReadFile(out)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed reading file:%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// base 64 decode the file and write it somewhere
|
||||||
|
// that we will cleanup
|
||||||
|
fileData, err = base64.StdEncoding.DecodeString(input.StartImage)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a temporary file
|
||||||
|
outputFile, err := os.CreateTemp(appConfig.GeneratedContentDir, "b64")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// write the base64 result
|
||||||
|
writer := bufio.NewWriter(outputFile)
|
||||||
|
_, err = writer.Write(fileData)
|
||||||
|
if err != nil {
|
||||||
|
outputFile.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
outputFile.Close()
|
||||||
|
src = outputFile.Name()
|
||||||
|
defer os.RemoveAll(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||||
|
|
||||||
|
switch config.Backend {
|
||||||
|
case "stablediffusion":
|
||||||
|
config.Backend = model.StableDiffusionGGMLBackend
|
||||||
|
case "":
|
||||||
|
config.Backend = model.StableDiffusionGGMLBackend
|
||||||
|
}
|
||||||
|
|
||||||
|
width := input.Width
|
||||||
|
height := input.Height
|
||||||
|
|
||||||
|
if width == 0 {
|
||||||
|
width = 512
|
||||||
|
}
|
||||||
|
if height == 0 {
|
||||||
|
height = 512
|
||||||
|
}
|
||||||
|
|
||||||
|
b64JSON := input.ResponseFormat == "b64_json"
|
||||||
|
|
||||||
|
tempDir := ""
|
||||||
|
if !b64JSON {
|
||||||
|
tempDir = filepath.Join(appConfig.GeneratedContentDir, "videos")
|
||||||
|
}
|
||||||
|
// Create a temporary file
|
||||||
|
outputFile, err := os.CreateTemp(tempDir, "b64")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
outputFile.Close()
|
||||||
|
|
||||||
|
// TODO: use mime type to determine the extension
|
||||||
|
output := outputFile.Name() + ".mp4"
|
||||||
|
|
||||||
|
// Rename the temporary file
|
||||||
|
err = os.Rename(outputFile.Name(), output)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := c.BaseURL()
|
||||||
|
|
||||||
|
fn, err := backend.VideoGeneration(height, width, input.Prompt, src, input.EndImage, output, ml, *config, appConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := fn(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
item := &schema.Item{}
|
||||||
|
|
||||||
|
if b64JSON {
|
||||||
|
defer os.RemoveAll(output)
|
||||||
|
data, err := os.ReadFile(output)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||||
|
} else {
|
||||||
|
base := filepath.Base(output)
|
||||||
|
item.URL = baseURL + "/generated-videos/" + base
|
||||||
|
}
|
||||||
|
|
||||||
|
id := uuid.New().String()
|
||||||
|
created := int(time.Now().Unix())
|
||||||
|
resp := &schema.OpenAIResponse{
|
||||||
|
ID: id,
|
||||||
|
Created: created,
|
||||||
|
Data: []schema.Item{*item},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp)
|
||||||
|
log.Debug().Msgf("Response: %s", jsonResult)
|
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp)
|
||||||
|
}
|
||||||
|
}
|
|
@ -108,7 +108,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a temporary file
|
// Create a temporary file
|
||||||
outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64")
|
outputFile, err := os.CreateTemp(appConfig.GeneratedContentDir, "b64")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -184,7 +184,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||||
|
|
||||||
tempDir := ""
|
tempDir := ""
|
||||||
if !b64JSON {
|
if !b64JSON {
|
||||||
tempDir = appConfig.ImageDir
|
tempDir = filepath.Join(appConfig.GeneratedContentDir, "images")
|
||||||
}
|
}
|
||||||
// Create a temporary file
|
// Create a temporary file
|
||||||
outputFile, err := os.CreateTemp(tempDir, "b64")
|
outputFile, err := os.CreateTemp(tempDir, "b64")
|
||||||
|
@ -192,6 +192,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
outputFile.Close()
|
outputFile.Close()
|
||||||
|
|
||||||
output := outputFile.Name() + ".png"
|
output := outputFile.Name() + ".png"
|
||||||
// Rename the temporary file
|
// Rename the temporary file
|
||||||
err = os.Rename(outputFile.Name(), output)
|
err = os.Rename(outputFile.Name(), output)
|
||||||
|
|
|
@ -59,6 +59,11 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
||||||
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
router.Get("/metrics", localai.LocalAIMetricsEndpoint())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
router.Post("/video",
|
||||||
|
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
|
||||||
|
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) }),
|
||||||
|
localai.VideoEndpoint(cl, ml, appConfig))
|
||||||
|
|
||||||
// Backend Statistics Module
|
// Backend Statistics Module
|
||||||
// TODO: Should these use standard middlewares? Refactor later, they are extremely simple.
|
// TODO: Should these use standard middlewares? Refactor later, they are extremely simple.
|
||||||
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now
|
||||||
|
|
|
@ -24,6 +24,20 @@ type GalleryResponse struct {
|
||||||
StatusURL string `json:"status"`
|
StatusURL string `json:"status"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type VideoRequest struct {
|
||||||
|
BasicModelRequest
|
||||||
|
Prompt string `json:"prompt" yaml:"prompt"`
|
||||||
|
StartImage string `json:"start_image" yaml:"start_image"`
|
||||||
|
EndImage string `json:"end_image" yaml:"end_image"`
|
||||||
|
Width int32 `json:"width" yaml:"width"`
|
||||||
|
Height int32 `json:"height" yaml:"height"`
|
||||||
|
NumFrames int32 `json:"num_frames" yaml:"num_frames"`
|
||||||
|
FPS int32 `json:"fps" yaml:"fps"`
|
||||||
|
Seed int32 `json:"seed" yaml:"seed"`
|
||||||
|
CFGScale float32 `json:"cfg_scale" yaml:"cfg_scale"`
|
||||||
|
ResponseFormat string `json:"response_format" yaml:"response_format"`
|
||||||
|
}
|
||||||
|
|
||||||
// @Description TTS request body
|
// @Description TTS request body
|
||||||
type TTSRequest struct {
|
type TTSRequest struct {
|
||||||
BasicModelRequest
|
BasicModelRequest
|
||||||
|
|
|
@ -481,8 +481,7 @@ In the help text below, BASEPATH is the location that local-ai is being executed
|
||||||
|-----------|---------|-------------|----------------------|
|
|-----------|---------|-------------|----------------------|
|
||||||
| --models-path | BASEPATH/models | Path containing models used for inferencing | $LOCALAI_MODELS_PATH |
|
| --models-path | BASEPATH/models | Path containing models used for inferencing | $LOCALAI_MODELS_PATH |
|
||||||
| --backend-assets-path |/tmp/localai/backend_data | Path used to extract libraries that are required by some of the backends in runtime | $LOCALAI_BACKEND_ASSETS_PATH |
|
| --backend-assets-path |/tmp/localai/backend_data | Path used to extract libraries that are required by some of the backends in runtime | $LOCALAI_BACKEND_ASSETS_PATH |
|
||||||
| --image-path | /tmp/generated/images | Location for images generated by backends (e.g. stablediffusion) | $LOCALAI_IMAGE_PATH |
|
| --generated-content-path | /tmp/generated/content | Location for assets generated by backends (e.g. stablediffusion) | $LOCALAI_GENERATED_CONTENT_PATH |
|
||||||
| --audio-path | /tmp/generated/audio | Location for audio generated by backends (e.g. piper) | $LOCALAI_AUDIO_PATH |
|
|
||||||
| --upload-path | /tmp/localai/upload | Path to store uploads from files api | $LOCALAI_UPLOAD_PATH |
|
| --upload-path | /tmp/localai/upload | Path to store uploads from files api | $LOCALAI_UPLOAD_PATH |
|
||||||
| --config-path | /tmp/localai/config | | $LOCALAI_CONFIG_PATH |
|
| --config-path | /tmp/localai/config | | $LOCALAI_CONFIG_PATH |
|
||||||
| --localai-config-dir | BASEPATH/configuration | Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json) | $LOCALAI_CONFIG_DIR |
|
| --localai-config-dir | BASEPATH/configuration | Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json) | $LOCALAI_CONFIG_DIR |
|
||||||
|
|
|
@ -39,6 +39,7 @@ type Backend interface {
|
||||||
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error
|
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error
|
||||||
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
|
GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
|
||||||
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)
|
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)
|
||||||
|
|
|
@ -53,6 +53,10 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
|
||||||
return fmt.Errorf("unimplemented")
|
return fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (llm *Base) GenerateVideo(*pb.GenerateVideoRequest) error {
|
||||||
|
return fmt.Errorf("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||||
return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
|
return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
|
@ -215,6 +215,28 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest,
|
||||||
return client.GenerateImage(ctx, in, opts...)
|
return client.GenerateImage(ctx, in, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||||
|
if !c.parallel {
|
||||||
|
c.opMutex.Lock()
|
||||||
|
defer c.opMutex.Unlock()
|
||||||
|
}
|
||||||
|
c.setBusy(true)
|
||||||
|
defer c.setBusy(false)
|
||||||
|
c.wdMark()
|
||||||
|
defer c.wdUnMark()
|
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
grpc.WithDefaultCallOptions(
|
||||||
|
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
|
||||||
|
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
client := pb.NewBackendClient(conn)
|
||||||
|
return client.GenerateVideo(ctx, in, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||||
if !c.parallel {
|
if !c.parallel {
|
||||||
c.opMutex.Lock()
|
c.opMutex.Lock()
|
||||||
|
|
|
@ -47,6 +47,10 @@ func (e *embedBackend) GenerateImage(ctx context.Context, in *pb.GenerateImageRe
|
||||||
return e.s.GenerateImage(ctx, in)
|
return e.s.GenerateImage(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *embedBackend) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||||
|
return e.s.GenerateVideo(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
|
||||||
return e.s.TTS(ctx, in)
|
return e.s.TTS(ctx, in)
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ type LLM interface {
|
||||||
Load(*pb.ModelOptions) error
|
Load(*pb.ModelOptions) error
|
||||||
Embeddings(*pb.PredictOptions) ([]float32, error)
|
Embeddings(*pb.PredictOptions) ([]float32, error)
|
||||||
GenerateImage(*pb.GenerateImageRequest) error
|
GenerateImage(*pb.GenerateImageRequest) error
|
||||||
|
GenerateVideo(*pb.GenerateVideoRequest) error
|
||||||
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
|
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
|
||||||
TTS(*pb.TTSRequest) error
|
TTS(*pb.TTSRequest) error
|
||||||
SoundGeneration(*pb.SoundGenerationRequest) error
|
SoundGeneration(*pb.SoundGenerationRequest) error
|
||||||
|
|
|
@ -75,6 +75,18 @@ func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest)
|
||||||
return &pb.Result{Message: "Image generated", Success: true}, nil
|
return &pb.Result{Message: "Image generated", Success: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *server) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest) (*pb.Result, error) {
|
||||||
|
if s.llm.Locking() {
|
||||||
|
s.llm.Lock()
|
||||||
|
defer s.llm.Unlock()
|
||||||
|
}
|
||||||
|
err := s.llm.GenerateVideo(in)
|
||||||
|
if err != nil {
|
||||||
|
return &pb.Result{Message: fmt.Sprintf("Error generating video: %s", err.Error()), Success: false}, err
|
||||||
|
}
|
||||||
|
return &pb.Result{Message: "Video generated", Success: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) {
|
func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) {
|
||||||
if s.llm.Locking() {
|
if s.llm.Locking() {
|
||||||
s.llm.Lock()
|
s.llm.Lock()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue