feat(video-gen): add endpoint for video generation (#5247)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2025-04-26 18:05:01 +02:00 committed by GitHub
parent a67d22f5f2
commit 2c9279a542
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 401 additions and 69 deletions

View file

@ -14,6 +14,7 @@ service Backend {
rpc PredictStream(PredictOptions) returns (stream Reply) {} rpc PredictStream(PredictOptions) returns (stream Reply) {}
rpc Embedding(PredictOptions) returns (EmbeddingResult) {} rpc Embedding(PredictOptions) returns (EmbeddingResult) {}
rpc GenerateImage(GenerateImageRequest) returns (Result) {} rpc GenerateImage(GenerateImageRequest) returns (Result) {}
rpc GenerateVideo(GenerateVideoRequest) returns (Result) {}
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {} rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
rpc TTS(TTSRequest) returns (Result) {} rpc TTS(TTSRequest) returns (Result) {}
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {} rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
@ -301,6 +302,19 @@ message GenerateImageRequest {
int32 CLIPSkip = 11; int32 CLIPSkip = 11;
} }
message GenerateVideoRequest {
string prompt = 1;
string start_image = 2; // Path or base64 encoded image for the start frame
string end_image = 3; // Path or base64 encoded image for the end frame
int32 width = 4;
int32 height = 5;
int32 num_frames = 6; // Number of frames to generate
int32 fps = 7; // Frames per second
int32 seed = 8;
float cfg_scale = 9; // Classifier-free guidance scale
string dst = 10; // Output path for the generated video
}
message TTSRequest { message TTSRequest {
string text = 1; string text = 1;
string model = 2; string model = 2;

View file

@ -43,18 +43,12 @@ func New(opts ...config.AppOption) (*Application, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to create ModelPath: %q", err) return nil, fmt.Errorf("unable to create ModelPath: %q", err)
} }
if options.ImageDir != "" { if options.GeneratedContentDir != "" {
err := os.MkdirAll(options.ImageDir, 0750) err := os.MkdirAll(options.GeneratedContentDir, 0750)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to create ImageDir: %q", err) return nil, fmt.Errorf("unable to create ImageDir: %q", err)
} }
} }
if options.AudioDir != "" {
err := os.MkdirAll(options.AudioDir, 0750)
if err != nil {
return nil, fmt.Errorf("unable to create AudioDir: %q", err)
}
}
if options.UploadDir != "" { if options.UploadDir != "" {
err := os.MkdirAll(options.UploadDir, 0750) err := os.MkdirAll(options.UploadDir, 0750)
if err != nil { if err != nil {

View file

@ -35,12 +35,17 @@ func SoundGeneration(
return "", nil, fmt.Errorf("could not load sound generation model") return "", nil, fmt.Errorf("could not load sound generation model")
} }
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil { if err := os.MkdirAll(appConfig.GeneratedContentDir, 0750); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err) return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
} }
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "sound_generation", ".wav") audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio")
filePath := filepath.Join(appConfig.AudioDir, fileName) if err := os.MkdirAll(audioDir, 0750); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}
fileName := utils.GenerateUniqueFileName(audioDir, "sound_generation", ".wav")
filePath := filepath.Join(audioDir, fileName)
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{ res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
Text: text, Text: text,

View file

@ -32,12 +32,13 @@ func ModelTTS(
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model) return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
} }
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil { audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio")
if err := os.MkdirAll(audioDir, 0750); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err) return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
} }
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav") fileName := utils.GenerateUniqueFileName(audioDir, "tts", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName) filePath := filepath.Join(audioDir, fileName)
// We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect. // We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
// This should be addressed in a follow up PR soon. // This should be addressed in a follow up PR soon.

36
core/backend/video.go Normal file
View file

@ -0,0 +1,36 @@
package backend
import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
)
func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {
opts := ModelOptions(backendConfig, appConfig)
inferenceModel, err := loader.Load(
opts...,
)
if err != nil {
return nil, err
}
defer loader.Close()
fn := func() error {
_, err := inferenceModel.GenerateVideo(
appConfig.Context,
&proto.GenerateVideoRequest{
Height: height,
Width: width,
Prompt: prompt,
StartImage: startImage,
EndImage: endImage,
Dst: dst,
})
return err
}
return fn, nil
}

View file

@ -21,8 +21,7 @@ type RunCMD struct {
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"` BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
ImagePath string `env:"LOCALAI_IMAGE_PATH,IMAGE_PATH" type:"path" default:"/tmp/generated/images" help:"Location for images generated by backends (e.g. stablediffusion)" group:"storage"` GeneratedContentPath string `env:"LOCALAI_GENERATED_CONTENT_PATH,GENERATED_CONTENT_PATH" type:"path" default:"/tmp/generated/content" help:"Location for generated content (e.g. images, audio, videos)" group:"storage"`
AudioPath string `env:"LOCALAI_AUDIO_PATH,AUDIO_PATH" type:"path" default:"/tmp/generated/audio" help:"Location for audio generated by backends (e.g. piper)" group:"storage"`
UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"` UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"`
ConfigPath string `env:"LOCALAI_CONFIG_PATH,CONFIG_PATH" default:"/tmp/localai/config" group:"storage"` ConfigPath string `env:"LOCALAI_CONFIG_PATH,CONFIG_PATH" default:"/tmp/localai/config" group:"storage"`
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"` LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
@ -81,8 +80,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithModelPath(r.ModelsPath), config.WithModelPath(r.ModelsPath),
config.WithContextSize(r.ContextSize), config.WithContextSize(r.ContextSize),
config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel), config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel),
config.WithImageDir(r.ImagePath), config.WithGeneratedContentDir(r.GeneratedContentPath),
config.WithAudioDir(r.AudioPath),
config.WithUploadDir(r.UploadPath), config.WithUploadDir(r.UploadPath),
config.WithConfigsDir(r.ConfigPath), config.WithConfigsDir(r.ConfigPath),
config.WithDynamicConfigDir(r.LocalaiConfigDir), config.WithDynamicConfigDir(r.LocalaiConfigDir),

View file

@ -70,7 +70,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
opts := &config.ApplicationConfig{ opts := &config.ApplicationConfig{
ModelPath: t.ModelsPath, ModelPath: t.ModelsPath,
Context: context.Background(), Context: context.Background(),
AudioDir: outputDir, GeneratedContentDir: outputDir,
AssetsDestination: t.BackendAssetsPath, AssetsDestination: t.BackendAssetsPath,
ExternalGRPCBackends: externalBackends, ExternalGRPCBackends: externalBackends,
} }

View file

@ -38,7 +38,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
opts := &config.ApplicationConfig{ opts := &config.ApplicationConfig{
ModelPath: t.ModelsPath, ModelPath: t.ModelsPath,
Context: context.Background(), Context: context.Background(),
AudioDir: outputDir, GeneratedContentDir: outputDir,
AssetsDestination: t.BackendAssetsPath, AssetsDestination: t.BackendAssetsPath,
} }
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend) ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)

View file

@ -19,10 +19,11 @@ type ApplicationConfig struct {
UploadLimitMB, Threads, ContextSize int UploadLimitMB, Threads, ContextSize int
F16 bool F16 bool
Debug bool Debug bool
ImageDir string GeneratedContentDir string
AudioDir string
UploadDir string
ConfigsDir string ConfigsDir string
UploadDir string
DynamicConfigsDir string DynamicConfigsDir string
DynamicConfigsDirPollInterval time.Duration DynamicConfigsDirPollInterval time.Duration
CORS bool CORS bool
@ -279,15 +280,9 @@ func WithDebug(debug bool) AppOption {
} }
} }
func WithAudioDir(audioDir string) AppOption { func WithGeneratedContentDir(generatedContentDir string) AppOption {
return func(o *ApplicationConfig) { return func(o *ApplicationConfig) {
o.AudioDir = audioDir o.GeneratedContentDir = generatedContentDir
}
}
func WithImageDir(imageDir string) AppOption {
return func(o *ApplicationConfig) {
o.ImageDir = imageDir
} }
} }

View file

@ -436,18 +436,19 @@ func (c *BackendConfig) HasTemplate() bool {
type BackendConfigUsecases int type BackendConfigUsecases int
const ( const (
FLAG_ANY BackendConfigUsecases = 0b00000000000 FLAG_ANY BackendConfigUsecases = 0b000000000000
FLAG_CHAT BackendConfigUsecases = 0b00000000001 FLAG_CHAT BackendConfigUsecases = 0b000000000001
FLAG_COMPLETION BackendConfigUsecases = 0b00000000010 FLAG_COMPLETION BackendConfigUsecases = 0b000000000010
FLAG_EDIT BackendConfigUsecases = 0b00000000100 FLAG_EDIT BackendConfigUsecases = 0b000000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b00000001000 FLAG_EMBEDDINGS BackendConfigUsecases = 0b000000001000
FLAG_RERANK BackendConfigUsecases = 0b00000010000 FLAG_RERANK BackendConfigUsecases = 0b000000010000
FLAG_IMAGE BackendConfigUsecases = 0b00000100000 FLAG_IMAGE BackendConfigUsecases = 0b000000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b00001000000 FLAG_TRANSCRIPT BackendConfigUsecases = 0b000001000000
FLAG_TTS BackendConfigUsecases = 0b00010000000 FLAG_TTS BackendConfigUsecases = 0b000010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b00100000000 FLAG_SOUND_GENERATION BackendConfigUsecases = 0b000100000000
FLAG_TOKENIZE BackendConfigUsecases = 0b01000000000 FLAG_TOKENIZE BackendConfigUsecases = 0b001000000000
FLAG_VAD BackendConfigUsecases = 0b10000000000 FLAG_VAD BackendConfigUsecases = 0b010000000000
FLAG_VIDEO BackendConfigUsecases = 0b100000000000
// Common Subsets // Common Subsets
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
@ -468,6 +469,7 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
"FLAG_TOKENIZE": FLAG_TOKENIZE, "FLAG_TOKENIZE": FLAG_TOKENIZE,
"FLAG_VAD": FLAG_VAD, "FLAG_VAD": FLAG_VAD,
"FLAG_LLM": FLAG_LLM, "FLAG_LLM": FLAG_LLM,
"FLAG_VIDEO": FLAG_VIDEO,
} }
} }
@ -532,6 +534,17 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
return false return false
} }
}
if (u & FLAG_VIDEO) == FLAG_VIDEO {
videoBackends := []string{"diffusers", "stablediffusion"}
if !slices.Contains(videoBackends, c.Backend) {
return false
}
if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" {
return false
}
} }
if (u & FLAG_RERANK) == FLAG_RERANK { if (u & FLAG_RERANK) == FLAG_RERANK {
if c.Backend != "rerankers" { if c.Backend != "rerankers" {

View file

@ -5,6 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"os"
"path/filepath"
"github.com/dave-gray101/v2keyauth" "github.com/dave-gray101/v2keyauth"
"github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/utils"
@ -153,12 +155,19 @@ func API(application *application.Application) (*fiber.App, error) {
Browse: true, Browse: true,
})) }))
if application.ApplicationConfig().ImageDir != "" { if application.ApplicationConfig().GeneratedContentDir != "" {
router.Static("/generated-images", application.ApplicationConfig().ImageDir) os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750)
} audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio")
imagePath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "images")
videoPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "videos")
if application.ApplicationConfig().AudioDir != "" { os.MkdirAll(audioPath, 0750)
router.Static("/generated-audio", application.ApplicationConfig().AudioDir) os.MkdirAll(imagePath, 0750)
os.MkdirAll(videoPath, 0750)
router.Static("/generated-audio", audioPath)
router.Static("/generated-images", imagePath)
router.Static("/generated-videos", videoPath)
} }
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration // Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration

View file

