Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 104 additions & 37 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,51 +8,114 @@ import (
"github.com/spf13/viper"
)

// Config holds all application configuration
// Config holds all application configuration loaded from environment variables.
type Config struct {
// Server configuration
MCPHost string `mapstructure:"MCP_HOST"`
MCPPort int `mapstructure:"MCP_PORT"`
MCPSSLKeyfile string `mapstructure:"MCP_SSL_KEYFILE"`
MCPSSLCertfile string `mapstructure:"MCP_SSL_CERTFILE"`

// MCPHost is the hostname or IP address the server binds to.
MCPHost string `mapstructure:"MCP_HOST"`

// MCPPort is the port number the server listens on (1024-65535).
MCPPort int `mapstructure:"MCP_PORT"`

// MCPSSLKeyfile is the path to the SSL private key file (optional).
MCPSSLKeyfile string `mapstructure:"MCP_SSL_KEYFILE"`

// MCPSSLCertfile is the path to the SSL certificate file (optional).
MCPSSLCertfile string `mapstructure:"MCP_SSL_CERTFILE"`

// MCPTransportProtocol specifies the transport protocol ("stdio", "streamable-http", "sse", "http").
MCPTransportProtocol string `mapstructure:"MCP_TRANSPORT_PROTOCOL"`
MCPHostEndpoint string `mapstructure:"MCP_HOST_ENDPOINT"`
Environment string `mapstructure:"ENVIRONMENT"`

// MCPHostEndpoint is the full URL endpoint of the server.
MCPHostEndpoint string `mapstructure:"MCP_HOST_ENDPOINT"`

// Environment specifies the deployment environment ("development", "staging", "production").
Environment string `mapstructure:"ENVIRONMENT"`

// Logging configuration

// LogLevel sets the logging verbosity ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL").
LogLevel string `mapstructure:"LOG_LEVEL"`

// CORS configuration
CORSEnabled bool `mapstructure:"CORS_ENABLED"`
CORSOrigins []string `mapstructure:"CORS_ORIGINS"`
CORSCredentials bool `mapstructure:"CORS_CREDENTIALS"`
CORSMethods []string `mapstructure:"CORS_METHODS"`
CORSHeaders []string `mapstructure:"CORS_HEADERS"`

// CORSEnabled controls whether CORS middleware is enabled.
CORSEnabled bool `mapstructure:"CORS_ENABLED"`

// CORSOrigins lists allowed origins for CORS requests.
CORSOrigins []string `mapstructure:"CORS_ORIGINS"`

// CORSCredentials controls whether credentials are allowed in CORS requests.
CORSCredentials bool `mapstructure:"CORS_CREDENTIALS"`

// CORSMethods lists allowed HTTP methods for CORS requests.
CORSMethods []string `mapstructure:"CORS_METHODS"`

// CORSHeaders lists allowed headers for CORS requests.
CORSHeaders []string `mapstructure:"CORS_HEADERS"`

// SSO/OAuth configuration
SSOClientID string `mapstructure:"SSO_CLIENT_ID"`
SSOClientSecret string `mapstructure:"SSO_CLIENT_SECRET"`
SSOCallbackURL string `mapstructure:"SSO_CALLBACK_URL"`
SSOAuthorizationURL string `mapstructure:"SSO_AUTHORIZATION_URL"`
SSOTokenURL string `mapstructure:"SSO_TOKEN_URL"`
SSOIntrospectionURL string `mapstructure:"SSO_INTROSPECTION_URL"`
SessionSecret string `mapstructure:"SESSION_SECRET"`
UseExternalBrowserAuth bool `mapstructure:"USE_EXTERNAL_BROWSER_AUTH"`
CompatibleWithCursor bool `mapstructure:"COMPATIBLE_WITH_CURSOR"`
CursorCompatibleSSE bool `mapstructure:"CURSOR_COMPATIBLE_SSE"`
EnableAuth bool `mapstructure:"ENABLE_AUTH"`

// SSOClientID is the OAuth client ID for SSO authentication.
SSOClientID string `mapstructure:"SSO_CLIENT_ID"`

// SSOClientSecret is the OAuth client secret for SSO authentication.
SSOClientSecret string `mapstructure:"SSO_CLIENT_SECRET"`

// SSOCallbackURL is the OAuth callback/redirect URL.
SSOCallbackURL string `mapstructure:"SSO_CALLBACK_URL"`

// SSOAuthorizationURL is the OAuth authorization endpoint.
SSOAuthorizationURL string `mapstructure:"SSO_AUTHORIZATION_URL"`

// SSOTokenURL is the OAuth token exchange endpoint.
SSOTokenURL string `mapstructure:"SSO_TOKEN_URL"`

// SSOIntrospectionURL is the OAuth token introspection endpoint.
SSOIntrospectionURL string `mapstructure:"SSO_INTROSPECTION_URL"`

// SessionSecret is the secret key used for session encryption.
SessionSecret string `mapstructure:"SESSION_SECRET"`

// UseExternalBrowserAuth enables external browser-based authentication.
UseExternalBrowserAuth bool `mapstructure:"USE_EXTERNAL_BROWSER_AUTH"`

// CompatibleWithCursor enables Cursor IDE compatibility mode.
CompatibleWithCursor bool `mapstructure:"COMPATIBLE_WITH_CURSOR"`

// CursorCompatibleSSE enables SSE streaming compatible with Cursor IDE.
CursorCompatibleSSE bool `mapstructure:"CURSOR_COMPATIBLE_SSE"`

// EnableAuth controls whether authentication is required.
EnableAuth bool `mapstructure:"ENABLE_AUTH"`

// PostgreSQL configuration
PostgresHost string `mapstructure:"POSTGRES_HOST"`
PostgresPort int `mapstructure:"POSTGRES_PORT"`
PostgresDB string `mapstructure:"POSTGRES_DB"`
PostgresUser string `mapstructure:"POSTGRES_USER"`
PostgresPassword string `mapstructure:"POSTGRES_PASSWORD"`
PostgresPoolSize int `mapstructure:"POSTGRES_POOL_SIZE"`
PostgresMaxConnections int `mapstructure:"POSTGRES_MAX_CONNECTIONS"`

// PostgresHost is the PostgreSQL server hostname or IP address.
PostgresHost string `mapstructure:"POSTGRES_HOST"`

// PostgresPort is the PostgreSQL server port (default: 5432).
PostgresPort int `mapstructure:"POSTGRES_PORT"`

// PostgresDB is the name of the PostgreSQL database.
PostgresDB string `mapstructure:"POSTGRES_DB"`

// PostgresUser is the PostgreSQL authentication username.
PostgresUser string `mapstructure:"POSTGRES_USER"`

// PostgresPassword is the PostgreSQL authentication password.
PostgresPassword string `mapstructure:"POSTGRES_PASSWORD"`

// PostgresPoolSize is the initial number of connections in the pool.
PostgresPoolSize int `mapstructure:"POSTGRES_POOL_SIZE"`

// PostgresMaxConnections is the maximum number of concurrent connections.
PostgresMaxConnections int `mapstructure:"POSTGRES_MAX_CONNECTIONS"`
}

