From ca735e7ffc211e7c77242b9046a520c912cd6174 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Mon, 6 Nov 2023 23:58:38 +0000 Subject: [PATCH] Add implement of tiny and update env Signed-off-by: GitHub --- extra/grpc/sam/sam.py | 58 +++++++++++++++++++++++++++++++++++++++--- extra/grpc/sam/sam.yml | 14 ++++++++++ 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/extra/grpc/sam/sam.py b/extra/grpc/sam/sam.py index b9d90b84..f413b6db 100644 --- a/extra/grpc/sam/sam.py +++ b/extra/grpc/sam/sam.py @@ -17,7 +17,7 @@ import grpc import torch from functools import partial from segment_anything_hq import SamAutomaticMaskGenerator -from segment_anything_hq.modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer +from segment_anything_hq.modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer, TinyViT import matplotlib.pyplot as plt import numpy as np @@ -54,8 +54,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): model_name = request.model_name if model_name not in SamModelType.__dict__.keys(): raise Exception(f"Model name {model_name} not found in {SamModelType.__dict__.keys()}") + model_path = request.model_path - # check the model_path is valid if not os.path.exists(model_path): raise Exception(f"Model path {model_path} does not exist") @@ -69,8 +69,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): case SamModelType.vit_b: sam = _build_sam_vit_b(checkpoint=model_path) case SamModelType.vit_tiny: - # TODO: Implement this - pass + sam = _build_sam_vit_tiny(checkpoint=model_path) case _: raise Exception(f"Model name {model_name} not found in {SamModelType.__dict__.keys()}") # TODO No sure if this is the right way to do it @@ -163,6 +162,57 @@ def _build_sam_vit_l(checkpoint=None): def _build_sam_vit_b(checkpoint=None): return _constrcut_sam(encoder_embed_dim=768,encoder_depth=12,encoder_num_heads=12,encoder_global_attn_indexes=[2,5,8,11],checkpoint=checkpoint) +def _build_sam_vit_tiny(checkpoint=None): + image_embedding_size = IMAGE_SIZE // VIT_PATCH_SIZE + mobile_sam = Sam( + image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8 + ), + prompt_encoder=PromptEncoder( + embed_dim=PROMT_EMBED_DIM, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(IMAGE_SIZE, IMAGE_SIZE), + mask_in_chans=16, + ), + mask_decoder=MaskDecoderHQ( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=PROMT_EMBED_DIM, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=PROMT_EMBED_DIM, + iou_head_depth=3, + iou_head_hidden_dim=256, + vit_dim=160, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + + mobile_sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + device = "cuda" if torch.cuda.is_available() else "cpu" + state_dict = torch.load(f, map_location=device) + info = mobile_sam.load_state_dict(state_dict, strict=False) + print(info) + for n, p in mobile_sam.named_parameters(): + if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n: + p.requires_grad = False + return mobile_sam + def masks_to_image(anns, request): if len(anns)==0: return diff --git a/extra/grpc/sam/sam.yml b/extra/grpc/sam/sam.yml index 42717a0a..1529b1e5 100644 --- a/extra/grpc/sam/sam.yml +++ b/extra/grpc/sam/sam.yml @@ -27,12 +27,18 @@ dependencies: - pip: - certifi==2023.7.22 - charset-normalizer==3.3.2 + - contourpy==1.2.0 + - cycler==0.12.1 - filelock==3.13.1 + - fonttools==4.44.0 - fsspec==2023.10.0 - grpcio==1.59.2 + - huggingface-hub==0.18.0 - idna==3.4 - jinja2==3.1.2 + - kiwisolver==1.4.5 - markupsafe==2.1.3 + - matplotlib==3.8.1 - mpmath==1.3.0 - networkx==3.2.1 - numpy==1.26.1 @@ -48,13 +54,21 @@ dependencies: - nvidia-nccl-cu12==2.18.1 - nvidia-nvjitlink-cu12==12.3.52 - nvidia-nvtx-cu12==12.1.105 + - packaging==23.2 - pillow==10.1.0 - protobuf==4.25.0 + - pyparsing==3.1.1 + - python-dateutil==2.8.2 + - pyyaml==6.0.1 - requests==2.31.0 + - safetensors==0.4.0 - segment-anything-hq==0.3 + - six==1.16.0 - sympy==1.12 + - timm==0.9.10 - torch==2.1.0 - torchvision==0.16.0 + - tqdm==4.66.1 - triton==2.1.0 - typing-extensions==4.8.0 - urllib3==2.0.7