mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-31 16:05:00 +00:00
Add implement of tiny and update env
Signed-off-by: GitHub <noreply@github.com>
This commit is contained in:
parent
edcae7a5f1
commit
ca735e7ffc
2 changed files with 68 additions and 4 deletions
|
@ -17,7 +17,7 @@ import grpc
|
||||||
import torch
|
import torch
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from segment_anything_hq import SamAutomaticMaskGenerator
|
from segment_anything_hq import SamAutomaticMaskGenerator
|
||||||
from segment_anything_hq.modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer
|
from segment_anything_hq.modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer, TinyViT
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -54,8 +54,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
model_name = request.model_name
|
model_name = request.model_name
|
||||||
if model_name not in SamModelType.__dict__.keys():
|
if model_name not in SamModelType.__dict__.keys():
|
||||||
raise Exception(f"Model name {model_name} not found in {SamModelType.__dict__.keys()}")
|
raise Exception(f"Model name {model_name} not found in {SamModelType.__dict__.keys()}")
|
||||||
|
|
||||||
model_path = request.model_path
|
model_path = request.model_path
|
||||||
# check the model_path is valid
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
raise Exception(f"Model path {model_path} does not exist")
|
raise Exception(f"Model path {model_path} does not exist")
|
||||||
|
|
||||||
|
@ -69,8 +69,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
case SamModelType.vit_b:
|
case SamModelType.vit_b:
|
||||||
sam = _build_sam_vit_b(checkpoint=model_path)
|
sam = _build_sam_vit_b(checkpoint=model_path)
|
||||||
case SamModelType.vit_tiny:
|
case SamModelType.vit_tiny:
|
||||||
# TODO: Implement this
|
sam = _build_sam_vit_tiny(checkpoint=model_path)
|
||||||
pass
|
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f"Model name {model_name} not found in {SamModelType.__dict__.keys()}")
|
raise Exception(f"Model name {model_name} not found in {SamModelType.__dict__.keys()}")
|
||||||
# TODO No sure if this is the right way to do it
|
# TODO No sure if this is the right way to do it
|
||||||
|
@ -163,6 +162,57 @@ def _build_sam_vit_l(checkpoint=None):
|
||||||
def _build_sam_vit_b(checkpoint=None):
|
def _build_sam_vit_b(checkpoint=None):
|
||||||
return _constrcut_sam(encoder_embed_dim=768,encoder_depth=12,encoder_num_heads=12,encoder_global_attn_indexes=[2,5,8,11],checkpoint=checkpoint)
|
return _constrcut_sam(encoder_embed_dim=768,encoder_depth=12,encoder_num_heads=12,encoder_global_attn_indexes=[2,5,8,11],checkpoint=checkpoint)
|
||||||
|
|
||||||
|
def _build_sam_vit_tiny(checkpoint=None):
|
||||||
|
image_embedding_size = IMAGE_SIZE // VIT_PATCH_SIZE
|
||||||
|
mobile_sam = Sam(
|
||||||
|
image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000,
|
||||||
|
embed_dims=[64, 128, 160, 320],
|
||||||
|
depths=[2, 2, 6, 2],
|
||||||
|
num_heads=[2, 4, 5, 10],
|
||||||
|
window_sizes=[7, 7, 14, 7],
|
||||||
|
mlp_ratio=4.,
|
||||||
|
drop_rate=0.,
|
||||||
|
drop_path_rate=0.0,
|
||||||
|
use_checkpoint=False,
|
||||||
|
mbconv_expand_ratio=4.0,
|
||||||
|
local_conv_size=3,
|
||||||
|
layer_lr_decay=0.8
|
||||||
|
),
|
||||||
|
prompt_encoder=PromptEncoder(
|
||||||
|
embed_dim=PROMT_EMBED_DIM,
|
||||||
|
image_embedding_size=(image_embedding_size, image_embedding_size),
|
||||||
|
input_image_size=(IMAGE_SIZE, IMAGE_SIZE),
|
||||||
|
mask_in_chans=16,
|
||||||
|
),
|
||||||
|
mask_decoder=MaskDecoderHQ(
|
||||||
|
num_multimask_outputs=3,
|
||||||
|
transformer=TwoWayTransformer(
|
||||||
|
depth=2,
|
||||||
|
embedding_dim=PROMT_EMBED_DIM,
|
||||||
|
mlp_dim=2048,
|
||||||
|
num_heads=8,
|
||||||
|
),
|
||||||
|
transformer_dim=PROMT_EMBED_DIM,
|
||||||
|
iou_head_depth=3,
|
||||||
|
iou_head_hidden_dim=256,
|
||||||
|
vit_dim=160,
|
||||||
|
),
|
||||||
|
pixel_mean=[123.675, 116.28, 103.53],
|
||||||
|
pixel_std=[58.395, 57.12, 57.375],
|
||||||
|
)
|
||||||
|
|
||||||
|
mobile_sam.eval()
|
||||||
|
if checkpoint is not None:
|
||||||
|
with open(checkpoint, "rb") as f:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
state_dict = torch.load(f, map_location=device)
|
||||||
|
info = mobile_sam.load_state_dict(state_dict, strict=False)
|
||||||
|
print(info)
|
||||||
|
for n, p in mobile_sam.named_parameters():
|
||||||
|
if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n:
|
||||||
|
p.requires_grad = False
|
||||||
|
return mobile_sam
|
||||||
|
|
||||||
def masks_to_image(anns, request):
|
def masks_to_image(anns, request):
|
||||||
if len(anns)==0:
|
if len(anns)==0:
|
||||||
return
|
return
|
||||||
|
|
|
@ -27,12 +27,18 @@ dependencies:
|
||||||
- pip:
|
- pip:
|
||||||
- certifi==2023.7.22
|
- certifi==2023.7.22
|
||||||
- charset-normalizer==3.3.2
|
- charset-normalizer==3.3.2
|
||||||
|
- contourpy==1.2.0
|
||||||
|
- cycler==0.12.1
|
||||||
- filelock==3.13.1
|
- filelock==3.13.1
|
||||||
|
- fonttools==4.44.0
|
||||||
- fsspec==2023.10.0
|
- fsspec==2023.10.0
|
||||||
- grpcio==1.59.2
|
- grpcio==1.59.2
|
||||||
|
- huggingface-hub==0.18.0
|
||||||
- idna==3.4
|
- idna==3.4
|
||||||
- jinja2==3.1.2
|
- jinja2==3.1.2
|
||||||
|
- kiwisolver==1.4.5
|
||||||
- markupsafe==2.1.3
|
- markupsafe==2.1.3
|
||||||
|
- matplotlib==3.8.1
|
||||||
- mpmath==1.3.0
|
- mpmath==1.3.0
|
||||||
- networkx==3.2.1
|
- networkx==3.2.1
|
||||||
- numpy==1.26.1
|
- numpy==1.26.1
|
||||||
|
@ -48,13 +54,21 @@ dependencies:
|
||||||
- nvidia-nccl-cu12==2.18.1
|
- nvidia-nccl-cu12==2.18.1
|
||||||
- nvidia-nvjitlink-cu12==12.3.52
|
- nvidia-nvjitlink-cu12==12.3.52
|
||||||
- nvidia-nvtx-cu12==12.1.105
|
- nvidia-nvtx-cu12==12.1.105
|
||||||
|
- packaging==23.2
|
||||||
- pillow==10.1.0
|
- pillow==10.1.0
|
||||||
- protobuf==4.25.0
|
- protobuf==4.25.0
|
||||||
|
- pyparsing==3.1.1
|
||||||
|
- python-dateutil==2.8.2
|
||||||
|
- pyyaml==6.0.1
|
||||||
- requests==2.31.0
|
- requests==2.31.0
|
||||||
|
- safetensors==0.4.0
|
||||||
- segment-anything-hq==0.3
|
- segment-anything-hq==0.3
|
||||||
|
- six==1.16.0
|
||||||
- sympy==1.12
|
- sympy==1.12
|
||||||
|
- timm==0.9.10
|
||||||
- torch==2.1.0
|
- torch==2.1.0
|
||||||
- torchvision==0.16.0
|
- torchvision==0.16.0
|
||||||
|
- tqdm==4.66.1
|
||||||
- triton==2.1.0
|
- triton==2.1.0
|
||||||
- typing-extensions==4.8.0
|
- typing-extensions==4.8.0
|
||||||
- urllib3==2.0.7
|
- urllib3==2.0.7
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue