mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-23 12:05:00 +00:00
Feat: rwkv improvements: (#937)
This commit is contained in:
parent
0d6165e481
commit
901f0709c5
7 changed files with 208 additions and 150 deletions
|
@ -20,9 +20,15 @@ type LLM struct {
|
|||
}
|
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error {
|
||||
tokenizerFile := opts.Tokenizer
|
||||
if tokenizerFile == "" {
|
||||
modelFile := filepath.Base(opts.ModelFile)
|
||||
tokenizerFile = modelFile + tokenizerSuffix
|
||||
}
|
||||
modelPath := filepath.Dir(opts.ModelFile)
|
||||
modelFile := filepath.Base(opts.ModelFile)
|
||||
model := rwkv.LoadFiles(opts.ModelFile, filepath.Join(modelPath, modelFile+tokenizerSuffix), uint32(opts.GetThreads()))
|
||||
tokenizerPath := filepath.Join(modelPath, tokenizerFile)
|
||||
|
||||
model := rwkv.LoadFiles(opts.ModelFile, tokenizerPath, uint32(opts.GetThreads()))
|
||||
|
||||
if model == nil {
|
||||
return fmt.Errorf("could not load model")
|
||||
|
@ -68,3 +74,22 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
|
||||
tokens, err := llm.rwkv.Tokenizer.Encode(opts.Prompt)
|
||||
if err != nil {
|
||||
return pb.TokenizationResponse{}, err
|
||||
}
|
||||
|
||||
l := len(tokens)
|
||||
i32Tokens := make([]int32, l)
|
||||
|
||||
for i, t := range tokens {
|
||||
i32Tokens[i] = int32(t.ID)
|
||||
}
|
||||
|
||||
return pb.TokenizationResponse{
|
||||
Length: int32(l),
|
||||
Tokens: i32Tokens,
|
||||
}, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue