mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-19 18:15:00 +00:00
fix: vllm missing logprobs (#5279)
* working to address missing items referencing #3436, #2930 - if i could test it, this might show that the output from the vllm backend is processed and returned to the user Signed-off-by: Wyatt Neal <wyatt.neal+git@gmail.com> * adding in vllm tests to test-extras Signed-off-by: Wyatt Neal <wyatt.neal+git@gmail.com> * adding in tests to pipeline for execution Signed-off-by: Wyatt Neal <wyatt.neal+git@gmail.com> * removing todo block, test via pipeline Signed-off-by: Wyatt Neal <wyatt.neal+git@gmail.com> --------- Signed-off-by: Wyatt Neal <wyatt.neal+git@gmail.com>
This commit is contained in:
parent
26cbf77c0d
commit
4076ea0494
4 changed files with 101 additions and 19 deletions
20
.github/workflows/test-extra.yml
vendored
20
.github/workflows/test-extra.yml
vendored
|
@ -78,6 +78,26 @@ jobs:
|
||||||
make --jobs=5 --output-sync=target -C backend/python/diffusers
|
make --jobs=5 --output-sync=target -C backend/python/diffusers
|
||||||
make --jobs=5 --output-sync=target -C backend/python/diffusers test
|
make --jobs=5 --output-sync=target -C backend/python/diffusers test
|
||||||
|
|
||||||
|
tests-vllm:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
submodules: true
|
||||||
|
- name: Dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y build-essential ffmpeg
|
||||||
|
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
|
||||||
|
sudo apt-get install -y libopencv-dev
|
||||||
|
# Install UV
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
pip install --user --no-cache-dir grpcio-tools==1.64.1
|
||||||
|
- name: Test vllm backend
|
||||||
|
run: |
|
||||||
|
make --jobs=5 --output-sync=target -C backend/python/vllm
|
||||||
|
make --jobs=5 --output-sync=target -C backend/python/vllm test
|
||||||
# tests-transformers-musicgen:
|
# tests-transformers-musicgen:
|
||||||
# runs-on: ubuntu-latest
|
# runs-on: ubuntu-latest
|
||||||
# steps:
|
# steps:
|
||||||
|
|
2
Makefile
2
Makefile
|
@ -598,10 +598,12 @@ prepare-extra-conda-environments: protogen-python
|
||||||
prepare-test-extra: protogen-python
|
prepare-test-extra: protogen-python
|
||||||
$(MAKE) -C backend/python/transformers
|
$(MAKE) -C backend/python/transformers
|
||||||
$(MAKE) -C backend/python/diffusers
|
$(MAKE) -C backend/python/diffusers
|
||||||
|
$(MAKE) -C backend/python/vllm
|
||||||
|
|
||||||
test-extra: prepare-test-extra
|
test-extra: prepare-test-extra
|
||||||
$(MAKE) -C backend/python/transformers test
|
$(MAKE) -C backend/python/transformers test
|
||||||
$(MAKE) -C backend/python/diffusers test
|
$(MAKE) -C backend/python/diffusers test
|
||||||
|
$(MAKE) -C backend/python/vllm test
|
||||||
|
|
||||||
backend-assets:
|
backend-assets:
|
||||||
mkdir -p backend-assets
|
mkdir -p backend-assets
|
||||||
|
|
|
@ -194,27 +194,40 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
await iterations.aclose()
|
await iterations.aclose()
|
||||||
|
|
||||||
async def _predict(self, request, context, streaming=False):
|
async def _predict(self, request, context, streaming=False):
|
||||||
|
# Build the sampling parameters
|
||||||
|
# NOTE: this must stay in sync with the vllm backend
|
||||||
|
request_to_sampling_params = {
|
||||||
|
"N": "n",
|
||||||
|
"PresencePenalty": "presence_penalty",
|
||||||
|
"FrequencyPenalty": "frequency_penalty",
|
||||||
|
"RepetitionPenalty": "repetition_penalty",
|
||||||
|
"Temperature": "temperature",
|
||||||
|
"TopP": "top_p",
|
||||||
|
"TopK": "top_k",
|
||||||
|
"MinP": "min_p",
|
||||||
|
"Seed": "seed",
|
||||||
|
"StopPrompts": "stop",
|
||||||
|
"StopTokenIds": "stop_token_ids",
|
||||||
|
"BadWords": "bad_words",
|
||||||
|
"IncludeStopStrInOutput": "include_stop_str_in_output",
|
||||||
|
"IgnoreEOS": "ignore_eos",
|
||||||
|
"Tokens": "max_tokens",
|
||||||
|
"MinTokens": "min_tokens",
|
||||||
|
"Logprobs": "logprobs",
|
||||||
|
"PromptLogprobs": "prompt_logprobs",
|
||||||
|
"SkipSpecialTokens": "skip_special_tokens",
|
||||||
|
"SpacesBetweenSpecialTokens": "spaces_between_special_tokens",
|
||||||
|
"TruncatePromptTokens": "truncate_prompt_tokens",
|
||||||
|
"GuidedDecoding": "guided_decoding",
|
||||||
|
}
|
||||||
|
|
||||||
# Build sampling parameters
|
|
||||||
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
|
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
|
||||||
if request.TopP != 0:
|
|
||||||
sampling_params.top_p = request.TopP
|
for request_field, param_field in request_to_sampling_params.items():
|
||||||
if request.Tokens > 0:
|
if hasattr(request, request_field):
|
||||||
sampling_params.max_tokens = request.Tokens
|
value = getattr(request, request_field)
|
||||||
if request.Temperature != 0:
|
if value not in (None, 0, [], False):
|
||||||
sampling_params.temperature = request.Temperature
|
setattr(sampling_params, param_field, value)
|
||||||
if request.TopK != 0:
|
|
||||||
sampling_params.top_k = request.TopK
|
|
||||||
if request.PresencePenalty != 0:
|
|
||||||
sampling_params.presence_penalty = request.PresencePenalty
|
|
||||||
if request.FrequencyPenalty != 0:
|
|
||||||
sampling_params.frequency_penalty = request.FrequencyPenalty
|
|
||||||
if request.StopPrompts:
|
|
||||||
sampling_params.stop = request.StopPrompts
|
|
||||||
if request.IgnoreEOS:
|
|
||||||
sampling_params.ignore_eos = request.IgnoreEOS
|
|
||||||
if request.Seed != 0:
|
|
||||||
sampling_params.seed = request.Seed
|
|
||||||
|
|
||||||
# Extract image paths and process images
|
# Extract image paths and process images
|
||||||
prompt = request.Prompt
|
prompt = request.Prompt
|
||||||
|
|
|
@ -75,6 +75,53 @@ class TestBackendServicer(unittest.TestCase):
|
||||||
finally:
|
finally:
|
||||||
self.tearDown()
|
self.tearDown()
|
||||||
|
|
||||||
|
def test_sampling_params(self):
|
||||||
|
"""
|
||||||
|
This method tests if all sampling parameters are correctly processed
|
||||||
|
NOTE: this does NOT test for correctness, just that we received a compatible response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
||||||
|
self.assertTrue(response.success)
|
||||||
|
|
||||||
|
req = backend_pb2.PredictOptions(
|
||||||
|
Prompt="The capital of France is",
|
||||||
|
TopP=0.8,
|
||||||
|
Tokens=50,
|
||||||
|
Temperature=0.7,
|
||||||
|
TopK=40,
|
||||||
|
PresencePenalty=0.1,
|
||||||
|
FrequencyPenalty=0.2,
|
||||||
|
RepetitionPenalty=1.1,
|
||||||
|
MinP=0.05,
|
||||||
|
Seed=42,
|
||||||
|
StopPrompts=["\n"],
|
||||||
|
StopTokenIds=[50256],
|
||||||
|
BadWords=["badword"],
|
||||||
|
IncludeStopStrInOutput=True,
|
||||||
|
IgnoreEOS=True,
|
||||||
|
MinTokens=5,
|
||||||
|
Logprobs=5,
|
||||||
|
PromptLogprobs=5,
|
||||||
|
SkipSpecialTokens=True,
|
||||||
|
SpacesBetweenSpecialTokens=True,
|
||||||
|
TruncatePromptTokens=10,
|
||||||
|
GuidedDecoding=True,
|
||||||
|
N=2,
|
||||||
|
)
|
||||||
|
resp = stub.Predict(req)
|
||||||
|
self.assertIsNotNone(resp.message)
|
||||||
|
self.assertIsNotNone(resp.logprobs)
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("sampling params service failed")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
|
||||||
def test_embedding(self):
|
def test_embedding(self):
|
||||||
"""
|
"""
|
||||||
This method tests if the embeddings are generated successfully
|
This method tests if the embeddings are generated successfully
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue