Feat: rwkv improvements: (#937)

This commit is contained in:
Dave 2023-08-22 12:48:06 -04:00 committed by GitHub
parent 0d6165e481
commit 901f0709c5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 208 additions and 150 deletions

View file

@ -20,9 +20,15 @@ type LLM struct {
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
tokenizerFile := opts.Tokenizer
if tokenizerFile == "" {
modelFile := filepath.Base(opts.ModelFile)
tokenizerFile = modelFile + tokenizerSuffix
}
modelPath := filepath.Dir(opts.ModelFile)
modelFile := filepath.Base(opts.ModelFile)
model := rwkv.LoadFiles(opts.ModelFile, filepath.Join(modelPath, modelFile+tokenizerSuffix), uint32(opts.GetThreads()))
tokenizerPath := filepath.Join(modelPath, tokenizerFile)
model := rwkv.LoadFiles(opts.ModelFile, tokenizerPath, uint32(opts.GetThreads()))
if model == nil {
return fmt.Errorf("could not load model")
@ -68,3 +74,22 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) erro
return nil
}
func (llm *LLM) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
tokens, err := llm.rwkv.Tokenizer.Encode(opts.Prompt)
if err != nil {
return pb.TokenizationResponse{}, err
}
l := len(tokens)
i32Tokens := make([]int32, l)
for i, t := range tokens {
i32Tokens[i] = int32(t.ID)
}
return pb.TokenizationResponse{
Length: int32(l),
Tokens: i32Tokens,
}, nil
}