Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions server/cmd/api/api/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"os/user"
"path/filepath"
"strconv"
"strings"
Expand Down Expand Up @@ -89,20 +91,66 @@ func buildCmd(body *oapi.ProcessExecRequest) (*exec.Cmd, error) {
env = append(env, k+"="+v)
}
cmd.Env = env

// Configure user if requested
if body.AsRoot != nil && *body.AsRoot && body.AsUser != nil && *body.AsUser != "" {
return nil, errors.New("cannot specify both as_root and as_user")
}
if body.AsRoot != nil && *body.AsRoot {
cmd.SysProcAttr = &syscall.SysProcAttr{
Credential: &syscall.Credential{Uid: 0, Gid: 0},
}
} else if body.AsUser != nil && *body.AsUser != "" {
spec := *body.AsUser
// support forms: "username" or "uid" or "uid:gid"
var uidStr, gidStr string
if i := strings.IndexByte(spec, ':'); i >= 0 {
uidStr = spec[:i]
gidStr = spec[i+1:]
} else {
uidStr = spec
}

var u *user.User
var err error
if _, errNum := strconv.Atoi(uidStr); errNum == nil {
u, err = user.LookupId(uidStr)
} else {
u, err = user.Lookup(uidStr)
}
if err != nil {
return nil, fmt.Errorf("failed to lookup user %q: %w", spec, err)
}
uid64, err := strconv.ParseUint(u.Uid, 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid uid for user %q: %w", spec, err)
}
gid64, err := strconv.ParseUint(u.Gid, 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid gid for user %q: %w", spec, err)
}
// If gid override provided, require it to be numeric
if gidStr != "" {
if gOverride, err := strconv.ParseUint(gidStr, 10, 32); err == nil {
gid64 = gOverride
} else {
return nil, fmt.Errorf("gid override must be numeric, got %q", gidStr)
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic silently ignores non-numeric group specifiers. If a user passes something like "1000:wheel", the "wheel" group name will be silently ignored rather than being looked up or causing an error. Consider either:

  1. Adding group name lookup similar to user lookup, or
  2. Returning an error for non-numeric gid overrides to make the behavior explicit
Suggested change
}
if gidStr != "" {
if gOverride, err := strconv.ParseUint(gidStr, 10, 32); err == nil {
gid64 = gOverride
} else {
return nil, fmt.Errorf("gid override must be numeric, got %q", gidStr)
}
}

Type: Logic | Severity: Medium

}
cmd.SysProcAttr = &syscall.SysProcAttr{
Credential: &syscall.Credential{Uid: uint32(uid64), Gid: uint32(gid64)},
}
}
return cmd, nil
}

