Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions backend/internal/handler/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,19 +260,8 @@ func (r *RequestsHandler) GenerateRequest(c *fiber.Ctx) error {
Notes: notes,
}}

res, err := r.RequestRepository.InsertRequest(c.Context(), &req)
if err != nil {
slog.Error("failed to insert generated request", "error", err)
return errs.InternalServerError()
}
if r.NotificationSender != nil && req.UserID != nil {
if err := r.NotificationSender.Notify(c.Context(), *req.UserID, models.TypeTaskAssigned, msgTaskAssigned, res.Name); err != nil {
slog.Error("failed to send task assigned notification", "err", err)
}
}

return c.JSON(models.GenerateRequestResponse{
Request: *res,
Request: req,
Warning: warningFromAI(parsed.Warning),
})
}
Expand Down
150 changes: 11 additions & 139 deletions backend/internal/handler/requests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"errors"
"io"
"log/slog"
"net/http/httptest"
"testing"
"time"
Expand Down Expand Up @@ -593,13 +592,6 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
t.Parallel()

description := "Guest requested extra towels for their room"
repoMock := &mockRequestRepository{
makeRequestFunc: func(ctx context.Context, req *models.Request) (*models.Request, error) {
req.ID = "generated-uuid"
req.CreatedAt = time.Now()
return req, nil
},
}

llmMock := &mockLLMService{
runGenerateRequestFunc: func(ctx context.Context, input aiflows.GenerateRequestInput) (aiflows.GenerateRequestOutput, error) {
Expand All @@ -614,7 +606,7 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
}

app := fiber.New()
h := NewRequestsHandler(repoMock, llmMock, nil)
h := NewRequestsHandler(&mockRequestRepository{}, llmMock, nil)
app.Post("/request/generate", h.GenerateRequest)

req := httptest.NewRequest("POST", "/request/generate", bytes.NewBufferString(validBody))
Expand All @@ -626,7 +618,6 @@ func TestRequestHandler_Generate_Request(t *testing.T) {

body, _ := io.ReadAll(resp.Body)
assert.Contains(t, string(body), `"request"`)
assert.Contains(t, string(body), "generated-uuid")
assert.Contains(t, string(body), "Extra Towels Request")
assert.Contains(t, string(body), "high")
})
Expand All @@ -640,13 +631,6 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
notes := "Guest prefers eco-friendly products"
estimatedTime := 30

repoMock := &mockRequestRepository{
makeRequestFunc: func(ctx context.Context, req *models.Request) (*models.Request, error) {
req.ID = "generated-uuid"
return req, nil
},
}

llmMock := &mockLLMService{
runGenerateRequestFunc: func(ctx context.Context, input aiflows.GenerateRequestInput) (aiflows.GenerateRequestOutput, error) {
return aiflows.GenerateRequestOutput{
Expand All @@ -664,7 +648,7 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
}

app := fiber.New()
h := NewRequestsHandler(repoMock, llmMock, nil)
h := NewRequestsHandler(&mockRequestRepository{}, llmMock, nil)
app.Post("/request/generate", h.GenerateRequest)

req := httptest.NewRequest("POST", "/request/generate", bytes.NewBufferString(validBody))
Expand Down Expand Up @@ -776,48 +760,9 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
assert.Equal(t, 500, resp.StatusCode)
})

t.Run("returns 500 on db error", func(t *testing.T) {
t.Parallel()

repoMock := &mockRequestRepository{
makeRequestFunc: func(ctx context.Context, req *models.Request) (*models.Request, error) {
return nil, errors.New("db connection failed")
},
}

llmMock := &mockLLMService{
runGenerateRequestFunc: func(ctx context.Context, input aiflows.GenerateRequestInput) (aiflows.GenerateRequestOutput, error) {
return aiflows.GenerateRequestOutput{
Name: "Towel Request",
RequestType: "one-time",
Status: "pending",
Priority: "medium",
}, nil
},
}

app := fiber.New(fiber.Config{ErrorHandler: errs.ErrorHandler})
h := NewRequestsHandler(repoMock, llmMock, nil)
app.Post("/request/generate", h.GenerateRequest)

req := httptest.NewRequest("POST", "/request/generate", bytes.NewBufferString(validBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
require.NoError(t, err)

assert.Equal(t, 500, resp.StatusCode)
})

t.Run("returns 500 when LLM output fails validation", func(t *testing.T) {
t.Parallel()

repoMock := &mockRequestRepository{
makeRequestFunc: func(ctx context.Context, req *models.Request) (*models.Request, error) {
t.Fatal("request should not be inserted when LLM output is invalid")
return nil, nil
},
}

llmMock := &mockLLMService{
runGenerateRequestFunc: func(ctx context.Context, input aiflows.GenerateRequestInput) (aiflows.GenerateRequestOutput, error) {
return aiflows.GenerateRequestOutput{
Expand All @@ -830,7 +775,7 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
}

app := fiber.New(fiber.Config{ErrorHandler: errs.ErrorHandler})
h := NewRequestsHandler(repoMock, llmMock, nil)
h := NewRequestsHandler(&mockRequestRepository{}, llmMock, nil)
app.Post("/request/generate", h.GenerateRequest)

req := httptest.NewRequest("POST", "/request/generate", bytes.NewBufferString(validBody))
Expand All @@ -846,13 +791,6 @@ func TestRequestHandler_Generate_Request(t *testing.T) {

var capturedInput aiflows.GenerateRequestInput

repoMock := &mockRequestRepository{
makeRequestFunc: func(ctx context.Context, req *models.Request) (*models.Request, error) {
req.ID = "generated-uuid"
return req, nil
},
}

llmMock := &mockLLMService{
runGenerateRequestFunc: func(ctx context.Context, input aiflows.GenerateRequestInput) (aiflows.GenerateRequestOutput, error) {
capturedInput = input
Expand All @@ -866,7 +804,7 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
}

app := fiber.New()
h := NewRequestsHandler(repoMock, llmMock, nil)
h := NewRequestsHandler(&mockRequestRepository{}, llmMock, nil)
app.Post("/request/generate", h.GenerateRequest)

customBody := `{
Expand All @@ -886,16 +824,6 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
t.Run("uses hotel_id from request body not LLM", func(t *testing.T) {
t.Parallel()

var capturedRequest *models.Request

repoMock := &mockRequestRepository{
makeRequestFunc: func(ctx context.Context, req *models.Request) (*models.Request, error) {
capturedRequest = req
req.ID = "generated-uuid"
return req, nil
},
}

llmMock := &mockLLMService{
runGenerateRequestFunc: func(ctx context.Context, input aiflows.GenerateRequestInput) (aiflows.GenerateRequestOutput, error) {
return aiflows.GenerateRequestOutput{
Expand All @@ -908,7 +836,7 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
}

app := fiber.New()
h := NewRequestsHandler(repoMock, llmMock, nil)
h := NewRequestsHandler(&mockRequestRepository{}, llmMock, nil)
app.Post("/request/generate", h.GenerateRequest)

req := httptest.NewRequest("POST", "/request/generate", bytes.NewBufferString(validBody))
Expand All @@ -917,22 +845,13 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
require.NoError(t, err)

assert.Equal(t, 200, resp.StatusCode)
assert.Equal(t, "550e8400-e29b-41d4-a716-446655440000", capturedRequest.HotelID)
body, _ := io.ReadAll(resp.Body)
assert.Contains(t, string(body), "550e8400-e29b-41d4-a716-446655440000")
})

t.Run("defaults notes to empty string when LLM omits notes", func(t *testing.T) {
t.Parallel()

var capturedRequest *models.Request

repoMock := &mockRequestRepository{
makeRequestFunc: func(ctx context.Context, req *models.Request) (*models.Request, error) {
capturedRequest = req
req.ID = "generated-uuid"
return req, nil
},
}

llmMock := &mockLLMService{
runGenerateRequestFunc: func(ctx context.Context, input aiflows.GenerateRequestInput) (aiflows.GenerateRequestOutput, error) {
return aiflows.GenerateRequestOutput{
Expand All @@ -946,7 +865,7 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
}

app := fiber.New()
h := NewRequestsHandler(repoMock, llmMock, nil)
h := NewRequestsHandler(&mockRequestRepository{}, llmMock, nil)
app.Post("/request/generate", h.GenerateRequest)

req := httptest.NewRequest("POST", "/request/generate", bytes.NewBufferString(validBody))
Expand All @@ -955,23 +874,15 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
require.NoError(t, err)

assert.Equal(t, 200, resp.StatusCode)
require.NotNil(t, capturedRequest)
require.NotNil(t, capturedRequest.Notes)
assert.Equal(t, "", *capturedRequest.Notes)
body, _ := io.ReadAll(resp.Body)
assert.Contains(t, string(body), `"notes":""`)
})

t.Run("returns warning metadata when room lookup warning exists", func(t *testing.T) {
t.Parallel()

warningMessage := "Room 301 could not be resolved for this hotel."

repoMock := &mockRequestRepository{
makeRequestFunc: func(ctx context.Context, req *models.Request) (*models.Request, error) {
req.ID = "generated-uuid"
return req, nil
},
}

llmMock := &mockLLMService{
runGenerateRequestFunc: func(ctx context.Context, input aiflows.GenerateRequestInput) (aiflows.GenerateRequestOutput, error) {
return aiflows.GenerateRequestOutput{
Expand All @@ -988,7 +899,7 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
}

app := fiber.New()
h := NewRequestsHandler(repoMock, llmMock, nil)
h := NewRequestsHandler(&mockRequestRepository{}, llmMock, nil)
app.Post("/request/generate", h.GenerateRequest)

req := httptest.NewRequest("POST", "/request/generate", bytes.NewBufferString(validBody))
Expand All @@ -1004,45 +915,6 @@ func TestRequestHandler_Generate_Request(t *testing.T) {
assert.Contains(t, string(body), warningMessage)
})

t.Run("logs repository insert failures before returning 500", func(t *testing.T) {
t.Parallel()

var logBuffer bytes.Buffer
previousLogger := slog.Default()
logger := slog.New(slog.NewTextHandler(&logBuffer, nil))
slog.SetDefault(logger)
t.Cleanup(func() { slog.SetDefault(previousLogger) })

repoMock := &mockRequestRepository{
makeRequestFunc: func(ctx context.Context, req *models.Request) (*models.Request, error) {
return nil, errors.New("db connection failed")
},
}

llmMock := &mockLLMService{
runGenerateRequestFunc: func(ctx context.Context, input aiflows.GenerateRequestInput) (aiflows.GenerateRequestOutput, error) {
return aiflows.GenerateRequestOutput{
Name: "Towel Request",
RequestType: "one-time",
Status: "pending",
Priority: "medium",
}, nil
},
}

app := fiber.New(fiber.Config{ErrorHandler: errs.ErrorHandler})
h := NewRequestsHandler(repoMock, llmMock, nil)
app.Post("/request/generate", h.GenerateRequest)

req := httptest.NewRequest("POST", "/request/generate", bytes.NewBufferString(validBody))
req.Header.Set("Content-Type", "application/json")
resp, err := app.Test(req)
require.NoError(t, err)

assert.Equal(t, 500, resp.StatusCode)
assert.Contains(t, logBuffer.String(), "failed to insert generated request")
assert.Contains(t, logBuffer.String(), "db connection failed")
})
}

func TestRequestHandler_GetRequestByCursor(t *testing.T) {
Expand Down
Loading