edge-connect-mcp/main.go

328 lines
9.4 KiB
Go

package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
v2 "edp.buildth.ing/DevFW-CICD/edge-connect-client/v2/sdk/edgeconnect/v2"
"github.com/modelcontextprotocol/go-sdk/mcp"
"edp.buildth.ing/DevFW-CICD/edge-connect-mcp/oauth"
)
var (
edgeClient *v2.Client
config *Config
)
var (
mode = flag.String("mode", "stdio", "Server mode: 'stdio' for local, 'remote' for HTTP/SSE")
host = flag.String("host", "", "Host to bind to for remote mode (default from config or 0.0.0.0)")
port = flag.Int("port", 0, "Port to bind to for remote mode (default from config or 8080)")
)
func main() {
flag.Parse()
// Load configuration
var err error
config, err = LoadConfig()
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
// Override config with command-line flags
if *mode != "" {
config.ServerMode = *mode
}
if *host != "" {
config.RemoteHost = *host
}
if *port != 0 {
config.RemotePort = *port
}
// config.Debug = true
// Validate configuration
if err := config.Validate(); err != nil {
log.Fatalf("Invalid configuration: %v", err)
}
// Initialize edge-connect client
edgeClient, err = initializeEdgeClient(config)
if err != nil {
log.Fatalf("Failed to initialize edge-connect client: %v", err)
}
// Create MCP server
mcpServer := mcp.NewServer(&mcp.Implementation{
Name: "edge-connect-mcp",
Version: "1.0.0",
}, nil)
// Register all tools
registerTools(mcpServer)
// Start server based on mode
switch config.ServerMode {
case "stdio":
log.Println("Starting server in stdio mode...")
if err := mcpServer.Run(context.Background(), &mcp.StdioTransport{}); err != nil {
log.Fatalf("Server error: %v", err)
}
case "remote":
log.Printf("Starting server in remote mode on %s:%d...", config.RemoteHost, config.RemotePort)
if err := startRemoteServer(mcpServer, config); err != nil {
log.Fatalf("Remote server error: %v", err)
}
default:
log.Fatalf("Invalid server mode: %s (must be 'stdio' or 'remote')", config.ServerMode)
}
}
func startRemoteServer(mcpServer *mcp.Server, cfg *Config) error {
// Create HTTP mux
mux := http.NewServeMux()
streamableHttpHandler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server {
// Simple bearer token auth - only if OAuth is disabled and auth is required
if !cfg.OAuthEnabled && cfg.RemoteAuthRequired {
if !authenticateRequest(r, cfg) {
return nil
}
}
return mcpServer
}, &mcp.StreamableHTTPOptions{})
// Health check endpoint (no auth required)
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"status":"healthy"}`))
})
// Configure OAuth if enabled
if cfg.OAuthEnabled {
// Create JWT validator
validator := oauth.NewJWTValidator(
cfg.OAuthResourceURI,
cfg.OAuthIssuer,
cfg.OAuthJWKSURL,
)
// Create resource server
resourceServer := oauth.NewResourceServer(
cfg.OAuthResourceURI,
cfg.OAuthAuthServers,
validator,
)
// Register Protected Resource Metadata endpoint (RFC 9728)
mux.HandleFunc("/.well-known/oauth-protected-resource", resourceServer.ServeMetadata)
// Create OAuth middleware
authMiddleware := oauth.AuthMiddleware(resourceServer, validator)
// Wrap MCP handler with OAuth middleware
mux.Handle("/mcp", authMiddleware(streamableHttpHandler))
log.Printf("OAuth 2.1 enabled")
log.Printf("Protected Resource URI: %s", cfg.OAuthResourceURI)
log.Printf("Authorization Servers: %v", cfg.OAuthAuthServers)
} else {
mux.Handle("/mcp", streamableHttpHandler)
}
// Create HTTP server
addr := fmt.Sprintf("%s:%d", cfg.RemoteHost, cfg.RemotePort)
httpServer := &http.Server{
Addr: addr,
Handler: mux,
ReadTimeout: 30 * time.Second,
// WriteTimeout disabled for SSE long-lived connections
WriteTimeout: 0,
IdleTimeout: 120 * time.Second,
}
// Handle graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
errChan := make(chan error, 1)
// Start basic auth server if enabled
if cfg.OAuthAuthServerEnabled {
go startBasicAuthServer(cfg)
}
// Start HTTP server
go func() {
log.Printf("HTTP server listening on %s", addr)
log.Printf("SSE endpoint: http://%s/sse", addr)
log.Printf("Health check: http://%s/health", addr)
if cfg.OAuthEnabled {
log.Printf("Authentication: OAuth 2.1 (JWT Bearer tokens required)")
log.Printf("Protected Resource Metadata: http://%s/.well-known/oauth-protected-resource", addr)
} else if cfg.RemoteAuthRequired {
log.Printf("Authentication: Simple Bearer token")
} else {
log.Printf("Authentication: DISABLED (anyone can connect)")
}
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
errChan <- fmt.Errorf("HTTP server error: %w", err)
}
}()
// Wait for shutdown signal or error
select {
case <-sigChan:
log.Println("Shutdown signal received, stopping server...")
// Shutdown HTTP server
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
log.Printf("HTTP server shutdown error: %v", err)
}
log.Println("Server stopped")
return nil
case err := <-errChan:
return err
}
}
func initializeEdgeClient(cfg *Config) (*v2.Client, error) {
var options []v2.Option
// Add logger if debug enabled
if cfg.Debug {
options = append(options, v2.WithLogger(log.Default()))
}
// Configure retry options
if cfg.RetryMaxRetries > 0 {
retryOpts := v2.RetryOptions{
MaxRetries: cfg.RetryMaxRetries,
InitialDelay: cfg.RetryInitialDelay,
MaxDelay: cfg.RetryMaxDelay,
Multiplier: cfg.RetryMultiplier,
}
options = append(options, v2.WithRetryOptions(retryOpts))
}
// Initialize client with authentication
switch cfg.AuthType {
case "token":
if cfg.Token == "" {
return nil, fmt.Errorf("token is required when auth_type is 'token'")
}
authProvider := v2.NewStaticTokenProvider(cfg.Token)
options = append(options, v2.WithAuthProvider(authProvider))
return v2.NewClient(cfg.BaseURL, options...), nil
case "credentials":
if cfg.Username == "" || cfg.Password == "" {
return nil, fmt.Errorf("username and password are required when auth_type is 'credentials'")
}
return v2.NewClientWithCredentials(cfg.BaseURL, cfg.Username, cfg.Password, options...), nil
case "none":
authProvider := v2.NewNoAuthProvider()
options = append(options, v2.WithAuthProvider(authProvider))
return v2.NewClient(cfg.BaseURL, options...), nil
default:
return nil, fmt.Errorf("invalid auth_type: %s (must be 'token', 'credentials', or 'none')", cfg.AuthType)
}
}
func registerTools(s *mcp.Server) {
// Apps endpoints
registerCreateAppTool(s)
registerShowAppTool(s)
registerListAppsTool(s)
registerUpdateAppTool(s)
registerDeleteAppTool(s)
// App Instance endpoints
registerCreateAppInstanceTool(s)
registerShowAppInstanceTool(s)
registerListAppInstancesTool(s)
registerUpdateAppInstanceTool(s)
registerRefreshAppInstanceTool(s)
registerDeleteAppInstanceTool(s)
log.Printf("Registered 11 tools")
}
func startBasicAuthServer(cfg *Config) {
// Create basic authorization server
authServer, err := oauth.NewBasicAuthServer(
cfg.OAuthIssuer,
fmt.Sprintf("http://localhost:%d", cfg.OAuthAuthServerPort),
)
if err != nil {
log.Printf("Failed to create basic auth server: %v", err)
return
}
// Register client from configuration
authServer.RegisterClient(cfg.OAuthClientID, []string{cfg.OAuthRedirectURI})
// Create HTTP mux for auth server
mux := http.NewServeMux()
// Authorization endpoint (GET /authorize)
mux.HandleFunc("/authorize", authServer.HandleAuthorize)
// Token endpoint (POST /token)
mux.HandleFunc("/token", authServer.HandleToken)
// Registration endpoint (POST /register) - RFC 7591
mux.HandleFunc("/register", authServer.HandleRegistration)
// JWKS endpoint (GET /.well-known/jwks.json)
mux.HandleFunc("/.well-known/jwks.json", authServer.HandleJWKS)
// Authorization server metadata endpoint (GET /.well-known/oauth-authorization-server)
mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "public, max-age=3600")
if err := json.NewEncoder(w).Encode(authServer.GetMetadata()); err != nil {
log.Printf("Failed to encode metadata response: %v", err)
}
})
// Start auth server
addr := fmt.Sprintf(":%d", cfg.OAuthAuthServerPort)
log.Printf("Basic OAuth 2.1 Authorization Server starting on http://localhost%s", addr)
log.Printf(" Authorization endpoint: http://localhost%s/authorize", addr)
log.Printf(" Token endpoint: http://localhost%s/token", addr)
log.Printf(" Registration endpoint: http://localhost%s/register", addr)
log.Printf(" JWKS endpoint: http://localhost%s/.well-known/jwks.json", addr)
log.Printf(" Metadata endpoint: http://localhost%s/.well-known/oauth-authorization-server", addr)
log.Printf(" Registered client: %s", cfg.OAuthClientID)
log.Printf(" Dynamic client registration: enabled (RFC 7591)")
if err := http.ListenAndServe(addr, mux); err != nil {
log.Printf("Basic auth server error: %v", err)
}
}