// Load loads configuration from environment variables
// Load loads configuration from environment variables and optional .env file.
// It applies default values, reads configuration, and performs validation.
func Load() (*Config, error) {
v := viper.New()

Expand Down Expand Up @@ -89,6 +152,7 @@ func Load() (*Config, error) {
return &cfg, nil
}

// setDefaults sets default configuration values.
func setDefaults(v *viper.Viper) {
// Server defaults
v.SetDefault("MCP_HOST", "localhost")
Expand Down Expand Up @@ -119,7 +183,8 @@ func setDefaults(v *viper.Viper) {
v.SetDefault("POSTGRES_MAX_CONNECTIONS", 20)
}

// Validate performs validation on configuration values
// Validate performs validation on configuration values.
// It returns an error if any configuration value is invalid.
func (c *Config) Validate() error {
// Validate port range
if c.MCPPort < 1024 || c.MCPPort > 65535 {
Expand Down Expand Up @@ -183,12 +248,12 @@ func (c *Config) Validate() error {
return nil
}

// GetServerAddress returns the server address string
// GetServerAddress returns the server address string in "host:port" format.
func (c *Config) GetServerAddress() string {
return fmt.Sprintf("%s:%d", c.MCPHost, c.MCPPort)
}

// GetPostgresConnectionString returns the PostgreSQL connection string
// GetPostgresConnectionString returns the PostgreSQL connection string.
func (c *Config) GetPostgresConnectionString() string {
return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s",
c.PostgresUser,
Expand All @@ -199,12 +264,13 @@ func (c *Config) GetPostgresConnectionString() string {
)
}

// HasSSL returns true if SSL is configured
// HasSSL returns true if SSL is configured (both key and cert files are set).
func (c *Config) HasSSL() bool {
return c.MCPSSLKeyfile != "" && c.MCPSSLCertfile != ""
}

// GetSessionSecret returns the session secret, generating one for development if needed
// GetSessionSecret returns the session secret, generating one for development if needed.
// In production, a secret must be explicitly configured.
func (c *Config) GetSessionSecret() string {
if c.SessionSecret != "" {
return c.SessionSecret
Expand All @@ -218,7 +284,8 @@ func (c *Config) GetSessionSecret() string {
return ""
}

// generateEphemeralKey creates a temporary session secret for development
// generateEphemeralKey creates a temporary session secret for development environments.
// This should NEVER be used in production.
func generateEphemeralKey() string {
// This is a simple implementation; in production use crypto/rand
return fmt.Sprintf("dev-ephemeral-key-%d", time.Now().UnixNano())
Expand Down
71 changes: 48 additions & 23 deletions internal/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,38 @@ import (
"net/http"
)

// AppError represents an application-level error
// AppError represents an application-level error with HTTP status code and error code.
// It implements the error interface and supports error wrapping.
type AppError struct {
Code string `json:"code"`
Message string `json:"message"`
StatusCode int `json:"-"`
Err error `json:"-"`
// Code is a machine-readable error code (e.g., "bad_request", "not_found").
Code string `json:"code"`

// Message is a human-readable error message.
Message string `json:"message"`

// StatusCode is the HTTP status code to return (e.g., 400, 404, 500).
StatusCode int `json:"-"`

// Err is the underlying error that caused this AppError.
Err error `json:"-"`
}

// Error implements the error interface
// Error implements the error interface.
// It returns a formatted error message, including the underlying error if present.
func (e *AppError) Error() string {
if e.Err != nil {
return fmt.Sprintf("%s: %v", e.Message, e.Err)
}
return e.Message
}

// Unwrap returns the underlying error
// Unwrap returns the underlying error for error chain inspection.
// This allows errors.Is() and errors.As() to work correctly.
func (e *AppError) Unwrap() error {
return e.Err
}

// New creates a new AppError
// New creates a new AppError with the specified code, message, status code, and underlying error.
func New(code, message string, statusCode int, err error) *AppError {
return &AppError{
Code: code,
Expand All @@ -36,7 +46,8 @@ func New(code, message string, statusCode int, err error) *AppError {
}
}

// Wrap wraps an error with additional context
// Wrap wraps an existing error with additional context.
// This is useful for adding HTTP status codes and error codes to standard errors.
func Wrap(err error, code, message string, statusCode int) *AppError {
return &AppError{
Code: code,
Expand All @@ -48,48 +59,62 @@ func Wrap(err error, code, message string, statusCode int) *AppError {

// Common error constructors

// ErrBadRequest creates a 400 Bad Request error
// ErrBadRequest creates a 400 Bad Request error.
func ErrBadRequest(message string, err error) *AppError {
return New("bad_request", message, http.StatusBadRequest, err)
}

// ErrUnauthorized creates a 401 Unauthorized error
// ErrUnauthorized creates a 401 Unauthorized error.
func ErrUnauthorized(message string, err error) *AppError {
return New("unauthorized", message, http.StatusUnauthorized, err)
}

// ErrForbidden creates a 403 Forbidden error
// ErrForbidden creates a 403 Forbidden error.
func ErrForbidden(message string, err error) *AppError {
return New("forbidden", message, http.StatusForbidden, err)
}

// ErrNotFound creates a 404 Not Found error
// ErrNotFound creates a 404 Not Found error.
func ErrNotFound(message string, err error) *AppError {
return New("not_found", message, http.StatusNotFound, err)
}

// ErrConflict creates a 409 Conflict error
// ErrConflict creates a 409 Conflict error.
func ErrConflict(message string, err error) *AppError {
return New("conflict", message, http.StatusConflict, err)
}

// ErrInternal creates a 500 Internal Server Error
// ErrInternal creates a 500 Internal Server Error.
func ErrInternal(message string, err error) *AppError {
return New("internal_error", message, http.StatusInternalServerError, err)
}

// ErrServiceUnavailable creates a 503 Service Unavailable error
// ErrServiceUnavailable creates a 503 Service Unavailable error.
func ErrServiceUnavailable(message string, err error) *AppError {
return New("service_unavailable", message, http.StatusServiceUnavailable, err)
}

// Predefined errors
// Predefined errors for common scenarios.
// These can be used directly or wrapped with additional context.
var (
ErrInvalidInput = ErrBadRequest("Invalid input", nil)
// ErrInvalidInput indicates that the request contains invalid data.
ErrInvalidInput = ErrBadRequest("Invalid input", nil)

// ErrMissingAuthHeader indicates that the Authorization header is missing.
ErrMissingAuthHeader = ErrUnauthorized("Missing authorization header", nil)
ErrInvalidToken = ErrUnauthorized("Invalid token", nil)
ErrExpiredToken = ErrUnauthorized("Token expired", nil)
ErrResourceNotFound = ErrNotFound("Resource not found", nil)
ErrDatabaseError = ErrInternal("Database error", nil)
ErrStorageError = ErrInternal("Storage error", nil)

// ErrInvalidToken indicates that the provided token is invalid or malformed.
ErrInvalidToken = ErrUnauthorized("Invalid token", nil)

// ErrExpiredToken indicates that the provided token has expired.
ErrExpiredToken = ErrUnauthorized("Token expired", nil)

// ErrResourceNotFound indicates that the requested resource doesn't exist.
ErrResourceNotFound = ErrNotFound("Resource not found", nil)

// ErrDatabaseError indicates a database operation failed.
ErrDatabaseError = ErrInternal("Database error", nil)

// ErrStorageError indicates a storage operation failed.
ErrStorageError = ErrInternal("Storage error", nil)
)
Loading