diff --git a/internal/config/config.go b/internal/config/config.go index 4937f12..e1924da 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,51 +8,114 @@ import ( "github.com/spf13/viper" ) -// Config holds all application configuration +// Config holds all application configuration loaded from environment variables. type Config struct { // Server configuration - MCPHost string `mapstructure:"MCP_HOST"` - MCPPort int `mapstructure:"MCP_PORT"` - MCPSSLKeyfile string `mapstructure:"MCP_SSL_KEYFILE"` - MCPSSLCertfile string `mapstructure:"MCP_SSL_CERTFILE"` + + // MCPHost is the hostname or IP address the server binds to. + MCPHost string `mapstructure:"MCP_HOST"` + + // MCPPort is the port number the server listens on (1024-65535). + MCPPort int `mapstructure:"MCP_PORT"` + + // MCPSSLKeyfile is the path to the SSL private key file (optional). + MCPSSLKeyfile string `mapstructure:"MCP_SSL_KEYFILE"` + + // MCPSSLCertfile is the path to the SSL certificate file (optional). + MCPSSLCertfile string `mapstructure:"MCP_SSL_CERTFILE"` + + // MCPTransportProtocol specifies the transport protocol ("stdio", "streamable-http", "sse", "http"). MCPTransportProtocol string `mapstructure:"MCP_TRANSPORT_PROTOCOL"` - MCPHostEndpoint string `mapstructure:"MCP_HOST_ENDPOINT"` - Environment string `mapstructure:"ENVIRONMENT"` + + // MCPHostEndpoint is the full URL endpoint of the server. + MCPHostEndpoint string `mapstructure:"MCP_HOST_ENDPOINT"` + + // Environment specifies the deployment environment ("development", "staging", "production"). + Environment string `mapstructure:"ENVIRONMENT"` // Logging configuration + + // LogLevel sets the logging verbosity ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"). LogLevel string `mapstructure:"LOG_LEVEL"` // CORS configuration - CORSEnabled bool `mapstructure:"CORS_ENABLED"` - CORSOrigins []string `mapstructure:"CORS_ORIGINS"` - CORSCredentials bool `mapstructure:"CORS_CREDENTIALS"` - CORSMethods []string `mapstructure:"CORS_METHODS"` - CORSHeaders []string `mapstructure:"CORS_HEADERS"` + + // CORSEnabled controls whether CORS middleware is enabled. + CORSEnabled bool `mapstructure:"CORS_ENABLED"` + + // CORSOrigins lists allowed origins for CORS requests. + CORSOrigins []string `mapstructure:"CORS_ORIGINS"` + + // CORSCredentials controls whether credentials are allowed in CORS requests. + CORSCredentials bool `mapstructure:"CORS_CREDENTIALS"` + + // CORSMethods lists allowed HTTP methods for CORS requests. + CORSMethods []string `mapstructure:"CORS_METHODS"` + + // CORSHeaders lists allowed headers for CORS requests. + CORSHeaders []string `mapstructure:"CORS_HEADERS"` // SSO/OAuth configuration - SSOClientID string `mapstructure:"SSO_CLIENT_ID"` - SSOClientSecret string `mapstructure:"SSO_CLIENT_SECRET"` - SSOCallbackURL string `mapstructure:"SSO_CALLBACK_URL"` - SSOAuthorizationURL string `mapstructure:"SSO_AUTHORIZATION_URL"` - SSOTokenURL string `mapstructure:"SSO_TOKEN_URL"` - SSOIntrospectionURL string `mapstructure:"SSO_INTROSPECTION_URL"` - SessionSecret string `mapstructure:"SESSION_SECRET"` - UseExternalBrowserAuth bool `mapstructure:"USE_EXTERNAL_BROWSER_AUTH"` - CompatibleWithCursor bool `mapstructure:"COMPATIBLE_WITH_CURSOR"` - CursorCompatibleSSE bool `mapstructure:"CURSOR_COMPATIBLE_SSE"` - EnableAuth bool `mapstructure:"ENABLE_AUTH"` + + // SSOClientID is the OAuth client ID for SSO authentication. + SSOClientID string `mapstructure:"SSO_CLIENT_ID"` + + // SSOClientSecret is the OAuth client secret for SSO authentication. + SSOClientSecret string `mapstructure:"SSO_CLIENT_SECRET"` + + // SSOCallbackURL is the OAuth callback/redirect URL. + SSOCallbackURL string `mapstructure:"SSO_CALLBACK_URL"` + + // SSOAuthorizationURL is the OAuth authorization endpoint. + SSOAuthorizationURL string `mapstructure:"SSO_AUTHORIZATION_URL"` + + // SSOTokenURL is the OAuth token exchange endpoint. + SSOTokenURL string `mapstructure:"SSO_TOKEN_URL"` + + // SSOIntrospectionURL is the OAuth token introspection endpoint. + SSOIntrospectionURL string `mapstructure:"SSO_INTROSPECTION_URL"` + + // SessionSecret is the secret key used for session encryption. + SessionSecret string `mapstructure:"SESSION_SECRET"` + + // UseExternalBrowserAuth enables external browser-based authentication. + UseExternalBrowserAuth bool `mapstructure:"USE_EXTERNAL_BROWSER_AUTH"` + + // CompatibleWithCursor enables Cursor IDE compatibility mode. + CompatibleWithCursor bool `mapstructure:"COMPATIBLE_WITH_CURSOR"` + + // CursorCompatibleSSE enables SSE streaming compatible with Cursor IDE. + CursorCompatibleSSE bool `mapstructure:"CURSOR_COMPATIBLE_SSE"` + + // EnableAuth controls whether authentication is required. + EnableAuth bool `mapstructure:"ENABLE_AUTH"` // PostgreSQL configuration - PostgresHost string `mapstructure:"POSTGRES_HOST"` - PostgresPort int `mapstructure:"POSTGRES_PORT"` - PostgresDB string `mapstructure:"POSTGRES_DB"` - PostgresUser string `mapstructure:"POSTGRES_USER"` - PostgresPassword string `mapstructure:"POSTGRES_PASSWORD"` - PostgresPoolSize int `mapstructure:"POSTGRES_POOL_SIZE"` - PostgresMaxConnections int `mapstructure:"POSTGRES_MAX_CONNECTIONS"` + + // PostgresHost is the PostgreSQL server hostname or IP address. + PostgresHost string `mapstructure:"POSTGRES_HOST"` + + // PostgresPort is the PostgreSQL server port (default: 5432). + PostgresPort int `mapstructure:"POSTGRES_PORT"` + + // PostgresDB is the name of the PostgreSQL database. + PostgresDB string `mapstructure:"POSTGRES_DB"` + + // PostgresUser is the PostgreSQL authentication username. + PostgresUser string `mapstructure:"POSTGRES_USER"` + + // PostgresPassword is the PostgreSQL authentication password. + PostgresPassword string `mapstructure:"POSTGRES_PASSWORD"` + + // PostgresPoolSize is the initial number of connections in the pool. + PostgresPoolSize int `mapstructure:"POSTGRES_POOL_SIZE"` + + // PostgresMaxConnections is the maximum number of concurrent connections. + PostgresMaxConnections int `mapstructure:"POSTGRES_MAX_CONNECTIONS"` } -// Load loads configuration from environment variables +// Load loads configuration from environment variables and optional .env file. +// It applies default values, reads configuration, and performs validation. func Load() (*Config, error) { v := viper.New() @@ -89,6 +152,7 @@ func Load() (*Config, error) { return &cfg, nil } +// setDefaults sets default configuration values. func setDefaults(v *viper.Viper) { // Server defaults v.SetDefault("MCP_HOST", "localhost") @@ -119,7 +183,8 @@ func setDefaults(v *viper.Viper) { v.SetDefault("POSTGRES_MAX_CONNECTIONS", 20) } -// Validate performs validation on configuration values +// Validate performs validation on configuration values. +// It returns an error if any configuration value is invalid. func (c *Config) Validate() error { // Validate port range if c.MCPPort < 1024 || c.MCPPort > 65535 { @@ -183,12 +248,12 @@ func (c *Config) Validate() error { return nil } -// GetServerAddress returns the server address string +// GetServerAddress returns the server address string in "host:port" format. func (c *Config) GetServerAddress() string { return fmt.Sprintf("%s:%d", c.MCPHost, c.MCPPort) } -// GetPostgresConnectionString returns the PostgreSQL connection string +// GetPostgresConnectionString returns the PostgreSQL connection string. func (c *Config) GetPostgresConnectionString() string { return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", c.PostgresUser, @@ -199,12 +264,13 @@ func (c *Config) GetPostgresConnectionString() string { ) } -// HasSSL returns true if SSL is configured +// HasSSL returns true if SSL is configured (both key and cert files are set). func (c *Config) HasSSL() bool { return c.MCPSSLKeyfile != "" && c.MCPSSLCertfile != "" } -// GetSessionSecret returns the session secret, generating one for development if needed +// GetSessionSecret returns the session secret, generating one for development if needed. +// In production, a secret must be explicitly configured. func (c *Config) GetSessionSecret() string { if c.SessionSecret != "" { return c.SessionSecret @@ -218,7 +284,8 @@ func (c *Config) GetSessionSecret() string { return "" } -// generateEphemeralKey creates a temporary session secret for development +// generateEphemeralKey creates a temporary session secret for development environments. +// This should NEVER be used in production. func generateEphemeralKey() string { // This is a simple implementation; in production use crypto/rand return fmt.Sprintf("dev-ephemeral-key-%d", time.Now().UnixNano()) diff --git a/internal/errors/errors.go b/internal/errors/errors.go index f099ecd..153d2dc 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -5,15 +5,24 @@ import ( "net/http" ) -// AppError represents an application-level error +// AppError represents an application-level error with HTTP status code and error code. +// It implements the error interface and supports error wrapping. type AppError struct { - Code string `json:"code"` - Message string `json:"message"` - StatusCode int `json:"-"` - Err error `json:"-"` + // Code is a machine-readable error code (e.g., "bad_request", "not_found"). + Code string `json:"code"` + + // Message is a human-readable error message. + Message string `json:"message"` + + // StatusCode is the HTTP status code to return (e.g., 400, 404, 500). + StatusCode int `json:"-"` + + // Err is the underlying error that caused this AppError. + Err error `json:"-"` } -// Error implements the error interface +// Error implements the error interface. +// It returns a formatted error message, including the underlying error if present. func (e *AppError) Error() string { if e.Err != nil { return fmt.Sprintf("%s: %v", e.Message, e.Err) @@ -21,12 +30,13 @@ func (e *AppError) Error() string { return e.Message } -// Unwrap returns the underlying error +// Unwrap returns the underlying error for error chain inspection. +// This allows errors.Is() and errors.As() to work correctly. func (e *AppError) Unwrap() error { return e.Err } -// New creates a new AppError +// New creates a new AppError with the specified code, message, status code, and underlying error. func New(code, message string, statusCode int, err error) *AppError { return &AppError{ Code: code, @@ -36,7 +46,8 @@ func New(code, message string, statusCode int, err error) *AppError { } } -// Wrap wraps an error with additional context +// Wrap wraps an existing error with additional context. +// This is useful for adding HTTP status codes and error codes to standard errors. func Wrap(err error, code, message string, statusCode int) *AppError { return &AppError{ Code: code, @@ -48,48 +59,62 @@ func Wrap(err error, code, message string, statusCode int) *AppError { // Common error constructors -// ErrBadRequest creates a 400 Bad Request error +// ErrBadRequest creates a 400 Bad Request error. func ErrBadRequest(message string, err error) *AppError { return New("bad_request", message, http.StatusBadRequest, err) } -// ErrUnauthorized creates a 401 Unauthorized error +// ErrUnauthorized creates a 401 Unauthorized error. func ErrUnauthorized(message string, err error) *AppError { return New("unauthorized", message, http.StatusUnauthorized, err) } -// ErrForbidden creates a 403 Forbidden error +// ErrForbidden creates a 403 Forbidden error. func ErrForbidden(message string, err error) *AppError { return New("forbidden", message, http.StatusForbidden, err) } -// ErrNotFound creates a 404 Not Found error +// ErrNotFound creates a 404 Not Found error. func ErrNotFound(message string, err error) *AppError { return New("not_found", message, http.StatusNotFound, err) } -// ErrConflict creates a 409 Conflict error +// ErrConflict creates a 409 Conflict error. func ErrConflict(message string, err error) *AppError { return New("conflict", message, http.StatusConflict, err) } -// ErrInternal creates a 500 Internal Server Error +// ErrInternal creates a 500 Internal Server Error. func ErrInternal(message string, err error) *AppError { return New("internal_error", message, http.StatusInternalServerError, err) } -// ErrServiceUnavailable creates a 503 Service Unavailable error +// ErrServiceUnavailable creates a 503 Service Unavailable error. func ErrServiceUnavailable(message string, err error) *AppError { return New("service_unavailable", message, http.StatusServiceUnavailable, err) } -// Predefined errors +// Predefined errors for common scenarios. +// These can be used directly or wrapped with additional context. var ( - ErrInvalidInput = ErrBadRequest("Invalid input", nil) + // ErrInvalidInput indicates that the request contains invalid data. + ErrInvalidInput = ErrBadRequest("Invalid input", nil) + + // ErrMissingAuthHeader indicates that the Authorization header is missing. ErrMissingAuthHeader = ErrUnauthorized("Missing authorization header", nil) - ErrInvalidToken = ErrUnauthorized("Invalid token", nil) - ErrExpiredToken = ErrUnauthorized("Token expired", nil) - ErrResourceNotFound = ErrNotFound("Resource not found", nil) - ErrDatabaseError = ErrInternal("Database error", nil) - ErrStorageError = ErrInternal("Storage error", nil) + + // ErrInvalidToken indicates that the provided token is invalid or malformed. + ErrInvalidToken = ErrUnauthorized("Invalid token", nil) + + // ErrExpiredToken indicates that the provided token has expired. + ErrExpiredToken = ErrUnauthorized("Token expired", nil) + + // ErrResourceNotFound indicates that the requested resource doesn't exist. + ErrResourceNotFound = ErrNotFound("Resource not found", nil) + + // ErrDatabaseError indicates a database operation failed. + ErrDatabaseError = ErrInternal("Database error", nil) + + // ErrStorageError indicates a storage operation failed. + ErrStorageError = ErrInternal("Storage error", nil) ) diff --git a/internal/storage/interface.go b/internal/storage/interface.go index 5c76182..75ca1ae 100644 --- a/internal/storage/interface.go +++ b/internal/storage/interface.go @@ -5,88 +5,188 @@ import ( "time" ) -// Client represents an OAuth client +// Client represents an OAuth 2.0 client registered with the server. +// Clients are identified by their ID and authenticated using their Secret. type Client struct { - ID string `json:"id"` - Secret string `json:"secret"` - Name string `json:"name"` - RedirectURIs []string `json:"redirect_uris"` - GrantTypes []string `json:"grant_types"` - ResponseTypes []string `json:"response_types"` - Scope string `json:"scope"` - CreatedAt time.Time `json:"created_at"` + // ID is the unique identifier for this client. + ID string `json:"id"` + + // Secret is the client's authentication credential. + // This should be stored securely and never exposed in logs. + Secret string `json:"secret"` + + // Name is a human-readable name for the client. + Name string `json:"name"` + + // RedirectURIs are the allowed callback URLs for this client. + RedirectURIs []string `json:"redirect_uris"` + + // GrantTypes specifies which OAuth grant types this client can use. + // Common values: "authorization_code", "refresh_token", "client_credentials" + GrantTypes []string `json:"grant_types"` + + // ResponseTypes specifies allowed response types. + // Common values: "code", "token" + ResponseTypes []string `json:"response_types"` + + // Scope defines the default scope for this client. + Scope string `json:"scope"` + + // CreatedAt is when this client was registered. + CreatedAt time.Time `json:"created_at"` } -// AuthorizationCode represents an authorization code +// AuthorizationCode represents an OAuth 2.0 authorization code. +// Authorization codes are short-lived credentials exchanged for access tokens. type AuthorizationCode struct { - Code string `json:"code"` - ClientID string `json:"client_id"` - RedirectURI string `json:"redirect_uri"` - Scope string `json:"scope"` - CodeChallenge string `json:"code_challenge"` - CodeChallengeMethod string `json:"code_challenge_method"` - SnowflakeToken map[string]interface{} `json:"snowflake_token,omitempty"` - ExpiresAt time.Time `json:"expires_at"` - State string `json:"state"` + // Code is the authorization code value. + Code string `json:"code"` + + // ClientID identifies which client this code was issued to. + ClientID string `json:"client_id"` + + // RedirectURI is the callback URL that must match during token exchange. + RedirectURI string `json:"redirect_uri"` + + // Scope defines the permissions granted by this authorization. + Scope string `json:"scope"` + + // CodeChallenge is the PKCE code challenge (if using PKCE). + CodeChallenge string `json:"code_challenge"` + + // CodeChallengeMethod specifies the PKCE challenge method ("S256" or "plain"). + CodeChallengeMethod string `json:"code_challenge_method"` + + // SnowflakeToken contains Snowflake-specific token data. + SnowflakeToken map[string]interface{} `json:"snowflake_token,omitempty"` + + // ExpiresAt is when this authorization code expires. + ExpiresAt time.Time `json:"expires_at"` + + // State is the CSRF protection state parameter. + State string `json:"state"` } -// AccessToken represents an access token +// AccessToken represents an OAuth 2.0 access token. +// Access tokens grant access to protected resources. type AccessToken struct { - Token string `json:"token"` - ClientID string `json:"client_id"` - Scope string `json:"scope"` - TokenType string `json:"token_type"` + // Token is the access token value. + Token string `json:"token"` + + // ClientID identifies which client this token was issued to. + ClientID string `json:"client_id"` + + // Scope defines the permissions granted by this token. + Scope string `json:"scope"` + + // TokenType specifies the token type (typically "Bearer"). + TokenType string `json:"token_type"` + + // ExpiresAt is when this access token expires. ExpiresAt time.Time `json:"expires_at"` } -// RefreshToken represents a refresh token +// RefreshToken represents an OAuth 2.0 refresh token. +// Refresh tokens are used to obtain new access tokens without re-authentication. type RefreshToken struct { - Token string `json:"token"` - ClientID string `json:"client_id"` - AccessToken string `json:"access_token"` - Scope string `json:"scope"` - ExpiresAt time.Time `json:"expires_at"` + // Token is the refresh token value. + Token string `json:"token"` + + // ClientID identifies which client this token was issued to. + ClientID string `json:"client_id"` + + // AccessToken is the associated access token. + AccessToken string `json:"access_token"` + + // Scope defines the permissions granted by this token. + Scope string `json:"scope"` + + // ExpiresAt is when this refresh token expires. + ExpiresAt time.Time `json:"expires_at"` } -// Store defines the interface for storage operations +// Store defines the interface for persistent storage operations. +// Implementations must be safe for concurrent use. type Store interface { - // Connection management + // Connect establishes a connection to the storage backend. + // It should be called before any other operations. Connect(ctx context.Context, cfg Config) error + + // Disconnect closes the connection to the storage backend. + // It should be called during graceful shutdown. Disconnect(ctx context.Context) error + + // IsHealthy returns true if the storage backend is responsive. IsHealthy(ctx context.Context) bool - // Client operations + // GetClientByNameAndRedirectURIs retrieves a client by name and redirect URIs. + // Returns nil if no matching client is found. GetClientByNameAndRedirectURIs(ctx context.Context, name string, redirectURIs []string) (*Client, error) + + // StoreClient persists a client to storage. StoreClient(ctx context.Context, client *Client) error + + // GetClient retrieves a client by ID. + // Returns nil if the client doesn't exist. GetClient(ctx context.Context, clientID string) (*Client, error) - // Authorization code operations + // StoreAuthorizationCode persists an authorization code to storage. StoreAuthorizationCode(ctx context.Context, code *AuthorizationCode) error + + // GetAuthorizationCode retrieves an authorization code by value. + // Returns nil if the code doesn't exist or has expired. GetAuthorizationCode(ctx context.Context, code string) (*AuthorizationCode, error) + + // UpdateAuthorizationCodeToken updates the Snowflake token associated with a code. UpdateAuthorizationCodeToken(ctx context.Context, code string, token map[string]interface{}) error + + // DeleteAuthorizationCode removes an authorization code from storage. DeleteAuthorizationCode(ctx context.Context, code string) error - // Access token operations + // StoreAccessToken persists an access token to storage. StoreAccessToken(ctx context.Context, token *AccessToken) error + + // GetAccessToken retrieves an access token by value. + // Returns nil if the token doesn't exist or has expired. GetAccessToken(ctx context.Context, token string) (*AccessToken, error) + + // DeleteAccessToken removes an access token from storage. DeleteAccessToken(ctx context.Context, token string) error - // Refresh token operations + // StoreRefreshToken persists a refresh token to storage. StoreRefreshToken(ctx context.Context, token *RefreshToken) error + + // GetRefreshToken retrieves a refresh token by value. + // Returns nil if the token doesn't exist or has expired. GetRefreshToken(ctx context.Context, token string) (*RefreshToken, error) + + // DeleteRefreshToken removes a refresh token from storage. DeleteRefreshToken(ctx context.Context, token string) error - // Status + // GetStatus returns diagnostic information about the storage backend. GetStatus(ctx context.Context) map[string]interface{} } -// Config holds storage configuration +// Config holds storage configuration. type Config struct { - Host string - Port int - Database string - Username string - Password string - PoolSize int + // Host is the database server hostname or IP address. + Host string + + // Port is the database server port. + Port int + + // Database is the name of the database to connect to. + Database string + + // Username is the database authentication username. + Username string + + // Password is the database authentication password. + Password string + + // PoolSize is the initial number of connections in the pool. + PoolSize int + + // MaxConnections is the maximum number of concurrent connections allowed. MaxConnections int }