@ -629,8 +629,7 @@ var _ = Describe("API test", func() {
application, err := application.New( application, err := application.New(
append(commonOpts, append(commonOpts,
config.WithContext(c), config.WithContext(c),
config.WithAudioDir(tmpdir), config.WithGeneratedContentDir(tmpdir),
config.WithImageDir(tmpdir),
config.WithGalleries(galleries), config.WithGalleries(galleries),
config.WithModelPath(modelDir), config.WithModelPath(modelDir),
config.WithBackendAssets(backendAssets), config.WithBackendAssets(backendAssets),

View file

@ -0,0 +1,205 @@
package localai
import (
"bufio"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend"
"github.com/gofiber/fiber/v2"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
func downloadFile(url string) (string, error) {
// Get the data
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
// Create the file
out, err := os.CreateTemp("", "video")
if err != nil {
return "", err
}
defer out.Close()
// Write the body to file
_, err = io.Copy(out, resp.Body)
return out.Name(), err
}
//
/*
*
curl http://localhost:8080/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"prompt": "A cute baby sea otter",
"n": 1,
"size": "512x512"
}'
*
*/
// VideoEndpoint
// @Summary Creates a video given a prompt.
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /video [post]
func VideoEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
if !ok || input.Model == "" {
log.Error().Msg("Video Endpoint - Invalid Input")
return fiber.ErrBadRequest
}
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
if !ok || config == nil {
log.Error().Msg("Video Endpoint - Invalid Config")
return fiber.ErrBadRequest
}
src := ""
if input.StartImage != "" {
var fileData []byte
var err error
// check if input.File is an URL, if so download it and save it
// to a temporary file
if strings.HasPrefix(input.StartImage, "http://") || strings.HasPrefix(input.StartImage, "https://") {
out, err := downloadFile(input.StartImage)
if err != nil {
return fmt.Errorf("failed downloading file:%w", err)
}
defer os.RemoveAll(out)
fileData, err = os.ReadFile(out)
if err != nil {
return fmt.Errorf("failed reading file:%w", err)
}
} else {
// base 64 decode the file and write it somewhere
// that we will cleanup
fileData, err = base64.StdEncoding.DecodeString(input.StartImage)
if err != nil {
return err
}
}
// Create a temporary file
outputFile, err := os.CreateTemp(appConfig.GeneratedContentDir, "b64")
if err != nil {
return err
}
// write the base64 result
writer := bufio.NewWriter(outputFile)
_, err = writer.Write(fileData)
if err != nil {
outputFile.Close()
return err
}
outputFile.Close()
src = outputFile.Name()
defer os.RemoveAll(src)
}
log.Debug().Msgf("Parameter Config: %+v", config)
switch config.Backend {
case "stablediffusion":
config.Backend = model.StableDiffusionGGMLBackend
case "":
config.Backend = model.StableDiffusionGGMLBackend
}
width := input.Width
height := input.Height
if width == 0 {
width = 512
}
if height == 0 {
height = 512
}
b64JSON := input.ResponseFormat == "b64_json"
tempDir := ""
if !b64JSON {
tempDir = filepath.Join(appConfig.GeneratedContentDir, "videos")
}
// Create a temporary file
outputFile, err := os.CreateTemp(tempDir, "b64")
if err != nil {
return err
}
outputFile.Close()
// TODO: use mime type to determine the extension
output := outputFile.Name() + ".mp4"
// Rename the temporary file
err = os.Rename(outputFile.Name(), output)
if err != nil {
return err
}
baseURL := c.BaseURL()
fn, err := backend.VideoGeneration(height, width, input.Prompt, src, input.EndImage, output, ml, *config, appConfig)
if err != nil {
return err
}
if err := fn(); err != nil {
return err
}
item := &schema.Item{}
if b64JSON {
defer os.RemoveAll(output)
data, err := os.ReadFile(output)
if err != nil {
return err
}
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL = baseURL + "/generated-videos/" + base
}
id := uuid.New().String()
created := int(time.Now().Unix())
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Data: []schema.Item{*item},
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(resp)
}
}

View file

@ -108,7 +108,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
} }
// Create a temporary file // Create a temporary file
outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64") outputFile, err := os.CreateTemp(appConfig.GeneratedContentDir, "b64")
if err != nil { if err != nil {
return err return err
} }
@ -184,7 +184,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
tempDir := "" tempDir := ""
if !b64JSON { if !b64JSON {
tempDir = appConfig.ImageDir tempDir = filepath.Join(appConfig.GeneratedContentDir, "images")
} }
// Create a temporary file // Create a temporary file
outputFile, err := os.CreateTemp(tempDir, "b64") outputFile, err := os.CreateTemp(tempDir, "b64")
@ -192,6 +192,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
return err return err
} }
outputFile.Close() outputFile.Close()
output := outputFile.Name() + ".png" output := outputFile.Name() + ".png"
// Rename the temporary file // Rename the temporary file
err = os.Rename(outputFile.Name(), output) err = os.Rename(outputFile.Name(), output)

View file

@ -59,6 +59,11 @@ func RegisterLocalAIRoutes(router *fiber.App,
router.Get("/metrics", localai.LocalAIMetricsEndpoint()) router.Get("/metrics", localai.LocalAIMetricsEndpoint())
} }
router.Post("/video",
requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.VideoRequest) }),
localai.VideoEndpoint(cl, ml, appConfig))
// Backend Statistics Module // Backend Statistics Module
// TODO: Should these use standard middlewares? Refactor later, they are extremely simple. // TODO: Should these use standard middlewares? Refactor later, they are extremely simple.
backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now

View file

@ -24,6 +24,20 @@ type GalleryResponse struct {
StatusURL string `json:"status"` StatusURL string `json:"status"`
} }
type VideoRequest struct {
BasicModelRequest
Prompt string `json:"prompt" yaml:"prompt"`
StartImage string `json:"start_image" yaml:"start_image"`
EndImage string `json:"end_image" yaml:"end_image"`
Width int32 `json:"width" yaml:"width"`
Height int32 `json:"height" yaml:"height"`
NumFrames int32 `json:"num_frames" yaml:"num_frames"`
FPS int32 `json:"fps" yaml:"fps"`
Seed int32 `json:"seed" yaml:"seed"`
CFGScale float32 `json:"cfg_scale" yaml:"cfg_scale"`
ResponseFormat string `json:"response_format" yaml:"response_format"`
}
// @Description TTS request body // @Description TTS request body
type TTSRequest struct { type TTSRequest struct {
BasicModelRequest BasicModelRequest

View file

@ -481,8 +481,7 @@ In the help text below, BASEPATH is the location that local-ai is being executed
|-----------|---------|-------------|----------------------| |-----------|---------|-------------|----------------------|
| --models-path | BASEPATH/models | Path containing models used for inferencing | $LOCALAI_MODELS_PATH | | --models-path | BASEPATH/models | Path containing models used for inferencing | $LOCALAI_MODELS_PATH |
| --backend-assets-path |/tmp/localai/backend_data | Path used to extract libraries that are required by some of the backends in runtime | $LOCALAI_BACKEND_ASSETS_PATH | | --backend-assets-path |/tmp/localai/backend_data | Path used to extract libraries that are required by some of the backends in runtime | $LOCALAI_BACKEND_ASSETS_PATH |
| --image-path | /tmp/generated/images | Location for images generated by backends (e.g. stablediffusion) | $LOCALAI_IMAGE_PATH | | --generated-content-path | /tmp/generated/content | Location for assets generated by backends (e.g. stablediffusion) | $LOCALAI_GENERATED_CONTENT_PATH |
| --audio-path | /tmp/generated/audio | Location for audio generated by backends (e.g. piper) | $LOCALAI_AUDIO_PATH |
| --upload-path | /tmp/localai/upload | Path to store uploads from files api | $LOCALAI_UPLOAD_PATH | | --upload-path | /tmp/localai/upload | Path to store uploads from files api | $LOCALAI_UPLOAD_PATH |
| --config-path | /tmp/localai/config | | $LOCALAI_CONFIG_PATH | | --config-path | /tmp/localai/config | | $LOCALAI_CONFIG_PATH |
| --localai-config-dir | BASEPATH/configuration | Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json) | $LOCALAI_CONFIG_DIR | | --localai-config-dir | BASEPATH/configuration | Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json) | $LOCALAI_CONFIG_DIR |

View file

@ -39,6 +39,7 @@ type Backend interface {
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)

View file

@ -53,6 +53,10 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
return fmt.Errorf("unimplemented") return fmt.Errorf("unimplemented")
} }
func (llm *Base) GenerateVideo(*pb.GenerateVideoRequest) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) { func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
return pb.TranscriptResult{}, fmt.Errorf("unimplemented") return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
} }

View file

@ -215,6 +215,28 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest,
return client.GenerateImage(ctx, in, opts...) return client.GenerateImage(ctx, in, opts...)
} }
func (c *Client) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.GenerateVideo(ctx, in, opts...)
}
func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) { func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
if !c.parallel { if !c.parallel {
c.opMutex.Lock() c.opMutex.Lock()

View file

@ -47,6 +47,10 @@ func (e *embedBackend) GenerateImage(ctx context.Context, in *pb.GenerateImageRe
return e.s.GenerateImage(ctx, in) return e.s.GenerateImage(ctx, in)
} }
func (e *embedBackend) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error) {
return e.s.GenerateVideo(ctx, in)
}
func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) { func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
return e.s.TTS(ctx, in) return e.s.TTS(ctx, in)
} }

View file

@ -14,6 +14,7 @@ type LLM interface {
Load(*pb.ModelOptions) error Load(*pb.ModelOptions) error
Embeddings(*pb.PredictOptions) ([]float32, error) Embeddings(*pb.PredictOptions) ([]float32, error)
GenerateImage(*pb.GenerateImageRequest) error GenerateImage(*pb.GenerateImageRequest) error
GenerateVideo(*pb.GenerateVideoRequest) error
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
TTS(*pb.TTSRequest) error TTS(*pb.TTSRequest) error
SoundGeneration(*pb.SoundGenerationRequest) error SoundGeneration(*pb.SoundGenerationRequest) error

View file

@ -75,6 +75,18 @@ func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest)
return &pb.Result{Message: "Image generated", Success: true}, nil return &pb.Result{Message: "Image generated", Success: true}, nil
} }
func (s *server) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
err := s.llm.GenerateVideo(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error generating video: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Video generated", Success: true}, nil
}
func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) { func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) {
if s.llm.Locking() { if s.llm.Locking() {
s.llm.Lock() s.llm.Lock()