// Execute a command synchronously (optional streaming)
// Execute a command synchronously
// (POST /process/exec)
func (s *ApiService) ProcessExec(ctx context.Context, request oapi.ProcessExecRequestObject) (oapi.ProcessExecResponseObject, error) {
log := logger.FromContext(ctx)
if request.Body == nil {
return oapi.ProcessExec400JSONResponse{BadRequestErrorJSONResponse: oapi.BadRequestErrorJSONResponse{Message: "request body required"}}, nil
}
// Streaming over this endpoint is not supported by the current API definition
if request.Body.Stream != nil && *request.Body.Stream {
return oapi.ProcessExec400JSONResponse{BadRequestErrorJSONResponse: oapi.BadRequestErrorJSONResponse{Message: "streaming not supported for /process/exec"}}, nil
}

cmd, err := buildCmd((*oapi.ProcessExecRequest)(request.Body))
if err != nil {
Expand Down
179 changes: 95 additions & 84 deletions server/cmd/api/api/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ import (
"encoding/base64"
"encoding/json"
"io"
"os/user"
"strconv"
"strings"
"testing"
"time"

"github.com/google/uuid"
openapi_types "github.com/oapi-codegen/runtime/types"
oapi "github.com/onkernel/kernel-images/server/lib/oapi"
"github.com/stretchr/testify/require"
)

func TestProcessExec(t *testing.T) {
Expand All @@ -24,24 +27,17 @@ func TestProcessExec(t *testing.T) {
args := []string{"-c", "echo -n out; echo -n err 1>&2; exit 3"}
body := &oapi.ProcessExecRequest{Command: cmd, Args: &args}
resp, err := svc.ProcessExec(ctx, oapi.ProcessExecRequestObject{Body: body})
if err != nil {
t.Fatalf("ProcessExec error: %v", err)
}
require.NoError(t, err, "ProcessExec error")
r200, ok := resp.(oapi.ProcessExec200JSONResponse)
if !ok {
t.Fatalf("unexpected resp type: %T", resp)
}
if r200.ExitCode == nil || *r200.ExitCode != 3 {
t.Fatalf("exit code mismatch: %+v", r200.ExitCode)
}
if r200.StdoutB64 == nil || r200.StderrB64 == nil {
t.Fatalf("missing stdout/stderr in response")
}
require.True(t, ok, "unexpected resp type: %T", resp)
require.NotNil(t, r200.ExitCode, "missing exit code")
require.Equal(t, 3, *r200.ExitCode, "exit code mismatch")
require.NotNil(t, r200.StdoutB64, "missing stdout in response")
require.NotNil(t, r200.StderrB64, "missing stderr in response")
out, _ := base64.StdEncoding.DecodeString(*r200.StdoutB64)
errB, _ := base64.StdEncoding.DecodeString(*r200.StderrB64)
if string(out) != "out" || string(errB) != "err" {
t.Fatalf("stdout/stderr mismatch: %q %q", string(out), string(errB))
}
require.Equal(t, "out", string(out), "stdout mismatch")
require.Equal(t, "err", string(errB), "stderr mismatch")
}

func TestProcessSpawnStatusAndStream(t *testing.T) {
Expand All @@ -54,32 +50,23 @@ func TestProcessSpawnStatusAndStream(t *testing.T) {
args := []string{"-c", "printf ABC; sleep 0.05; printf DEF 1>&2; sleep 0.05; exit 0"}
body := &oapi.ProcessSpawnRequest{Command: cmd, Args: &args}
spawnResp, err := svc.ProcessSpawn(ctx, oapi.ProcessSpawnRequestObject{Body: body})
if err != nil {
t.Fatalf("ProcessSpawn error: %v", err)
}
require.NoError(t, err, "ProcessSpawn error")
s200, ok := spawnResp.(oapi.ProcessSpawn200JSONResponse)
if !ok || s200.ProcessId == nil || s200.Pid == nil {
t.Fatalf("unexpected spawn resp: %+v", spawnResp)
}
require.True(t, ok, "unexpected spawn resp type: %T", spawnResp)
require.NotNil(t, s200.ProcessId, "missing ProcessId in spawn resp")
require.NotNil(t, s200.Pid, "missing Pid in spawn resp")

// Status should be running initially (may race to exited; tolerate both by not asserting)
statusResp, err := svc.ProcessStatus(ctx, oapi.ProcessStatusRequestObject{ProcessId: *s200.ProcessId})
if err != nil {
t.Fatalf("ProcessStatus error: %v", err)
}
if _, ok := statusResp.(oapi.ProcessStatus200JSONResponse); !ok {
t.Fatalf("unexpected status resp: %T", statusResp)
}
require.NoError(t, err, "ProcessStatus error")
_, ok = statusResp.(oapi.ProcessStatus200JSONResponse)
require.True(t, ok, "unexpected status resp: %T", statusResp)

// Start stream reader and collect at least two data events and one exit event
streamResp, err := svc.ProcessStdoutStream(ctx, oapi.ProcessStdoutStreamRequestObject{ProcessId: *s200.ProcessId})
if err != nil {
t.Fatalf("StdoutStream error: %v", err)
}
require.NoError(t, err, "StdoutStream error")
st200, ok := streamResp.(oapi.ProcessStdoutStream200TexteventStreamResponse)
if !ok {
t.Fatalf("unexpected stream resp: %T", streamResp)
}
require.True(t, ok, "unexpected stream resp: %T", streamResp)

reader := bufio.NewReader(st200.Body)
var gotStdout, gotStderr, gotExit bool
Expand All @@ -90,15 +77,15 @@ func TestProcessSpawnStatusAndStream(t *testing.T) {
if err == io.EOF {
break
}
t.Fatalf("read SSE line: %v", err)
require.NoError(t, err, "read SSE line")
}
if !strings.HasPrefix(line, "data: ") {
continue
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
var evt oapi.ProcessStreamEvent
if err := json.Unmarshal([]byte(payload), &evt); err != nil {
t.Fatalf("unmarshal event: %v", err)
require.NoError(t, err, "unmarshal event")
}
if evt.Stream != nil && *evt.Stream == "stdout" && evt.DataB64 != nil {
b, _ := base64.StdEncoding.DecodeString(*evt.DataB64)
Expand All @@ -118,9 +105,7 @@ func TestProcessSpawnStatusAndStream(t *testing.T) {
// consume blank line
_, _ = reader.ReadString('\n')
}
if !(gotStdout && gotStderr && gotExit) {
t.Fatalf("missing events: stdout=%v stderr=%v exit=%v", gotStdout, gotStderr, gotExit)
}
require.True(t, gotStdout && gotStderr && gotExit, "missing events: stdout=%v stderr=%v exit=%v", gotStdout, gotStderr, gotExit)
}

func TestProcessStdinAndExit(t *testing.T) {
Expand All @@ -133,42 +118,33 @@ func TestProcessStdinAndExit(t *testing.T) {
args := []string{"-c", "dd of=/dev/null bs=1 count=3 status=none"}
body := &oapi.ProcessSpawnRequest{Command: cmd, Args: &args}
spawnResp, err := svc.ProcessSpawn(ctx, oapi.ProcessSpawnRequestObject{Body: body})
if err != nil {
t.Fatalf("ProcessSpawn error: %v", err)
}
require.NoError(t, err, "ProcessSpawn error")
s200, ok := spawnResp.(oapi.ProcessSpawn200JSONResponse)
if !ok || s200.ProcessId == nil {
t.Fatalf("unexpected spawn resp: %T", spawnResp)
}
require.True(t, ok, "unexpected spawn resp: %T", spawnResp)
require.NotNil(t, s200.ProcessId, "missing ProcessId in spawn resp")

// Write 3 bytes
data := base64.StdEncoding.EncodeToString([]byte("xyz"))
stdinResp, err := svc.ProcessStdin(ctx, oapi.ProcessStdinRequestObject{ProcessId: *s200.ProcessId, Body: &oapi.ProcessStdinRequest{DataB64: data}})
if err != nil {
t.Fatalf("ProcessStdin error: %v", err)
}
require.NoError(t, err, "ProcessStdin error")
st200, ok := stdinResp.(oapi.ProcessStdin200JSONResponse)
if !ok || st200.WrittenBytes == nil || *st200.WrittenBytes != 3 {
t.Fatalf("unexpected stdin resp: %+v", stdinResp)
}
require.True(t, ok, "unexpected stdin resp type: %T", stdinResp)
require.NotNil(t, st200.WrittenBytes, "missing WrittenBytes in stdin resp")
require.Equal(t, 3, *st200.WrittenBytes, "written bytes mismatch")

// Wait for exit via status polling
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
resp, err := svc.ProcessStatus(ctx, oapi.ProcessStatusRequestObject{ProcessId: *s200.ProcessId})
if err != nil {
t.Fatalf("ProcessStatus error: %v", err)
}
require.NoError(t, err, "ProcessStatus error")
sr, ok := resp.(oapi.ProcessStatus200JSONResponse)
if !ok {
t.Fatalf("unexpected status resp: %T", resp)
}
require.True(t, ok, "unexpected status resp: %T", resp)
if sr.State != nil && *sr.State == "exited" {
return
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("process did not exit in time")
require.True(t, false, "process did not exit in time")
}

func TestProcessKill(t *testing.T) {
Expand All @@ -180,41 +156,31 @@ func TestProcessKill(t *testing.T) {
args := []string{"-c", "sleep 5"}
body := &oapi.ProcessSpawnRequest{Command: cmd, Args: &args}
spawnResp, err := svc.ProcessSpawn(ctx, oapi.ProcessSpawnRequestObject{Body: body})
if err != nil {
t.Fatalf("ProcessSpawn error: %v", err)
}
require.NoError(t, err, "ProcessSpawn error")
s200, ok := spawnResp.(oapi.ProcessSpawn200JSONResponse)
if !ok || s200.ProcessId == nil {
t.Fatalf("unexpected spawn resp: %T", spawnResp)
}
require.True(t, ok, "unexpected spawn resp: %T", spawnResp)
require.NotNil(t, s200.ProcessId, "missing ProcessId in spawn resp")

// Send KILL
killBody := &oapi.ProcessKillRequest{Signal: "KILL"}
killResp, err := svc.ProcessKill(ctx, oapi.ProcessKillRequestObject{ProcessId: *s200.ProcessId, Body: killBody})
if err != nil {
t.Fatalf("ProcessKill error: %v", err)
}
if _, ok := killResp.(oapi.ProcessKill200JSONResponse); !ok {
t.Fatalf("unexpected kill resp: %T", killResp)
}
require.NoError(t, err, "ProcessKill error")
_, ok = killResp.(oapi.ProcessKill200JSONResponse)
require.True(t, ok, "unexpected kill resp: %T", killResp)

// Verify exited
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
resp, err := svc.ProcessStatus(ctx, oapi.ProcessStatusRequestObject{ProcessId: *s200.ProcessId})
if err != nil {
t.Fatalf("ProcessStatus error: %v", err)
}
require.NoError(t, err, "ProcessStatus error")
sr, ok := resp.(oapi.ProcessStatus200JSONResponse)
if !ok {
t.Fatalf("unexpected status resp: %T", resp)
}
require.True(t, ok, "unexpected status resp: %T", resp)
if sr.State != nil && *sr.State == "exited" {
return
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("process not killed in time")
require.True(t, false, "process not killed in time")
}

func TestProcessNotFoundRoutes(t *testing.T) {
Expand All @@ -224,14 +190,59 @@ func TestProcessNotFoundRoutes(t *testing.T) {

// random id that will not exist
id := openapi_types.UUID(uuid.New())
if resp, _ := svc.ProcessStatus(ctx, oapi.ProcessStatusRequestObject{ProcessId: id}); resp == nil {
t.Fatalf("expected a response")
} else if _, ok := resp.(oapi.ProcessStatus404JSONResponse); !ok {
t.Fatalf("expected 404, got %T", resp)
if resp, _ := svc.ProcessStatus(ctx, oapi.ProcessStatusRequestObject{ProcessId: id}); true {
require.NotNil(t, resp, "expected a response")
_, ok := resp.(oapi.ProcessStatus404JSONResponse)
require.True(t, ok, "expected 404, got %T", resp)
}
if resp, _ := svc.ProcessStdoutStream(ctx, oapi.ProcessStdoutStreamRequestObject{ProcessId: id}); resp == nil {
t.Fatalf("expected a response")
} else if _, ok := resp.(oapi.ProcessStdoutStream404JSONResponse); !ok {
t.Fatalf("expected 404, got %T", resp)
if resp, _ := svc.ProcessStdoutStream(ctx, oapi.ProcessStdoutStreamRequestObject{ProcessId: id}); true {
require.NotNil(t, resp, "expected a response")
_, ok := resp.(oapi.ProcessStdoutStream404JSONResponse)
require.True(t, ok, "expected 404, got %T", resp)
}
}

func TestBuildCmd_AsRootSetsCredential(t *testing.T) {
t.Parallel()
asRoot := true
body := &oapi.ProcessExecRequest{Command: "true", AsRoot: &asRoot}
cmd, err := buildCmd(body)
require.NoError(t, err, "buildCmd returned error")
require.NotNil(t, cmd.SysProcAttr, "expected SysProcAttr to be set for AsRoot")
require.NotNil(t, cmd.SysProcAttr.Credential, "expected SysProcAttr.Credential to be set for AsRoot")
require.Equal(t, uint32(0), cmd.SysProcAttr.Credential.Uid, "expected root uid")
require.Equal(t, uint32(0), cmd.SysProcAttr.Credential.Gid, "expected root gid")
}

func TestBuildCmd_AsUserUidAndGidOverride(t *testing.T) {
t.Parallel()
cur, err := user.Current()
if err != nil {
t.Skipf("skipping: failed to determine current user: %v", err)
}
// Use numeric uid with an explicit gid override to exercise parsing path
spec := cur.Uid + ":0" // override gid to 0 for determinism; we're not executing
body := &oapi.ProcessExecRequest{Command: "true", AsUser: &spec}
cmd, err := buildCmd(body)
require.NoError(t, err, "buildCmd returned error")
require.NotNil(t, cmd.SysProcAttr, "expected SysProcAttr to be set for AsUser")
require.NotNil(t, cmd.SysProcAttr.Credential, "expected SysProcAttr.Credential to be set for AsUser")
// Verify uid matches the looked-up uid and gid matches the override
wantUID64, err := strconv.ParseUint(cur.Uid, 10, 32)
require.NoError(t, err, "parse current uid")
if cmd.SysProcAttr.Credential.Uid != uint32(wantUID64) {
require.Equal(t, uint32(wantUID64), cmd.SysProcAttr.Credential.Uid, "uid mismatch")
}
if cmd.SysProcAttr.Credential.Gid != 0 {
require.Equal(t, uint32(0), cmd.SysProcAttr.Credential.Gid, "gid override mismatch")
}
}

func TestBuildCmd_AsRootAndAsUserConflict(t *testing.T) {
t.Parallel()
asRoot := true
asUser := "0"
body := &oapi.ProcessExecRequest{Command: "true", AsRoot: &asRoot, AsUser: &asUser}
_, err := buildCmd(body)
require.Error(t, err, "expected error when both as_root and as_user are set")
}
Loading
Loading