mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-20 18:45:00 +00:00
working to address missing items
referencing #3436, #2930 - if i could test it, this might show that the output from the vllm backend is processed and returned to the user Signed-off-by: Wyatt Neal <wyatt.neal+git@gmail.com>
This commit is contained in:
parent
2b2d907a3a
commit
1569bc4959
2 changed files with 79 additions and 19 deletions
|
@ -194,27 +194,40 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
await iterations.aclose()
|
await iterations.aclose()
|
||||||
|
|
||||||
async def _predict(self, request, context, streaming=False):
|
async def _predict(self, request, context, streaming=False):
|
||||||
|
# Build the sampling parameters
|
||||||
|
# NOTE: this must stay in sync with the vllm backend
|
||||||
|
request_to_sampling_params = {
|
||||||
|
"N": "n",
|
||||||
|
"PresencePenalty": "presence_penalty",
|
||||||
|
"FrequencyPenalty": "frequency_penalty",
|
||||||
|
"RepetitionPenalty": "repetition_penalty",
|
||||||
|
"Temperature": "temperature",
|
||||||
|
"TopP": "top_p",
|
||||||
|
"TopK": "top_k",
|
||||||
|
"MinP": "min_p",
|
||||||
|
"Seed": "seed",
|
||||||
|
"StopPrompts": "stop",
|
||||||
|
"StopTokenIds": "stop_token_ids",
|
||||||
|
"BadWords": "bad_words",
|
||||||
|
"IncludeStopStrInOutput": "include_stop_str_in_output",
|
||||||
|
"IgnoreEOS": "ignore_eos",
|
||||||
|
"Tokens": "max_tokens",
|
||||||
|
"MinTokens": "min_tokens",
|
||||||
|
"Logprobs": "logprobs",
|
||||||
|
"PromptLogprobs": "prompt_logprobs",
|
||||||
|
"SkipSpecialTokens": "skip_special_tokens",
|
||||||
|
"SpacesBetweenSpecialTokens": "spaces_between_special_tokens",
|
||||||
|
"TruncatePromptTokens": "truncate_prompt_tokens",
|
||||||
|
"GuidedDecoding": "guided_decoding",
|
||||||
|
}
|
||||||
|
|
||||||
# Build sampling parameters
|
|
||||||
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
|
sampling_params = SamplingParams(top_p=0.9, max_tokens=200)
|
||||||
if request.TopP != 0:
|
|
||||||
sampling_params.top_p = request.TopP
|
for request_field, param_field in request_to_sampling_params.items():
|
||||||
if request.Tokens > 0:
|
if hasattr(request, request_field):
|
||||||
sampling_params.max_tokens = request.Tokens
|
value = getattr(request, request_field)
|
||||||
if request.Temperature != 0:
|
if value not in (None, 0, [], False):
|
||||||
sampling_params.temperature = request.Temperature
|
setattr(sampling_params, param_field, value)
|
||||||
if request.TopK != 0:
|
|
||||||
sampling_params.top_k = request.TopK
|
|
||||||
if request.PresencePenalty != 0:
|
|
||||||
sampling_params.presence_penalty = request.PresencePenalty
|
|
||||||
if request.FrequencyPenalty != 0:
|
|
||||||
sampling_params.frequency_penalty = request.FrequencyPenalty
|
|
||||||
if request.StopPrompts:
|
|
||||||
sampling_params.stop = request.StopPrompts
|
|
||||||
if request.IgnoreEOS:
|
|
||||||
sampling_params.ignore_eos = request.IgnoreEOS
|
|
||||||
if request.Seed != 0:
|
|
||||||
sampling_params.seed = request.Seed
|
|
||||||
|
|
||||||
# Extract image paths and process images
|
# Extract image paths and process images
|
||||||
prompt = request.Prompt
|
prompt = request.Prompt
|
||||||
|
|
|
@ -75,6 +75,53 @@ class TestBackendServicer(unittest.TestCase):
|
||||||
finally:
|
finally:
|
||||||
self.tearDown()
|
self.tearDown()
|
||||||
|
|
||||||
|
def test_sampling_params(self):
|
||||||
|
"""
|
||||||
|
This method tests if all sampling parameters are correctly processed
|
||||||
|
NOTE: this does NOT test for correctness, just that we received a compatible response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
||||||
|
self.assertTrue(response.success)
|
||||||
|
|
||||||
|
req = backend_pb2.PredictOptions(
|
||||||
|
Prompt="The capital of France is",
|
||||||
|
TopP=0.8,
|
||||||
|
Tokens=50,
|
||||||
|
Temperature=0.7,
|
||||||
|
TopK=40,
|
||||||
|
PresencePenalty=0.1,
|
||||||
|
FrequencyPenalty=0.2,
|
||||||
|
RepetitionPenalty=1.1,
|
||||||
|
MinP=0.05,
|
||||||
|
Seed=42,
|
||||||
|
StopPrompts=["\n"],
|
||||||
|
StopTokenIds=[50256],
|
||||||
|
BadWords=["badword"],
|
||||||
|
IncludeStopStrInOutput=True,
|
||||||
|
IgnoreEOS=True,
|
||||||
|
MinTokens=5,
|
||||||
|
Logprobs=5,
|
||||||
|
PromptLogprobs=5,
|
||||||
|
SkipSpecialTokens=True,
|
||||||
|
SpacesBetweenSpecialTokens=True,
|
||||||
|
TruncatePromptTokens=10,
|
||||||
|
GuidedDecoding=True,
|
||||||
|
N=2,
|
||||||
|
)
|
||||||
|
resp = stub.Predict(req)
|
||||||
|
self.assertIsNotNone(resp.message)
|
||||||
|
self.assertIsNotNone(resp.logprobs)
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("sampling params service failed")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
|
||||||
def test_embedding(self):
|
def test_embedding(self):
|
||||||
"""
|
"""
|
||||||
This method tests if the embeddings are generated successfully
|
This method tests if the embeddings are generated successfully
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue