diff --git a/core/cli/federated.go b/core/cli/federated.go index 84440a9f..32f0fa87 100644 --- a/core/cli/federated.go +++ b/core/cli/federated.go @@ -10,11 +10,12 @@ import ( type FederatedCLI struct { Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"` Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2ptoken" help:"Token for P2P mode (optional)" group:"p2p"` + LoadBalanced bool `env:"LOCALAI_LOAD_BALANCED,LOAD_BALANCED" default:"false" help:"Enable load balancing" group:"p2p"` } func (f *FederatedCLI) Run(ctx *cliContext.Context) error { - fs := p2p.NewFederatedServer(f.Address, p2p.FederatedID, f.Peer2PeerToken) + fs := p2p.NewFederatedServer(f.Address, p2p.FederatedID, f.Peer2PeerToken, f.LoadBalanced) return fs.Start(context.Background()) } diff --git a/core/p2p/federated.go b/core/p2p/federated.go index c76ff7b0..b56c9e0c 100644 --- a/core/p2p/federated.go +++ b/core/p2p/federated.go @@ -4,12 +4,44 @@ const FederatedID = "federated" type FederatedServer struct { listenAddr, service, p2ptoken string + requestTable map[string]int + loadBalanced bool } -func NewFederatedServer(listenAddr, service, p2pToken string) *FederatedServer { +func NewFederatedServer(listenAddr, service, p2pToken string, loadBalanced bool) *FederatedServer { return &FederatedServer{ - listenAddr: listenAddr, - service: service, - p2ptoken: p2pToken, + listenAddr: listenAddr, + service: service, + p2ptoken: p2pToken, + requestTable: map[string]int{}, + loadBalanced: loadBalanced, + } +} + +func (fs *FederatedServer) SelectLeastUsedServer() string { + // cycle over requestTable and find the entry with the lower number + // if there are multiple entries with the same number, select one randomly + // if there are no entries, return an empty string + var min int + var minKey string + for k, v := range fs.requestTable { + if min == 0 || v < min { + min = v + minKey = k + } + } + return minKey +} + +func (fs *FederatedServer) RecordRequest(nodeID string) { + // increment the counter for the nodeID in the requestTable + fs.requestTable[nodeID]++ +} + +func (fs *FederatedServer) EnsureRecordExist(nodeID string) { + // if the nodeID is not in the requestTable, add it with a counter of 0 + _, ok := fs.requestTable[nodeID] + if !ok { + fs.requestTable[nodeID] = 0 } } diff --git a/core/p2p/federated_server.go b/core/p2p/federated_server.go index db5957e7..75da97ec 100644 --- a/core/p2p/federated_server.go +++ b/core/p2p/federated_server.go @@ -100,10 +100,23 @@ func (fs *FederatedServer) proxy(ctx context.Context, node *node.Node) error { return } - // open a TCP stream to one of the tunnels - // chosen randomly - // TODO: optimize this and track usage - tunnelAddr := tunnelAddresses[rand.IntN(len(tunnelAddresses))] + tunnelAddr := "" + + if fs.loadBalanced { + for _, t := range tunnelAddresses { + fs.EnsureRecordExist(t) + } + + tunnelAddr = fs.SelectLeastUsedServer() + log.Debug().Msgf("Selected tunnel %s", tunnelAddr) + if tunnelAddr == "" { + tunnelAddr = tunnelAddresses[rand.IntN(len(tunnelAddresses))] + } + + fs.RecordRequest(tunnelAddr) + } else { + tunnelAddr = tunnelAddresses[rand.IntN(len(tunnelAddresses))] + } tunnelConn, err := net.Dial("tcp", tunnelAddr) if err != nil {