diff --git a/.gitignore b/.gitignore index f165fa7..8ceafbb 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ __pycache__ .hypothesis +*.log +*.db \ No newline at end of file diff --git a/go.mod b/go.mod index 6b35a39..e5ff47b 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,9 @@ go 1.22 toolchain go1.23.3 require ( + github.com/aws/aws-sdk-go v1.55.5 github.com/google/go-containerregistry v0.20.2 + github.com/mattn/go-sqlite3 v1.14.24 gopkg.in/yaml.v3 v3.0.1 ) @@ -25,6 +27,7 @@ require ( github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/klauspost/compress v1.17.10 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect diff --git a/go.sum b/go.sum index 3fb0b66..e4a8eab 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25 github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= +github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= @@ -43,6 +45,10 @@ github.com/google/go-containerregistry v0.20.2 h1:B1wPJ1SN/S7pB+ZAimcciVD+r+yV/l github.com/google/go-containerregistry v0.20.2/go.mod h1:z38EKdKh4h7IP2gSfUUqEvalZBqs6AoLeWfUy34nQC8= github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 h1:Wqo399gCIufwto+VfwCSvsnfGpF/w5E9CNxSwbpD6No= github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0/go.mod h1:qmOFXW2epJhM0qSnUUYpldc7gVz2KMQwJ/QYCDIa7XU= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.10 h1:oXAz+Vh0PMUvJczoi+flxpnBEPxoER1IaAnU/NMPtT0= @@ -51,6 +57,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= +github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= @@ -148,6 +156,8 @@ google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHh gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/reefagent/agent.go b/reefagent/agent.go new file mode 100644 index 0000000..aa72e7d --- /dev/null +++ b/reefagent/agent.go @@ -0,0 +1,84 @@ +package reefagent + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "time" +) + +type Agent struct { + Id string + Queue string + Job *Job + ServiceHost string +} + +func (a *Agent) Start() { + // start a loop to ping the server every 1 second on a goroutine until ping returns a job + for { + success, job := a.Ping() + if success { + a.AcquireAndRunJob(job) + return + } + time.Sleep(1 * time.Second) + } +} + +func (a *Agent) Ping() (bool, *Job) { + // call POST /ping endpoint with agent ID and queue + url := fmt.Sprintf("%s/ping?agentId=%s&queue=%s", a.ServiceHost, a.Id, a.Queue) + resp, err := http.Get(url) + + if err != nil { + log.Fatal(err) + } + var response struct { + JobId string `json:"jobId"` + Commands []string `json:"commands"` + } + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + log.Fatal(err) + } + job := &Job{ + Id: response.JobId, + Commands: response.Commands, + } + defer resp.Body.Close() + // if response is 200, return the job and true + // otherwise return nil and false + if resp.StatusCode == 200 { + return true, job + } + return false, nil +} + +func (a *Agent) AcquireAndRunJob(job *Job) { + success := a.AcquireJob(job.Id) + if !success { + fmt.Println("Failed to acquire job") + return + } + a.Job = job + a.RunJob() +} + +func (a *Agent) RunJob() { + jr := NewJobRunner(a.Job, a.ServiceHost) + jr.Run() +} + +func (a *Agent) AcquireJob(jobId string) bool { + // Send POST request to acquire the job + url := fmt.Sprintf("%s/job/acquire?jobId=%s&agentId=%s", a.ServiceHost, jobId, a.Id) + resp, err := http.Post(url, "application/json", nil) + if err != nil { + log.Printf("Error acquiring job: %v", err) + return false + } + defer resp.Body.Close() + + return resp.StatusCode == 200 +} diff --git a/reefagent/job.go b/reefagent/job.go new file mode 100644 index 0000000..c99fb10 --- /dev/null +++ b/reefagent/job.go @@ -0,0 +1,6 @@ +package reefagent + +type Job struct { + Id string + Commands []string +} diff --git a/reefagent/job_runner.go b/reefagent/job_runner.go new file mode 100644 index 0000000..093ead1 --- /dev/null +++ b/reefagent/job_runner.go @@ -0,0 +1,37 @@ +package reefagent + +import ( + "fmt" + "io" + "os" + "os/exec" + "strings" +) + +type JobRunner struct { + logStreamer *LogStreamer + job *Job +} + +func (jr *JobRunner) Run() { + jr.logStreamer.Start() + for _, command := range jr.job.Commands { + parts := strings.Split(command, " ") + cmd := exec.Command(parts[0], parts[1:]...) + cmd.Stdout = jr.logStreamer.logsWriter + cmd.Stderr = jr.logStreamer.logsWriter + if err := cmd.Run(); err != nil { + fmt.Printf("Error running command %s: %v\n", command, err) + } + } + jr.logStreamer.Stop() +} + +func NewJobRunner(job *Job, serviceHost string) *JobRunner { + jr := &JobRunner{ + job: job, + logStreamer: NewLogStreamer(job.Id, serviceHost), + } + jr.logStreamer.logsWriter = io.MultiWriter(jr.logStreamer.logs, os.Stdout) + return jr +} diff --git a/reefagent/log_streamer.go b/reefagent/log_streamer.go new file mode 100644 index 0000000..9385bab --- /dev/null +++ b/reefagent/log_streamer.go @@ -0,0 +1,149 @@ +package reefagent + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "os" + "sync" + "time" +) + +type LogChunk struct { + Data []byte + Sequence int +} + +type Buffer struct { + buf []byte +} + +func (l *Buffer) Write(b []byte) (int, error) { + l.buf = append(l.buf, b...) + return len(b), nil +} + +func (l *Buffer) ReadAndFlush() []byte { + buf := l.buf + l.buf = []byte{} + return buf +} + +type LogStreamer struct { + jobId string + logsWriter io.Writer + logs *Buffer + logOrder int + queue chan LogChunk + maxSize int + active bool + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + serviceHost string +} + +func NewLogStreamer(jobId string) *LogStreamer { + ctx, cancel := context.WithCancel(context.Background()) + return &LogStreamer{ + jobId: jobId, + logs: &Buffer{}, + logOrder: 0, + maxSize: 10 * 1024 * 1024, // 10MB + active: false, + queue: make(chan LogChunk, 100), + ctx: ctx, + cancel: cancel, + } +} + +func (ls *LogStreamer) Start() { + ls.active = true + ls.wg.Add(2) + go ls.StreamLogs() + go ls.RetrieveAndUploadChunk() +} + +func (ls *LogStreamer) Stop() { + ls.ChunkLogs(ls.logs.ReadAndFlush()) + ls.cancel() + ls.wg.Wait() + close(ls.queue) +} + +// StreamLogs streams the result from the logs and chunks them into the queue +func (ls *LogStreamer) StreamLogs() { + defer ls.wg.Done() + for { + // read the logs and chunk them then push into the queue + logs := ls.logs.ReadAndFlush() + if len(logs) > 0 { + ls.ChunkLogs(logs) + } + select { + case <-time.After(1 * time.Second): + case <-ls.ctx.Done(): + return + } + } +} + +func (ls *LogStreamer) ChunkLogs(data []byte) { + chunkSize := ls.maxSize + for i := 0; i < len(data); i += chunkSize { + chunkData := data[i:min(i+chunkSize, len(data))] + logChunk := LogChunk{Data: chunkData, Sequence: ls.logOrder} + ls.queue <- logChunk + ls.logOrder++ + } +} + +// function that takes chunk from the queue and write it to file +func (ls *LogStreamer) RetrieveAndUploadChunk() { + defer ls.wg.Done() + for { + select { + case chunk, ok := <-ls.queue: + if !ok { + fmt.Println("Queue closed.. exiting") + return + } + // write the chunk to file + ls.WriteToFile(chunk) + // send the chunk to server + ls.UploadChunk(chunk) + case <-ls.ctx.Done(): + return + } + } +} + +func (ls *LogStreamer) WriteToFile(logChunk LogChunk) { + fileName := fmt.Sprintf("logs/%s-%d.log", ls.jobId, logChunk.Sequence) + file, err := os.OpenFile(fileName, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + fmt.Println("Error writing to file", err) + return + } + defer file.Close() + file.Write(logChunk.Data) +} + +func (ls *LogStreamer) UploadChunk(logChunk LogChunk) { + // send the chunk to server + url := fmt.Sprintf("%s/job/logs?jobId=%s&sequence=%d", ls.serviceHost, ls.jobId, logChunk.Sequence) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(logChunk.Data)) + if err != nil { + fmt.Println("Error creating request", err) + return + } + req.Header.Set("Content-Type", "text/plain") + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + fmt.Println("Error uploading chunk", err) + } + defer resp.Body.Close() +} diff --git a/reefd/instance_manager.go b/reefd/instance_manager.go new file mode 100644 index 0000000..0259a8f --- /dev/null +++ b/reefd/instance_manager.go @@ -0,0 +1,196 @@ +package reefd + +import ( + "database/sql" + "fmt" + "log" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2" +) + +type InstanceConfiguration struct { + InstanceType string `json:"instance_type"` + AMI string `json:"ami"` +} + +// LaunchRequest represents a row of launch_requests table in the database +type LaunchRequest struct { + Id string `json:"id"` + InstanceConfigName string `json:"instance_config_name"` + DesiredState string `json:"desired_state"` + CurrentState string `json:"current_state"` + InstanceId *string `json:"instance_id,omitempty"` +} + +const ( + defaultRegion = "us-west-2" +) + +type instanceManager struct { + db *sql.DB + ec2Client EC2Client +} + +func (m *instanceManager) getInstanceConfig(instanceConfigName string) (InstanceConfiguration, error) { + var instanceConfig InstanceConfiguration + err := m.db.QueryRow("SELECT instance_type, ami FROM instance_configs WHERE name = ?", instanceConfigName).Scan(&instanceConfig.InstanceType, &instanceConfig.AMI) + if err != nil { + return InstanceConfiguration{}, fmt.Errorf("error querying database: %v", err) + } + return instanceConfig, nil +} + +// launchInstance launches an instance with the given instance type and AMI +func (m *instanceManager) launchInstance(instanceConfigName string) (string, error) { + svc := m.ec2Client + if svc == nil { + return "", fmt.Errorf("failed to get EC2 client") + } + log.Printf("Launching instance with config: %s", instanceConfigName) + instanceConfig, err := m.getInstanceConfig(instanceConfigName) + if err != nil { + return "", fmt.Errorf("error getting instance config: %v", err) + } + instanceType := instanceConfig.InstanceType + ami := instanceConfig.AMI + + runResult, err := svc.RunInstances(&ec2.RunInstancesInput{ + ImageId: aws.String(ami), + InstanceType: aws.String(instanceType), + MinCount: aws.Int64(1), + MaxCount: aws.Int64(1), + TagSpecifications: []*ec2.TagSpecification{{ + ResourceType: aws.String("instance"), + Tags: []*ec2.Tag{{ + Key: aws.String("Name"), + Value: aws.String("Kevin-launch"), + }}, + }}, + }) + + if err != nil { + return "", fmt.Errorf("failed to launch instance: %v", err) + } + instanceID := *runResult.Instances[0].InstanceId + log.Printf("Created instance: %s", instanceID) + return instanceID, nil +} + +// getInstanceState retrieves the current state of the instance from AWS with the given instance ID +func (m *instanceManager) getInstanceState(instanceID string) (string, error) { + log.Printf("Getting instance state for %s", instanceID) + svc := m.ec2Client + if svc == nil { + return "", fmt.Errorf("ec2 client does not exist") + } + + result, err := svc.DescribeInstances(&ec2.DescribeInstancesInput{ + InstanceIds: []*string{aws.String(instanceID)}, + }) + if err != nil { + return "", fmt.Errorf("error describing instance %s: %v", instanceID, err) + } + + if len(result.Reservations) == 0 || len(result.Reservations[0].Instances) == 0 { + return "", fmt.Errorf("no instance found with ID %s", instanceID) + } + + instance := result.Reservations[0].Instances[0] + return *instance.State.Name, nil +} + + +// updateCurrentState updates the current state of the existing instances in the database table +func (m *instanceManager) updateCurrentState() error { + log.Printf("Updating current state of existing instances") + rows, err := m.db.Query("SELECT id, instance_id FROM launch_requests WHERE instance_id IS NOT NULL") + if err != nil { + return fmt.Errorf("error querying database: %v", err) + } + defer rows.Close() + + currentStateMap := make(map[string]string) + for rows.Next() { + var id, instanceID string + if err := rows.Scan(&id, &instanceID); err != nil { + log.Printf("Error scanning row: %s", err) + continue + } + + state, err := m.getInstanceState(instanceID) + if err != nil { + log.Printf("Error getting instance state: %v", err) + continue + } + + currentStateMap[id] = state + } + + for id, currentState := range currentStateMap { + if _, err := m.db.Exec(`UPDATE launch_requests SET current_state = ? WHERE id = ?`, currentState, id); err != nil { + return fmt.Errorf("Error updating current state for each request: %v", err) + } + } + return nil +} + +type EC2Client interface { + RunInstances(*ec2.RunInstancesInput) (*ec2.Reservation, error) + DescribeInstances(*ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) +} + +func getEC2Client() EC2Client { + sess, err := session.NewSession(&aws.Config{Region: aws.String(defaultRegion)}) + if err != nil { + log.Fatalf("Error creating session: %v", err) + } + return ec2.New(sess) +} + +// processLaunchRequests updates the current state of the existing instances and launches new instances for the launch requests that have not been launched yet +func processLaunchRequests(instanceManager *instanceManager) error { + // update the current state of the existing instances + if err := instanceManager.updateCurrentState(); err != nil { + return fmt.Errorf("error updating current state: %v", err) + } + + // query all launch requests where the instance has not been launched yet + log.Printf("Scanning for launch requests with different desired and current states") + rows, err := instanceManager.db.Query("SELECT id, instance_config_name FROM launch_requests WHERE current_state IS NULL OR instance_id IS NULL") + if err != nil { + return fmt.Errorf("error querying database: %v", err) + } + defer rows.Close() + + instanceIDMap := make(map[string]string) + + // iterate over all matching launch requests + for rows.Next() { + var id, instanceConfigName string + if err := rows.Scan(&id, &instanceConfigName); err != nil { + log.Printf("Error scanning row: %s", err) + continue + } + + // launch instance + instanceID, err := instanceManager.launchInstance(instanceConfigName) + if err != nil { + log.Printf("Error launching instance: %v", err) + continue + } + if instanceID != "" { + instanceIDMap[id] = instanceID // map the id to the instance id to update on the db later + } + } + + // update the launch requests with the instance id + for id, instanceID := range instanceIDMap { + log.Printf("Updating instance ID for request %s: %s", id, instanceID) + if _, err := instanceManager.db.Exec(`UPDATE launch_requests SET instance_id = ? WHERE id = ?`, instanceID, id); err != nil { + return fmt.Errorf("error updating instance ID for request %s: %v", id, err) + } + } + return nil +} diff --git a/reefd/instance_manager_test.go b/reefd/instance_manager_test.go new file mode 100644 index 0000000..6bb3af1 --- /dev/null +++ b/reefd/instance_manager_test.go @@ -0,0 +1,237 @@ +package reefd + +import ( + "database/sql" + "io" + "net/http" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ec2" + _ "github.com/mattn/go-sqlite3" +) + +type mockEC2Client struct { + describeInstancesOutput *ec2.DescribeInstancesOutput + runInstancesOutput *ec2.Reservation + err error +} + +func (m *mockEC2Client) DescribeInstances(*ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) { + return m.describeInstancesOutput, m.err +} + +func (m *mockEC2Client) RunInstances(*ec2.RunInstancesInput) (*ec2.Reservation, error) { + return m.runInstancesOutput, m.err +} + +const sampleDesiredState = `{"instance_type":"t3.micro","ami":"ami-1234567890abcdef0","state":"running"}` + +func setupTestDB(t *testing.T) (*sql.DB, string) { + dbPath := filepath.Join(os.TempDir(), "test_db.sqlite3") + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("Error opening test database: %s", err) + } + + _, err = db.Exec(` + CREATE TABLE launch_requests ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + instance_id TEXT, + desired_state TEXT, + current_state TEXT + ) + `) + if err != nil { + t.Fatalf("Error creating test launch_requests table: %s", err) + } + + return db, dbPath +} + +func TestGetInstanceState(t *testing.T) { + testInstanceManager := &instanceManager{ + db: nil, + ec2Client: &mockEC2Client{ + describeInstancesOutput: &ec2.DescribeInstancesOutput{ + Reservations: []*ec2.Reservation{{ + Instances: []*ec2.Instance{{ + InstanceType: aws.String("t3.micro"), + ImageId: aws.String("ami-1234567890abcdef0"), + State: &ec2.InstanceState{ + Name: aws.String("running"), + }, + }}, + }}, + }, + }, + } + + instanceInfo, err := testInstanceManager.getInstanceState("i-1234567890abcdef0") + if err != nil { + t.Fatalf("Error getting instance state: %v", err) + } + want := &InstanceInfo{ + InstanceType: "t3.micro", + AMI: "ami-1234567890abcdef0", + State: "running", + } + + if !reflect.DeepEqual(instanceInfo, want) { + t.Errorf("got %v, want %v", instanceInfo, want) + } +} + +func TestLaunchInstance(t *testing.T) { + testInstanceManager := &instanceManager{ + db: nil, + ec2Client: &mockEC2Client{ + runInstancesOutput: &ec2.Reservation{ + Instances: []*ec2.Instance{{ + InstanceId: aws.String("i-1234567890abcdef0"), + }}, + }, + }, + } + + instanceID, err := testInstanceManager.launchInstance("t3.micro", "ami-1234567890abcdef0") + if err != nil { + t.Fatalf("Error launching instance: %v", err) + } + + want := "i-1234567890abcdef0" + if instanceID != want { + t.Errorf("got %q, want %q", instanceID, want) + } +} + +func TestUpdateCurrentState(t *testing.T) { + db, dbPath := setupTestDB(t) + defer db.Close() + defer os.Remove(dbPath) + + testInstanceManager := &instanceManager{ + db: db, + ec2Client: &mockEC2Client{ + describeInstancesOutput: &ec2.DescribeInstancesOutput{ + Reservations: []*ec2.Reservation{{ + Instances: []*ec2.Instance{{ + InstanceType: aws.String("t3.micro"), + ImageId: aws.String("ami-1234567890abcdef0"), + State: &ec2.InstanceState{ + Name: aws.String("running"), + }, + }}, + }}, + }, + }, + } + + _, err := db.Exec( + `INSERT INTO launch_requests (id, instance_id, desired_state, current_state) VALUES (?, ?, ?, ?)`, + "1", "i-1234567890abcdef0", sampleDesiredState, nil, + ) + if err != nil { + t.Fatalf("Error inserting test launch_requests row: %s", err) + } + + testInstanceManager.updateCurrentState() + + var currentState string + err = db.QueryRow(`SELECT current_state FROM launch_requests WHERE id = ?`, "1").Scan(¤tState) + if err != nil { + t.Fatalf("Error querying current_state column: %s", err) + } + + want := sampleDesiredState + if currentState != want { + t.Errorf("got %q, want %q", currentState, want) + } +} + +func TestProcessLaunchRequests(t *testing.T) { + db, dbPath := setupTestDB(t) + defer db.Close() + defer os.Remove(dbPath) + + testInstanceManager := &instanceManager{ + db: db, + ec2Client: &mockEC2Client{ + runInstancesOutput: &ec2.Reservation{ + Instances: []*ec2.Instance{{ + InstanceId: aws.String("i-1234567890abcdef1"), + }}, + }, + }, + } + + _, err := db.Exec( + `INSERT INTO launch_requests (id, instance_id, desired_state, current_state) VALUES (?, ?, ?, ?)`, + "1", nil, sampleDesiredState, nil, + ) + if err != nil { + t.Fatalf("Error inserting test launch_requests row: %s", err) + } + + processLaunchRequests(testInstanceManager) + + var instanceID string + err = db.QueryRow(`SELECT instance_id FROM launch_requests WHERE id = ?`, "1").Scan(&instanceID) + if err != nil { + t.Fatalf("Error querying instance_id column: %s", err) + } + + want := "i-1234567890abcdef1" + if instanceID != want { + t.Errorf("got %q, want %q", instanceID, want) + } +} + +func TestHandleLaunchRequest(t *testing.T) { + db, dbPath := setupTestDB(t) + defer db.Close() + defer os.Remove(dbPath) + + ec2Client := &mockEC2Client{ + runInstancesOutput: &ec2.Reservation{ + Instances: []*ec2.Instance{{ + InstanceId: aws.String("i-1234567890abcdef1"), + }}, + }, + } + + handleLaunchRequest(db, nil, &http.Request{ + Body: io.NopCloser(strings.NewReader(sampleDesiredState)), + }, ec2Client) + + // wait for the goroutine to finish + time.Sleep(5 * time.Second) + + // check how many rows are in the launch_requests table + var count int + err := db.QueryRow(`SELECT COUNT(*) FROM launch_requests`).Scan(&count) + if err != nil { + t.Fatalf("Error querying launch_requests table: %s", err) + } + + want_count := 1 + if count != want_count { + t.Errorf("got %d, want %d", count, want_count) + } + + // check that instance_id is set on that row + var instanceID string + err = db.QueryRow(`SELECT instance_id FROM launch_requests WHERE id = ?`, "1").Scan(&instanceID) + if err != nil { + t.Fatalf("Error querying instance_id column: %s", err) + } + want_instanceID := "i-1234567890abcdef1" + if instanceID != want_instanceID { + t.Errorf("got %q, want %q", instanceID, want_instanceID) + } +} diff --git a/reefd/job.go b/reefd/job.go new file mode 100644 index 0000000..ebf6cbb --- /dev/null +++ b/reefd/job.go @@ -0,0 +1,44 @@ +package reefd + +import ( + "database/sql" + "encoding/json" + "time" +) + +type Job struct { + Id string + Queue string + Commands []string + AgentId string `json:"agent_id,omitempty"` + CreatedAt time.Time +} + +/* +Add jobs table to database: + +CREATE TABLE IF NOT EXISTS jobs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + queue TEXT NOT NULL, + commands TEXT NOT NULL, + agent_id TEXT, + created_at TIMESTAMP NOT NULL +); +*/ + +func getJob(db *sql.DB, queue string) (*Job, error) { + job := &Job{} + var commandsJSON string + err := db.QueryRow(`SELECT id, queue, commands, created_at FROM jobs WHERE queue = ? AND agent_id IS NULL ORDER BY created_at ASC LIMIT 1`, queue).Scan(&job.Id, &job.Queue, &commandsJSON, &job.CreatedAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + + if err := json.Unmarshal([]byte(commandsJSON), &job.Commands); err != nil { + return nil, err + } + return job, nil +} diff --git a/reefd/push.sh b/reefd/push.sh new file mode 100644 index 0000000..a92ae35 --- /dev/null +++ b/reefd/push.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -euo pipefail +set -x + +go build -o reefd/reefd ./reefd +ssh rayci -- docker rm -f rayci +scp reefd/reefd rayci:/opt/apps/rayci/bin/reefd +ssh rayci -- /bin/bash /opt/apps/rayci/create.sh +ssh rayci -- docker start rayci \ No newline at end of file diff --git a/reefd/reefd/.gitignore b/reefd/reefd/.gitignore new file mode 100644 index 0000000..8fdd2ee --- /dev/null +++ b/reefd/reefd/.gitignore @@ -0,0 +1 @@ +/reefd diff --git a/reefd/reefd/agent.go b/reefd/reefd/agent.go new file mode 100644 index 0000000..a2663f5 --- /dev/null +++ b/reefd/reefd/agent.go @@ -0,0 +1,13 @@ +package main + +import ( + "github.com/ray-project/rayci/reefagent" +) + +func main() { + a := reefagent.Agent{ + Id: "123", + Queue: "test", + } + a.Start() +} diff --git a/reefd/reefd/main.go b/reefd/reefd/main.go index 2dacc63..e69de29 100644 --- a/reefd/reefd/main.go +++ b/reefd/reefd/main.go @@ -1,18 +0,0 @@ -package main - -import ( - "flag" - "log" - - "github.com/ray-project/rayci/reefd" -) - -func main() { - config := &reefd.Config{} - addr := flag.String("addr", "localhost:8000", "address to listen on") - flag.Parse() - - if err := reefd.Serve(*addr, config); err != nil { - log.Fatal(err) - } -} diff --git a/reefd/reefd/server.go b/reefd/reefd/server.go new file mode 100644 index 0000000..2a0c6cc --- /dev/null +++ b/reefd/reefd/server.go @@ -0,0 +1,45 @@ +package main + +import ( + "database/sql" + "flag" + "log" + "os" + + _ "github.com/mattn/go-sqlite3" + "github.com/ray-project/rayci/reefd" +) + +func main() { + dbPath := flag.String("db", "", "Path to .db file") + flag.Parse() + + if *dbPath == "" { + log.Fatal("Database path is required") + } + + if _, err := os.Stat(*dbPath); err != nil { + if os.IsNotExist(err) { + log.Fatalf("File %s does not exist", *dbPath) + } + log.Fatalf("Error checking database file %s: %v", *dbPath, err) + } + + db, err := sql.Open("sqlite3", *dbPath) + if err != nil { + log.Fatalf("Error connecting to database: %s", err) + } + defer db.Close() + + config := &reefd.Config{ + DB: db, + } + + addr := flag.String("addr", "0.0.0.0:1235", "address to listen on") + flag.Parse() + + log.Println("serving at:", *addr) + if err := reefd.Serve(*addr, config); err != nil { + log.Fatal(err) + } +} diff --git a/reefd/serve.go b/reefd/serve.go index 30c27c4..e77c38a 100644 --- a/reefd/serve.go +++ b/reefd/serve.go @@ -2,12 +2,17 @@ package reefd import ( + "database/sql" + "encoding/json" + "fmt" "io" "net/http" + "time" ) // Config contains the configuration for the running the server. type Config struct { + DB *sql.DB } type server struct { @@ -19,15 +24,158 @@ func newServer(c *Config) *server { } func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, "Hello, World!") + io.WriteString(w, "Hello, Kevin!") +} + +// handleLaunchRequest retrieves the desired state from the request body and inserts it into the database +// then starts a goroutine to process the launch requests +func handleLaunchRequest(db *sql.DB, w http.ResponseWriter, r *http.Request, ec2Client EC2Client) { + instanceConfigName := r.URL.Query().Get("instanceConfigName") + + // insert the desired state into the database + if _, err := db.Exec(`INSERT INTO launch_requests (instance_config_name, desired_state) VALUES (?, ?)`, instanceConfigName, "running"); err != nil { + http.Error(w, "Error inserting into database: "+err.Error(), http.StatusInternalServerError) + return + } + + // start a goroutine to scan the database for launch requests with different desired and current states + instanceManager := &instanceManager{db: db, ec2Client: ec2Client} + go processLaunchRequests(instanceManager) +} + +func handleInstanceConfigAdd(db *sql.DB, w http.ResponseWriter, r *http.Request) { + // read instanceConfig from the request body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Error reading body: "+err.Error(), http.StatusInternalServerError) + return + } + + // to parse request body into name and instance config + var requestBody struct { + Name string `json:"name"` + InstanceConfig InstanceConfiguration `json:"instance_config"` + } + if err := json.Unmarshal(body, &requestBody); err != nil { + http.Error(w, "Error unmarshalling request body: "+err.Error(), http.StatusInternalServerError) + return + } + instanceConfig := requestBody.InstanceConfig + instanceType := instanceConfig.InstanceType + ami := instanceConfig.AMI + name := requestBody.Name + + // insert the instance config into the database + if _, err := db.Exec(`INSERT INTO instance_configs (name, instance_type, ami) VALUES (?, ?, ?)`, name, instanceType, ami); err != nil { + http.Error(w, "Error inserting into database: "+err.Error(), http.StatusInternalServerError) + return + } + + // send OK response + w.WriteHeader(http.StatusOK) +} + +// handleJobLogs handles job logs sent from the agent +func handleJobLogs(w http.ResponseWriter, r *http.Request) { + jobId := r.URL.Query().Get("jobId") + sequence := r.URL.Query().Get("sequence") + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Error reading body: "+err.Error(), http.StatusInternalServerError) + return + } + fmt.Println("log ", jobId, "-", sequence, ": ", string(body)) + // TODO: figure out how to store and display logs in order and a nice way +} + +// handlePing handles requests from the agent to check if there's any job for agent to take +func handlePing(db *sql.DB, w http.ResponseWriter, r *http.Request) { + // get agent ID from the request + queue := r.URL.Query().Get("queue") + // Look into database to see if there's any job that is in the queue + // If there's any job, send it to the agent + job, err := getJob(db, queue) + if err != nil { + http.Error(w, "Error getting job: "+err.Error(), http.StatusInternalServerError) + return + } + if job == nil { + http.Error(w, "No job found", http.StatusNotFound) + return + } + // send the jobId and job commands back in response + json.NewEncoder(w).Encode(map[string]interface{}{ + "jobId": job.Id, + "commands": job.Commands, + }) +} + +// handleAcquireJob handles request from the agent to acquire a job +func handleAcquireJob(db *sql.DB, w http.ResponseWriter, r *http.Request) { + agentId := r.URL.Query().Get("agentId") + jobId := r.URL.Query().Get("jobId") + + // update the job with the agent ID + if _, err := db.Exec(`UPDATE jobs SET agent_id = ? WHERE id = ?`, agentId, jobId); err != nil { + http.Error(w, "Error updating job: "+err.Error(), http.StatusInternalServerError) + return + } + + fmt.Println("Agent", agentId, "acquired job", jobId) + w.WriteHeader(http.StatusOK) +} + +// handleJobAdd handles requests to add a job +func handleJobAdd(db *sql.DB, w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Error reading body: "+err.Error(), http.StatusInternalServerError) + return + } + + // decompose request body into job + var job Job + if err := json.Unmarshal(body, &job); err != nil { + http.Error(w, "Error unmarshalling job: "+err.Error(), http.StatusInternalServerError) + return + } + + // insert the job into the jobs db table + commandsJson, err := json.Marshal(job.Commands) + if err != nil { + http.Error(w, "Error marshalling commands: "+err.Error(), http.StatusInternalServerError) + return + } + if _, err := db.Exec(`INSERT INTO jobs (commands, queue, created_at) VALUES (?, ?, ?)`, string(commandsJson), job.Queue, time.Now()); err != nil { + http.Error(w, "Error inserting into database: "+err.Error(), http.StatusInternalServerError) + return + } + // send OK response + w.WriteHeader(http.StatusOK) } // Serve runs the server. func Serve(addr string, c *Config) error { - s := newServer(c) - httpServer := &http.Server{ - Addr: addr, - Handler: s, - } - return httpServer.ListenAndServe() + ec2Client := getEC2Client() + + http.HandleFunc("/instance/launch", func(w http.ResponseWriter, r *http.Request) { + handleLaunchRequest(c.DB, w, r, ec2Client) + }) + http.HandleFunc("/instance_config/add", func(w http.ResponseWriter, r *http.Request) { + handleInstanceConfigAdd(c.DB, w, r) + }) + http.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { + handlePing(c.DB, w, r) + }) + http.HandleFunc("/job/logs", func(w http.ResponseWriter, r *http.Request) { + handleJobLogs(w, r) + }) + http.HandleFunc("/job/acquire", func(w http.ResponseWriter, r *http.Request) { + handleAcquireJob(c.DB, w, r) + }) + http.HandleFunc("/job/add", func(w http.ResponseWriter, r *http.Request) { + handleJobAdd(c.DB, w, r) + }) + + return http.ListenAndServe(addr, nil) }