mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 10:35:01 +00:00
feat(transformers): add support to Mamba (#4669)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
200fe358f0
commit
89429a439b
17 changed files with 10 additions and 345 deletions
|
@ -21,7 +21,7 @@ import torch.cuda
|
|||
|
||||
|
||||
XPU=os.environ.get("XPU", "0") == "1"
|
||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
|
||||
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
from scipy.io import wavfile
|
||||
import outetts
|
||||
|
@ -245,6 +245,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
autoTokenizer = False
|
||||
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||
self.SentenceTransformer = True
|
||||
elif request.Type == "Mamba":
|
||||
autoTokenizer = False
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = MambaForCausalLM.from_pretrained(model_name)
|
||||
else:
|
||||
print("Automodel", file=sys.stderr)
|
||||
self.model = AutoModel.from_pretrained(model_name,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue