diff --git a/cmd/dst/run.go b/cmd/dst/run.go index ab3c9af6..08e77520 100644 --- a/cmd/dst/run.go +++ b/cmd/dst/run.go @@ -33,6 +33,7 @@ func RunDSTCmd() *cobra.Command { scenario string visualizationPath string verbose bool + printOps bool reqsPerTick = util.NewRangeIntFlag(1, 25) ids = util.NewRangeIntFlag(1, 25) @@ -170,6 +171,7 @@ func RunDSTCmd() *cobra.Command { Timeout: timeout, VisualizationPath: visualizationPath, Verbose: verbose, + PrintOps: printOps, TimeElapsedPerTick: 1000, // ms TimeoutTicks: t, ReqsPerTick: func() int { return reqsPerTick.Resolve(r) }, @@ -210,6 +212,7 @@ func RunDSTCmd() *cobra.Command { cmd.Flags().StringVar(&scenario, "scenario", "default", "can be one of: default, fault, lazy") cmd.Flags().StringVar(&visualizationPath, "visualization-path", "dst.html", "porcupine visualization file path") cmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "log additional information when run is non linearizable") + cmd.Flags().BoolVar(&printOps, "print-ops", true, "log the request/response pairs of a run.") cmd.Flags().Var(reqsPerTick, "reqs-per-tick", "number of requests per tick") cmd.Flags().Var(ids, "ids", "promise id set size") cmd.Flags().Var(idempotencyKeys, "idempotency-keys", "idempotency key set size") @@ -220,6 +223,7 @@ func RunDSTCmd() *cobra.Command { // bind config _ = config.BindDST(cmd) + cmd.SilenceUsage = true cmd.Flags().SortFlags = false diff --git a/internal/app/coroutines/completeTask.go b/internal/app/coroutines/completeTask.go index 7a39eb49..8ae198be 100644 --- a/internal/app/coroutines/completeTask.go +++ b/internal/app/coroutines/completeTask.go @@ -47,7 +47,7 @@ func CompleteTask(c gocoro.Coroutine[*t_aio.Submission, *t_aio.Completion, any], } if t.State == task.Completed || t.State == task.Timedout { - status = t_api.StatusTaskAlreadyCompleted + status = t_api.StatusOK } else if t.State == task.Init || t.State == task.Enqueued { status = t_api.StatusTaskInvalidState } else if t.Counter != r.CompleteTask.Counter { diff --git a/internal/app/coroutines/createCallback.go b/internal/app/coroutines/createCallback.go index b2afa234..c2dceb5e 100644 --- a/internal/app/coroutines/createCallback.go +++ b/internal/app/coroutines/createCallback.go @@ -66,7 +66,7 @@ func CreateCallback(c gocoro.Coroutine[*t_aio.Submission, *t_aio.Completion, any createdOn := c.Time() - callbackId := fmt.Sprintf("%s.%s", r.CreateCallback.PromiseId, r.CreateCallback.Id) + cbId := callbackId(r.CreateCallback.RootPromiseId, r.CreateCallback.PromiseId) completion, err := gocoro.YieldAndAwait(c, &t_aio.Submission{ Kind: t_aio.Store, Tags: r.Tags, @@ -76,7 +76,7 @@ func CreateCallback(c gocoro.Coroutine[*t_aio.Submission, *t_aio.Completion, any { Kind: t_aio.CreateCallback, CreateCallback: &t_aio.CreateCallbackCommand{ - Id: callbackId, + Id: cbId, PromiseId: r.CreateCallback.PromiseId, Recv: r.CreateCallback.Recv, Mesg: mesg, @@ -104,7 +104,7 @@ func CreateCallback(c gocoro.Coroutine[*t_aio.Submission, *t_aio.Completion, any if result.RowsAffected == 1 { status = t_api.StatusCreated cb = &callback.Callback{ - Id: callbackId, + Id: cbId, PromiseId: r.CreateCallback.PromiseId, Recv: r.CreateCallback.Recv, Mesg: mesg, @@ -138,3 +138,7 @@ func CreateCallback(c gocoro.Coroutine[*t_aio.Submission, *t_aio.Completion, any util.Assert(res != nil, "response must not be nil") return res, nil } + +func callbackId(rootPromiseId, promiseId string) string { + return fmt.Sprintf("__resume:%s:%s", rootPromiseId, promiseId) +} diff --git a/internal/app/coroutines/createPromise.go b/internal/app/coroutines/createPromise.go index 7d871914..82929da6 100644 --- a/internal/app/coroutines/createPromise.go +++ b/internal/app/coroutines/createPromise.go @@ -1,6 +1,7 @@ package coroutines import ( + "fmt" "log/slog" "github.com/resonatehq/gocoro" @@ -24,6 +25,7 @@ func CreatePromiseAndTask(c gocoro.Coroutine[*t_aio.Submission, *t_aio.Completio util.Assert(r.CreatePromiseAndTask.Promise.Timeout == r.CreatePromiseAndTask.Task.Timeout, "timeouts must match") return createPromiseAndTask(c, r, r.CreatePromiseAndTask.Promise, &t_aio.CreateTaskCommand{ + Id: invokeTaskId(r.CreatePromiseAndTask.Task.PromiseId), Recv: nil, Mesg: &message.Mesg{Type: message.Invoke, Root: r.CreatePromiseAndTask.Task.PromiseId, Leaf: r.CreatePromiseAndTask.Task.PromiseId}, Timeout: r.CreatePromiseAndTask.Task.Timeout, @@ -93,8 +95,14 @@ func createPromiseAndTask( if err != nil { return nil, err } + var promiseRowsAffected int64 + if taskCmd == nil { + promiseRowsAffected = completion.Store.Results[0].CreatePromise.RowsAffected + } else { + promiseRowsAffected = completion.Store.Results[0].CreatePromiseAndTask.PromiseRowsAffected + } - if completion.Store.Results[0].CreatePromise.RowsAffected == 0 { + if promiseRowsAffected == 0 { // It's possible that the promise was created by another coroutine // while we were creating. In that case, we should just retry. return createPromiseAndTask(c, r, createPromiseReq, taskCmd) @@ -117,20 +125,21 @@ func createPromiseAndTask( switch r.Kind { case t_api.CreatePromiseAndTask: util.Assert(taskCmd != nil, "create task cmd must not be nil") - util.Assert(completion.Store.Results[1].Kind == t_aio.CreateTask, "completion must be create task") + util.Assert(completion.Store.Results[0].Kind == t_aio.CreatePromiseAndTask, "completion must be createPromiseAndTask") t = &task.Task{ - Id: completion.Store.Results[1].CreateTask.LastInsertId, - ProcessId: taskCmd.ProcessId, - State: taskCmd.State, - Recv: taskCmd.Recv, - Mesg: taskCmd.Mesg, - Timeout: taskCmd.Timeout, - Counter: 1, - Attempt: 0, - Ttl: taskCmd.Ttl, - ExpiresAt: taskCmd.ExpiresAt, - CreatedOn: &taskCmd.CreatedOn, + Id: taskCmd.Id, + ProcessId: taskCmd.ProcessId, + RootPromiseId: p.Id, + State: taskCmd.State, + Recv: taskCmd.Recv, + Mesg: taskCmd.Mesg, + Timeout: taskCmd.Timeout, + Counter: 1, + Attempt: 0, + Ttl: taskCmd.Ttl, + ExpiresAt: taskCmd.ExpiresAt, + CreatedOn: &taskCmd.CreatedOn, } } } else { @@ -203,12 +212,26 @@ func createPromise(tags map[string]string, promiseCmd *t_aio.CreatePromiseComman promiseCmd.Tags = map[string]string{} } + isCreatePromiseAndTask := taskCmd != nil + return func(c gocoro.Coroutine[*t_aio.Submission, *t_aio.Completion, *t_aio.Completion]) (*t_aio.Completion, error) { - // add create promise command - commands := []*t_aio.Command{{ - Kind: t_aio.CreatePromise, - CreatePromise: promiseCmd, - }} + commands := []*t_aio.Command{} + + // Combine both commands if taskCmd is not null otherwise add just the CreatePromiseCmd + if isCreatePromiseAndTask { + commands = append(commands, &t_aio.Command{ + Kind: t_aio.CreatePromiseAndTask, + CreatePromiseAndTask: &t_aio.CreatePromiseAndTaskCommand{ + PromiseCommand: promiseCmd, + TaskCommand: taskCmd, + }, + }) + } else { + commands = append(commands, &t_aio.Command{ + Kind: t_aio.CreatePromise, + CreatePromise: promiseCmd, + }) + } // check router to see if a task needs to be created completion, err := gocoro.YieldAndAwait(c, &t_aio.Submission{ @@ -231,8 +254,8 @@ func createPromise(tags map[string]string, promiseCmd *t_aio.CreatePromiseComman slog.Warn("failed to match promise", "cmd", promiseCmd, "err", err) } - if taskCmd != nil && (err != nil || !completion.Router.Matched) { - slog.Error("failed to match promise when creating a task", "cmd", promiseCmd) + if isCreatePromiseAndTask && (err != nil || !completion.Router.Matched) { + slog.Error("failed to match promise with router when creating a task", "cmd", promiseCmd) return nil, t_api.NewError(t_api.StatusPromiseRecvNotFound, err) } @@ -240,10 +263,12 @@ func createPromise(tags map[string]string, promiseCmd *t_aio.CreatePromiseComman util.Assert(completion.Router.Recv != nil, "recv must not be nil") // If there is a taskCmd just update the Recv otherwise create a tasks for the match - if taskCmd != nil { + if isCreatePromiseAndTask { + // Note: we are mutating the taskCmd that is already merged with the createPromiseCmd taskCmd.Recv = completion.Router.Recv } else { taskCmd = &t_aio.CreateTaskCommand{ + Id: invokeTaskId(promiseCmd.Id), Recv: completion.Router.Recv, Mesg: &message.Mesg{Type: message.Invoke, Root: promiseCmd.Id, Leaf: promiseCmd.Id}, Timeout: promiseCmd.Timeout, @@ -251,13 +276,14 @@ func createPromise(tags map[string]string, promiseCmd *t_aio.CreatePromiseComman CreatedOn: promiseCmd.CreatedOn, } + // add create task command if matched + commands = append(commands, &t_aio.Command{ + Kind: t_aio.CreateTask, + CreateTask: taskCmd, + }) + } - // add create task command if matched - commands = append(commands, &t_aio.Command{ - Kind: t_aio.CreateTask, - CreateTask: taskCmd, - }) } // add additional commands @@ -281,8 +307,21 @@ func createPromise(tags map[string]string, promiseCmd *t_aio.CreatePromiseComman util.Assert(completion.Store != nil, "completion must not be nil") util.Assert(len(completion.Store.Results) == len(commands), "completion must have same number of results as commands") - util.Assert(completion.Store.Results[0].CreatePromise.RowsAffected == 0 || completion.Store.Results[0].CreatePromise.RowsAffected == 1, "result must return 0 or 1 rows") + if isCreatePromiseAndTask { + promiseAndTaskResult := completion.Store.Results[0].CreatePromiseAndTask + util.Assert(promiseAndTaskResult.PromiseRowsAffected == 0 || promiseAndTaskResult.PromiseRowsAffected == 1, "Creating promise result must return 0 or 1 rows") + if promiseAndTaskResult.PromiseRowsAffected == 0 { + util.Assert(promiseAndTaskResult.TaskRowsAffected == 0, "If not promise was created a task must have not been created") + } + } else { + createPromiseResult := completion.Store.Results[0].CreatePromise + util.Assert(createPromiseResult.RowsAffected == 0 || createPromiseResult.RowsAffected == 1, "CreatePromise result must return 0 or 1 rows") + } return completion, nil } } + +func invokeTaskId(promiseId string) string { + return fmt.Sprintf("__invoke:%s", promiseId) +} diff --git a/internal/app/coroutines/createSubscription.go b/internal/app/coroutines/createSubscription.go index b0c16b17..e106b457 100644 --- a/internal/app/coroutines/createSubscription.go +++ b/internal/app/coroutines/createSubscription.go @@ -66,7 +66,7 @@ func CreateSubscription(c gocoro.Coroutine[*t_aio.Submission, *t_aio.Completion, createdOn := c.Time() - callbackId := fmt.Sprintf("%s.%s", r.CreateSubscription.PromiseId, r.CreateSubscription.Id) + callbackId := subscriptionId(r.CreateSubscription.PromiseId, r.CreateSubscription.Id) completion, err := gocoro.YieldAndAwait(c, &t_aio.Submission{ Kind: t_aio.Store, Tags: r.Tags, @@ -138,3 +138,7 @@ func CreateSubscription(c gocoro.Coroutine[*t_aio.Submission, *t_aio.Completion, util.Assert(res != nil, "response must not be nil") return res, nil } + +func subscriptionId(promiseId, customId string) string { + return fmt.Sprintf("__notify:%s:%s", promiseId, customId) +} diff --git a/internal/app/subsystems/aio/sender/sender_dst.go b/internal/app/subsystems/aio/sender/sender_dst.go index 671716b4..c3e16220 100644 --- a/internal/app/subsystems/aio/sender/sender_dst.go +++ b/internal/app/subsystems/aio/sender/sender_dst.go @@ -6,6 +6,7 @@ import ( "github.com/resonatehq/resonate/internal/kernel/bus" "github.com/resonatehq/resonate/internal/kernel/t_aio" + "github.com/resonatehq/resonate/pkg/message" ) // Config @@ -52,8 +53,17 @@ func (s *SenderDST) Process(sqes []*bus.SQE[t_aio.Submission, t_aio.Completion]) for i, sqe := range sqes { var completion *t_aio.SenderCompletion + mesgType := sqe.Submission.Sender.Task.Mesg.Type + + var obj any + if mesgType == message.Notify { + obj = sqe.Submission.Sender.Promise + } else { + obj = sqe.Submission.Sender.Task + } + select { - case s.backchannel <- sqe.Submission.Sender.Task: + case s.backchannel <- obj: completion = &t_aio.SenderCompletion{ Success: s.r.Float64() < s.config.P, } diff --git a/internal/app/subsystems/aio/store/postgres/postgres.go b/internal/app/subsystems/aio/store/postgres/postgres.go index 3555ce20..57d4513f 100644 --- a/internal/app/subsystems/aio/store/postgres/postgres.go +++ b/internal/app/subsystems/aio/store/postgres/postgres.go @@ -91,7 +91,8 @@ const ( CREATE INDEX IF NOT EXISTS idx_locks_expires_at ON locks(expires_at); CREATE TABLE IF NOT EXISTS tasks ( - id SERIAL PRIMARY KEY, + id TEXT, + sort_id SERIAL, process_id TEXT, state INTEGER DEFAULT 1, root_promise_id TEXT, @@ -103,7 +104,8 @@ const ( ttl INTEGER DEFAULT 0, expires_at BIGINT DEFAULT 0, created_on BIGINT, - completed_on BIGINT + completed_on BIGINT, + PRIMARY KEY(id) ); CREATE INDEX IF NOT EXISTS idx_tasks_process_id ON tasks(process_id); @@ -297,7 +299,7 @@ const ( FROM tasks WHERE state & $1 != 0 AND (expires_at <= $2 OR timeout <= $2) - ORDER BY root_promise_id, id + ORDER BY root_promise_id, sort_id ASC LIMIT $3` TASK_SELECT_ENQUEUEABLE_STATEMENT = ` @@ -324,21 +326,21 @@ const ( WHERE t2.root_promise_id = t1.root_promise_id AND t2.state in (2, 4) -- 2 -> Enqueue, 4 -> Claimed ) - ORDER BY root_promise_id, id + ORDER BY root_promise_id, sort_id ASC LIMIT $1` TASK_INSERT_STATEMENT = ` INSERT INTO tasks - (recv, mesg, timeout, process_id, state, root_promise_id, ttl, expires_at, created_on) + (id, recv, mesg, timeout, process_id, state, root_promise_id, ttl, expires_at, created_on) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9) - RETURNING id` + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ON CONFLICT(id) DO NOTHING` TASK_INSERT_ALL_STATEMENT = ` INSERT INTO tasks - (recv, mesg, timeout, root_promise_id, created_on) + (id, recv, mesg, timeout, root_promise_id, created_on) SELECT - recv, mesg, timeout, root_promise_id, $1 + id, recv, mesg, timeout, root_promise_id, $1 FROM callbacks WHERE @@ -360,7 +362,7 @@ const ( SET state = 8, completed_on = $1 -- State = 8 -> Completed WHERE - root_promise_id = $2 AND state in (1, 2) -- State in (Init, Enqueued)` + root_promise_id = $2 AND state in (1, 2, 4) -- State in (Init, Enqueued, Claimed)` TASK_HEARTBEAT_STATEMENT = ` UPDATE @@ -589,6 +591,7 @@ func (w *PostgresStoreWorker) performCommands(tx *sql.Tx, transactions []*t_aio. var lockHeartbeatStmt *sql.Stmt var lockTimeoutStmt *sql.Stmt var tasksInsertStmt *sql.Stmt + var taskInsertStmt *sql.Stmt var taskUpdateStmt *sql.Stmt var tasksCompleteStmt *sql.Stmt var taskHeartbeatStmt *sql.Stmt @@ -765,8 +768,16 @@ func (w *PostgresStoreWorker) performCommands(tx *sql.Tx, transactions []*t_aio. util.Assert(command.ReadEnquableTasks != nil, "command must not be nil") results[i][j], err = w.readEnqueueableTasks(tx, command.ReadEnquableTasks) case t_aio.CreateTask: + if taskInsertStmt == nil { + taskInsertStmt, err = tx.Prepare(TASK_INSERT_STATEMENT) + if err != nil { + return nil, err + } + defer taskInsertStmt.Close() + } + util.Assert(command.CreateTask != nil, "command must not be nil") - results[i][j], err = w.createTask(tx, command.CreateTask) + results[i][j], err = w.createTask(tx, taskInsertStmt, command.CreateTask) case t_aio.CreateTasks: if tasksInsertStmt == nil { tasksInsertStmt, err = tx.Prepare(TASK_INSERT_ALL_STATEMENT) @@ -812,6 +823,26 @@ func (w *PostgresStoreWorker) performCommands(tx *sql.Tx, transactions []*t_aio. util.Assert(command.HeartbeatTasks != nil, "command must not be nil") results[i][j], err = w.heartbeatTasks(tx, taskHeartbeatStmt, command.HeartbeatTasks) + case t_aio.CreatePromiseAndTask: + if promiseInsertStmt == nil { + promiseInsertStmt, err = tx.Prepare(PROMISE_INSERT_STATEMENT) + if err != nil { + return nil, err + } + defer promiseInsertStmt.Close() + } + + if taskInsertStmt == nil { + taskInsertStmt, err = tx.Prepare(TASK_INSERT_STATEMENT) + if err != nil { + return nil, err + } + defer taskInsertStmt.Close() + } + + util.Assert(command.CreatePromiseAndTask != nil, "createPromiseAndTask command must bot be nil") + results[i][j], err = w.createPromiseAndTask(tx, promiseInsertStmt, taskInsertStmt, command.CreatePromiseAndTask) + default: panic(fmt.Sprintf("invalid command: %s", command.Kind.String())) } @@ -995,7 +1026,7 @@ func (w *PostgresStoreWorker) searchPromises(tx *sql.Tx, cmd *t_aio.SearchPromis }, nil } -func (w *PostgresStoreWorker) createPromise(tx *sql.Tx, stmt *sql.Stmt, cmd *t_aio.CreatePromiseCommand) (*t_aio.Result, error) { +func (w *PostgresStoreWorker) createPromise(_ *sql.Tx, stmt *sql.Stmt, cmd *t_aio.CreatePromiseCommand) (*t_aio.Result, error) { util.Assert(cmd.Param.Headers != nil, "param headers must not be nil") util.Assert(cmd.Param.Data != nil, "param data must not be nil") util.Assert(cmd.Tags != nil, "tags must not be nil") @@ -1029,6 +1060,34 @@ func (w *PostgresStoreWorker) createPromise(tx *sql.Tx, stmt *sql.Stmt, cmd *t_a }, nil } +func (w *PostgresStoreWorker) createPromiseAndTask(tx *sql.Tx, promiseStmt *sql.Stmt, taskStmt *sql.Stmt, cmd *t_aio.CreatePromiseAndTaskCommand) (*t_aio.Result, error) { + promiseResult, err := w.createPromise(tx, promiseStmt, cmd.PromiseCommand) + if err != nil { + return nil, err + } + + // Couldn't create a promise + if promiseResult.CreatePromise.RowsAffected == 0 { + return &t_aio.Result{ + Kind: t_aio.CreatePromiseAndTask, + CreatePromiseAndTask: &t_aio.AlterPromiseAndTaskResult{}, + }, nil + } + + taskResult, err := w.createTask(tx, taskStmt, cmd.TaskCommand) + if err != nil { + return nil, err + } + + return &t_aio.Result{ + Kind: t_aio.CreatePromiseAndTask, + CreatePromiseAndTask: &t_aio.AlterPromiseAndTaskResult{ + PromiseRowsAffected: promiseResult.CreatePromise.RowsAffected, + TaskRowsAffected: taskResult.CreateTask.RowsAffected, + }, + }, nil +} + func (w *PostgresStoreWorker) updatePromise(tx *sql.Tx, stmt *sql.Stmt, cmd *t_aio.UpdatePromiseCommand) (*t_aio.Result, error) { util.Assert(cmd.State.In(promise.Resolved|promise.Rejected|promise.Canceled|promise.Timedout), "state must be canceled, resolved, rejected, or timedout") util.Assert(cmd.Value.Headers != nil, "value headers must not be nil") @@ -1584,7 +1643,7 @@ func (w *PostgresStoreWorker) readEnqueueableTasks(tx *sql.Tx, cmd *t_aio.ReadEn }, nil } -func (w *PostgresStoreWorker) createTask(tx *sql.Tx, cmd *t_aio.CreateTaskCommand) (*t_aio.Result, error) { +func (w *PostgresStoreWorker) createTask(tx *sql.Tx, stmt *sql.Stmt, cmd *t_aio.CreateTaskCommand) (*t_aio.Result, error) { util.Assert(cmd.Recv != nil, "recv must not be nil") util.Assert(cmd.Mesg != nil, "mesg must not be nil") util.Assert(cmd.State.In(task.Init|task.Claimed), "state must be init or claimed") @@ -1595,23 +1654,21 @@ func (w *PostgresStoreWorker) createTask(tx *sql.Tx, cmd *t_aio.CreateTaskComman return nil, store.StoreErr(err) } - var lastInsertId string - rowsAffected := int64(1) - row := tx.QueryRow(TASK_INSERT_STATEMENT, cmd.Recv, mesg, cmd.Timeout, cmd.ProcessId, cmd.State, cmd.Mesg.Root, cmd.Ttl, cmd.ExpiresAt, cmd.CreatedOn) + // insert + res, err := stmt.Exec(cmd.Id, cmd.Recv, mesg, cmd.Timeout, cmd.ProcessId, cmd.State, cmd.Mesg.Root, cmd.Ttl, cmd.ExpiresAt, cmd.CreatedOn) + if err != nil { + return nil, err + } - if err := row.Scan(&lastInsertId); err != nil { - if err == sql.ErrNoRows { - rowsAffected = 0 - } else { - return nil, store.StoreErr(err) - } + rowsAffected, err := res.RowsAffected() + if err != nil { + return nil, err } return &t_aio.Result{ Kind: t_aio.CreateTask, CreateTask: &t_aio.AlterTasksResult{ RowsAffected: rowsAffected, - LastInsertId: lastInsertId, }, }, nil } diff --git a/internal/app/subsystems/aio/store/sqlite/sqlite.go b/internal/app/subsystems/aio/store/sqlite/sqlite.go index bef36638..fdc5654a 100644 --- a/internal/app/subsystems/aio/store/sqlite/sqlite.go +++ b/internal/app/subsystems/aio/store/sqlite/sqlite.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "os" - "strconv" "strings" "time" @@ -88,7 +87,8 @@ const ( CREATE INDEX IF NOT EXISTS idx_locks_expires_at ON locks(expires_at); CREATE TABLE IF NOT EXISTS tasks ( - id INTEGER PRIMARY KEY AUTOINCREMENT, + id TEXT UNIQUE, + sort_id INTEGER PRIMARY KEY AUTOINCREMENT, process_id TEXT, state INTEGER DEFAULT 1, root_promise_id TEXT, @@ -286,7 +286,7 @@ const ( FROM tasks WHERE state & ? != 0 AND (expires_at <= ? OR timeout <= ?) - ORDER BY root_promise_id, id + ORDER BY root_promise_id, sort_id ASC LIMIT ?` TASK_SELECT_ENQUEUEABLE_STATEMENT = ` @@ -314,20 +314,21 @@ const ( AND t2.state in (2, 4) -- 2 -> Enqueue, 4 -> Claimed ) GROUP BY root_promise_id - ORDER BY root_promise_id, id + ORDER BY root_promise_id, sort_id ASC LIMIT ?` TASK_INSERT_STATEMENT = ` INSERT INTO tasks - (recv, mesg, timeout, process_id, state, root_promise_id, ttl, expires_at, created_on) + (id, recv, mesg, timeout, process_id, state, root_promise_id, ttl, expires_at, created_on) VALUES - (?, ?, ?, ?, ?, ?, ?, ?, ?)` + (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO NOTHING` TASK_INSERT_ALL_STATEMENT = ` INSERT INTO tasks - (recv, mesg, timeout, root_promise_id, created_on) + (id, recv, mesg, timeout, root_promise_id, created_on) SELECT - recv, mesg, timeout, root_promise_id, ? + id, recv, mesg, timeout, root_promise_id, ? FROM callbacks WHERE @@ -349,7 +350,7 @@ const ( SET state = 8, completed_on = ? -- State 8 -> Completed WHERE - root_promise_id = ? AND state in (1, 2) -- State in (Init, Enqueued)` + root_promise_id = ? AND state in (1, 2, 4) -- State in (Init, Enqueued, Claimed)` TASK_HEARTBEAT_STATEMENT = ` UPDATE @@ -775,6 +776,24 @@ func (w *SqliteStoreWorker) performCommands(tx *sql.Tx, transactions []*t_aio.Tr util.Assert(command.HeartbeatTasks != nil, "command must not be nil") results[i][j], err = w.heartbeatTasks(tx, taskHeartbeatStmt, command.HeartbeatTasks) + case t_aio.CreatePromiseAndTask: + if promiseInsertStmt == nil { + promiseInsertStmt, err = tx.Prepare(PROMISE_INSERT_STATEMENT) + if err != nil { + return nil, err + } + defer promiseInsertStmt.Close() + } + if taskInsertStmt == nil { + taskInsertStmt, err = tx.Prepare(TASK_INSERT_STATEMENT) + if err != nil { + return nil, err + } + defer taskInsertStmt.Close() + } + + util.Assert(command.CreatePromiseAndTask != nil, "createPromiseAndTask command must bot be nil") + results[i][j], err = w.createPromiseAndTask(tx, promiseInsertStmt, taskInsertStmt, command.CreatePromiseAndTask) default: panic(fmt.Sprintf("invalid command: %s", command.Kind.String())) @@ -997,6 +1016,34 @@ func (w *SqliteStoreWorker) createPromise(tx *sql.Tx, stmt *sql.Stmt, cmd *t_aio }, nil } +func (w *SqliteStoreWorker) createPromiseAndTask(tx *sql.Tx, promiseStmt *sql.Stmt, TaskStmt *sql.Stmt, cmd *t_aio.CreatePromiseAndTaskCommand) (*t_aio.Result, error) { + promiseResult, err := w.createPromise(tx, promiseStmt, cmd.PromiseCommand) + if err != nil { + return nil, err + } + + // Couldn't create a promise + if promiseResult.CreatePromise.RowsAffected == 0 { + return &t_aio.Result{ + Kind: t_aio.CreatePromiseAndTask, + CreatePromiseAndTask: &t_aio.AlterPromiseAndTaskResult{}, + }, nil + } + + taskResult, err := w.createTask(tx, TaskStmt, cmd.TaskCommand) + if err != nil { + return nil, err + } + + return &t_aio.Result{ + Kind: t_aio.CreatePromiseAndTask, + CreatePromiseAndTask: &t_aio.AlterPromiseAndTaskResult{ + PromiseRowsAffected: promiseResult.CreatePromise.RowsAffected, + TaskRowsAffected: taskResult.CreateTask.RowsAffected, + }, + }, nil +} + func (w *SqliteStoreWorker) updatePromise(tx *sql.Tx, stmt *sql.Stmt, cmd *t_aio.UpdatePromiseCommand) (*t_aio.Result, error) { util.Assert(cmd.State.In(promise.Resolved|promise.Rejected|promise.Canceled|promise.Timedout), "state must be canceled, resolved, rejected, or timedout") util.Assert(cmd.Value.Headers != nil, "value headers must not be nil") @@ -1586,7 +1633,7 @@ func (w *SqliteStoreWorker) createTask(tx *sql.Tx, stmt *sql.Stmt, cmd *t_aio.Cr return nil, store.StoreErr(err) } - res, err := stmt.Exec(cmd.Recv, mesg, cmd.Timeout, cmd.ProcessId, cmd.State, cmd.Mesg.Root, cmd.Ttl, cmd.ExpiresAt, cmd.CreatedOn) + res, err := stmt.Exec(cmd.Id, cmd.Recv, mesg, cmd.Timeout, cmd.ProcessId, cmd.State, cmd.Mesg.Root, cmd.Ttl, cmd.ExpiresAt, cmd.CreatedOn) if err != nil { return nil, store.StoreErr(err) } @@ -1596,21 +1643,10 @@ func (w *SqliteStoreWorker) createTask(tx *sql.Tx, stmt *sql.Stmt, cmd *t_aio.Cr return nil, store.StoreErr(err) } - lastInsertId, err := res.LastInsertId() - if err != nil { - return nil, store.StoreErr(err) - } - - var lastInsertIdStr string - if rowsAffected != 0 { - lastInsertIdStr = strconv.FormatInt(lastInsertId, 10) - } - return &t_aio.Result{ Kind: t_aio.CreateTask, CreateTask: &t_aio.AlterTasksResult{ RowsAffected: rowsAffected, - LastInsertId: lastInsertIdStr, }, }, nil } diff --git a/internal/app/subsystems/aio/store/test/cases.go b/internal/app/subsystems/aio/store/test/cases.go index ad2432c2..af550e31 100644 --- a/internal/app/subsystems/aio/store/test/cases.go +++ b/internal/app/subsystems/aio/store/test/cases.go @@ -2848,12 +2848,188 @@ var TestCases = []*testCase{ }, // TASKS + { + name: "CreatePromiseAndTask", + commands: []*t_aio.Command{ + { + Kind: t_aio.CreatePromiseAndTask, + CreatePromiseAndTask: &t_aio.CreatePromiseAndTaskCommand{ + PromiseCommand: &t_aio.CreatePromiseCommand{ + Id: "foo", + Timeout: 1, + Param: promise.Value{ + Headers: map[string]string{}, + Data: []byte{}, + }, + Tags: map[string]string{}, + CreatedOn: 1, + }, + TaskCommand: &t_aio.CreateTaskCommand{ + Id: "__invoke:foo", + Recv: []byte("foo"), + Mesg: &message.Mesg{Type: message.Invoke, Root: "foo", Leaf: "foo"}, + ProcessId: util.ToPointer("pid"), + State: task.Claimed, + CreatedOn: 1, + Ttl: 2, + ExpiresAt: 2, + Timeout: 3, + }, + }, + }, + { + Kind: t_aio.ReadPromise, + ReadPromise: &t_aio.ReadPromiseCommand{ + Id: "foo", + }, + }, + { + Kind: t_aio.ReadTask, + ReadTask: &t_aio.ReadTaskCommand{ + Id: "__invoke:foo", + }, + }, + }, + expected: []*t_aio.Result{ + { + Kind: t_aio.CreatePromiseAndTask, + CreatePromiseAndTask: &t_aio.AlterPromiseAndTaskResult{ + PromiseRowsAffected: 1, + TaskRowsAffected: 1, + }, + }, + { + Kind: t_aio.ReadPromise, + ReadPromise: &t_aio.QueryPromisesResult{ + RowsReturned: 1, + Records: []*promise.PromiseRecord{{ + Id: "foo", + State: 1, + ParamHeaders: []byte("{}"), + ParamData: []byte{}, + Timeout: 1, + Tags: []byte("{}"), + CreatedOn: util.ToPointer(int64(1)), + }}, + }, + }, + { + Kind: t_aio.ReadTask, + ReadTask: &t_aio.QueryTasksResult{ + RowsReturned: 1, + Records: []*task.TaskRecord{ + { + Id: "__invoke:foo", + ProcessId: util.ToPointer("pid"), + State: task.Claimed, + RootPromiseId: "foo", + Recv: []byte("foo"), + Mesg: []byte(`{"type":"invoke","root":"foo","leaf":"foo"}`), + Attempt: 0, + Counter: 1, + CreatedOn: util.ToPointer[int64](1), + Ttl: 2, + ExpiresAt: 2, + Timeout: 3, + }, + }, + }, + }, + }, + }, + { + name: "CreatePromiseAndTask_PromiseAlreadyExists", + commands: []*t_aio.Command{ + { + Kind: t_aio.CreatePromise, + CreatePromise: &t_aio.CreatePromiseCommand{ + Id: "foo", + Timeout: 1, + Param: promise.Value{Headers: map[string]string{}, Data: []byte{}}, + Tags: map[string]string{}, + CreatedOn: 1, + }, + }, + { + Kind: t_aio.CreatePromiseAndTask, + CreatePromiseAndTask: &t_aio.CreatePromiseAndTaskCommand{ + PromiseCommand: &t_aio.CreatePromiseCommand{ + Id: "foo", + Timeout: 1, + Param: promise.Value{Headers: map[string]string{}, Data: []byte{}}, + Tags: map[string]string{}, + CreatedOn: 1, + }, + TaskCommand: &t_aio.CreateTaskCommand{ + Id: "__invoke:foo", + Recv: []byte("foo"), + Mesg: &message.Mesg{Type: message.Invoke, Root: "foo", Leaf: "foo"}, + ProcessId: util.ToPointer("pid"), + State: task.Claimed, + CreatedOn: 1, + Ttl: 2, + ExpiresAt: 2, + Timeout: 3, + }, + }, + }, + { + Kind: t_aio.ReadPromise, + ReadPromise: &t_aio.ReadPromiseCommand{ + Id: "foo", + }, + }, + { + Kind: t_aio.ReadTask, + ReadTask: &t_aio.ReadTaskCommand{ + Id: "__invoke:foo", + }, + }, + }, + expected: []*t_aio.Result{ + { + Kind: t_aio.CreatePromise, + CreatePromise: &t_aio.AlterPromisesResult{ + RowsAffected: 1, + }, + }, + { + Kind: t_aio.CreatePromiseAndTask, + CreatePromiseAndTask: &t_aio.AlterPromiseAndTaskResult{ + PromiseRowsAffected: 0, + TaskRowsAffected: 0, + }, + }, + { + Kind: t_aio.ReadPromise, + ReadPromise: &t_aio.QueryPromisesResult{ + RowsReturned: 1, + Records: []*promise.PromiseRecord{{ + Id: "foo", + State: 1, + ParamHeaders: []byte("{}"), + ParamData: []byte{}, + Timeout: 1, + Tags: []byte("{}"), + CreatedOn: util.ToPointer(int64(1)), + }}, + }, + }, + { + Kind: t_aio.ReadTask, + ReadTask: &t_aio.QueryTasksResult{ + RowsReturned: 0, + }, + }, + }, + }, { name: "CreateTask", commands: []*t_aio.Command{ { Kind: t_aio.CreateTask, CreateTask: &t_aio.CreateTaskCommand{ + Id: "1", Recv: []byte("foo"), Mesg: &message.Mesg{Type: message.Invoke, Root: "foo", Leaf: "foo"}, Timeout: 1, @@ -2864,6 +3040,7 @@ var TestCases = []*testCase{ { Kind: t_aio.CreateTask, CreateTask: &t_aio.CreateTaskCommand{ + Id: "2", Recv: []byte("bar"), Mesg: &message.Mesg{Type: message.Invoke, Root: "bar", Leaf: "bar"}, Timeout: 2, @@ -2884,14 +3061,12 @@ var TestCases = []*testCase{ Kind: t_aio.CreateTask, CreateTask: &t_aio.AlterTasksResult{ RowsAffected: 1, - LastInsertId: "1", }, }, { Kind: t_aio.CreateTask, CreateTask: &t_aio.AlterTasksResult{ RowsAffected: 1, - LastInsertId: "2", }, }, { @@ -3013,7 +3188,7 @@ var TestCases = []*testCase{ RowsReturned: 3, Records: []*task.TaskRecord{ { - Id: "2", + Id: "foo.2", Counter: 1, State: task.Init, RootPromiseId: "bar", @@ -3022,7 +3197,7 @@ var TestCases = []*testCase{ CreatedOn: util.ToPointer[int64](0), }, { - Id: "3", + Id: "foo.3", Counter: 1, State: task.Init, RootPromiseId: "baz", @@ -3031,7 +3206,7 @@ var TestCases = []*testCase{ CreatedOn: util.ToPointer[int64](0), }, { - Id: "1", + Id: "foo.1", Counter: 1, State: task.Init, RootPromiseId: "foo", @@ -3138,7 +3313,7 @@ var TestCases = []*testCase{ { Kind: t_aio.UpdateTask, UpdateTask: &t_aio.UpdateTaskCommand{ - Id: "1", + Id: "foo.1", ProcessId: util.ToPointer("pid"), State: task.Enqueued, Counter: 2, @@ -3159,7 +3334,7 @@ var TestCases = []*testCase{ { Kind: t_aio.UpdateTask, UpdateTask: &t_aio.UpdateTaskCommand{ - Id: "5", + Id: "pbar.1", ProcessId: util.ToPointer("pid"), State: task.Enqueued, Counter: 2, @@ -3245,7 +3420,7 @@ var TestCases = []*testCase{ RowsReturned: 2, Records: []*task.TaskRecord{ { - Id: "1", + Id: "foo.1", Counter: 1, State: task.Init, RootPromiseId: "foo", @@ -3254,7 +3429,7 @@ var TestCases = []*testCase{ CreatedOn: util.ToPointer[int64](0), }, { - Id: "4", + Id: "pbar.1", Counter: 1, State: task.Init, RootPromiseId: "pbar", @@ -3277,7 +3452,7 @@ var TestCases = []*testCase{ RowsReturned: 1, Records: []*task.TaskRecord{ { - Id: "4", + Id: "pbar.1", Counter: 1, State: task.Init, RootPromiseId: "pbar", @@ -3331,7 +3506,7 @@ var TestCases = []*testCase{ { Kind: t_aio.UpdateTask, UpdateTask: &t_aio.UpdateTaskCommand{ - Id: "1", + Id: "foo.1", ProcessId: util.ToPointer("pid"), State: task.Enqueued, Counter: 2, @@ -3346,7 +3521,7 @@ var TestCases = []*testCase{ { Kind: t_aio.UpdateTask, UpdateTask: &t_aio.UpdateTaskCommand{ - Id: "1", + Id: "foo.1", ProcessId: util.ToPointer("pid"), State: task.Claimed, Counter: 3, @@ -3361,7 +3536,7 @@ var TestCases = []*testCase{ { Kind: t_aio.UpdateTask, UpdateTask: &t_aio.UpdateTaskCommand{ - Id: "1", + Id: "foo.1", ProcessId: util.ToPointer("pid"), State: task.Claimed, Counter: 4, @@ -3376,7 +3551,7 @@ var TestCases = []*testCase{ { Kind: t_aio.UpdateTask, UpdateTask: &t_aio.UpdateTaskCommand{ - Id: "1", + Id: "foo.1", ProcessId: util.ToPointer("pid"), State: task.Completed, Counter: 5, @@ -3391,7 +3566,7 @@ var TestCases = []*testCase{ { Kind: t_aio.UpdateTask, UpdateTask: &t_aio.UpdateTaskCommand{ - Id: "1", + Id: "foo.1", ProcessId: util.ToPointer("pid"), State: task.Completed, Counter: 6, @@ -3406,7 +3581,7 @@ var TestCases = []*testCase{ { Kind: t_aio.ReadTask, ReadTask: &t_aio.ReadTaskCommand{ - Id: "1", + Id: "foo.1", }, }, }, @@ -3465,7 +3640,7 @@ var TestCases = []*testCase{ RowsReturned: 1, Records: []*task.TaskRecord{ { - Id: "1", + Id: "foo.1", ProcessId: util.ToPointer("pid"), State: task.Completed, RootPromiseId: "foo", @@ -3688,13 +3863,13 @@ var TestCases = []*testCase{ { Kind: t_aio.ReadTask, ReadTask: &t_aio.ReadTaskCommand{ - Id: "1", + Id: "foo.1", }, }, { Kind: t_aio.ReadTask, ReadTask: &t_aio.ReadTaskCommand{ - Id: "2", + Id: "bar.1", }, }, }, @@ -3759,7 +3934,7 @@ var TestCases = []*testCase{ RowsReturned: 1, Records: []*task.TaskRecord{ { - Id: "1", + Id: "foo.1", State: task.Completed, RootPromiseId: "root1", Recv: []byte("foo"), @@ -3780,7 +3955,7 @@ var TestCases = []*testCase{ RowsReturned: 1, Records: []*task.TaskRecord{ { - Id: "2", + Id: "bar.1", State: task.Init, RootPromiseId: "root2", Recv: []byte("bar"), @@ -3902,7 +4077,7 @@ var TestCases = []*testCase{ }, }, { - name: "CompleteTasks_ClaimedTaskNotCompleted", + name: "CompleteTasks_ClaimedTaskCompleted", commands: []*t_aio.Command{ { Kind: t_aio.CreatePromise, @@ -3938,7 +4113,7 @@ var TestCases = []*testCase{ { Kind: t_aio.UpdateTask, UpdateTask: &t_aio.UpdateTaskCommand{ - Id: "1", + Id: "foo.1", ProcessId: util.ToPointer("pid"), State: task.Claimed, // Set task to Claimed state Counter: 2, @@ -3959,7 +4134,7 @@ var TestCases = []*testCase{ { Kind: t_aio.ReadTask, ReadTask: &t_aio.ReadTaskCommand{ - Id: "1", + Id: "foo.1", }, }, }, @@ -3997,7 +4172,7 @@ var TestCases = []*testCase{ { Kind: t_aio.CompleteTasks, CompleteTasks: &t_aio.AlterTasksResult{ - RowsAffected: 0, // No tasks should be completed + RowsAffected: 1, }, }, { @@ -4006,9 +4181,9 @@ var TestCases = []*testCase{ RowsReturned: 1, Records: []*task.TaskRecord{ { - Id: "1", + Id: "foo.1", ProcessId: util.ToPointer("pid"), - State: task.Claimed, // Task should remain in Claimed state + State: task.Completed, RootPromiseId: "root", Recv: []byte("foo"), Mesg: []byte(`{"type":"resume","root":"root","leaf":"foo"}`), @@ -4017,6 +4192,7 @@ var TestCases = []*testCase{ Ttl: 1, ExpiresAt: 1, CreatedOn: util.ToPointer[int64](0), + CompletedOn: util.ToPointer[int64](5), }, }, }, @@ -4060,7 +4236,7 @@ var TestCases = []*testCase{ { Kind: t_aio.UpdateTask, UpdateTask: &t_aio.UpdateTaskCommand{ - Id: "1", + Id: "foo.1", ProcessId: util.ToPointer("pid"), State: task.Timedout, // Set task to Timeout state Counter: 2, @@ -4081,7 +4257,7 @@ var TestCases = []*testCase{ { Kind: t_aio.ReadTask, ReadTask: &t_aio.ReadTaskCommand{ - Id: "1", + Id: "foo.1", }, }, }, @@ -4128,7 +4304,7 @@ var TestCases = []*testCase{ RowsReturned: 1, Records: []*task.TaskRecord{ { - Id: "1", + Id: "foo.1", ProcessId: util.ToPointer("pid"), State: task.Timedout, // Task should remain in Timeout state RootPromiseId: "root", @@ -4192,7 +4368,7 @@ var TestCases = []*testCase{ { Kind: t_aio.UpdateTask, UpdateTask: &t_aio.UpdateTaskCommand{ - Id: "1", + Id: "foo.1", ProcessId: util.ToPointer("bar"), State: task.Claimed, CurrentStates: []task.State{task.Init}, @@ -4202,7 +4378,7 @@ var TestCases = []*testCase{ { Kind: t_aio.UpdateTask, UpdateTask: &t_aio.UpdateTaskCommand{ - Id: "2", + Id: "foo.2", ProcessId: util.ToPointer("bar"), State: task.Claimed, CurrentStates: []task.State{task.Init}, @@ -4212,7 +4388,7 @@ var TestCases = []*testCase{ { Kind: t_aio.UpdateTask, UpdateTask: &t_aio.UpdateTaskCommand{ - Id: "3", + Id: "foo.3", ProcessId: util.ToPointer("bar"), State: task.Completed, CurrentStates: []task.State{task.Init}, diff --git a/internal/kernel/t_aio/store.go b/internal/kernel/t_aio/store.go index 9b056cec..05b8eca3 100644 --- a/internal/kernel/t_aio/store.go +++ b/internal/kernel/t_aio/store.go @@ -42,6 +42,7 @@ const ( CompleteTasks UpdateTask HeartbeatTasks + CreatePromiseAndTask // LOCKS ReadLock @@ -99,6 +100,9 @@ func (k StoreKind) String() string { return "UpdateTask" case HeartbeatTasks: return "HeartbeatTasks" + case CreatePromiseAndTask: + return "CreatePromiseAndTask" + // LOCKS case ReadLock: return "ReadLock" @@ -159,14 +163,15 @@ type Command struct { DeleteSchedule *DeleteScheduleCommand // TASKS - ReadTask *ReadTaskCommand - ReadTasks *ReadTasksCommand - ReadEnquableTasks *ReadEnqueueableTasksCommand - CreateTask *CreateTaskCommand - CreateTasks *CreateTasksCommand - CompleteTasks *CompleteTasksCommand - UpdateTask *UpdateTaskCommand - HeartbeatTasks *HeartbeatTasksCommand + ReadTask *ReadTaskCommand + ReadTasks *ReadTasksCommand + ReadEnquableTasks *ReadEnqueueableTasksCommand + CreateTask *CreateTaskCommand + CreateTasks *CreateTasksCommand + CompleteTasks *CompleteTasksCommand + UpdateTask *UpdateTaskCommand + HeartbeatTasks *HeartbeatTasksCommand + CreatePromiseAndTask *CreatePromiseAndTaskCommand // LOCKS ReadLock *ReadLockCommand @@ -211,6 +216,7 @@ type Result struct { CompleteTasks *AlterTasksResult UpdateTask *AlterTasksResult HeartbeatTasks *AlterTasksResult + CreatePromiseAndTask *AlterPromiseAndTaskResult // LOCKS ReadLock *QueryLocksResult @@ -365,6 +371,7 @@ type ReadEnqueueableTasksCommand struct { } type CreateTaskCommand struct { + Id string Recv []byte Mesg *message.Mesg Timeout int64 @@ -403,6 +410,11 @@ type HeartbeatTasksCommand struct { Time int64 } +type CreatePromiseAndTaskCommand struct { + PromiseCommand *CreatePromiseCommand + TaskCommand *CreateTaskCommand +} + // Task results type QueryTasksResult struct { @@ -412,7 +424,11 @@ type QueryTasksResult struct { type AlterTasksResult struct { RowsAffected int64 - LastInsertId string +} + +type AlterPromiseAndTaskResult struct { + PromiseRowsAffected int64 + TaskRowsAffected int64 } // Lock commands diff --git a/internal/kernel/t_api/status.go b/internal/kernel/t_api/status.go index abb2e31b..d131d685 100644 --- a/internal/kernel/t_api/status.go +++ b/internal/kernel/t_api/status.go @@ -76,6 +76,8 @@ func (s StatusCode) String() string { return "The specified lock was not found" case StatusTaskNotFound: return "The specified task was not found" + case StatusPromiseRecvNotFound: + return "The specified recv couldn't be found" case StatusPromiseAlreadyExists: return "The specified promise already exists" case StatusScheduleAlreadyExists: diff --git a/test/dst/bc_validator.go b/test/dst/bc_validator.go index d65dc042..05570a96 100644 --- a/test/dst/bc_validator.go +++ b/test/dst/bc_validator.go @@ -4,6 +4,9 @@ import ( "fmt" "math/rand" + "github.com/resonatehq/resonate/internal/util" + "github.com/resonatehq/resonate/pkg/message" + "github.com/resonatehq/resonate/pkg/promise" "github.com/resonatehq/resonate/pkg/task" ) @@ -11,7 +14,7 @@ type BcValidator struct { validators []BcValidatorFn } -type BcValidatorFn func(*Model, *Req) (*Model, error) +type BcValidatorFn func(*Model, int64, int64, *Req) (*Model, error) func NewBcValidator(r *rand.Rand, config *Config) *BcValidator { return &BcValidator{ @@ -23,12 +26,11 @@ func (v *BcValidator) AddBcValidator(bcv BcValidatorFn) { v.validators = append(v.validators, bcv) } -func (v *BcValidator) Validate(model *Model, req *Req) (*Model, error) { +func (v *BcValidator) Validate(model *Model, reqTime int64, resTime int64, req *Req) (*Model, error) { var err error for _, bcv := range v.validators { - model, err = bcv(model, req) + model, err = bcv(model, reqTime, resTime, req) if err != nil { - fmt.Printf("Got error %v\n", err) return model, err } } @@ -36,24 +38,83 @@ func (v *BcValidator) Validate(model *Model, req *Req) (*Model, error) { return model, nil } -func ValidateTasksWithSameRootPromiseId(model *Model, req *Req) (*Model, error) { +func ValidateNotify(model *Model, reqTime int64, resTime int64, req *Req) (*Model, error) { + if req.bc.Promise != nil { + storedP := model.promises.get(req.bc.Promise.Id) + p := req.bc.Promise + if storedP.State == promise.Pending { + // the only way this can happen is if the promise timedout + if p.State == promise.GetTimedoutState(storedP) && resTime >= storedP.Timeout { + model = model.Copy() + model.promises.set(p.Id, p) + return model, nil + } + return model, fmt.Errorf("received a notification for promise '%s' but promise was not completed", req.bc.Promise.Id) + } + } + + return model, nil +} + +func ValidateTasksWithSameRootPromiseId(model *Model, reqTime int64, _ int64, req *Req) (*Model, error) { if req.bc.Task != nil { - for _, stored := range *model.tasks { - if stored.value.Id == req.bc.Task.Id && - (stored.value.Counter < req.bc.Task.Counter || - stored.value.Attempt < req.bc.Task.Attempt) { - continue + reqT := req.bc.Task + stored := model.tasks.get(reqT.Id) + p := model.promises.get(req.bc.Task.RootPromiseId) + + if stored != nil && stored.RootPromiseId != reqT.RootPromiseId { + return model, fmt.Errorf("Same task with different rootpromiseId: '%s'", req.bc.Task.Id) + } + + if stored != nil && stored.State.In(task.Completed|task.Timedout) { + return model, nil + } + + if stored != nil && stored.State == task.Claimed && stored.ExpiresAt > reqTime { + return model, nil + } + + state := reqT.State + completedOn := reqT.CompletedOn + if p != nil && + p.State != promise.Pending { + if reqT.Mesg.Type == message.Invoke && *p.CompletedOn <= *reqT.CreatedOn { + return model, fmt.Errorf("Invocation for a promise that is alredy completed.") + } else if *p.CompletedOn > *reqT.CreatedOn { + state = task.Completed + completedOn = util.ToPointer(*p.CompletedOn) } + } - if stored.value.RootPromiseId == req.bc.Task.RootPromiseId && - stored.value.State != task.Completed && - stored.value.ExpiresAt > req.time && - stored.value.Timeout > req.time { - return model, fmt.Errorf("task '%s' for root promise '%s' should not have been enqueued", req.bc.Task.Id, req.bc.Task.RootPromiseId) + for _, t := range *model.tasks { + if t.value.RootPromiseId != reqT.RootPromiseId || + t.value.Mesg.Type != reqT.Mesg.Type || + t.value.Id == reqT.Id { + continue + } + if !t.value.State.In(task.Completed|task.Timedout) && + t.value.Timeout > reqTime { + return model, fmt.Errorf("Multiple tasks with same rootpromiseId '%s' active at the same time.", req.bc.Task.RootPromiseId) } } + + newT := &task.Task{ + Id: reqT.Id, + Counter: reqT.Counter, + Timeout: reqT.Timeout, + ProcessId: reqT.ProcessId, + State: state, + RootPromiseId: reqT.RootPromiseId, + Recv: reqT.Recv, + Mesg: reqT.Mesg, + Attempt: reqT.Attempt, + Ttl: reqT.Ttl, + ExpiresAt: reqT.ExpiresAt, + CreatedOn: reqT.CreatedOn, + CompletedOn: completedOn, + } model = model.Copy() - model.tasks.set(req.bc.Task.Id, req.bc.Task) + model.tasks.set(newT.Id, newT) } return model, nil diff --git a/test/dst/dst.go b/test/dst/dst.go index 4318c978..ae35d413 100644 --- a/test/dst/dst.go +++ b/test/dst/dst.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "math/rand" + "sort" "strconv" "time" @@ -15,6 +16,7 @@ import ( "github.com/resonatehq/resonate/internal/kernel/system" "github.com/resonatehq/resonate/internal/kernel/t_api" "github.com/resonatehq/resonate/internal/util" + "github.com/resonatehq/resonate/pkg/promise" "github.com/resonatehq/resonate/pkg/task" ) @@ -23,6 +25,7 @@ type DST struct { generator *Generator validator *Validator bcValidator *BcValidator + partitions [][]porcupine.Operation // set by Partition in the porcupine model } type Config struct { @@ -30,6 +33,7 @@ type Config struct { Timeout time.Duration VisualizationPath string Verbose bool + PrintOps bool TimeElapsedPerTick int64 TimeoutTicks int64 ReqsPerTick func() int @@ -52,13 +56,6 @@ const ( type Partition int -const ( - Promise Partition = iota - Schedule - Lock - Task -) - type Req struct { kind Kind time int64 @@ -73,8 +70,16 @@ type Res struct { err error } +type BcKind int + +const ( + Task BcKind = iota + Notify +) + type Backchannel struct { - Task *task.Task + Task *task.Task + Promise *promise.Promise } func New(r *rand.Rand, config *Config) *DST { @@ -96,7 +101,6 @@ func (d *DST) Run(r *rand.Rand, api api.API, aio aio.AIO, system *system.System) // promises d.Add(t_api.ReadPromise, d.generator.GenerateReadPromise, d.validator.ValidateReadPromise) - d.Add(t_api.SearchPromises, d.generator.GenerateSearchPromises, d.validator.ValidateSearchPromises) d.Add(t_api.CreatePromise, d.generator.GenerateCreatePromise, d.validator.ValidateCreatePromise) d.Add(t_api.CreatePromiseAndTask, d.generator.GenerateCreatePromiseAndTask, d.validator.ValidateCreatePromiseAndTask) d.Add(t_api.CompletePromise, d.generator.GenerateCompletePromise, d.validator.ValidateCompletePromise) @@ -109,7 +113,6 @@ func (d *DST) Run(r *rand.Rand, api api.API, aio aio.AIO, system *system.System) // schedules d.Add(t_api.ReadSchedule, d.generator.GenerateReadSchedule, d.validator.ValidateReadSchedule) - d.Add(t_api.SearchSchedules, d.generator.GenerateSearchSchedules, d.validator.ValidateSearchSchedules) d.Add(t_api.CreateSchedule, d.generator.GenerateCreateSchedule, d.validator.ValidateCreateSchedule) d.Add(t_api.DeleteSchedule, d.generator.GenerateDeleteSchedule, d.validator.ValidateDeleteSchedule) @@ -125,6 +128,7 @@ func (d *DST) Run(r *rand.Rand, api api.API, aio aio.AIO, system *system.System) // backchannel validators d.bcValidator.AddBcValidator(ValidateTasksWithSameRootPromiseId) + d.bcValidator.AddBcValidator(ValidateNotify) // porcupine ops var ops []porcupine.Operation @@ -139,10 +143,11 @@ func (d *DST) Run(r *rand.Rand, api api.API, aio aio.AIO, system *system.System) req := req reqTime := time - req.Tags = map[string]string{ - "id": id, - "name": req.Kind.String(), + if req.Tags == nil { + req.Tags = make(map[string]string) } + req.Tags["id"] = id + req.Tags["name"] = req.Kind.String() api.EnqueueSQE(&bus.SQE[t_api.Request, t_api.Response]{ Submission: req, @@ -152,8 +157,10 @@ func (d *DST) Run(r *rand.Rand, api api.API, aio aio.AIO, system *system.System) resTime = resTime - 1 // subtract 1 to ensure tick timeframes don't overlap } - // log - slog.Info("DST", "t", fmt.Sprintf("%d|%d", reqTime, resTime), "id", id, "req", req, "res", res, "err", err) + if d.config.PrintOps { + // log + slog.Info("DST", "t", fmt.Sprintf("%d|%d", reqTime, resTime), "id", id, "req", req, "res", res, "err", err) + } // extract cursors for subsequent requests if err == nil { @@ -183,44 +190,6 @@ func (d *DST) Run(r *rand.Rand, api api.API, aio aio.AIO, system *system.System) Input: &Req{Op, reqTime, req, nil}, Output: &Res{Op, resTime, res, err}, }) - - // Warning: - // A CreatePromiseAndTask request applies to two partitions, the - // promise partition and the task partition. Merging the - // partitions results in long checking time, so as a workaround - // we create an independent CreatePromise request. The mapping - // of requests to partitions is as follows: - // CreatePromise -> p partition - // CreatePromiseAndTask -> t partition - if req.Kind == t_api.CreatePromiseAndTask { - j++ - - req = &t_api.Request{ - Kind: t_api.CreatePromise, - Tags: req.Tags, - CreatePromise: req.CreatePromiseAndTask.Promise, - } - - if res != nil { - res = &t_api.Response{ - Kind: t_api.CreatePromise, - Tags: res.Tags, - CreatePromise: &t_api.CreatePromiseResponse{ - Status: res.CreatePromiseAndTask.Status, - Promise: res.CreatePromiseAndTask.Promise, - }, - } - } - - ops = append(ops, porcupine.Operation{ - ClientId: int(j % d.config.MaxReqsPerTick), - Call: reqTime, - Return: resTime, - Input: &Req{Op, reqTime, req, nil}, - Output: &Res{Op, resTime, res, err}, - }) - } - j++ }, }) @@ -245,9 +214,15 @@ func (d *DST) Run(r *rand.Rand, api api.API, aio aio.AIO, system *system.System) // our model has been updated via the backchannel counter := obj.Counter - r.Intn(2) + // The ProcessId is always the taskId, which means each task + // is always claimed by a different process, which means + // that when heartbeating there will always be a single task at + // most when heartbeating + // add claim req to generator d.generator.AddRequest(&t_api.Request{ Kind: t_api.ClaimTask, + Tags: map[string]string{"partitionId": obj.RootPromiseId}, ClaimTask: &t_api.ClaimTaskRequest{ Id: obj.Id, Counter: counter, @@ -255,6 +230,10 @@ func (d *DST) Run(r *rand.Rand, api api.API, aio aio.AIO, system *system.System) Ttl: RangeIntn(r, 1000, 5000), }, }) + case *promise.Promise: + bc = &Backchannel{ + Promise: obj, + } default: panic("invalid backchannel type") } @@ -304,7 +283,7 @@ func (d *DST) Run(r *rand.Rand, api api.API, aio aio.AIO, system *system.System) case porcupine.Illegal: slog.Error("DST is non linearizable, run with -v flag for more information", "v", d.config.Verbose) if d.config.Verbose { - d.logNonLinearizable(ops, history) + d.logPossibleError(history) } case porcupine.Unknown: slog.Error("DST timed out before linearizability could be determined") @@ -313,51 +292,69 @@ func (d *DST) Run(r *rand.Rand, api api.API, aio aio.AIO, system *system.System) return result == porcupine.Ok } -func (d *DST) logNonLinearizable(ops []porcupine.Operation, history porcupine.LinearizationInfo) { - // log the linearizations - linearizations := history.PartialLinearizationsOperations() - util.Assert(len(linearizations) == 4, "linearizations must be equal to the number of partitions") +func (d *DST) logPossibleError(history porcupine.LinearizationInfo) { + // Whats is printed here and whats is visualized in the dst.html diagram might not match. + // this is a best effort to preserve the possible validation that failed. + fmt.Println("====== Possible errors ======") + + linearizationsPartitions := history.PartialLinearizationsOperations() // check each parition individually - // partitions are in order: promise, schedule, lock - for i, p := range []Partition{Promise, Schedule, Lock} { - util.Assert(len(linearizations[i]) > 0, "partition must have at least one linearization") - - // take the first linearization - linearization := linearizations[i][0] - - // determine the next operation that breaks linearizability - if next, ok := next(linearization, ops, p); ok { - // add the next operation to the linearization, so that we can log - // the non linearizable path - linearization = append(linearization, next) - } else { - // if no op is found, we can assume the partition is linearizable + // partitions are in the order they were given to porcupine + for i, partiton := range d.partitions { + util.Assert(len(linearizationsPartitions[i]) > 0, "partition must have at least one linearization") + + // take the first (and we assumee, by empiric evidence, only linearization) + linearization := linearizationsPartitions[i][0] + + // if the linearization includes all the operations in the partiton all good + if len(partiton) == len(linearization) { continue } - // re run and log operations - d.log(linearization) + op := nextFailure(linearization, partiton) + d.logError(linearization, op) } } -func (d *DST) log(ops []porcupine.Operation) { +func (d *DST) logError(partialLinearization []porcupine.Operation, lastOp porcupine.Operation) { // create a new model model := NewModel() // re feed operations through model - for i, op := range ops { + for _, op := range partialLinearization { req := op.Input.(*Req) res := op.Output.(*Res) var err error // step through the model (again) - model, err = d.Step(model, req.time, res.time, req.req, res.res, res.err) - slog.Info("DST", "t", fmt.Sprintf("%d|%d", req.time, res.time), "id", req.req.Tags["id"], "req", req.req, "res", res.res, "err", err) + if req.kind == Op { + model, err = d.Step(model, req.time, res.time, req.req, res.res, res.err) + } else { + model, err = d.BcStep(model, req.time, res.time, req) + } + util.Assert(err == nil, "Only the last operation must result in error") + } - util.Assert(i != len(ops)-1 || err != nil, "the last operation must result in an error") + req := lastOp.Input.(*Req) + res := lastOp.Output.(*Res) + var err error + if req.kind == Op { + _, err = d.Step(model, req.time, res.time, req.req, res.res, res.err) + fmt.Printf("Op(id=%s, t=%d|%d), req=%v, res=%v\n", req.req.Tags["id"], req.time, res.time, req.req, res.res) + } else { + _, err = d.BcStep(model, req.time, res.time, req) + var obj any + if req.bc.Task != nil { + obj = req.bc.Task + } else if req.bc.Promise != nil { + obj = req.bc.Promise + } + fmt.Printf("Op(id=backchannel, t=%d|%d), %v\n", req.time, res.time, obj) } + + fmt.Printf("err=%v\n\n", err) } func (d *DST) Model() porcupine.Model { @@ -366,29 +363,28 @@ func (d *DST) Model() porcupine.Model { return NewModel() }, Partition: func(history []porcupine.Operation) [][]porcupine.Operation { - p := []porcupine.Operation{} - s := []porcupine.Operation{} - l := []porcupine.Operation{} - t := []porcupine.Operation{} + partitions := make(map[string][]porcupine.Operation) for _, op := range history { req := op.Input.(*Req) + partitionKey := partition(req) + partitions[partitionKey] = append(partitions[partitionKey], op) + } - switch partition(req) { - case Promise: - p = append(p, op) - case Schedule: - s = append(s, op) - case Lock: - l = append(l, op) - // TODO(avillega): Temporaly disable validations over tasks - // TODO(avillega): Add validations over notify tasks. this will require to have the task and promise accesible in the same partition. - // case Task: - // t = append(t, op) - } + // Get sorted keys to iterate over the partitions in a deterministic way + keys := make([]string, 0, len(partitions)) + for k := range partitions { + keys = append(keys, k) + } + sort.Strings(keys) + + var result [][]porcupine.Operation + for _, key := range keys { + result = append(result, partitions[key]) } - return [][]porcupine.Operation{p, s, l, t} + d.partitions = result + return result }, Step: func(state, input, output interface{}) (bool, interface{}) { model := state.(*Model) @@ -405,13 +401,11 @@ func (d *DST) Model() porcupine.Model { } return true, updatedModel case Bc: - // TODO(avillega): Temporaly disable validations over tasks - // updatedModel, err := d.BcStep(model, req) - // if err != nil { - // fmt.Println(err.Error()) - // return false, model - // } - return true, model + updatedModel, err := d.BcStep(model, req.time, res.time, req) + if err != nil { + return false, model + } + return true, updatedModel default: panic(fmt.Sprintf("unknown request kind: %d", req.kind)) } @@ -440,7 +434,13 @@ func (d *DST) Model() porcupine.Model { return fmt.Sprintf("%s | %s → %d", req.req.Tags["id"], req.req, status) case Bc: - return fmt.Sprintf("Backchannel | %s", req.bc.Task) + if req.bc.Task != nil { + return fmt.Sprintf("Backchannel | %s", req.bc.Task) + } else if req.bc.Promise != nil { + return fmt.Sprintf("Backchannel | %s", req.bc.Promise) + } else { + return "Backchannel | unknown(possible error)" + } default: panic(fmt.Sprintf("unknown request kind: %d", req.kind)) } @@ -449,9 +449,15 @@ func (d *DST) Model() porcupine.Model { model := state.(*Model) switch { - case len(*model.promises) > 0 || len(*model.callbacks) > 0: + case len(*model.promises) > 0 || len(*model.callbacks) > 0 || len(*model.tasks) > 0: var promises string for _, p := range *model.promises { + var completedOn string + if p.value.CompletedOn == nil { + completedOn = "--" + } else { + completedOn = fmt.Sprintf("%d", *p.value.CompletedOn) + } promises = promises + fmt.Sprintf(` %s @@ -459,8 +465,9 @@ func (d *DST) Model() porcupine.Model { %s %s %d + %s - `, p.value.Id, p.value.State, p.value.IdempotencyKeyForCreate, p.value.IdempotencyKeyForComplete, p.value.Timeout) + `, p.value.Id, p.value.State, p.value.IdempotencyKeyForCreate, p.value.IdempotencyKeyForComplete, p.value.Timeout, completedOn) } var callbacks string @@ -473,49 +480,87 @@ func (d *DST) Model() porcupine.Model { `, c.value.Id, c.value.PromiseId) } + var tasks string + for _, t := range *model.tasks { + tasks = tasks + fmt.Sprintf(` + + %s + %s + %s + %d + %d + %d + + `, t.value.Id, t.value.State, t.value.RootPromiseId, t.value.ExpiresAt, t.value.Timeout, *t.value.CreatedOn) + } return fmt.Sprintf(` - - - - - - - - - - - - + + + + + + + + + + + + + + + + +
PromisesCallbacks
- - - - - - - - - - - - %s - -
idstateikeyCreateikeyCompletetimeout
-
- - - - - - - - - %s - -
idpromiseId
-
PromisesTasks
+ + + + + + + + + + + + + %s + +
idstateikeyCreateikeyCompletetimeoutcompletedOn
+
+ + + + + + + + + + + + + %s + +
idstaterootPromiseIdexpiresAttimeoutcreatedOn
+
+ + + + + + + + + + + + %s + +
Callbacks
idpromiseId
+
- `, promises, callbacks) + `, promises, tasks, callbacks) case len(*model.schedules) > 0: var schedules string for _, s := range *model.schedules { @@ -594,51 +639,6 @@ func (d *DST) Model() porcupine.Model { `, locks) - case len(*model.tasks) > 0: - var tasks string - for _, t := range *model.tasks { - tasks = tasks + fmt.Sprintf(` - - %s - %s - %s - %d - %d - %d - - `, t.value.Id, util.SafeDeref(t.value.ProcessId), t.value.State, t.value.Counter, t.value.ExpiresAt, t.value.Timeout) - } - - return fmt.Sprintf(` - - - - - - - - - - - -
Tasks
- - - - - - - - - - - - - %s - -
idprocessIdstatecounterexpiresAttimeout
-
- `, tasks) default: return "" } @@ -672,13 +672,13 @@ func (d *DST) Step(model *Model, reqTime int64, resTime int64, req *t_api.Reques return d.validator.Validate(model, reqTime, resTime, req, res) } -func (d *DST) BcStep(model *Model, req *Req) (*Model, error) { +func (d *DST) BcStep(model *Model, reqTime int64, resTime int64, req *Req) (*Model, error) { util.Assert(req.kind == Bc, "Backchannel step can only be taken if req is of kind Bc") - if req.bc.Task == nil { + if req.bc.Task == nil && req.bc.Promise == nil { return model, nil } - return d.bcValidator.Validate(model, req) + return d.bcValidator.Validate(model, reqTime, resTime, req) } func (d *DST) Time(t int64) int64 { @@ -699,52 +699,49 @@ func (d *DST) String() string { // Helper functions -func partition(req *Req) Partition { +func partition(req *Req) string { switch req.kind { case Op: - switch req.req.Kind { - case t_api.ReadPromise, t_api.SearchPromises, t_api.CreatePromise, t_api.CompletePromise, t_api.CreateCallback, t_api.CreateSubscription: - return Promise - case t_api.ReadSchedule, t_api.SearchSchedules, t_api.CreateSchedule, t_api.DeleteSchedule: - return Schedule - case t_api.AcquireLock, t_api.ReleaseLock, t_api.HeartbeatLocks: - return Lock - case t_api.ClaimTask, t_api.CompleteTask, t_api.HeartbeatTasks, t_api.CreatePromiseAndTask: - return Task - default: - panic(fmt.Sprintf("unknown request kind: %s", req.req.Kind)) + partition, exists := req.req.Tags["partitionId"] + if !exists { + panic(fmt.Sprintf("Missing partitionId for request %v", req.req)) } + return partition case Bc: - return Task + if req.bc.Task != nil { + return req.bc.Task.RootPromiseId + } else if req.bc.Promise != nil { + return req.bc.Promise.Id + } else { + panic("unknown backchannel type") + } default: panic(fmt.Sprintf("unknown request kind: %d", req.kind)) } } -func next(linearizable []porcupine.Operation, ops []porcupine.Operation, p Partition) (porcupine.Operation, bool) { +// Find the first Operation if any that is not part of a partial linearization +// by comparing our partition with the linearization +func nextFailure(linearizationOps []porcupine.Operation, partitionOps []porcupine.Operation) porcupine.Operation { // convert to map for quick lookup - linearizableMap := map[porcupine.Operation]bool{} - for _, op := range linearizable { - linearizableMap[op] = true + linearizableMap := map[*Req]bool{} + for _, op := range linearizationOps { + req := op.Input.(*Req) + linearizableMap[req] = true } - for _, op := range ops { + for _, op := range partitionOps { req := op.Input.(*Req) - - // if req is part of a different partition, skip - if partition(req) != p { - continue - } - // if req is part of the linearizable path, skip - if _, ok := linearizableMap[op]; ok { + if _, ok := linearizableMap[req]; ok { continue } // ops are ordered by time, so the first op is not part of the // linearizable path should break the model - return op, true + return op } - return porcupine.Operation{}, false + util.Assert(false, "There must be an operation not included in the linearization") + return porcupine.Operation{} } diff --git a/test/dst/generator.go b/test/dst/generator.go index ce136ce9..c23d7b2f 100644 --- a/test/dst/generator.go +++ b/test/dst/generator.go @@ -121,6 +121,7 @@ func (g *Generator) GenerateReadPromise(r *rand.Rand, t int64) *t_api.Request { return &t_api.Request{ Kind: t_api.ReadPromise, + Tags: map[string]string{"partitionId": id}, ReadPromise: &t_api.ReadPromiseRequest{ Id: id, }, @@ -176,6 +177,7 @@ func (g *Generator) GenerateCreatePromise(r *rand.Rand, t int64) *t_api.Request return &t_api.Request{ Kind: t_api.CreatePromise, + Tags: map[string]string{"partitionId": id}, CreatePromise: &t_api.CreatePromiseRequest{ Id: id, IdempotencyKey: idempotencyKey, @@ -197,6 +199,7 @@ func (g *Generator) GenerateCreatePromiseAndTask(r *rand.Rand, t int64) *t_api.R return &t_api.Request{ Kind: t_api.CreatePromiseAndTask, + Tags: map[string]string{"partitionId": req.CreatePromise.Id}, CreatePromiseAndTask: &t_api.CreatePromiseAndTaskRequest{ Promise: req.CreatePromise, Task: &t_api.CreateTaskRequest{ @@ -219,6 +222,7 @@ func (g *Generator) GenerateCompletePromise(r *rand.Rand, t int64) *t_api.Reques return &t_api.Request{ Kind: t_api.CompletePromise, + Tags: map[string]string{"partitionId": id}, CompletePromise: &t_api.CompletePromiseRequest{ Id: id, IdempotencyKey: idempotencyKey, @@ -239,6 +243,7 @@ func (g *Generator) GenerateCreateCallback(r *rand.Rand, t int64) *t_api.Request return &t_api.Request{ Kind: t_api.CreateCallback, + Tags: map[string]string{"partitionId": promiseId}, CreateCallback: &t_api.CreateCallbackRequest{ Id: id, PromiseId: promiseId, @@ -258,6 +263,7 @@ func (g *Generator) GenerateCreateSubscription(r *rand.Rand, t int64) *t_api.Req return &t_api.Request{ Kind: t_api.CreateSubscription, + Tags: map[string]string{"partitionId": promiseId}, CreateSubscription: &t_api.CreateSubscriptionRequest{ Id: id, PromiseId: promiseId, @@ -274,6 +280,7 @@ func (g *Generator) GenerateReadSchedule(r *rand.Rand, t int64) *t_api.Request { return &t_api.Request{ Kind: t_api.ReadSchedule, + Tags: map[string]string{"partitionId": id}, ReadSchedule: &t_api.ReadScheduleRequest{ Id: id, }, @@ -304,15 +311,19 @@ func (g *Generator) GenerateCreateSchedule(r *rand.Rand, t int64) *t_api.Request id := g.scheduleId(r) cron := fmt.Sprintf("%d * * * *", r.Intn(60)) tags := g.tags(r) + // do not create schedules that can invoke promises. + delete(tags, "resonate:invoke") idempotencyKey := g.idempotencyKey(r) promiseTimeout := RangeInt63n(r, t, g.ticks*g.timeElapsedPerTick) promiseHeaders := g.headers(r) promiseData := g.dataSet[r.Intn(len(g.dataSet))] promiseTags := g.tags(r) + delete(promiseTags, "resonate:invoke") return &t_api.Request{ Kind: t_api.CreateSchedule, + Tags: map[string]string{"partitionId": id}, CreateSchedule: &t_api.CreateScheduleRequest{ Id: id, Description: "", @@ -332,6 +343,7 @@ func (g *Generator) GenerateDeleteSchedule(r *rand.Rand, t int64) *t_api.Request return &t_api.Request{ Kind: t_api.DeleteSchedule, + Tags: map[string]string{"partitionId": id}, DeleteSchedule: &t_api.DeleteScheduleRequest{ Id: id, }, @@ -348,6 +360,7 @@ func (g *Generator) GenerateAcquireLock(r *rand.Rand, t int64) *t_api.Request { return &t_api.Request{ Kind: t_api.AcquireLock, + Tags: map[string]string{"partitionId": "___locks___"}, AcquireLock: &t_api.AcquireLockRequest{ ResourceId: resourceId, ExecutionId: executionId, @@ -363,6 +376,7 @@ func (g *Generator) GenerateReleaseLock(r *rand.Rand, t int64) *t_api.Request { return &t_api.Request{ Kind: t_api.ReleaseLock, + Tags: map[string]string{"partitionId": "___locks___"}, ReleaseLock: &t_api.ReleaseLockRequest{ ResourceId: resourceId, ExecutionId: executionId, @@ -375,6 +389,7 @@ func (g *Generator) GenerateHeartbeatLocks(r *rand.Rand, t int64) *t_api.Request return &t_api.Request{ Kind: t_api.HeartbeatLocks, + Tags: map[string]string{"partitionId": "___locks___"}, HeartbeatLocks: &t_api.HeartbeatLocksRequest{ ProcessId: processId, }, @@ -387,7 +402,7 @@ func (g *Generator) GenerateClaimTask(r *rand.Rand, t int64) *t_api.Request { req := g.pop(r, t_api.ClaimTask) if req != nil { - g.nextTasks(r, req.ClaimTask.Id, req.ClaimTask.ProcessId, req.ClaimTask.Counter) + g.nextTasks(r, req.ClaimTask.Id, req.ClaimTask.ProcessId, req.ClaimTask.Counter, req.Tags) } return req @@ -450,7 +465,7 @@ func (g *Generator) pop(r *rand.Rand, kind t_api.Kind) *t_api.Request { return req } -func (g *Generator) nextTasks(r *rand.Rand, id string, pid string, counter int) { +func (g *Generator) nextTasks(r *rand.Rand, id string, pid string, counter int, reqTags map[string]string) { // seed the "next" requests, // sometimes we deliberately do nothing for i := 0; i < r.Intn(3); i++ { @@ -458,6 +473,7 @@ func (g *Generator) nextTasks(r *rand.Rand, id string, pid string, counter int) case 0: g.AddRequest(&t_api.Request{ Kind: t_api.ClaimTask, + Tags: reqTags, ClaimTask: &t_api.ClaimTaskRequest{ Id: id, ProcessId: pid, @@ -468,6 +484,7 @@ func (g *Generator) nextTasks(r *rand.Rand, id string, pid string, counter int) case 1: g.AddRequest(&t_api.Request{ Kind: t_api.CompleteTask, + Tags: reqTags, CompleteTask: &t_api.CompleteTaskRequest{ Id: id, Counter: counter, @@ -476,6 +493,7 @@ func (g *Generator) nextTasks(r *rand.Rand, id string, pid string, counter int) case 2: g.AddRequest(&t_api.Request{ Kind: t_api.HeartbeatTasks, + Tags: reqTags, HeartbeatTasks: &t_api.HeartbeatTasksRequest{ ProcessId: pid, }, diff --git a/test/dst/validator.go b/test/dst/validator.go index 138856c8..77b2a11b 100644 --- a/test/dst/validator.go +++ b/test/dst/validator.go @@ -64,6 +64,7 @@ func (v *Validator) ValidateReadPromise(model *Model, reqTime int64, resTime int if res.ReadPromise.Promise.State == promise.GetTimedoutState(p) && resTime >= p.Timeout { model = model.Copy() model.promises.set(req.ReadPromise.Id, res.ReadPromise.Promise) + completeRelatedTasks(model, p.Id, reqTime) } else { return model, fmt.Errorf("invalid state transition (%s -> %s) for promise '%s'", p.State, res.ReadPromise.Promise.State, req.ReadPromise.Id) } @@ -127,12 +128,6 @@ func (v *Validator) ValidateCreatePromise(model *Model, reqTime int64, resTime i } func (v *Validator) ValidateCreatePromiseAndTask(model *Model, reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) (*Model, error) { - // Do NOT validate the create promise, as a workaround we create a - // "duplicate" CreatePromise request and map the requests as - // follows: - // CreatePromise -> p partition - // CreatePromiseAndTask -> t partition - switch res.CreatePromiseAndTask.Status { case t_api.StatusCreated: if model.tasks.get(res.CreatePromiseAndTask.Task.Id) != nil { @@ -143,7 +138,12 @@ func (v *Validator) ValidateCreatePromiseAndTask(model *Model, reqTime int64, re model.tasks.set(res.CreatePromiseAndTask.Task.Id, res.CreatePromiseAndTask.Task) } - return model, nil + promiseRes := &t_api.CreatePromiseResponse{ + Status: res.CreatePromiseAndTask.Status, + Promise: res.CreatePromiseAndTask.Promise, + } + + return v.validateCreatePromise(model, reqTime, resTime, req.CreatePromiseAndTask.Promise, promiseRes) } func (v *Validator) validateCreatePromise(model *Model, reqTime int64, resTime int64, req *t_api.CreatePromiseRequest, res *t_api.CreatePromiseResponse) (*Model, error) { @@ -172,6 +172,7 @@ func (v *Validator) validateCreatePromise(model *Model, reqTime int64, resTime i if res.Promise.State == promise.GetTimedoutState(p) && resTime >= p.Timeout { model = model.Copy() model.promises.set(req.Id, res.Promise) + completeRelatedTasks(model, p.Id, reqTime) } else { return model, fmt.Errorf("invalid state transition (%s -> %s) for promise '%s'", p.State, res.Promise.State, req.Id) } @@ -216,6 +217,7 @@ func (v *Validator) ValidateCompletePromise(model *Model, reqTime int64, resTime // update model state model = model.Copy() model.promises.set(req.CompletePromise.Id, res.CompletePromise.Promise) + completeRelatedTasks(model, p.Id, reqTime) return model, nil case t_api.StatusOK: if p == nil { @@ -229,6 +231,7 @@ func (v *Validator) ValidateCompletePromise(model *Model, reqTime int64, resTime if res.CompletePromise.Promise.State == promise.GetTimedoutState(p) && resTime >= p.Timeout { model = model.Copy() model.promises.set(req.CompletePromise.Id, res.CompletePromise.Promise) + completeRelatedTasks(model, p.Id, reqTime) } else { return model, fmt.Errorf("invalid state transition (%s -> %s) for promise '%s'", p.State, res.CompletePromise.Promise.State, req.CompletePromise.Id) } @@ -320,11 +323,12 @@ func (v *Validator) ValidateCreateCallback(model *Model, reqTime int64, resTime if resTime >= p.Timeout { model = model.Copy() model.promises.set(p.Id, res.CreateCallback.Promise) + completeRelatedTasks(model, p.Id, reqTime) return model, nil } // otherwise verify the callback was created previously - callbackId := fmt.Sprintf("%s.%s", p.Id, req.CreateCallback.Id) + callbackId := fmt.Sprintf("__resume:%s:%s", req.CreateCallback.RootPromiseId, req.CreateCallback.PromiseId) if model.callbacks.get(callbackId) == nil { return model, fmt.Errorf("callback '%s' must exist", callbackId) } @@ -385,7 +389,7 @@ func (v *Validator) ValidateCreateSubscription(model *Model, reqTime int64, resT } // otherwise verify the subscription was created previously - subscriptionId := fmt.Sprintf("%s.%s", p.Id, req.CreateSubscription.Id) + subscriptionId := fmt.Sprintf("__notify:%s:%s", p.Id, req.CreateSubscription.Id) if model.callbacks.get(subscriptionId) == nil { return model, fmt.Errorf("subscription '%s' must exist", subscriptionId) } @@ -616,7 +620,29 @@ func (v *Validator) ValidateClaimTask(model *Model, reqTime int64, resTime int64 return model, fmt.Errorf("task '%s' does not exist", req.ClaimTask.Id) } if !t.State.In(task.Completed|task.Timedout) && t.Timeout >= resTime { - return model, fmt.Errorf("task '%s' not completed", req.ClaimTask.Id) + // This could happen if the promise timetout + p := model.promises.get(t.RootPromiseId) + + if !promise.GetTimedoutState(p).In(promise.Pending) && resTime >= p.Timeout { + model = model.Copy() + newP := promise.Promise{ + Id: p.Id, + State: promise.GetTimedoutState(p), + Param: p.Param, + Value: p.Value, + Timeout: p.Timeout, + IdempotencyKeyForCreate: p.IdempotencyKeyForCreate, + IdempotencyKeyForComplete: p.IdempotencyKeyForComplete, + Tags: p.Tags, + CreatedOn: p.CreatedOn, + CompletedOn: util.ToPointer(p.Timeout), + SortId: p.SortId, + } + model.promises.set(p.Id, &newP) + completeRelatedTasks(model, p.Id, reqTime) + } else { + return model, fmt.Errorf("task '%s' state not completed", req.ClaimTask.Id) + } } return model, nil case t_api.StatusTaskInvalidCounter: @@ -655,12 +681,33 @@ func (v *Validator) ValidateCompleteTask(model *Model, reqTime int64, resTime in model = model.Copy() model.tasks.set(req.CompleteTask.Id, res.CompleteTask.Task) return model, nil - case t_api.StatusTaskAlreadyCompleted: + case t_api.StatusOK: if t == nil { return model, fmt.Errorf("task '%s' does not exist", req.CompleteTask.Id) } if !t.State.In(task.Completed|task.Timedout) && t.Timeout >= resTime { - return model, fmt.Errorf("task '%s' state not completed", req.CompleteTask.Id) + // This could happen if the promise timedout + p := model.promises.get(t.RootPromiseId) + if !promise.GetTimedoutState(p).In(promise.Pending) && resTime >= p.Timeout { + model = model.Copy() + newP := promise.Promise{ + Id: p.Id, + State: promise.GetTimedoutState(p), + Param: p.Param, + Value: p.Value, + Timeout: p.Timeout, + IdempotencyKeyForCreate: p.IdempotencyKeyForCreate, + IdempotencyKeyForComplete: p.IdempotencyKeyForComplete, + Tags: p.Tags, + CreatedOn: p.CreatedOn, + CompletedOn: util.ToPointer(p.Timeout), + SortId: p.SortId, + } + model.promises.set(p.Id, &newP) + completeRelatedTasks(model, p.Id, reqTime) + } else { + return model, fmt.Errorf("task '%s' state not completed", req.CompleteTask.Id) + } } return model, nil case t_api.StatusTaskInvalidCounter: @@ -724,3 +771,27 @@ func (v *Validator) ValidateHeartbeatTasks(model *Model, reqTime int64, resTime return model, fmt.Errorf("unexpected response status '%d'", res.HeartbeatTasks.Status) } } + +// This function modifies the model in place, make sure you have called +// model.copy() before calling this function +func completeRelatedTasks(model *Model, promiseId string, _ int64) { + new_tasks := []task.Task{} + rp := model.promises.get(promiseId) + rpCompletedOn := util.SafeDeref(rp.CompletedOn) + for _, t := range *model.tasks { + if t.value.State.In(task.Completed | task.Timedout) { + continue + } + // A task created after the promise was completed (resumes) must not be completed + if t.value.RootPromiseId == promiseId && *t.value.CreatedOn < rpCompletedOn { + new_t := *t.value // Make a copy to avoid modifing the model + new_t.State = task.Completed + new_t.CompletedOn = &rpCompletedOn + new_tasks = append(new_tasks, new_t) + } + } + + for _, new_t := range new_tasks { + model.tasks.set(new_t.Id, &new_t) + } +}