From c216df37fd6d42d53d55ef647fc255532ba2e61c Mon Sep 17 00:00:00 2001 From: Alexander Lerma Date: Mon, 1 Jun 2026 06:16:24 -0500 Subject: [PATCH 1/6] feat(store): add Postgres backend for the hub store MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a faithful pkg/store/postgres mirror of pkg/store/sqlite so the hub can run as a stateless control plane against an external Postgres, configured via database.driver: postgres (SCION_SERVER_DATABASE_DRIVER / SCION_SERVER_DATABASE_URL) instead of a node-local SQLite file. The ent layer already supported Postgres (entc.OpenPostgres); this adds the hand-written store's Postgres twin and wires initStore's case "postgres": to compose them via entadapter.NewCompositeStore. The mirror stays 1-to-1 with the SQLite implementation — same method names, signatures, comments and control flow (206 methods each). The only differences are dialect mechanics: - ? -> $N numbered placeholders - INSERT OR IGNORE/REPLACE -> ON CONFLICT DO NOTHING / DO UPDATE - sqlite_master / PRAGMA introspection -> information_schema - randomblob() UUID seeds -> gen_random_uuid() - json_each() -> json_array_elements_text(); json_array() -> json_build_array() - COLLATE NOCASE -> LOWER(email) functional unique index - BLOB -> BYTEA; datetime('now') -> NOW(); INSTR/SUBSTR -> split_part The grove->project data backfill (MigrateGroveToProjectData) is a SQLite-era repair and is intentionally skipped on the Postgres path — a fresh Postgres DB has no legacy grove rows. The Postgres driver import is gated behind a no_postgres build tag, mirroring the existing no_sqlite tag, so slim images can exclude either driver. Tests: env-gated integration tests (SCION_TEST_POSTGRES_URL) cover all 53 migrations plus CRUD round-trips for users, projects, agents, secrets, groups, policies, invite codes and env vars. They skip when the env var is unset so the CI no_sqlite path (make test-fast, no DB) stays green. Co-Authored-By: Claude Opus 4.8 --- .../project-log/2026-06-01-postgres-store.md | 93 ++ changelog/2026-06-01-changelog.md | 12 + cmd/server_foreground.go | 33 + pkg/store/postgres/agents.go | 732 +++++++++++ pkg/store/postgres/allowlist.go | 330 +++++ pkg/store/postgres/brokers.go | 274 ++++ pkg/store/postgres/brokersecret.go | 290 +++++ pkg/store/postgres/driver.go | 19 + pkg/store/postgres/envvars.go | 199 +++ pkg/store/postgres/gcp_service_account.go | 200 +++ pkg/store/postgres/github_installation.go | 185 +++ pkg/store/postgres/groups.go | 555 ++++++++ pkg/store/postgres/harness_configs.go | 385 ++++++ pkg/store/postgres/invites.go | 259 ++++ pkg/store/postgres/maintenance.go | 269 ++++ pkg/store/postgres/messages.go | 226 ++++ pkg/store/postgres/migrations.go | 1150 +++++++++++++++++ pkg/store/postgres/notification.go | 553 ++++++++ pkg/store/postgres/policies.go | 361 ++++++ pkg/store/postgres/postgres.go | 311 +++++ pkg/store/postgres/postgres_test.go | 659 ++++++++++ pkg/store/postgres/project_sync_state.go | 142 ++ pkg/store/postgres/projects.go | 428 ++++++ pkg/store/postgres/providers.go | 215 +++ pkg/store/postgres/schedule.go | 365 ++++++ pkg/store/postgres/scheduled_event.go | 317 +++++ pkg/store/postgres/secrets.go | 325 +++++ pkg/store/postgres/templates.go | 384 ++++++ pkg/store/postgres/tokens.go | 201 +++ pkg/store/postgres/users.go | 247 ++++ 30 files changed, 9719 insertions(+) create mode 100644 .design/project-log/2026-06-01-postgres-store.md create mode 100644 changelog/2026-06-01-changelog.md create mode 100644 pkg/store/postgres/agents.go create mode 100644 pkg/store/postgres/allowlist.go create mode 100644 pkg/store/postgres/brokers.go create mode 100644 pkg/store/postgres/brokersecret.go create mode 100644 pkg/store/postgres/driver.go create mode 100644 pkg/store/postgres/envvars.go create mode 100644 pkg/store/postgres/gcp_service_account.go create mode 100644 pkg/store/postgres/github_installation.go create mode 100644 pkg/store/postgres/groups.go create mode 100644 pkg/store/postgres/harness_configs.go create mode 100644 pkg/store/postgres/invites.go create mode 100644 pkg/store/postgres/maintenance.go create mode 100644 pkg/store/postgres/messages.go create mode 100644 pkg/store/postgres/migrations.go create mode 100644 pkg/store/postgres/notification.go create mode 100644 pkg/store/postgres/policies.go create mode 100644 pkg/store/postgres/postgres.go create mode 100644 pkg/store/postgres/postgres_test.go create mode 100644 pkg/store/postgres/project_sync_state.go create mode 100644 pkg/store/postgres/projects.go create mode 100644 pkg/store/postgres/providers.go create mode 100644 pkg/store/postgres/schedule.go create mode 100644 pkg/store/postgres/scheduled_event.go create mode 100644 pkg/store/postgres/secrets.go create mode 100644 pkg/store/postgres/templates.go create mode 100644 pkg/store/postgres/tokens.go create mode 100644 pkg/store/postgres/users.go diff --git a/.design/project-log/2026-06-01-postgres-store.md b/.design/project-log/2026-06-01-postgres-store.md new file mode 100644 index 00000000..d1a38182 --- /dev/null +++ b/.design/project-log/2026-06-01-postgres-store.md @@ -0,0 +1,93 @@ +# PostgreSQL Store Implementation + +**Date:** 2026-06-01 + +## Motivation + +The hub needs to run stateless in a hosted/cloud topology where the database is +a GitOps-configured external service (e.g. Cloud SQL, Lakebase, any managed +Postgres). SQLite is process-local and cannot be shared across replicas. +The `database.driver` + `database.url` fields already existed in `GlobalConfig` +to hold a connection URL; what was missing was a `Store` implementation that +consumed them. See `.design/hosted/resource-storage-refactor.md` §1.1 +("Cloud / hosted mode — the storage backend is GCS") for the broader hosted +architecture context that motivated a stateless control plane. + +## What landed + +### `pkg/store/postgres/` + +A new package, parallel in shape to `pkg/store/sqlite/`, implementing the full +`store.Store` interface against PostgreSQL. + +- **`postgres.go`** — `PostgresStore` struct wrapping `*sql.DB`, `New(connURL + string)`, `Migrate(ctx)`, `Ping`, `Close`. Connection pool fixed at + `MaxOpenConns=4` / `MaxIdleConns=4`. +- **`driver.go`** — blank import of `github.com/lib/pq` (database/sql driver + name `postgres`) guarded by `//go:build !no_postgres`. +- **`migrations.go`** — 53 versioned migrations tracked in a + `schema_migrations` table (`version INTEGER PRIMARY KEY`). `Migrate` is + idempotent: it reads `MAX(version)` and skips already-applied steps. Each + migration runs in its own transaction; a `foreignKeysOffMigrations` map is + preserved for shape-parity with the SQLite runner (in Postgres, FK deferral + is handled inside the migration SQL itself via `CASCADE`/explicit FK drops, + so the function body is a plain transaction). +- **Per-entity files** (`agents.go`, `users.go`, `projects.go`, `secrets.go`, + `messages.go`, `groups.go`, `policies.go`, `tokens.go`, `invites.go`, + `brokers.go`, `envvars.go`, `schedule.go`, `scheduled_event.go`, + `notification.go`, `templates.go`, `harness_configs.go`, `allowlist.go`, + `brokersecret.go`, `providers.go`, `project_sync_state.go`, + `gcp_service_account.go`, `github_installation.go`, `maintenance.go`) — + one file per entity group, matching the sqlite layout. + +### `initStore` case in `cmd/server_foreground.go` + +`initStore` gained a `"postgres"` branch: `postgres.New(cfg.Database.URL)` → +`pgStore.Migrate` → `entc.OpenPostgres(cfg.Database.URL)` → `entc.AutoMigrate` +→ `entadapter.NewCompositeStore`. The grove→project data backfill +(`entc.MigrateGroveToProjectData`) is **not** called on the postgres path (see +below). + +### Dialect translation rules applied throughout + +| SQLite pattern | PostgreSQL replacement | +|---|---| +| `?` positional placeholder | `$N` numbered placeholder | +| `INSERT OR IGNORE` / `INSERT OR REPLACE` | `ON CONFLICT … DO NOTHING` / `ON CONFLICT … DO UPDATE SET` | +| `sqlite_master` / `pragma_table_info` | `information_schema.tables` / `information_schema.columns` (both scoped to `table_schema='public'`, queried with `$1`/$`$2` params) | +| `randomblob(16)` | `gen_random_uuid()::text` (pgcrypto built-in; used in several data-backfill migrations) | +| `json_each(…)` | `json_array_elements_text(…::json)` (used in agent ancestry filter) | +| Case-insensitive email uniqueness via `UNIQUE` on TEXT | `CREATE UNIQUE INDEX … ON allow_list (LOWER(email))` (functional unique index) | +| `BLOB` | `BYTEA` (broker secret key column) | + +## What was deliberately skipped + +**Grove→project data backfill** (`entc.MigrateGroveToProjectData`) is omitted +from the postgres `initStore` path. A fresh postgres database starts with the +post-rename schema (V50 renames `groves` → `projects` and all `grove_id` +columns → `project_id` in-place); there is no legacy ent sqlite data to +backfill. The backfill only applies to existing SQLite deployments upgrading +in-place. + +## How it is tested + +`pkg/store/postgres/postgres_test.go` contains integration tests (migration +idempotency, CRUD + filter coverage for users, projects, agents, secrets, +groups, policies, invite codes, env vars) that run against a live Postgres +instance. Tests skip automatically when `SCION_TEST_POSTGRES_URL` is not set: + +```go +const envVarDSN = "SCION_TEST_POSTGRES_URL" +// ... +if dsn == "" { + t.Skipf("set %s to run Postgres tests", envVarDSN) +} +``` + +Each test calls `resetSchema` (`DROP SCHEMA public CASCADE; CREATE SCHEMA +public`) before applying migrations, giving a clean slate per test function. + +`make test-fast` (which passes `-tags no_sqlite`) excludes the SQLite driver +and exercises the rest of the codebase including the postgres package files; CI +runs this path. The full Postgres integration suite requires a live DSN and is +not wired into CI at this time. diff --git a/changelog/2026-06-01-changelog.md b/changelog/2026-06-01-changelog.md new file mode 100644 index 00000000..dc08ae65 --- /dev/null +++ b/changelog/2026-06-01-changelog.md @@ -0,0 +1,12 @@ +# Release Notes (Jun 1, 2026) + +This release introduces a Postgres store backend for the hub, enabling stateless control-plane deployments backed by an external database instead of a node-local SQLite PVC. The ent-layer has supported Postgres for some time; this change brings the hand-written store layer to parity. + +## 🚀 Features +* **[Store]: Postgres Backend.** The hub now accepts `database.driver: postgres` (environment variables `SCION_SERVER_DATABASE_DRIVER=postgres` and `SCION_SERVER_DATABASE_URL=`) to connect to an external Postgres database instead of the default node-local SQLite file. + * **Stateless Control Plane.** With Postgres as the backing store the hub StatefulSet and its associated PVC are no longer required for state durability, enabling fully stateless hub deployments that can scale horizontally or restart without data loss. + * **Ent Parity.** The ent-generated layer has supported Postgres since its introduction; this change adds the hand-written store's Postgres twin so that all hub persistence paths (agents, sessions, secrets, projects) are covered by both drivers. + * **Migration Note.** The grove → project backfill that runs on SQLite databases at startup is skipped automatically on a fresh Postgres deployment; no manual intervention is needed. + +## 🐛 Fixes +* **[Infrastructure]:** Continued monitoring and stabilization of the agent dispatch pipeline. diff --git a/cmd/server_foreground.go b/cmd/server_foreground.go index 307a2cb4..d514a9bf 100644 --- a/cmd/server_foreground.go +++ b/cmd/server_foreground.go @@ -47,6 +47,7 @@ import ( "github.com/GoogleCloudPlatform/scion/pkg/storage" "github.com/GoogleCloudPlatform/scion/pkg/store" "github.com/GoogleCloudPlatform/scion/pkg/store/entadapter" + "github.com/GoogleCloudPlatform/scion/pkg/store/postgres" "github.com/GoogleCloudPlatform/scion/pkg/store/sqlite" "github.com/GoogleCloudPlatform/scion/pkg/util" "github.com/GoogleCloudPlatform/scion/pkg/util/logging" @@ -680,6 +681,38 @@ func initStore(cfg *config.GlobalConfig) (store.Store, error) { return nil, fmt.Errorf("database ping failed: %w", err) } + return s, nil + case "postgres": + pgStore, err := postgres.New(cfg.Database.URL) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + if err := pgStore.Migrate(context.Background()); err != nil { + pgStore.Close() + return nil, fmt.Errorf("failed to run migrations: %w", err) + } + + entClient, err := entc.OpenPostgres(cfg.Database.URL) + if err != nil { + pgStore.Close() + return nil, fmt.Errorf("failed to open ent database: %w", err) + } + if err := entc.AutoMigrate(context.Background(), entClient); err != nil { + entClient.Close() + pgStore.Close() + return nil, fmt.Errorf("failed to run ent migrations: %w", err) + } + + // grove->project backfill is a SQLite-era data repair; a fresh Postgres DB has no legacy grove rows, so it is intentionally skipped. + + s := entadapter.NewCompositeStore(pgStore, entClient) + + if err := s.Ping(context.Background()); err != nil { + pgStore.Close() + return nil, fmt.Errorf("database ping failed: %w", err) + } + return s, nil default: return nil, fmt.Errorf("unsupported database driver: %s", cfg.Database.Driver) diff --git a/pkg/store/postgres/agents.go b/pkg/store/postgres/agents.go new file mode 100644 index 00000000..70c1873b --- /dev/null +++ b/pkg/store/postgres/agents.go @@ -0,0 +1,732 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres provides a PostgreSQL implementation of the Store interface. +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +func (s *PostgresStore) CreateAgent(ctx context.Context, agent *store.Agent) error { + now := time.Now() + agent.Created = now + agent.Updated = now + agent.StateVersion = 1 + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO agents ( + id, agent_id, name, template, project_id, + labels, annotations, + phase, activity, tool_name, + connection_state, container_status, runtime_state, + stalled_from_activity, + image, detached, runtime, runtime_broker_id, web_pty_enabled, task_summary, message, + applied_config, + created_at, updated_at, last_seen, last_activity_event, deleted_at, + created_by, owner_id, visibility, state_version, ancestry + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32) + `, + agent.ID, agent.Slug, agent.Name, agent.Template, agent.ProjectID, + marshalJSON(agent.Labels), marshalJSON(agent.Annotations), + agent.Phase, agent.Activity, agent.ToolName, + agent.ConnectionState, agent.ContainerStatus, agent.RuntimeState, + agent.StalledFromActivity, + agent.Image, boolToInt(agent.Detached), agent.Runtime, nullableString(agent.RuntimeBrokerID), boolToInt(agent.WebPTYEnabled), agent.TaskSummary, agent.Message, + marshalJSON(agent.AppliedConfig), + agent.Created, agent.Updated, nullableTime(agent.LastSeen), nullableTime(agent.LastActivityEvent), nullableTime(agent.DeletedAt), + agent.CreatedBy, agent.OwnerID, agent.Visibility, agent.StateVersion, marshalJSON(agent.Ancestry), + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) GetAgent(ctx context.Context, id string) (*store.Agent, error) { + agent := &store.Agent{} + var labels, annotations, appliedConfig string + var lastSeen, lastActivityEvent, deletedAt, startedAt sql.NullTime + var runtimeBrokerID, message, toolName, ancestry sql.NullString + + err := s.db.QueryRowContext(ctx, ` + SELECT id, agent_id, name, template, project_id, + labels, annotations, + phase, activity, tool_name, + connection_state, container_status, runtime_state, + stalled_from_activity, + current_turns, current_model_calls, + image, detached, runtime, runtime_broker_id, web_pty_enabled, task_summary, message, + applied_config, + created_at, updated_at, last_seen, last_activity_event, deleted_at, started_at, + created_by, owner_id, visibility, state_version, ancestry + FROM agents WHERE id = $1 + `, id).Scan( + &agent.ID, &agent.Slug, &agent.Name, &agent.Template, &agent.ProjectID, + &labels, &annotations, + &agent.Phase, &agent.Activity, &toolName, + &agent.ConnectionState, &agent.ContainerStatus, &agent.RuntimeState, + &agent.StalledFromActivity, + &agent.CurrentTurns, &agent.CurrentModelCalls, + &agent.Image, &agent.Detached, &agent.Runtime, &runtimeBrokerID, &agent.WebPTYEnabled, &agent.TaskSummary, &message, + &appliedConfig, + &agent.Created, &agent.Updated, &lastSeen, &lastActivityEvent, &deletedAt, &startedAt, + &agent.CreatedBy, &agent.OwnerID, &agent.Visibility, &agent.StateVersion, &ancestry, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + unmarshalJSON(labels, &agent.Labels) + unmarshalJSON(annotations, &agent.Annotations) + unmarshalJSON(appliedConfig, &agent.AppliedConfig) + unmarshalJSON(ancestry.String, &agent.Ancestry) + if lastSeen.Valid { + agent.LastSeen = lastSeen.Time + } + if lastActivityEvent.Valid { + agent.LastActivityEvent = lastActivityEvent.Time + } + if deletedAt.Valid { + agent.DeletedAt = deletedAt.Time + } + if startedAt.Valid { + agent.StartedAt = startedAt.Time + } + if runtimeBrokerID.Valid { + agent.RuntimeBrokerID = runtimeBrokerID.String + } + if message.Valid { + agent.Message = message.String + } + if toolName.Valid { + agent.ToolName = toolName.String + } + + return agent, nil +} + +func (s *PostgresStore) GetAgentBySlug(ctx context.Context, projectID, slug string) (*store.Agent, error) { + agent := &store.Agent{} + var labels, annotations, appliedConfig string + var lastSeen, lastActivityEvent, deletedAt, startedAt sql.NullTime + var runtimeBrokerID, message, toolName, ancestry sql.NullString + + err := s.db.QueryRowContext(ctx, ` + SELECT id, agent_id, name, template, project_id, + labels, annotations, + phase, activity, tool_name, + connection_state, container_status, runtime_state, + stalled_from_activity, + current_turns, current_model_calls, + image, detached, runtime, runtime_broker_id, web_pty_enabled, task_summary, message, + applied_config, + created_at, updated_at, last_seen, last_activity_event, deleted_at, started_at, + created_by, owner_id, visibility, state_version, ancestry + FROM agents WHERE project_id = $1 AND agent_id = $2 + `, projectID, slug).Scan( + &agent.ID, &agent.Slug, &agent.Name, &agent.Template, &agent.ProjectID, + &labels, &annotations, + &agent.Phase, &agent.Activity, &toolName, + &agent.ConnectionState, &agent.ContainerStatus, &agent.RuntimeState, + &agent.StalledFromActivity, + &agent.CurrentTurns, &agent.CurrentModelCalls, + &agent.Image, &agent.Detached, &agent.Runtime, &runtimeBrokerID, &agent.WebPTYEnabled, &agent.TaskSummary, &message, + &appliedConfig, + &agent.Created, &agent.Updated, &lastSeen, &lastActivityEvent, &deletedAt, &startedAt, + &agent.CreatedBy, &agent.OwnerID, &agent.Visibility, &agent.StateVersion, &ancestry, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + unmarshalJSON(labels, &agent.Labels) + unmarshalJSON(annotations, &agent.Annotations) + unmarshalJSON(appliedConfig, &agent.AppliedConfig) + unmarshalJSON(ancestry.String, &agent.Ancestry) + if lastSeen.Valid { + agent.LastSeen = lastSeen.Time + } + if lastActivityEvent.Valid { + agent.LastActivityEvent = lastActivityEvent.Time + } + if deletedAt.Valid { + agent.DeletedAt = deletedAt.Time + } + if startedAt.Valid { + agent.StartedAt = startedAt.Time + } + if runtimeBrokerID.Valid { + agent.RuntimeBrokerID = runtimeBrokerID.String + } + if message.Valid { + agent.Message = message.String + } + if toolName.Valid { + agent.ToolName = toolName.String + } + + return agent, nil +} + +func (s *PostgresStore) UpdateAgent(ctx context.Context, agent *store.Agent) error { + agent.Updated = time.Now() + newVersion := agent.StateVersion + 1 + + result, err := s.db.ExecContext(ctx, ` + UPDATE agents SET + agent_id = $1, name = $2, template = $3, + labels = $4, annotations = $5, + phase = $6, activity = $7, tool_name = $8, + connection_state = $9, container_status = $10, runtime_state = $11, + stalled_from_activity = $12, + image = $13, detached = $14, runtime = $15, runtime_broker_id = $16, web_pty_enabled = $17, task_summary = $18, message = $19, + applied_config = $20, + updated_at = $21, last_seen = $22, last_activity_event = $23, deleted_at = $24, + owner_id = $25, visibility = $26, state_version = $27 + WHERE id = $28 AND state_version = $29 + `, + agent.Slug, agent.Name, agent.Template, + marshalJSON(agent.Labels), marshalJSON(agent.Annotations), + agent.Phase, agent.Activity, agent.ToolName, + agent.ConnectionState, agent.ContainerStatus, agent.RuntimeState, + agent.StalledFromActivity, + agent.Image, boolToInt(agent.Detached), agent.Runtime, nullableString(agent.RuntimeBrokerID), boolToInt(agent.WebPTYEnabled), agent.TaskSummary, agent.Message, + marshalJSON(agent.AppliedConfig), + agent.Updated, nullableTime(agent.LastSeen), nullableTime(agent.LastActivityEvent), nullableTime(agent.DeletedAt), + agent.OwnerID, agent.Visibility, newVersion, + agent.ID, agent.StateVersion, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + // Check if agent exists + var exists bool + s.db.QueryRowContext(ctx, "SELECT 1 FROM agents WHERE id = $1", agent.ID).Scan(&exists) + if !exists { + return store.ErrNotFound + } + return store.ErrVersionConflict + } + + agent.StateVersion = newVersion + return nil +} + +func (s *PostgresStore) DeleteAgent(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM agents WHERE id = $1", id) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) ListAgents(ctx context.Context, filter store.AgentFilter, opts store.ListOptions) (*store.ListResult[store.Agent], error) { + var conditions []string + var args []interface{} + + if len(filter.MemberOrOwnerProjectIDs) > 0 { + // Combine project_id membership with owner_id match using OR + placeholders := make([]string, len(filter.MemberOrOwnerProjectIDs)) + for i, id := range filter.MemberOrOwnerProjectIDs { + placeholders[i] = fmt.Sprintf("$%d", len(args)+1) + args = append(args, id) + } + orParts := []string{"project_id IN (" + strings.Join(placeholders, ",") + ")"} + if filter.OwnerID != "" { + orParts = append(orParts, fmt.Sprintf("owner_id = $%d", len(args)+1)) + args = append(args, filter.OwnerID) + } + conditions = append(conditions, "("+strings.Join(orParts, " OR ")+")") + } else if len(filter.MemberProjectIDs) > 0 { + placeholders := make([]string, len(filter.MemberProjectIDs)) + for i, id := range filter.MemberProjectIDs { + placeholders[i] = fmt.Sprintf("$%d", len(args)+1) + args = append(args, id) + } + conditions = append(conditions, "project_id IN ("+strings.Join(placeholders, ",")+")") + } else if filter.OwnerID != "" { + conditions = append(conditions, fmt.Sprintf("owner_id = $%d", len(args)+1)) + args = append(args, filter.OwnerID) + } + if filter.ExcludeOwnerID != "" { + conditions = append(conditions, fmt.Sprintf("owner_id != $%d", len(args)+1)) + args = append(args, filter.ExcludeOwnerID) + } + if filter.ProjectID != "" { + conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)+1)) + args = append(args, filter.ProjectID) + } + if filter.RuntimeBrokerID != "" { + conditions = append(conditions, fmt.Sprintf("runtime_broker_id = $%d", len(args)+1)) + args = append(args, filter.RuntimeBrokerID) + } + if filter.Phase != "" { + conditions = append(conditions, fmt.Sprintf("phase = $%d", len(args)+1)) + args = append(args, filter.Phase) + } + if filter.AncestorID != "" { + conditions = append(conditions, fmt.Sprintf("EXISTS (SELECT 1 FROM json_array_elements_text(ancestry::json) AS e(value) WHERE e.value = $%d)", len(args)+1)) + args = append(args, filter.AncestorID) + } + + // Exclude soft-deleted agents unless explicitly requested + if !filter.IncludeDeleted { + conditions = append(conditions, "deleted_at IS NULL") + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + // Get total count + var totalCount int + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM agents %s", whereClause) + if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil { + return nil, err + } + + // Apply pagination + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + if limit > 200 { + limit = 200 + } + + query := fmt.Sprintf(` + SELECT id, agent_id, name, template, project_id, + labels, annotations, + phase, activity, tool_name, + connection_state, container_status, runtime_state, + stalled_from_activity, + current_turns, current_model_calls, + image, detached, runtime, runtime_broker_id, web_pty_enabled, task_summary, message, + applied_config, + created_at, updated_at, last_seen, last_activity_event, deleted_at, started_at, + created_by, owner_id, visibility, state_version, ancestry + FROM agents %s ORDER BY created_at DESC LIMIT $%d + `, whereClause, len(args)+1) + args = append(args, limit+1) // Fetch one extra to determine if there's a next page + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var agents []store.Agent + for rows.Next() { + var agent store.Agent + var labels, annotations, appliedConfig string + var lastSeen, lastActivityEvent, deletedAt, startedAt sql.NullTime + var runtimeBrokerID, message, toolName, ancestry sql.NullString + + if err := rows.Scan( + &agent.ID, &agent.Slug, &agent.Name, &agent.Template, &agent.ProjectID, + &labels, &annotations, + &agent.Phase, &agent.Activity, &toolName, + &agent.ConnectionState, &agent.ContainerStatus, &agent.RuntimeState, + &agent.StalledFromActivity, + &agent.CurrentTurns, &agent.CurrentModelCalls, + &agent.Image, &agent.Detached, &agent.Runtime, &runtimeBrokerID, &agent.WebPTYEnabled, &agent.TaskSummary, &message, + &appliedConfig, + &agent.Created, &agent.Updated, &lastSeen, &lastActivityEvent, &deletedAt, &startedAt, + &agent.CreatedBy, &agent.OwnerID, &agent.Visibility, &agent.StateVersion, &ancestry, + ); err != nil { + return nil, err + } + + unmarshalJSON(labels, &agent.Labels) + unmarshalJSON(annotations, &agent.Annotations) + unmarshalJSON(appliedConfig, &agent.AppliedConfig) + unmarshalJSON(ancestry.String, &agent.Ancestry) + if lastSeen.Valid { + agent.LastSeen = lastSeen.Time + } + if lastActivityEvent.Valid { + agent.LastActivityEvent = lastActivityEvent.Time + } + if deletedAt.Valid { + agent.DeletedAt = deletedAt.Time + } + if startedAt.Valid { + agent.StartedAt = startedAt.Time + } + if runtimeBrokerID.Valid { + agent.RuntimeBrokerID = runtimeBrokerID.String + } + if message.Valid { + agent.Message = message.String + } + if toolName.Valid { + agent.ToolName = toolName.String + } + + agents = append(agents, agent) + } + + result := &store.ListResult[store.Agent]{ + Items: agents, + TotalCount: totalCount, + } + + // Handle pagination + if len(agents) > limit { + result.Items = agents[:limit] + result.NextCursor = agents[limit-1].ID + } + + return result, nil +} + +func (s *PostgresStore) UpdateAgentStatus(ctx context.Context, id string, su store.AgentStatusUpdate) error { + now := time.Now() + + // When activity is being updated to something other than "executing", + // clear tool_name (it's only meaningful during execution). + // We signal this by setting the activity-provided flag. + activityProvided := su.Activity != "" + + // Prepare nullable values for limits tracking fields + var currentTurnsProvided bool + var currentTurnsVal int + if su.CurrentTurns != nil { + currentTurnsProvided = true + currentTurnsVal = *su.CurrentTurns + } + var currentModelCallsProvided bool + var currentModelCallsVal int + if su.CurrentModelCalls != nil { + currentModelCallsProvided = true + currentModelCallsVal = *su.CurrentModelCalls + } + + result, err := s.db.ExecContext(ctx, ` + UPDATE agents SET + phase = COALESCE(NULLIF($1, ''), phase), + activity = CASE WHEN $2 != '' THEN + CASE WHEN phase = 'stopped' + AND activity IN ('crashed', 'limits_exceeded') + AND $3 NOT IN ('crashed', 'limits_exceeded') + THEN activity ELSE $4 END + ELSE activity END, + tool_name = CASE WHEN $5 THEN $6 ELSE tool_name END, + message = COALESCE(NULLIF($7, ''), message), + connection_state = COALESCE(NULLIF($8, ''), connection_state), + container_status = COALESCE(NULLIF($9, ''), container_status), + runtime_state = COALESCE(NULLIF($10, ''), runtime_state), + task_summary = COALESCE(NULLIF($11, ''), task_summary), + stalled_from_activity = CASE WHEN $12 != '' THEN '' ELSE stalled_from_activity END, + last_activity_event = CASE WHEN $13 != '' THEN $14 ELSE last_activity_event END, + current_turns = CASE WHEN $15 THEN $16 ELSE current_turns END, + current_model_calls = CASE WHEN $17 THEN $18 ELSE current_model_calls END, + started_at = COALESCE(NULLIF($19, ''), started_at), + updated_at = $20, + last_seen = $21 + WHERE id = $22 + `, + su.Phase, + su.Activity, su.Activity, su.Activity, + activityProvided, su.ToolName, + su.Message, su.ConnectionState, su.ContainerStatus, + su.RuntimeState, su.TaskSummary, + su.Activity, + su.Activity, now, + currentTurnsProvided, currentTurnsVal, + currentModelCallsProvided, currentModelCallsVal, + su.StartedAt, + now, now, id, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) PurgeDeletedAgents(ctx context.Context, cutoff time.Time) (int, error) { + result, err := s.db.ExecContext(ctx, + "DELETE FROM agents WHERE deleted_at IS NOT NULL AND deleted_at < $1", + cutoff, + ) + if err != nil { + return 0, err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return 0, err + } + return int(rowsAffected), nil +} + +func (s *PostgresStore) MarkStaleAgentsOffline(ctx context.Context, threshold time.Time) ([]store.Agent, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + now := time.Now() + + // Update stale agents to offline activity. + // Only affects agents that: + // - Have reported at least one heartbeat (last_seen IS NOT NULL) + // - Are in the running phase + // - Are not already in a terminal/sticky activity (completed, limits_exceeded, offline) + _, err = tx.ExecContext(ctx, ` + UPDATE agents SET + activity = 'offline', + updated_at = $1 + WHERE last_seen < $2 + AND last_seen IS NOT NULL + AND phase = 'running' + AND activity NOT IN ('completed', 'limits_exceeded', 'blocked', 'offline') + `, now, threshold) + if err != nil { + return nil, err + } + + // Fetch the agents that were just updated. + rows, err := tx.QueryContext(ctx, ` + SELECT id, agent_id, name, template, project_id, + labels, annotations, + phase, activity, tool_name, + connection_state, container_status, runtime_state, + stalled_from_activity, + current_turns, current_model_calls, + image, detached, runtime, runtime_broker_id, web_pty_enabled, task_summary, message, + applied_config, + created_at, updated_at, last_seen, last_activity_event, deleted_at, started_at, + created_by, owner_id, visibility, state_version, ancestry + FROM agents + WHERE activity = 'offline' AND updated_at = $1 + AND last_seen < $2 + AND last_seen IS NOT NULL + AND phase = 'running' + `, now, threshold) + if err != nil { + return nil, err + } + defer rows.Close() + + var agents []store.Agent + for rows.Next() { + var agent store.Agent + var labels, annotations, appliedConfig string + var lastSeen, lastActivityEvent, deletedAt, startedAt sql.NullTime + var runtimeBrokerID, message, toolName, ancestry sql.NullString + + if err := rows.Scan( + &agent.ID, &agent.Slug, &agent.Name, &agent.Template, &agent.ProjectID, + &labels, &annotations, + &agent.Phase, &agent.Activity, &toolName, + &agent.ConnectionState, &agent.ContainerStatus, &agent.RuntimeState, + &agent.StalledFromActivity, + &agent.CurrentTurns, &agent.CurrentModelCalls, + &agent.Image, &agent.Detached, &agent.Runtime, &runtimeBrokerID, &agent.WebPTYEnabled, &agent.TaskSummary, &message, + &appliedConfig, + &agent.Created, &agent.Updated, &lastSeen, &lastActivityEvent, &deletedAt, &startedAt, + &agent.CreatedBy, &agent.OwnerID, &agent.Visibility, &agent.StateVersion, &ancestry, + ); err != nil { + return nil, err + } + + unmarshalJSON(labels, &agent.Labels) + unmarshalJSON(annotations, &agent.Annotations) + unmarshalJSON(appliedConfig, &agent.AppliedConfig) + unmarshalJSON(ancestry.String, &agent.Ancestry) + if lastSeen.Valid { + agent.LastSeen = lastSeen.Time + } + if lastActivityEvent.Valid { + agent.LastActivityEvent = lastActivityEvent.Time + } + if deletedAt.Valid { + agent.DeletedAt = deletedAt.Time + } + if startedAt.Valid { + agent.StartedAt = startedAt.Time + } + if runtimeBrokerID.Valid { + agent.RuntimeBrokerID = runtimeBrokerID.String + } + if message.Valid { + agent.Message = message.String + } + if toolName.Valid { + agent.ToolName = toolName.String + } + + agents = append(agents, agent) + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return agents, nil +} + +func (s *PostgresStore) MarkStalledAgents(ctx context.Context, activityThreshold, heartbeatRecency time.Time) ([]store.Agent, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + now := time.Now() + + // Update agents to stalled activity. + // Only affects agents that: + // - Have a stale last_activity_event (older than activityThreshold) + // - Have a recent heartbeat (last_seen >= heartbeatRecency) — process is alive + // - Are in the running phase + // - Are not already in a terminal/sticky/waiting activity or already stalled/offline + _, err = tx.ExecContext(ctx, ` + UPDATE agents SET + stalled_from_activity = activity, + activity = 'stalled', + updated_at = $1 + WHERE last_activity_event < $2 + AND last_activity_event IS NOT NULL + AND last_seen >= $3 + AND last_seen IS NOT NULL + AND phase = 'running' + AND activity NOT IN ('completed', 'limits_exceeded', 'blocked', 'stalled', 'offline', 'waiting_for_input') + `, now, activityThreshold, heartbeatRecency) + if err != nil { + return nil, err + } + + // Fetch the agents that were just updated. + rows, err := tx.QueryContext(ctx, ` + SELECT id, agent_id, name, template, project_id, + labels, annotations, + phase, activity, tool_name, + connection_state, container_status, runtime_state, + stalled_from_activity, + current_turns, current_model_calls, + image, detached, runtime, runtime_broker_id, web_pty_enabled, task_summary, message, + applied_config, + created_at, updated_at, last_seen, last_activity_event, deleted_at, started_at, + created_by, owner_id, visibility, state_version, ancestry + FROM agents + WHERE activity = 'stalled' AND updated_at = $1 + AND last_activity_event < $2 + AND last_activity_event IS NOT NULL + AND last_seen >= $3 + AND last_seen IS NOT NULL + AND phase = 'running' + `, now, activityThreshold, heartbeatRecency) + if err != nil { + return nil, err + } + defer rows.Close() + + var agents []store.Agent + for rows.Next() { + var agent store.Agent + var labels, annotations, appliedConfig string + var lastSeen, lastActivityEvent, deletedAt, startedAt sql.NullTime + var runtimeBrokerID, message, toolName, ancestry sql.NullString + + if err := rows.Scan( + &agent.ID, &agent.Slug, &agent.Name, &agent.Template, &agent.ProjectID, + &labels, &annotations, + &agent.Phase, &agent.Activity, &toolName, + &agent.ConnectionState, &agent.ContainerStatus, &agent.RuntimeState, + &agent.StalledFromActivity, + &agent.CurrentTurns, &agent.CurrentModelCalls, + &agent.Image, &agent.Detached, &agent.Runtime, &runtimeBrokerID, &agent.WebPTYEnabled, &agent.TaskSummary, &message, + &appliedConfig, + &agent.Created, &agent.Updated, &lastSeen, &lastActivityEvent, &deletedAt, &startedAt, + &agent.CreatedBy, &agent.OwnerID, &agent.Visibility, &agent.StateVersion, &ancestry, + ); err != nil { + return nil, err + } + + unmarshalJSON(labels, &agent.Labels) + unmarshalJSON(annotations, &agent.Annotations) + unmarshalJSON(appliedConfig, &agent.AppliedConfig) + unmarshalJSON(ancestry.String, &agent.Ancestry) + if lastSeen.Valid { + agent.LastSeen = lastSeen.Time + } + if lastActivityEvent.Valid { + agent.LastActivityEvent = lastActivityEvent.Time + } + if deletedAt.Valid { + agent.DeletedAt = deletedAt.Time + } + if startedAt.Valid { + agent.StartedAt = startedAt.Time + } + if runtimeBrokerID.Valid { + agent.RuntimeBrokerID = runtimeBrokerID.String + } + if message.Valid { + agent.Message = message.String + } + if toolName.Valid { + agent.ToolName = toolName.String + } + + agents = append(agents, agent) + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return agents, nil +} diff --git a/pkg/store/postgres/allowlist.go b/pkg/store/postgres/allowlist.go new file mode 100644 index 00000000..5e8fc6a8 --- /dev/null +++ b/pkg/store/postgres/allowlist.go @@ -0,0 +1,330 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +func (s *PostgresStore) AddAllowListEntry(ctx context.Context, entry *store.AllowListEntry) error { + if entry.Created.IsZero() { + entry.Created = time.Now() + } + entry.Email = strings.ToLower(entry.Email) + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO allow_list (id, email, note, added_by, invite_id, created) + VALUES ($1, $2, $3, $4, $5, $6) + `, entry.ID, entry.Email, entry.Note, entry.AddedBy, entry.InviteID, entry.Created) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) RemoveAllowListEntry(ctx context.Context, email string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM allow_list WHERE email = $1", strings.ToLower(email)) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) GetAllowListEntry(ctx context.Context, email string) (*store.AllowListEntry, error) { + entry := &store.AllowListEntry{} + err := s.db.QueryRowContext(ctx, ` + SELECT id, email, note, added_by, invite_id, created + FROM allow_list WHERE email = $1 + `, strings.ToLower(email)).Scan( + &entry.ID, &entry.Email, &entry.Note, &entry.AddedBy, &entry.InviteID, &entry.Created, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + return entry, nil +} + +func (s *PostgresStore) ListAllowListEntries(ctx context.Context, opts store.ListOptions) (*store.ListResult[store.AllowListEntry], error) { + var totalCount int + if err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM allow_list").Scan(&totalCount); err != nil { + return nil, err + } + + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + + var conditions []string + var args []interface{} + + if opts.Cursor != "" { + var cursorCreated time.Time + if err := s.db.QueryRowContext(ctx, "SELECT created FROM allow_list WHERE id = $1", opts.Cursor).Scan(&cursorCreated); err != nil { + return nil, fmt.Errorf("invalid cursor: %w", err) + } + conditions = append(conditions, `(created < $1 OR (created = $2 AND id < $3))`) + args = append(args, cursorCreated, cursorCreated, opts.Cursor) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + query := fmt.Sprintf(` + SELECT id, email, note, added_by, invite_id, created + FROM allow_list %s ORDER BY created DESC, id DESC LIMIT $%d + `, whereClause, len(args)+1) + args = append(args, limit+1) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var entries []store.AllowListEntry + for rows.Next() { + var entry store.AllowListEntry + if err := rows.Scan(&entry.ID, &entry.Email, &entry.Note, &entry.AddedBy, &entry.InviteID, &entry.Created); err != nil { + return nil, err + } + entries = append(entries, entry) + } + if err := rows.Err(); err != nil { + return nil, err + } + if entries == nil { + entries = []store.AllowListEntry{} + } + + var nextCursor string + if len(entries) > limit { + nextCursor = entries[limit-1].ID + entries = entries[:limit] + } + + return &store.ListResult[store.AllowListEntry]{ + Items: entries, + TotalCount: totalCount, + NextCursor: nextCursor, + }, nil +} + +func (s *PostgresStore) IsEmailAllowListed(ctx context.Context, email string) (bool, error) { + var count int + err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM allow_list WHERE email = $1", strings.ToLower(email)).Scan(&count) + if err != nil { + return false, err + } + return count > 0, nil +} + +func (s *PostgresStore) UpdateAllowListEntryInviteID(ctx context.Context, email string, inviteID string) error { + result, err := s.db.ExecContext(ctx, + "UPDATE allow_list SET invite_id = $1 WHERE email = $2", + inviteID, strings.ToLower(email)) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) ListAllowListEntriesWithInvites(ctx context.Context, opts store.ListOptions) (*store.ListResult[store.AllowListEntryWithInvite], error) { + var totalCount int + if err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM allow_list").Scan(&totalCount); err != nil { + return nil, err + } + + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + + var conditions []string + var args []interface{} + + if opts.Cursor != "" { + var cursorCreated time.Time + if err := s.db.QueryRowContext(ctx, "SELECT created FROM allow_list WHERE id = $1", opts.Cursor).Scan(&cursorCreated); err != nil { + return nil, fmt.Errorf("invalid cursor: %w", err) + } + conditions = append(conditions, `(a.created < $1 OR (a.created = $2 AND a.id < $3))`) + args = append(args, cursorCreated, cursorCreated, opts.Cursor) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + query := fmt.Sprintf(` + SELECT a.id, a.email, a.note, a.added_by, a.invite_id, a.created, + i.code_prefix, i.max_uses, i.use_count, i.expires_at, i.revoked + FROM allow_list a + LEFT JOIN invite_codes i ON a.invite_id = i.id AND a.invite_id != '' + %s ORDER BY a.created DESC, a.id DESC LIMIT $%d + `, whereClause, len(args)+1) + args = append(args, limit+1) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var entries []store.AllowListEntryWithInvite + for rows.Next() { + var entry store.AllowListEntryWithInvite + var codePrefix sql.NullString + var maxUses, useCount, revoked sql.NullInt64 + var expiresAt sql.NullTime + if err := rows.Scan( + &entry.ID, &entry.Email, &entry.Note, &entry.AddedBy, &entry.InviteID, &entry.Created, + &codePrefix, &maxUses, &useCount, &expiresAt, &revoked, + ); err != nil { + return nil, err + } + if codePrefix.Valid { + entry.InviteCodePrefix = codePrefix.String + } + if maxUses.Valid { + entry.InviteMaxUses = int(maxUses.Int64) + } + if useCount.Valid { + entry.InviteUseCount = int(useCount.Int64) + } + if expiresAt.Valid { + entry.InviteExpiresAt = expiresAt.Time + } + if revoked.Valid { + entry.InviteRevoked = revoked.Int64 != 0 + } + entries = append(entries, entry) + } + if err := rows.Err(); err != nil { + return nil, err + } + if entries == nil { + entries = []store.AllowListEntryWithInvite{} + } + + var nextCursor string + if len(entries) > limit { + nextCursor = entries[limit-1].ID + entries = entries[:limit] + } + + return &store.ListResult[store.AllowListEntryWithInvite]{ + Items: entries, + TotalCount: totalCount, + NextCursor: nextCursor, + }, nil +} + +func (s *PostgresStore) BulkAddAllowListEntries(ctx context.Context, entries []*store.AllowListEntry) (int, int, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return 0, 0, err + } + defer tx.Rollback() + + stmt, err := tx.PrepareContext(ctx, ` + INSERT INTO allow_list (id, email, note, added_by, invite_id, created) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (email) DO NOTHING + `) + if err != nil { + return 0, 0, err + } + defer stmt.Close() + + added := 0 + skipped := 0 + now := time.Now() + + for _, entry := range entries { + entry.Email = strings.ToLower(entry.Email) + if entry.Created.IsZero() { + entry.Created = now + } + result, err := stmt.ExecContext(ctx, entry.ID, entry.Email, entry.Note, entry.AddedBy, entry.InviteID, entry.Created) + if err != nil { + return added, skipped, err + } + rows, _ := result.RowsAffected() + if rows > 0 { + added++ + } else { + skipped++ + } + } + + if err := tx.Commit(); err != nil { + return 0, 0, err + } + return added, skipped, nil +} + +func (s *PostgresStore) ListEmailDomains(ctx context.Context) ([]string, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT DISTINCT split_part(email, '@', 2) AS domain + FROM users + WHERE email LIKE '%@%' + ORDER BY domain + `) + if err != nil { + return nil, err + } + defer rows.Close() + + var domains []string + for rows.Next() { + var domain string + if err := rows.Scan(&domain); err != nil { + return nil, err + } + domains = append(domains, domain) + } + return domains, rows.Err() +} diff --git a/pkg/store/postgres/brokers.go b/pkg/store/postgres/brokers.go new file mode 100644 index 00000000..e46218cf --- /dev/null +++ b/pkg/store/postgres/brokers.go @@ -0,0 +1,274 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +func (s *PostgresStore) CreateRuntimeBroker(ctx context.Context, broker *store.RuntimeBroker) error { + now := time.Now() + broker.Created = now + broker.Updated = now + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO runtime_brokers ( + id, name, slug, type, mode, version, + status, connection_state, last_heartbeat, + capabilities, supported_harnesses, resources, runtimes, + labels, annotations, endpoint, + created_at, updated_at, created_by, auto_provide + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20) + `, + broker.ID, broker.Name, broker.Slug, "", "", broker.Version, + broker.Status, broker.ConnectionState, broker.LastHeartbeat, + marshalJSON(broker.Capabilities), "[]", + "{}", marshalJSON(broker.Profiles), + marshalJSON(broker.Labels), marshalJSON(broker.Annotations), broker.Endpoint, + broker.Created, broker.Updated, nullableString(broker.CreatedBy), boolToInt(broker.AutoProvide), + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) GetRuntimeBroker(ctx context.Context, id string) (*store.RuntimeBroker, error) { + broker := &store.RuntimeBroker{} + var capabilities, profiles, labels, annotations string + var brokerType, brokerMode, harnesses, resources string // unused columns kept for schema compatibility + var lastHeartbeat sql.NullTime + var createdBy sql.NullString + + err := s.db.QueryRowContext(ctx, ` + SELECT id, name, slug, type, mode, version, + status, connection_state, last_heartbeat, + capabilities, supported_harnesses, resources, runtimes, + labels, annotations, endpoint, + created_at, updated_at, created_by, auto_provide + FROM runtime_brokers WHERE id = $1 + `, id).Scan( + &broker.ID, &broker.Name, &broker.Slug, &brokerType, &brokerMode, &broker.Version, + &broker.Status, &broker.ConnectionState, &lastHeartbeat, + &capabilities, &harnesses, &resources, &profiles, + &labels, &annotations, &broker.Endpoint, + &broker.Created, &broker.Updated, &createdBy, &broker.AutoProvide, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + if lastHeartbeat.Valid { + broker.LastHeartbeat = lastHeartbeat.Time + } + if createdBy.Valid { + broker.CreatedBy = createdBy.String + } + unmarshalJSON(capabilities, &broker.Capabilities) + unmarshalJSON(profiles, &broker.Profiles) + unmarshalJSON(labels, &broker.Labels) + unmarshalJSON(annotations, &broker.Annotations) + + return broker, nil +} + +func (s *PostgresStore) GetRuntimeBrokerByName(ctx context.Context, name string) (*store.RuntimeBroker, error) { + var id string + err := s.db.QueryRowContext(ctx, "SELECT id FROM runtime_brokers WHERE LOWER(name) = LOWER($1)", name).Scan(&id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + return s.GetRuntimeBroker(ctx, id) +} + +func (s *PostgresStore) UpdateRuntimeBroker(ctx context.Context, broker *store.RuntimeBroker) error { + broker.Updated = time.Now() + + result, err := s.db.ExecContext(ctx, ` + UPDATE runtime_brokers SET + name = $1, slug = $2, type = $3, version = $4, + status = $5, connection_state = $6, last_heartbeat = $7, + capabilities = $8, supported_harnesses = $9, resources = $10, runtimes = $11, + labels = $12, annotations = $13, endpoint = $14, + updated_at = $15, auto_provide = $16 + WHERE id = $17 + `, + broker.Name, broker.Slug, "", broker.Version, + broker.Status, broker.ConnectionState, broker.LastHeartbeat, + marshalJSON(broker.Capabilities), "[]", + "{}", marshalJSON(broker.Profiles), + marshalJSON(broker.Labels), marshalJSON(broker.Annotations), broker.Endpoint, + broker.Updated, boolToInt(broker.AutoProvide), + broker.ID, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeleteRuntimeBroker(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM runtime_brokers WHERE id = $1", id) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) ListRuntimeBrokers(ctx context.Context, filter store.RuntimeBrokerFilter, opts store.ListOptions) (*store.ListResult[store.RuntimeBroker], error) { + var conditions []string + var args []interface{} + + if filter.Status != "" { + conditions = append(conditions, fmt.Sprintf("status = $%d", len(args)+1)) + args = append(args, filter.Status) + } + if filter.ProjectID != "" { + conditions = append(conditions, fmt.Sprintf("(id IN (SELECT broker_id FROM project_contributors WHERE project_id = $%d) OR auto_provide = 1)", len(args)+1)) + args = append(args, filter.ProjectID) + } + if filter.Name != "" { + conditions = append(conditions, fmt.Sprintf("LOWER(name) = LOWER($%d)", len(args)+1)) + args = append(args, filter.Name) + } + if filter.AutoProvide != nil { + conditions = append(conditions, fmt.Sprintf("auto_provide = $%d", len(args)+1)) + args = append(args, boolToInt(*filter.AutoProvide)) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + var totalCount int + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM runtime_brokers %s", whereClause) + if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil { + return nil, err + } + + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + + query := fmt.Sprintf(` + SELECT id, name, slug, type, mode, version, + status, connection_state, last_heartbeat, + capabilities, supported_harnesses, resources, runtimes, + labels, annotations, endpoint, + created_at, updated_at, created_by, auto_provide + FROM runtime_brokers %s ORDER BY created_at DESC LIMIT $%d + `, whereClause, len(args)+1) + args = append(args, limit) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var hosts []store.RuntimeBroker + for rows.Next() { + var broker store.RuntimeBroker + var capabilities, profiles, labels, annotations string + var brokerType, brokerMode, harnesses, resources string // unused columns kept for schema compatibility + var lastHeartbeat sql.NullTime + var createdBy sql.NullString + + if err := rows.Scan( + &broker.ID, &broker.Name, &broker.Slug, &brokerType, &brokerMode, &broker.Version, + &broker.Status, &broker.ConnectionState, &lastHeartbeat, + &capabilities, &harnesses, &resources, &profiles, + &labels, &annotations, &broker.Endpoint, + &broker.Created, &broker.Updated, &createdBy, &broker.AutoProvide, + ); err != nil { + return nil, err + } + + if lastHeartbeat.Valid { + broker.LastHeartbeat = lastHeartbeat.Time + } + if createdBy.Valid { + broker.CreatedBy = createdBy.String + } + unmarshalJSON(capabilities, &broker.Capabilities) + unmarshalJSON(profiles, &broker.Profiles) + unmarshalJSON(labels, &broker.Labels) + unmarshalJSON(annotations, &broker.Annotations) + + hosts = append(hosts, broker) + } + + return &store.ListResult[store.RuntimeBroker]{ + Items: hosts, + TotalCount: totalCount, + }, nil +} + +func (s *PostgresStore) UpdateRuntimeBrokerHeartbeat(ctx context.Context, id string, status string) error { + now := time.Now() + + result, err := s.db.ExecContext(ctx, ` + UPDATE runtime_brokers SET + status = $1, + last_heartbeat = $2, + updated_at = $3 + WHERE id = $4 + `, status, now, now, id) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} diff --git a/pkg/store/postgres/brokersecret.go b/pkg/store/postgres/brokersecret.go new file mode 100644 index 00000000..65196d19 --- /dev/null +++ b/pkg/store/postgres/brokersecret.go @@ -0,0 +1,290 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres provides a PostgreSQL implementation of the Store interface. +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +// ============================================================================ +// Broker Secret Operations +// ============================================================================ + +// CreateBrokerSecret creates a new broker secret record. +func (s *PostgresStore) CreateBrokerSecret(ctx context.Context, secret *store.BrokerSecret) error { + if secret.BrokerID == "" { + return store.ErrInvalidInput + } + + now := time.Now() + if secret.CreatedAt.IsZero() { + secret.CreatedAt = now + } + if secret.Algorithm == "" { + secret.Algorithm = store.BrokerSecretAlgorithmHMACSHA256 + } + if secret.Status == "" { + secret.Status = store.BrokerSecretStatusActive + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO broker_secrets ( + broker_id, secret_key, algorithm, + created_at, rotated_at, expires_at, status + ) VALUES ($1, $2, $3, $4, $5, $6, $7) + `, + secret.BrokerID, secret.SecretKey, secret.Algorithm, + secret.CreatedAt, nullableTime(secret.RotatedAt), nullableTime(secret.ExpiresAt), secret.Status, + ) + if err != nil { + if strings.Contains(err.Error(), "unique constraint") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + if strings.Contains(err.Error(), "foreign key constraint") { + return fmt.Errorf("broker %s does not exist: %w", secret.BrokerID, store.ErrNotFound) + } + return err + } + return nil +} + +// GetBrokerSecret retrieves a broker secret by broker ID. +func (s *PostgresStore) GetBrokerSecret(ctx context.Context, brokerID string) (*store.BrokerSecret, error) { + secret := &store.BrokerSecret{} + var rotatedAt, expiresAt sql.NullTime + + err := s.db.QueryRowContext(ctx, ` + SELECT broker_id, secret_key, algorithm, + created_at, rotated_at, expires_at, status + FROM broker_secrets WHERE broker_id = $1 + `, brokerID).Scan( + &secret.BrokerID, &secret.SecretKey, &secret.Algorithm, + &secret.CreatedAt, &rotatedAt, &expiresAt, &secret.Status, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + if rotatedAt.Valid { + secret.RotatedAt = rotatedAt.Time + } + if expiresAt.Valid { + secret.ExpiresAt = expiresAt.Time + } + + return secret, nil +} + +// GetActiveSecrets retrieves all active and deprecated secrets for a broker. +// This supports dual-secret validation during rotation grace periods. +func (s *PostgresStore) GetActiveSecrets(ctx context.Context, brokerID string) ([]*store.BrokerSecret, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT broker_id, secret_key, algorithm, + created_at, rotated_at, expires_at, status + FROM broker_secrets + WHERE broker_id = $1 AND status IN ($2, $3) + ORDER BY created_at DESC + `, brokerID, store.BrokerSecretStatusActive, store.BrokerSecretStatusDeprecated) + if err != nil { + return nil, err + } + defer rows.Close() + + var secrets []*store.BrokerSecret + for rows.Next() { + secret := &store.BrokerSecret{} + var rotatedAt, expiresAt sql.NullTime + + if err := rows.Scan( + &secret.BrokerID, &secret.SecretKey, &secret.Algorithm, + &secret.CreatedAt, &rotatedAt, &expiresAt, &secret.Status, + ); err != nil { + return nil, err + } + + if rotatedAt.Valid { + secret.RotatedAt = rotatedAt.Time + } + if expiresAt.Valid { + secret.ExpiresAt = expiresAt.Time + } + + secrets = append(secrets, secret) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return secrets, nil +} + +// UpdateBrokerSecret updates an existing broker secret. +func (s *PostgresStore) UpdateBrokerSecret(ctx context.Context, secret *store.BrokerSecret) error { + result, err := s.db.ExecContext(ctx, ` + UPDATE broker_secrets SET + secret_key = $1, + algorithm = $2, + rotated_at = $3, + expires_at = $4, + status = $5 + WHERE broker_id = $6 + `, + secret.SecretKey, secret.Algorithm, + nullableTime(secret.RotatedAt), nullableTime(secret.ExpiresAt), secret.Status, + secret.BrokerID, + ) + if err != nil { + return err + } + + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +// DeleteBrokerSecret removes a broker secret. +func (s *PostgresStore) DeleteBrokerSecret(ctx context.Context, brokerID string) error { + result, err := s.db.ExecContext(ctx, ` + DELETE FROM broker_secrets WHERE broker_id = $1 + `, brokerID) + if err != nil { + return err + } + + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +// ============================================================================ +// Broker Join Token Operations +// ============================================================================ + +// CreateJoinToken creates a new join token for broker registration. +func (s *PostgresStore) CreateJoinToken(ctx context.Context, token *store.BrokerJoinToken) error { + if token.BrokerID == "" || token.TokenHash == "" { + return store.ErrInvalidInput + } + + now := time.Now() + if token.CreatedAt.IsZero() { + token.CreatedAt = now + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO broker_join_tokens ( + broker_id, token_hash, expires_at, created_at, created_by + ) VALUES ($1, $2, $3, $4, $5) + `, + token.BrokerID, token.TokenHash, token.ExpiresAt, token.CreatedAt, token.CreatedBy, + ) + if err != nil { + if strings.Contains(err.Error(), "unique constraint") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + if strings.Contains(err.Error(), "foreign key constraint") { + return store.ErrNotFound + } + return err + } + return nil +} + +// GetJoinToken retrieves a join token by token hash. +func (s *PostgresStore) GetJoinToken(ctx context.Context, tokenHash string) (*store.BrokerJoinToken, error) { + token := &store.BrokerJoinToken{} + + err := s.db.QueryRowContext(ctx, ` + SELECT broker_id, token_hash, expires_at, created_at, created_by + FROM broker_join_tokens WHERE token_hash = $1 + `, tokenHash).Scan( + &token.BrokerID, &token.TokenHash, &token.ExpiresAt, &token.CreatedAt, &token.CreatedBy, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + return token, nil +} + +// GetJoinTokenByBrokerID retrieves a join token by broker ID. +func (s *PostgresStore) GetJoinTokenByBrokerID(ctx context.Context, brokerID string) (*store.BrokerJoinToken, error) { + token := &store.BrokerJoinToken{} + + err := s.db.QueryRowContext(ctx, ` + SELECT broker_id, token_hash, expires_at, created_at, created_by + FROM broker_join_tokens WHERE broker_id = $1 + `, brokerID).Scan( + &token.BrokerID, &token.TokenHash, &token.ExpiresAt, &token.CreatedAt, &token.CreatedBy, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + return token, nil +} + +// DeleteJoinToken removes a join token by broker ID. +func (s *PostgresStore) DeleteJoinToken(ctx context.Context, brokerID string) error { + result, err := s.db.ExecContext(ctx, ` + DELETE FROM broker_join_tokens WHERE broker_id = $1 + `, brokerID) + if err != nil { + return err + } + + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +// CleanExpiredJoinTokens removes all expired join tokens. +func (s *PostgresStore) CleanExpiredJoinTokens(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, ` + DELETE FROM broker_join_tokens WHERE expires_at < $1 + `, time.Now()) + return err +} diff --git a/pkg/store/postgres/driver.go b/pkg/store/postgres/driver.go new file mode 100644 index 00000000..27137177 --- /dev/null +++ b/pkg/store/postgres/driver.go @@ -0,0 +1,19 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !no_postgres + +package postgres + +import _ "github.com/lib/pq" // Postgres driver (database/sql driver name: postgres) diff --git a/pkg/store/postgres/envvars.go b/pkg/store/postgres/envvars.go new file mode 100644 index 00000000..3781e1f2 --- /dev/null +++ b/pkg/store/postgres/envvars.go @@ -0,0 +1,199 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +func (s *PostgresStore) CreateEnvVar(ctx context.Context, envVar *store.EnvVar) error { + now := time.Now() + envVar.Created = now + envVar.Updated = now + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO env_vars (id, key, value, scope, scope_id, description, sensitive, injection_mode, secret, created_at, updated_at, created_by) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + `, + envVar.ID, envVar.Key, envVar.Value, envVar.Scope, envVar.ScopeID, + envVar.Description, boolToInt(envVar.Sensitive), envVar.InjectionMode, boolToInt(envVar.Secret), + envVar.Created, envVar.Updated, envVar.CreatedBy, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) GetEnvVar(ctx context.Context, key, scope, scopeID string) (*store.EnvVar, error) { + envVar := &store.EnvVar{} + + err := s.db.QueryRowContext(ctx, ` + SELECT id, key, value, scope, scope_id, description, sensitive, injection_mode, secret, created_at, updated_at, created_by + FROM env_vars WHERE key = $1 AND scope = $2 AND scope_id = $3 + `, key, scope, scopeID).Scan( + &envVar.ID, &envVar.Key, &envVar.Value, &envVar.Scope, &envVar.ScopeID, + &envVar.Description, &envVar.Sensitive, &envVar.InjectionMode, &envVar.Secret, + &envVar.Created, &envVar.Updated, &envVar.CreatedBy, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + return envVar, nil +} + +func (s *PostgresStore) UpdateEnvVar(ctx context.Context, envVar *store.EnvVar) error { + envVar.Updated = time.Now() + + result, err := s.db.ExecContext(ctx, ` + UPDATE env_vars SET + value = $1, description = $2, sensitive = $3, injection_mode = $4, secret = $5, updated_at = $6 + WHERE key = $7 AND scope = $8 AND scope_id = $9 + `, + envVar.Value, envVar.Description, boolToInt(envVar.Sensitive), envVar.InjectionMode, boolToInt(envVar.Secret), envVar.Updated, + envVar.Key, envVar.Scope, envVar.ScopeID, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) UpsertEnvVar(ctx context.Context, envVar *store.EnvVar) (bool, error) { + now := time.Now() + envVar.Updated = now + + // Check if it already exists + existing, err := s.GetEnvVar(ctx, envVar.Key, envVar.Scope, envVar.ScopeID) + if err != nil && err != store.ErrNotFound { + return false, err + } + + if existing != nil { + // Update existing + envVar.ID = existing.ID + envVar.Created = existing.Created + envVar.CreatedBy = existing.CreatedBy + if err := s.UpdateEnvVar(ctx, envVar); err != nil { + return false, err + } + return false, nil + } + + // Create new + envVar.Created = now + if err := s.CreateEnvVar(ctx, envVar); err != nil { + return false, err + } + return true, nil +} + +func (s *PostgresStore) DeleteEnvVar(ctx context.Context, key, scope, scopeID string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM env_vars WHERE key = $1 AND scope = $2 AND scope_id = $3", key, scope, scopeID) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeleteEnvVarsByScope(ctx context.Context, scope, scopeID string) (int, error) { + result, err := s.db.ExecContext(ctx, "DELETE FROM env_vars WHERE scope = $1 AND scope_id = $2", scope, scopeID) + if err != nil { + return 0, err + } + n, err := result.RowsAffected() + if err != nil { + return 0, err + } + return int(n), nil +} + +func (s *PostgresStore) ListEnvVars(ctx context.Context, filter store.EnvVarFilter) ([]store.EnvVar, error) { + var conditions []string + var args []interface{} + + if filter.Scope != "" { + conditions = append(conditions, fmt.Sprintf("scope = $%d", len(args)+1)) + args = append(args, filter.Scope) + } + if filter.ScopeID != "" { + conditions = append(conditions, fmt.Sprintf("scope_id = $%d", len(args)+1)) + args = append(args, filter.ScopeID) + } + if filter.Key != "" { + conditions = append(conditions, fmt.Sprintf("key = $%d", len(args)+1)) + args = append(args, filter.Key) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + query := fmt.Sprintf(` + SELECT id, key, value, scope, scope_id, description, sensitive, injection_mode, secret, created_at, updated_at, created_by + FROM env_vars %s ORDER BY key + `, whereClause) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var envVars []store.EnvVar + for rows.Next() { + var envVar store.EnvVar + if err := rows.Scan( + &envVar.ID, &envVar.Key, &envVar.Value, &envVar.Scope, &envVar.ScopeID, + &envVar.Description, &envVar.Sensitive, &envVar.InjectionMode, &envVar.Secret, + &envVar.Created, &envVar.Updated, &envVar.CreatedBy, + ); err != nil { + return nil, err + } + envVars = append(envVars, envVar) + } + + return envVars, nil +} diff --git a/pkg/store/postgres/gcp_service_account.go b/pkg/store/postgres/gcp_service_account.go new file mode 100644 index 00000000..a8195f9b --- /dev/null +++ b/pkg/store/postgres/gcp_service_account.go @@ -0,0 +1,200 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +func (s *PostgresStore) CreateGCPServiceAccount(ctx context.Context, sa *store.GCPServiceAccount) error { + if sa.CreatedAt.IsZero() { + sa.CreatedAt = time.Now() + } + + scopesStr := strings.Join(sa.DefaultScopes, ",") + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO gcp_service_accounts (id, scope, scope_id, email, project_id, display_name, default_scopes, verified, verified_at, created_by, created_at, managed, managed_by) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)`, + sa.ID, sa.Scope, sa.ScopeID, sa.Email, sa.ProjectID, sa.DisplayName, + scopesStr, boolToInt(sa.Verified), nullableTime(sa.VerifiedAt), sa.CreatedBy, sa.CreatedAt, + boolToInt(sa.Managed), sa.ManagedBy, + ) + if err != nil { + if strings.Contains(err.Error(), "unique constraint") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) GetGCPServiceAccount(ctx context.Context, id string) (*store.GCPServiceAccount, error) { + var sa store.GCPServiceAccount + var scopesStr string + var verifiedAt sql.NullTime + + err := s.db.QueryRowContext(ctx, ` + SELECT id, scope, scope_id, email, project_id, display_name, default_scopes, verified, verified_at, created_by, created_at, managed, managed_by + FROM gcp_service_accounts WHERE id = $1`, id, + ).Scan(&sa.ID, &sa.Scope, &sa.ScopeID, &sa.Email, &sa.ProjectID, &sa.DisplayName, + &scopesStr, &sa.Verified, &verifiedAt, &sa.CreatedBy, &sa.CreatedAt, + &sa.Managed, &sa.ManagedBy, + ) + if err == sql.ErrNoRows { + return nil, store.ErrNotFound + } + if err != nil { + return nil, err + } + + if scopesStr != "" { + sa.DefaultScopes = strings.Split(scopesStr, ",") + } + if verifiedAt.Valid { + sa.VerifiedAt = verifiedAt.Time + } + + return &sa, nil +} + +func (s *PostgresStore) UpdateGCPServiceAccount(ctx context.Context, sa *store.GCPServiceAccount) error { + scopesStr := strings.Join(sa.DefaultScopes, ",") + + result, err := s.db.ExecContext(ctx, ` + UPDATE gcp_service_accounts + SET email = $1, project_id = $2, display_name = $3, default_scopes = $4, verified = $5, verified_at = $6, managed = $7, managed_by = $8 + WHERE id = $9`, + sa.Email, sa.ProjectID, sa.DisplayName, scopesStr, boolToInt(sa.Verified), nullableTime(sa.VerifiedAt), + boolToInt(sa.Managed), sa.ManagedBy, sa.ID, + ) + if err != nil { + return err + } + rows, _ := result.RowsAffected() + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeleteGCPServiceAccount(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, `DELETE FROM gcp_service_accounts WHERE id = $1`, id) + if err != nil { + return err + } + rows, _ := result.RowsAffected() + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) ListGCPServiceAccounts(ctx context.Context, filter store.GCPServiceAccountFilter) ([]store.GCPServiceAccount, error) { + query := `SELECT id, scope, scope_id, email, project_id, display_name, default_scopes, verified, verified_at, created_by, created_at, managed, managed_by FROM gcp_service_accounts WHERE 1=1` + var args []interface{} + n := 1 + + if filter.Scope != "" { + query += fmt.Sprintf(` AND scope = $%d`, n) + args = append(args, filter.Scope) + n++ + } + if filter.ScopeID != "" { + query += fmt.Sprintf(` AND scope_id = $%d`, n) + args = append(args, filter.ScopeID) + n++ + } + if filter.Email != "" { + query += fmt.Sprintf(` AND email = $%d`, n) + args = append(args, filter.Email) + n++ + } + if filter.Managed != nil { + query += fmt.Sprintf(` AND managed = $%d`, n) + args = append(args, boolToInt(*filter.Managed)) + n++ + } + + query += ` ORDER BY created_at DESC` + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var results []store.GCPServiceAccount + for rows.Next() { + var sa store.GCPServiceAccount + var scopesStr string + var verifiedAt sql.NullTime + + if err := rows.Scan(&sa.ID, &sa.Scope, &sa.ScopeID, &sa.Email, &sa.ProjectID, &sa.DisplayName, + &scopesStr, &sa.Verified, &verifiedAt, &sa.CreatedBy, &sa.CreatedAt, + &sa.Managed, &sa.ManagedBy, + ); err != nil { + return nil, err + } + + if scopesStr != "" { + sa.DefaultScopes = strings.Split(scopesStr, ",") + } + if verifiedAt.Valid { + sa.VerifiedAt = verifiedAt.Time + } + + results = append(results, sa) + } + + return results, rows.Err() +} + +func (s *PostgresStore) CountGCPServiceAccounts(ctx context.Context, filter store.GCPServiceAccountFilter) (int, error) { + query := `SELECT COUNT(*) FROM gcp_service_accounts WHERE 1=1` + var args []interface{} + n := 1 + + if filter.Scope != "" { + query += fmt.Sprintf(` AND scope = $%d`, n) + args = append(args, filter.Scope) + n++ + } + if filter.ScopeID != "" { + query += fmt.Sprintf(` AND scope_id = $%d`, n) + args = append(args, filter.ScopeID) + n++ + } + if filter.Email != "" { + query += fmt.Sprintf(` AND email = $%d`, n) + args = append(args, filter.Email) + n++ + } + if filter.Managed != nil { + query += fmt.Sprintf(` AND managed = $%d`, n) + args = append(args, boolToInt(*filter.Managed)) + n++ + } + + var count int + err := s.db.QueryRowContext(ctx, query, args...).Scan(&count) + return count, err +} diff --git a/pkg/store/postgres/github_installation.go b/pkg/store/postgres/github_installation.go new file mode 100644 index 00000000..ad103d1b --- /dev/null +++ b/pkg/store/postgres/github_installation.go @@ -0,0 +1,185 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +func (s *PostgresStore) CreateGitHubInstallation(ctx context.Context, installation *store.GitHubInstallation) error { + if installation.CreatedAt.IsZero() { + installation.CreatedAt = time.Now() + } + if installation.UpdatedAt.IsZero() { + installation.UpdatedAt = installation.CreatedAt + } + if installation.Status == "" { + installation.Status = store.GitHubInstallationStatusActive + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO github_installations (installation_id, account_login, account_type, app_id, repositories, status, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT DO NOTHING`, + installation.InstallationID, installation.AccountLogin, installation.AccountType, + installation.AppID, marshalJSON(installation.Repositories), + installation.Status, installation.CreatedAt, installation.UpdatedAt, + ) + if err != nil { + return err + } + return nil +} + +func (s *PostgresStore) GetGitHubInstallation(ctx context.Context, installationID int64) (*store.GitHubInstallation, error) { + var inst store.GitHubInstallation + var repos string + + err := s.db.QueryRowContext(ctx, ` + SELECT installation_id, account_login, account_type, app_id, repositories, status, created_at, updated_at + FROM github_installations WHERE installation_id = $1`, installationID, + ).Scan(&inst.InstallationID, &inst.AccountLogin, &inst.AccountType, + &inst.AppID, &repos, &inst.Status, &inst.CreatedAt, &inst.UpdatedAt, + ) + if err == sql.ErrNoRows { + return nil, store.ErrNotFound + } + if err != nil { + return nil, err + } + + unmarshalJSON(repos, &inst.Repositories) + return &inst, nil +} + +func (s *PostgresStore) UpdateGitHubInstallation(ctx context.Context, installation *store.GitHubInstallation) error { + installation.UpdatedAt = time.Now() + + result, err := s.db.ExecContext(ctx, ` + UPDATE github_installations SET + account_login = $1, account_type = $2, app_id = $3, + repositories = $4, status = $5, updated_at = $6 + WHERE installation_id = $7`, + installation.AccountLogin, installation.AccountType, installation.AppID, + marshalJSON(installation.Repositories), installation.Status, installation.UpdatedAt, + installation.InstallationID, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeleteGitHubInstallation(ctx context.Context, installationID int64) error { + result, err := s.db.ExecContext(ctx, `DELETE FROM github_installations WHERE installation_id = $1`, installationID) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) GetInstallationForRepository(ctx context.Context, repoFullName string) (*store.GitHubInstallation, error) { + // Search active installations whose repositories JSON array contains the repo. + installations, err := s.ListGitHubInstallations(ctx, store.GitHubInstallationFilter{ + Status: store.GitHubInstallationStatusActive, + }) + if err != nil { + return nil, err + } + + for i := range installations { + for _, repo := range installations[i].Repositories { + if repo == repoFullName { + return &installations[i], nil + } + } + } + return nil, store.ErrNotFound +} + +func (s *PostgresStore) ListGitHubInstallations(ctx context.Context, filter store.GitHubInstallationFilter) ([]store.GitHubInstallation, error) { + query := "SELECT installation_id, account_login, account_type, app_id, repositories, status, created_at, updated_at FROM github_installations WHERE 1=1" + var args []interface{} + n := 1 + + if filter.AccountLogin != "" { + query += fmt.Sprintf(" AND account_login = $%d", n) + args = append(args, filter.AccountLogin) + n++ + } + if filter.Status != "" { + query += fmt.Sprintf(" AND status = $%d", n) + args = append(args, filter.Status) + n++ + } + if filter.AppID != 0 { + query += fmt.Sprintf(" AND app_id = $%d", n) + args = append(args, filter.AppID) + n++ + } + + query += " ORDER BY created_at ASC" + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var results []store.GitHubInstallation + for rows.Next() { + var inst store.GitHubInstallation + var repos string + + if err := rows.Scan(&inst.InstallationID, &inst.AccountLogin, &inst.AccountType, + &inst.AppID, &repos, &inst.Status, &inst.CreatedAt, &inst.UpdatedAt); err != nil { + return nil, err + } + + unmarshalJSON(repos, &inst.Repositories) + results = append(results, inst) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + // Ensure we never return nil slice (return empty slice instead) + if results == nil { + results = []store.GitHubInstallation{} + } + + return results, nil +} diff --git a/pkg/store/postgres/groups.go b/pkg/store/postgres/groups.go new file mode 100644 index 00000000..cde649e0 --- /dev/null +++ b/pkg/store/postgres/groups.go @@ -0,0 +1,555 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +func (s *PostgresStore) CreateGroup(ctx context.Context, group *store.Group) error { + now := time.Now() + group.Created = now + group.Updated = now + if group.GroupType == "" { + group.GroupType = store.GroupTypeExplicit + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO groups (id, name, slug, description, group_type, project_id, parent_id, labels, annotations, created_at, updated_at, created_by, owner_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + `, + group.ID, group.Name, group.Slug, group.Description, + group.GroupType, nullableString(group.ProjectID), + nullableString(group.ParentID), + marshalJSON(group.Labels), marshalJSON(group.Annotations), + group.Created, group.Updated, group.CreatedBy, group.OwnerID, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) GetGroup(ctx context.Context, id string) (*store.Group, error) { + group := &store.Group{} + var labels, annotations string + var parentID, projectID sql.NullString + + err := s.db.QueryRowContext(ctx, ` + SELECT id, name, slug, description, group_type, project_id, parent_id, labels, annotations, created_at, updated_at, created_by, owner_id + FROM groups WHERE id = $1 + `, id).Scan( + &group.ID, &group.Name, &group.Slug, &group.Description, + &group.GroupType, &projectID, + &parentID, + &labels, &annotations, + &group.Created, &group.Updated, &group.CreatedBy, &group.OwnerID, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + if parentID.Valid { + group.ParentID = parentID.String + } + if projectID.Valid { + group.ProjectID = projectID.String + } + unmarshalJSON(labels, &group.Labels) + unmarshalJSON(annotations, &group.Annotations) + if group.GroupType == "" { + group.GroupType = store.GroupTypeExplicit + } + + return group, nil +} + +func (s *PostgresStore) GetGroupBySlug(ctx context.Context, slug string) (*store.Group, error) { + var id string + err := s.db.QueryRowContext(ctx, "SELECT id FROM groups WHERE slug = $1", slug).Scan(&id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + return s.GetGroup(ctx, id) +} + +func (s *PostgresStore) UpdateGroup(ctx context.Context, group *store.Group) error { + group.Updated = time.Now() + + result, err := s.db.ExecContext(ctx, ` + UPDATE groups SET + name = $1, slug = $2, description = $3, group_type = $4, project_id = $5, + parent_id = $6, labels = $7, annotations = $8, + updated_at = $9, owner_id = $10 + WHERE id = $11 + `, + group.Name, group.Slug, group.Description, + group.GroupType, nullableString(group.ProjectID), + nullableString(group.ParentID), + marshalJSON(group.Labels), marshalJSON(group.Annotations), + group.Updated, group.OwnerID, + group.ID, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeleteGroup(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM groups WHERE id = $1", id) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) ListGroups(ctx context.Context, filter store.GroupFilter, opts store.ListOptions) (*store.ListResult[store.Group], error) { + var conditions []string + var args []interface{} + + if filter.OwnerID != "" { + conditions = append(conditions, fmt.Sprintf("owner_id = $%d", len(args)+1)) + args = append(args, filter.OwnerID) + } + if filter.ParentID != "" { + conditions = append(conditions, fmt.Sprintf("parent_id = $%d", len(args)+1)) + args = append(args, filter.ParentID) + } + if filter.GroupType != "" { + conditions = append(conditions, fmt.Sprintf("group_type = $%d", len(args)+1)) + args = append(args, filter.GroupType) + } + if filter.ProjectID != "" { + conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)+1)) + args = append(args, filter.ProjectID) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + var totalCount int + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM groups %s", whereClause) + if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil { + return nil, err + } + + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + + query := fmt.Sprintf(` + SELECT id, name, slug, description, group_type, project_id, parent_id, labels, annotations, created_at, updated_at, created_by, owner_id + FROM groups %s ORDER BY created_at DESC LIMIT $%d + `, whereClause, len(args)+1) + args = append(args, limit) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var groups []store.Group + for rows.Next() { + var group store.Group + var labels, annotations string + var parentID, projectID sql.NullString + + if err := rows.Scan( + &group.ID, &group.Name, &group.Slug, &group.Description, + &group.GroupType, &projectID, + &parentID, + &labels, &annotations, + &group.Created, &group.Updated, &group.CreatedBy, &group.OwnerID, + ); err != nil { + return nil, err + } + + if parentID.Valid { + group.ParentID = parentID.String + } + if projectID.Valid { + group.ProjectID = projectID.String + } + unmarshalJSON(labels, &group.Labels) + unmarshalJSON(annotations, &group.Annotations) + if group.GroupType == "" { + group.GroupType = store.GroupTypeExplicit + } + + groups = append(groups, group) + } + + return &store.ListResult[store.Group]{ + Items: groups, + TotalCount: totalCount, + }, nil +} + +func (s *PostgresStore) AddGroupMember(ctx context.Context, member *store.GroupMember) error { + if member.AddedAt.IsZero() { + member.AddedAt = time.Now() + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO group_members (group_id, member_type, member_id, role, added_at, added_by) + VALUES ($1, $2, $3, $4, $5, $6) + `, + member.GroupID, member.MemberType, member.MemberID, member.Role, member.AddedAt, member.AddedBy, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") || strings.Contains(err.Error(), "PRIMARY KEY constraint failed") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) UpdateGroupMemberRole(ctx context.Context, groupID, memberType, memberID, newRole string) error { + result, err := s.db.ExecContext(ctx, + `UPDATE group_members SET role = $1 WHERE group_id = $2 AND member_type = $3 AND member_id = $4`, + newRole, groupID, memberType, memberID, + ) + if err != nil { + return err + } + rows, _ := result.RowsAffected() + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) RemoveGroupMember(ctx context.Context, groupID, memberType, memberID string) error { + result, err := s.db.ExecContext(ctx, + "DELETE FROM group_members WHERE group_id = $1 AND member_type = $2 AND member_id = $3", + groupID, memberType, memberID, + ) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) GetGroupMembers(ctx context.Context, groupID string) ([]store.GroupMember, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT group_id, member_type, member_id, role, added_at, added_by + FROM group_members WHERE group_id = $1 + `, groupID) + if err != nil { + return nil, err + } + defer rows.Close() + + var members []store.GroupMember + for rows.Next() { + var member store.GroupMember + if err := rows.Scan( + &member.GroupID, &member.MemberType, &member.MemberID, &member.Role, &member.AddedAt, &member.AddedBy, + ); err != nil { + return nil, err + } + members = append(members, member) + } + + return members, nil +} + +func (s *PostgresStore) GetUserGroups(ctx context.Context, userID string) ([]store.GroupMember, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT group_id, member_type, member_id, role, added_at, added_by + FROM group_members WHERE member_type = 'user' AND member_id = $1 + `, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var memberships []store.GroupMember + for rows.Next() { + var member store.GroupMember + if err := rows.Scan( + &member.GroupID, &member.MemberType, &member.MemberID, &member.Role, &member.AddedAt, &member.AddedBy, + ); err != nil { + return nil, err + } + memberships = append(memberships, member) + } + + return memberships, nil +} + +func (s *PostgresStore) GetGroupMembership(ctx context.Context, groupID, memberType, memberID string) (*store.GroupMember, error) { + member := &store.GroupMember{} + + err := s.db.QueryRowContext(ctx, ` + SELECT group_id, member_type, member_id, role, added_at, added_by + FROM group_members WHERE group_id = $1 AND member_type = $2 AND member_id = $3 + `, groupID, memberType, memberID).Scan( + &member.GroupID, &member.MemberType, &member.MemberID, &member.Role, &member.AddedAt, &member.AddedBy, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + return member, nil +} + +// WouldCreateCycle checks if adding memberGroupID as a member of groupID would create a cycle. +// A cycle exists if groupID is reachable from memberGroupID by following the containment relationship. +// Example: if A contains B, and we try to add A as member of B, we'd have A->B->A (cycle). +func (s *PostgresStore) WouldCreateCycle(ctx context.Context, groupID, memberGroupID string) (bool, error) { + // If they're the same, it's a direct cycle + if groupID == memberGroupID { + return true, nil + } + + // Check if groupID is reachable from memberGroupID by traversing DOWN the containment graph + // (i.e., checking what groups memberGroupID contains, and what those contain, etc.) + visited := make(map[string]bool) + return s.hasPathDown(ctx, memberGroupID, groupID, visited) +} + +// hasPathDown checks if 'target' is reachable from 'current' by following containment. +// It looks at what groups 'current' contains as members. +func (s *PostgresStore) hasPathDown(ctx context.Context, current, target string, visited map[string]bool) (bool, error) { + if current == target { + return true, nil + } + if visited[current] { + return false, nil + } + visited[current] = true + + // Get all groups that 'current' contains (groups where current is the group_id) + rows, err := s.db.QueryContext(ctx, + "SELECT member_id FROM group_members WHERE member_type = 'group' AND group_id = $1", current) + if err != nil { + return false, err + } + defer rows.Close() + + for rows.Next() { + var childGroupID string + if err := rows.Scan(&childGroupID); err != nil { + return false, err + } + found, err := s.hasPathDown(ctx, childGroupID, target, visited) + if err != nil { + return false, err + } + if found { + return true, nil + } + } + + return false, nil +} + +// GetEffectiveGroups returns all groups a user belongs to, including transitive memberships. +func (s *PostgresStore) GetEffectiveGroups(ctx context.Context, userID string) ([]string, error) { + // Start with direct group memberships + directMemberships, err := s.GetUserGroups(ctx, userID) + if err != nil { + return nil, err + } + + effectiveGroups := make(map[string]bool) + for _, m := range directMemberships { + effectiveGroups[m.GroupID] = true + // Add transitive group memberships + if err := s.addTransitiveGroups(ctx, m.GroupID, effectiveGroups); err != nil { + return nil, err + } + } + + result := make([]string, 0, len(effectiveGroups)) + for groupID := range effectiveGroups { + result = append(result, groupID) + } + + return result, nil +} + +// addTransitiveGroups recursively adds all groups that contain the given group. +func (s *PostgresStore) addTransitiveGroups(ctx context.Context, groupID string, visited map[string]bool) error { + // Find all groups where this group is a member + rows, err := s.db.QueryContext(ctx, + "SELECT group_id FROM group_members WHERE member_type = 'group' AND member_id = $1", groupID) + if err != nil { + return err + } + + // Collect all parent group IDs first, then close rows before recursing + // This avoids issues with SQLite connections during recursive queries + var parentGroupIDs []string + for rows.Next() { + var parentGroupID string + if err := rows.Scan(&parentGroupID); err != nil { + rows.Close() + return err + } + parentGroupIDs = append(parentGroupIDs, parentGroupID) + } + rows.Close() + + // Now recurse after rows are closed + for _, parentGroupID := range parentGroupIDs { + if !visited[parentGroupID] { + visited[parentGroupID] = true + if err := s.addTransitiveGroups(ctx, parentGroupID, visited); err != nil { + return err + } + } + } + + return nil +} + +// GetGroupByProjectID retrieves the project_agents group associated with a project. +func (s *PostgresStore) GetGroupByProjectID(ctx context.Context, projectID string) (*store.Group, error) { + var id string + err := s.db.QueryRowContext(ctx, "SELECT id FROM groups WHERE project_id = $1 AND group_type = $2 LIMIT 1", + projectID, store.GroupTypeProjectAgents).Scan(&id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + return s.GetGroup(ctx, id) +} + +// GetEffectiveGroupsForAgent returns all groups an agent belongs to. +func (s *PostgresStore) GetEffectiveGroupsForAgent(ctx context.Context, agentID string) ([]string, error) { + return nil, nil +} + +// CheckDelegatedAccess is a stub for the SQLite store. Delegation resolution +// is implemented in the Ent adapter. +func (s *PostgresStore) CheckDelegatedAccess(ctx context.Context, agentID string, conditions *store.PolicyConditions) (bool, error) { + return false, nil +} + +// GetGroupsByIDs is a stub for the SQLite store. Group retrieval by IDs +// is implemented in the Ent adapter. +func (s *PostgresStore) GetGroupsByIDs(ctx context.Context, ids []string) ([]store.Group, error) { + if len(ids) == 0 { + return nil, nil + } + + placeholders := make([]string, len(ids)) + args := make([]interface{}, len(ids)) + for i, id := range ids { + placeholders[i] = fmt.Sprintf("$%d", i+1) + args[i] = id + } + + rows, err := s.db.QueryContext(ctx, + `SELECT id, name, slug, description, group_type, project_id, parent_id, labels, annotations, created_at, updated_at, created_by, owner_id + FROM groups WHERE id IN (`+strings.Join(placeholders, ",")+`)`, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var groups []store.Group + for rows.Next() { + var g store.Group + var labels, annotations string + var parentID, projectID sql.NullString + if err := rows.Scan( + &g.ID, &g.Name, &g.Slug, &g.Description, + &g.GroupType, &projectID, + &parentID, + &labels, &annotations, + &g.Created, &g.Updated, &g.CreatedBy, &g.OwnerID, + ); err != nil { + return nil, err + } + if parentID.Valid { + g.ParentID = parentID.String + } + if projectID.Valid { + g.ProjectID = projectID.String + } + unmarshalJSON(labels, &g.Labels) + unmarshalJSON(annotations, &g.Annotations) + if g.GroupType == "" { + g.GroupType = store.GroupTypeExplicit + } + groups = append(groups, g) + } + + return groups, rows.Err() +} + +func (s *PostgresStore) CountGroupMembersByRole(ctx context.Context, groupID, role string) (int, error) { + var count int + err := s.db.QueryRowContext(ctx, + `SELECT COUNT(*) FROM group_members WHERE group_id = $1 AND role = $2`, + groupID, role, + ).Scan(&count) + if err != nil { + return 0, err + } + return count, nil +} diff --git a/pkg/store/postgres/harness_configs.go b/pkg/store/postgres/harness_configs.go new file mode 100644 index 00000000..e0cfd1d0 --- /dev/null +++ b/pkg/store/postgres/harness_configs.go @@ -0,0 +1,385 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +func (s *PostgresStore) CreateHarnessConfig(ctx context.Context, hc *store.HarnessConfig) error { + now := time.Now() + hc.Created = now + hc.Updated = now + + if hc.Status == "" { + hc.Status = store.HarnessConfigStatusActive + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO harness_configs ( + id, name, slug, display_name, description, harness, config, + content_hash, scope, scope_id, + storage_uri, storage_bucket, storage_path, files, + locked, status, + owner_id, created_by, updated_by, visibility, + created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22) + `, + hc.ID, hc.Name, hc.Slug, nullableString(hc.DisplayName), nullableString(hc.Description), + hc.Harness, marshalJSON(hc.Config), + nullableString(hc.ContentHash), hc.Scope, nullableString(hc.ScopeID), + nullableString(hc.StorageURI), nullableString(hc.StorageBucket), nullableString(hc.StoragePath), marshalJSON(hc.Files), + hc.Locked, hc.Status, + nullableString(hc.OwnerID), nullableString(hc.CreatedBy), nullableString(hc.UpdatedBy), hc.Visibility, + hc.Created, hc.Updated, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) GetHarnessConfig(ctx context.Context, id string) (*store.HarnessConfig, error) { + hc := &store.HarnessConfig{} + var configJSON, filesJSON string + var displayName, description, contentHash, scopeID sql.NullString + var storageURI, storageBucket, storagePath sql.NullString + var createdBy, updatedBy, ownerID, visibility sql.NullString + + err := s.db.QueryRowContext(ctx, ` + SELECT id, name, slug, display_name, description, harness, config, + content_hash, scope, scope_id, + storage_uri, storage_bucket, storage_path, files, + locked, status, + owner_id, created_by, updated_by, visibility, + created_at, updated_at + FROM harness_configs WHERE id = $1 + `, id).Scan( + &hc.ID, &hc.Name, &hc.Slug, &displayName, &description, + &hc.Harness, &configJSON, + &contentHash, &hc.Scope, &scopeID, + &storageURI, &storageBucket, &storagePath, &filesJSON, + &hc.Locked, &hc.Status, + &ownerID, &createdBy, &updatedBy, &visibility, + &hc.Created, &hc.Updated, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + if displayName.Valid { + hc.DisplayName = displayName.String + } + if description.Valid { + hc.Description = description.String + } + if contentHash.Valid { + hc.ContentHash = contentHash.String + } + if scopeID.Valid { + hc.ScopeID = scopeID.String + } + if storageURI.Valid { + hc.StorageURI = storageURI.String + } + if storageBucket.Valid { + hc.StorageBucket = storageBucket.String + } + if storagePath.Valid { + hc.StoragePath = storagePath.String + } + if ownerID.Valid { + hc.OwnerID = ownerID.String + } + if createdBy.Valid { + hc.CreatedBy = createdBy.String + } + if updatedBy.Valid { + hc.UpdatedBy = updatedBy.String + } + if visibility.Valid { + hc.Visibility = visibility.String + } + unmarshalJSON(configJSON, &hc.Config) + unmarshalJSON(filesJSON, &hc.Files) + + return hc, nil +} + +func (s *PostgresStore) GetHarnessConfigBySlug(ctx context.Context, slug, scope, scopeID string) (*store.HarnessConfig, error) { + var id string + var err error + + if scopeID != "" { + err = s.db.QueryRowContext(ctx, "SELECT id FROM harness_configs WHERE slug = $1 AND scope = $2 AND scope_id = $3", slug, scope, scopeID).Scan(&id) + } else { + err = s.db.QueryRowContext(ctx, "SELECT id FROM harness_configs WHERE slug = $1 AND scope = $2", slug, scope).Scan(&id) + } + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + return s.GetHarnessConfig(ctx, id) +} + +func (s *PostgresStore) UpdateHarnessConfig(ctx context.Context, hc *store.HarnessConfig) error { + hc.Updated = time.Now() + + result, err := s.db.ExecContext(ctx, ` + UPDATE harness_configs SET + name = $1, slug = $2, display_name = $3, description = $4, + harness = $5, config = $6, + content_hash = $7, scope = $8, scope_id = $9, + storage_uri = $10, storage_bucket = $11, storage_path = $12, files = $13, + locked = $14, status = $15, + owner_id = $16, updated_by = $17, visibility = $18, + updated_at = $19 + WHERE id = $20 + `, + hc.Name, hc.Slug, nullableString(hc.DisplayName), nullableString(hc.Description), + hc.Harness, marshalJSON(hc.Config), + nullableString(hc.ContentHash), hc.Scope, nullableString(hc.ScopeID), + nullableString(hc.StorageURI), nullableString(hc.StorageBucket), nullableString(hc.StoragePath), marshalJSON(hc.Files), + hc.Locked, hc.Status, + nullableString(hc.OwnerID), nullableString(hc.UpdatedBy), hc.Visibility, + hc.Updated, + hc.ID, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeleteHarnessConfig(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM harness_configs WHERE id = $1", id) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeleteHarnessConfigsByScope(ctx context.Context, scope, scopeID string) (int, error) { + result, err := s.db.ExecContext(ctx, "DELETE FROM harness_configs WHERE scope = $1 AND scope_id = $2", scope, scopeID) + if err != nil { + return 0, err + } + n, err := result.RowsAffected() + if err != nil { + return 0, err + } + return int(n), nil +} + +func (s *PostgresStore) ListHarnessConfigs(ctx context.Context, filter store.HarnessConfigFilter, opts store.ListOptions) (*store.ListResult[store.HarnessConfig], error) { + var conditions []string + var args []interface{} + + if filter.Name != "" { + n := len(args) + 1 + conditions = append(conditions, fmt.Sprintf("(name = $%d OR slug = $%d)", n, n+1)) + args = append(args, filter.Name, filter.Name) + } + if filter.Scope != "" { + n := len(args) + 1 + conditions = append(conditions, fmt.Sprintf("scope = $%d", n)) + args = append(args, filter.Scope) + } + if filter.ScopeID != "" { + n := len(args) + 1 + conditions = append(conditions, fmt.Sprintf("scope_id = $%d", n)) + args = append(args, filter.ScopeID) + } else if filter.ProjectID != "" && filter.Scope == "" { + // When projectId is set without scope, return global + project-scoped configs for this project + n := len(args) + 1 + conditions = append(conditions, fmt.Sprintf("(scope = 'global' OR (scope = 'project' AND scope_id = $%d))", n)) + args = append(args, filter.ProjectID) + } else if (filter.Scope == "project" || filter.Scope == "grove") && filter.ProjectID != "" { + // projectId combined with an explicit scope filter — narrow to that project. + n := len(args) + 1 + conditions = append(conditions, fmt.Sprintf("scope_id = $%d", n)) + args = append(args, filter.ProjectID) + } + if filter.Harness != "" { + n := len(args) + 1 + conditions = append(conditions, fmt.Sprintf("harness = $%d", n)) + args = append(args, filter.Harness) + } + if filter.OwnerID != "" { + n := len(args) + 1 + conditions = append(conditions, fmt.Sprintf("owner_id = $%d", n)) + args = append(args, filter.OwnerID) + } + if filter.Status != "" { + n := len(args) + 1 + conditions = append(conditions, fmt.Sprintf("status = $%d", n)) + args = append(args, filter.Status) + } + if filter.Search != "" { + n := len(args) + 1 + conditions = append(conditions, fmt.Sprintf("(name LIKE $%d OR description LIKE $%d)", n, n+1)) + searchPattern := "%" + filter.Search + "%" + args = append(args, searchPattern, searchPattern) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + var totalCount int + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM harness_configs %s", whereClause) + if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil { + return nil, err + } + + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + + n := len(args) + 1 + query := fmt.Sprintf(` + SELECT id, name, slug, display_name, description, harness, config, + content_hash, scope, scope_id, + storage_uri, storage_bucket, storage_path, files, + locked, status, + owner_id, created_by, updated_by, visibility, + created_at, updated_at + FROM harness_configs %s ORDER BY created_at DESC LIMIT $%d + `, whereClause, n) + args = append(args, limit) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var harnessConfigs []store.HarnessConfig + for rows.Next() { + var hc store.HarnessConfig + var configJSON, filesJSON string + var displayName, description, contentHash, scopeID sql.NullString + var storageURI, storageBucket, storagePath sql.NullString + var createdBy, updatedBy, ownerID, visibility sql.NullString + + if err := rows.Scan( + &hc.ID, &hc.Name, &hc.Slug, &displayName, &description, + &hc.Harness, &configJSON, + &contentHash, &hc.Scope, &scopeID, + &storageURI, &storageBucket, &storagePath, &filesJSON, + &hc.Locked, &hc.Status, + &ownerID, &createdBy, &updatedBy, &visibility, + &hc.Created, &hc.Updated, + ); err != nil { + return nil, err + } + + if displayName.Valid { + hc.DisplayName = displayName.String + } + if description.Valid { + hc.Description = description.String + } + if contentHash.Valid { + hc.ContentHash = contentHash.String + } + if scopeID.Valid { + hc.ScopeID = scopeID.String + } + if storageURI.Valid { + hc.StorageURI = storageURI.String + } + if storageBucket.Valid { + hc.StorageBucket = storageBucket.String + } + if storagePath.Valid { + hc.StoragePath = storagePath.String + } + if ownerID.Valid { + hc.OwnerID = ownerID.String + } + if createdBy.Valid { + hc.CreatedBy = createdBy.String + } + if updatedBy.Valid { + hc.UpdatedBy = updatedBy.String + } + if visibility.Valid { + hc.Visibility = visibility.String + } + unmarshalJSON(configJSON, &hc.Config) + unmarshalJSON(filesJSON, &hc.Files) + + harnessConfigs = append(harnessConfigs, hc) + } + + // When querying by ProjectID without explicit Scope, the query returns both + // global and project-scoped configs. Deduplicate by slug, preferring the more + // specific scope (project > global). + if filter.ProjectID != "" && filter.Scope == "" { + seen := make(map[string]int, len(harnessConfigs)) + deduped := make([]store.HarnessConfig, 0, len(harnessConfigs)) + for _, hc := range harnessConfigs { + if idx, exists := seen[hc.Slug]; exists { + if hc.Scope == "project" && deduped[idx].Scope == "global" { + deduped[idx] = hc + } + } else { + seen[hc.Slug] = len(deduped) + deduped = append(deduped, hc) + } + } + harnessConfigs = deduped + totalCount = len(deduped) + } + + return &store.ListResult[store.HarnessConfig]{ + Items: harnessConfigs, + TotalCount: totalCount, + }, nil +} diff --git a/pkg/store/postgres/invites.go b/pkg/store/postgres/invites.go new file mode 100644 index 00000000..8ee1d2fa --- /dev/null +++ b/pkg/store/postgres/invites.go @@ -0,0 +1,259 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +func (s *PostgresStore) CreateInviteCode(ctx context.Context, invite *store.InviteCode) error { + if invite.Created.IsZero() { + invite.Created = time.Now() + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO invite_codes (id, code_hash, code_prefix, max_uses, use_count, expires_at, revoked, created_by, note, created) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + `, invite.ID, invite.CodeHash, invite.CodePrefix, invite.MaxUses, invite.UseCount, + invite.ExpiresAt, boolToInt(invite.Revoked), invite.CreatedBy, invite.Note, invite.Created) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) GetInviteCodeByHash(ctx context.Context, codeHash string) (*store.InviteCode, error) { + invite := &store.InviteCode{} + var revoked int + err := s.db.QueryRowContext(ctx, ` + SELECT id, code_hash, code_prefix, max_uses, use_count, expires_at, revoked, created_by, note, created + FROM invite_codes WHERE code_hash = $1 + `, codeHash).Scan( + &invite.ID, &invite.CodeHash, &invite.CodePrefix, &invite.MaxUses, &invite.UseCount, + &invite.ExpiresAt, &revoked, &invite.CreatedBy, &invite.Note, &invite.Created, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + invite.Revoked = revoked != 0 + return invite, nil +} + +func (s *PostgresStore) GetInviteCode(ctx context.Context, id string) (*store.InviteCode, error) { + invite := &store.InviteCode{} + var revoked int + err := s.db.QueryRowContext(ctx, ` + SELECT id, code_hash, code_prefix, max_uses, use_count, expires_at, revoked, created_by, note, created + FROM invite_codes WHERE id = $1 + `, id).Scan( + &invite.ID, &invite.CodeHash, &invite.CodePrefix, &invite.MaxUses, &invite.UseCount, + &invite.ExpiresAt, &revoked, &invite.CreatedBy, &invite.Note, &invite.Created, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + invite.Revoked = revoked != 0 + return invite, nil +} + +func (s *PostgresStore) ListInviteCodes(ctx context.Context, opts store.ListOptions) (*store.ListResult[store.InviteCode], error) { + var totalCount int + if err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM invite_codes").Scan(&totalCount); err != nil { + return nil, err + } + + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + + var conditions []string + var args []interface{} + + if opts.Cursor != "" { + conditions = append(conditions, `(created < (SELECT created FROM invite_codes WHERE id = $1) + OR (created = (SELECT created FROM invite_codes WHERE id = $2) AND id < $3))`) + args = append(args, opts.Cursor, opts.Cursor, opts.Cursor) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + query := fmt.Sprintf(` + SELECT id, code_prefix, max_uses, use_count, expires_at, revoked, created_by, note, created + FROM invite_codes %s ORDER BY created DESC, id DESC LIMIT $%d + `, whereClause, len(args)+1) + args = append(args, limit+1) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var invites []store.InviteCode + for rows.Next() { + var invite store.InviteCode + var revoked int + if err := rows.Scan( + &invite.ID, &invite.CodePrefix, &invite.MaxUses, &invite.UseCount, + &invite.ExpiresAt, &revoked, &invite.CreatedBy, &invite.Note, &invite.Created, + ); err != nil { + return nil, err + } + invite.Revoked = revoked != 0 + invites = append(invites, invite) + } + if err := rows.Err(); err != nil { + return nil, err + } + if invites == nil { + invites = []store.InviteCode{} + } + + var nextCursor string + if len(invites) > limit { + nextCursor = invites[limit-1].ID + invites = invites[:limit] + } + + return &store.ListResult[store.InviteCode]{ + Items: invites, + TotalCount: totalCount, + NextCursor: nextCursor, + }, nil +} + +func (s *PostgresStore) IncrementInviteUseCount(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, ` + UPDATE invite_codes SET use_count = use_count + 1 + WHERE id = $1 AND revoked = 0 AND expires_at > NOW() + AND (max_uses = 0 OR use_count < max_uses) + `, id) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) RevokeInviteCode(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, "UPDATE invite_codes SET revoked = 1 WHERE id = $1", id) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeleteInviteCode(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM invite_codes WHERE id = $1", id) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) GetInviteStats(ctx context.Context) (*store.InviteStats, error) { + stats := &store.InviteStats{} + + // Count pending (active, not expired, not exhausted) invites + err := s.db.QueryRowContext(ctx, ` + SELECT COUNT(*) FROM invite_codes + WHERE revoked = 0 + AND expires_at > NOW() + AND (max_uses = 0 OR use_count < max_uses) + `).Scan(&stats.PendingInvites) + if err != nil { + return nil, err + } + + // Total redemptions across all invites + err = s.db.QueryRowContext(ctx, ` + SELECT COALESCE(SUM(use_count), 0) FROM invite_codes + `).Scan(&stats.TotalRedemptions) + if err != nil { + return nil, err + } + + // Allow list count + err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM allow_list`).Scan(&stats.AllowListCount) + if err != nil { + return nil, err + } + + // Recent invites that have been redeemed (use_count > 0), ordered by most recently created + rows, err := s.db.QueryContext(ctx, ` + SELECT id, code_prefix, use_count, max_uses, expires_at, note, created + FROM invite_codes + WHERE use_count > 0 + ORDER BY created DESC + LIMIT 10 + `) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var info store.InviteCodeInfo + if err := rows.Scan(&info.ID, &info.CodePrefix, &info.UseCount, &info.MaxUses, &info.ExpiresAt, &info.Note, &info.Created); err != nil { + return nil, err + } + stats.RecentRedemptions = append(stats.RecentRedemptions, info) + } + if stats.RecentRedemptions == nil { + stats.RecentRedemptions = []store.InviteCodeInfo{} + } + + return stats, rows.Err() +} diff --git a/pkg/store/postgres/maintenance.go b/pkg/store/postgres/maintenance.go new file mode 100644 index 00000000..a973f4ae --- /dev/null +++ b/pkg/store/postgres/maintenance.go @@ -0,0 +1,269 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +// ============================================================================ +// Maintenance Operation Operations +// ============================================================================ + +// ListMaintenanceOperations returns all registered operations and migrations. +func (s *PostgresStore) ListMaintenanceOperations(ctx context.Context) ([]store.MaintenanceOperation, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, key, title, description, category, status, + created_at, started_at, completed_at, started_by, result, metadata + FROM maintenance_operations + ORDER BY category, created_at + `) + if err != nil { + return nil, err + } + defer rows.Close() + + var ops []store.MaintenanceOperation + for rows.Next() { + var op store.MaintenanceOperation + var startedAt, completedAt sql.NullTime + var startedBy, result, metadata sql.NullString + + if err := rows.Scan( + &op.ID, &op.Key, &op.Title, &op.Description, &op.Category, &op.Status, + &op.CreatedAt, &startedAt, &completedAt, &startedBy, &result, &metadata, + ); err != nil { + return nil, err + } + + if startedAt.Valid { + op.StartedAt = &startedAt.Time + } + if completedAt.Valid { + op.CompletedAt = &completedAt.Time + } + op.StartedBy = startedBy.String + op.Result = result.String + op.Metadata = metadata.String + + ops = append(ops, op) + } + return ops, rows.Err() +} + +// GetMaintenanceOperation returns a single operation by key. +func (s *PostgresStore) GetMaintenanceOperation(ctx context.Context, key string) (*store.MaintenanceOperation, error) { + op := &store.MaintenanceOperation{} + var startedAt, completedAt sql.NullTime + var startedBy, result, metadata sql.NullString + + err := s.db.QueryRowContext(ctx, ` + SELECT id, key, title, description, category, status, + created_at, started_at, completed_at, started_by, result, metadata + FROM maintenance_operations WHERE key = $1 + `, key).Scan( + &op.ID, &op.Key, &op.Title, &op.Description, &op.Category, &op.Status, + &op.CreatedAt, &startedAt, &completedAt, &startedBy, &result, &metadata, + ) + if err == sql.ErrNoRows { + return nil, store.ErrNotFound + } + if err != nil { + return nil, err + } + + if startedAt.Valid { + op.StartedAt = &startedAt.Time + } + if completedAt.Valid { + op.CompletedAt = &completedAt.Time + } + op.StartedBy = startedBy.String + op.Result = result.String + op.Metadata = metadata.String + + return op, nil +} + +// UpdateMaintenanceOperation updates an operation's status and result fields. +func (s *PostgresStore) UpdateMaintenanceOperation(ctx context.Context, op *store.MaintenanceOperation) error { + res, err := s.db.ExecContext(ctx, ` + UPDATE maintenance_operations + SET status = $1, started_at = $2, completed_at = $3, started_by = $4, result = $5, metadata = $6 + WHERE key = $7 + `, + op.Status, + nullableTime(timeFromPtr(op.StartedAt)), + nullableTime(timeFromPtr(op.CompletedAt)), + nullableString(op.StartedBy), + nullableString(op.Result), + nullableString(op.Metadata), + op.Key, + ) + if err != nil { + return err + } + n, _ := res.RowsAffected() + if n == 0 { + return store.ErrNotFound + } + return nil +} + +// CreateMaintenanceRun inserts a new run record. +func (s *PostgresStore) CreateMaintenanceRun(ctx context.Context, run *store.MaintenanceOperationRun) error { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO maintenance_operation_runs ( + id, operation_key, status, started_at, completed_at, started_by, result, log + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + `, + run.ID, run.OperationKey, run.Status, run.StartedAt, + nullableTime(timeFromPtr(run.CompletedAt)), + nullableString(run.StartedBy), + nullableString(run.Result), + run.Log, + ) + return err +} + +// UpdateMaintenanceRun updates a run's status, result, and log. +func (s *PostgresStore) UpdateMaintenanceRun(ctx context.Context, run *store.MaintenanceOperationRun) error { + res, err := s.db.ExecContext(ctx, ` + UPDATE maintenance_operation_runs + SET status = $1, completed_at = $2, result = $3, log = $4 + WHERE id = $5 + `, + run.Status, + nullableTime(timeFromPtr(run.CompletedAt)), + nullableString(run.Result), + run.Log, + run.ID, + ) + if err != nil { + return err + } + n, _ := res.RowsAffected() + if n == 0 { + return store.ErrNotFound + } + return nil +} + +// GetMaintenanceRun returns a single run by ID. +func (s *PostgresStore) GetMaintenanceRun(ctx context.Context, id string) (*store.MaintenanceOperationRun, error) { + run := &store.MaintenanceOperationRun{} + var completedAt sql.NullTime + var startedBy, result sql.NullString + + err := s.db.QueryRowContext(ctx, ` + SELECT id, operation_key, status, started_at, completed_at, started_by, result, log + FROM maintenance_operation_runs WHERE id = $1 + `, id).Scan( + &run.ID, &run.OperationKey, &run.Status, &run.StartedAt, + &completedAt, &startedBy, &result, &run.Log, + ) + if err == sql.ErrNoRows { + return nil, store.ErrNotFound + } + if err != nil { + return nil, err + } + + if completedAt.Valid { + run.CompletedAt = &completedAt.Time + } + run.StartedBy = startedBy.String + run.Result = result.String + + return run, nil +} + +// AbortRunningMaintenanceOps transitions any "running" operation runs and +// migrations to "failed" with an appropriate result message. This is called at +// server startup to clean up operations interrupted by a restart. +func (s *PostgresStore) AbortRunningMaintenanceOps(ctx context.Context) (int64, int64, error) { + now := sql.NullTime{Time: time.Now(), Valid: true} + result := `{"error":"aborted: server was restarted while operation was running"}` + + // Abort stalled runs. + res, err := s.db.ExecContext(ctx, ` + UPDATE maintenance_operation_runs + SET status = 'failed', completed_at = $1, result = $2 + WHERE status = 'running' + `, now, result) + if err != nil { + return 0, 0, err + } + runs, _ := res.RowsAffected() + + // Reset stalled migrations back to pending (they can be retried). + res, err = s.db.ExecContext(ctx, ` + UPDATE maintenance_operations + SET status = 'pending', started_at = NULL, completed_at = NULL, result = $1 + WHERE status = 'running' AND category = 'migration' + `, result) + if err != nil { + return runs, 0, err + } + migrations, _ := res.RowsAffected() + + return runs, migrations, nil +} + +// ListMaintenanceRuns returns runs for a given operation key, ordered by started_at DESC. +func (s *PostgresStore) ListMaintenanceRuns(ctx context.Context, operationKey string, limit int) ([]store.MaintenanceOperationRun, error) { + if limit <= 0 { + limit = 20 + } + + rows, err := s.db.QueryContext(ctx, ` + SELECT id, operation_key, status, started_at, completed_at, started_by, result, log + FROM maintenance_operation_runs + WHERE operation_key = $1 + ORDER BY started_at DESC + LIMIT $2 + `, operationKey, limit) + if err != nil { + return nil, err + } + defer rows.Close() + + var runs []store.MaintenanceOperationRun + for rows.Next() { + var run store.MaintenanceOperationRun + var completedAt sql.NullTime + var startedBy, result sql.NullString + + if err := rows.Scan( + &run.ID, &run.OperationKey, &run.Status, &run.StartedAt, + &completedAt, &startedBy, &result, &run.Log, + ); err != nil { + return nil, err + } + + if completedAt.Valid { + run.CompletedAt = &completedAt.Time + } + run.StartedBy = startedBy.String + run.Result = result.String + + runs = append(runs, run) + } + return runs, rows.Err() +} diff --git a/pkg/store/postgres/messages.go b/pkg/store/postgres/messages.go new file mode 100644 index 00000000..e918ea69 --- /dev/null +++ b/pkg/store/postgres/messages.go @@ -0,0 +1,226 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres provides a Postgres implementation of the Store interface. +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +// ============================================================================ +// Message Operations +// ============================================================================ + +// CreateMessage persists a new message. +func (s *PostgresStore) CreateMessage(ctx context.Context, msg *store.Message) error { + if msg.ID == "" || msg.ProjectID == "" || msg.Msg == "" { + return store.ErrInvalidInput + } + if msg.CreatedAt.IsZero() { + msg.CreatedAt = time.Now() + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO messages ( + id, project_id, sender, sender_id, recipient, recipient_id, + msg, type, urgent, broadcasted, read, agent_id, group_id, created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + `, + msg.ID, msg.ProjectID, msg.Sender, msg.SenderID, msg.Recipient, msg.RecipientID, + msg.Msg, msg.Type, + boolToInt(msg.Urgent), boolToInt(msg.Broadcasted), boolToInt(msg.Read), + msg.AgentID, msg.GroupID, msg.CreatedAt, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +// GetMessage returns a single message by ID. +func (s *PostgresStore) GetMessage(ctx context.Context, id string) (*store.Message, error) { + row := s.db.QueryRowContext(ctx, ` + SELECT id, project_id, sender, sender_id, recipient, recipient_id, + msg, type, urgent, broadcasted, read, agent_id, group_id, created_at + FROM messages + WHERE id = $1 + `, id) + + var msg store.Message + var urgent, broadcasted, read int + if err := row.Scan( + &msg.ID, &msg.ProjectID, &msg.Sender, &msg.SenderID, &msg.Recipient, &msg.RecipientID, + &msg.Msg, &msg.Type, &urgent, &broadcasted, &read, + &msg.AgentID, &msg.GroupID, &msg.CreatedAt, + ); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + msg.Urgent = urgent != 0 + msg.Broadcasted = broadcasted != 0 + msg.Read = read != 0 + return &msg, nil +} + +// ListMessages returns messages matching the given filter, ordered by created_at DESC. +func (s *PostgresStore) ListMessages(ctx context.Context, filter store.MessageFilter, opts store.ListOptions) (*store.ListResult[store.Message], error) { + var conditions []string + var args []interface{} + + if filter.ProjectID != "" { + args = append(args, filter.ProjectID) + conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args))) + } + if filter.AgentID != "" { + args = append(args, filter.AgentID) + conditions = append(conditions, fmt.Sprintf("agent_id = $%d", len(args))) + } + if filter.RecipientID != "" { + args = append(args, filter.RecipientID) + conditions = append(conditions, fmt.Sprintf("recipient_id = $%d", len(args))) + } + if filter.SenderID != "" { + args = append(args, filter.SenderID) + conditions = append(conditions, fmt.Sprintf("sender_id = $%d", len(args))) + } + if filter.ParticipantID != "" { + args = append(args, filter.ParticipantID, filter.ParticipantID) + conditions = append(conditions, fmt.Sprintf("(recipient_id = $%d OR sender_id = $%d)", len(args)-1, len(args))) + } + if filter.OnlyUnread { + conditions = append(conditions, "read = 0") + } + if filter.Type != "" { + args = append(args, filter.Type) + conditions = append(conditions, fmt.Sprintf("type = $%d", len(args))) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + var totalCount int + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM messages %s", whereClause) + if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil { + return nil, err + } + + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + if limit > 200 { + limit = 200 + } + + args = append(args, limit+1) + query := fmt.Sprintf(` + SELECT id, project_id, sender, sender_id, recipient, recipient_id, + msg, type, urgent, broadcasted, read, agent_id, group_id, created_at + FROM messages %s ORDER BY created_at DESC LIMIT $%d + `, whereClause, len(args)) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var msgs []store.Message + for rows.Next() { + var msg store.Message + var urgent, broadcasted, read int + if err := rows.Scan( + &msg.ID, &msg.ProjectID, &msg.Sender, &msg.SenderID, &msg.Recipient, &msg.RecipientID, + &msg.Msg, &msg.Type, &urgent, &broadcasted, &read, + &msg.AgentID, &msg.GroupID, &msg.CreatedAt, + ); err != nil { + return nil, err + } + msg.Urgent = urgent != 0 + msg.Broadcasted = broadcasted != 0 + msg.Read = read != 0 + msgs = append(msgs, msg) + } + if err := rows.Err(); err != nil { + return nil, err + } + + result := &store.ListResult[store.Message]{ + Items: msgs, + TotalCount: totalCount, + } + if len(msgs) > limit { + result.Items = msgs[:limit] + result.NextCursor = msgs[limit-1].ID + } + return result, nil +} + +// MarkMessageRead marks a message as read. +func (s *PostgresStore) MarkMessageRead(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, ` + UPDATE messages SET read = 1 WHERE id = $1 + `, id) + if err != nil { + return err + } + n, err := result.RowsAffected() + if err != nil { + return err + } + if n == 0 { + return store.ErrNotFound + } + return nil +} + +// MarkAllMessagesRead marks all messages for a recipient as read. +func (s *PostgresStore) MarkAllMessagesRead(ctx context.Context, recipientID string) error { + _, err := s.db.ExecContext(ctx, ` + UPDATE messages SET read = 1 WHERE recipient_id = $1 + `, recipientID) + return err +} + +// PurgeOldMessages removes read messages older than readCutoff and unread messages +// older than unreadCutoff. Returns the number of messages removed. +func (s *PostgresStore) PurgeOldMessages(ctx context.Context, readCutoff time.Time, unreadCutoff time.Time) (int, error) { + result, err := s.db.ExecContext(ctx, ` + DELETE FROM messages + WHERE (read = 1 AND created_at < $1) OR (read = 0 AND created_at < $2) + `, readCutoff, unreadCutoff) + if err != nil { + return 0, err + } + n, err := result.RowsAffected() + if err != nil { + return 0, err + } + return int(n), nil +} diff --git a/pkg/store/postgres/migrations.go b/pkg/store/postgres/migrations.go new file mode 100644 index 00000000..8ce84d33 --- /dev/null +++ b/pkg/store/postgres/migrations.go @@ -0,0 +1,1150 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres provides a PostgreSQL implementation of the Store interface. +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" +) + +// Migration V1: Initial schema +const migrationV1 = ` +-- Projects table +CREATE TABLE IF NOT EXISTS groves ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + slug TEXT NOT NULL, + git_remote TEXT UNIQUE, + labels TEXT, + annotations TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by TEXT, + owner_id TEXT, + visibility TEXT NOT NULL DEFAULT 'private' +); +CREATE INDEX IF NOT EXISTS idx_groves_slug ON groves(slug); +CREATE INDEX IF NOT EXISTS idx_groves_git_remote ON groves(git_remote); +CREATE INDEX IF NOT EXISTS idx_groves_owner ON groves(owner_id); + +-- Runtime brokers table +CREATE TABLE IF NOT EXISTS runtime_brokers ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + slug TEXT NOT NULL, + type TEXT NOT NULL, + mode TEXT NOT NULL DEFAULT 'connected', + version TEXT, + status TEXT NOT NULL DEFAULT 'offline', + connection_state TEXT DEFAULT 'disconnected', + last_heartbeat TIMESTAMP, + capabilities TEXT, + supported_harnesses TEXT, + resources TEXT, + runtimes TEXT, + labels TEXT, + annotations TEXT, + endpoint TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX IF NOT EXISTS idx_runtime_brokers_slug ON runtime_brokers(slug); +CREATE INDEX IF NOT EXISTS idx_runtime_brokers_status ON runtime_brokers(status); + +-- Project contributors (many-to-many relationship) +CREATE TABLE IF NOT EXISTS grove_contributors ( + grove_id TEXT NOT NULL, + broker_id TEXT NOT NULL, + broker_name TEXT NOT NULL, + mode TEXT NOT NULL DEFAULT 'connected', + status TEXT NOT NULL DEFAULT 'offline', + profiles TEXT, + last_seen TIMESTAMP, + PRIMARY KEY (grove_id, broker_id), + FOREIGN KEY (grove_id) REFERENCES groves(id) ON DELETE CASCADE, + FOREIGN KEY (broker_id) REFERENCES runtime_brokers(id) ON DELETE CASCADE +); + +-- Agents table +CREATE TABLE IF NOT EXISTS agents ( + id TEXT PRIMARY KEY, + agent_id TEXT NOT NULL, + name TEXT NOT NULL, + template TEXT NOT NULL, + grove_id TEXT NOT NULL, + labels TEXT, + annotations TEXT, + status TEXT NOT NULL DEFAULT 'pending', + connection_state TEXT DEFAULT 'unknown', + container_status TEXT, + session_status TEXT, + runtime_state TEXT, + image TEXT, + detached INTEGER NOT NULL DEFAULT 1, + runtime TEXT, + runtime_broker_id TEXT, + web_pty_enabled INTEGER NOT NULL DEFAULT 0, + task_summary TEXT, + applied_config TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_seen TIMESTAMP, + created_by TEXT, + owner_id TEXT, + visibility TEXT NOT NULL DEFAULT 'private', + state_version INTEGER NOT NULL DEFAULT 1, + FOREIGN KEY (grove_id) REFERENCES groves(id) ON DELETE CASCADE, + FOREIGN KEY (runtime_broker_id) REFERENCES runtime_brokers(id) ON DELETE SET NULL +); +-- Use (agent_id, grove_id) order to match Ent schema's (slug, project_id) +CREATE UNIQUE INDEX IF NOT EXISTS idx_agents_grove_slug ON agents(agent_id, grove_id); +CREATE INDEX IF NOT EXISTS idx_agents_grove ON agents(grove_id); +CREATE INDEX IF NOT EXISTS idx_agents_status ON agents(status); +CREATE INDEX IF NOT EXISTS idx_agents_runtime_broker ON agents(runtime_broker_id); + +-- Templates table +CREATE TABLE IF NOT EXISTS templates ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + slug TEXT NOT NULL, + harness TEXT NOT NULL, + image TEXT, + config TEXT, + scope TEXT NOT NULL DEFAULT 'global', + grove_id TEXT, + storage_uri TEXT, + owner_id TEXT, + visibility TEXT NOT NULL DEFAULT 'private', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (grove_id) REFERENCES groves(id) ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS idx_templates_slug_scope ON templates(slug, scope); +CREATE INDEX IF NOT EXISTS idx_templates_harness ON templates(harness); + +-- Users table +CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + email TEXT UNIQUE NOT NULL, + display_name TEXT NOT NULL, + avatar_url TEXT, + role TEXT NOT NULL DEFAULT 'member', + status TEXT NOT NULL DEFAULT 'active', + preferences TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_login TIMESTAMP +); +CREATE INDEX IF NOT EXISTS idx_users_email ON users(email); +` + +// Migration V2: Add default_runtime_broker_id to groves +const migrationV2 = ` +-- Add default runtime broker to groves +ALTER TABLE groves ADD COLUMN default_runtime_broker_id TEXT REFERENCES runtime_brokers(id) ON DELETE SET NULL; +CREATE INDEX IF NOT EXISTS idx_groves_default_runtime_broker ON groves(default_runtime_broker_id); +` + +// Migration V3: Add local_path to grove_contributors +const migrationV3 = ` +-- Add local_path column to grove_contributors for tracking filesystem paths per broker +ALTER TABLE grove_contributors ADD COLUMN local_path TEXT; +` + +// Migration V4: Add environment variables and secrets tables +const migrationV4 = ` +-- Environment variables table +CREATE TABLE IF NOT EXISTS env_vars ( + id TEXT PRIMARY KEY, + key TEXT NOT NULL, + value TEXT NOT NULL, + scope TEXT NOT NULL, + scope_id TEXT NOT NULL, + description TEXT, + sensitive INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by TEXT +); +CREATE UNIQUE INDEX IF NOT EXISTS idx_env_vars_key_scope ON env_vars(key, scope, scope_id); +CREATE INDEX IF NOT EXISTS idx_env_vars_scope ON env_vars(scope, scope_id); + +-- Secrets table +CREATE TABLE IF NOT EXISTS secrets ( + id TEXT PRIMARY KEY, + key TEXT NOT NULL, + encrypted_value TEXT NOT NULL, + scope TEXT NOT NULL, + scope_id TEXT NOT NULL, + description TEXT, + version INTEGER NOT NULL DEFAULT 1, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by TEXT, + updated_by TEXT +); +CREATE UNIQUE INDEX IF NOT EXISTS idx_secrets_key_scope ON secrets(key, scope, scope_id); +CREATE INDEX IF NOT EXISTS idx_secrets_scope ON secrets(scope, scope_id); +` + +// Migration V5: Groups and Policies (Hub Permissions System) +const migrationV5 = ` +-- Groups table +CREATE TABLE IF NOT EXISTS groups ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + slug TEXT UNIQUE NOT NULL, + description TEXT, + parent_id TEXT REFERENCES groups(id) ON DELETE SET NULL, + labels TEXT, + annotations TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by TEXT, + owner_id TEXT +); +CREATE INDEX IF NOT EXISTS idx_groups_slug ON groups(slug); +CREATE INDEX IF NOT EXISTS idx_groups_parent ON groups(parent_id); +CREATE INDEX IF NOT EXISTS idx_groups_owner ON groups(owner_id); + +-- Group members table (users and nested groups) +CREATE TABLE IF NOT EXISTS group_members ( + group_id TEXT NOT NULL, + member_type TEXT NOT NULL, -- 'user' or 'group' + member_id TEXT NOT NULL, + role TEXT NOT NULL DEFAULT 'member', + added_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + added_by TEXT, + PRIMARY KEY (group_id, member_type, member_id), + FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS idx_group_members_member ON group_members(member_type, member_id); + +-- Policies table +CREATE TABLE IF NOT EXISTS policies ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + scope_type TEXT NOT NULL, + scope_id TEXT, + resource_type TEXT NOT NULL DEFAULT '*', + resource_id TEXT, + actions TEXT NOT NULL, -- JSON array + effect TEXT NOT NULL, + conditions TEXT, -- JSON object + priority INTEGER NOT NULL DEFAULT 0, + labels TEXT, + annotations TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by TEXT +); +CREATE INDEX IF NOT EXISTS idx_policies_scope ON policies(scope_type, scope_id); +CREATE INDEX IF NOT EXISTS idx_policies_effect ON policies(effect); +CREATE INDEX IF NOT EXISTS idx_policies_priority ON policies(priority DESC); + +-- Policy bindings table +CREATE TABLE IF NOT EXISTS policy_bindings ( + policy_id TEXT NOT NULL, + principal_type TEXT NOT NULL, -- 'user' or 'group' + principal_id TEXT NOT NULL, + PRIMARY KEY (policy_id, principal_type, principal_id), + FOREIGN KEY (policy_id) REFERENCES policies(id) ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS idx_policy_bindings_principal ON policy_bindings(principal_type, principal_id); +` + +// Migration V6: Extend templates table for hosted template management +const migrationV6 = ` +-- Add new columns to templates table +ALTER TABLE templates ADD COLUMN display_name TEXT; +ALTER TABLE templates ADD COLUMN description TEXT; +ALTER TABLE templates ADD COLUMN content_hash TEXT; +ALTER TABLE templates ADD COLUMN scope_id TEXT; +ALTER TABLE templates ADD COLUMN storage_bucket TEXT; +ALTER TABLE templates ADD COLUMN storage_path TEXT; +ALTER TABLE templates ADD COLUMN files TEXT; +ALTER TABLE templates ADD COLUMN base_template TEXT; +ALTER TABLE templates ADD COLUMN locked INTEGER NOT NULL DEFAULT 0; +ALTER TABLE templates ADD COLUMN status TEXT NOT NULL DEFAULT 'active'; +ALTER TABLE templates ADD COLUMN created_by TEXT; +ALTER TABLE templates ADD COLUMN updated_by TEXT; + +-- Add indexes for new columns +CREATE INDEX IF NOT EXISTS idx_templates_status ON templates(status); +CREATE INDEX IF NOT EXISTS idx_templates_content_hash ON templates(content_hash); +CREATE INDEX IF NOT EXISTS idx_templates_scope_id ON templates(scope, scope_id); +` + +const migrationV7 = ` +-- Add API keys table +CREATE TABLE IF NOT EXISTS api_keys ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + name TEXT NOT NULL, + prefix TEXT NOT NULL, + key_hash TEXT NOT NULL UNIQUE, + scopes TEXT, + revoked INTEGER NOT NULL DEFAULT 0, + expires_at TIMESTAMP, + last_used TIMESTAMP, + created_at TIMESTAMP NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +); + +-- Add indexes for API keys +CREATE INDEX IF NOT EXISTS idx_api_keys_user_id ON api_keys(user_id); +CREATE INDEX IF NOT EXISTS idx_api_keys_key_hash ON api_keys(key_hash); +CREATE INDEX IF NOT EXISTS idx_api_keys_prefix ON api_keys(prefix); +` + +const migrationV8 = ` +-- Add message column to agents table +ALTER TABLE agents ADD COLUMN message TEXT; +` + +// Migration V9: Broker secrets and join tokens for Runtime Broker authentication +const migrationV9 = ` +-- Broker secrets table for HMAC-based authentication +CREATE TABLE IF NOT EXISTS broker_secrets ( + broker_id TEXT PRIMARY KEY, + secret_key BYTEA NOT NULL, + algorithm TEXT NOT NULL DEFAULT 'hmac-sha256', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + rotated_at TIMESTAMP, + expires_at TIMESTAMP, + status TEXT NOT NULL DEFAULT 'active', + FOREIGN KEY (broker_id) REFERENCES runtime_brokers(id) ON DELETE CASCADE +); + +-- Broker join tokens table for registration bootstrap +CREATE TABLE IF NOT EXISTS broker_join_tokens ( + broker_id TEXT PRIMARY KEY, + token_hash TEXT NOT NULL UNIQUE, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by TEXT NOT NULL, + FOREIGN KEY (broker_id) REFERENCES runtime_brokers(id) ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS idx_broker_join_tokens_hash ON broker_join_tokens(token_hash); +CREATE INDEX IF NOT EXISTS idx_broker_join_tokens_expires ON broker_join_tokens(expires_at); +` + +// Migration V10: Add user tracking to grove_contributors and runtime_brokers +const migrationV10 = ` +-- Add linked_by and linked_at columns to grove_contributors for tracking who linked a broker +ALTER TABLE grove_contributors ADD COLUMN linked_by TEXT; +ALTER TABLE grove_contributors ADD COLUMN linked_at TIMESTAMP; + +-- Add created_by column to runtime_brokers for tracking who registered the broker +ALTER TABLE runtime_brokers ADD COLUMN created_by TEXT; +` + +// Migration V11: Add auto_provide column to runtime_brokers +const migrationV11 = ` +-- Add auto_provide column to runtime_brokers for automatic project provider registration +ALTER TABLE runtime_brokers ADD COLUMN auto_provide INTEGER NOT NULL DEFAULT 0; +` + +// Migration V12: Add injection_mode and secret columns to env_vars +const migrationV12 = ` +ALTER TABLE env_vars ADD COLUMN injection_mode TEXT NOT NULL DEFAULT 'as_needed'; +ALTER TABLE env_vars ADD COLUMN secret INTEGER NOT NULL DEFAULT 0; +` + +const migrationV13 = ` +ALTER TABLE secrets ADD COLUMN secret_type TEXT NOT NULL DEFAULT 'environment'; +ALTER TABLE secrets ADD COLUMN target TEXT; +` + +const migrationV14 = ` +ALTER TABLE secrets ADD COLUMN secret_ref TEXT; +` + +const migrationV15 = ` +UPDATE agents SET status = session_status WHERE session_status IS NOT NULL AND session_status != ''; +ALTER TABLE agents DROP COLUMN session_status; +` + +// Migration V16: Add harness_configs table +const migrationV16 = ` +CREATE TABLE IF NOT EXISTS harness_configs ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + slug TEXT NOT NULL, + display_name TEXT, + description TEXT, + harness TEXT NOT NULL, + config TEXT, + content_hash TEXT, + scope TEXT NOT NULL DEFAULT 'global', + scope_id TEXT, + storage_uri TEXT, + storage_bucket TEXT, + storage_path TEXT, + files TEXT, + locked INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'active', + owner_id TEXT, + created_by TEXT, + updated_by TEXT, + visibility TEXT NOT NULL DEFAULT 'private', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX IF NOT EXISTS idx_harness_configs_slug_scope ON harness_configs(slug, scope); +CREATE INDEX IF NOT EXISTS idx_harness_configs_harness ON harness_configs(harness); +CREATE INDEX IF NOT EXISTS idx_harness_configs_status ON harness_configs(status); +CREATE INDEX IF NOT EXISTS idx_harness_configs_content_hash ON harness_configs(content_hash); +CREATE INDEX IF NOT EXISTS idx_harness_configs_scope_id ON harness_configs(scope, scope_id); +` + +// Migration V17: Add deleted_at column to agents for soft-delete support +const migrationV17 = ` +ALTER TABLE agents ADD COLUMN deleted_at TIMESTAMP; +CREATE INDEX IF NOT EXISTS idx_agents_deleted ON agents(status, deleted_at) WHERE status = 'deleted'; +` + +// Migration V18: Notification subscriptions and notifications tables +const migrationV18 = ` +CREATE TABLE IF NOT EXISTS notification_subscriptions ( + id TEXT PRIMARY KEY, + agent_id TEXT NOT NULL, + subscriber_type TEXT NOT NULL DEFAULT 'agent', + subscriber_id TEXT NOT NULL, + grove_id TEXT NOT NULL, + trigger_statuses TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by TEXT NOT NULL, + FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS idx_notification_subs_agent ON notification_subscriptions(agent_id); +CREATE INDEX IF NOT EXISTS idx_notification_subs_project ON notification_subscriptions(grove_id); + +CREATE TABLE IF NOT EXISTS notifications ( + id TEXT PRIMARY KEY, + subscription_id TEXT NOT NULL, + agent_id TEXT NOT NULL, + grove_id TEXT NOT NULL, + subscriber_type TEXT NOT NULL, + subscriber_id TEXT NOT NULL, + status TEXT NOT NULL, + message TEXT NOT NULL, + dispatched INTEGER NOT NULL DEFAULT 0, + acknowledged INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (subscription_id) REFERENCES notification_subscriptions(id) ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS idx_notifications_subscriber ON notifications(subscriber_type, subscriber_id); +CREATE INDEX IF NOT EXISTS idx_notifications_project ON notifications(grove_id); +` + +const migrationV19 = ` +CREATE TABLE IF NOT EXISTS scheduled_events ( + id TEXT PRIMARY KEY, + grove_id TEXT NOT NULL, + event_type TEXT NOT NULL, + fire_at TIMESTAMP NOT NULL, + payload TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by TEXT, + fired_at TIMESTAMP, + error TEXT, + + FOREIGN KEY (grove_id) REFERENCES groves(id) ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS idx_scheduled_events_status ON scheduled_events(status); +CREATE INDEX IF NOT EXISTS idx_scheduled_events_fire_at ON scheduled_events(fire_at) WHERE status = 'pending'; +CREATE INDEX IF NOT EXISTS idx_scheduled_events_project ON scheduled_events(grove_id); +` + +const migrationV20 = ` +ALTER TABLE agents ADD COLUMN phase TEXT NOT NULL DEFAULT 'created'; +ALTER TABLE agents ADD COLUMN activity TEXT DEFAULT ''; +ALTER TABLE agents ADD COLUMN tool_name TEXT DEFAULT ''; + +-- Backfill phase/activity from existing status values +UPDATE agents SET phase = 'created' WHERE status IN ('created', 'pending'); +UPDATE agents SET phase = 'provisioning' WHERE status = 'provisioning'; +UPDATE agents SET phase = 'cloning' WHERE status = 'cloning'; +UPDATE agents SET phase = 'running', activity = 'idle' WHERE status = 'running'; +UPDATE agents SET phase = 'stopped' WHERE status = 'stopped'; +UPDATE agents SET phase = 'error' WHERE status = 'error'; +UPDATE agents SET phase = 'running', activity = 'thinking' WHERE status = 'busy'; +UPDATE agents SET phase = 'running', activity = 'idle' WHERE status = 'idle'; +UPDATE agents SET phase = 'running', activity = 'waiting_for_input' WHERE status = 'waiting_for_input'; +UPDATE agents SET phase = 'running', activity = 'completed' WHERE status = 'completed'; +UPDATE agents SET phase = 'running', activity = 'limits_exceeded' WHERE status = 'limits_exceeded'; +UPDATE agents SET phase = 'stopped' WHERE status IN ('deleted', 'restored'); +UPDATE agents SET phase = 'running', activity = 'offline' WHERE status = 'undetermined'; + +CREATE INDEX IF NOT EXISTS idx_agents_phase ON agents(phase); +` + +// Migration V21: Remove legacy status column from agents table. +// Phase 6 of the agent state refactor — the status column is superseded by +// the phase/activity columns added in V20. +const migrationV21 = ` +-- Backfill any remaining agents where phase was not set +UPDATE agents SET phase = status WHERE (phase = '' OR phase IS NULL) AND status IN ('created','provisioning','cloning','starting','running','stopping','stopped','error'); +UPDATE agents SET phase = 'created' WHERE (phase = '' OR phase IS NULL) AND status = 'pending'; +UPDATE agents SET phase = 'stopped' WHERE (phase = '' OR phase IS NULL) AND status = 'deleted'; + +-- Backfill activity from status for running agents +UPDATE agents SET activity = status WHERE phase = 'running' AND (activity = '' OR activity IS NULL) AND status IN ('idle','waiting_for_input','completed','limits_exceeded','offline'); +UPDATE agents SET activity = 'thinking' WHERE phase = 'running' AND (activity = '' OR activity IS NULL) AND status = 'busy'; + +-- Update soft-delete index: rely on deleted_at instead of status +DROP INDEX IF EXISTS idx_agents_deleted; +CREATE INDEX IF NOT EXISTS idx_agents_deleted ON agents(deleted_at) WHERE deleted_at IS NOT NULL; + +-- Drop the status index before dropping the column +DROP INDEX IF EXISTS idx_agents_status; + +-- Drop the status column (SQLite supports this from 3.35.0+) +ALTER TABLE agents DROP COLUMN status; +` + +// Migration V22: Rename trigger_statuses to trigger_activities in notification_subscriptions. +const migrationV22 = ` +ALTER TABLE notification_subscriptions RENAME COLUMN trigger_statuses TO trigger_activities; +` + +// Migration V23: Add injection_mode column to secrets +const migrationV23 = ` +ALTER TABLE secrets ADD COLUMN injection_mode TEXT NOT NULL DEFAULT 'as_needed'; +` + +// Migration V24: Add last_activity_event column to agents for stalled detection. +// Backfills existing agents to prevent false positives on upgrade. +const migrationV24 = ` +ALTER TABLE agents ADD COLUMN last_activity_event TIMESTAMP; +UPDATE agents SET last_activity_event = COALESCE(last_seen, updated_at, created_at); +` + +// Migration V25: Add stalled_from_activity column for stalled detection. +// Records the activity that was active when the agent was marked stalled, +// so heartbeats can distinguish "still stuck" from "genuinely recovered". +const migrationV25 = ` +ALTER TABLE agents ADD COLUMN stalled_from_activity TEXT DEFAULT ''; +` + +// Migration V26: Add limits tracking columns to agents table. +// These fields are updated by sciontool status reports from inside the container. +const migrationV26 = ` +ALTER TABLE agents ADD COLUMN current_turns INTEGER DEFAULT 0; +ALTER TABLE agents ADD COLUMN current_model_calls INTEGER DEFAULT 0; +ALTER TABLE agents ADD COLUMN started_at TIMESTAMP; +` + +const migrationV27 = ` +ALTER TABLE users ADD COLUMN last_seen TIMESTAMP; +` + +// Migration V28: Add shared_dirs column to groves table. +// Stores project-level shared directory configuration as JSON. +const migrationV28 = ` +ALTER TABLE groves ADD COLUMN shared_dirs TEXT DEFAULT ''; +` + +// Migration V29: Add group_type and grove_id columns to groups table. +// These enable filtering groups by type and project association. +const migrationV29 = ` +ALTER TABLE groups ADD COLUMN group_type TEXT NOT NULL DEFAULT 'explicit'; +ALTER TABLE groups ADD COLUMN grove_id TEXT DEFAULT ''; +CREATE INDEX IF NOT EXISTS idx_groups_project ON groups(grove_id); +` + +// Migration V30: Create gcp_service_accounts table for GCP identity management. +const migrationV30 = ` +CREATE TABLE IF NOT EXISTS gcp_service_accounts ( + id TEXT PRIMARY KEY, + scope TEXT NOT NULL, + scope_id TEXT NOT NULL, + email TEXT NOT NULL, + grove_id TEXT NOT NULL, + display_name TEXT NOT NULL DEFAULT '', + default_scopes TEXT NOT NULL DEFAULT '', + verified INTEGER NOT NULL DEFAULT 0, + verified_at TIMESTAMP, + created_by TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(email, scope, scope_id) +); +CREATE INDEX IF NOT EXISTS idx_gcp_sa_scope ON gcp_service_accounts(scope, scope_id); +` + +// Migration V31: Add scope column to notification_subscriptions and make agent_id nullable. +// Enables project-scoped subscriptions (watch all agents in a project) in addition to +// agent-scoped subscriptions. Adds unique constraint for deduplication. +const migrationV31 = ` +-- Postgres supports ALTER TABLE directly, so we recreate the table as in SQLite source. +CREATE TABLE notification_subscriptions_new ( + id TEXT PRIMARY KEY, + scope TEXT NOT NULL DEFAULT 'agent', + agent_id TEXT, + subscriber_type TEXT NOT NULL DEFAULT 'agent', + subscriber_id TEXT NOT NULL, + grove_id TEXT NOT NULL, + trigger_activities TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by TEXT NOT NULL, + FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE CASCADE +); + +-- Copy existing data (all existing subscriptions are agent-scoped) +INSERT INTO notification_subscriptions_new + (id, scope, agent_id, subscriber_type, subscriber_id, grove_id, trigger_activities, created_at, created_by) +SELECT id, 'agent', agent_id, subscriber_type, subscriber_id, grove_id, trigger_activities, created_at, created_by +FROM notification_subscriptions; + +DROP TABLE notification_subscriptions CASCADE; +ALTER TABLE notification_subscriptions_new RENAME TO notification_subscriptions; + +-- Recreate indexes +CREATE INDEX IF NOT EXISTS idx_notification_subs_agent ON notification_subscriptions(agent_id); +CREATE INDEX IF NOT EXISTS idx_notification_subs_project ON notification_subscriptions(grove_id); +CREATE INDEX IF NOT EXISTS idx_notification_subs_subscriber ON notification_subscriptions(subscriber_type, subscriber_id); + +-- Unique constraint: one subscription per (scope, target, subscriber, project) +CREATE UNIQUE INDEX IF NOT EXISTS idx_notification_subs_unique + ON notification_subscriptions(scope, COALESCE(agent_id, ''), subscriber_type, subscriber_id, grove_id); +` + +// Migration V32: Recurring schedules table and schedule_id FK on scheduled_events. +const migrationV32 = ` +CREATE TABLE IF NOT EXISTS schedules ( + id TEXT PRIMARY KEY, + grove_id TEXT NOT NULL, + name TEXT NOT NULL, + cron_expr TEXT NOT NULL, + event_type TEXT NOT NULL, + payload TEXT NOT NULL DEFAULT '{}', + status TEXT NOT NULL DEFAULT 'active', + next_run_at TIMESTAMP, + last_run_at TIMESTAMP, + last_run_status TEXT, + last_run_error TEXT, + run_count INTEGER NOT NULL DEFAULT 0, + error_count INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by TEXT, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (grove_id) REFERENCES groves(id) ON DELETE CASCADE, + UNIQUE(grove_id, name) +); +CREATE INDEX IF NOT EXISTS idx_schedules_project ON schedules(grove_id); +CREATE INDEX IF NOT EXISTS idx_schedules_next_run ON schedules(next_run_at) WHERE status = 'active'; + +ALTER TABLE scheduled_events ADD COLUMN schedule_id TEXT DEFAULT ''; +` + +// Migration V33: Subscription templates table. +const migrationV33 = ` +CREATE TABLE IF NOT EXISTS subscription_templates ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + scope TEXT NOT NULL DEFAULT 'project', + trigger_activities TEXT NOT NULL, + grove_id TEXT NOT NULL DEFAULT '', + created_by TEXT NOT NULL, + UNIQUE(grove_id, name) +); +CREATE INDEX IF NOT EXISTS idx_sub_templates_project ON subscription_templates(grove_id); +` + +// Migration V34: User access tokens table (replaces api_keys). +const migrationV34 = ` +CREATE TABLE IF NOT EXISTS user_access_tokens ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + name TEXT NOT NULL, + prefix TEXT NOT NULL, + key_hash TEXT NOT NULL UNIQUE, + grove_id TEXT NOT NULL, + scopes TEXT NOT NULL, + revoked INTEGER NOT NULL DEFAULT 0, + expires_at TIMESTAMP, + last_used TIMESTAMP, + created_at TIMESTAMP NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + FOREIGN KEY (grove_id) REFERENCES groves(id) ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS idx_uat_user_id ON user_access_tokens(user_id); +CREATE INDEX IF NOT EXISTS idx_uat_key_hash ON user_access_tokens(key_hash); +` + +// Migration V35: GitHub App installations and project GitHub App fields. +const migrationV35 = ` +CREATE TABLE IF NOT EXISTS github_installations ( + installation_id BIGINT PRIMARY KEY, + account_login TEXT NOT NULL, + account_type TEXT NOT NULL DEFAULT 'Organization', + app_id INTEGER NOT NULL, + repositories TEXT NOT NULL DEFAULT '[]', + status TEXT NOT NULL DEFAULT 'active', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX IF NOT EXISTS idx_github_installations_account ON github_installations(account_login); +CREATE INDEX IF NOT EXISTS idx_github_installations_status ON github_installations(status); + +ALTER TABLE groves ADD COLUMN github_installation_id INTEGER; +ALTER TABLE groves ADD COLUMN github_permissions TEXT; +ALTER TABLE groves ADD COLUMN github_app_status TEXT; +` + +// Migration V36: Git identity configuration for commit attribution. +const migrationV36 = ` +ALTER TABLE groves ADD COLUMN git_identity TEXT; +` + +// Migration V37: Add ancestry column for transitive access control. +const migrationV37 = ` +ALTER TABLE agents ADD COLUMN ancestry TEXT; +` + +// Migration V38: Backfill ancestry for existing agents from created_by. +const migrationV38 = ` +UPDATE agents SET ancestry = json_build_array(created_by)::text +WHERE created_by IS NOT NULL AND created_by != '' AND ancestry IS NULL; +` + +// Migration V39: Messages table for bidirectional human-agent messaging. +const migrationV39 = ` +CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY, + grove_id TEXT NOT NULL, + sender TEXT NOT NULL, + sender_id TEXT NOT NULL DEFAULT '', + recipient TEXT NOT NULL, + recipient_id TEXT NOT NULL DEFAULT '', + msg TEXT NOT NULL, + type TEXT NOT NULL DEFAULT 'instruction', + urgent INTEGER NOT NULL DEFAULT 0, + broadcasted INTEGER NOT NULL DEFAULT 0, + read INTEGER NOT NULL DEFAULT 0, + agent_id TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS idx_messages_project ON messages(grove_id); +CREATE INDEX IF NOT EXISTS idx_messages_recipient ON messages(recipient_id, read); +CREATE INDEX IF NOT EXISTS idx_messages_agent ON messages(agent_id); +CREATE INDEX IF NOT EXISTS idx_messages_sender ON messages(sender_id); +CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(created_at DESC); +` + +// Migration V40: Allow multiple groves per git remote (drop UNIQUE on git_remote), +// and enforce slug uniqueness (add UNIQUE on slug). Requires table recreation +// because SQLite does not support ALTER TABLE DROP CONSTRAINT. +// +// IMPORTANT: This migration requires foreign_keys=OFF around the DROP TABLE. +// SQLite ignores PRAGMA changes inside transactions, so the migration runner +// handles this via the foreignKeysOffMigrations set. The PRAGMA statements are +// intentionally NOT included in the SQL string. +const migrationV40 = ` +CREATE TABLE IF NOT EXISTS groves_new ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + slug TEXT NOT NULL UNIQUE, + git_remote TEXT, + labels TEXT, + annotations TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by TEXT, + owner_id TEXT, + visibility TEXT NOT NULL DEFAULT 'private', + default_runtime_broker_id TEXT REFERENCES runtime_brokers(id) ON DELETE SET NULL, + shared_dirs TEXT, + github_installation_id INTEGER REFERENCES github_installations(installation_id), + github_permissions TEXT, + github_app_status TEXT, + git_identity TEXT +); + +INSERT INTO groves_new SELECT + id, name, slug, git_remote, labels, annotations, + created_at, updated_at, created_by, owner_id, visibility, + default_runtime_broker_id, shared_dirs, + github_installation_id, github_permissions, github_app_status, + git_identity +FROM groves ON CONFLICT DO NOTHING; + +DROP TABLE IF EXISTS groves CASCADE; +ALTER TABLE groves_new RENAME TO groves; + +CREATE INDEX IF NOT EXISTS idx_groves_slug ON groves(slug); +CREATE INDEX IF NOT EXISTS idx_groves_git_remote ON groves(git_remote); +CREATE INDEX IF NOT EXISTS idx_groves_owner ON groves(owner_id); +CREATE INDEX IF NOT EXISTS idx_groves_default_runtime_broker ON groves(default_runtime_broker_id); +` + +// Migration V41: Maintenance operations tables for the admin maintenance panel. +// Tracks one-time migrations and repeatable operations with execution history. +const migrationV41 = ` +CREATE TABLE IF NOT EXISTS maintenance_operations ( + id TEXT PRIMARY KEY, + key TEXT NOT NULL UNIQUE, + title TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '', + category TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + started_at TIMESTAMP, + completed_at TIMESTAMP, + started_by TEXT, + result TEXT, + metadata TEXT NOT NULL DEFAULT '{}' +); +CREATE INDEX IF NOT EXISTS idx_maintenance_ops_category ON maintenance_operations(category); +CREATE INDEX IF NOT EXISTS idx_maintenance_ops_status ON maintenance_operations(status); + +CREATE TABLE IF NOT EXISTS maintenance_operation_runs ( + id TEXT PRIMARY KEY, + operation_key TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'running', + started_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP, + started_by TEXT, + result TEXT, + log TEXT NOT NULL DEFAULT '', + FOREIGN KEY (operation_key) REFERENCES maintenance_operations(key) +); +CREATE INDEX IF NOT EXISTS idx_maintenance_runs_key ON maintenance_operation_runs(operation_key); +CREATE INDEX IF NOT EXISTS idx_maintenance_runs_started ON maintenance_operation_runs(started_at DESC); + +-- Seed: one-time migrations +INSERT INTO maintenance_operations (id, key, title, description, category, status) +VALUES ( + gen_random_uuid()::text, + 'secret-hub-id-migration', + 'Secret Hub ID Namespace Migration', + 'Migrates hub-scoped secrets from the legacy fixed "hub" scope ID to the per-instance hub ID. Required when upgrading a hub that was created before the hub ID namespacing feature. Only needed for GCP Secret Manager backend.', + 'migration', + 'pending' +); + +-- Seed: repeatable operations +INSERT INTO maintenance_operations (id, key, title, description, category, status) +VALUES ( + gen_random_uuid()::text, + 'pull-images', + 'Pull Container Images', + 'Pulls the latest container images for all configured harnesses from the image registry.', + 'operation', + 'pending' +); + +INSERT INTO maintenance_operations (id, key, title, description, category, status) +VALUES ( + gen_random_uuid()::text, + 'rebuild-server', + 'Rebuild Server from Git', + 'Pulls latest code from the repository, rebuilds the server binary and web assets, then restarts the hub service. Equivalent to the fast-deploy mode of gce-start-hub.sh.', + 'operation', + 'pending' +); + +INSERT INTO maintenance_operations (id, key, title, description, category, status) +VALUES ( + gen_random_uuid()::text, + 'rebuild-web', + 'Rebuild Web Frontend', + 'Rebuilds only the web frontend assets from source without restarting the server binary. Changes take effect on the next page load.', + 'operation', + 'pending' +); +` + +const migrationV42 = ` +CREATE TABLE IF NOT EXISTS grove_sync_state ( + grove_id TEXT NOT NULL, + broker_id TEXT NOT NULL DEFAULT '', + last_sync_time TIMESTAMP, + last_commit_sha TEXT, + file_count INTEGER NOT NULL DEFAULT 0, + total_bytes INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (grove_id, broker_id), + FOREIGN KEY (grove_id) REFERENCES groves(id) ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS idx_grove_sync_state_project ON grove_sync_state(grove_id); +` + +// migrationV43 fixes pre-existing signing key secrets that were stored with +// the default secret_type ('environment' or ") instead of 'internal'. Without +// this, stale rows created before the fix would still be resolved and injected +// into agent containers. +const migrationV43 = ` +UPDATE secrets SET secret_type = 'internal' +WHERE key IN ('agent_signing_key', 'user_signing_key') + AND scope = 'hub' + AND secret_type != 'internal'; +` + +// Migration V44: Add managed and managed_by columns to gcp_service_accounts table +// for hub-minted service accounts. +const migrationV44 = ` +ALTER TABLE gcp_service_accounts ADD COLUMN managed INTEGER NOT NULL DEFAULT 0; +ALTER TABLE gcp_service_accounts ADD COLUMN managed_by TEXT NOT NULL DEFAULT ''; +` + +// Migration V45: Add allow_progeny column to secrets table +const migrationV45 = ` +ALTER TABLE secrets ADD COLUMN allow_progeny INTEGER NOT NULL DEFAULT 0; +` + +const migrationV46 = ` +ALTER TABLE templates ADD COLUMN default_harness_config TEXT; +` + +const migrationV47 = ` +INSERT INTO maintenance_operations (id, key, title, description, category, status) +VALUES ( + gen_random_uuid()::text, + 'rebuild-container-binaries', + 'Rebuild Container Binaries', + 'Rebuilds scion and sciontool binaries for Linux containers (make container-binaries). Only available when SCION_DEV_BINARIES is set. Binaries are written to .build/container/ in the source checkout.', + 'operation', + 'pending' +); +` + +const migrationV48 = ` +CREATE TABLE IF NOT EXISTS allow_list ( + id TEXT PRIMARY KEY, + email TEXT NOT NULL, + note TEXT NOT NULL DEFAULT '', + added_by TEXT NOT NULL, + invite_id TEXT NOT NULL DEFAULT '', + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE UNIQUE INDEX IF NOT EXISTS allow_list_email_unique ON allow_list (LOWER(email)); +` + +const migrationV49 = ` +CREATE TABLE IF NOT EXISTS invite_codes ( + id TEXT PRIMARY KEY, + code_hash TEXT NOT NULL UNIQUE, + code_prefix TEXT NOT NULL, + max_uses INTEGER NOT NULL DEFAULT 1, + use_count INTEGER NOT NULL DEFAULT 0, + expires_at TIMESTAMP NOT NULL, + revoked INTEGER NOT NULL DEFAULT 0, + created_by TEXT NOT NULL, + note TEXT NOT NULL DEFAULT '', + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX IF NOT EXISTS idx_invite_codes_expires ON invite_codes(expires_at); +` + +// migrateV50 renames 'grove' entities to 'project' idempotently. +// This is Phase 4 of the grove-to-project rename strategy. +// Each rename operation checks whether the old name still exists before +// attempting the rename, so the migration can be re-run safely on databases +// that partially applied an earlier (non-idempotent) version of V50. +func migrateV50(ctx context.Context, tx *sql.Tx) error { + // 1. Rename Tables (check before renaming) + tableRenames := [][2]string{ + {"groves", "projects"}, + {"grove_contributors", "project_contributors"}, + {"grove_sync_state", "project_sync_state"}, + } + for _, r := range tableRenames { + exists, err := tableExists(ctx, tx, r[0]) + if err != nil { + return fmt.Errorf("checking table %s: %w", r[0], err) + } + if exists { + if _, err := tx.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s RENAME TO %s", r[0], r[1])); err != nil { + return fmt.Errorf("renaming table %s to %s: %w", r[0], r[1], err) + } + } + } + + // 2. Rename Columns (check before renaming) + // After step 1, tables are at their new names. If step 1 was already + // applied in a prior run, the tables are also at their new names. + columnRenames := [][3]string{ + {"project_contributors", "grove_id", "project_id"}, + {"project_sync_state", "grove_id", "project_id"}, + {"agents", "grove_id", "project_id"}, + {"templates", "grove_id", "project_id"}, + {"notification_subscriptions", "grove_id", "project_id"}, + {"notifications", "grove_id", "project_id"}, + {"scheduled_events", "grove_id", "project_id"}, + {"schedules", "grove_id", "project_id"}, + {"subscription_templates", "grove_id", "project_id"}, + {"user_access_tokens", "grove_id", "project_id"}, + {"messages", "grove_id", "project_id"}, + {"groups", "grove_id", "project_id"}, + {"gcp_service_accounts", "grove_id", "project_id"}, + } + for _, r := range columnRenames { + exists, err := columnExists(ctx, tx, r[0], r[1]) + if err != nil { + return fmt.Errorf("checking column %s.%s: %w", r[0], r[1], err) + } + if exists { + if _, err := tx.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", r[0], r[1], r[2])); err != nil { + return fmt.Errorf("renaming column %s.%s to %s: %w", r[0], r[1], r[2], err) + } + } + } + + // 3. Update Data Values (already idempotent — UPDATE WHERE is a no-op + // when the old value no longer exists) + dataUpdates := ` +UPDATE env_vars SET scope = 'project' WHERE scope = 'grove'; +UPDATE secrets SET scope = 'project' WHERE scope = 'grove'; +UPDATE policies SET scope_type = 'project' WHERE scope_type = 'grove'; +UPDATE gcp_service_accounts SET scope = 'project' WHERE scope = 'grove'; +UPDATE groups SET group_type = 'project_agents' WHERE group_type = 'grove_agents'; +UPDATE notification_subscriptions SET scope = 'project' WHERE scope = 'grove'; +UPDATE subscription_templates SET scope = 'project' WHERE scope = 'grove'; +UPDATE templates SET scope = 'project' WHERE scope = 'grove'; +UPDATE harness_configs SET scope = 'project' WHERE scope = 'grove'; +` + if _, err := tx.ExecContext(ctx, dataUpdates); err != nil { + return fmt.Errorf("updating data values: %w", err) + } + + // 4. Rename/Recreate Indexes (already idempotent — DROP IF EXISTS / CREATE IF NOT EXISTS) + indexSQL := ` +DROP INDEX IF EXISTS idx_groves_slug; +CREATE UNIQUE INDEX IF NOT EXISTS idx_projects_slug ON projects(slug); +DROP INDEX IF EXISTS idx_groves_git_remote; +CREATE INDEX IF NOT EXISTS idx_projects_git_remote ON projects(git_remote); +DROP INDEX IF EXISTS idx_groves_owner; +CREATE INDEX IF NOT EXISTS idx_projects_owner ON projects(owner_id); +DROP INDEX IF EXISTS idx_groves_default_runtime_broker; +CREATE INDEX IF NOT EXISTS idx_projects_default_runtime_broker ON projects(default_runtime_broker_id); + +DROP INDEX IF EXISTS idx_agents_grove_slug; +DROP INDEX IF EXISTS idx_agents_project_slug; +CREATE UNIQUE INDEX IF NOT EXISTS idx_agents_project_slug ON agents(agent_id, project_id); +DROP INDEX IF EXISTS idx_agents_grove; +CREATE INDEX IF NOT EXISTS idx_agents_project ON agents(project_id); + +DROP INDEX IF EXISTS idx_grove_sync_state_grove; +CREATE INDEX IF NOT EXISTS idx_project_sync_state_project ON project_sync_state(project_id); + +DROP INDEX IF EXISTS idx_notification_subs_grove; +CREATE INDEX IF NOT EXISTS idx_notification_subs_project ON notification_subscriptions(project_id); + +DROP INDEX IF EXISTS idx_notifications_grove; +CREATE INDEX IF NOT EXISTS idx_notifications_project ON notifications(project_id); + +DROP INDEX IF EXISTS idx_scheduled_events_grove; +CREATE INDEX IF NOT EXISTS idx_scheduled_events_project ON scheduled_events(project_id); + +DROP INDEX IF EXISTS idx_schedules_grove; +CREATE INDEX IF NOT EXISTS idx_schedules_project ON schedules(project_id); + +DROP INDEX IF EXISTS idx_sub_templates_grove; +CREATE INDEX IF NOT EXISTS idx_sub_templates_project ON subscription_templates(project_id); + +DROP INDEX IF EXISTS idx_messages_grove; +CREATE INDEX IF NOT EXISTS idx_messages_project ON messages(project_id); + +DROP INDEX IF EXISTS idx_groups_grove; +CREATE INDEX IF NOT EXISTS idx_groups_project ON groups(project_id); + +DROP INDEX IF EXISTS idx_gcp_sa_grove; +CREATE INDEX IF NOT EXISTS idx_gcp_sa_project ON gcp_service_accounts(project_id); +` + if _, err := tx.ExecContext(ctx, indexSQL); err != nil { + return fmt.Errorf("updating indexes: %w", err) + } + + return nil +} + +// migrationV51 adds group_id to messages for correlating set[] deliveries. +const migrationV51 = ` +ALTER TABLE messages ADD COLUMN group_id TEXT NOT NULL DEFAULT ''; +` + +// migrationV52 renames the idle activity to working for clearer agent state reporting. +const migrationV52 = ` +UPDATE agents SET activity = 'working' WHERE activity = 'idle'; +UPDATE agents SET stalled_from_activity = 'working' WHERE stalled_from_activity = 'idle'; +` + +// migrationV53 adds an index on (created, id) to allow_list for efficient keyset pagination. +// It also ensures the allow_list table exists, because databases created before V48/V49 were +// inserted into the migration sequence already have version 48 recorded with different content +// (the grove-to-project rename that is now V50). On those databases V48 is skipped, so the +// allow_list table was never created. +const migrationV53 = ` +CREATE TABLE IF NOT EXISTS allow_list ( + id TEXT PRIMARY KEY, + email TEXT NOT NULL, + note TEXT NOT NULL DEFAULT '', + added_by TEXT NOT NULL, + invite_id TEXT NOT NULL DEFAULT '', + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE UNIQUE INDEX IF NOT EXISTS allow_list_email_unique ON allow_list (LOWER(email)); +CREATE TABLE IF NOT EXISTS invite_codes ( + id TEXT PRIMARY KEY, + code_hash TEXT NOT NULL UNIQUE, + code_prefix TEXT NOT NULL, + max_uses INTEGER NOT NULL DEFAULT 1, + use_count INTEGER NOT NULL DEFAULT 0, + expires_at TIMESTAMP NOT NULL, + revoked INTEGER NOT NULL DEFAULT 0, + created_by TEXT NOT NULL, + note TEXT NOT NULL DEFAULT '', + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX IF NOT EXISTS idx_invite_codes_expires ON invite_codes(expires_at); +CREATE INDEX IF NOT EXISTS idx_allow_list_created_id ON allow_list (created DESC, id DESC); +` + +// tableExists checks whether a table with the given name exists in the database. +func tableExists(ctx context.Context, tx *sql.Tx, tableName string) (bool, error) { + var name string + err := tx.QueryRowContext(ctx, + "SELECT table_name FROM information_schema.tables WHERE table_name=$1 AND table_schema='public'", tableName, + ).Scan(&name) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + if err != nil { + return false, err + } + return true, nil +} + +// columnExists checks whether a column with the given name exists in the specified table. +func columnExists(ctx context.Context, tx *sql.Tx, tableName, columnName string) (bool, error) { + var name string + err := tx.QueryRowContext(ctx, + "SELECT column_name FROM information_schema.columns WHERE table_name=$1 AND column_name=$2 AND table_schema='public'", + tableName, columnName, + ).Scan(&name) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + if err != nil { + return false, err + } + return true, nil +} diff --git a/pkg/store/postgres/notification.go b/pkg/store/postgres/notification.go new file mode 100644 index 00000000..e1890399 --- /dev/null +++ b/pkg/store/postgres/notification.go @@ -0,0 +1,553 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres provides a Postgres implementation of the Store interface. +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +// ============================================================================ +// Notification Subscription Operations +// ============================================================================ + +// CreateNotificationSubscription creates a new notification subscription. +func (s *PostgresStore) CreateNotificationSubscription(ctx context.Context, sub *store.NotificationSubscription) error { + if sub.ID == "" || sub.SubscriberID == "" || sub.ProjectID == "" { + return store.ErrInvalidInput + } + + // Default scope to agent for backward compatibility + if sub.Scope == "" { + sub.Scope = store.SubscriptionScopeAgent + } + + // Validate scope-specific constraints + switch sub.Scope { + case store.SubscriptionScopeAgent: + if sub.AgentID == "" { + return store.ErrInvalidInput + } + case store.SubscriptionScopeProject: + sub.AgentID = "" // Ensure no agent_id for project-scoped + default: + return fmt.Errorf("invalid scope %q: %w", sub.Scope, store.ErrInvalidInput) + } + + now := time.Now() + if sub.CreatedAt.IsZero() { + sub.CreatedAt = now + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO notification_subscriptions ( + id, scope, agent_id, subscriber_type, subscriber_id, project_id, + trigger_activities, created_at, created_by + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + `, + sub.ID, sub.Scope, nullableString(sub.AgentID), sub.SubscriberType, sub.SubscriberID, sub.ProjectID, + marshalJSON(sub.TriggerActivities), sub.CreatedAt, sub.CreatedBy, + ) + if err != nil { + if strings.Contains(err.Error(), "unique constraint") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + if strings.Contains(err.Error(), "foreign key constraint") { + return fmt.Errorf("agent %s does not exist: %w", sub.AgentID, store.ErrInvalidInput) + } + return err + } + return nil +} + +// GetNotificationSubscription returns a single subscription by ID. +func (s *PostgresStore) GetNotificationSubscription(ctx context.Context, id string) (*store.NotificationSubscription, error) { + row := s.db.QueryRowContext(ctx, ` + SELECT id, scope, agent_id, subscriber_type, subscriber_id, project_id, + trigger_activities, created_at, created_by + FROM notification_subscriptions + WHERE id = $1 + `, id) + + var sub store.NotificationSubscription + var agentID sql.NullString + var triggerActivitiesJSON string + + if err := row.Scan( + &sub.ID, &sub.Scope, &agentID, &sub.SubscriberType, &sub.SubscriberID, &sub.ProjectID, + &triggerActivitiesJSON, &sub.CreatedAt, &sub.CreatedBy, + ); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + if agentID.Valid { + sub.AgentID = agentID.String + } + unmarshalJSON(triggerActivitiesJSON, &sub.TriggerActivities) + return &sub, nil +} + +// GetNotificationSubscriptions returns all agent-scoped subscriptions for a watched agent. +func (s *PostgresStore) GetNotificationSubscriptions(ctx context.Context, agentID string) ([]store.NotificationSubscription, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, scope, agent_id, subscriber_type, subscriber_id, project_id, + trigger_activities, created_at, created_by + FROM notification_subscriptions + WHERE agent_id = $1 + ORDER BY created_at ASC + `, agentID) + if err != nil { + return nil, err + } + defer rows.Close() + + return scanSubscriptions(rows) +} + +// GetNotificationSubscriptionsByProject returns all subscriptions within a project (any scope). +func (s *PostgresStore) GetNotificationSubscriptionsByProject(ctx context.Context, projectID string) ([]store.NotificationSubscription, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, scope, agent_id, subscriber_type, subscriber_id, project_id, + trigger_activities, created_at, created_by + FROM notification_subscriptions + WHERE project_id = $1 + ORDER BY created_at ASC + `, projectID) + if err != nil { + return nil, err + } + defer rows.Close() + + return scanSubscriptions(rows) +} + +// GetNotificationSubscriptionsByProjectScope returns project-scoped subscriptions +// (scope='project') for a given project. +func (s *PostgresStore) GetNotificationSubscriptionsByProjectScope(ctx context.Context, projectID string) ([]store.NotificationSubscription, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, scope, agent_id, subscriber_type, subscriber_id, project_id, + trigger_activities, created_at, created_by + FROM notification_subscriptions + WHERE project_id = $1 AND scope = 'project' + ORDER BY created_at ASC + `, projectID) + if err != nil { + return nil, err + } + defer rows.Close() + + return scanSubscriptions(rows) +} + +// GetSubscriptionsForSubscriber returns all subscriptions owned by a subscriber. +func (s *PostgresStore) GetSubscriptionsForSubscriber(ctx context.Context, subscriberType, subscriberID string) ([]store.NotificationSubscription, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, scope, agent_id, subscriber_type, subscriber_id, project_id, + trigger_activities, created_at, created_by + FROM notification_subscriptions + WHERE subscriber_type = $1 AND subscriber_id = $2 + ORDER BY created_at ASC + `, subscriberType, subscriberID) + if err != nil { + return nil, err + } + defer rows.Close() + + return scanSubscriptions(rows) +} + +// UpdateNotificationSubscriptionTriggers updates the trigger activities of a subscription. +func (s *PostgresStore) UpdateNotificationSubscriptionTriggers(ctx context.Context, id string, triggerActivities []string) error { + if id == "" || len(triggerActivities) == 0 { + return store.ErrInvalidInput + } + + result, err := s.db.ExecContext(ctx, ` + UPDATE notification_subscriptions SET trigger_activities = $1 WHERE id = $2 + `, marshalJSON(triggerActivities), id) + if err != nil { + return err + } + + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +// DeleteNotificationSubscription deletes a subscription by ID. +func (s *PostgresStore) DeleteNotificationSubscription(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, ` + DELETE FROM notification_subscriptions WHERE id = $1 + `, id) + if err != nil { + return err + } + + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +// DeleteNotificationSubscriptionsForAgent deletes all subscriptions for a watched agent. +// No error on zero rows affected. +func (s *PostgresStore) DeleteNotificationSubscriptionsForAgent(ctx context.Context, agentID string) error { + _, err := s.db.ExecContext(ctx, ` + DELETE FROM notification_subscriptions WHERE agent_id = $1 + `, agentID) + return err +} + +// ============================================================================ +// Notification Operations +// ============================================================================ + +// CreateNotification creates a new notification record. +func (s *PostgresStore) CreateNotification(ctx context.Context, notif *store.Notification) error { + if notif.ID == "" || notif.SubscriptionID == "" || notif.AgentID == "" { + return store.ErrInvalidInput + } + + now := time.Now() + if notif.CreatedAt.IsZero() { + notif.CreatedAt = now + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO notifications ( + id, subscription_id, agent_id, project_id, + subscriber_type, subscriber_id, + status, message, dispatched, acknowledged, created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + `, + notif.ID, notif.SubscriptionID, notif.AgentID, notif.ProjectID, + notif.SubscriberType, notif.SubscriberID, + notif.Status, notif.Message, + boolToInt(notif.Dispatched), boolToInt(notif.Acknowledged), + notif.CreatedAt, + ) + if err != nil { + if strings.Contains(err.Error(), "foreign key constraint") { + return fmt.Errorf("subscription %s does not exist: %w", notif.SubscriptionID, store.ErrInvalidInput) + } + return err + } + return nil +} + +// GetNotifications returns notifications for a subscriber. +// If onlyUnacknowledged is true, only unacknowledged notifications are returned. +// Results are ordered by created_at DESC. +func (s *PostgresStore) GetNotifications(ctx context.Context, subscriberType, subscriberID string, onlyUnacknowledged bool) ([]store.Notification, error) { + query := ` + SELECT id, subscription_id, agent_id, project_id, + subscriber_type, subscriber_id, + status, message, dispatched, acknowledged, created_at + FROM notifications + WHERE subscriber_type = $1 AND subscriber_id = $2 + ` + args := []interface{}{subscriberType, subscriberID} + + if onlyUnacknowledged { + query += ` AND acknowledged = 0` + } + + query += ` ORDER BY created_at DESC` + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + return scanNotifications(rows) +} + +// GetNotificationsByAgent returns notifications for a subscriber filtered by agent ID. +// If onlyUnacknowledged is true, only unacknowledged notifications are returned. +// Results are ordered by created_at DESC. +func (s *PostgresStore) GetNotificationsByAgent(ctx context.Context, agentID, subscriberType, subscriberID string, onlyUnacknowledged bool) ([]store.Notification, error) { + query := ` + SELECT id, subscription_id, agent_id, project_id, + subscriber_type, subscriber_id, + status, message, dispatched, acknowledged, created_at + FROM notifications + WHERE agent_id = $1 AND subscriber_type = $2 AND subscriber_id = $3 + ` + args := []interface{}{agentID, subscriberType, subscriberID} + + if onlyUnacknowledged { + query += ` AND acknowledged = 0` + } + + query += ` ORDER BY created_at DESC` + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + return scanNotifications(rows) +} + +// AcknowledgeNotification marks a notification as acknowledged. +func (s *PostgresStore) AcknowledgeNotification(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, ` + UPDATE notifications SET acknowledged = 1 WHERE id = $1 + `, id) + if err != nil { + return err + } + + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +// AcknowledgeAllNotifications marks all notifications for a subscriber as acknowledged. +// No error on zero rows affected. +func (s *PostgresStore) AcknowledgeAllNotifications(ctx context.Context, subscriberType, subscriberID string) error { + _, err := s.db.ExecContext(ctx, ` + UPDATE notifications SET acknowledged = 1 + WHERE subscriber_type = $1 AND subscriber_id = $2 + `, subscriberType, subscriberID) + return err +} + +// MarkNotificationDispatched marks a notification as dispatched. +func (s *PostgresStore) MarkNotificationDispatched(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, ` + UPDATE notifications SET dispatched = 1 WHERE id = $1 + `, id) + if err != nil { + return err + } + + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +// GetLastNotificationStatus returns the status of the most recent notification +// for a given subscription. Returns ("", nil) if no notifications exist. +func (s *PostgresStore) GetLastNotificationStatus(ctx context.Context, subscriptionID string) (string, error) { + var status string + err := s.db.QueryRowContext(ctx, ` + SELECT status FROM notifications + WHERE subscription_id = $1 + ORDER BY created_at DESC + LIMIT 1 + `, subscriptionID).Scan(&status) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", nil + } + return "", err + } + return status, nil +} + +// ============================================================================ +// Subscription Template Operations +// ============================================================================ + +// CreateSubscriptionTemplate creates a new subscription template. +func (s *PostgresStore) CreateSubscriptionTemplate(ctx context.Context, tmpl *store.SubscriptionTemplate) error { + if tmpl.ID == "" || tmpl.Name == "" || len(tmpl.TriggerActivities) == 0 { + return store.ErrInvalidInput + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO subscription_templates (id, name, scope, trigger_activities, project_id, created_by) + VALUES ($1, $2, $3, $4, $5, $6) + `, tmpl.ID, tmpl.Name, tmpl.Scope, marshalJSON(tmpl.TriggerActivities), tmpl.ProjectID, tmpl.CreatedBy) + if err != nil { + if strings.Contains(err.Error(), "unique constraint") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +// GetSubscriptionTemplate returns a template by ID. +func (s *PostgresStore) GetSubscriptionTemplate(ctx context.Context, id string) (*store.SubscriptionTemplate, error) { + row := s.db.QueryRowContext(ctx, ` + SELECT id, name, scope, trigger_activities, project_id, created_by + FROM subscription_templates WHERE id = $1 + `, id) + + var tmpl store.SubscriptionTemplate + var triggersJSON string + if err := row.Scan(&tmpl.ID, &tmpl.Name, &tmpl.Scope, &triggersJSON, &tmpl.ProjectID, &tmpl.CreatedBy); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + unmarshalJSON(triggersJSON, &tmpl.TriggerActivities) + return &tmpl, nil +} + +// ListSubscriptionTemplates returns all templates. If projectID is non-empty, +// returns both global templates and project-specific templates. +func (s *PostgresStore) ListSubscriptionTemplates(ctx context.Context, projectID string) ([]store.SubscriptionTemplate, error) { + var rows *sql.Rows + var err error + + if projectID != "" { + rows, err = s.db.QueryContext(ctx, ` + SELECT id, name, scope, trigger_activities, project_id, created_by + FROM subscription_templates + WHERE project_id = '' OR project_id = $1 + ORDER BY project_id ASC, name ASC + `, projectID) + } else { + rows, err = s.db.QueryContext(ctx, ` + SELECT id, name, scope, trigger_activities, project_id, created_by + FROM subscription_templates + WHERE project_id = '' + ORDER BY name ASC + `) + } + if err != nil { + return nil, err + } + defer rows.Close() + + var templates []store.SubscriptionTemplate + for rows.Next() { + var tmpl store.SubscriptionTemplate + var triggersJSON string + if err := rows.Scan(&tmpl.ID, &tmpl.Name, &tmpl.Scope, &triggersJSON, &tmpl.ProjectID, &tmpl.CreatedBy); err != nil { + return nil, err + } + unmarshalJSON(triggersJSON, &tmpl.TriggerActivities) + templates = append(templates, tmpl) + } + return templates, rows.Err() +} + +// DeleteSubscriptionTemplate deletes a template by ID. +func (s *PostgresStore) DeleteSubscriptionTemplate(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, ` + DELETE FROM subscription_templates WHERE id = $1 + `, id) + if err != nil { + return err + } + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +// ============================================================================ +// Helpers +// ============================================================================ + +// boolToInt converts a bool to an int for SQLite storage. +func boolToInt(b bool) int { + if b { + return 1 + } + return 0 +} + +// scanSubscriptions scans rows into NotificationSubscription slices. +func scanSubscriptions(rows *sql.Rows) ([]store.NotificationSubscription, error) { + var subs []store.NotificationSubscription + for rows.Next() { + var sub store.NotificationSubscription + var agentID sql.NullString + var triggerActivitiesJSON string + + if err := rows.Scan( + &sub.ID, &sub.Scope, &agentID, &sub.SubscriberType, &sub.SubscriberID, &sub.ProjectID, + &triggerActivitiesJSON, &sub.CreatedAt, &sub.CreatedBy, + ); err != nil { + return nil, err + } + + if agentID.Valid { + sub.AgentID = agentID.String + } + unmarshalJSON(triggerActivitiesJSON, &sub.TriggerActivities) + subs = append(subs, sub) + } + if err := rows.Err(); err != nil { + return nil, err + } + return subs, nil +} + +// scanNotifications scans rows into Notification slices. +func scanNotifications(rows *sql.Rows) ([]store.Notification, error) { + var notifs []store.Notification + for rows.Next() { + var notif store.Notification + var dispatched, acknowledged int + + if err := rows.Scan( + ¬if.ID, ¬if.SubscriptionID, ¬if.AgentID, ¬if.ProjectID, + ¬if.SubscriberType, ¬if.SubscriberID, + ¬if.Status, ¬if.Message, &dispatched, &acknowledged, ¬if.CreatedAt, + ); err != nil { + return nil, err + } + + notif.Dispatched = dispatched != 0 + notif.Acknowledged = acknowledged != 0 + notifs = append(notifs, notif) + } + if err := rows.Err(); err != nil { + return nil, err + } + return notifs, nil +} diff --git a/pkg/store/postgres/policies.go b/pkg/store/postgres/policies.go new file mode 100644 index 00000000..c3661c73 --- /dev/null +++ b/pkg/store/postgres/policies.go @@ -0,0 +1,361 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres provides a PostgreSQL implementation of the Store interface. +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +func (s *PostgresStore) CreatePolicy(ctx context.Context, policy *store.Policy) error { + now := time.Now() + policy.Created = now + policy.Updated = now + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO policies (id, name, description, scope_type, scope_id, resource_type, resource_id, actions, effect, conditions, priority, labels, annotations, created_at, updated_at, created_by) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + `, + policy.ID, policy.Name, policy.Description, policy.ScopeType, policy.ScopeID, + policy.ResourceType, policy.ResourceID, + marshalJSON(policy.Actions), policy.Effect, marshalJSON(policy.Conditions), + policy.Priority, marshalJSON(policy.Labels), marshalJSON(policy.Annotations), + policy.Created, policy.Updated, policy.CreatedBy, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) GetPolicy(ctx context.Context, id string) (*store.Policy, error) { + policy := &store.Policy{} + var actions, conditions, labels, annotations string + + err := s.db.QueryRowContext(ctx, ` + SELECT id, name, description, scope_type, scope_id, resource_type, resource_id, actions, effect, conditions, priority, labels, annotations, created_at, updated_at, created_by + FROM policies WHERE id = $1 + `, id).Scan( + &policy.ID, &policy.Name, &policy.Description, &policy.ScopeType, &policy.ScopeID, + &policy.ResourceType, &policy.ResourceID, + &actions, &policy.Effect, &conditions, + &policy.Priority, &labels, &annotations, + &policy.Created, &policy.Updated, &policy.CreatedBy, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + unmarshalJSON(actions, &policy.Actions) + unmarshalJSON(conditions, &policy.Conditions) + unmarshalJSON(labels, &policy.Labels) + unmarshalJSON(annotations, &policy.Annotations) + + return policy, nil +} + +func (s *PostgresStore) UpdatePolicy(ctx context.Context, policy *store.Policy) error { + policy.Updated = time.Now() + + result, err := s.db.ExecContext(ctx, ` + UPDATE policies SET + name = $1, description = $2, scope_type = $3, scope_id = $4, + resource_type = $5, resource_id = $6, + actions = $7, effect = $8, conditions = $9, + priority = $10, labels = $11, annotations = $12, + updated_at = $13 + WHERE id = $14 + `, + policy.Name, policy.Description, policy.ScopeType, policy.ScopeID, + policy.ResourceType, policy.ResourceID, + marshalJSON(policy.Actions), policy.Effect, marshalJSON(policy.Conditions), + policy.Priority, marshalJSON(policy.Labels), marshalJSON(policy.Annotations), + policy.Updated, + policy.ID, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeletePolicy(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM policies WHERE id = $1", id) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) ListPolicies(ctx context.Context, filter store.PolicyFilter, opts store.ListOptions) (*store.ListResult[store.Policy], error) { + var conditions []string + var args []interface{} + + if filter.Name != "" { + conditions = append(conditions, fmt.Sprintf("name = $%d", len(args)+1)) + args = append(args, filter.Name) + } + if filter.ScopeType != "" { + conditions = append(conditions, fmt.Sprintf("scope_type = $%d", len(args)+1)) + args = append(args, filter.ScopeType) + } + if filter.ScopeID != "" { + conditions = append(conditions, fmt.Sprintf("scope_id = $%d", len(args)+1)) + args = append(args, filter.ScopeID) + } + if filter.ResourceType != "" { + conditions = append(conditions, fmt.Sprintf("resource_type = $%d", len(args)+1)) + args = append(args, filter.ResourceType) + } + if filter.Effect != "" { + conditions = append(conditions, fmt.Sprintf("effect = $%d", len(args)+1)) + args = append(args, filter.Effect) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + var totalCount int + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM policies %s", whereClause) + if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil { + return nil, err + } + + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + + query := fmt.Sprintf(` + SELECT id, name, description, scope_type, scope_id, resource_type, resource_id, actions, effect, conditions, priority, labels, annotations, created_at, updated_at, created_by + FROM policies %s ORDER BY priority DESC, created_at DESC LIMIT $%d + `, whereClause, len(args)+1) + args = append(args, limit) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var policies []store.Policy + for rows.Next() { + var policy store.Policy + var actions, conditions, labels, annotations string + + if err := rows.Scan( + &policy.ID, &policy.Name, &policy.Description, &policy.ScopeType, &policy.ScopeID, + &policy.ResourceType, &policy.ResourceID, + &actions, &policy.Effect, &conditions, + &policy.Priority, &labels, &annotations, + &policy.Created, &policy.Updated, &policy.CreatedBy, + ); err != nil { + return nil, err + } + + unmarshalJSON(actions, &policy.Actions) + unmarshalJSON(conditions, &policy.Conditions) + unmarshalJSON(labels, &policy.Labels) + unmarshalJSON(annotations, &policy.Annotations) + + policies = append(policies, policy) + } + + return &store.ListResult[store.Policy]{ + Items: policies, + TotalCount: totalCount, + }, nil +} + +func (s *PostgresStore) AddPolicyBinding(ctx context.Context, binding *store.PolicyBinding) error { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO policy_bindings (policy_id, principal_type, principal_id) + VALUES ($1, $2, $3) + `, + binding.PolicyID, binding.PrincipalType, binding.PrincipalID, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") || strings.Contains(err.Error(), "PRIMARY KEY constraint failed") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) RemovePolicyBinding(ctx context.Context, policyID, principalType, principalID string) error { + result, err := s.db.ExecContext(ctx, + "DELETE FROM policy_bindings WHERE policy_id = $1 AND principal_type = $2 AND principal_id = $3", + policyID, principalType, principalID, + ) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) GetPolicyBindings(ctx context.Context, policyID string) ([]store.PolicyBinding, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT policy_id, principal_type, principal_id + FROM policy_bindings WHERE policy_id = $1 + `, policyID) + if err != nil { + return nil, err + } + defer rows.Close() + + var bindings []store.PolicyBinding + for rows.Next() { + var binding store.PolicyBinding + if err := rows.Scan(&binding.PolicyID, &binding.PrincipalType, &binding.PrincipalID); err != nil { + return nil, err + } + bindings = append(bindings, binding) + } + + return bindings, nil +} + +func (s *PostgresStore) GetPoliciesForPrincipal(ctx context.Context, principalType, principalID string) ([]store.Policy, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT p.id, p.name, p.description, p.scope_type, p.scope_id, p.resource_type, p.resource_id, p.actions, p.effect, p.conditions, p.priority, p.labels, p.annotations, p.created_at, p.updated_at, p.created_by + FROM policies p + INNER JOIN policy_bindings pb ON p.id = pb.policy_id + WHERE pb.principal_type = $1 AND pb.principal_id = $2 + ORDER BY p.priority DESC, p.created_at DESC + `, principalType, principalID) + if err != nil { + return nil, err + } + defer rows.Close() + + var policies []store.Policy + for rows.Next() { + var policy store.Policy + var actions, conditions, labels, annotations string + + if err := rows.Scan( + &policy.ID, &policy.Name, &policy.Description, &policy.ScopeType, &policy.ScopeID, + &policy.ResourceType, &policy.ResourceID, + &actions, &policy.Effect, &conditions, + &policy.Priority, &labels, &annotations, + &policy.Created, &policy.Updated, &policy.CreatedBy, + ); err != nil { + return nil, err + } + + unmarshalJSON(actions, &policy.Actions) + unmarshalJSON(conditions, &policy.Conditions) + unmarshalJSON(labels, &policy.Labels) + unmarshalJSON(annotations, &policy.Annotations) + + policies = append(policies, policy) + } + + return policies, nil +} + +func (s *PostgresStore) GetPoliciesForPrincipals(ctx context.Context, principals []store.PrincipalRef) ([]store.Policy, error) { + if len(principals) == 0 { + return nil, nil + } + + // Build dynamic OR clauses for each principal + var clauses []string + var args []interface{} + for _, p := range principals { + n := len(args) + 1 + clauses = append(clauses, fmt.Sprintf("(pb.principal_type = $%d AND pb.principal_id = $%d)", n, n+1)) + args = append(args, p.Type, p.ID) + } + + query := ` + SELECT DISTINCT p.id, p.name, p.description, p.scope_type, p.scope_id, p.resource_type, p.resource_id, p.actions, p.effect, p.conditions, p.priority, p.labels, p.annotations, p.created_at, p.updated_at, p.created_by + FROM policies p + INNER JOIN policy_bindings pb ON p.id = pb.policy_id + WHERE ` + strings.Join(clauses, " OR ") + ` + ORDER BY + CASE p.scope_type WHEN 'hub' THEN 0 WHEN 'project' THEN 1 WHEN 'resource' THEN 2 END, + p.priority ASC + ` + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var policies []store.Policy + for rows.Next() { + var policy store.Policy + var actions, conditions, labels, annotations string + + if err := rows.Scan( + &policy.ID, &policy.Name, &policy.Description, &policy.ScopeType, &policy.ScopeID, + &policy.ResourceType, &policy.ResourceID, + &actions, &policy.Effect, &conditions, + &policy.Priority, &labels, &annotations, + &policy.Created, &policy.Updated, &policy.CreatedBy, + ); err != nil { + return nil, err + } + + unmarshalJSON(actions, &policy.Actions) + unmarshalJSON(conditions, &policy.Conditions) + unmarshalJSON(labels, &policy.Labels) + unmarshalJSON(annotations, &policy.Annotations) + + policies = append(policies, policy) + } + + return policies, nil +} diff --git a/pkg/store/postgres/postgres.go b/pkg/store/postgres/postgres.go new file mode 100644 index 00000000..4247de15 --- /dev/null +++ b/pkg/store/postgres/postgres.go @@ -0,0 +1,311 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres provides a PostgreSQL implementation of the Store interface. +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "time" +) + +// PostgresStore implements the Store interface using PostgreSQL. +type PostgresStore struct { + db *sql.DB +} + +// New creates a new Postgres store with the given connection URL. +// The URL is passed directly to lib/pq (e.g. "postgres://user:pass@host/db?sslmode=disable"). +func New(connURL string) (*PostgresStore, error) { + db, err := sql.Open("postgres", connURL) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + db.SetMaxOpenConns(4) + db.SetMaxIdleConns(4) + + return &PostgresStore{db: db}, nil +} + +// Close closes the database connection. +func (s *PostgresStore) Close() error { + return s.db.Close() +} + +// DB returns the underlying *sql.DB for direct access in tests. +func (s *PostgresStore) DB() *sql.DB { + return s.db +} + +// Ping checks database connectivity. +func (s *PostgresStore) Ping(ctx context.Context) error { + return s.db.PingContext(ctx) +} + +// Migrate applies database migrations. +func (s *PostgresStore) Migrate(ctx context.Context) error { + migrations := []any{ + migrationV1, + migrationV2, + migrationV3, + migrationV4, + migrationV5, + migrationV6, + migrationV7, + migrationV8, + migrationV9, + migrationV10, + migrationV11, + migrationV12, + migrationV13, + migrationV14, + migrationV15, + migrationV16, + migrationV17, + migrationV18, + migrationV19, + migrationV20, + migrationV21, + migrationV22, + migrationV23, + migrationV24, + migrationV25, + migrationV26, + migrationV27, + migrationV28, + migrationV29, + migrationV30, + migrationV31, + migrationV32, + migrationV33, + migrationV34, + migrationV35, + migrationV36, + migrationV37, + migrationV38, + migrationV39, + migrationV40, + migrationV41, + migrationV42, + migrationV43, + migrationV44, + migrationV45, + migrationV46, + migrationV47, + migrationV48, + migrationV49, + migrateV50, + migrationV51, + migrationV52, + migrationV53, + } + + // Create migrations table if not exists + if _, err := s.db.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + `); err != nil { + return fmt.Errorf("failed to create migrations table: %w", err) + } + + // Get current version + var currentVersion int + err := s.db.QueryRowContext(ctx, "SELECT COALESCE(MAX(version), 0) FROM schema_migrations").Scan(¤tVersion) + if err != nil { + return fmt.Errorf("failed to get current schema version: %w", err) + } + + // Migrations that require PRAGMA foreign_keys=OFF around the transaction. + // SQLite ignores PRAGMA changes inside transactions, so we must disable + // foreign keys before BeginTx and re-enable after Commit. Without this, + // DROP TABLE on a parent table triggers ON DELETE CASCADE on child tables. + foreignKeysOffMigrations := map[int]bool{ + 40: true, // V40 drops and recreates the projects table + } + + // Apply pending migrations + for i, migration := range migrations { + version := i + 1 + if version <= currentVersion { + continue + } + + switch m := migration.(type) { + case string: + needsFKOff := foreignKeysOffMigrations[version] + + if needsFKOff { + if err := s.applyMigrationWithFKOff(ctx, version, m); err != nil { + return err + } + continue + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to start transaction for migration %d: %w", version, err) + } + + if _, err := tx.ExecContext(ctx, m); err != nil { + tx.Rollback() + return fmt.Errorf("failed to apply migration %d: %w", version, err) + } + + if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES ($1)", version); err != nil { + tx.Rollback() + return fmt.Errorf("failed to record migration %d: %w", version, err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit migration %d: %w", version, err) + } + + case func(ctx context.Context, tx *sql.Tx) error: + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to start transaction for migration %d: %w", version, err) + } + + if err := m(ctx, tx); err != nil { + tx.Rollback() + return fmt.Errorf("failed to apply migration %d: %w", version, err) + } + + if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES ($1)", version); err != nil { + tx.Rollback() + return fmt.Errorf("failed to record migration %d: %w", version, err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit migration %d: %w", version, err) + } + + default: + return fmt.Errorf("migration %d: unsupported type %T", version, migration) + } + } + + return nil +} + +// applyMigrationWithFKOff runs a migration that requires PRAGMA +// foreign_keys=OFF. In Postgres, foreign key deferral is handled within the +// migration SQL itself (e.g. DROP ... CASCADE / explicit FK drops), so this +// function simply runs the migration in a plain transaction. The +// foreignKeysOffMigrations map and this function are kept so the runner shape +// stays 1-to-1 with the SQLite implementation. +func (s *PostgresStore) applyMigrationWithFKOff(ctx context.Context, version int, migration string) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to start transaction for migration %d: %w", version, err) + } + + if _, err := tx.ExecContext(ctx, migration); err != nil { + tx.Rollback() + return fmt.Errorf("failed to apply migration %d: %w", version, err) + } + + if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (version) VALUES ($1)", version); err != nil { + tx.Rollback() + return fmt.Errorf("failed to record migration %d: %w", version, err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit migration %d: %w", version, err) + } + + return nil +} + +// Helper functions for JSON marshaling/unmarshaling +func marshalJSON(v interface{}) string { + if v == nil { + return "" + } + data, err := json.Marshal(v) + if err != nil { + return "" + } + return string(data) +} + +func unmarshalJSON[T any](data string, v *T) { + if data == "" { + return + } + json.Unmarshal([]byte(data), v) +} + +// nullableString returns a sql.NullString for database insertion. +// Empty strings become NULL, which is important for UNIQUE and FK constraints. +func nullableString(s string) sql.NullString { + if s == "" { + return sql.NullString{Valid: false} + } + return sql.NullString{String: s, Valid: true} +} + +// nullableTime returns a sql.NullTime for database insertion. +// Zero time values become NULL. +func nullableTime(t time.Time) sql.NullTime { + if t.IsZero() { + return sql.NullTime{Valid: false} + } + return sql.NullTime{Time: t, Valid: true} +} + +// nullableInt64 returns a sql.NullInt64 for database insertion. +// Nil pointers become NULL. +func nullableInt64(v *int64) sql.NullInt64 { + if v == nil { + return sql.NullInt64{Valid: false} + } + return sql.NullInt64{Int64: *v, Valid: true} +} + +// marshalJSONPtr marshals a pointer value to JSON string, returning empty string for nil pointers. +// Unlike marshalJSON, this correctly detects nil typed pointers. +func marshalJSONPtr[T any](v *T) string { + if v == nil { + return "" + } + data, err := json.Marshal(v) + if err != nil { + return "" + } + return string(data) +} + +// nullableTimePtr returns a *time.Time for scanning nullable timestamps. +func nullableTimePtr(t sql.NullTime) *time.Time { + if !t.Valid { + return nil + } + return &t.Time +} + +// ptrToNullTime converts a *time.Time to sql.NullTime for database insertion. +// Nil pointers become NULL. +func ptrToNullTime(t *time.Time) sql.NullTime { + if t == nil { + return sql.NullTime{Valid: false} + } + return sql.NullTime{Time: *t, Valid: true} +} diff --git a/pkg/store/postgres/postgres_test.go b/pkg/store/postgres/postgres_test.go new file mode 100644 index 00000000..7514edac --- /dev/null +++ b/pkg/store/postgres/postgres_test.go @@ -0,0 +1,659 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "os" + "testing" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/api" + "github.com/GoogleCloudPlatform/scion/pkg/store" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const envVarDSN = "SCION_TEST_POSTGRES_URL" + +// resetSchema drops and recreates the public schema, giving each test a clean slate. +func resetSchema(t *testing.T, s *PostgresStore) { + t.Helper() + ctx := context.Background() + _, err := s.db.ExecContext(ctx, `DROP SCHEMA public CASCADE; CREATE SCHEMA public;`) + require.NoError(t, err, "resetSchema") +} + +// newTestStore opens a PostgresStore against the live DB and applies all +// migrations. It skips the test when SCION_TEST_POSTGRES_URL is not set. +func newTestStore(t *testing.T) *PostgresStore { + t.Helper() + dsn := os.Getenv(envVarDSN) + if dsn == "" { + t.Skipf("set %s to run Postgres tests", envVarDSN) + } + + s, err := New(dsn) + require.NoError(t, err) + + resetSchema(t, s) + + ctx := context.Background() + err = s.Migrate(ctx) + require.NoError(t, err, "Migrate") + + t.Cleanup(func() { s.Close() }) + return s +} + +// ============================================================================ +// Migration Tests +// ============================================================================ + +func TestMigrate(t *testing.T) { + dsn := os.Getenv(envVarDSN) + if dsn == "" { + t.Skipf("set %s to run Postgres tests", envVarDSN) + } + + s, err := New(dsn) + require.NoError(t, err) + defer s.Close() + + resetSchema(t, s) + + ctx := context.Background() + + // First run: all 53 migrations must apply cleanly. + err = s.Migrate(ctx) + require.NoError(t, err, "first Migrate must succeed") + + // Verify final version in schema_migrations. + var version int + err = s.db.QueryRowContext(ctx, "SELECT MAX(version) FROM schema_migrations").Scan(&version) + require.NoError(t, err) + assert.Equal(t, 53, version, "all 53 migrations should be applied") + + // Second run: must be idempotent (no error, no re-apply). + err = s.Migrate(ctx) + require.NoError(t, err, "second Migrate must be idempotent") + + var version2 int + err = s.db.QueryRowContext(ctx, "SELECT MAX(version) FROM schema_migrations").Scan(&version2) + require.NoError(t, err) + assert.Equal(t, 53, version2, "idempotent run must not change version") +} + +// ============================================================================ +// User Tests +// ============================================================================ + +func TestPGUserCRUD(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + user := &store.User{ + ID: api.NewUUID(), + Email: "alice@example.com", + DisplayName: "Alice", + Role: store.UserRoleMember, + Status: "active", + Preferences: &store.UserPreferences{Theme: "dark"}, + } + + require.NoError(t, s.CreateUser(ctx, user)) + assert.NotZero(t, user.Created) + + // Get by ID + got, err := s.GetUser(ctx, user.ID) + require.NoError(t, err) + assert.Equal(t, user.Email, got.Email) + assert.Equal(t, "dark", got.Preferences.Theme) + + // Get by email + got2, err := s.GetUserByEmail(ctx, "alice@example.com") + require.NoError(t, err) + assert.Equal(t, user.ID, got2.ID) + + // Duplicate email returns ErrAlreadyExists + dup := &store.User{ + ID: api.NewUUID(), Email: "alice@example.com", + DisplayName: "Alice2", Role: store.UserRoleMember, Status: "active", + } + err = s.CreateUser(ctx, dup) + assert.ErrorIs(t, err, store.ErrAlreadyExists) + + // Update + got.DisplayName = "Alice Updated" + got.LastLogin = time.Now() + require.NoError(t, s.UpdateUser(ctx, got)) + + got3, err := s.GetUser(ctx, user.ID) + require.NoError(t, err) + assert.Equal(t, "Alice Updated", got3.DisplayName) + assert.NotZero(t, got3.LastLogin) + + // Delete + require.NoError(t, s.DeleteUser(ctx, user.ID)) + _, err = s.GetUser(ctx, user.ID) + assert.ErrorIs(t, err, store.ErrNotFound) +} + +func TestPGUserList(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 3; i++ { + u := &store.User{ + ID: api.NewUUID(), + Email: "user" + string(rune('a'+i)) + "@example.com", + DisplayName: "User " + string(rune('A'+i)), + Role: store.UserRoleMember, + Status: "active", + } + if i == 0 { + u.Role = store.UserRoleAdmin + } + require.NoError(t, s.CreateUser(ctx, u)) + } + + result, err := s.ListUsers(ctx, store.UserFilter{}, store.ListOptions{}) + require.NoError(t, err) + assert.Equal(t, 3, result.TotalCount) + + result, err = s.ListUsers(ctx, store.UserFilter{Role: store.UserRoleAdmin}, store.ListOptions{}) + require.NoError(t, err) + assert.Equal(t, 1, result.TotalCount) +} + +// ============================================================================ +// Project Tests +// ============================================================================ + +func TestPGProjectCRUD(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + project := &store.Project{ + ID: api.NewUUID(), + Name: "My Project", + Slug: "my-project", + GitRemote: "github.com/org/repo", + Visibility: store.VisibilityPrivate, + Labels: map[string]string{"team": "platform"}, + } + + require.NoError(t, s.CreateProject(ctx, project)) + assert.NotZero(t, project.Created) + + got, err := s.GetProject(ctx, project.ID) + require.NoError(t, err) + assert.Equal(t, project.Name, got.Name) + assert.Equal(t, "platform", got.Labels["team"]) + + // Slug uniqueness + dup := &store.Project{ + ID: api.NewUUID(), Name: "Dup", Slug: "my-project", Visibility: store.VisibilityPrivate, + } + err = s.CreateProject(ctx, dup) + assert.ErrorIs(t, err, store.ErrAlreadyExists) + + // Update + got.Name = "Updated Project" + require.NoError(t, s.UpdateProject(ctx, got)) + + got2, err := s.GetProject(ctx, project.ID) + require.NoError(t, err) + assert.Equal(t, "Updated Project", got2.Name) + + // Delete + require.NoError(t, s.DeleteProject(ctx, project.ID)) + _, err = s.GetProject(ctx, project.ID) + assert.ErrorIs(t, err, store.ErrNotFound) +} + +// ============================================================================ +// Agent Tests +// ============================================================================ + +func TestPGAgentCRUD(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + project := &store.Project{ + ID: api.NewUUID(), + Name: "Test Project", + Slug: "test-project", + Visibility: store.VisibilityPrivate, + } + require.NoError(t, s.CreateProject(ctx, project)) + + agent := &store.Agent{ + ID: api.NewUUID(), + Slug: "test-agent", + Name: "Test Agent", + Template: "claude", + ProjectID: project.ID, + Phase: "created", + Visibility: store.VisibilityPrivate, + Labels: map[string]string{"env": "test"}, + } + + require.NoError(t, s.CreateAgent(ctx, agent)) + assert.NotZero(t, agent.Created) + assert.Equal(t, int64(1), agent.StateVersion) + + got, err := s.GetAgent(ctx, agent.ID) + require.NoError(t, err) + assert.Equal(t, agent.Slug, got.Slug) + assert.Equal(t, "test", got.Labels["env"]) + + got, err = s.GetAgentBySlug(ctx, project.ID, "test-agent") + require.NoError(t, err) + assert.Equal(t, agent.ID, got.ID) + + got.Name = "Updated Agent" + got.Phase = "running" + require.NoError(t, s.UpdateAgent(ctx, got)) + assert.Equal(t, int64(2), got.StateVersion) + + got2, err := s.GetAgent(ctx, agent.ID) + require.NoError(t, err) + assert.Equal(t, "Updated Agent", got2.Name) + + // Version conflict + got2.StateVersion = 1 + err = s.UpdateAgent(ctx, got2) + assert.ErrorIs(t, err, store.ErrVersionConflict) + + require.NoError(t, s.DeleteAgent(ctx, agent.ID)) + _, err = s.GetAgent(ctx, agent.ID) + assert.ErrorIs(t, err, store.ErrNotFound) +} + +func TestPGAgentList(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + project := &store.Project{ + ID: api.NewUUID(), + Name: "Test Project", + Slug: "test-project", + Visibility: store.VisibilityPrivate, + } + require.NoError(t, s.CreateProject(ctx, project)) + + for i := 0; i < 5; i++ { + agent := &store.Agent{ + ID: api.NewUUID(), + Slug: "agent-" + string(rune('a'+i)), + Name: "Agent " + string(rune('A'+i)), + Template: "claude", + ProjectID: project.ID, + Phase: "running", + Visibility: store.VisibilityPrivate, + } + if i%2 == 0 { + agent.Phase = "stopped" + } + require.NoError(t, s.CreateAgent(ctx, agent)) + } + + result, err := s.ListAgents(ctx, store.AgentFilter{}, store.ListOptions{}) + require.NoError(t, err) + assert.Equal(t, 5, result.TotalCount) + + result, err = s.ListAgents(ctx, store.AgentFilter{Phase: "running"}, store.ListOptions{}) + require.NoError(t, err) + assert.Equal(t, 2, result.TotalCount) + + result, err = s.ListAgents(ctx, store.AgentFilter{ProjectID: project.ID}, store.ListOptions{}) + require.NoError(t, err) + assert.Equal(t, 5, result.TotalCount) + + result, err = s.ListAgents(ctx, store.AgentFilter{}, store.ListOptions{Limit: 2}) + require.NoError(t, err) + assert.Len(t, result.Items, 2) +} + +// ============================================================================ +// Secret Tests (incl. Upsert) +// ============================================================================ + +func TestPGSecretCRUD(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + secret := &store.Secret{ + ID: api.NewUUID(), + Key: "MY_SECRET", + EncryptedValue: "enc_val_1", + Scope: store.ScopeHub, + ScopeID: "hub-1", + SecretType: store.SecretTypeEnvironment, + CreatedBy: "user-1", + UpdatedBy: "user-1", + } + + require.NoError(t, s.CreateSecret(ctx, secret)) + assert.NotZero(t, secret.Created) + assert.Equal(t, 1, secret.Version) + + got, err := s.GetSecret(ctx, "MY_SECRET", store.ScopeHub, "hub-1") + require.NoError(t, err) + assert.Equal(t, secret.ID, got.ID) + assert.Equal(t, "MY_SECRET", got.Key) + + got.EncryptedValue = "enc_val_2" + require.NoError(t, s.UpdateSecret(ctx, got)) + assert.Equal(t, 2, got.Version) + + got2, err := s.GetSecretValue(ctx, "MY_SECRET", store.ScopeHub, "hub-1") + require.NoError(t, err) + assert.Equal(t, "enc_val_2", got2) + + require.NoError(t, s.DeleteSecret(ctx, "MY_SECRET", store.ScopeHub, "hub-1")) + _, err = s.GetSecret(ctx, "MY_SECRET", store.ScopeHub, "hub-1") + assert.ErrorIs(t, err, store.ErrNotFound) +} + +func TestPGSecretUpsert(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + secret := &store.Secret{ + ID: api.NewUUID(), + Key: "UPSERT_KEY", + EncryptedValue: "first_value", + Scope: store.ScopeUser, + ScopeID: "u1", + CreatedBy: "user-1", + UpdatedBy: "user-1", + } + + created, err := s.UpsertSecret(ctx, secret) + require.NoError(t, err) + assert.True(t, created, "first upsert should create") + + // Second upsert with same key/scope should update + secret2 := &store.Secret{ + Key: "UPSERT_KEY", + EncryptedValue: "second_value", + Scope: store.ScopeUser, + ScopeID: "u1", + UpdatedBy: "user-1", + } + created2, err := s.UpsertSecret(ctx, secret2) + require.NoError(t, err) + assert.False(t, created2, "second upsert should update, not create") + + got, err := s.GetSecretValue(ctx, "UPSERT_KEY", store.ScopeUser, "u1") + require.NoError(t, err) + assert.Equal(t, "second_value", got) +} + +func TestPGListSecrets(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 3; i++ { + require.NoError(t, s.CreateSecret(ctx, &store.Secret{ + ID: api.NewUUID(), + Key: "KEY_" + string(rune('A'+i)), + EncryptedValue: "val", + Scope: store.ScopeProject, + ScopeID: "proj-1", + })) + } + // One in a different scope + require.NoError(t, s.CreateSecret(ctx, &store.Secret{ + ID: api.NewUUID(), + Key: "HUB_KEY", + EncryptedValue: "val", + Scope: store.ScopeHub, + ScopeID: "hub-1", + })) + + secrets, err := s.ListSecrets(ctx, store.SecretFilter{Scope: store.ScopeProject, ScopeID: "proj-1"}) + require.NoError(t, err) + assert.Len(t, secrets, 3) + + n, err := s.DeleteSecretsByScope(ctx, store.ScopeProject, "proj-1") + require.NoError(t, err) + assert.Equal(t, 3, n) +} + +// ============================================================================ +// Group Tests (+membership) +// ============================================================================ + +func TestPGGroupCRUD(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + group := &store.Group{ + ID: api.NewUUID(), + Name: "Engineering", + Slug: "engineering", + Description: "Eng team", + Created: time.Now(), + Updated: time.Now(), + } + require.NoError(t, s.CreateGroup(ctx, group)) + + got, err := s.GetGroup(ctx, group.ID) + require.NoError(t, err) + assert.Equal(t, "Engineering", got.Name) + + got.Description = "Eng team updated" + require.NoError(t, s.UpdateGroup(ctx, got)) + + got2, err := s.GetGroup(ctx, group.ID) + require.NoError(t, err) + assert.Equal(t, "Eng team updated", got2.Description) + + require.NoError(t, s.DeleteGroup(ctx, group.ID)) + _, err = s.GetGroup(ctx, group.ID) + assert.ErrorIs(t, err, store.ErrNotFound) +} + +func TestPGGroupMembership(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + group := &store.Group{ + ID: api.NewUUID(), + Name: "Alpha", + Slug: "alpha", + Created: time.Now(), + Updated: time.Now(), + } + require.NoError(t, s.CreateGroup(ctx, group)) + + member := &store.GroupMember{ + GroupID: group.ID, + MemberType: "user", + MemberID: "user-42", + Role: "member", + AddedAt: time.Now(), + } + require.NoError(t, s.AddGroupMember(ctx, member)) + + members, err := s.GetGroupMembers(ctx, group.ID) + require.NoError(t, err) + assert.Len(t, members, 1) + assert.Equal(t, "user-42", members[0].MemberID) + + groups, err := s.GetEffectiveGroups(ctx, "user-42") + require.NoError(t, err) + assert.Contains(t, groups, group.ID) + + require.NoError(t, s.RemoveGroupMember(ctx, group.ID, "user", "user-42")) + members, err = s.GetGroupMembers(ctx, group.ID) + require.NoError(t, err) + assert.Len(t, members, 0) +} + +// ============================================================================ +// Policy Tests +// ============================================================================ + +func TestPGPolicyCRUD(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + policy := &store.Policy{ + ID: api.NewUUID(), + Name: "ReadPolicy", + ScopeType: store.PolicyScopeHub, + ScopeID: "hub-1", + ResourceType: "*", + Actions: []string{"read"}, + Effect: "allow", + Priority: 10, + Created: time.Now(), + Updated: time.Now(), + } + require.NoError(t, s.CreatePolicy(ctx, policy)) + + got, err := s.GetPolicy(ctx, policy.ID) + require.NoError(t, err) + assert.Equal(t, "ReadPolicy", got.Name) + assert.Equal(t, []string{"read"}, got.Actions) + + result, err := s.ListPolicies(ctx, store.PolicyFilter{ScopeType: store.PolicyScopeHub}, store.ListOptions{}) + require.NoError(t, err) + assert.Equal(t, 1, result.TotalCount) + + require.NoError(t, s.DeletePolicy(ctx, policy.ID)) + _, err = s.GetPolicy(ctx, policy.ID) + assert.ErrorIs(t, err, store.ErrNotFound) +} + +// ============================================================================ +// InviteCode Tests +// ============================================================================ + +func TestPGInviteCode(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + invite := &store.InviteCode{ + ID: api.NewUUID(), + CodeHash: "hash-abc", + CodePrefix: "scion_inv_", + MaxUses: 5, + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedBy: "admin", + Note: "test invite", + Created: time.Now(), + } + require.NoError(t, s.CreateInviteCode(ctx, invite)) + + got, err := s.GetInviteCodeByHash(ctx, "hash-abc") + require.NoError(t, err) + assert.Equal(t, invite.ID, got.ID) + assert.Equal(t, "test invite", got.Note) + + require.NoError(t, s.IncrementInviteUseCount(ctx, invite.ID)) + + got2, err := s.GetInviteCodeByHash(ctx, "hash-abc") + require.NoError(t, err) + assert.Equal(t, 1, got2.UseCount) + + require.NoError(t, s.RevokeInviteCode(ctx, invite.ID)) + + got3, err := s.GetInviteCodeByHash(ctx, "hash-abc") + require.NoError(t, err) + assert.True(t, got3.Revoked) +} + +// ============================================================================ +// EnvVar Tests +// ============================================================================ + +func TestPGEnvVarCRUD(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + ev := &store.EnvVar{ + ID: api.NewUUID(), + Key: "DATABASE_URL", + Value: "postgres://localhost/mydb", + Scope: store.ScopeProject, + ScopeID: "proj-1", + } + require.NoError(t, s.CreateEnvVar(ctx, ev)) + assert.NotZero(t, ev.Created) + + got, err := s.GetEnvVar(ctx, "DATABASE_URL", store.ScopeProject, "proj-1") + require.NoError(t, err) + assert.Equal(t, ev.Value, got.Value) + + got.Value = "postgres://localhost/newdb" + require.NoError(t, s.UpdateEnvVar(ctx, got)) + + got2, err := s.GetEnvVar(ctx, "DATABASE_URL", store.ScopeProject, "proj-1") + require.NoError(t, err) + assert.Equal(t, "postgres://localhost/newdb", got2.Value) + + require.NoError(t, s.DeleteEnvVar(ctx, "DATABASE_URL", store.ScopeProject, "proj-1")) + _, err = s.GetEnvVar(ctx, "DATABASE_URL", store.ScopeProject, "proj-1") + assert.ErrorIs(t, err, store.ErrNotFound) +} + +func TestPGEnvVarList(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 3; i++ { + require.NoError(t, s.CreateEnvVar(ctx, &store.EnvVar{ + ID: api.NewUUID(), + Key: "VAR_" + string(rune('A'+i)), + Value: "val", + Scope: store.ScopeProject, + ScopeID: "proj-evlist", + })) + } + require.NoError(t, s.CreateEnvVar(ctx, &store.EnvVar{ + ID: api.NewUUID(), + Key: "HUB_VAR", + Value: "hub", + Scope: store.ScopeHub, + ScopeID: "hub-1", + })) + + vars, err := s.ListEnvVars(ctx, store.EnvVarFilter{Scope: store.ScopeProject, ScopeID: "proj-evlist"}) + require.NoError(t, err) + assert.Len(t, vars, 3) + + n, err := s.DeleteEnvVarsByScope(ctx, store.ScopeProject, "proj-evlist") + require.NoError(t, err) + assert.Equal(t, 3, n) + + vars, err = s.ListEnvVars(ctx, store.EnvVarFilter{Scope: store.ScopeProject, ScopeID: "proj-evlist"}) + require.NoError(t, err) + assert.Empty(t, vars) +} + +// ============================================================================ +// Ping +// ============================================================================ + +func TestPGPing(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + require.NoError(t, s.Ping(ctx)) +} diff --git a/pkg/store/postgres/project_sync_state.go b/pkg/store/postgres/project_sync_state.go new file mode 100644 index 00000000..85908a4d --- /dev/null +++ b/pkg/store/postgres/project_sync_state.go @@ -0,0 +1,142 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +// ============================================================================ +// Project Sync State Operations +// ============================================================================ + +// UpsertProjectSyncState creates or updates sync state for a project. +func (s *PostgresStore) UpsertProjectSyncState(ctx context.Context, state *store.ProjectSyncState) error { + if state.ProjectID == "" { + return store.ErrInvalidInput + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO project_sync_state (project_id, broker_id, last_sync_time, last_commit_sha, file_count, total_bytes) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT(project_id, broker_id) DO UPDATE SET + last_sync_time = excluded.last_sync_time, + last_commit_sha = excluded.last_commit_sha, + file_count = excluded.file_count, + total_bytes = excluded.total_bytes + `, state.ProjectID, state.BrokerID, + ptrToNullTime(state.LastSyncTime), + nullableString(state.LastCommitSHA), + state.FileCount, state.TotalBytes, + ) + return err +} + +// GetProjectSyncState retrieves sync state for a project and optional broker. +func (s *PostgresStore) GetProjectSyncState(ctx context.Context, projectID, brokerID string) (*store.ProjectSyncState, error) { + state := &store.ProjectSyncState{} + var lastSyncTime sql.NullTime + var lastCommitSHA sql.NullString + + err := s.db.QueryRowContext(ctx, ` + SELECT project_id, broker_id, last_sync_time, last_commit_sha, file_count, total_bytes + FROM project_sync_state + WHERE project_id = $1 AND broker_id = $2 + `, projectID, brokerID).Scan( + &state.ProjectID, &state.BrokerID, + &lastSyncTime, &lastCommitSHA, + &state.FileCount, &state.TotalBytes, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, store.ErrNotFound + } + return nil, err + } + + if lastSyncTime.Valid { + state.LastSyncTime = &lastSyncTime.Time + } + if lastCommitSHA.Valid { + state.LastCommitSHA = lastCommitSHA.String + } + + return state, nil +} + +// ListProjectSyncStates returns all sync states for a project. +func (s *PostgresStore) ListProjectSyncStates(ctx context.Context, projectID string) ([]store.ProjectSyncState, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT project_id, broker_id, last_sync_time, last_commit_sha, file_count, total_bytes + FROM project_sync_state + WHERE project_id = $1 + ORDER BY broker_id + `, projectID) + if err != nil { + return nil, err + } + defer rows.Close() + + var states []store.ProjectSyncState + for rows.Next() { + var state store.ProjectSyncState + var lastSyncTime sql.NullTime + var lastCommitSHA sql.NullString + + if err := rows.Scan( + &state.ProjectID, &state.BrokerID, + &lastSyncTime, &lastCommitSHA, + &state.FileCount, &state.TotalBytes, + ); err != nil { + return nil, err + } + + if lastSyncTime.Valid { + state.LastSyncTime = &lastSyncTime.Time + } + if lastCommitSHA.Valid { + state.LastCommitSHA = lastCommitSHA.String + } + + states = append(states, state) + } + + if states == nil { + states = []store.ProjectSyncState{} + } + return states, rows.Err() +} + +// DeleteProjectSyncState removes sync state for a project and optional broker. +func (s *PostgresStore) DeleteProjectSyncState(ctx context.Context, projectID, brokerID string) error { + result, err := s.db.ExecContext(ctx, ` + DELETE FROM project_sync_state WHERE project_id = $1 AND broker_id = $2 + `, projectID, brokerID) + if err != nil { + return err + } + + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} diff --git a/pkg/store/postgres/projects.go b/pkg/store/postgres/projects.go new file mode 100644 index 00000000..db850818 --- /dev/null +++ b/pkg/store/postgres/projects.go @@ -0,0 +1,428 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +func (s *PostgresStore) CreateProject(ctx context.Context, project *store.Project) error { + now := time.Now() + project.Created = now + project.Updated = now + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO projects (id, name, slug, git_remote, default_runtime_broker_id, labels, annotations, shared_dirs, created_at, updated_at, created_by, owner_id, visibility, github_installation_id, github_permissions, github_app_status, git_identity) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) + `, + project.ID, project.Name, project.Slug, nullableString(project.GitRemote), nullableString(project.DefaultRuntimeBrokerID), + marshalJSON(project.Labels), marshalJSON(project.Annotations), marshalJSON(project.SharedDirs), + project.Created, project.Updated, project.CreatedBy, project.OwnerID, project.Visibility, + nullableInt64(project.GitHubInstallationID), marshalJSONPtr(project.GitHubPermissions), marshalJSONPtr(project.GitHubAppStatus), + marshalJSONPtr(project.GitIdentity), + ) + if err != nil { + if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) GetProject(ctx context.Context, id string) (*store.Project, error) { + project := &store.Project{} + var labels, annotations, sharedDirs string + var gitRemote, defaultRuntimeBrokerID sql.NullString + var githubInstallationID sql.NullInt64 + var githubPermissions, githubAppStatus, gitIdentity string + + err := s.db.QueryRowContext(ctx, ` + SELECT id, name, slug, git_remote, default_runtime_broker_id, labels, annotations, shared_dirs, created_at, updated_at, created_by, owner_id, visibility, github_installation_id, COALESCE(github_permissions, ''), COALESCE(github_app_status, ''), COALESCE(git_identity, '') + FROM projects WHERE id = $1 + `, id).Scan( + &project.ID, &project.Name, &project.Slug, &gitRemote, &defaultRuntimeBrokerID, + &labels, &annotations, &sharedDirs, + &project.Created, &project.Updated, &project.CreatedBy, &project.OwnerID, &project.Visibility, + &githubInstallationID, &githubPermissions, &githubAppStatus, &gitIdentity, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + if gitRemote.Valid { + project.GitRemote = gitRemote.String + } + if defaultRuntimeBrokerID.Valid { + project.DefaultRuntimeBrokerID = defaultRuntimeBrokerID.String + } + if githubInstallationID.Valid { + id := githubInstallationID.Int64 + project.GitHubInstallationID = &id + } + unmarshalJSON(labels, &project.Labels) + unmarshalJSON(annotations, &project.Annotations) + unmarshalJSON(sharedDirs, &project.SharedDirs) + if githubPermissions != "" { + project.GitHubPermissions = &store.GitHubTokenPermissions{} + unmarshalJSON(githubPermissions, project.GitHubPermissions) + } + if githubAppStatus != "" { + project.GitHubAppStatus = &store.GitHubAppProjectStatus{} + unmarshalJSON(githubAppStatus, project.GitHubAppStatus) + } + if gitIdentity != "" { + project.GitIdentity = &store.GitIdentityConfig{} + unmarshalJSON(gitIdentity, project.GitIdentity) + } + + // Populate computed fields + s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM agents WHERE project_id = $1", id).Scan(&project.AgentCount) + s.db.QueryRowContext(ctx, ` + SELECT (SELECT COUNT(*) FROM project_contributors WHERE project_id = $1 AND status = 'online') + + (SELECT COUNT(*) FROM runtime_brokers WHERE auto_provide = 1 AND status = 'online' + AND id NOT IN (SELECT broker_id FROM project_contributors WHERE project_id = $2)) + `, id, id).Scan(&project.ActiveBrokerCount) + s.populateProjectType(ctx, project) + + return project, nil +} + +// populateProjectType sets the computed ProjectType field based on how the project was established. +// Type is "linked" (pre-existing local project linked to Hub) or "hub-managed" (created via Hub). +// Whether a project is git-backed is orthogonal — indicated by the GitRemote field. +func (s *PostgresStore) populateProjectType(ctx context.Context, project *store.Project) { + // Check if any provider has a local_path not under ~/.scion/projects/ (i.e. broker-linked) + var linkedCount int + s.db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM project_contributors WHERE project_id = $1 AND local_path != '' AND local_path NOT LIKE '%/.scion/projects/%'", + project.ID).Scan(&linkedCount) + if linkedCount > 0 { + project.ProjectType = store.ProjectTypeLinked + return + } + project.ProjectType = store.ProjectTypeHubManaged +} + +func (s *PostgresStore) GetProjectBySlug(ctx context.Context, slug string) (*store.Project, error) { + var id string + err := s.db.QueryRowContext(ctx, "SELECT id FROM projects WHERE slug = $1", slug).Scan(&id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + return s.GetProject(ctx, id) +} + +func (s *PostgresStore) GetProjectBySlugCaseInsensitive(ctx context.Context, slug string) (*store.Project, error) { + var id string + err := s.db.QueryRowContext(ctx, "SELECT id FROM projects WHERE LOWER(slug) = LOWER($1)", slug).Scan(&id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + return s.GetProject(ctx, id) +} + +func (s *PostgresStore) GetProjectsByGitRemote(ctx context.Context, gitRemote string) ([]*store.Project, error) { + rows, err := s.db.QueryContext(ctx, "SELECT id FROM projects WHERE git_remote = $1 ORDER BY created_at ASC", gitRemote) + if err != nil { + return nil, err + } + + // Collect all IDs first, then close the cursor before calling GetProject + // (SQLite single-connection can't serve a new query while rows are open). + var ids []string + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + rows.Close() + return nil, err + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + rows.Close() + return nil, err + } + rows.Close() + + projects := make([]*store.Project, 0, len(ids)) + for _, id := range ids { + project, err := s.GetProject(ctx, id) + if err != nil { + return nil, err + } + projects = append(projects, project) + } + return projects, nil +} + +func (s *PostgresStore) NextAvailableSlug(ctx context.Context, baseSlug string) (string, error) { + // Check if the base slug is available + var count int + if err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM projects WHERE slug = $1", baseSlug).Scan(&count); err != nil { + return "", err + } + if count == 0 { + return baseSlug, nil + } + + // Find the next available serial suffix + for i := 1; ; i++ { + candidate := fmt.Sprintf("%s-%d", baseSlug, i) + if err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM projects WHERE slug = $1", candidate).Scan(&count); err != nil { + return "", err + } + if count == 0 { + return candidate, nil + } + } +} + +func (s *PostgresStore) UpdateProject(ctx context.Context, project *store.Project) error { + project.Updated = time.Now() + + result, err := s.db.ExecContext(ctx, ` + UPDATE projects SET + name = $1, slug = $2, git_remote = $3, default_runtime_broker_id = $4, + labels = $5, annotations = $6, shared_dirs = $7, + updated_at = $8, owner_id = $9, visibility = $10, + github_installation_id = $11, github_permissions = $12, github_app_status = $13, + git_identity = $14 + WHERE id = $15 + `, + project.Name, project.Slug, nullableString(project.GitRemote), nullableString(project.DefaultRuntimeBrokerID), + marshalJSON(project.Labels), marshalJSON(project.Annotations), marshalJSON(project.SharedDirs), + project.Updated, project.OwnerID, project.Visibility, + nullableInt64(project.GitHubInstallationID), marshalJSONPtr(project.GitHubPermissions), marshalJSONPtr(project.GitHubAppStatus), + marshalJSONPtr(project.GitIdentity), + project.ID, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeleteProject(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM projects WHERE id = $1", id) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) ListProjects(ctx context.Context, filter store.ProjectFilter, opts store.ListOptions) (*store.ListResult[store.Project], error) { + var conditions []string + var args []interface{} + n := 1 + + if len(filter.MemberOrOwnerIDs) > 0 { + // Combine owner_id match with project ID membership using OR + placeholders := make([]string, len(filter.MemberOrOwnerIDs)) + for i, id := range filter.MemberOrOwnerIDs { + placeholders[i] = fmt.Sprintf("$%d", n) + n++ + args = append(args, id) + } + orParts := []string{"id IN (" + strings.Join(placeholders, ",") + ")"} + if filter.OwnerID != "" { + orParts = append(orParts, fmt.Sprintf("owner_id = $%d", n)) + n++ + args = append(args, filter.OwnerID) + } + conditions = append(conditions, "("+strings.Join(orParts, " OR ")+")") + } else if len(filter.MemberProjectIDs) > 0 { + // Strict project ID membership (no owner OR) + placeholders := make([]string, len(filter.MemberProjectIDs)) + for i, id := range filter.MemberProjectIDs { + placeholders[i] = fmt.Sprintf("$%d", n) + n++ + args = append(args, id) + } + conditions = append(conditions, "id IN ("+strings.Join(placeholders, ",")+")") + } else if filter.OwnerID != "" { + conditions = append(conditions, fmt.Sprintf("owner_id = $%d", n)) + n++ + args = append(args, filter.OwnerID) + } + if filter.ExcludeOwnerID != "" { + conditions = append(conditions, fmt.Sprintf("owner_id != $%d", n)) + n++ + args = append(args, filter.ExcludeOwnerID) + } + if filter.Visibility != "" { + conditions = append(conditions, fmt.Sprintf("visibility = $%d", n)) + n++ + args = append(args, filter.Visibility) + } + if filter.GitRemote != "" { + conditions = append(conditions, fmt.Sprintf("git_remote = $%d", n)) + n++ + args = append(args, filter.GitRemote) + } else if filter.GitRemotePrefix != "" { + conditions = append(conditions, fmt.Sprintf("git_remote LIKE $%d", n)) + n++ + args = append(args, filter.GitRemotePrefix+"%") + } + if filter.BrokerID != "" { + conditions = append(conditions, fmt.Sprintf("id IN (SELECT project_id FROM project_contributors WHERE broker_id = $%d)", n)) + n++ + args = append(args, filter.BrokerID) + } + if filter.Name != "" { + conditions = append(conditions, fmt.Sprintf("LOWER(name) = LOWER($%d)", n)) + n++ + args = append(args, filter.Name) + } + if filter.Slug != "" { + conditions = append(conditions, fmt.Sprintf("LOWER(slug) = LOWER($%d)", n)) + n++ + args = append(args, filter.Slug) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + var totalCount int + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM projects %s", whereClause) + if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil { + return nil, err + } + + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + + query := fmt.Sprintf(` + SELECT id, name, slug, git_remote, default_runtime_broker_id, labels, annotations, shared_dirs, created_at, updated_at, created_by, owner_id, visibility, + github_installation_id, COALESCE(github_permissions, ''), COALESCE(github_app_status, ''), COALESCE(git_identity, '') + FROM projects %s ORDER BY created_at DESC LIMIT $%d + `, whereClause, n) + args = append(args, limit) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var projects []store.Project + type projectRow struct { + project store.Project + labels string + annotations string + sharedDirs string + gitRemote sql.NullString + brokerID sql.NullString + githubInstallationID sql.NullInt64 + githubPermissions string + githubAppStatus string + gitIdentity string + } + var rowData []projectRow + + for rows.Next() { + var r projectRow + if err := rows.Scan( + &r.project.ID, &r.project.Name, &r.project.Slug, &r.gitRemote, &r.brokerID, + &r.labels, &r.annotations, &r.sharedDirs, + &r.project.Created, &r.project.Updated, &r.project.CreatedBy, &r.project.OwnerID, &r.project.Visibility, + &r.githubInstallationID, &r.githubPermissions, &r.githubAppStatus, &r.gitIdentity, + ); err != nil { + return nil, err + } + rowData = append(rowData, r) + } + rows.Close() // Close early to release connection for nested queries + + for _, r := range rowData { + project := r.project + if r.gitRemote.Valid { + project.GitRemote = r.gitRemote.String + } + if r.brokerID.Valid { + project.DefaultRuntimeBrokerID = r.brokerID.String + } + if r.githubInstallationID.Valid { + id := r.githubInstallationID.Int64 + project.GitHubInstallationID = &id + } + unmarshalJSON(r.labels, &project.Labels) + unmarshalJSON(r.annotations, &project.Annotations) + unmarshalJSON(r.sharedDirs, &project.SharedDirs) + if r.githubPermissions != "" { + project.GitHubPermissions = &store.GitHubTokenPermissions{} + unmarshalJSON(r.githubPermissions, project.GitHubPermissions) + } + if r.githubAppStatus != "" { + project.GitHubAppStatus = &store.GitHubAppProjectStatus{} + unmarshalJSON(r.githubAppStatus, project.GitHubAppStatus) + } + if r.gitIdentity != "" { + project.GitIdentity = &store.GitIdentityConfig{} + unmarshalJSON(r.gitIdentity, project.GitIdentity) + } + + // Populate computed fields - these now have a connection available + s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM agents WHERE project_id = $1", project.ID).Scan(&project.AgentCount) + s.db.QueryRowContext(ctx, ` + SELECT (SELECT COUNT(*) FROM project_contributors WHERE project_id = $1 AND status = 'online') + + (SELECT COUNT(*) FROM runtime_brokers WHERE auto_provide = 1 AND status = 'online' + AND id NOT IN (SELECT broker_id FROM project_contributors WHERE project_id = $2)) + `, project.ID, project.ID).Scan(&project.ActiveBrokerCount) + s.populateProjectType(ctx, &project) + + projects = append(projects, project) + } + + return &store.ListResult[store.Project]{ + Items: projects, + TotalCount: totalCount, + }, nil +} diff --git a/pkg/store/postgres/providers.go b/pkg/store/postgres/providers.go new file mode 100644 index 00000000..dc459d0b --- /dev/null +++ b/pkg/store/postgres/providers.go @@ -0,0 +1,215 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +// ============================================================================ +// ProjectProvider Operations + +// ============================================================================ + +func (s *PostgresStore) AddProjectProvider(ctx context.Context, provider *store.ProjectProvider) error { + // Set LinkedAt to now if not already set + if provider.LinkedAt.IsZero() && provider.LinkedBy != "" { + provider.LinkedAt = time.Now() + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO project_contributors (project_id, broker_id, broker_name, local_path, mode, status, profiles, last_seen, linked_by, linked_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ON CONFLICT (project_id, broker_id) DO UPDATE SET + broker_name = EXCLUDED.broker_name, + local_path = EXCLUDED.local_path, + mode = EXCLUDED.mode, + status = EXCLUDED.status, + profiles = EXCLUDED.profiles, + last_seen = EXCLUDED.last_seen, + linked_by = EXCLUDED.linked_by, + linked_at = EXCLUDED.linked_at + `, + provider.ProjectID, provider.BrokerID, provider.BrokerName, provider.LocalPath, "", provider.Status, + "[]", provider.LastSeen, // profiles column kept for schema compat but no longer used + nullableString(provider.LinkedBy), nullableTime(provider.LinkedAt), + ) + return err +} + +func (s *PostgresStore) RemoveProjectProvider(ctx context.Context, projectID, brokerID string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM project_contributors WHERE project_id = $1 AND broker_id = $2", projectID, brokerID) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) GetProjectProvider(ctx context.Context, projectID, brokerID string) (*store.ProjectProvider, error) { + var provider store.ProjectProvider + var localPath, linkedBy sql.NullString + var providerMode, profiles string // unused columns kept for schema compat + var lastSeen, linkedAt sql.NullTime + + err := s.db.QueryRowContext(ctx, ` + SELECT project_id, broker_id, broker_name, local_path, mode, status, profiles, last_seen, linked_by, linked_at + FROM project_contributors WHERE project_id = $1 AND broker_id = $2 + `, projectID, brokerID).Scan( + &provider.ProjectID, &provider.BrokerID, &provider.BrokerName, &localPath, &providerMode, &provider.Status, + &profiles, &lastSeen, &linkedBy, &linkedAt, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + if localPath.Valid { + provider.LocalPath = localPath.String + } + if lastSeen.Valid { + provider.LastSeen = lastSeen.Time + } + if linkedBy.Valid { + provider.LinkedBy = linkedBy.String + } + if linkedAt.Valid { + provider.LinkedAt = linkedAt.Time + } + // profiles column no longer used - lookup from RuntimeBroker.Profiles instead + + return &provider, nil +} + +func (s *PostgresStore) GetProjectProviders(ctx context.Context, projectID string) ([]store.ProjectProvider, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT project_id, broker_id, broker_name, local_path, mode, status, profiles, last_seen, linked_by, linked_at + FROM project_contributors WHERE project_id = $1 + `, projectID) + if err != nil { + return nil, err + } + defer rows.Close() + + var providers []store.ProjectProvider + for rows.Next() { + var provider store.ProjectProvider + var localPath, linkedBy sql.NullString + var providerMode, profiles string // unused columns kept for schema compat + var lastSeen, linkedAt sql.NullTime + + if err := rows.Scan( + &provider.ProjectID, &provider.BrokerID, &provider.BrokerName, &localPath, &providerMode, &provider.Status, + &profiles, &lastSeen, &linkedBy, &linkedAt, + ); err != nil { + return nil, err + } + + if localPath.Valid { + provider.LocalPath = localPath.String + } + if lastSeen.Valid { + provider.LastSeen = lastSeen.Time + } + if linkedBy.Valid { + provider.LinkedBy = linkedBy.String + } + if linkedAt.Valid { + provider.LinkedAt = linkedAt.Time + } + // profiles column no longer used - lookup from RuntimeBroker.Profiles instead + + providers = append(providers, provider) + } + + return providers, nil +} + +func (s *PostgresStore) GetBrokerProjects(ctx context.Context, brokerID string) ([]store.ProjectProvider, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT project_id, broker_id, broker_name, local_path, mode, status, profiles, last_seen, linked_by, linked_at + FROM project_contributors WHERE broker_id = $1 + `, brokerID) + if err != nil { + return nil, err + } + defer rows.Close() + + var providers []store.ProjectProvider + for rows.Next() { + var provider store.ProjectProvider + var localPath, linkedBy sql.NullString + var providerMode, profiles string // unused columns kept for schema compat + var lastSeen, linkedAt sql.NullTime + + if err := rows.Scan( + &provider.ProjectID, &provider.BrokerID, &provider.BrokerName, &localPath, &providerMode, &provider.Status, + &profiles, &lastSeen, &linkedBy, &linkedAt, + ); err != nil { + return nil, err + } + + if localPath.Valid { + provider.LocalPath = localPath.String + } + if lastSeen.Valid { + provider.LastSeen = lastSeen.Time + } + if linkedBy.Valid { + provider.LinkedBy = linkedBy.String + } + if linkedAt.Valid { + provider.LinkedAt = linkedAt.Time + } + // profiles column no longer used - lookup from RuntimeBroker.Profiles instead + + providers = append(providers, provider) + } + + return providers, nil +} + +func (s *PostgresStore) UpdateProviderStatus(ctx context.Context, projectID, brokerID, status string) error { + now := time.Now() + + result, err := s.db.ExecContext(ctx, ` + UPDATE project_contributors SET status = $1, last_seen = $2 WHERE project_id = $3 AND broker_id = $4 + `, status, now, projectID, brokerID) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} diff --git a/pkg/store/postgres/schedule.go b/pkg/store/postgres/schedule.go new file mode 100644 index 00000000..7fbe9de3 --- /dev/null +++ b/pkg/store/postgres/schedule.go @@ -0,0 +1,365 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +// ============================================================================ +// Schedule Operations (Recurring Schedules) +// ============================================================================ + +// CreateSchedule creates a new recurring schedule. +func (s *PostgresStore) CreateSchedule(ctx context.Context, schedule *store.Schedule) error { + if schedule.ID == "" || schedule.ProjectID == "" || schedule.Name == "" || schedule.CronExpr == "" { + return store.ErrInvalidInput + } + + now := time.Now() + if schedule.CreatedAt.IsZero() { + schedule.CreatedAt = now + } + if schedule.UpdatedAt.IsZero() { + schedule.UpdatedAt = now + } + if schedule.Status == "" { + schedule.Status = store.ScheduleStatusActive + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO schedules ( + id, project_id, name, cron_expr, event_type, payload, status, + next_run_at, last_run_at, last_run_status, last_run_error, + run_count, error_count, created_at, created_by, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + `, + schedule.ID, schedule.ProjectID, schedule.Name, schedule.CronExpr, + schedule.EventType, schedule.Payload, schedule.Status, + nullableTime(timeFromNullablePtr(schedule.NextRunAt)), + nullableTime(timeFromNullablePtr(schedule.LastRunAt)), + nullableString(schedule.LastRunStatus), nullableString(schedule.LastRunError), + schedule.RunCount, schedule.ErrorCount, + schedule.CreatedAt, nullableString(schedule.CreatedBy), schedule.UpdatedAt, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + if strings.Contains(err.Error(), "FOREIGN KEY constraint failed") { + return fmt.Errorf("project %s does not exist: %w", schedule.ProjectID, store.ErrInvalidInput) + } + return err + } + return nil +} + +// GetSchedule retrieves a schedule by ID. +func (s *PostgresStore) GetSchedule(ctx context.Context, id string) (*store.Schedule, error) { + schedule := &store.Schedule{} + var nextRunAt, lastRunAt sql.NullTime + var lastRunStatus, lastRunError, createdBy sql.NullString + + err := s.db.QueryRowContext(ctx, ` + SELECT id, project_id, name, cron_expr, event_type, payload, status, + next_run_at, last_run_at, last_run_status, last_run_error, + run_count, error_count, created_at, created_by, updated_at + FROM schedules WHERE id = $1 + `, id).Scan( + &schedule.ID, &schedule.ProjectID, &schedule.Name, &schedule.CronExpr, + &schedule.EventType, &schedule.Payload, &schedule.Status, + &nextRunAt, &lastRunAt, &lastRunStatus, &lastRunError, + &schedule.RunCount, &schedule.ErrorCount, + &schedule.CreatedAt, &createdBy, &schedule.UpdatedAt, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + if nextRunAt.Valid { + schedule.NextRunAt = &nextRunAt.Time + } + if lastRunAt.Valid { + schedule.LastRunAt = &lastRunAt.Time + } + if lastRunStatus.Valid { + schedule.LastRunStatus = lastRunStatus.String + } + if lastRunError.Valid { + schedule.LastRunError = lastRunError.String + } + if createdBy.Valid { + schedule.CreatedBy = createdBy.String + } + + return schedule, nil +} + +// ListSchedules returns schedules matching the filter criteria. +func (s *PostgresStore) ListSchedules(ctx context.Context, filter store.ScheduleFilter, opts store.ListOptions) (*store.ListResult[store.Schedule], error) { + var conditions []string + var args []interface{} + + if filter.ProjectID != "" { + args = append(args, filter.ProjectID) + conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args))) + } + if filter.Status != "" { + args = append(args, filter.Status) + conditions = append(conditions, fmt.Sprintf("status = $%d", len(args))) + } else { + // By default, exclude deleted schedules + args = append(args, store.ScheduleStatusDeleted) + conditions = append(conditions, fmt.Sprintf("status != $%d", len(args))) + } + if filter.Name != "" { + args = append(args, filter.Name) + conditions = append(conditions, fmt.Sprintf("name = $%d", len(args))) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + // Get total count + var totalCount int + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM schedules %s", whereClause) + if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil { + return nil, err + } + + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + if limit > 200 { + limit = 200 + } + + query := fmt.Sprintf(` + SELECT id, project_id, name, cron_expr, event_type, payload, status, + next_run_at, last_run_at, last_run_status, last_run_error, + run_count, error_count, created_at, created_by, updated_at + FROM schedules %s + ORDER BY created_at DESC + LIMIT $%d + `, whereClause, len(args)+1) + + queryArgs := append(args, limit+1) //nolint:gocritic + + rows, err := s.db.QueryContext(ctx, query, queryArgs...) + if err != nil { + return nil, err + } + defer rows.Close() + + schedules, err := scanSchedules(rows) + if err != nil { + return nil, err + } + + result := &store.ListResult[store.Schedule]{ + TotalCount: totalCount, + } + + if len(schedules) > limit { + result.Items = schedules[:limit] + result.NextCursor = schedules[limit-1].ID + } else { + result.Items = schedules + } + + return result, nil +} + +// UpdateSchedule updates an existing schedule. +func (s *PostgresStore) UpdateSchedule(ctx context.Context, schedule *store.Schedule) error { + schedule.UpdatedAt = time.Now() + + result, err := s.db.ExecContext(ctx, ` + UPDATE schedules SET + name = $1, cron_expr = $2, event_type = $3, payload = $4, + status = $5, next_run_at = $6, updated_at = $7 + WHERE id = $8 + `, + schedule.Name, schedule.CronExpr, schedule.EventType, schedule.Payload, + schedule.Status, nullableTime(timeFromNullablePtr(schedule.NextRunAt)), + schedule.UpdatedAt, schedule.ID, + ) + if err != nil { + return err + } + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +// UpdateScheduleStatus updates only the status of a schedule. +func (s *PostgresStore) UpdateScheduleStatus(ctx context.Context, id string, status string) error { + result, err := s.db.ExecContext(ctx, ` + UPDATE schedules SET status = $1, updated_at = $2 WHERE id = $3 + `, status, time.Now(), id) + if err != nil { + return err + } + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +// UpdateScheduleAfterRun updates a schedule after a run completes. +func (s *PostgresStore) UpdateScheduleAfterRun(ctx context.Context, id string, ranAt time.Time, nextRunAt time.Time, errMsg string) error { + var query string + var args []interface{} + + if errMsg != "" { + query = ` + UPDATE schedules SET + last_run_at = $1, next_run_at = $2, last_run_status = $3, last_run_error = $4, + run_count = run_count + 1, error_count = error_count + 1, updated_at = $5 + WHERE id = $6 + ` + args = []interface{}{ranAt, nextRunAt, store.ScheduleRunError, errMsg, time.Now(), id} + } else { + query = ` + UPDATE schedules SET + last_run_at = $1, next_run_at = $2, last_run_status = $3, last_run_error = NULL, + run_count = run_count + 1, updated_at = $4 + WHERE id = $5 + ` + args = []interface{}{ranAt, nextRunAt, store.ScheduleRunSuccess, time.Now(), id} + } + + result, err := s.db.ExecContext(ctx, query, args...) + if err != nil { + return err + } + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +// DeleteSchedule removes a schedule by ID (hard delete). +func (s *PostgresStore) DeleteSchedule(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM schedules WHERE id = $1", id) + if err != nil { + return err + } + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +// ListDueSchedules returns active schedules whose next_run_at has passed. +func (s *PostgresStore) ListDueSchedules(ctx context.Context, now time.Time) ([]store.Schedule, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, project_id, name, cron_expr, event_type, payload, status, + next_run_at, last_run_at, last_run_status, last_run_error, + run_count, error_count, created_at, created_by, updated_at + FROM schedules + WHERE status = $1 AND next_run_at IS NOT NULL AND next_run_at <= $2 + ORDER BY next_run_at ASC + `, store.ScheduleStatusActive, now) + if err != nil { + return nil, err + } + defer rows.Close() + + return scanSchedules(rows) +} + +// ============================================================================ +// Helpers +// ============================================================================ + +// timeFromNullablePtr returns the time from a pointer, or zero time if nil. +func timeFromNullablePtr(t *time.Time) time.Time { + if t == nil { + return time.Time{} + } + return *t +} + +// scanSchedules scans rows into Schedule slices. +func scanSchedules(rows *sql.Rows) ([]store.Schedule, error) { + var schedules []store.Schedule + for rows.Next() { + var schedule store.Schedule + var nextRunAt, lastRunAt sql.NullTime + var lastRunStatus, lastRunError, createdBy sql.NullString + + if err := rows.Scan( + &schedule.ID, &schedule.ProjectID, &schedule.Name, &schedule.CronExpr, + &schedule.EventType, &schedule.Payload, &schedule.Status, + &nextRunAt, &lastRunAt, &lastRunStatus, &lastRunError, + &schedule.RunCount, &schedule.ErrorCount, + &schedule.CreatedAt, &createdBy, &schedule.UpdatedAt, + ); err != nil { + return nil, err + } + + if nextRunAt.Valid { + schedule.NextRunAt = &nextRunAt.Time + } + if lastRunAt.Valid { + schedule.LastRunAt = &lastRunAt.Time + } + if lastRunStatus.Valid { + schedule.LastRunStatus = lastRunStatus.String + } + if lastRunError.Valid { + schedule.LastRunError = lastRunError.String + } + if createdBy.Valid { + schedule.CreatedBy = createdBy.String + } + schedules = append(schedules, schedule) + } + if err := rows.Err(); err != nil { + return nil, err + } + return schedules, nil +} diff --git a/pkg/store/postgres/scheduled_event.go b/pkg/store/postgres/scheduled_event.go new file mode 100644 index 00000000..4be53b7b --- /dev/null +++ b/pkg/store/postgres/scheduled_event.go @@ -0,0 +1,317 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +// ============================================================================ +// Scheduled Event Operations +// ============================================================================ + +// CreateScheduledEvent creates a new scheduled event. +func (s *PostgresStore) CreateScheduledEvent(ctx context.Context, event *store.ScheduledEvent) error { + if event.ID == "" || event.ProjectID == "" || event.EventType == "" { + return store.ErrInvalidInput + } + + now := time.Now() + if event.CreatedAt.IsZero() { + event.CreatedAt = now + } + if event.Status == "" { + event.Status = store.ScheduledEventPending + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO scheduled_events ( + id, project_id, event_type, fire_at, payload, status, + created_at, created_by, fired_at, error, schedule_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + `, + event.ID, event.ProjectID, event.EventType, event.FireAt, event.Payload, event.Status, + event.CreatedAt, nullableString(event.CreatedBy), nullableTime(timeFromPtr(event.FiredAt)), nullableString(event.Error), + nullableString(event.ScheduleID), + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + if strings.Contains(err.Error(), "foreign key constraint") || strings.Contains(err.Error(), "FOREIGN KEY constraint failed") { + return fmt.Errorf("project %s does not exist: %w", event.ProjectID, store.ErrInvalidInput) + } + return err + } + return nil +} + +// GetScheduledEvent retrieves a scheduled event by ID. +func (s *PostgresStore) GetScheduledEvent(ctx context.Context, id string) (*store.ScheduledEvent, error) { + event := &store.ScheduledEvent{} + var createdBy sql.NullString + var firedAt sql.NullTime + var errMsg sql.NullString + var scheduleID sql.NullString + + err := s.db.QueryRowContext(ctx, ` + SELECT id, project_id, event_type, fire_at, payload, status, + created_at, created_by, fired_at, error, schedule_id + FROM scheduled_events WHERE id = $1 + `, id).Scan( + &event.ID, &event.ProjectID, &event.EventType, &event.FireAt, &event.Payload, &event.Status, + &event.CreatedAt, &createdBy, &firedAt, &errMsg, &scheduleID, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + if createdBy.Valid { + event.CreatedBy = createdBy.String + } + if firedAt.Valid { + event.FiredAt = &firedAt.Time + } + if errMsg.Valid { + event.Error = errMsg.String + } + if scheduleID.Valid { + event.ScheduleID = scheduleID.String + } + + return event, nil +} + +// ListPendingScheduledEvents returns all events with status "pending", +// ordered by fire_at ASC. +func (s *PostgresStore) ListPendingScheduledEvents(ctx context.Context) ([]store.ScheduledEvent, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, project_id, event_type, fire_at, payload, status, + created_at, created_by, fired_at, error, schedule_id + FROM scheduled_events + WHERE status = $1 + ORDER BY fire_at ASC + `, store.ScheduledEventPending) + if err != nil { + return nil, err + } + defer rows.Close() + + return scanScheduledEvents(rows) +} + +// UpdateScheduledEventStatus updates the status and optional error for an event. +func (s *PostgresStore) UpdateScheduledEventStatus(ctx context.Context, id string, status string, firedAt *time.Time, errMsg string) error { + _, err := s.db.ExecContext(ctx, ` + UPDATE scheduled_events SET status = $1, fired_at = $2, error = $3 + WHERE id = $4 + `, status, nullableTime(timeFromPtr(firedAt)), nullableString(errMsg), id) + return err +} + +// CancelScheduledEvent marks an event as cancelled. +// Returns ErrNotFound if the event doesn't exist or is not pending. +func (s *PostgresStore) CancelScheduledEvent(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, ` + UPDATE scheduled_events SET status = $1 + WHERE id = $2 AND status = $3 + `, store.ScheduledEventCancelled, id, store.ScheduledEventPending) + if err != nil { + return err + } + + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return store.ErrNotFound + } + return nil +} + +// ListScheduledEvents returns events matching the filter criteria. +func (s *PostgresStore) ListScheduledEvents(ctx context.Context, filter store.ScheduledEventFilter, opts store.ListOptions) (*store.ListResult[store.ScheduledEvent], error) { + var conditions []string + var args []interface{} + + if filter.ProjectID != "" { + conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)+1)) + args = append(args, filter.ProjectID) + } + if filter.EventType != "" { + conditions = append(conditions, fmt.Sprintf("event_type = $%d", len(args)+1)) + args = append(args, filter.EventType) + } + if filter.Status != "" { + conditions = append(conditions, fmt.Sprintf("status = $%d", len(args)+1)) + args = append(args, filter.Status) + } + if filter.ScheduleID != "" { + conditions = append(conditions, fmt.Sprintf("schedule_id = $%d", len(args)+1)) + args = append(args, filter.ScheduleID) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + // Get total count + var totalCount int + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM scheduled_events %s", whereClause) + if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil { + return nil, err + } + + // Apply pagination + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + if limit > 200 { + limit = 200 + } + + query := fmt.Sprintf(` + SELECT id, project_id, event_type, fire_at, payload, status, + created_at, created_by, fired_at, error, schedule_id + FROM scheduled_events %s + ORDER BY created_at DESC + LIMIT $%d + `, whereClause, len(args)+1) + + queryArgs := append(args, limit+1) //nolint:gocritic // intentional append to copy + + if opts.Cursor != "" { + if whereClause == "" { + query = fmt.Sprintf(` + SELECT id, project_id, event_type, fire_at, payload, status, + created_at, created_by, fired_at, error, schedule_id + FROM scheduled_events WHERE id < $%d + ORDER BY created_at DESC + LIMIT $%d + `, len(args)+1, len(args)+2) + } else { + query = fmt.Sprintf(` + SELECT id, project_id, event_type, fire_at, payload, status, + created_at, created_by, fired_at, error, schedule_id + FROM scheduled_events %s AND id < $%d + ORDER BY created_at DESC + LIMIT $%d + `, whereClause, len(args)+1, len(args)+2) + } + queryArgs = append(args, opts.Cursor, limit+1) //nolint:gocritic + } + + rows, err := s.db.QueryContext(ctx, query, queryArgs...) + if err != nil { + return nil, err + } + defer rows.Close() + + events, err := scanScheduledEvents(rows) + if err != nil { + return nil, err + } + + result := &store.ListResult[store.ScheduledEvent]{ + TotalCount: totalCount, + } + + if len(events) > limit { + result.Items = events[:limit] + result.NextCursor = events[limit-1].ID + } else { + result.Items = events + } + + return result, nil +} + +// PurgeOldScheduledEvents removes non-pending events older than cutoff. +func (s *PostgresStore) PurgeOldScheduledEvents(ctx context.Context, cutoff time.Time) (int, error) { + result, err := s.db.ExecContext(ctx, + "DELETE FROM scheduled_events WHERE status != $1 AND created_at < $2", + store.ScheduledEventPending, cutoff, + ) + if err != nil { + return 0, err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return 0, err + } + return int(rowsAffected), nil +} + +// ============================================================================ +// Helpers +// ============================================================================ + +// timeFromPtr returns the time from a pointer, or zero time if nil. +func timeFromPtr(t *time.Time) time.Time { + if t == nil { + return time.Time{} + } + return *t +} + +// scanScheduledEvents scans rows into ScheduledEvent slices. +func scanScheduledEvents(rows *sql.Rows) ([]store.ScheduledEvent, error) { + var events []store.ScheduledEvent + for rows.Next() { + var event store.ScheduledEvent + var createdBy sql.NullString + var firedAt sql.NullTime + var errMsg sql.NullString + var scheduleID sql.NullString + + if err := rows.Scan( + &event.ID, &event.ProjectID, &event.EventType, &event.FireAt, &event.Payload, &event.Status, + &event.CreatedAt, &createdBy, &firedAt, &errMsg, &scheduleID, + ); err != nil { + return nil, err + } + + if createdBy.Valid { + event.CreatedBy = createdBy.String + } + if firedAt.Valid { + event.FiredAt = &firedAt.Time + } + if errMsg.Valid { + event.Error = errMsg.String + } + if scheduleID.Valid { + event.ScheduleID = scheduleID.String + } + events = append(events, event) + } + if err := rows.Err(); err != nil { + return nil, err + } + return events, nil +} diff --git a/pkg/store/postgres/secrets.go b/pkg/store/postgres/secrets.go new file mode 100644 index 00000000..79f19ba6 --- /dev/null +++ b/pkg/store/postgres/secrets.go @@ -0,0 +1,325 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres provides a PostgreSQL implementation of the Store interface. +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +func (s *PostgresStore) CreateSecret(ctx context.Context, secret *store.Secret) error { + now := time.Now() + secret.Created = now + secret.Updated = now + secret.Version = 1 + + if secret.SecretType == "" { + secret.SecretType = store.SecretTypeEnvironment + } + if secret.Target == "" { + secret.Target = secret.Key + } + if secret.InjectionMode == "" { + secret.InjectionMode = store.InjectionModeAsNeeded + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO secrets (id, key, encrypted_value, secret_ref, secret_type, target, scope, scope_id, description, injection_mode, allow_progeny, version, created_at, updated_at, created_by, updated_by) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + `, + secret.ID, secret.Key, secret.EncryptedValue, nullableString(secret.SecretRef), + secret.SecretType, nullableString(secret.Target), + secret.Scope, secret.ScopeID, + secret.Description, secret.InjectionMode, boolToInt(secret.AllowProgeny), secret.Version, + secret.Created, secret.Updated, secret.CreatedBy, secret.UpdatedBy, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) GetSecret(ctx context.Context, key, scope, scopeID string) (*store.Secret, error) { + secret := &store.Secret{} + var target sql.NullString + var secretRef sql.NullString + + var allowProgeny int + err := s.db.QueryRowContext(ctx, ` + SELECT id, key, encrypted_value, secret_ref, secret_type, COALESCE(target, key), scope, scope_id, description, injection_mode, allow_progeny, version, created_at, updated_at, created_by, updated_by + FROM secrets WHERE key = $1 AND scope = $2 AND scope_id = $3 + `, key, scope, scopeID).Scan( + &secret.ID, &secret.Key, &secret.EncryptedValue, &secretRef, + &secret.SecretType, &target, + &secret.Scope, &secret.ScopeID, + &secret.Description, &secret.InjectionMode, &allowProgeny, &secret.Version, + &secret.Created, &secret.Updated, &secret.CreatedBy, &secret.UpdatedBy, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + if target.Valid { + secret.Target = target.String + } + if secretRef.Valid { + secret.SecretRef = secretRef.String + } + secret.AllowProgeny = allowProgeny != 0 + + return secret, nil +} + +func (s *PostgresStore) UpdateSecret(ctx context.Context, secret *store.Secret) error { + secret.Updated = time.Now() + secret.Version++ // Increment version on each update + + if secret.SecretType == "" { + secret.SecretType = store.SecretTypeEnvironment + } + if secret.Target == "" { + secret.Target = secret.Key + } + if secret.InjectionMode == "" { + secret.InjectionMode = store.InjectionModeAsNeeded + } + + result, err := s.db.ExecContext(ctx, ` + UPDATE secrets SET + encrypted_value = $1, secret_ref = $2, secret_type = $3, target = $4, description = $5, injection_mode = $6, allow_progeny = $7, version = $8, updated_at = $9, updated_by = $10 + WHERE key = $11 AND scope = $12 AND scope_id = $13 + `, + secret.EncryptedValue, nullableString(secret.SecretRef), + secret.SecretType, nullableString(secret.Target), + secret.Description, secret.InjectionMode, boolToInt(secret.AllowProgeny), secret.Version, secret.Updated, secret.UpdatedBy, + secret.Key, secret.Scope, secret.ScopeID, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) UpsertSecret(ctx context.Context, secret *store.Secret) (bool, error) { + now := time.Now() + secret.Updated = now + + // Check if it already exists + existing, err := s.GetSecret(ctx, secret.Key, secret.Scope, secret.ScopeID) + if err != nil && err != store.ErrNotFound { + return false, err + } + + if existing != nil { + // Update existing + secret.ID = existing.ID + secret.Created = existing.Created + secret.CreatedBy = existing.CreatedBy + secret.Version = existing.Version // Will be incremented in UpdateSecret + if err := s.UpdateSecret(ctx, secret); err != nil { + return false, err + } + return false, nil + } + + // Create new + secret.Created = now + if err := s.CreateSecret(ctx, secret); err != nil { + return false, err + } + return true, nil +} + +func (s *PostgresStore) DeleteSecret(ctx context.Context, key, scope, scopeID string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM secrets WHERE key = $1 AND scope = $2 AND scope_id = $3", key, scope, scopeID) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeleteSecretsByScope(ctx context.Context, scope, scopeID string) (int, error) { + result, err := s.db.ExecContext(ctx, "DELETE FROM secrets WHERE scope = $1 AND scope_id = $2", scope, scopeID) + if err != nil { + return 0, err + } + n, err := result.RowsAffected() + if err != nil { + return 0, err + } + return int(n), nil +} + +func (s *PostgresStore) ListSecrets(ctx context.Context, filter store.SecretFilter) ([]store.Secret, error) { + var conditions []string + var args []interface{} + + if filter.Scope != "" { + conditions = append(conditions, fmt.Sprintf("scope = $%d", len(args)+1)) + args = append(args, filter.Scope) + } + if filter.ScopeID != "" { + conditions = append(conditions, fmt.Sprintf("scope_id = $%d", len(args)+1)) + args = append(args, filter.ScopeID) + } + if filter.Key != "" { + conditions = append(conditions, fmt.Sprintf("key = $%d", len(args)+1)) + args = append(args, filter.Key) + } + if filter.Type != "" { + conditions = append(conditions, fmt.Sprintf("secret_type = $%d", len(args)+1)) + args = append(args, filter.Type) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + // Note: We do NOT select encrypted_value for listing + query := fmt.Sprintf(` + SELECT id, key, secret_ref, secret_type, COALESCE(target, key), scope, scope_id, description, injection_mode, allow_progeny, version, created_at, updated_at, created_by, updated_by + FROM secrets %s ORDER BY key + `, whereClause) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var secrets []store.Secret + for rows.Next() { + var secret store.Secret + var target sql.NullString + var secretRef sql.NullString + var allowProgeny int + if err := rows.Scan( + &secret.ID, &secret.Key, &secretRef, &secret.SecretType, &target, + &secret.Scope, &secret.ScopeID, + &secret.Description, &secret.InjectionMode, &allowProgeny, &secret.Version, + &secret.Created, &secret.Updated, &secret.CreatedBy, &secret.UpdatedBy, + ); err != nil { + return nil, err + } + if target.Valid { + secret.Target = target.String + } + if secretRef.Valid { + secret.SecretRef = secretRef.String + } + secret.AllowProgeny = allowProgeny != 0 + secrets = append(secrets, secret) + } + + return secrets, nil +} + +func (s *PostgresStore) ListProgenySecrets(ctx context.Context, ancestorIDs []string) ([]store.Secret, error) { + if len(ancestorIDs) == 0 { + return nil, nil + } + + // Build placeholder list for IN clause + placeholders := make([]string, len(ancestorIDs)) + args := make([]interface{}, len(ancestorIDs)) + for i, id := range ancestorIDs { + placeholders[i] = fmt.Sprintf("$%d", i+1) + args[i] = id + } + + query := fmt.Sprintf(` + SELECT id, key, secret_ref, secret_type, COALESCE(target, key), scope, scope_id, description, injection_mode, allow_progeny, version, created_at, updated_at, created_by, updated_by + FROM secrets + WHERE scope = 'user' AND allow_progeny = 1 AND created_by IN (%s) + ORDER BY key + `, strings.Join(placeholders, ", ")) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var secrets []store.Secret + for rows.Next() { + var secret store.Secret + var target sql.NullString + var secretRef sql.NullString + var allowProgeny int + if err := rows.Scan( + &secret.ID, &secret.Key, &secretRef, &secret.SecretType, &target, + &secret.Scope, &secret.ScopeID, + &secret.Description, &secret.InjectionMode, &allowProgeny, &secret.Version, + &secret.Created, &secret.Updated, &secret.CreatedBy, &secret.UpdatedBy, + ); err != nil { + return nil, err + } + if target.Valid { + secret.Target = target.String + } + if secretRef.Valid { + secret.SecretRef = secretRef.String + } + secret.AllowProgeny = allowProgeny != 0 + secrets = append(secrets, secret) + } + + return secrets, nil +} + +func (s *PostgresStore) GetSecretValue(ctx context.Context, key, scope, scopeID string) (string, error) { + var encryptedValue string + + err := s.db.QueryRowContext(ctx, ` + SELECT encrypted_value FROM secrets WHERE key = $1 AND scope = $2 AND scope_id = $3 + `, key, scope, scopeID).Scan(&encryptedValue) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", store.ErrNotFound + } + return "", err + } + + return encryptedValue, nil +} diff --git a/pkg/store/postgres/templates.go b/pkg/store/postgres/templates.go new file mode 100644 index 00000000..cb352899 --- /dev/null +++ b/pkg/store/postgres/templates.go @@ -0,0 +1,384 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +// ============================================================================ +// Template Operations +// ============================================================================ + +func (s *PostgresStore) CreateTemplate(ctx context.Context, template *store.Template) error { + now := time.Now() + template.Created = now + template.Updated = now + + // Set default status if not provided + if template.Status == "" { + template.Status = store.TemplateStatusActive + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO templates ( + id, name, slug, display_name, description, harness, default_harness_config, image, config, + content_hash, scope, scope_id, project_id, + storage_uri, storage_bucket, storage_path, files, + base_template, locked, status, + owner_id, created_by, updated_by, visibility, + created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) + `, + template.ID, template.Name, template.Slug, nullableString(template.DisplayName), nullableString(template.Description), + template.Harness, nullableString(template.DefaultHarnessConfig), template.Image, marshalJSON(template.Config), + nullableString(template.ContentHash), template.Scope, nullableString(template.ScopeID), nullableString(template.ProjectID), + nullableString(template.StorageURI), nullableString(template.StorageBucket), nullableString(template.StoragePath), marshalJSON(template.Files), + nullableString(template.BaseTemplate), template.Locked, template.Status, + nullableString(template.OwnerID), nullableString(template.CreatedBy), nullableString(template.UpdatedBy), template.Visibility, + template.Created, template.Updated, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) GetTemplate(ctx context.Context, id string) (*store.Template, error) { + template := &store.Template{} + var config, files string + var displayName, description, contentHash, scopeID, projectID sql.NullString + var storageURI, storageBucket, storagePath, baseTemplate sql.NullString + var createdBy, updatedBy, ownerID, visibility sql.NullString + var defaultHarnessConfig sql.NullString + + err := s.db.QueryRowContext(ctx, ` + SELECT id, name, slug, display_name, description, harness, default_harness_config, image, config, + content_hash, scope, scope_id, project_id, + storage_uri, storage_bucket, storage_path, files, + base_template, locked, status, + owner_id, created_by, updated_by, visibility, + created_at, updated_at + FROM templates WHERE id = $1 + `, id).Scan( + &template.ID, &template.Name, &template.Slug, &displayName, &description, + &template.Harness, &defaultHarnessConfig, &template.Image, &config, + &contentHash, &template.Scope, &scopeID, &projectID, + &storageURI, &storageBucket, &storagePath, &files, + &baseTemplate, &template.Locked, &template.Status, + &ownerID, &createdBy, &updatedBy, &visibility, + &template.Created, &template.Updated, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + if displayName.Valid { + template.DisplayName = displayName.String + } + if description.Valid { + template.Description = description.String + } + if defaultHarnessConfig.Valid { + template.DefaultHarnessConfig = defaultHarnessConfig.String + } + if contentHash.Valid { + template.ContentHash = contentHash.String + } + if scopeID.Valid { + template.ScopeID = scopeID.String + } + if projectID.Valid { + template.ProjectID = projectID.String + } + if storageURI.Valid { + template.StorageURI = storageURI.String + } + if storageBucket.Valid { + template.StorageBucket = storageBucket.String + } + if storagePath.Valid { + template.StoragePath = storagePath.String + } + if baseTemplate.Valid { + template.BaseTemplate = baseTemplate.String + } + if ownerID.Valid { + template.OwnerID = ownerID.String + } + if createdBy.Valid { + template.CreatedBy = createdBy.String + } + if updatedBy.Valid { + template.UpdatedBy = updatedBy.String + } + if visibility.Valid { + template.Visibility = visibility.String + } + unmarshalJSON(config, &template.Config) + unmarshalJSON(files, &template.Files) + + return template, nil +} + +func (s *PostgresStore) GetTemplateBySlug(ctx context.Context, slug, scope, scopeID string) (*store.Template, error) { + var id string + var err error + + if scope == "project" && scopeID != "" { + // Try scope_id first, then fall back to project_id for backwards compatibility + err = s.db.QueryRowContext(ctx, "SELECT id FROM templates WHERE slug = $1 AND scope = $2 AND (scope_id = $3 OR project_id = $4)", slug, scope, scopeID, scopeID).Scan(&id) + } else if scope == "user" && scopeID != "" { + err = s.db.QueryRowContext(ctx, "SELECT id FROM templates WHERE slug = $1 AND scope = $2 AND scope_id = $3", slug, scope, scopeID).Scan(&id) + } else { + err = s.db.QueryRowContext(ctx, "SELECT id FROM templates WHERE slug = $1 AND scope = $2", slug, scope).Scan(&id) + } + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + return s.GetTemplate(ctx, id) +} + +func (s *PostgresStore) UpdateTemplate(ctx context.Context, template *store.Template) error { + template.Updated = time.Now() + + result, err := s.db.ExecContext(ctx, ` + UPDATE templates SET + name = $1, slug = $2, display_name = $3, description = $4, + harness = $5, default_harness_config = $6, image = $7, config = $8, + content_hash = $9, scope = $10, scope_id = $11, project_id = $12, + storage_uri = $13, storage_bucket = $14, storage_path = $15, files = $16, + base_template = $17, locked = $18, status = $19, + owner_id = $20, updated_by = $21, visibility = $22, + updated_at = $23 + WHERE id = $24 + `, + template.Name, template.Slug, nullableString(template.DisplayName), nullableString(template.Description), + template.Harness, nullableString(template.DefaultHarnessConfig), template.Image, marshalJSON(template.Config), + nullableString(template.ContentHash), template.Scope, nullableString(template.ScopeID), nullableString(template.ProjectID), + nullableString(template.StorageURI), nullableString(template.StorageBucket), nullableString(template.StoragePath), marshalJSON(template.Files), + nullableString(template.BaseTemplate), template.Locked, template.Status, + nullableString(template.OwnerID), nullableString(template.UpdatedBy), template.Visibility, + template.Updated, + template.ID, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeleteTemplate(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM templates WHERE id = $1", id) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeleteTemplatesByScope(ctx context.Context, scope, scopeID string) (int, error) { + result, err := s.db.ExecContext(ctx, "DELETE FROM templates WHERE scope = $1 AND scope_id = $2", scope, scopeID) + if err != nil { + return 0, err + } + n, err := result.RowsAffected() + if err != nil { + return 0, err + } + return int(n), nil +} + +func (s *PostgresStore) ListTemplates(ctx context.Context, filter store.TemplateFilter, opts store.ListOptions) (*store.ListResult[store.Template], error) { + var conditions []string + var args []interface{} + + if filter.Name != "" { + // Exact match on name or slug + conditions = append(conditions, fmt.Sprintf("(name = $%d OR slug = $%d)", len(args)+1, len(args)+2)) + args = append(args, filter.Name, filter.Name) + } + if filter.Scope != "" { + conditions = append(conditions, fmt.Sprintf("scope = $%d", len(args)+1)) + args = append(args, filter.Scope) + } + if filter.ScopeID != "" { + conditions = append(conditions, fmt.Sprintf("(scope_id = $%d OR project_id = $%d)", len(args)+1, len(args)+2)) + args = append(args, filter.ScopeID, filter.ScopeID) + } else if filter.ProjectID != "" && filter.Scope == "" { + // When projectId is set without scope, return global + project-scoped templates for this project + conditions = append(conditions, fmt.Sprintf("(scope = 'global' OR (scope = 'project' AND (scope_id = $%d OR project_id = $%d)))", len(args)+1, len(args)+2)) + args = append(args, filter.ProjectID, filter.ProjectID) + } else if filter.ProjectID != "" { + // Backwards compatibility: projectId with explicit scope + conditions = append(conditions, fmt.Sprintf("(scope_id = $%d OR project_id = $%d)", len(args)+1, len(args)+2)) + args = append(args, filter.ProjectID, filter.ProjectID) + } + if filter.Harness != "" { + conditions = append(conditions, fmt.Sprintf("harness = $%d", len(args)+1)) + args = append(args, filter.Harness) + } + if filter.OwnerID != "" { + conditions = append(conditions, fmt.Sprintf("owner_id = $%d", len(args)+1)) + args = append(args, filter.OwnerID) + } + if filter.Status != "" { + conditions = append(conditions, fmt.Sprintf("status = $%d", len(args)+1)) + args = append(args, filter.Status) + } + if filter.Search != "" { + conditions = append(conditions, fmt.Sprintf("(name LIKE $%d OR description LIKE $%d)", len(args)+1, len(args)+2)) + searchPattern := "%" + filter.Search + "%" + args = append(args, searchPattern, searchPattern) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + var totalCount int + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM templates %s", whereClause) + if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil { + return nil, err + } + + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + + query := fmt.Sprintf(` + SELECT id, name, slug, display_name, description, harness, default_harness_config, image, config, + content_hash, scope, scope_id, project_id, + storage_uri, storage_bucket, storage_path, files, + base_template, locked, status, + owner_id, created_by, updated_by, visibility, + created_at, updated_at + FROM templates %s ORDER BY created_at DESC LIMIT $%d + `, whereClause, len(args)+1) + args = append(args, limit) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var templates []store.Template + for rows.Next() { + var template store.Template + var config, files string + var displayName, description, contentHash, scopeID, projectID sql.NullString + var storageURI, storageBucket, storagePath, baseTemplate sql.NullString + var createdBy, updatedBy, ownerID, visibility sql.NullString + var defaultHarnessConfig sql.NullString + + if err := rows.Scan( + &template.ID, &template.Name, &template.Slug, &displayName, &description, + &template.Harness, &defaultHarnessConfig, &template.Image, &config, + &contentHash, &template.Scope, &scopeID, &projectID, + &storageURI, &storageBucket, &storagePath, &files, + &baseTemplate, &template.Locked, &template.Status, + &ownerID, &createdBy, &updatedBy, &visibility, + &template.Created, &template.Updated, + ); err != nil { + return nil, err + } + + if displayName.Valid { + template.DisplayName = displayName.String + } + if description.Valid { + template.Description = description.String + } + if defaultHarnessConfig.Valid { + template.DefaultHarnessConfig = defaultHarnessConfig.String + } + if contentHash.Valid { + template.ContentHash = contentHash.String + } + if scopeID.Valid { + template.ScopeID = scopeID.String + } + if projectID.Valid { + template.ProjectID = projectID.String + } + if storageURI.Valid { + template.StorageURI = storageURI.String + } + if storageBucket.Valid { + template.StorageBucket = storageBucket.String + } + if storagePath.Valid { + template.StoragePath = storagePath.String + } + if baseTemplate.Valid { + template.BaseTemplate = baseTemplate.String + } + if ownerID.Valid { + template.OwnerID = ownerID.String + } + if createdBy.Valid { + template.CreatedBy = createdBy.String + } + if updatedBy.Valid { + template.UpdatedBy = updatedBy.String + } + if visibility.Valid { + template.Visibility = visibility.String + } + unmarshalJSON(config, &template.Config) + unmarshalJSON(files, &template.Files) + + templates = append(templates, template) + } + + return &store.ListResult[store.Template]{ + Items: templates, + TotalCount: totalCount, + }, nil +} diff --git a/pkg/store/postgres/tokens.go b/pkg/store/postgres/tokens.go new file mode 100644 index 00000000..ba991c0b --- /dev/null +++ b/pkg/store/postgres/tokens.go @@ -0,0 +1,201 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +// ============================================================================ +// User Access Token Operations +// ============================================================================ + +func (s *PostgresStore) CreateUserAccessToken(ctx context.Context, token *store.UserAccessToken) error { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO user_access_tokens ( + id, user_id, name, prefix, key_hash, project_id, scopes, + revoked, expires_at, last_used, created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + `, + token.ID, token.UserID, token.Name, token.Prefix, token.KeyHash, + token.ProjectID, marshalJSON(token.Scopes), + boolToInt(token.Revoked), ptrToNullTime(token.ExpiresAt), ptrToNullTime(token.LastUsed), token.Created, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + if strings.Contains(err.Error(), "foreign key constraint") || strings.Contains(err.Error(), "FOREIGN KEY constraint failed") { + return store.ErrInvalidInput + } + return err + } + return nil +} + +func (s *PostgresStore) GetUserAccessToken(ctx context.Context, id string) (*store.UserAccessToken, error) { + token := &store.UserAccessToken{} + var scopes string + var expiresAt, lastUsed sql.NullTime + + err := s.db.QueryRowContext(ctx, ` + SELECT id, user_id, name, prefix, key_hash, project_id, scopes, + revoked, expires_at, last_used, created_at + FROM user_access_tokens WHERE id = $1 + `, id).Scan( + &token.ID, &token.UserID, &token.Name, &token.Prefix, &token.KeyHash, + &token.ProjectID, &scopes, + &token.Revoked, &expiresAt, &lastUsed, &token.Created, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + unmarshalJSON(scopes, &token.Scopes) + if expiresAt.Valid { + token.ExpiresAt = &expiresAt.Time + } + if lastUsed.Valid { + token.LastUsed = &lastUsed.Time + } + return token, nil +} + +func (s *PostgresStore) GetUserAccessTokenByHash(ctx context.Context, hash string) (*store.UserAccessToken, error) { + token := &store.UserAccessToken{} + var scopes string + var expiresAt, lastUsed sql.NullTime + + err := s.db.QueryRowContext(ctx, ` + SELECT id, user_id, name, prefix, key_hash, project_id, scopes, + revoked, expires_at, last_used, created_at + FROM user_access_tokens WHERE key_hash = $1 + `, hash).Scan( + &token.ID, &token.UserID, &token.Name, &token.Prefix, &token.KeyHash, + &token.ProjectID, &scopes, + &token.Revoked, &expiresAt, &lastUsed, &token.Created, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + unmarshalJSON(scopes, &token.Scopes) + if expiresAt.Valid { + token.ExpiresAt = &expiresAt.Time + } + if lastUsed.Valid { + token.LastUsed = &lastUsed.Time + } + return token, nil +} + +func (s *PostgresStore) UpdateUserAccessTokenLastUsed(ctx context.Context, id string) error { + _, err := s.db.ExecContext(ctx, + "UPDATE user_access_tokens SET last_used = $1 WHERE id = $2", + time.Now(), id, + ) + return err +} + +func (s *PostgresStore) RevokeUserAccessToken(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, + "UPDATE user_access_tokens SET revoked = 1 WHERE id = $1", id, + ) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) DeleteUserAccessToken(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM user_access_tokens WHERE id = $1", id) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) ListUserAccessTokens(ctx context.Context, userID string) ([]store.UserAccessToken, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, user_id, name, prefix, project_id, scopes, + revoked, expires_at, last_used, created_at + FROM user_access_tokens WHERE user_id = $1 + ORDER BY created_at DESC + `, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var tokens []store.UserAccessToken + for rows.Next() { + var token store.UserAccessToken + var scopes string + var expiresAt, lastUsed sql.NullTime + + if err := rows.Scan( + &token.ID, &token.UserID, &token.Name, &token.Prefix, + &token.ProjectID, &scopes, + &token.Revoked, &expiresAt, &lastUsed, &token.Created, + ); err != nil { + return nil, err + } + + unmarshalJSON(scopes, &token.Scopes) + if expiresAt.Valid { + token.ExpiresAt = &expiresAt.Time + } + if lastUsed.Valid { + token.LastUsed = &lastUsed.Time + } + tokens = append(tokens, token) + } + return tokens, nil +} + +func (s *PostgresStore) CountUserAccessTokens(ctx context.Context, userID string) (int, error) { + var count int + err := s.db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM user_access_tokens WHERE user_id = $1 AND revoked = 0", + userID, + ).Scan(&count) + return count, err +} diff --git a/pkg/store/postgres/users.go b/pkg/store/postgres/users.go new file mode 100644 index 00000000..238b5608 --- /dev/null +++ b/pkg/store/postgres/users.go @@ -0,0 +1,247 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/GoogleCloudPlatform/scion/pkg/store" +) + +func (s *PostgresStore) CreateUser(ctx context.Context, user *store.User) error { + now := time.Now() + user.Created = now + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO users (id, email, display_name, avatar_url, role, status, preferences, created_at, last_login) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + `, + user.ID, user.Email, user.DisplayName, user.AvatarURL, user.Role, user.Status, + marshalJSON(user.Preferences), user.Created, user.LastLogin, + ) + if err != nil { + if strings.Contains(err.Error(), "unique constraint") || strings.Contains(err.Error(), "duplicate key") { + return store.ErrAlreadyExists + } + return err + } + return nil +} + +func (s *PostgresStore) GetUser(ctx context.Context, id string) (*store.User, error) { + user := &store.User{} + var preferences string + var lastLogin, lastSeen sql.NullTime + + err := s.db.QueryRowContext(ctx, ` + SELECT id, email, display_name, avatar_url, role, status, preferences, created_at, last_login, last_seen + FROM users WHERE id = $1 + `, id).Scan( + &user.ID, &user.Email, &user.DisplayName, &user.AvatarURL, &user.Role, &user.Status, + &preferences, &user.Created, &lastLogin, &lastSeen, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + + if lastLogin.Valid { + user.LastLogin = lastLogin.Time + } + if lastSeen.Valid { + user.LastSeen = lastSeen.Time + } + unmarshalJSON(preferences, &user.Preferences) + + return user, nil +} + +func (s *PostgresStore) GetUserByEmail(ctx context.Context, email string) (*store.User, error) { + var id string + err := s.db.QueryRowContext(ctx, "SELECT id FROM users WHERE email = $1", email).Scan(&id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrNotFound + } + return nil, err + } + return s.GetUser(ctx, id) +} + +func (s *PostgresStore) UpdateUser(ctx context.Context, user *store.User) error { + result, err := s.db.ExecContext(ctx, ` + UPDATE users SET + email = $1, display_name = $2, avatar_url = $3, + role = $4, status = $5, preferences = $6, last_login = $7, last_seen = $8 + WHERE id = $9 + `, + user.Email, user.DisplayName, user.AvatarURL, + user.Role, user.Status, marshalJSON(user.Preferences), user.LastLogin, user.LastSeen, + user.ID, + ) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) UpdateUserLastSeen(ctx context.Context, id string, t time.Time) error { + _, err := s.db.ExecContext(ctx, `UPDATE users SET last_seen = $1 WHERE id = $2`, t, id) + return err +} + +func (s *PostgresStore) DeleteUser(ctx context.Context, id string) error { + result, err := s.db.ExecContext(ctx, "DELETE FROM users WHERE id = $1", id) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return store.ErrNotFound + } + return nil +} + +func (s *PostgresStore) ListUsers(ctx context.Context, filter store.UserFilter, opts store.ListOptions) (*store.ListResult[store.User], error) { + var conditions []string + var args []interface{} + + if filter.Role != "" { + args = append(args, filter.Role) + conditions = append(conditions, fmt.Sprintf("role = $%d", len(args))) + } + if filter.Status != "" { + args = append(args, filter.Status) + conditions = append(conditions, fmt.Sprintf("status = $%d", len(args))) + } + if filter.Search != "" { + pattern := "%" + filter.Search + "%" + args = append(args, pattern, pattern) + conditions = append(conditions, fmt.Sprintf("(email LIKE $%d OR display_name LIKE $%d)", len(args)-1, len(args))) + } + + whereClause := "" + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } + + var totalCount int + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM users %s", whereClause) + if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil { + return nil, err + } + + limit := opts.Limit + if limit <= 0 { + limit = 50 + } + + if limit > 200 { + limit = 200 + } + + offset := 0 + if opts.Cursor != "" { + if parsed, err := strconv.Atoi(opts.Cursor); err == nil && parsed > 0 { + offset = parsed + } + } + + // Map sort field to column name (whitelist to prevent SQL injection) + orderColumn := "created_at" + orderDir := "DESC" + switch opts.SortBy { + case "name": + orderColumn = "display_name" + orderDir = "ASC" // default ascending for name + case "lastSeen": + orderColumn = "last_seen" + case "created": + orderColumn = "created_at" + } + if opts.SortDir == "asc" { + orderDir = "ASC" + } else if opts.SortDir == "desc" { + orderDir = "DESC" + } + + args = append(args, limit+1, offset) + query := fmt.Sprintf(` + SELECT id, email, display_name, avatar_url, role, status, preferences, created_at, last_login, last_seen + FROM users %s ORDER BY %s %s LIMIT $%d OFFSET $%d + `, whereClause, orderColumn, orderDir, len(args)-1, len(args)) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var users []store.User + for rows.Next() { + var user store.User + var preferences string + var lastLogin, lastSeen sql.NullTime + + if err := rows.Scan( + &user.ID, &user.Email, &user.DisplayName, &user.AvatarURL, &user.Role, &user.Status, + &preferences, &user.Created, &lastLogin, &lastSeen, + ); err != nil { + return nil, err + } + + if lastLogin.Valid { + user.LastLogin = lastLogin.Time + } + if lastSeen.Valid { + user.LastSeen = lastSeen.Time + } + unmarshalJSON(preferences, &user.Preferences) + + users = append(users, user) + } + + result := &store.ListResult[store.User]{ + Items: users, + TotalCount: totalCount, + } + + // Handle pagination: if we got more than limit, there's a next page + if len(users) > limit { + result.Items = users[:limit] + result.NextCursor = strconv.Itoa(offset + limit) + } + + return result, nil +} From 46b57a85eed62b72f20b5be22b2ba023694e6762 Mon Sep 17 00:00:00 2001 From: Alexander Lerma Date: Mon, 1 Jun 2026 06:58:34 -0500 Subject: [PATCH 2/6] =?UTF-8?q?fix(store/postgres):=20address=20Gemini=20r?= =?UTF-8?q?eview=20=E2=80=94=20FK-safe=20migrations,=20ancestry=20guard,?= =?UTF-8?q?=20concurrency?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three genuine Postgres-specific correctness fixes from the PR #289 review: - V40/V31 migrations: replace the SQLite-style DROP TABLE ... CASCADE + recreate with in-place ALTER TABLE. In Postgres, CASCADE silently drops every foreign key in other tables that references the recreated table and never restores them (schema corruption). Postgres supports DROP/ADD CONSTRAINT and ALTER COLUMN DROP NOT NULL directly. Verified against live Postgres 16: all 7 FKs referencing the projects (ex-groves) table and the notifications->notification_subscriptions FK survive. - ListAgents ancestry filter: json_array_elements_text(ancestry::json) crashes on empty-string/NULL ancestry ('invalid input syntax for type json'). Guard with COALESCE(NULLIF(ancestry,''),'[]'). SQLite's json_each tolerated this. - Migrate(): acquire a database-global pg_advisory_lock on a pinned connection so concurrent hub replicas serialize migrations (the stateless-scale-out goal of this backend). Bump pool 4 -> 25 to avoid connection starvation under concurrent load. Co-Authored-By: Claude Opus 4.8 --- pkg/store/postgres/agents.go | 5 +- pkg/store/postgres/migrations.go | 83 +++++++------------------------- pkg/store/postgres/postgres.go | 46 ++++++++++++++---- 3 files changed, 57 insertions(+), 77 deletions(-) diff --git a/pkg/store/postgres/agents.go b/pkg/store/postgres/agents.go index 70c1873b..df41b61e 100644 --- a/pkg/store/postgres/agents.go +++ b/pkg/store/postgres/agents.go @@ -306,7 +306,10 @@ func (s *PostgresStore) ListAgents(ctx context.Context, filter store.AgentFilter args = append(args, filter.Phase) } if filter.AncestorID != "" { - conditions = append(conditions, fmt.Sprintf("EXISTS (SELECT 1 FROM json_array_elements_text(ancestry::json) AS e(value) WHERE e.value = $%d)", len(args)+1)) + // Guard against empty-string / NULL ancestry: ''::json raises a type error + // in Postgres (SQLite's json_each silently tolerates NULL), so coalesce to + // an empty JSON array before expanding it. + conditions = append(conditions, fmt.Sprintf("EXISTS (SELECT 1 FROM json_array_elements_text(COALESCE(NULLIF(ancestry, ''), '[]')::json) AS e(value) WHERE e.value = $%d)", len(args)+1)) args = append(args, filter.AncestorID) } diff --git a/pkg/store/postgres/migrations.go b/pkg/store/postgres/migrations.go index 8ce84d33..1ea77777 100644 --- a/pkg/store/postgres/migrations.go +++ b/pkg/store/postgres/migrations.go @@ -592,33 +592,14 @@ CREATE INDEX IF NOT EXISTS idx_gcp_sa_scope ON gcp_service_accounts(scope, scope // Migration V31: Add scope column to notification_subscriptions and make agent_id nullable. // Enables project-scoped subscriptions (watch all agents in a project) in addition to // agent-scoped subscriptions. Adds unique constraint for deduplication. +// The SQLite source recreates notification_subscriptions because SQLite cannot +// ALTER COLUMN to drop NOT NULL. Postgres supports ADD COLUMN and ALTER COLUMN +// DROP NOT NULL directly, so this is a plain in-place ALTER. Recreating the +// table with DROP ... CASCADE would silently destroy the notifications +// .subscription_id foreign key and never recreate it. const migrationV31 = ` --- Postgres supports ALTER TABLE directly, so we recreate the table as in SQLite source. -CREATE TABLE notification_subscriptions_new ( - id TEXT PRIMARY KEY, - scope TEXT NOT NULL DEFAULT 'agent', - agent_id TEXT, - subscriber_type TEXT NOT NULL DEFAULT 'agent', - subscriber_id TEXT NOT NULL, - grove_id TEXT NOT NULL, - trigger_activities TEXT NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - created_by TEXT NOT NULL, - FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE CASCADE -); - --- Copy existing data (all existing subscriptions are agent-scoped) -INSERT INTO notification_subscriptions_new - (id, scope, agent_id, subscriber_type, subscriber_id, grove_id, trigger_activities, created_at, created_by) -SELECT id, 'agent', agent_id, subscriber_type, subscriber_id, grove_id, trigger_activities, created_at, created_by -FROM notification_subscriptions; - -DROP TABLE notification_subscriptions CASCADE; -ALTER TABLE notification_subscriptions_new RENAME TO notification_subscriptions; - --- Recreate indexes -CREATE INDEX IF NOT EXISTS idx_notification_subs_agent ON notification_subscriptions(agent_id); -CREATE INDEX IF NOT EXISTS idx_notification_subs_project ON notification_subscriptions(grove_id); +ALTER TABLE notification_subscriptions ADD COLUMN IF NOT EXISTS scope TEXT NOT NULL DEFAULT 'agent'; +ALTER TABLE notification_subscriptions ALTER COLUMN agent_id DROP NOT NULL; CREATE INDEX IF NOT EXISTS idx_notification_subs_subscriber ON notification_subscriptions(subscriber_type, subscriber_id); -- Unique constraint: one subscription per (scope, target, subscriber, project) @@ -751,49 +732,19 @@ CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(created_at DESC); ` // Migration V40: Allow multiple groves per git remote (drop UNIQUE on git_remote), -// and enforce slug uniqueness (add UNIQUE on slug). Requires table recreation -// because SQLite does not support ALTER TABLE DROP CONSTRAINT. +// and enforce slug uniqueness (add UNIQUE on slug). // -// IMPORTANT: This migration requires foreign_keys=OFF around the DROP TABLE. -// SQLite ignores PRAGMA changes inside transactions, so the migration runner -// handles this via the foreignKeysOffMigrations set. The PRAGMA statements are -// intentionally NOT included in the SQL string. +// The SQLite source recreates the whole `groves` table because SQLite cannot +// ALTER TABLE ... DROP/ADD CONSTRAINT. Postgres supports both directly, so this +// is a plain in-place ALTER. We deliberately do NOT drop+recreate the table: +// in Postgres, DROP TABLE ... CASCADE would silently drop every foreign key in +// other tables (agents, templates, schedules, ...) that references groves(id), +// and those FKs would never be recreated — corrupting the schema. The inline +// `git_remote TEXT UNIQUE` from V1 is the auto-named constraint groves_git_remote_key. const migrationV40 = ` -CREATE TABLE IF NOT EXISTS groves_new ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, - slug TEXT NOT NULL UNIQUE, - git_remote TEXT, - labels TEXT, - annotations TEXT, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - created_by TEXT, - owner_id TEXT, - visibility TEXT NOT NULL DEFAULT 'private', - default_runtime_broker_id TEXT REFERENCES runtime_brokers(id) ON DELETE SET NULL, - shared_dirs TEXT, - github_installation_id INTEGER REFERENCES github_installations(installation_id), - github_permissions TEXT, - github_app_status TEXT, - git_identity TEXT -); - -INSERT INTO groves_new SELECT - id, name, slug, git_remote, labels, annotations, - created_at, updated_at, created_by, owner_id, visibility, - default_runtime_broker_id, shared_dirs, - github_installation_id, github_permissions, github_app_status, - git_identity -FROM groves ON CONFLICT DO NOTHING; - -DROP TABLE IF EXISTS groves CASCADE; -ALTER TABLE groves_new RENAME TO groves; - +ALTER TABLE groves DROP CONSTRAINT IF EXISTS groves_git_remote_key; +ALTER TABLE groves ADD CONSTRAINT groves_slug_key UNIQUE (slug); CREATE INDEX IF NOT EXISTS idx_groves_slug ON groves(slug); -CREATE INDEX IF NOT EXISTS idx_groves_git_remote ON groves(git_remote); -CREATE INDEX IF NOT EXISTS idx_groves_owner ON groves(owner_id); -CREATE INDEX IF NOT EXISTS idx_groves_default_runtime_broker ON groves(default_runtime_broker_id); ` // Migration V41: Maintenance operations tables for the admin maintenance panel. diff --git a/pkg/store/postgres/postgres.go b/pkg/store/postgres/postgres.go index 4247de15..39bcb952 100644 --- a/pkg/store/postgres/postgres.go +++ b/pkg/store/postgres/postgres.go @@ -36,8 +36,11 @@ func New(connURL string) (*PostgresStore, error) { return nil, fmt.Errorf("failed to open database: %w", err) } - db.SetMaxOpenConns(4) - db.SetMaxIdleConns(4) + // SQLite pins this at 1 (single-writer). Postgres is built for concurrent + // access, and the whole point of this backend is a horizontally-scaled, + // stateless hub — 4 would starve under concurrent load. 25 is a sane default. + db.SetMaxOpenConns(25) + db.SetMaxIdleConns(25) return &PostgresStore{db: db}, nil } @@ -57,8 +60,32 @@ func (s *PostgresStore) Ping(ctx context.Context) error { return s.db.PingContext(ctx) } +// advisoryLockID is an arbitrary, stable key for the migration advisory lock. +// In a stateless, horizontally-scaled hub, multiple replicas may start and run +// Migrate concurrently; a database-global advisory lock serializes them so only +// one replica applies migrations while the others wait. +const advisoryLockID = 0x5C104D16 // "SCIONDB" mnemonic + // Migrate applies database migrations. func (s *PostgresStore) Migrate(ctx context.Context) error { + // Serialize migrations across replicas with a session-level advisory lock. + // The lock and unlock MUST run on the same session, so pin them to a + // dedicated connection; the migrations themselves run via the pool — the + // lock is database-global and blocks other replicas regardless of which + // connection runs the DDL. + lockConn, err := s.db.Conn(ctx) + if err != nil { + return fmt.Errorf("failed to acquire connection for migration lock: %w", err) + } + defer lockConn.Close() + if _, err := lockConn.ExecContext(ctx, "SELECT pg_advisory_lock($1)", advisoryLockID); err != nil { + return fmt.Errorf("failed to acquire migration advisory lock: %w", err) + } + defer func() { + // Best-effort unlock; closing the connection also releases the lock. + _, _ = lockConn.ExecContext(ctx, "SELECT pg_advisory_unlock($1)", advisoryLockID) + }() + migrations := []any{ migrationV1, migrationV2, @@ -127,18 +154,17 @@ func (s *PostgresStore) Migrate(ctx context.Context) error { // Get current version var currentVersion int - err := s.db.QueryRowContext(ctx, "SELECT COALESCE(MAX(version), 0) FROM schema_migrations").Scan(¤tVersion) + err = s.db.QueryRowContext(ctx, "SELECT COALESCE(MAX(version), 0) FROM schema_migrations").Scan(¤tVersion) if err != nil { return fmt.Errorf("failed to get current schema version: %w", err) } - // Migrations that require PRAGMA foreign_keys=OFF around the transaction. - // SQLite ignores PRAGMA changes inside transactions, so we must disable - // foreign keys before BeginTx and re-enable after Commit. Without this, - // DROP TABLE on a parent table triggers ON DELETE CASCADE on child tables. - foreignKeysOffMigrations := map[int]bool{ - 40: true, // V40 drops and recreates the projects table - } + // In SQLite these migrations need PRAGMA foreign_keys=OFF because they + // recreate a parent table (DROP + rename), which would otherwise cascade. + // The Postgres translations use in-place ALTER TABLE instead, so none + // actually need special handling — the map is kept (empty) plus + // applyMigrationWithFKOff so the runner shape stays 1-to-1 with SQLite. + foreignKeysOffMigrations := map[int]bool{} // Apply pending migrations for i, migration := range migrations { From d3fa5713c10561acd95d8987183ca789bd008362 Mon Sep 17 00:00:00 2001 From: Alexander Lerma Date: Mon, 1 Jun 2026 19:40:43 -0500 Subject: [PATCH 3/6] fix(postgres): isolate Ent tables in a dedicated schema to avoid type-cast collision MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The SQLite store gives Ent its own database file (entDSN := url + "_ent"), so the raw-migration tables and the Ent-managed tables never share storage. The Postgres path opened Ent on the *same* DSN as the raw store, so Ent's auto-migration tried to ALTER tables the raw migrations had already created with different column types — projects.id is TEXT in the raw schema but UUID in the Ent model — failing at boot with: ent migrate: modify "projects" table: pq: column "id" cannot be cast automatically to type uuid (42804) Create a dedicated `ent` schema and pin the Ent client's search_path to it (withSearchPath helper, URL + keyword DSN forms), the Postgres analog of SQLite's separate _ent file. Verified end-to-end against a live PG16 Lakebase: raw store lands 31 tables in `public`, Ent lands 8 in `ent`, public.projects.id stays TEXT while ent.projects.id is UUID, and the full initStore path is idempotent across restarts. Co-Authored-By: Claude Opus 4.8 --- cmd/server_foreground.go | 53 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/cmd/server_foreground.go b/cmd/server_foreground.go index d514a9bf..2ce6d3b7 100644 --- a/cmd/server_foreground.go +++ b/cmd/server_foreground.go @@ -23,6 +23,7 @@ import ( "io" "log" "log/slog" + "net/url" "os" "os/signal" "path/filepath" @@ -693,7 +694,26 @@ func initStore(cfg *config.GlobalConfig) (store.Store, error) { return nil, fmt.Errorf("failed to run migrations: %w", err) } - entClient, err := entc.OpenPostgres(cfg.Database.URL) + // Isolate the Ent-managed tables from the raw-store tables. On SQLite + // these two table sets live in physically separate database files + // (entDSN := cfg.Database.URL + "_ent" above); several tables exist in + // both worlds with deliberately different column types — e.g. the raw + // migrations create projects.id as TEXT while the Ent schema models it + // as UUID. Pointing Ent at the same Postgres schema as the raw store + // makes Ent's auto-migration try to ALTER those shared tables in place + // ("column \"id\" cannot be cast automatically to type uuid"). A + // dedicated `ent` schema is the Postgres analog of the separate _ent + // file, keeping the two table sets from colliding. + if _, err := pgStore.DB().ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS ent"); err != nil { + pgStore.Close() + return nil, fmt.Errorf("failed to create ent schema: %w", err) + } + entDSN, err := withSearchPath(cfg.Database.URL, "ent") + if err != nil { + pgStore.Close() + return nil, fmt.Errorf("failed to build ent DSN: %w", err) + } + entClient, err := entc.OpenPostgres(entDSN) if err != nil { pgStore.Close() return nil, fmt.Errorf("failed to open ent database: %w", err) @@ -719,6 +739,37 @@ func initStore(cfg *config.GlobalConfig) (store.Store, error) { } } +// withSearchPath returns the Postgres DSN with its connection search_path +// pinned to schemaName, so an Ent client opened on it confines all of its +// tables to that schema. It understands both DSN flavors lib/pq accepts: a +// URL form ("postgres://user:pass@host/db?...") and the keyword/value form +// ("host=... dbname=..."). For the URL form the schema is set via the +// `options` query parameter (-c search_path=...); for the keyword form an +// `options` keyword is appended. An existing search_path/options is replaced. +func withSearchPath(dsn, schemaName string) (string, error) { + opt := "-c search_path=" + schemaName + if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { + u, err := url.Parse(dsn) + if err != nil { + return "", fmt.Errorf("parsing postgres URL: %w", err) + } + q := u.Query() + q.Set("options", opt) + u.RawQuery = q.Encode() + return u.String(), nil + } + // Keyword/value DSN: drop any existing options token, then append ours. + fields := make([]string, 0) + for _, f := range strings.Fields(dsn) { + if strings.HasPrefix(f, "options=") { + continue + } + fields = append(fields, f) + } + fields = append(fields, "options='"+opt+"'") + return strings.Join(fields, " "), nil +} + // initDevAuth initializes dev authentication and returns the token. func initDevAuth(cfg *config.GlobalConfig, globalDir string) (string, error) { devAuthCfg := apiclient.DevAuthConfig{ From c290d651e6267468166dc32032cba5085874a17c Mon Sep 17 00:00:00 2001 From: Alexander Lerma Date: Mon, 1 Jun 2026 21:23:53 -0500 Subject: [PATCH 4/6] fix(store/postgres): boolToInt the locked column on write CreateTemplate/UpdateTemplate and CreateHarnessConfig/UpdateHarnessConfig passed the Go bool template.Locked / hc.Locked straight into the locked INTEGER column. SQLite coerces bool->int silently, but lib/pq sends it as the string "false" and Postgres rejects it: pq: invalid input syntax for type integer: "false" (22P02) This made every template + harness-config bootstrap import fail on a Postgres-backed hub (claude/codex/gemini/opencode all skipped), so agents would later error harness-config "claude" not found. Wrap both write sites with boolToInt() to match every other bool->INTEGER column in the store (brokers.AutoProvide, envvars.Secret/Sensitive, tokens.Revoked, ...). Co-Authored-By: Claude Opus 4.8 --- pkg/store/postgres/harness_configs.go | 4 ++-- pkg/store/postgres/templates.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/store/postgres/harness_configs.go b/pkg/store/postgres/harness_configs.go index e0cfd1d0..eb9b2359 100644 --- a/pkg/store/postgres/harness_configs.go +++ b/pkg/store/postgres/harness_configs.go @@ -48,7 +48,7 @@ func (s *PostgresStore) CreateHarnessConfig(ctx context.Context, hc *store.Harne hc.Harness, marshalJSON(hc.Config), nullableString(hc.ContentHash), hc.Scope, nullableString(hc.ScopeID), nullableString(hc.StorageURI), nullableString(hc.StorageBucket), nullableString(hc.StoragePath), marshalJSON(hc.Files), - hc.Locked, hc.Status, + boolToInt(hc.Locked), hc.Status, nullableString(hc.OwnerID), nullableString(hc.CreatedBy), nullableString(hc.UpdatedBy), hc.Visibility, hc.Created, hc.Updated, ) @@ -168,7 +168,7 @@ func (s *PostgresStore) UpdateHarnessConfig(ctx context.Context, hc *store.Harne hc.Harness, marshalJSON(hc.Config), nullableString(hc.ContentHash), hc.Scope, nullableString(hc.ScopeID), nullableString(hc.StorageURI), nullableString(hc.StorageBucket), nullableString(hc.StoragePath), marshalJSON(hc.Files), - hc.Locked, hc.Status, + boolToInt(hc.Locked), hc.Status, nullableString(hc.OwnerID), nullableString(hc.UpdatedBy), hc.Visibility, hc.Updated, hc.ID, diff --git a/pkg/store/postgres/templates.go b/pkg/store/postgres/templates.go index cb352899..2dd2d954 100644 --- a/pkg/store/postgres/templates.go +++ b/pkg/store/postgres/templates.go @@ -53,7 +53,7 @@ func (s *PostgresStore) CreateTemplate(ctx context.Context, template *store.Temp template.Harness, nullableString(template.DefaultHarnessConfig), template.Image, marshalJSON(template.Config), nullableString(template.ContentHash), template.Scope, nullableString(template.ScopeID), nullableString(template.ProjectID), nullableString(template.StorageURI), nullableString(template.StorageBucket), nullableString(template.StoragePath), marshalJSON(template.Files), - nullableString(template.BaseTemplate), template.Locked, template.Status, + nullableString(template.BaseTemplate), boolToInt(template.Locked), template.Status, nullableString(template.OwnerID), nullableString(template.CreatedBy), nullableString(template.UpdatedBy), template.Visibility, template.Created, template.Updated, ) @@ -186,7 +186,7 @@ func (s *PostgresStore) UpdateTemplate(ctx context.Context, template *store.Temp template.Harness, nullableString(template.DefaultHarnessConfig), template.Image, marshalJSON(template.Config), nullableString(template.ContentHash), template.Scope, nullableString(template.ScopeID), nullableString(template.ProjectID), nullableString(template.StorageURI), nullableString(template.StorageBucket), nullableString(template.StoragePath), marshalJSON(template.Files), - nullableString(template.BaseTemplate), template.Locked, template.Status, + nullableString(template.BaseTemplate), boolToInt(template.Locked), template.Status, nullableString(template.OwnerID), nullableString(template.UpdatedBy), template.Visibility, template.Updated, template.ID, From 38363476edf71b9841bf23255569bf5e1bd251e0 Mon Sep 17 00:00:00 2001 From: Alexander Lerma Date: Mon, 1 Jun 2026 23:22:21 -0500 Subject: [PATCH 5/6] fix(store/postgres): cast started_at text param to timestamp in UpdateAgentStatus The agent heartbeat status update bound $19 (su.StartedAt, an RFC3339 string) straight into COALESCE(NULLIF($19, ''), started_at) against the started_at TIMESTAMP column. Postgres rejects this with "COALESCE types text and timestamp without time zone cannot be matched (42804)" because NULLIF($19,'') is typed text, while SQLite's dynamic typing tolerated it. This broke every running-agent heartbeat: the hub logged "Failed to update agent status from heartbeat" and returned 502 to the broker dispatch, so `scion start` saw a 30s context-deadline timeout even though the agent pod came up fine. Cast the text param to ::timestamp so COALESCE has matching types. The empty -> NULL path still falls through to the existing started_at value. Verified the RFC3339 'Z' string casts cleanly against the live Lakebase backend. Co-Authored-By: Claude Opus 4.8 --- pkg/store/postgres/agents.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/store/postgres/agents.go b/pkg/store/postgres/agents.go index df41b61e..5c287c93 100644 --- a/pkg/store/postgres/agents.go +++ b/pkg/store/postgres/agents.go @@ -466,7 +466,7 @@ func (s *PostgresStore) UpdateAgentStatus(ctx context.Context, id string, su sto last_activity_event = CASE WHEN $13 != '' THEN $14 ELSE last_activity_event END, current_turns = CASE WHEN $15 THEN $16 ELSE current_turns END, current_model_calls = CASE WHEN $17 THEN $18 ELSE current_model_calls END, - started_at = COALESCE(NULLIF($19, ''), started_at), + started_at = COALESCE(NULLIF($19, '')::timestamp, started_at), updated_at = $20, last_seen = $21 WHERE id = $22 From 0aa55e540297ffca96601d3d7fde6e5f595e0ead Mon Sep 17 00:00:00 2001 From: Alexander Lerma Date: Thu, 4 Jun 2026 00:46:16 -0500 Subject: [PATCH 6/6] feat(hub): add generic OAuth/OIDC login provider (Dex) Adds a configurable 'generic' OAuth2/OIDC provider alongside the hardcoded Google/GitHub ones, so the Hub web/CLI/device login can federate to any standards-compliant issuer (notably the in-cluster Dex) instead of only Google/GitHub. Config mirrors Better Auth's genericOAuth shape, via SCION_SERVER_OAUTH__GENERIC_*: - clientId / clientSecret - discoveryUrl (full .well-known URL) or issuer (derives the well-known path) for OIDC discovery, or explicit authorizationUrl/tokenUrl/userInfoUrl - scopes (default 'openid email profile') Endpoints resolve as: explicit authorize+token > discoveryUrl > issuer-derived discovery; discovery docs are cached per URL. Userinfo maps the standard sub/email/name/picture claims. Touches: hubclient OAuthProviderGeneric constant + order; hub OAuthService dispatch (auth URL + code exchange); config struct + server wiring; web login + callback provider validation. Google/GitHub unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) --- cmd/server_foreground.go | 30 ++++ pkg/config/hub_config.go | 14 ++ pkg/hub/oauth.go | 35 ++++- pkg/hub/oauth_generic.go | 239 +++++++++++++++++++++++++++++++ pkg/hub/oauth_generic_test.go | 259 ++++++++++++++++++++++++++++++++++ pkg/hub/web.go | 4 +- pkg/hubclient/auth.go | 5 + 7 files changed, 583 insertions(+), 3 deletions(-) create mode 100644 pkg/hub/oauth_generic.go create mode 100644 pkg/hub/oauth_generic_test.go diff --git a/cmd/server_foreground.go b/cmd/server_foreground.go index 2ce6d3b7..41da988a 100644 --- a/cmd/server_foreground.go +++ b/cmd/server_foreground.go @@ -909,6 +909,16 @@ func initHubServer(ctx context.Context, cfg *config.GlobalConfig, s store.Store, ClientID: cfg.OAuth.Web.GitHub.ClientID, ClientSecret: cfg.OAuth.Web.GitHub.ClientSecret, }, + Generic: hub.OAuthProviderConfig{ + ClientID: cfg.OAuth.Web.Generic.ClientID, + ClientSecret: cfg.OAuth.Web.Generic.ClientSecret, + DiscoveryURL: cfg.OAuth.Web.Generic.DiscoveryURL, + Issuer: cfg.OAuth.Web.Generic.Issuer, + AuthorizationURL: cfg.OAuth.Web.Generic.AuthorizationURL, + TokenURL: cfg.OAuth.Web.Generic.TokenURL, + UserInfoURL: cfg.OAuth.Web.Generic.UserInfoURL, + Scopes: cfg.OAuth.Web.Generic.Scopes, + }, }, CLI: hub.OAuthClientConfig{ Google: hub.OAuthProviderConfig{ @@ -919,6 +929,16 @@ func initHubServer(ctx context.Context, cfg *config.GlobalConfig, s store.Store, ClientID: cfg.OAuth.CLI.GitHub.ClientID, ClientSecret: cfg.OAuth.CLI.GitHub.ClientSecret, }, + Generic: hub.OAuthProviderConfig{ + ClientID: cfg.OAuth.CLI.Generic.ClientID, + ClientSecret: cfg.OAuth.CLI.Generic.ClientSecret, + DiscoveryURL: cfg.OAuth.CLI.Generic.DiscoveryURL, + Issuer: cfg.OAuth.CLI.Generic.Issuer, + AuthorizationURL: cfg.OAuth.CLI.Generic.AuthorizationURL, + TokenURL: cfg.OAuth.CLI.Generic.TokenURL, + UserInfoURL: cfg.OAuth.CLI.Generic.UserInfoURL, + Scopes: cfg.OAuth.CLI.Generic.Scopes, + }, }, Device: hub.OAuthClientConfig{ Google: hub.OAuthProviderConfig{ @@ -929,6 +949,16 @@ func initHubServer(ctx context.Context, cfg *config.GlobalConfig, s store.Store, ClientID: cfg.OAuth.Device.GitHub.ClientID, ClientSecret: cfg.OAuth.Device.GitHub.ClientSecret, }, + Generic: hub.OAuthProviderConfig{ + ClientID: cfg.OAuth.Device.Generic.ClientID, + ClientSecret: cfg.OAuth.Device.Generic.ClientSecret, + DiscoveryURL: cfg.OAuth.Device.Generic.DiscoveryURL, + Issuer: cfg.OAuth.Device.Generic.Issuer, + AuthorizationURL: cfg.OAuth.Device.Generic.AuthorizationURL, + TokenURL: cfg.OAuth.Device.Generic.TokenURL, + UserInfoURL: cfg.OAuth.Device.Generic.UserInfoURL, + Scopes: cfg.OAuth.Device.Generic.Scopes, + }, }, }, MaintenanceConfig: resolveMaintenanceConfig(cfg), diff --git a/pkg/config/hub_config.go b/pkg/config/hub_config.go index 2d284bac..59f66234 100644 --- a/pkg/config/hub_config.go +++ b/pkg/config/hub_config.go @@ -163,6 +163,16 @@ type OAuthProviderConfig struct { ClientID string `json:"clientId" yaml:"clientId" koanf:"clientId"` // ClientSecret is the OAuth application client secret. ClientSecret string `json:"clientSecret" yaml:"clientSecret" koanf:"clientSecret"` + // The following fields are only used by the generic OAuth/OIDC provider + // (e.g. Dex); field names mirror Better Auth's genericOAuth config. Set + // DiscoveryURL or Issuer for OIDC discovery, or set the endpoints + // explicitly. Left empty for Google/GitHub. + DiscoveryURL string `json:"discoveryUrl" yaml:"discoveryUrl" koanf:"discoveryUrl"` + Issuer string `json:"issuer" yaml:"issuer" koanf:"issuer"` + AuthorizationURL string `json:"authorizationUrl" yaml:"authorizationUrl" koanf:"authorizationUrl"` + TokenURL string `json:"tokenUrl" yaml:"tokenUrl" koanf:"tokenUrl"` + UserInfoURL string `json:"userInfoUrl" yaml:"userInfoUrl" koanf:"userInfoUrl"` + Scopes string `json:"scopes" yaml:"scopes" koanf:"scopes"` } // OAuthClientConfig holds OAuth provider configurations for a specific client type. @@ -171,6 +181,10 @@ type OAuthClientConfig struct { Google OAuthProviderConfig `json:"google" yaml:"google" koanf:"google"` // GitHub OAuth settings for this client type. GitHub OAuthProviderConfig `json:"github" yaml:"github" koanf:"github"` + // Generic is a configurable OAuth2/OIDC provider (e.g. Dex) for this client + // type. Configure via SCION_SERVER_OAUTH__GENERIC_{CLIENTID,CLIENTSECRET} + // plus GENERIC_ISSUER (discovery) or explicit GENERIC_{AUTHURL,TOKENURL,USERINFOURL}. + Generic OAuthProviderConfig `json:"generic" yaml:"generic" koanf:"generic"` } // OAuthConfig holds OAuth provider configurations. diff --git a/pkg/hub/oauth.go b/pkg/hub/oauth.go index 3efd34cb..45b199a4 100644 --- a/pkg/hub/oauth.go +++ b/pkg/hub/oauth.go @@ -22,6 +22,7 @@ import ( "net/http" "net/url" "strings" + "sync" "time" "github.com/GoogleCloudPlatform/scion/pkg/hubclient" @@ -31,17 +32,32 @@ import ( type OAuthProviderConfig struct { ClientID string ClientSecret string + // The following fields are only used by the generic OAuth/OIDC provider + // (Google/GitHub leave them empty). Field names mirror Better Auth's + // genericOAuth config. Endpoints resolve in this order: explicit + // AuthorizationURL/TokenURL/UserInfoURL win; else OIDC discovery against + // DiscoveryURL; else discovery derived from Issuer + // (Issuer + "/.well-known/openid-configuration"). + DiscoveryURL string // full .well-known/openid-configuration URL + Issuer string // issuer identifier; also derives DiscoveryURL when that is unset + AuthorizationURL string + TokenURL string + UserInfoURL string + Scopes string // space-separated; defaults to "openid email profile" } // OAuthClientConfig holds OAuth provider configurations for a specific client type. type OAuthClientConfig struct { Google OAuthProviderConfig GitHub OAuthProviderConfig + // Generic is a configurable OAuth2/OIDC provider (e.g. Dex) — discovery via + // Issuer, or explicit AuthURL/TokenURL/UserInfoURL. + Generic OAuthProviderConfig } // IsConfigured returns true if at least one OAuth provider is configured. func (c *OAuthClientConfig) IsConfigured() bool { - return c.Google.ClientID != "" || c.GitHub.ClientID != "" + return c.Google.ClientID != "" || c.GitHub.ClientID != "" || c.Generic.ClientID != "" } // IsProviderConfigured returns true if the specified provider is configured. @@ -51,6 +67,12 @@ func (c *OAuthClientConfig) IsProviderConfigured(provider string) bool { return c.Google.ClientID != "" && c.Google.ClientSecret != "" case hubclient.OAuthProviderGitHub: return c.GitHub.ClientID != "" && c.GitHub.ClientSecret != "" + case hubclient.OAuthProviderGeneric: + // Needs credentials plus a way to resolve endpoints: a discovery URL or + // issuer (for discovery), or explicit authorize+token endpoints. + hasEndpoints := c.Generic.DiscoveryURL != "" || c.Generic.Issuer != "" || + (c.Generic.AuthorizationURL != "" && c.Generic.TokenURL != "") + return c.Generic.ClientID != "" && c.Generic.ClientSecret != "" && hasEndpoints default: return false } @@ -63,6 +85,8 @@ func (c *OAuthClientConfig) GetProvider(provider string) OAuthProviderConfig { return c.Google case hubclient.OAuthProviderGitHub: return c.GitHub + case hubclient.OAuthProviderGeneric: + return c.Generic default: return OAuthProviderConfig{} } @@ -110,6 +134,10 @@ func oauthProviderOrder() []string { type OAuthService struct { config OAuthConfig httpClient *http.Client + + // oidcCache memoizes OIDC discovery documents keyed by issuer URL. + oidcMu sync.RWMutex + oidcCache map[string]*oidcDiscovery } // NewOAuthService creates a new OAuth service. @@ -119,6 +147,7 @@ func NewOAuthService(config OAuthConfig) *OAuthService { httpClient: &http.Client{ Timeout: 30 * time.Second, }, + oidcCache: make(map[string]*oidcDiscovery), } } @@ -212,6 +241,8 @@ func (s *OAuthService) GetAuthorizationURLForClient(clientType OAuthClientType, return s.getGoogleAuthURLWithConfig(cfg.Google, callbackURL, state) case hubclient.OAuthProviderGitHub: return s.getGitHubAuthURLWithConfig(cfg.GitHub, callbackURL, state) + case hubclient.OAuthProviderGeneric: + return s.getGenericAuthURLWithConfig(cfg.Generic, callbackURL, state) default: return "", fmt.Errorf("unsupported OAuth provider: %s", provider) } @@ -278,6 +309,8 @@ func (s *OAuthService) ExchangeCodeForClient(ctx context.Context, clientType OAu return s.exchangeGoogleCodeWithConfig(ctx, cfg.Google, code, callbackURL) case "github": return s.exchangeGitHubCodeWithConfig(ctx, cfg.GitHub, code, callbackURL) + case hubclient.OAuthProviderGeneric: + return s.exchangeGenericCodeWithConfig(ctx, cfg.Generic, code, callbackURL) default: return nil, fmt.Errorf("unsupported OAuth provider: %s", provider) } diff --git a/pkg/hub/oauth_generic.go b/pkg/hub/oauth_generic.go new file mode 100644 index 00000000..a228dbed --- /dev/null +++ b/pkg/hub/oauth_generic.go @@ -0,0 +1,239 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hub + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/GoogleCloudPlatform/scion/pkg/hubclient" +) + +// Generic OAuth/OIDC provider. +// +// Unlike the hardcoded Google/GitHub providers, the generic provider is +// configurable for any standards-compliant OAuth2/OIDC issuer (e.g. the +// in-cluster Dex). Modeled on Better Auth's genericOAuth plugin: point it at +// an issuer and let OIDC discovery resolve the endpoints, OR set the +// authorization/token/userinfo endpoints explicitly when the provider has no +// `.well-known/openid-configuration`. +// +// Configure via SCION_SERVER_OAUTH__GENERIC_{CLIENTID,CLIENTSECRET} +// plus either GENERIC_ISSUER (discovery) or the explicit +// GENERIC_{AUTHURL,TOKENURL,USERINFOURL}. + +// genericEndpoints holds the resolved OAuth2/OIDC endpoints for the generic provider. +type genericEndpoints struct { + AuthURL string + TokenURL string + UserInfoURL string +} + +// oidcDiscovery is the subset of the OIDC discovery document we consume. +type oidcDiscovery struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` +} + +// resolveGenericEndpoints returns the provider's endpoints, preferring explicit +// overrides and falling back to OIDC discovery (via DiscoveryURL, or derived +// from Issuer). Discovered documents are cached per discovery URL. +func (s *OAuthService) resolveGenericEndpoints(ctx context.Context, cfg OAuthProviderConfig) (genericEndpoints, error) { + ep := genericEndpoints{ + AuthURL: cfg.AuthorizationURL, + TokenURL: cfg.TokenURL, + UserInfoURL: cfg.UserInfoURL, + } + + // If both the authorize and token endpoints are explicit, no discovery needed. + if ep.AuthURL != "" && ep.TokenURL != "" { + return ep, nil + } + + // Resolve the discovery document URL: explicit DiscoveryURL wins, otherwise + // derive the standard well-known path from the issuer. + discoveryURL := cfg.DiscoveryURL + if discoveryURL == "" && cfg.Issuer != "" { + discoveryURL = strings.TrimSuffix(cfg.Issuer, "/") + "/.well-known/openid-configuration" + } + if discoveryURL == "" { + return genericEndpoints{}, fmt.Errorf("generic OAuth provider requires a discoveryUrl or issuer (for discovery), or explicit authorizationUrl+tokenUrl") + } + + disc, err := s.discoverOIDC(ctx, discoveryURL) + if err != nil { + return genericEndpoints{}, err + } + // Explicit overrides win over discovered values when both are present. + if ep.AuthURL == "" { + ep.AuthURL = disc.AuthorizationEndpoint + } + if ep.TokenURL == "" { + ep.TokenURL = disc.TokenEndpoint + } + if ep.UserInfoURL == "" { + ep.UserInfoURL = disc.UserinfoEndpoint + } + return ep, nil +} + +// discoverOIDC fetches (and caches) the OIDC discovery document from the given +// discovery URL (a full .well-known/openid-configuration URL). +func (s *OAuthService) discoverOIDC(ctx context.Context, discoveryURL string) (*oidcDiscovery, error) { + if discoveryURL == "" { + return nil, fmt.Errorf("OIDC discovery URL is not configured") + } + + s.oidcMu.RLock() + cached, ok := s.oidcCache[discoveryURL] + s.oidcMu.RUnlock() + if ok { + return cached, nil + } + + req, err := http.NewRequestWithContext(ctx, "GET", discoveryURL, nil) + if err != nil { + return nil, err + } + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("OIDC discovery request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("OIDC discovery failed: %s - %s", resp.Status, string(body)) + } + + var disc oidcDiscovery + if err := json.NewDecoder(resp.Body).Decode(&disc); err != nil { + return nil, fmt.Errorf("failed to decode OIDC discovery document: %w", err) + } + if disc.AuthorizationEndpoint == "" || disc.TokenEndpoint == "" { + return nil, fmt.Errorf("OIDC discovery document missing required endpoints") + } + + s.oidcMu.Lock() + s.oidcCache[discoveryURL] = &disc + s.oidcMu.Unlock() + + return &disc, nil +} + +// getGenericAuthURLWithConfig generates an authorization URL for the generic +// provider using the given config. +func (s *OAuthService) getGenericAuthURLWithConfig(cfg OAuthProviderConfig, callbackURL, state string) (string, error) { + if cfg.ClientID == "" { + return "", fmt.Errorf("generic OAuth is not configured") + } + ep, err := s.resolveGenericEndpoints(context.Background(), cfg) + if err != nil { + return "", err + } + + scopes := cfg.Scopes + if strings.TrimSpace(scopes) == "" { + scopes = "openid email profile" + } + + params := url.Values{ + "client_id": {cfg.ClientID}, + "redirect_uri": {callbackURL}, + "response_type": {"code"}, + "scope": {scopes}, + "state": {state}, + } + + return ep.AuthURL + "?" + params.Encode(), nil +} + +// exchangeGenericCodeWithConfig exchanges an authorization code for user info +// against the generic provider using the given config. +func (s *OAuthService) exchangeGenericCodeWithConfig(ctx context.Context, cfg OAuthProviderConfig, code, callbackURL string) (*OAuthUserInfo, error) { + if cfg.ClientID == "" || cfg.ClientSecret == "" { + return nil, fmt.Errorf("generic OAuth is not configured") + } + ep, err := s.resolveGenericEndpoints(ctx, cfg) + if err != nil { + return nil, err + } + + // Standard OAuth2 authorization-code exchange (same form-POST shape as Google). + tokenResp, err := s.exchangeCodeForToken(ctx, ep.TokenURL, cfg.ClientID, cfg.ClientSecret, code, callbackURL) + if err != nil { + return nil, fmt.Errorf("failed to exchange generic OAuth code: %w", err) + } + + userInfo, err := s.getGenericUserInfo(ctx, ep.UserInfoURL, tokenResp.AccessToken) + if err != nil { + return nil, fmt.Errorf("failed to get generic OAuth user info: %w", err) + } + return userInfo, nil +} + +// genericUserInfo is the subset of the standard OIDC userinfo response we use. +type genericUserInfo struct { + Sub string `json:"sub"` + Email string `json:"email"` + Name string `json:"name"` + Picture string `json:"picture"` +} + +// getGenericUserInfo retrieves user information from a standard OIDC userinfo endpoint. +func (s *OAuthService) getGenericUserInfo(ctx context.Context, userinfoURL, accessToken string) (*OAuthUserInfo, error) { + if userinfoURL == "" { + return nil, fmt.Errorf("generic OAuth provider has no userinfo endpoint (set issuer for discovery or userInfoURL explicitly)") + } + req, err := http.NewRequestWithContext(ctx, "GET", userinfoURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to get user info: %s - %s", resp.Status, string(body)) + } + + var ui genericUserInfo + if err := json.NewDecoder(resp.Body).Decode(&ui); err != nil { + return nil, fmt.Errorf("failed to decode user info: %w", err) + } + if ui.Sub == "" || ui.Email == "" { + return nil, fmt.Errorf("generic OAuth userinfo missing required sub/email claims") + } + + return &OAuthUserInfo{ + ID: ui.Sub, + Email: ui.Email, + DisplayName: ui.Name, + AvatarURL: ui.Picture, + Provider: hubclient.OAuthProviderGeneric, + }, nil +} diff --git a/pkg/hub/oauth_generic_test.go b/pkg/hub/oauth_generic_test.go new file mode 100644 index 00000000..5a60055a --- /dev/null +++ b/pkg/hub/oauth_generic_test.go @@ -0,0 +1,259 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hub + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/GoogleCloudPlatform/scion/pkg/hubclient" +) + +func TestGenericProvider_IsProviderConfigured(t *testing.T) { + tests := []struct { + name string + cfg OAuthClientConfig + expected bool + }{ + { + name: "empty", + cfg: OAuthClientConfig{}, + expected: false, + }, + { + name: "credentials but no endpoints", + cfg: OAuthClientConfig{ + Generic: OAuthProviderConfig{ClientID: "id", ClientSecret: "secret"}, + }, + expected: false, + }, + { + name: "issuer discovery", + cfg: OAuthClientConfig{ + Generic: OAuthProviderConfig{ClientID: "id", ClientSecret: "secret", Issuer: "https://dex.example.com"}, + }, + expected: true, + }, + { + name: "explicit discovery url", + cfg: OAuthClientConfig{ + Generic: OAuthProviderConfig{ClientID: "id", ClientSecret: "secret", DiscoveryURL: "https://dex.example.com/.well-known/openid-configuration"}, + }, + expected: true, + }, + { + name: "explicit endpoints", + cfg: OAuthClientConfig{ + Generic: OAuthProviderConfig{ + ClientID: "id", + ClientSecret: "secret", + AuthorizationURL: "https://idp.example.com/auth", + TokenURL: "https://idp.example.com/token", + }, + }, + expected: true, + }, + { + name: "explicit authorize only is not enough", + cfg: OAuthClientConfig{ + Generic: OAuthProviderConfig{ClientID: "id", ClientSecret: "secret", AuthorizationURL: "https://idp.example.com/auth"}, + }, + expected: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.IsProviderConfigured(hubclient.OAuthProviderGeneric); got != tt.expected { + t.Errorf("IsProviderConfigured(generic) = %v, want %v", got, tt.expected) + } + }) + } +} + +// newDiscoveryServer returns an httptest server that serves an OIDC discovery +// document pointing its endpoints back at itself. +func newDiscoveryServer(t *testing.T) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + srv := httptest.NewServer(mux) + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "issuer": srv.URL, + "authorization_endpoint": srv.URL + "/auth", + "token_endpoint": srv.URL + "/token", + "userinfo_endpoint": srv.URL + "/userinfo", + }) + }) + return srv +} + +func TestGenericProvider_AuthURL_Discovery(t *testing.T) { + srv := newDiscoveryServer(t) + defer srv.Close() + + svc := NewOAuthService(OAuthConfig{ + Web: OAuthClientConfig{ + Generic: OAuthProviderConfig{ + ClientID: "scion-web", + ClientSecret: "secret", + Issuer: srv.URL, // discovery derived from issuer + }, + }, + }) + + authURL, err := svc.GetAuthorizationURLForClient(OAuthClientTypeWeb, hubclient.OAuthProviderGeneric, "https://hub.example.com/auth/callback/generic", "state123") + if err != nil { + t.Fatalf("GetAuthorizationURLForClient: %v", err) + } + if !strings.HasPrefix(authURL, srv.URL+"/auth?") { + t.Fatalf("auth URL %q does not start with discovered authorization endpoint %q", authURL, srv.URL+"/auth") + } + u, err := url.Parse(authURL) + if err != nil { + t.Fatalf("parse auth URL: %v", err) + } + q := u.Query() + if q.Get("client_id") != "scion-web" { + t.Errorf("client_id = %q, want scion-web", q.Get("client_id")) + } + if q.Get("state") != "state123" { + t.Errorf("state = %q, want state123", q.Get("state")) + } + if q.Get("scope") != "openid email profile" { + t.Errorf("scope = %q, want default openid email profile", q.Get("scope")) + } + if q.Get("response_type") != "code" { + t.Errorf("response_type = %q, want code", q.Get("response_type")) + } +} + +func TestGenericProvider_AuthURL_ExplicitEndpointsAndScopes(t *testing.T) { + svc := NewOAuthService(OAuthConfig{ + Web: OAuthClientConfig{ + Generic: OAuthProviderConfig{ + ClientID: "scion-web", + ClientSecret: "secret", + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/oauth/token", + Scopes: "openid email", + }, + }, + }) + + authURL, err := svc.GetAuthorizationURLForClient(OAuthClientTypeWeb, hubclient.OAuthProviderGeneric, "https://hub.example.com/cb", "s") + if err != nil { + t.Fatalf("GetAuthorizationURLForClient: %v", err) + } + if !strings.HasPrefix(authURL, "https://idp.example.com/authorize?") { + t.Fatalf("auth URL %q does not use explicit authorization endpoint", authURL) + } + if got := mustQuery(t, authURL).Get("scope"); got != "openid email" { + t.Errorf("scope = %q, want custom openid email", got) + } +} + +func TestGenericProvider_ExchangeCode(t *testing.T) { + mux := http.NewServeMux() + srv := httptest.NewServer(mux) + defer srv.Close() + + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]string{ + "issuer": srv.URL, + "authorization_endpoint": srv.URL + "/auth", + "token_endpoint": srv.URL + "/token", + "userinfo_endpoint": srv.URL + "/userinfo", + }) + }) + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"access_token": "at-123", "token_type": "Bearer", "expires_in": 3600}) + }) + mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer at-123" { + http.Error(w, "missing bearer", http.StatusUnauthorized) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "sub": "dex-user-uuid", + "email": "alex@neuralnetes.com", + "name": "Alex", + "picture": "https://img/avatar.png", + }) + }) + + svc := NewOAuthService(OAuthConfig{ + Web: OAuthClientConfig{ + Generic: OAuthProviderConfig{ClientID: "scion-web", ClientSecret: "secret", DiscoveryURL: srv.URL + "/.well-known/openid-configuration"}, + }, + }) + + info, err := svc.ExchangeCodeForClient(context.Background(), OAuthClientTypeWeb, hubclient.OAuthProviderGeneric, "auth-code", "https://hub.example.com/cb") + if err != nil { + t.Fatalf("ExchangeCodeForClient: %v", err) + } + if info.ID != "dex-user-uuid" { + t.Errorf("ID = %q, want dex-user-uuid", info.ID) + } + if info.Email != "alex@neuralnetes.com" { + t.Errorf("Email = %q", info.Email) + } + if info.Provider != hubclient.OAuthProviderGeneric { + t.Errorf("Provider = %q, want generic", info.Provider) + } +} + +func TestGenericProvider_DiscoveryCached(t *testing.T) { + var hits int + mux := http.NewServeMux() + srv := httptest.NewServer(mux) + defer srv.Close() + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + hits++ + _ = json.NewEncoder(w).Encode(map[string]string{ + "issuer": srv.URL, + "authorization_endpoint": srv.URL + "/auth", + "token_endpoint": srv.URL + "/token", + "userinfo_endpoint": srv.URL + "/userinfo", + }) + }) + + svc := NewOAuthService(OAuthConfig{}) + cfg := OAuthProviderConfig{ClientID: "id", ClientSecret: "s", DiscoveryURL: srv.URL + "/.well-known/openid-configuration"} + for i := 0; i < 3; i++ { + if _, err := svc.resolveGenericEndpoints(context.Background(), cfg); err != nil { + t.Fatalf("resolveGenericEndpoints: %v", err) + } + } + if hits != 1 { + t.Errorf("discovery fetched %d times, want 1 (cached)", hits) + } +} + +func mustQuery(t *testing.T, rawURL string) url.Values { + t.Helper() + u, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("parse %q: %v", rawURL, err) + } + return u.Query() +} diff --git a/pkg/hub/web.go b/pkg/hub/web.go index 46b1766e..317126b3 100644 --- a/pkg/hub/web.go +++ b/pkg/hub/web.go @@ -1253,7 +1253,7 @@ func (ws *WebServer) handleOAuthLogin(w http.ResponseWriter, r *http.Request) { } // Validate provider - if provider != "google" && provider != "github" { + if provider != "google" && provider != "github" && provider != "generic" { http.Error(w, "unsupported OAuth provider", http.StatusBadRequest) return } @@ -1310,7 +1310,7 @@ func (ws *WebServer) handleOAuthCallback(w http.ResponseWriter, r *http.Request) provider := strings.TrimPrefix(r.URL.Path, "/auth/callback/") provider = strings.TrimSuffix(provider, "/") - if provider != "google" && provider != "github" { + if provider != "google" && provider != "github" && provider != "generic" { http.Error(w, "unsupported OAuth provider", http.StatusBadRequest) return } diff --git a/pkg/hubclient/auth.go b/pkg/hubclient/auth.go index 55f5f86c..d14c278b 100644 --- a/pkg/hubclient/auth.go +++ b/pkg/hubclient/auth.go @@ -28,6 +28,10 @@ type OAuthClientType string const ( OAuthProviderGoogle = "google" OAuthProviderGitHub = "github" + // OAuthProviderGeneric is a generic, configurable OAuth2/OIDC provider + // (e.g. the in-cluster Dex). Unlike Google/GitHub, its endpoints are + // resolved from the issuer's OIDC discovery document, or set explicitly. + OAuthProviderGeneric = "generic" OAuthClientTypeWeb OAuthClientType = "web" OAuthClientTypeCLI OAuthClientType = "cli" @@ -38,6 +42,7 @@ func OAuthProviderOrder() []string { return []string{ OAuthProviderGoogle, OAuthProviderGitHub, + OAuthProviderGeneric, } }