328 lines
9.4 KiB
Go
328 lines
9.4 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"syscall"
|
|
"time"
|
|
|
|
v2 "edp.buildth.ing/DevFW-CICD/edge-connect-client/v2/sdk/edgeconnect/v2"
|
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
|
|
"edp.buildth.ing/DevFW-CICD/edge-connect-mcp/oauth"
|
|
)
|
|
|
|
var (
|
|
edgeClient *v2.Client
|
|
config *Config
|
|
)
|
|
|
|
var (
|
|
mode = flag.String("mode", "stdio", "Server mode: 'stdio' for local, 'remote' for HTTP/SSE")
|
|
host = flag.String("host", "", "Host to bind to for remote mode (default from config or 0.0.0.0)")
|
|
port = flag.Int("port", 0, "Port to bind to for remote mode (default from config or 8080)")
|
|
)
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
// Load configuration
|
|
var err error
|
|
config, err = LoadConfig()
|
|
if err != nil {
|
|
log.Fatalf("Failed to load configuration: %v", err)
|
|
}
|
|
|
|
// Override config with command-line flags
|
|
if *mode != "" {
|
|
config.ServerMode = *mode
|
|
}
|
|
if *host != "" {
|
|
config.RemoteHost = *host
|
|
}
|
|
if *port != 0 {
|
|
config.RemotePort = *port
|
|
}
|
|
|
|
// config.Debug = true
|
|
|
|
// Validate configuration
|
|
if err := config.Validate(); err != nil {
|
|
log.Fatalf("Invalid configuration: %v", err)
|
|
}
|
|
|
|
// Initialize edge-connect client
|
|
edgeClient, err = initializeEdgeClient(config)
|
|
if err != nil {
|
|
log.Fatalf("Failed to initialize edge-connect client: %v", err)
|
|
}
|
|
|
|
// Create MCP server
|
|
mcpServer := mcp.NewServer(&mcp.Implementation{
|
|
Name: "edge-connect-mcp",
|
|
Version: "1.0.0",
|
|
}, nil)
|
|
|
|
// Register all tools
|
|
registerTools(mcpServer)
|
|
|
|
// Start server based on mode
|
|
switch config.ServerMode {
|
|
case "stdio":
|
|
log.Println("Starting server in stdio mode...")
|
|
if err := mcpServer.Run(context.Background(), &mcp.StdioTransport{}); err != nil {
|
|
log.Fatalf("Server error: %v", err)
|
|
}
|
|
|
|
case "remote":
|
|
log.Printf("Starting server in remote mode on %s:%d...", config.RemoteHost, config.RemotePort)
|
|
if err := startRemoteServer(mcpServer, config); err != nil {
|
|
log.Fatalf("Remote server error: %v", err)
|
|
}
|
|
|
|
default:
|
|
log.Fatalf("Invalid server mode: %s (must be 'stdio' or 'remote')", config.ServerMode)
|
|
}
|
|
}
|
|
|
|
func startRemoteServer(mcpServer *mcp.Server, cfg *Config) error {
|
|
// Create HTTP mux
|
|
mux := http.NewServeMux()
|
|
|
|
streamableHttpHandler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server {
|
|
// Simple bearer token auth - only if OAuth is disabled and auth is required
|
|
if !cfg.OAuthEnabled && cfg.RemoteAuthRequired {
|
|
if !authenticateRequest(r, cfg) {
|
|
return nil
|
|
}
|
|
}
|
|
return mcpServer
|
|
}, &mcp.StreamableHTTPOptions{})
|
|
|
|
// Health check endpoint (no auth required)
|
|
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(`{"status":"healthy"}`))
|
|
})
|
|
|
|
// Configure OAuth if enabled
|
|
if cfg.OAuthEnabled {
|
|
// Create JWT validator
|
|
validator := oauth.NewJWTValidator(
|
|
cfg.OAuthResourceURI,
|
|
cfg.OAuthIssuer,
|
|
cfg.OAuthJWKSURL,
|
|
)
|
|
|
|
// Create resource server
|
|
resourceServer := oauth.NewResourceServer(
|
|
cfg.OAuthResourceURI,
|
|
cfg.OAuthAuthServers,
|
|
validator,
|
|
)
|
|
|
|
// Register Protected Resource Metadata endpoint (RFC 9728)
|
|
mux.HandleFunc("/.well-known/oauth-protected-resource", resourceServer.ServeMetadata)
|
|
|
|
// Create OAuth middleware
|
|
authMiddleware := oauth.AuthMiddleware(resourceServer, validator)
|
|
|
|
// Wrap MCP handler with OAuth middleware
|
|
mux.Handle("/mcp", authMiddleware(streamableHttpHandler))
|
|
|
|
log.Printf("OAuth 2.1 enabled")
|
|
log.Printf("Protected Resource URI: %s", cfg.OAuthResourceURI)
|
|
log.Printf("Authorization Servers: %v", cfg.OAuthAuthServers)
|
|
} else {
|
|
mux.Handle("/mcp", streamableHttpHandler)
|
|
}
|
|
|
|
// Create HTTP server
|
|
addr := fmt.Sprintf("%s:%d", cfg.RemoteHost, cfg.RemotePort)
|
|
httpServer := &http.Server{
|
|
Addr: addr,
|
|
Handler: mux,
|
|
ReadTimeout: 30 * time.Second,
|
|
// WriteTimeout disabled for SSE long-lived connections
|
|
WriteTimeout: 0,
|
|
IdleTimeout: 120 * time.Second,
|
|
}
|
|
|
|
// Handle graceful shutdown
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
|
|
|
errChan := make(chan error, 1)
|
|
|
|
// Start basic auth server if enabled
|
|
if cfg.OAuthAuthServerEnabled {
|
|
go startBasicAuthServer(cfg)
|
|
}
|
|
|
|
// Start HTTP server
|
|
go func() {
|
|
log.Printf("HTTP server listening on %s", addr)
|
|
log.Printf("SSE endpoint: http://%s/sse", addr)
|
|
log.Printf("Health check: http://%s/health", addr)
|
|
if cfg.OAuthEnabled {
|
|
log.Printf("Authentication: OAuth 2.1 (JWT Bearer tokens required)")
|
|
log.Printf("Protected Resource Metadata: http://%s/.well-known/oauth-protected-resource", addr)
|
|
} else if cfg.RemoteAuthRequired {
|
|
log.Printf("Authentication: Simple Bearer token")
|
|
} else {
|
|
log.Printf("Authentication: DISABLED (anyone can connect)")
|
|
}
|
|
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
errChan <- fmt.Errorf("HTTP server error: %w", err)
|
|
}
|
|
}()
|
|
|
|
// Wait for shutdown signal or error
|
|
select {
|
|
case <-sigChan:
|
|
log.Println("Shutdown signal received, stopping server...")
|
|
|
|
// Shutdown HTTP server
|
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer shutdownCancel()
|
|
|
|
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
|
log.Printf("HTTP server shutdown error: %v", err)
|
|
}
|
|
|
|
log.Println("Server stopped")
|
|
return nil
|
|
|
|
case err := <-errChan:
|
|
return err
|
|
}
|
|
}
|
|
|
|
func initializeEdgeClient(cfg *Config) (*v2.Client, error) {
|
|
var options []v2.Option
|
|
|
|
// Add logger if debug enabled
|
|
if cfg.Debug {
|
|
options = append(options, v2.WithLogger(log.Default()))
|
|
}
|
|
|
|
// Configure retry options
|
|
if cfg.RetryMaxRetries > 0 {
|
|
retryOpts := v2.RetryOptions{
|
|
MaxRetries: cfg.RetryMaxRetries,
|
|
InitialDelay: cfg.RetryInitialDelay,
|
|
MaxDelay: cfg.RetryMaxDelay,
|
|
Multiplier: cfg.RetryMultiplier,
|
|
}
|
|
options = append(options, v2.WithRetryOptions(retryOpts))
|
|
}
|
|
|
|
// Initialize client with authentication
|
|
switch cfg.AuthType {
|
|
case "token":
|
|
if cfg.Token == "" {
|
|
return nil, fmt.Errorf("token is required when auth_type is 'token'")
|
|
}
|
|
authProvider := v2.NewStaticTokenProvider(cfg.Token)
|
|
options = append(options, v2.WithAuthProvider(authProvider))
|
|
return v2.NewClient(cfg.BaseURL, options...), nil
|
|
|
|
case "credentials":
|
|
if cfg.Username == "" || cfg.Password == "" {
|
|
return nil, fmt.Errorf("username and password are required when auth_type is 'credentials'")
|
|
}
|
|
return v2.NewClientWithCredentials(cfg.BaseURL, cfg.Username, cfg.Password, options...), nil
|
|
|
|
case "none":
|
|
authProvider := v2.NewNoAuthProvider()
|
|
options = append(options, v2.WithAuthProvider(authProvider))
|
|
return v2.NewClient(cfg.BaseURL, options...), nil
|
|
|
|
default:
|
|
return nil, fmt.Errorf("invalid auth_type: %s (must be 'token', 'credentials', or 'none')", cfg.AuthType)
|
|
}
|
|
}
|
|
|
|
func registerTools(s *mcp.Server) {
|
|
// Apps endpoints
|
|
registerCreateAppTool(s)
|
|
registerShowAppTool(s)
|
|
registerListAppsTool(s)
|
|
registerUpdateAppTool(s)
|
|
registerDeleteAppTool(s)
|
|
|
|
// App Instance endpoints
|
|
registerCreateAppInstanceTool(s)
|
|
registerShowAppInstanceTool(s)
|
|
registerListAppInstancesTool(s)
|
|
registerUpdateAppInstanceTool(s)
|
|
registerRefreshAppInstanceTool(s)
|
|
registerDeleteAppInstanceTool(s)
|
|
|
|
log.Printf("Registered 11 tools")
|
|
}
|
|
|
|
func startBasicAuthServer(cfg *Config) {
|
|
// Create basic authorization server
|
|
authServer, err := oauth.NewBasicAuthServer(
|
|
cfg.OAuthIssuer,
|
|
fmt.Sprintf("http://localhost:%d", cfg.OAuthAuthServerPort),
|
|
)
|
|
if err != nil {
|
|
log.Printf("Failed to create basic auth server: %v", err)
|
|
return
|
|
}
|
|
|
|
// Register client from configuration
|
|
authServer.RegisterClient(cfg.OAuthClientID, []string{cfg.OAuthRedirectURI})
|
|
|
|
// Create HTTP mux for auth server
|
|
mux := http.NewServeMux()
|
|
|
|
// Authorization endpoint (GET /authorize)
|
|
mux.HandleFunc("/authorize", authServer.HandleAuthorize)
|
|
|
|
// Token endpoint (POST /token)
|
|
mux.HandleFunc("/token", authServer.HandleToken)
|
|
|
|
// Registration endpoint (POST /register) - RFC 7591
|
|
mux.HandleFunc("/register", authServer.HandleRegistration)
|
|
|
|
// JWKS endpoint (GET /.well-known/jwks.json)
|
|
mux.HandleFunc("/.well-known/jwks.json", authServer.HandleJWKS)
|
|
|
|
// Authorization server metadata endpoint (GET /.well-known/oauth-authorization-server)
|
|
mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Cache-Control", "public, max-age=3600")
|
|
if err := json.NewEncoder(w).Encode(authServer.GetMetadata()); err != nil {
|
|
log.Printf("Failed to encode metadata response: %v", err)
|
|
}
|
|
})
|
|
|
|
// Start auth server
|
|
addr := fmt.Sprintf(":%d", cfg.OAuthAuthServerPort)
|
|
log.Printf("Basic OAuth 2.1 Authorization Server starting on http://localhost%s", addr)
|
|
log.Printf(" Authorization endpoint: http://localhost%s/authorize", addr)
|
|
log.Printf(" Token endpoint: http://localhost%s/token", addr)
|
|
log.Printf(" Registration endpoint: http://localhost%s/register", addr)
|
|
log.Printf(" JWKS endpoint: http://localhost%s/.well-known/jwks.json", addr)
|
|
log.Printf(" Metadata endpoint: http://localhost%s/.well-known/oauth-authorization-server", addr)
|
|
log.Printf(" Registered client: %s", cfg.OAuthClientID)
|
|
log.Printf(" Dynamic client registration: enabled (RFC 7591)")
|
|
|
|
if err := http.ListenAndServe(addr, mux); err != nil {
|
|
log.Printf("Basic auth server error: %v", err)
|
|
}
|
|
}
|