mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-19 18:15:00 +00:00
fix: vllm missing logprobs (#5279)
* working to address missing items referencing #3436, #2930 - if i could test it, this might show that the output from the vllm backend is processed and returned to the user Signed-off-by: Wyatt Neal <wyatt.neal+git@gmail.com> * adding in vllm tests to test-extras Signed-off-by: Wyatt Neal <wyatt.neal+git@gmail.com> * adding in tests to pipeline for execution Signed-off-by: Wyatt Neal <wyatt.neal+git@gmail.com> * removing todo block, test via pipeline Signed-off-by: Wyatt Neal <wyatt.neal+git@gmail.com> --------- Signed-off-by: Wyatt Neal <wyatt.neal+git@gmail.com>
This commit is contained in:
parent
26cbf77c0d
commit
4076ea0494
4 changed files with 101 additions and 19 deletions
20
.github/workflows/test-extra.yml
vendored
20
.github/workflows/test-extra.yml
vendored
|
@ -78,6 +78,26 @@ jobs:
|
|||
make --jobs=5 --output-sync=target -C backend/python/diffusers
|
||||
make --jobs=5 --output-sync=target -C backend/python/diffusers test
|
||||
|
||||
tests-vllm:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential ffmpeg
|
||||
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
|
||||
sudo apt-get install -y libopencv-dev
|
||||
# Install UV
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
pip install --user --no-cache-dir grpcio-tools==1.64.1
|
||||
- name: Test vllm backend
|
||||
run: |
|
||||
make --jobs=5 --output-sync=target -C backend/python/vllm
|
||||
make --jobs=5 --output-sync=target -C backend/python/vllm test
|
||||
# tests-transformers-musicgen:
|
||||
# runs-on: ubuntu-latest
|
||||
# steps:
|
||||
|
|
2
Makefile
2
Makefile
|
@ -598,10 +598,12 @@ prepare-extra-conda-environments: protogen-python
|
|||
prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/transformers
|
||||
$(MAKE) -C backend/python/diffusers
|
||||
$(MAKE) -C backend/python/vllm
|
||||
|
||||
test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/transformers test
|
||||
$(MAKE) -C backend/python/diffusers test
|
||||
$(MAKE) -C backend/python/vllm test
|
||||
|
||||
backend-assets:
|
||||
mkdir -p backend-assets
|
||||
|
|
|
@ -194,27 +194,40 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||
await iterations.aclose()
|
||||
|
||||
async def _predict(self, request, context, streaming=False):
|
||||
# Build the sampling parameters
|
||||
# NOTE: this must stay in sync with the vllm backend
|
||||
request_to_sampling_params = {
|
||||
"N": "n",
|
||||
"PresencePenalty": "presence_penalty",
|
||||
"FrequencyPenalty": "frequency_penalty",
|
||||
"RepetitionPenalty": "repetition_penalty",
|
||||
"Temperature": "temperature",
|
||||
"TopP": "top_p",
|
||||
"TopK": "top_k",
|
||||
"MinP": "min_p",
|
||||
"Seed": "seed",
|
||||
"StopPrompts": "stop",
|
||||
"StopTokenIds": "stop_token_ids",
|
||||
"BadWords": "bad_words",
|
||||
"IncludeStopStrInOutput": "include_stop_str_in_output",
|
||||
"IgnoreEOS": "ignore_eos",
|
||||
"Tokens": "max_tokens",
|
||||
"MinTokens": "min_tokens",
|
||||
"Logprobs": "logprobs",
|
||||
"PromptLogprobs": "prompt_logprobs",
|
||||
"SkipSpecialTokens": "skip_special_tokens",
|
||||
"SpacesBetweenSpecialTokens": "spaces_between_special_tokens",
|
||||
"TruncatePromptTokens": "truncate_prompt_tokens",
|
||||
"GuidedDecoding": "guided_decoding",
|
||||
}
|
||||
|
||||
# Build sampling parameters
|
||||
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
|
||||
if request.TopP != 0:
|
||||
sampling_params.top_p = request.TopP
|
||||
if request.Tokens > 0:
|
||||
sampling_params.max_tokens = request.Tokens
|
||||
if request.Temperature != 0:
|
||||
sampling_params.temperature = request.Temperature
|
||||
if request.TopK != 0:
|
||||
sampling_params.top_k = request.TopK
|
||||
if request.PresencePenalty != 0:
|
||||
sampling_params.presence_penalty = request.PresencePenalty
|
||||
if request.FrequencyPenalty != 0:
|
||||
sampling_params.frequency_penalty = request.FrequencyPenalty
|
||||
if request.StopPrompts:
|
||||
sampling_params.stop = request.StopPrompts
|
||||
if request.IgnoreEOS:
|
||||
sampling_params.ignore_eos = request.IgnoreEOS
|
||||
if request.Seed != 0:
|
||||
sampling_params.seed = request.Seed
|
||||
|
||||
for request_field, param_field in request_to_sampling_params.items():
|
||||
if hasattr(request, request_field):
|
||||
value = getattr(request, request_field)
|
||||
if value not in (None, 0, [], False):
|
||||
setattr(sampling_params, param_field, value)
|
||||
|
||||
# Extract image paths and process images
|
||||
prompt = request.Prompt
|
||||
|
|
|
@ -75,6 +75,53 @@ class TestBackendServicer(unittest.TestCase):
|
|||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_sampling_params(self):
|
||||
"""
|
||||
This method tests if all sampling parameters are correctly processed
|
||||
NOTE: this does NOT test for correctness, just that we received a compatible response
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
req = backend_pb2.PredictOptions(
|
||||
Prompt="The capital of France is",
|
||||
TopP=0.8,
|
||||
Tokens=50,
|
||||
Temperature=0.7,
|
||||
TopK=40,
|
||||
PresencePenalty=0.1,
|
||||
FrequencyPenalty=0.2,
|
||||
RepetitionPenalty=1.1,
|
||||
MinP=0.05,
|
||||
Seed=42,
|
||||
StopPrompts=["\n"],
|
||||
StopTokenIds=[50256],
|
||||
BadWords=["badword"],
|
||||
IncludeStopStrInOutput=True,
|
||||
IgnoreEOS=True,
|
||||
MinTokens=5,
|
||||
Logprobs=5,
|
||||
PromptLogprobs=5,
|
||||
SkipSpecialTokens=True,
|
||||
SpacesBetweenSpecialTokens=True,
|
||||
TruncatePromptTokens=10,
|
||||
GuidedDecoding=True,
|
||||
N=2,
|
||||
)
|
||||
resp = stub.Predict(req)
|
||||
self.assertIsNotNone(resp.message)
|
||||
self.assertIsNotNone(resp.logprobs)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("sampling params service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
def test_embedding(self):
|
||||
"""
|
||||
This method tests if the embeddings are generated successfully
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue