diff --git a/core/http/routes/peerguard.go b/core/http/routes/peerguard.go deleted file mode 100644 index dd576656..00000000 --- a/core/http/routes/peerguard.go +++ /dev/null @@ -1,80 +0,0 @@ -package routes - -import ( - "context" - "time" - - "github.com/gofiber/fiber/v2" - "github.com/mudler/edgevpn/pkg/node" -) - -const DefaultInterval = 5 * time.Second -const Timeout = 20 * time.Second - -// TODO connect routes and write a middleware for authorization based on p2p auth providers private keys -func RegisterPeerguardAuthRoutes(app *fiber.App, e *node.Node) { - app.Get("ledger/:bucket/:key", func(c *fiber.Ctx) error { - bucket := c.Params("bucket") - key := c.Params("key") - - ledger, err := e.Ledger() - if err != nil { - return err - } - - return c.JSON(ledger.CurrentData()[bucket][key]) - }) - - app.Get("ledger/:bucket", func(c *fiber.Ctx) error { - bucket := c.Params("bucket") - - ledger, err := e.Ledger() - if err != nil { - return err - } - - return c.JSON(ledger.CurrentData()[bucket]) - }) - - announcing := struct{ State string }{"Announcing"} - - // Store arbitrary data - app.Get("ledger/:bucket/:key/:value", func(c *fiber.Ctx) error { - bucket := c.Params("bucket") - key := c.Params("key") - value := c.Params("value") - - ledger, err := e.Ledger() - if err != nil { - return err - } - - ledger.Persist(context.Background(), DefaultInterval, Timeout, bucket, key, value) - return c.JSON(announcing) - }) - // Delete data from ledger - app.Get("ledger/:bucket", func(c *fiber.Ctx) error { - bucket := c.Params("bucket") - - ledger, err := e.Ledger() - if err != nil { - return err - } - - ledger.AnnounceDeleteBucket(context.Background(), DefaultInterval, Timeout, bucket) - return c.JSON(announcing) - }) - - app.Get("ledger/:bucket/:key", func(c *fiber.Ctx) error { - bucket := c.Params("bucket") - key := c.Params("key") - - ledger, err := e.Ledger() - if err != nil { - return err - } - - ledger.AnnounceDeleteBucketKey(context.Background(), DefaultInterval, Timeout, bucket, key) - return c.JSON(announcing) - }) -} diff --git a/core/p2p/federated_server.go b/core/p2p/federated_server.go index 081bd134..81fccaa3 100644 --- a/core/p2p/federated_server.go +++ b/core/p2p/federated_server.go @@ -4,17 +4,36 @@ package p2p import ( + "bufio" "context" + "encoding/json" "errors" "fmt" "io" "net" + "net/http" + "slices" + "strings" + "time" + logP2P "github.com/ipfs/go-log/v2" cliP2P "github.com/mudler/LocalAI/core/cli/p2p" + edgevpnConfig "github.com/mudler/edgevpn/pkg/config" + "github.com/mudler/edgevpn/pkg/logger" "github.com/mudler/edgevpn/pkg/node" + "github.com/mudler/edgevpn/pkg/trustzone" "github.com/rs/zerolog/log" ) +const Timeout = 20 * time.Second + +const ( + peekBufferSize = 512 + authHeader = "X-Auth-Token" + headerEnd = "\r\n\r\n" + lineEnd = "\r\n" +) + func (fs *FederatedServer) Start(ctx context.Context, p2pCommonFlags cliP2P.P2PCommonFlags) error { p2pCfg := NewP2PConfig(p2pCommonFlags) p2pCfg.NetworkToken = fs.p2ptoken @@ -36,11 +55,26 @@ func (fs *FederatedServer) Start(ctx context.Context, p2pCommonFlags cliP2P.P2PC return err } - return fs.proxy(ctx, n) + lvl, err := logP2P.LevelFromString(p2pCfg.LogLevel) + if err != nil { + lvl = logP2P.LevelError + } + llger := logger.New(lvl) + + aps := []trustzone.AuthProvider{} + for ap, providerOpts := range p2pCfg.PeerGuard.AuthProviders { + a, err := edgevpnConfig.AuthProvider(llger, ap, providerOpts) + if err != nil { + log.Warn().Msgf("invalid authprovider: %v", err) + continue + } + aps = append(aps, a) + } + + return fs.listener(ctx, n, aps) } -func (fs *FederatedServer) proxy(ctx context.Context, node *node.Node) error { - +func (fs *FederatedServer) listener(ctx context.Context, node *node.Node, aps []trustzone.AuthProvider) error { log.Info().Msgf("Allocating service '%s' on: %s", fs.service, fs.listenAddr) // Open local port for listening l, err := net.Listen("tcp", fs.listenAddr) @@ -57,6 +91,7 @@ func (fs *FederatedServer) proxy(ctx context.Context, node *node.Node) error { nodeAnnounce(ctx, node) defer l.Close() + for { select { case <-ctx.Done(): @@ -70,62 +105,243 @@ func (fs *FederatedServer) proxy(ctx context.Context, node *node.Node) error { continue } - // Handle connections in a new goroutine, forwarding to the p2p service go func() { - workerID := "" - if fs.workerTarget != "" { - workerID = fs.workerTarget - } else if fs.loadBalanced { - log.Debug().Msgf("Load balancing request") - - workerID = fs.SelectLeastUsedServer() - if workerID == "" { - log.Debug().Msgf("Least used server not found, selecting random") - workerID = fs.RandomServer() + if len(aps) > 0 { + if fs.handleHTTP(conn, node, aps) { + return } - } else { - workerID = fs.RandomServer() - } - - if workerID == "" { - log.Error().Msg("No available nodes yet") - fs.sendHTMLResponse(conn, 503, "Sorry, waiting for nodes to connect") - return - } - - log.Debug().Msgf("Selected node %s", workerID) - nodeData, exists := GetNode(fs.service, workerID) - if !exists { - log.Error().Msgf("Node %s not found", workerID) - fs.sendHTMLResponse(conn, 404, "Node not found") - return - } - - proxyP2PConnection(ctx, node, nodeData.ServiceID, conn) - if fs.loadBalanced { - fs.RecordRequest(workerID) } + fs.proxy(ctx, node, conn) }() } } } -// sendHTMLResponse sends a basic HTML response with a status code and a message. -// This is extracted to make the HTML content maintainable. -func (fs *FederatedServer) sendHTMLResponse(conn net.Conn, statusCode int, message string) { +func (fs *FederatedServer) handleHTTP(conn net.Conn, node *node.Node, aps []trustzone.AuthProvider) bool { + defer func() { + if r := recover(); r != nil { + log.Debug().Msgf("Recovered from panic: %v", r) + conn.Close() + } + }() + + r, err := testForHTTPRequest(conn) + if err != nil { + return false + } + defer r.Body.Close() + pathParts := strings.Split(strings.TrimPrefix(r.URL.Path, "/ledger/"), "/") + announcing := struct{ State string }{"Announcing"} + + // TODO deal with AuthProviders + // pubKey := r.Header.Get(authHeader) + + switch r.Method { + case http.MethodGet: + switch len(pathParts) { + case 2: // /ledger/:bucket/:key + bucket := pathParts[0] + key := pathParts[1] + + ledger, err := node.Ledger() + if err != nil { + fs.sendRawResponse(conn, http.StatusInternalServerError, "text/plain", []byte(err.Error())) + return true + } + + fs.sendJSONResponse(conn, http.StatusOK, ledger.CurrentData()[bucket][key]) + + case 1: // /ledger/:bucket + bucket := pathParts[0] + + ledger, err := node.Ledger() + if err != nil { + fs.sendRawResponse(conn, http.StatusInternalServerError, "text/plain", []byte(err.Error())) + return true + } + + fs.sendJSONResponse(conn, http.StatusOK, ledger.CurrentData()[bucket]) + + default: + fs.sendRawResponse(conn, http.StatusNotFound, "text/plain", []byte("not found")) + + } + + case http.MethodPut: + if len(pathParts) == 3 { // /ledger/:bucket/:key/:value + bucket := pathParts[0] + key := pathParts[1] + value := pathParts[2] + + ledger, err := node.Ledger() + if err != nil { + fs.sendRawResponse(conn, http.StatusInternalServerError, "text/plain", []byte(err.Error())) + return true + } + + ledger.Persist(context.Background(), DefaultInterval, Timeout, bucket, key, value) + fs.sendJSONResponse(conn, http.StatusOK, announcing) + + } else { + fs.sendRawResponse(conn, http.StatusNotFound, "text/plain", []byte("not found")) + } + + case http.MethodDelete: + switch len(pathParts) { + case 1: // /ledger/:bucket + bucket := pathParts[0] + + ledger, err := node.Ledger() + if err != nil { + fs.sendRawResponse(conn, http.StatusInternalServerError, "text/plain", []byte(err.Error())) + return true + } + + ledger.AnnounceDeleteBucket(context.Background(), DefaultInterval, Timeout, bucket) + fs.sendJSONResponse(conn, http.StatusOK, announcing) + + case 2: // /ledger/:bucket/:key + bucket := pathParts[0] + key := pathParts[1] + + ledger, err := node.Ledger() + if err != nil { + fs.sendRawResponse(conn, http.StatusInternalServerError, "text/plain", []byte(err.Error())) + return true + } + + ledger.AnnounceDeleteBucketKey(context.Background(), DefaultInterval, Timeout, bucket, key) + fs.sendJSONResponse(conn, http.StatusOK, announcing) + + default: + fs.sendRawResponse(conn, http.StatusNotFound, "text/plain", []byte("not found")) + + } + } + + return true +} + +// testForHTTPRequest peeking the first N bytes from the accepted conn, and trying to match it +// against the supported http methods, then against the supported route, then if there is auth header +func testForHTTPRequest(conn net.Conn) (*http.Request, error) { + reader := bufio.NewReader(conn) + + peekedData, err := reader.Peek(peekBufferSize) + if err != nil && err != bufio.ErrBufferFull { + log.Debug().Msgf("Error peeking data: %v", err) + return nil, err + } + peekedString := string(peekedData) + + // 1. Parse Request Line + firstLineEnd := strings.Index(peekedString, lineEnd) + if firstLineEnd == -1 { + log.Debug().Msg("Could not find request line end") + return nil, err + } + requestLine := peekedString[:firstLineEnd] + parts := strings.Split(requestLine, " ") + if len(parts) != 3 { + log.Debug().Msg("Invalid request line format") + return nil, err + } + method := parts[0] + uri := parts[1] + + if !slices.Contains([]string{ + http.MethodGet, + http.MethodPut, + http.MethodDelete, + }, method) { + log.Debug().Msg("Unsupported HTTP method") + return nil, err + } + if !strings.HasPrefix(uri, "/ledger") { + log.Debug().Msg("Unsupported HTTP route") + return nil, err + } + + headersPart := peekedString[firstLineEnd+len(lineEnd):] + headerEndIndex := strings.Index(headersPart, headerEnd) + if headerEndIndex == -1 { + log.Debug().Msg("Could not find end of headers within peek buffer") + return nil, err + } + headersString := headersPart[:headerEndIndex] + headers := strings.Split(headersString, lineEnd) + + foundAuth := false + for _, header := range headers { + if strings.HasPrefix(header, authHeader+":") { + parts := strings.SplitN(header, ":", 2) + if len(parts) == 2 { + foundAuth = true + break + } + } + } + + if !foundAuth { + log.Debug().Msgf("Required header '%s' not found.", authHeader) + return nil, err + } + + req, err := http.ReadRequest(reader) + if err != nil { + log.Debug().Msgf("Error reading full request: %v", err) + return nil, err + } + return req, nil +} + +func (fs *FederatedServer) proxy(ctx context.Context, node *node.Node, conn net.Conn) { + workerID := "" + if fs.workerTarget != "" { + workerID = fs.workerTarget + } else if fs.loadBalanced { + log.Debug().Msgf("Load balancing request") + + workerID = fs.SelectLeastUsedServer() + if workerID == "" { + log.Debug().Msgf("Least used server not found, selecting random") + workerID = fs.RandomServer() + } + } else { + workerID = fs.RandomServer() + } + + if workerID == "" { + log.Error().Msg("No available nodes yet") + fs.sendHTMLResponse(conn, 503, "Sorry, waiting for nodes to connect") + return + } + + log.Debug().Msgf("Selected node %s", workerID) + nodeData, exists := GetNode(fs.service, workerID) + if !exists { + log.Error().Msgf("Node %s not found", workerID) + fs.sendHTMLResponse(conn, 404, "Node not found") + return + } + + proxyP2PConnection(ctx, node, nodeData.ServiceID, conn) + if fs.loadBalanced { + fs.RecordRequest(workerID) + } +} + +// sendRawResponse sends whatever provided byte data with provided content type header +func (fs *FederatedServer) sendRawResponse(conn net.Conn, statusCode int, contentType string, data []byte) { defer conn.Close() - // Define the HTML content separately for easier maintenance. - htmlContent := fmt.Sprintf("