From 2c9279a54218a61285ec8984110a1a623545f8f5 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 26 Apr 2025 18:05:01 +0200 Subject: [PATCH] feat(video-gen): add endpoint for video generation (#5247) Signed-off-by: Ettore Di Giacinto --- backend/backend.proto | 14 ++ core/application/startup.go | 10 +- core/backend/soundgeneration.go | 11 +- core/backend/tts.go | 7 +- core/backend/video.go | 36 ++++ core/cli/run.go | 6 +- core/cli/soundgeneration.go | 2 +- core/cli/tts.go | 8 +- core/config/application_config.go | 39 ++-- core/config/backend_config.go | 37 ++-- core/http/app.go | 19 +- core/http/app_test.go | 3 +- core/http/endpoints/localai/video.go | 205 +++++++++++++++++++ core/http/endpoints/openai/image.go | 7 +- core/http/routes/localai.go | 5 + core/schema/localai.go | 14 ++ docs/content/docs/advanced/advanced-usage.md | 3 +- pkg/grpc/backend.go | 1 + pkg/grpc/base/base.go | 4 + pkg/grpc/client.go | 22 ++ pkg/grpc/embed.go | 4 + pkg/grpc/interface.go | 1 + pkg/grpc/server.go | 12 ++ 23 files changed, 401 insertions(+), 69 deletions(-) create mode 100644 core/backend/video.go create mode 100644 core/http/endpoints/localai/video.go diff --git a/backend/backend.proto b/backend/backend.proto index d5028efa..cdf09bf2 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -14,6 +14,7 @@ service Backend { rpc PredictStream(PredictOptions) returns (stream Reply) {} rpc Embedding(PredictOptions) returns (EmbeddingResult) {} rpc GenerateImage(GenerateImageRequest) returns (Result) {} + rpc GenerateVideo(GenerateVideoRequest) returns (Result) {} rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {} rpc TTS(TTSRequest) returns (Result) {} rpc SoundGeneration(SoundGenerationRequest) returns (Result) {} @@ -301,6 +302,19 @@ message GenerateImageRequest { int32 CLIPSkip = 11; } +message GenerateVideoRequest { + string prompt = 1; + string start_image = 2; // Path or base64 encoded image for the start frame + string end_image = 3; // Path or base64 encoded image for the end frame + int32 width = 4; + int32 height = 5; + int32 num_frames = 6; // Number of frames to generate + int32 fps = 7; // Frames per second + int32 seed = 8; + float cfg_scale = 9; // Classifier-free guidance scale + string dst = 10; // Output path for the generated video +} + message TTSRequest { string text = 1; string model = 2; diff --git a/core/application/startup.go b/core/application/startup.go index 6c93f03f..25b3691b 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -43,18 +43,12 @@ func New(opts ...config.AppOption) (*Application, error) { if err != nil { return nil, fmt.Errorf("unable to create ModelPath: %q", err) } - if options.ImageDir != "" { - err := os.MkdirAll(options.ImageDir, 0750) + if options.GeneratedContentDir != "" { + err := os.MkdirAll(options.GeneratedContentDir, 0750) if err != nil { return nil, fmt.Errorf("unable to create ImageDir: %q", err) } } - if options.AudioDir != "" { - err := os.MkdirAll(options.AudioDir, 0750) - if err != nil { - return nil, fmt.Errorf("unable to create AudioDir: %q", err) - } - } if options.UploadDir != "" { err := os.MkdirAll(options.UploadDir, 0750) if err != nil { diff --git a/core/backend/soundgeneration.go b/core/backend/soundgeneration.go index 94ec9c89..6379fb28 100644 --- a/core/backend/soundgeneration.go +++ b/core/backend/soundgeneration.go @@ -35,12 +35,17 @@ func SoundGeneration( return "", nil, fmt.Errorf("could not load sound generation model") } - if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil { + if err := os.MkdirAll(appConfig.GeneratedContentDir, 0750); err != nil { return "", nil, fmt.Errorf("failed creating audio directory: %s", err) } - fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "sound_generation", ".wav") - filePath := filepath.Join(appConfig.AudioDir, fileName) + audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio") + if err := os.MkdirAll(audioDir, 0750); err != nil { + return "", nil, fmt.Errorf("failed creating audio directory: %s", err) + } + + fileName := utils.GenerateUniqueFileName(audioDir, "sound_generation", ".wav") + filePath := filepath.Join(audioDir, fileName) res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{ Text: text, diff --git a/core/backend/tts.go b/core/backend/tts.go index 6157f4c1..81674016 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -32,12 +32,13 @@ func ModelTTS( return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model) } - if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil { + audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio") + if err := os.MkdirAll(audioDir, 0750); err != nil { return "", nil, fmt.Errorf("failed creating audio directory: %s", err) } - fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav") - filePath := filepath.Join(appConfig.AudioDir, fileName) + fileName := utils.GenerateUniqueFileName(audioDir, "tts", ".wav") + filePath := filepath.Join(audioDir, fileName) // We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect. // This should be addressed in a follow up PR soon. diff --git a/core/backend/video.go b/core/backend/video.go new file mode 100644 index 00000000..49241070 --- /dev/null +++ b/core/backend/video.go @@ -0,0 +1,36 @@ +package backend + +import ( + "github.com/mudler/LocalAI/core/config" + + "github.com/mudler/LocalAI/pkg/grpc/proto" + model "github.com/mudler/LocalAI/pkg/model" +) + +func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { + + opts := ModelOptions(backendConfig, appConfig) + inferenceModel, err := loader.Load( + opts..., + ) + if err != nil { + return nil, err + } + defer loader.Close() + + fn := func() error { + _, err := inferenceModel.GenerateVideo( + appConfig.Context, + &proto.GenerateVideoRequest{ + Height: height, + Width: width, + Prompt: prompt, + StartImage: startImage, + EndImage: endImage, + Dst: dst, + }) + return err + } + + return fn, nil +} diff --git a/core/cli/run.go b/core/cli/run.go index b245da67..5bc8913a 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -21,8 +21,7 @@ type RunCMD struct { ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"` - ImagePath string `env:"LOCALAI_IMAGE_PATH,IMAGE_PATH" type:"path" default:"/tmp/generated/images" help:"Location for images generated by backends (e.g. stablediffusion)" group:"storage"` - AudioPath string `env:"LOCALAI_AUDIO_PATH,AUDIO_PATH" type:"path" default:"/tmp/generated/audio" help:"Location for audio generated by backends (e.g. piper)" group:"storage"` + GeneratedContentPath string `env:"LOCALAI_GENERATED_CONTENT_PATH,GENERATED_CONTENT_PATH" type:"path" default:"/tmp/generated/content" help:"Location for generated content (e.g. images, audio, videos)" group:"storage"` UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"` ConfigPath string `env:"LOCALAI_CONFIG_PATH,CONFIG_PATH" default:"/tmp/localai/config" group:"storage"` LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"` @@ -81,8 +80,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { config.WithModelPath(r.ModelsPath), config.WithContextSize(r.ContextSize), config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel), - config.WithImageDir(r.ImagePath), - config.WithAudioDir(r.AudioPath), + config.WithGeneratedContentDir(r.GeneratedContentPath), config.WithUploadDir(r.UploadPath), config.WithConfigsDir(r.ConfigPath), config.WithDynamicConfigDir(r.LocalaiConfigDir), diff --git a/core/cli/soundgeneration.go b/core/cli/soundgeneration.go index 3c7e9af4..b7c1d0fe 100644 --- a/core/cli/soundgeneration.go +++ b/core/cli/soundgeneration.go @@ -70,7 +70,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error { opts := &config.ApplicationConfig{ ModelPath: t.ModelsPath, Context: context.Background(), - AudioDir: outputDir, + GeneratedContentDir: outputDir, AssetsDestination: t.BackendAssetsPath, ExternalGRPCBackends: externalBackends, } diff --git a/core/cli/tts.go b/core/cli/tts.go index 283372fe..074487e6 100644 --- a/core/cli/tts.go +++ b/core/cli/tts.go @@ -36,10 +36,10 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error { text := strings.Join(t.Text, " ") opts := &config.ApplicationConfig{ - ModelPath: t.ModelsPath, - Context: context.Background(), - AudioDir: outputDir, - AssetsDestination: t.BackendAssetsPath, + ModelPath: t.ModelsPath, + Context: context.Background(), + GeneratedContentDir: outputDir, + AssetsDestination: t.BackendAssetsPath, } ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend) diff --git a/core/config/application_config.go b/core/config/application_config.go index 2cc9b01b..9648e454 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -19,20 +19,21 @@ type ApplicationConfig struct { UploadLimitMB, Threads, ContextSize int F16 bool Debug bool - ImageDir string - AudioDir string - UploadDir string - ConfigsDir string - DynamicConfigsDir string - DynamicConfigsDirPollInterval time.Duration - CORS bool - CSRF bool - PreloadJSONModels string - PreloadModelsFromPath string - CORSAllowOrigins string - ApiKeys []string - P2PToken string - P2PNetworkID string + GeneratedContentDir string + + ConfigsDir string + UploadDir string + + DynamicConfigsDir string + DynamicConfigsDirPollInterval time.Duration + CORS bool + CSRF bool + PreloadJSONModels string + PreloadModelsFromPath string + CORSAllowOrigins string + ApiKeys []string + P2PToken string + P2PNetworkID string DisableWebUI bool EnforcePredownloadScans bool @@ -279,15 +280,9 @@ func WithDebug(debug bool) AppOption { } } -func WithAudioDir(audioDir string) AppOption { +func WithGeneratedContentDir(generatedContentDir string) AppOption { return func(o *ApplicationConfig) { - o.AudioDir = audioDir - } -} - -func WithImageDir(imageDir string) AppOption { - return func(o *ApplicationConfig) { - o.ImageDir = imageDir + o.GeneratedContentDir = generatedContentDir } } diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 2c022912..cb1263a6 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -436,18 +436,19 @@ func (c *BackendConfig) HasTemplate() bool { type BackendConfigUsecases int const ( - FLAG_ANY BackendConfigUsecases = 0b00000000000 - FLAG_CHAT BackendConfigUsecases = 0b00000000001 - FLAG_COMPLETION BackendConfigUsecases = 0b00000000010 - FLAG_EDIT BackendConfigUsecases = 0b00000000100 - FLAG_EMBEDDINGS BackendConfigUsecases = 0b00000001000 - FLAG_RERANK BackendConfigUsecases = 0b00000010000 - FLAG_IMAGE BackendConfigUsecases = 0b00000100000 - FLAG_TRANSCRIPT BackendConfigUsecases = 0b00001000000 - FLAG_TTS BackendConfigUsecases = 0b00010000000 - FLAG_SOUND_GENERATION BackendConfigUsecases = 0b00100000000 - FLAG_TOKENIZE BackendConfigUsecases = 0b01000000000 - FLAG_VAD BackendConfigUsecases = 0b10000000000 + FLAG_ANY BackendConfigUsecases = 0b000000000000 + FLAG_CHAT BackendConfigUsecases = 0b000000000001 + FLAG_COMPLETION BackendConfigUsecases = 0b000000000010 + FLAG_EDIT BackendConfigUsecases = 0b000000000100 + FLAG_EMBEDDINGS BackendConfigUsecases = 0b000000001000 + FLAG_RERANK BackendConfigUsecases = 0b000000010000 + FLAG_IMAGE BackendConfigUsecases = 0b000000100000 + FLAG_TRANSCRIPT BackendConfigUsecases = 0b000001000000 + FLAG_TTS BackendConfigUsecases = 0b000010000000 + FLAG_SOUND_GENERATION BackendConfigUsecases = 0b000100000000 + FLAG_TOKENIZE BackendConfigUsecases = 0b001000000000 + FLAG_VAD BackendConfigUsecases = 0b010000000000 + FLAG_VIDEO BackendConfigUsecases = 0b100000000000 // Common Subsets FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT @@ -468,6 +469,7 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases { "FLAG_TOKENIZE": FLAG_TOKENIZE, "FLAG_VAD": FLAG_VAD, "FLAG_LLM": FLAG_LLM, + "FLAG_VIDEO": FLAG_VIDEO, } } @@ -532,6 +534,17 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool { return false } + } + if (u & FLAG_VIDEO) == FLAG_VIDEO { + videoBackends := []string{"diffusers", "stablediffusion"} + if !slices.Contains(videoBackends, c.Backend) { + return false + } + + if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" { + return false + } + } if (u & FLAG_RERANK) == FLAG_RERANK { if c.Backend != "rerankers" { diff --git a/core/http/app.go b/core/http/app.go index 57f95465..0edd7ef1 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "net/http" + "os" + "path/filepath" "github.com/dave-gray101/v2keyauth" "github.com/mudler/LocalAI/pkg/utils" @@ -153,12 +155,19 @@ func API(application *application.Application) (*fiber.App, error) { Browse: true, })) - if application.ApplicationConfig().ImageDir != "" { - router.Static("/generated-images", application.ApplicationConfig().ImageDir) - } + if application.ApplicationConfig().GeneratedContentDir != "" { + os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750) + audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio") + imagePath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "images") + videoPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "videos") - if application.ApplicationConfig().AudioDir != "" { - router.Static("/generated-audio", application.ApplicationConfig().AudioDir) + os.MkdirAll(audioPath, 0750) + os.MkdirAll(imagePath, 0750) + os.MkdirAll(videoPath, 0750) + + router.Static("/generated-audio", audioPath) + router.Static("/generated-images", imagePath) + router.Static("/generated-videos", videoPath) } // Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration diff --git a/core/http/app_test.go b/core/http/app_test.go index ecaf6da3..8d12c496 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -629,8 +629,7 @@ var _ = Describe("API test", func() { application, err := application.New( append(commonOpts, config.WithContext(c), - config.WithAudioDir(tmpdir), - config.WithImageDir(tmpdir), + config.WithGeneratedContentDir(tmpdir), config.WithGalleries(galleries), config.WithModelPath(modelDir), config.WithBackendAssets(backendAssets), diff --git a/core/http/endpoints/localai/video.go b/core/http/endpoints/localai/video.go new file mode 100644 index 00000000..bec8a6a1 --- /dev/null +++ b/core/http/endpoints/localai/video.go @@ -0,0 +1,205 @@ +package localai + +import ( + "bufio" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" + + "github.com/mudler/LocalAI/core/backend" + + "github.com/gofiber/fiber/v2" + model "github.com/mudler/LocalAI/pkg/model" + "github.com/rs/zerolog/log" +) + +func downloadFile(url string) (string, error) { + // Get the data + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // Create the file + out, err := os.CreateTemp("", "video") + if err != nil { + return "", err + } + defer out.Close() + + // Write the body to file + _, err = io.Copy(out, resp.Body) + return out.Name(), err +} + +// + +/* +* + + curl http://localhost:8080/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A cute baby sea otter", + "n": 1, + "size": "512x512" + }' + +* +*/ +// VideoEndpoint +// @Summary Creates a video given a prompt. +// @Param request body schema.OpenAIRequest true "query params" +// @Success 200 {object} schema.OpenAIResponse "Response" +// @Router /video [post] +func VideoEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest) + if !ok || input.Model == "" { + log.Error().Msg("Video Endpoint - Invalid Input") + return fiber.ErrBadRequest + } + + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || config == nil { + log.Error().Msg("Video Endpoint - Invalid Config") + return fiber.ErrBadRequest + } + + src := "" + if input.StartImage != "" { + + var fileData []byte + var err error + // check if input.File is an URL, if so download it and save it + // to a temporary file + if strings.HasPrefix(input.StartImage, "http://") || strings.HasPrefix(input.StartImage, "https://") { + out, err := downloadFile(input.StartImage) + if err != nil { + return fmt.Errorf("failed downloading file:%w", err) + } + defer os.RemoveAll(out) + + fileData, err = os.ReadFile(out) + if err != nil { + return fmt.Errorf("failed reading file:%w", err) + } + + } else { + // base 64 decode the file and write it somewhere + // that we will cleanup + fileData, err = base64.StdEncoding.DecodeString(input.StartImage) + if err != nil { + return err + } + } + + // Create a temporary file + outputFile, err := os.CreateTemp(appConfig.GeneratedContentDir, "b64") + if err != nil { + return err + } + // write the base64 result + writer := bufio.NewWriter(outputFile) + _, err = writer.Write(fileData) + if err != nil { + outputFile.Close() + return err + } + outputFile.Close() + src = outputFile.Name() + defer os.RemoveAll(src) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + switch config.Backend { + case "stablediffusion": + config.Backend = model.StableDiffusionGGMLBackend + case "": + config.Backend = model.StableDiffusionGGMLBackend + } + + width := input.Width + height := input.Height + + if width == 0 { + width = 512 + } + if height == 0 { + height = 512 + } + + b64JSON := input.ResponseFormat == "b64_json" + + tempDir := "" + if !b64JSON { + tempDir = filepath.Join(appConfig.GeneratedContentDir, "videos") + } + // Create a temporary file + outputFile, err := os.CreateTemp(tempDir, "b64") + if err != nil { + return err + } + outputFile.Close() + + // TODO: use mime type to determine the extension + output := outputFile.Name() + ".mp4" + + // Rename the temporary file + err = os.Rename(outputFile.Name(), output) + if err != nil { + return err + } + + baseURL := c.BaseURL() + + fn, err := backend.VideoGeneration(height, width, input.Prompt, src, input.EndImage, output, ml, *config, appConfig) + if err != nil { + return err + } + if err := fn(); err != nil { + return err + } + + item := &schema.Item{} + + if b64JSON { + defer os.RemoveAll(output) + data, err := os.ReadFile(output) + if err != nil { + return err + } + item.B64JSON = base64.StdEncoding.EncodeToString(data) + } else { + base := filepath.Base(output) + item.URL = baseURL + "/generated-videos/" + base + } + + id := uuid.New().String() + created := int(time.Now().Unix()) + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Data: []schema.Item{*item}, + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index e4ff26db..3ac07cdc 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -72,7 +72,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon log.Error().Msg("Image Endpoint - Invalid Input") return fiber.ErrBadRequest } - + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) if !ok || config == nil { log.Error().Msg("Image Endpoint - Invalid Config") @@ -108,7 +108,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon } // Create a temporary file - outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64") + outputFile, err := os.CreateTemp(appConfig.GeneratedContentDir, "b64") if err != nil { return err } @@ -184,7 +184,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon tempDir := "" if !b64JSON { - tempDir = appConfig.ImageDir + tempDir = filepath.Join(appConfig.GeneratedContentDir, "images") } // Create a temporary file outputFile, err := os.CreateTemp(tempDir, "b64") @@ -192,6 +192,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon return err } outputFile.Close() + output := outputFile.Name() + ".png" // Rename the temporary file err = os.Rename(outputFile.Name(), output) diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index ebf9c1c9..e369a559 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -59,6 +59,11 @@ func RegisterLocalAIRoutes(router *fiber.App, router.Get("/metrics", localai.LocalAIMetricsEndpoint()) } + router.Post("/video", + requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)), + requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) }), + localai.VideoEndpoint(cl, ml, appConfig)) + // Backend Statistics Module // TODO: Should these use standard middlewares? Refactor later, they are extremely simple. backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now diff --git a/core/schema/localai.go b/core/schema/localai.go index 395b26b7..734314a2 100644 --- a/core/schema/localai.go +++ b/core/schema/localai.go @@ -24,6 +24,20 @@ type GalleryResponse struct { StatusURL string `json:"status"` } +type VideoRequest struct { + BasicModelRequest + Prompt string `json:"prompt" yaml:"prompt"` + StartImage string `json:"start_image" yaml:"start_image"` + EndImage string `json:"end_image" yaml:"end_image"` + Width int32 `json:"width" yaml:"width"` + Height int32 `json:"height" yaml:"height"` + NumFrames int32 `json:"num_frames" yaml:"num_frames"` + FPS int32 `json:"fps" yaml:"fps"` + Seed int32 `json:"seed" yaml:"seed"` + CFGScale float32 `json:"cfg_scale" yaml:"cfg_scale"` + ResponseFormat string `json:"response_format" yaml:"response_format"` +} + // @Description TTS request body type TTSRequest struct { BasicModelRequest diff --git a/docs/content/docs/advanced/advanced-usage.md b/docs/content/docs/advanced/advanced-usage.md index 3a370054..9d80b59e 100644 --- a/docs/content/docs/advanced/advanced-usage.md +++ b/docs/content/docs/advanced/advanced-usage.md @@ -481,8 +481,7 @@ In the help text below, BASEPATH is the location that local-ai is being executed |-----------|---------|-------------|----------------------| | --models-path | BASEPATH/models | Path containing models used for inferencing | $LOCALAI_MODELS_PATH | | --backend-assets-path |/tmp/localai/backend_data | Path used to extract libraries that are required by some of the backends in runtime | $LOCALAI_BACKEND_ASSETS_PATH | -| --image-path | /tmp/generated/images | Location for images generated by backends (e.g. stablediffusion) | $LOCALAI_IMAGE_PATH | -| --audio-path | /tmp/generated/audio | Location for audio generated by backends (e.g. piper) | $LOCALAI_AUDIO_PATH | +| --generated-content-path | /tmp/generated/content | Location for assets generated by backends (e.g. stablediffusion) | $LOCALAI_GENERATED_CONTENT_PATH | | --upload-path | /tmp/localai/upload | Path to store uploads from files api | $LOCALAI_UPLOAD_PATH | | --config-path | /tmp/localai/config | | $LOCALAI_CONFIG_PATH | | --localai-config-dir | BASEPATH/configuration | Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json) | $LOCALAI_CONFIG_DIR | diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index fabc0268..9f9f19b1 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -39,6 +39,7 @@ type Backend interface { LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) + GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 2e1fb209..a992f6d8 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -53,6 +53,10 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error { return fmt.Errorf("unimplemented") } +func (llm *Base) GenerateVideo(*pb.GenerateVideoRequest) error { + return fmt.Errorf("unimplemented") +} + func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) { return pb.TranscriptResult{}, fmt.Errorf("unimplemented") } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index fe4dcde4..78e1421d 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -215,6 +215,28 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, return client.GenerateImage(ctx, in, opts...) } +func (c *Client) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + c.wdMark() + defer c.wdUnMark() + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB + grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB + )) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.GenerateVideo(ctx, in, opts...) +} + func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) { if !c.parallel { c.opMutex.Lock() diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index 79648c5a..417b3890 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -47,6 +47,10 @@ func (e *embedBackend) GenerateImage(ctx context.Context, in *pb.GenerateImageRe return e.s.GenerateImage(ctx, in) } +func (e *embedBackend) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) { + return e.s.GenerateVideo(ctx, in) +} + func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) { return e.s.TTS(ctx, in) } diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 9214e3cf..35c5d977 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -14,6 +14,7 @@ type LLM interface { Load(*pb.ModelOptions) error Embeddings(*pb.PredictOptions) ([]float32, error) GenerateImage(*pb.GenerateImageRequest) error + GenerateVideo(*pb.GenerateVideoRequest) error AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) TTS(*pb.TTSRequest) error SoundGeneration(*pb.SoundGenerationRequest) error diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index b81c2c3a..546ed291 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -75,6 +75,18 @@ func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) return &pb.Result{Message: "Image generated", Success: true}, nil } +func (s *server) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest) (*pb.Result, error) { + if s.llm.Locking() { + s.llm.Lock() + defer s.llm.Unlock() + } + err := s.llm.GenerateVideo(in) + if err != nil { + return &pb.Result{Message: fmt.Sprintf("Error generating video: %s", err.Error()), Success: false}, err + } + return &pb.Result{Message: "Video generated", Success: true}, nil +} + func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) { if s.llm.Locking() { s.llm.Lock()