diff --git a/backend/python/transformers/transformers_server.py b/backend/python/transformers/transformers_server.py index 93b2ce25..f40b8951 100755 --- a/backend/python/transformers/transformers_server.py +++ b/backend/python/transformers/transformers_server.py @@ -89,8 +89,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): quantization = None if self.CUDA: - if request.Device: - device_map=request.Device + if request.MainGPU: + device_map=request.MainGPU else: device_map="cuda:0" if request.Quantization == "bnb_4bit": @@ -143,28 +143,36 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): from optimum.intel.openvino import OVModelForCausalLM from openvino.runtime import Core - if "GPU" in Core().available_devices: - device_map="GPU" + if request.MainGPU: + device_map=request.MainGPU else: - device_map="CPU" + device_map="AUTO" + devices = Core().available_devices + if "GPU" in " ".join(devices): + device_map="AUTO:GPU" + self.model = OVModelForCausalLM.from_pretrained(model_name, compile=True, trust_remote_code=request.TrustRemoteCode, - ov_config={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}, + ov_config={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}, device=device_map) self.OV = True elif request.Type == "OVModelForFeatureExtraction": from optimum.intel.openvino import OVModelForFeatureExtraction from openvino.runtime import Core - if "GPU" in Core().available_devices: - device_map="GPU" + if request.MainGPU: + device_map=request.MainGPU else: - device_map="CPU" + device_map="AUTO" + devices = Core().available_devices + if "GPU" in " ".join(devices): + device_map="AUTO:GPU" + self.model = OVModelForFeatureExtraction.from_pretrained(model_name, compile=True, trust_remote_code=request.TrustRemoteCode, - ov_config={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"}, + ov_config={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT", "GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"}, export=True, device=device_map) self.OV = True @@ -371,4 +379,4 @@ if __name__ == "__main__": ) args = parser.parse_args() - asyncio.run(serve(args.addr)) \ No newline at end of file + asyncio.run(serve(args.addr))