diff --git a/cypress/e2e/training-tasks.cy.js b/cypress/e2e/training-tasks.cy.js index e7baada..e512a2f 100644 --- a/cypress/e2e/training-tasks.cy.js +++ b/cypress/e2e/training-tasks.cy.js @@ -20,7 +20,7 @@ describe('Training Tasks Management', () => { const testName = generateUniqueName('tm'); cy.get('input[name="name"]').type(testName); - cy.get('select[name="trainingDatasetId"]').select('LHC24b1b'); + cy.get('select[name="trainingDatasetId"]').select('1'); cy.get('button').click(); let tmObject = cy.contains('tr', testName) @@ -60,7 +60,7 @@ describe('Training Tasks Management', () => { cy.get('@alreadyExisting').then(alreadyExisting => { cy.get('input[name="name"]').type(alreadyExisting); }) - cy.get('select[name="trainingDatasetId"]').select('LHC24b1b'); + cy.get('select[name="trainingDatasetId"]').select('1'); cy.get('button').click(); cy.get('#errors').invoke('text').should('eq', 'Name must be unique\n') diff --git a/internal/db/migrate/migrate.go b/internal/db/migrate/migrate.go index 65ce2dd..ae2093e 100644 --- a/internal/db/migrate/migrate.go +++ b/internal/db/migrate/migrate.go @@ -33,18 +33,87 @@ func SeedDB(db *gorm.DB) error { return err } - aods := []jalien.AODFile{{ - Name: "AO2D.root", - Path: "/alice/sim/2024/LHC24b1b/0/567454/AOD/002/AO2D.root", - Size: 2312421213, - LHCPeriod: "LHC24b1b", - RunNumber: 567454, - AODNumber: 2, - }} + c1Aods := []jalien.AODFile{ + { + Name: "AO2D.root", + Path: "/alice/sim/2023/LHC23c1/302004/AOD/010/AO2D.root", + Size: 3266476446, + LHCPeriod: "LHC23c1", + RunNumber: 302004, + AODNumber: 10, + }, + { + Name: "AO2D.root", + Path: "/alice/sim/2023/LHC23c1/302004/AOD/011/AO2D.root", + Size: 3239114872, + LHCPeriod: "LHC23c1", + RunNumber: 302004, + AODNumber: 11, + }, + { + Name: "AO2D.root", + Path: "/alice/sim/2023/LHC23c1/302004/AOD/013/AO2D.root", + Size: 3260265579, + LHCPeriod: "LHC23c1", + RunNumber: 302004, + AODNumber: 13, + }, + } + + mixedPeriodsAods := []jalien.AODFile{ + { + Name: "AO2D.root", + Path: "/alice/sim/2023/LHC23c1/302004/AOD/010/AO2D.root", + Size: 3266476446, + LHCPeriod: "LHC23c1", + RunNumber: 302004, + AODNumber: 10, + }, + { + Name: "AO2D.root", + Path: "/alice/sim/2023/LHC23c1/302004/AOD/011/AO2D.root", + Size: 3239114872, + LHCPeriod: "LHC23c1", + RunNumber: 302004, + AODNumber: 11, + }, + { + Name: "AO2D.root", + Path: "/alice/sim/2023/LHC23c1/302004/AOD/013/AO2D.root", + Size: 3260265579, + LHCPeriod: "LHC23c1", + RunNumber: 302004, + AODNumber: 13, + }, + { + Name: "AO2D.root", + Path: "/alice/sim/2023/LHC23e1/302002/AOD/013/AO2D.root", + Size: 35403114, + LHCPeriod: "LHC23e1", + RunNumber: 302002, + AODNumber: 13, + }, + { + Name: "AO2D.root", + Path: "/alice/sim/2023/LHC23e1/302002/AOD/024/AO2D.root", + Size: 97906832, + LHCPeriod: "LHC23e1", + RunNumber: 302002, + AODNumber: 24, + }, + { + Name: "AO2D.root", + Path: "/alice/sim/2023/LHC23e1/302002/AOD/030/AO2D.root", + Size: 175726295, + LHCPeriod: "LHC23e1", + RunNumber: 302002, + AODNumber: 30, + }, + } trainingDatasets := []models.TrainingDataset{ - {Name: "LHC24b1b", AODFiles: aods, UserId: users[0].ID}, - {Name: "LHC24b1b2", AODFiles: aods, UserId: users[1].ID}, + {Name: "Mixed periods 2023", AODFiles: mixedPeriodsAods, UserId: users[0].ID}, + {Name: "LHC23c1", AODFiles: c1Aods, UserId: users[1].ID}, } if err := db.Save(trainingDatasets).Error; err != nil { diff --git a/internal/handler/training_task_handler.go b/internal/handler/training_task_handler.go index 619947c..8d3198f 100644 --- a/internal/handler/training_task_handler.go +++ b/internal/handler/training_task_handler.go @@ -172,10 +172,10 @@ func (h *TrainingTaskHandler) Create(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusCreated) } -func InitTrainingTaskRoutes(mux *http.ServeMux, env *environment.Env, ccdbService service.ICCDBService, fileService service.IFileService, nnArch service.INNArchService) { +func InitTrainingTaskRoutes(mux *http.ServeMux, env *environment.Env, ccdbService service.ICCDBService, jalienService service.IJAliEnService, fileService service.IFileService, nnArch service.INNArchService) { prefix := "training-tasks" - ttService := service.NewTrainingTaskService(env.RepositoryContext, ccdbService, fileService, nnArch) + ttService := service.NewTrainingTaskService(env.RepositoryContext, ccdbService, jalienService, fileService, nnArch) tjh := NewTrainingTaskHandler(env, ttService) authMw := middleware.NewAuthMw(env.IAuthService, true) diff --git a/internal/jalien/commands.go b/internal/jalien/commands.go index 6d5c637..9204a3d 100644 --- a/internal/jalien/commands.go +++ b/internal/jalien/commands.go @@ -176,7 +176,7 @@ func ListAndParseDirectory(path string) (*DirectoryContents, error) { if lineParsed.IsDir { dirContents.Subdirs = append(dirContents.Subdirs, Dir{ - Name: lineParsed.Name, + Name: strings.TrimSuffix(lineParsed.Name, "/"), Path: linePath, }) } else if lineParsed.Name == aodFilename { diff --git a/internal/jalien/commands_test.go b/internal/jalien/commands_test.go index c2d3810..6e25f4e 100644 --- a/internal/jalien/commands_test.go +++ b/internal/jalien/commands_test.go @@ -29,7 +29,7 @@ func TestParseLongFormat(t *testing.T) { }, { name: "Valid directory line", - input: "drwxr-xr-x user group 4096 Feb 15 08:00 somedir", + input: "drwxr-xr-x user group 4096 Feb 15 08:00 somedir/", want: &longFormatParsed{ Permissions: "drwxr-xr-x", Owner: "user", @@ -38,7 +38,7 @@ func TestParseLongFormat(t *testing.T) { Month: "Feb", Day: "15", Time: "08:00", - Name: "somedir", + Name: "somedir/", IsDir: true, }, wantErr: false, diff --git a/internal/router.go b/internal/router.go index 2a3e556..a1ce4f2 100644 --- a/internal/router.go +++ b/internal/router.go @@ -40,7 +40,7 @@ func NewRouter(cfg *config.Config, repoContext *repository.RepositoryContext, au // handlers' routes handler.InitLandingRoutes(mux, env) handler.InitTrainingDatasetRoutes(mux, env, jalienService) - handler.InitTrainingTaskRoutes(mux, env, ccdbService, fileService, nnArch) + handler.InitTrainingTaskRoutes(mux, env, ccdbService, jalienService, fileService, nnArch) handler.InitTrainingMachineRoutes(mux, env, hasher) handler.InitQueueRoutes(mux, env, fileService, hasher) diff --git a/internal/service/training_task_service.go b/internal/service/training_task_service.go index d951d3a..febcf5f 100644 --- a/internal/service/training_task_service.go +++ b/internal/service/training_task_service.go @@ -4,10 +4,14 @@ import ( "errors" "fmt" "log" + "regexp" + "slices" + "strconv" "github.com/mytkom/AliceTraINT/internal/ccdb" "github.com/mytkom/AliceTraINT/internal/db/models" "github.com/mytkom/AliceTraINT/internal/db/repository" + "github.com/mytkom/AliceTraINT/internal/jalien" "gorm.io/gorm" ) @@ -32,17 +36,21 @@ type ITrainingTaskService interface { type TrainingTaskService struct { *repository.RepositoryContext - CCDBService ICCDBService - FileService IFileService - NNArch INNArchService + CCDBService ICCDBService + JAliEnService IJAliEnService + FileService IFileService + NNArch INNArchService + PeriodRegex *regexp.Regexp } -func NewTrainingTaskService(repo *repository.RepositoryContext, ccdbService ICCDBService, fileService IFileService, nnArch INNArchService) *TrainingTaskService { +func NewTrainingTaskService(repo *repository.RepositoryContext, ccdbService ICCDBService, jalienService IJAliEnService, fileService IFileService, nnArch INNArchService) *TrainingTaskService { return &TrainingTaskService{ RepositoryContext: repo, CCDBService: ccdbService, + JAliEnService: jalienService, FileService: fileService, NNArch: nnArch, + PeriodRegex: regexp.MustCompile(`(/alice/sim/\d{4}/LHC[a-z0-9A-Z\_].+(/\d+)?)/\d+/AOD/\d+`), } } @@ -152,14 +160,41 @@ func (s *TrainingTaskService) UploadOnnxResults(id uint) error { } } - smallestRun, greatestRun, err := s.findRunNumberRange(trainingTask) + lhcPeriods, err := s.getLHCPeriods(trainingTask) if err != nil { return err } - firstRunInfo, lastRunInfo, err := s.getRunInfoRange(smallestRun, greatestRun) - if err != nil { - return err + var minSOR, maxEOR uint64 + initialized := false + + for i, period := range lhcPeriods { + log.Printf("%d: Name=\"%s\" DirPath=\"%s\"", i, period.Name, period.DirPath) + + dirContents, err := s.JAliEnService.ListAndParseDirectory(period.DirPath) + if err != nil { + return err + } + + smallestRun, greatestRun, err := s.findRunNumberRange(dirContents.Subdirs) + if err != nil { + return err + } + + firstRunInfo, lastRunInfo, err := s.getRunInfoRange(smallestRun, greatestRun) + if err != nil { + return err + } + + if !initialized || firstRunInfo.SOR < minSOR { + minSOR = firstRunInfo.SOR + } + + if !initialized || lastRunInfo.EOR > maxEOR { + maxEOR = lastRunInfo.EOR + } + + initialized = true } mappedOnnxFiles, err := s.filterOnnxFiles(trainingTask.ID) @@ -168,7 +203,7 @@ func (s *TrainingTaskService) UploadOnnxResults(id uint) error { } for uploadName, file := range mappedOnnxFiles { - if err := s.uploadOnnxFile(firstRunInfo.SOR, lastRunInfo.EOR, file, uploadName); err != nil { + if err := s.uploadOnnxFile(minSOR, maxEOR, file, uploadName); err != nil { return err } } @@ -181,16 +216,65 @@ func (s *TrainingTaskService) UploadOnnxResults(id uint) error { return nil } -func (s *TrainingTaskService) findRunNumberRange(task *models.TrainingTask) (uint64, uint64, error) { - var smallestRun, greatestRun uint64 +type lhcPeriod struct { + Name string + DirPath string +} + +func (s *TrainingTaskService) periodPathFromAODPath(aodPath string) (string, error) { + matches := s.PeriodRegex.FindStringSubmatch(aodPath) + + if len(matches) != 3 { + return "", errors.New("unexpected AOD path format, cannot correctly match") + } + + return matches[1], nil +} + +func (s *TrainingTaskService) getLHCPeriods(task *models.TrainingTask) ([]lhcPeriod, error) { + var periods []lhcPeriod initialized := false for _, aod := range task.TrainingDataset.AODFiles { - if !initialized || aod.RunNumber < smallestRun { - smallestRun = aod.RunNumber + if !slices.ContainsFunc(periods, func(p lhcPeriod) bool { + return p.Name == aod.LHCPeriod + }) { + periodPath, err := s.periodPathFromAODPath(aod.Path) + if err != nil { + return nil, err + } + + periods = append(periods, lhcPeriod{ + Name: aod.LHCPeriod, + DirPath: periodPath, + }) + } + initialized = true + } + + if !initialized { + return nil, errors.New("unexpected behaviour: empty training dataset") + } + + return periods, nil +} + +func (s *TrainingTaskService) findRunNumberRange(subdirs []jalien.Dir) (uint64, uint64, error) { + var smallestRun, greatestRun uint64 + initialized := false + + for _, dir := range subdirs { + runNumber, err := strconv.ParseUint(dir.Name, 10, 64) + if err != nil { + log.Println(err.Error()) + continue + } + + if !initialized || runNumber < smallestRun { + smallestRun = runNumber } - if !initialized || aod.RunNumber > greatestRun { - greatestRun = aod.RunNumber + if !initialized || runNumber > greatestRun { + greatestRun = runNumber } initialized = true } diff --git a/test/integration/common.go b/test/integration/common.go index 0b088ab..fcad8b1 100644 --- a/test/integration/common.go +++ b/test/integration/common.go @@ -74,7 +74,7 @@ func mockRouter(db *gorm.DB, cfg *config.Config) *IntegrationTestUtils { // handlers' routes handler.InitLandingRoutes(mux, env) handler.InitTrainingDatasetRoutes(mux, env, jalienService) - handler.InitTrainingTaskRoutes(mux, env, ccdbService, fileService, nnArch) + handler.InitTrainingTaskRoutes(mux, env, ccdbService, jalienService, fileService, nnArch) handler.InitTrainingMachineRoutes(mux, env, hasher) handler.InitQueueRoutes(mux, env, fileService, hasher) diff --git a/test/integration/training_task_handler_test.go b/test/integration/training_task_handler_test.go index d354ccc..c439b9b 100644 --- a/test/integration/training_task_handler_test.go +++ b/test/integration/training_task_handler_test.go @@ -293,13 +293,21 @@ func prepareUploadToCCDB(t *testing.T, ut *IntegrationTestUtils, user *models.Us assert.NoError(t, ut.TrainingTask.Create(trainingTask)) now := uint64(time.Now().UTC().UnixMilli()) - ut.MockedServices.CCDB.On("GetRunInformation", td.AODFiles[0].RunNumber).Return(&ccdb.RunInformation{ - RunNumber: td.AODFiles[0].RunNumber, + ut.MockedServices.JAliEn.On("ListAndParseDirectory", "/alice/sim/2024/LHC24b1b/0").Return(&jalien.DirectoryContents{ + Subdirs: []jalien.Dir{ + {Name: "560000", Path: "/alice/sim/2024/LHC24b1b/0/560000"}, + {Name: "567454", Path: "/alice/sim/2024/LHC24b1b/0/567454"}, + {Name: "567458", Path: "/alice/sim/2024/LHC24b1b/0/567458"}, + {Name: "570000", Path: "/alice/sim/2024/LHC24b1b/0/570000"}, + }, + }, nil) + ut.MockedServices.CCDB.On("GetRunInformation", uint64(560000)).Return(&ccdb.RunInformation{ + RunNumber: 560000, SOR: now - 10000, EOR: now, }, nil) - ut.MockedServices.CCDB.On("GetRunInformation", td.AODFiles[1].RunNumber).Return(&ccdb.RunInformation{ - RunNumber: td.AODFiles[1].RunNumber, + ut.MockedServices.CCDB.On("GetRunInformation", uint64(570000)).Return(&ccdb.RunInformation{ + RunNumber: 570000, SOR: now, EOR: now + 10000, }, nil) diff --git a/test/service/queue_service_test.go b/test/service/queue_service_test.go index 0e18136..592ada7 100644 --- a/test/service/queue_service_test.go +++ b/test/service/queue_service_test.go @@ -12,7 +12,15 @@ import ( "gorm.io/gorm" ) -func newQueueService() (*service.QueueService, *service.MockHasher, *repository.MockTrainingTaskRepository, *repository.MockTrainingMachineRepository, *repository.MockTrainingTaskResultRepository, *service.MockFileService) { +type queueServiceTestUtils struct { + TTRepo *repository.MockTrainingTaskRepository + TTRRepo *repository.MockTrainingTaskResultRepository + TMRepo *repository.MockTrainingMachineRepository + FileService *service.MockFileService + Hasher *service.MockHasher +} + +func newQueueService() (*service.QueueService, *queueServiceTestUtils) { mockHasher := service.NewMockHasher() mockTaskRepo := repository.NewMockTrainingTaskRepository() mockMachineRepo := repository.NewMockTrainingMachineRepository() @@ -25,18 +33,24 @@ func newQueueService() (*service.QueueService, *service.MockHasher, *repository. TrainingTaskResult: mockTaskResultRepo, } - return service.NewQueueService(mockFileService, repoContext, mockHasher), mockHasher, mockTaskRepo, mockMachineRepo, mockTaskResultRepo, mockFileService + return service.NewQueueService(mockFileService, repoContext, mockHasher), &queueServiceTestUtils{ + TTRepo: mockTaskRepo, + TTRRepo: mockTaskResultRepo, + TMRepo: mockMachineRepo, + FileService: mockFileService, + Hasher: mockHasher, + } } func TestQueueService_UpdateTrainingTaskStatus_Success(t *testing.T) { // Arrange - queueService, _, mockTaskRepo, _, _, _ := newQueueService() + queueService, ut := newQueueService() taskID := uint(1) newStatus := models.Training mockTask := &models.TrainingTask{Model: gorm.Model{ID: taskID}, Status: models.Queued} - mockTaskRepo.On("GetByID", taskID).Return(mockTask, nil) - mockTaskRepo.On("Update", mock.AnythingOfType("*models.TrainingTask")).Return(nil) + ut.TTRepo.On("GetByID", taskID).Return(mockTask, nil) + ut.TTRepo.On("Update", mock.AnythingOfType("*models.TrainingTask")).Return(nil) // Act err := queueService.UpdateTrainingTaskStatus(taskID, newStatus) @@ -44,17 +58,17 @@ func TestQueueService_UpdateTrainingTaskStatus_Success(t *testing.T) { // Assert assert.NoError(t, err) assert.Equal(t, newStatus, mockTask.Status) - mockTaskRepo.AssertCalled(t, "GetByID", taskID) - mockTaskRepo.AssertCalled(t, "Update", mockTask) + ut.TTRepo.AssertCalled(t, "GetByID", taskID) + ut.TTRepo.AssertCalled(t, "Update", mockTask) } func TestQueueService_UpdateTrainingTaskStatus_TaskNotFound(t *testing.T) { // Arrange - queueService, _, mockTaskRepo, _, _, _ := newQueueService() + queueService, ut := newQueueService() taskID := uint(1) newStatus := models.Training - mockTaskRepo.On("GetByID", taskID).Return(nil, errors.New("task not found")) + ut.TTRepo.On("GetByID", taskID).Return(nil, errors.New("task not found")) // Act err := queueService.UpdateTrainingTaskStatus(taskID, newStatus) @@ -62,21 +76,21 @@ func TestQueueService_UpdateTrainingTaskStatus_TaskNotFound(t *testing.T) { // Assert assert.Error(t, err) assert.EqualError(t, err, "task not found") - mockTaskRepo.AssertCalled(t, "GetByID", taskID) - mockTaskRepo.AssertNotCalled(t, "Update", mock.Anything) + ut.TTRepo.AssertCalled(t, "GetByID", taskID) + ut.TTRepo.AssertNotCalled(t, "Update", mock.Anything) } func TestQueueService_AuthorizeTrainingMachine_Success(t *testing.T) { // Arrange - queueService, mockHasher, _, mockMachineRepo, _, _ := newQueueService() + queueService, ut := newQueueService() tmID := uint(1) secretID := "valid_secret" hashedSecret := "hashed_secret" trainingMachine := &models.TrainingMachine{SecretKeyHashed: hashedSecret, Model: gorm.Model{ID: tmID}} - mockMachineRepo.On("GetByID", tmID).Return(trainingMachine, nil) - mockMachineRepo.On("Update", trainingMachine).Return(nil) - mockHasher.On("VerifyKey", secretID, hashedSecret).Return(true, nil) + ut.TMRepo.On("GetByID", tmID).Return(trainingMachine, nil) + ut.TMRepo.On("Update", trainingMachine).Return(nil) + ut.Hasher.On("VerifyKey", secretID, hashedSecret).Return(true, nil) // Act result, err := queueService.AuthorizeTrainingMachine(secretID, tmID) @@ -85,20 +99,20 @@ func TestQueueService_AuthorizeTrainingMachine_Success(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, trainingMachine, result) - mockMachineRepo.AssertCalled(t, "GetByID", tmID) - mockHasher.AssertCalled(t, "VerifyKey", secretID, hashedSecret) + ut.TMRepo.AssertCalled(t, "GetByID", tmID) + ut.Hasher.AssertCalled(t, "VerifyKey", secretID, hashedSecret) } func TestQueueService_AuthorizeTrainingMachine_Failure(t *testing.T) { // Arrange - queueService, mockHasher, _, mockMachineRepo, _, _ := newQueueService() + queueService, ut := newQueueService() tmID := uint(1) secretID := "invalid_secret" hashedSecret := "hashed_secret" trainingMachine := &models.TrainingMachine{SecretKeyHashed: hashedSecret, Model: gorm.Model{ID: tmID}} - mockMachineRepo.On("GetByID", tmID).Return(trainingMachine, nil) - mockHasher.On("VerifyKey", secretID, hashedSecret).Return(false, nil) + ut.TMRepo.On("GetByID", tmID).Return(trainingMachine, nil) + ut.Hasher.On("VerifyKey", secretID, hashedSecret).Return(false, nil) // Act result, err := queueService.AuthorizeTrainingMachine(secretID, tmID) @@ -107,18 +121,18 @@ func TestQueueService_AuthorizeTrainingMachine_Failure(t *testing.T) { assert.Error(t, err) assert.Nil(t, result) assert.EqualError(t, err, "authorization failure") - mockMachineRepo.AssertCalled(t, "GetByID", tmID) - mockHasher.AssertCalled(t, "VerifyKey", secretID, hashedSecret) + ut.TMRepo.AssertCalled(t, "GetByID", tmID) + ut.Hasher.AssertCalled(t, "VerifyKey", secretID, hashedSecret) } func TestQueueService_AssignTaskToMachine_Success(t *testing.T) { // Arrange - queueService, _, mockTaskRepo, _, _, _ := newQueueService() + queueService, ut := newQueueService() tmID := uint(1) mockTask := &models.TrainingTask{Model: gorm.Model{ID: 1}, Status: models.Queued} - mockTaskRepo.On("GetFirstQueued").Return(mockTask, nil) - mockTaskRepo.On("Update", mockTask).Return(nil) + ut.TTRepo.On("GetFirstQueued").Return(mockTask, nil) + ut.TTRepo.On("Update", mockTask).Return(nil) // Act task, err := queueService.AssignTaskToMachine(tmID) @@ -128,15 +142,15 @@ func TestQueueService_AssignTaskToMachine_Success(t *testing.T) { assert.Equal(t, task, mockTask) assert.Equal(t, tmID, *mockTask.TrainingMachineId) assert.Equal(t, models.Training, mockTask.Status) - mockTaskRepo.AssertCalled(t, "GetFirstQueued") - mockTaskRepo.AssertCalled(t, "Update", mockTask) + ut.TTRepo.AssertCalled(t, "GetFirstQueued") + ut.TTRepo.AssertCalled(t, "Update", mockTask) } func TestQueueService_AssignTaskToMachine_NoTask(t *testing.T) { // Arrange - queueService, _, mockTaskRepo, _, _, _ := newQueueService() + queueService, ut := newQueueService() - mockTaskRepo.On("GetFirstQueued").Return(nil, errors.New("no task to run")) + ut.TTRepo.On("GetFirstQueued").Return(nil, errors.New("no task to run")) // Act task, err := queueService.AssignTaskToMachine(1) @@ -145,18 +159,18 @@ func TestQueueService_AssignTaskToMachine_NoTask(t *testing.T) { assert.Error(t, err) assert.Nil(t, task) assert.EqualError(t, err, "no task to run") - mockTaskRepo.AssertCalled(t, "GetFirstQueued") - mockTaskRepo.AssertNotCalled(t, "Update", mock.Anything) + ut.TTRepo.AssertCalled(t, "GetFirstQueued") + ut.TTRepo.AssertNotCalled(t, "Update", mock.Anything) } func TestQueueService_AssignTaskToMachine_UpdateError(t *testing.T) { // Arrange - queueService, _, mockTaskRepo, _, _, _ := newQueueService() + queueService, ut := newQueueService() tmID := uint(1) mockTask := &models.TrainingTask{Model: gorm.Model{ID: 1}, Status: models.Queued} - mockTaskRepo.On("GetFirstQueued").Return(mockTask, nil) - mockTaskRepo.On("Update", mockTask).Return(errors.New("update failed")) + ut.TTRepo.On("GetFirstQueued").Return(mockTask, nil) + ut.TTRepo.On("Update", mockTask).Return(errors.New("update failed")) // Act task, err := queueService.AssignTaskToMachine(tmID) @@ -165,13 +179,13 @@ func TestQueueService_AssignTaskToMachine_UpdateError(t *testing.T) { assert.Error(t, err) assert.Nil(t, task) assert.EqualError(t, err, "cannot assign task to machine: update failed") - mockTaskRepo.AssertCalled(t, "GetFirstQueued") - mockTaskRepo.AssertCalled(t, "Update", mockTask) + ut.TTRepo.AssertCalled(t, "GetFirstQueued") + ut.TTRepo.AssertCalled(t, "Update", mockTask) } func TestQueueService_CreateTrainingTaskResult_Success(t *testing.T) { // Arrange - queueService, _, mockTaskRepo, _, mockTaskResultRepo, mockFileService := newQueueService() + queueService, ut := newQueueService() taskID := uint(1) fileName := "test-file.txt" description := "Test description" @@ -186,9 +200,9 @@ func TestQueueService_CreateTrainingTaskResult_Success(t *testing.T) { File: *mockFileModel, } - mockTaskRepo.On("GetByID", taskID).Return(mockTask, nil) - mockFileService.On("SaveFile", mock.Anything, mock.Anything).Return(mockFileModel, nil) - mockTaskResultRepo.On("Create", mock.Anything).Return(nil) + ut.TTRepo.On("GetByID", taskID).Return(mockTask, nil) + ut.FileService.On("SaveFile", mock.Anything, mock.Anything).Return(mockFileModel, nil) + ut.TTRRepo.On("Create", mock.Anything).Return(nil) // Act result, err := queueService.CreateTrainingTaskResult(taskID, nil, nil, fileName, description, fileType) @@ -197,17 +211,17 @@ func TestQueueService_CreateTrainingTaskResult_Success(t *testing.T) { assert.NoError(t, err) assert.Equal(t, mockResult.Name, result.Name) assert.Equal(t, mockResult.Description, result.Description) - mockTaskRepo.AssertCalled(t, "GetByID", taskID) - mockFileService.AssertCalled(t, "SaveFile", mock.Anything, mock.Anything) - mockTaskResultRepo.AssertCalled(t, "Create", mock.Anything) + ut.TTRepo.AssertCalled(t, "GetByID", taskID) + ut.FileService.AssertCalled(t, "SaveFile", mock.Anything, mock.Anything) + ut.TTRRepo.AssertCalled(t, "Create", mock.Anything) } func TestQueueService_CreateTrainingTaskResult_TaskNotFound(t *testing.T) { // Arrange - queueService, _, mockTaskRepo, _, mockTaskResultRepo, mockFileService := newQueueService() + queueService, ut := newQueueService() taskID := uint(1) - mockTaskRepo.On("GetByID", taskID).Return(nil, errors.New("training task does not exist")) + ut.TTRepo.On("GetByID", taskID).Return(nil, errors.New("training task does not exist")) // Act result, err := queueService.CreateTrainingTaskResult(taskID, nil, nil, "test", "desc", "1") @@ -216,19 +230,19 @@ func TestQueueService_CreateTrainingTaskResult_TaskNotFound(t *testing.T) { assert.Error(t, err) assert.Nil(t, result) assert.EqualError(t, err, "training task does not exist") - mockTaskRepo.AssertCalled(t, "GetByID", taskID) - mockFileService.AssertNotCalled(t, "SaveFile", mock.Anything, mock.Anything) - mockTaskResultRepo.AssertNotCalled(t, "Create", mock.Anything) + ut.TTRepo.AssertCalled(t, "GetByID", taskID) + ut.FileService.AssertNotCalled(t, "SaveFile", mock.Anything, mock.Anything) + ut.TTRRepo.AssertNotCalled(t, "Create", mock.Anything) } func TestQueueService_CreateTrainingTaskResult_FileSaveError(t *testing.T) { // Arrange - queueService, _, mockTaskRepo, _, mockTaskResultRepo, mockFileService := newQueueService() + queueService, ut := newQueueService() taskID := uint(1) mockTask := &models.TrainingTask{Model: gorm.Model{ID: taskID}} - mockTaskRepo.On("GetByID", taskID).Return(mockTask, nil) - mockFileService.On("SaveFile", mock.Anything, mock.Anything).Return(nil, errors.New("file save error")) + ut.TTRepo.On("GetByID", taskID).Return(mockTask, nil) + ut.FileService.On("SaveFile", mock.Anything, mock.Anything).Return(nil, errors.New("file save error")) // Act result, err := queueService.CreateTrainingTaskResult(taskID, nil, nil, "test", "desc", "1") @@ -237,7 +251,7 @@ func TestQueueService_CreateTrainingTaskResult_FileSaveError(t *testing.T) { assert.Error(t, err) assert.Nil(t, result) assert.EqualError(t, err, "error saving file: file save error") - mockTaskRepo.AssertCalled(t, "GetByID", taskID) - mockFileService.AssertCalled(t, "SaveFile", mock.Anything, mock.Anything) - mockTaskResultRepo.AssertNotCalled(t, "Create", mock.Anything) + ut.TTRepo.AssertCalled(t, "GetByID", taskID) + ut.FileService.AssertCalled(t, "SaveFile", mock.Anything, mock.Anything) + ut.TTRRepo.AssertNotCalled(t, "Create", mock.Anything) } diff --git a/test/service/training_dataset_service_test.go b/test/service/training_dataset_service_test.go index c0479b5..e768be8 100644 --- a/test/service/training_dataset_service_test.go +++ b/test/service/training_dataset_service_test.go @@ -10,25 +10,32 @@ import ( "github.com/stretchr/testify/assert" ) -func newTrainingDatasetService() (*repository.MockTrainingDatasetRepository, *service.MockJAliEnService, *service.TrainingDatasetService) { +type trainingDatasetServiceTestUtils struct { + TDRepo *repository.MockTrainingDatasetRepository + JAliEnService *service.MockJAliEnService +} + +func newTrainingDatasetService() (*service.TrainingDatasetService, *trainingDatasetServiceTestUtils) { tdRepo := repository.NewMockTrainingDatasetRepository() jalienService := service.NewMockJAliEnService() - tdService := service.NewTrainingDatasetService(&repository.RepositoryContext{ - TrainingDataset: tdRepo, - }, jalienService) - return tdRepo, jalienService, tdService + return service.NewTrainingDatasetService(&repository.RepositoryContext{ + TrainingDataset: tdRepo, + }, jalienService), &trainingDatasetServiceTestUtils{ + TDRepo: tdRepo, + JAliEnService: jalienService, + } } func TestTrainingDatasetService_GetAll_Global(t *testing.T) { // Arrange - tdRepo, _, tdService := newTrainingDatasetService() + tdService, ut := newTrainingDatasetService() userId := uint(1) tds := []models.TrainingDataset{ {Name: "LHC24b1b", UserId: userId, AODFiles: []jalien.AODFile{}}, {Name: "LHC24b1b2", UserId: userId, AODFiles: []jalien.AODFile{}}, } - tdRepo.On("GetAll").Return(tds, nil) + ut.TDRepo.On("GetAll").Return(tds, nil) // Act datasets, err := tdService.GetAll(userId, false) @@ -42,14 +49,14 @@ func TestTrainingDatasetService_GetAll_Global(t *testing.T) { func TestTrainingDatasetService_GetAll_UserScoped(t *testing.T) { // Arrange - tdRepo, _, tdService := newTrainingDatasetService() + tdService, ut := newTrainingDatasetService() userId := uint(1) tds := []models.TrainingDataset{ {Name: "LHC24b1b", UserId: userId, AODFiles: []jalien.AODFile{}}, {Name: "LHC24b1b2", UserId: userId, AODFiles: []jalien.AODFile{}}, } - tdRepo.On("GetAllUser", userId).Return(tds, nil) + ut.TDRepo.On("GetAllUser", userId).Return(tds, nil) // Act datasets, err := tdService.GetAll(userId, true) diff --git a/test/service/training_machine_service_test.go b/test/service/training_machine_service_test.go index 5f16e95..6579517 100644 --- a/test/service/training_machine_service_test.go +++ b/test/service/training_machine_service_test.go @@ -11,32 +11,40 @@ import ( "github.com/stretchr/testify/mock" ) -func newTrainingMachineService() (*service.TrainingMachineService, *repository.MockTrainingMachineRepository, *service.MockHasher) { +type trainingMachineServiceTestUtils struct { + TMRepo *repository.MockTrainingMachineRepository + Hasher *service.MockHasher +} + +func newTrainingMachineService() (*service.TrainingMachineService, *trainingMachineServiceTestUtils) { tmRepo := repository.NewMockTrainingMachineRepository() hasher := service.NewMockHasher() return service.NewTrainingMachineService(&repository.RepositoryContext{ - TrainingMachine: tmRepo, - }, hasher), tmRepo, hasher + TrainingMachine: tmRepo, + }, hasher), &trainingMachineServiceTestUtils{ + TMRepo: tmRepo, + Hasher: hasher, + } } func TestTrainingMachineService_GetAll_Global(t *testing.T) { // Arrange - tmService, tmRepo, _ := newTrainingMachineService() + tmService, ut := newTrainingMachineService() userId := uint(1) tms := []models.TrainingMachine{ {Name: "awm1", UserId: userId, SecretKeyHashed: "secret1", LastActivityAt: time.Now()}, {Name: "awm2", UserId: userId, SecretKeyHashed: "secret2", LastActivityAt: time.Now().Add(5 * time.Hour)}, } - tmRepo.On("GetAll").Return(tms, nil) + ut.TMRepo.On("GetAll").Return(tms, nil) // Act machines, err := tmService.GetAll(userId, false) // Assert assert.NoError(t, err) - tmRepo.AssertCalled(t, "GetAll") - tmRepo.AssertNotCalled(t, "GetAllUser", mock.Anything) + ut.TMRepo.AssertCalled(t, "GetAll") + ut.TMRepo.AssertNotCalled(t, "GetAllUser", mock.Anything) assert.Equal(t, 2, len(machines)) assert.Equal(t, tms[0].Name, machines[0].Name) assert.Equal(t, tms[1].Name, machines[1].Name) @@ -44,21 +52,21 @@ func TestTrainingMachineService_GetAll_Global(t *testing.T) { func TestTrainingMachineService_GetAll_UserScoped(t *testing.T) { // Arrange - tmService, tmRepo, _ := newTrainingMachineService() + tmService, ut := newTrainingMachineService() userId := uint(1) tms := []models.TrainingMachine{ {Name: "awm1", UserId: userId, SecretKeyHashed: "secret1", LastActivityAt: time.Now()}, {Name: "awm2", UserId: userId, SecretKeyHashed: "secret2", LastActivityAt: time.Now().Add(5 * time.Hour)}, } - tmRepo.On("GetAllUser", userId).Return(tms, nil) + ut.TMRepo.On("GetAllUser", userId).Return(tms, nil) // Act machines, err := tmService.GetAll(userId, true) // Assert assert.NoError(t, err) - tmRepo.AssertNotCalled(t, "GetAll") - tmRepo.AssertCalled(t, "GetAllUser", userId) + ut.TMRepo.AssertNotCalled(t, "GetAll") + ut.TMRepo.AssertCalled(t, "GetAllUser", userId) assert.Equal(t, 2, len(machines)) assert.Equal(t, tms[0].Name, machines[0].Name) assert.Equal(t, tms[1].Name, machines[1].Name) @@ -66,25 +74,25 @@ func TestTrainingMachineService_GetAll_UserScoped(t *testing.T) { func TestTrainingMachineService_Create(t *testing.T) { // Arrange - tmService, tmRepo, hasher := newTrainingMachineService() + tmService, ut := newTrainingMachineService() userId := uint(1) tm := models.TrainingMachine{ Name: "awm1", UserId: userId, LastActivityAt: time.Now(), } - tmRepo.On("Create", mock.Anything).Return(nil) - hasher.On("GenerateKey").Return("secret", nil) - hasher.On("HashKey", "secret").Return("secretHashed", nil) + ut.TMRepo.On("Create", mock.Anything).Return(nil) + ut.Hasher.On("GenerateKey").Return("secret", nil) + ut.Hasher.On("HashKey", "secret").Return("secretHashed", nil) // Act secretKey, err := tmService.Create(&tm) // Assert assert.NoError(t, err) - hasher.AssertCalled(t, "GenerateKey") - hasher.AssertCalled(t, "HashKey", "secret") - tmRepo.AssertCalled(t, "Create", &tm) + ut.Hasher.AssertCalled(t, "GenerateKey") + ut.Hasher.AssertCalled(t, "HashKey", "secret") + ut.TMRepo.AssertCalled(t, "Create", &tm) assert.Equal(t, "secret", secretKey) assert.Equal(t, "secretHashed", tm.SecretKeyHashed) } diff --git a/test/service/training_task_service_test.go b/test/service/training_task_service_test.go index d746516..7bbd690 100644 --- a/test/service/training_task_service_test.go +++ b/test/service/training_task_service_test.go @@ -17,10 +17,21 @@ import ( "gorm.io/gorm" ) -func newTrainingTaskService() (*service.TrainingTaskService, *repository.MockTrainingTaskRepository, *repository.MockTrainingDatasetRepository, *repository.MockTrainingTaskResultRepository, *service.MockCCDBService, *service.MockFileService, *service.NNArchServiceInMemory) { +type trainingTaskServiceTestUtils struct { + TTRepo *repository.MockTrainingTaskRepository + TDRepo *repository.MockTrainingDatasetRepository + TTRRepo *repository.MockTrainingTaskResultRepository + CCDBService *service.MockCCDBService + JAliEnService *service.MockJAliEnService + FileService *service.MockFileService + NNArch *service.NNArchServiceInMemory +} + +func newTrainingTaskService() (*service.TrainingTaskService, *trainingTaskServiceTestUtils) { ttRepo := repository.NewMockTrainingTaskRepository() tdRepo := repository.NewMockTrainingDatasetRepository() ttrRepo := repository.NewMockTrainingTaskResultRepository() + jalienService := service.NewMockJAliEnService() ccdbService := service.NewMockCCDBService() fileService := service.NewMockFileService() nnArch := service.NewNNArchServiceInMemory(&service.NNFieldConfigs{ @@ -40,15 +51,23 @@ func newTrainingTaskService() (*service.TrainingTaskService, *repository.MockTra }) return service.NewTrainingTaskService(&repository.RepositoryContext{ - TrainingTask: ttRepo, - TrainingDataset: tdRepo, - TrainingTaskResult: ttrRepo, - }, ccdbService, fileService, nnArch), ttRepo, tdRepo, ttrRepo, ccdbService, fileService, nnArch + TrainingTask: ttRepo, + TrainingDataset: tdRepo, + TrainingTaskResult: ttrRepo, + }, ccdbService, jalienService, fileService, nnArch), &trainingTaskServiceTestUtils{ + TTRepo: ttRepo, + TDRepo: tdRepo, + TTRRepo: ttrRepo, + CCDBService: ccdbService, + JAliEnService: jalienService, + FileService: fileService, + NNArch: nnArch, + } } func TestTrainingTaskService_GetAll_Global(t *testing.T) { // Arrange - ttService, ttRepo, _, _, _, _, _ := newTrainingTaskService() + ttService, ut := newTrainingTaskService() userId := uint(1) tdId := uint(1) tmId := uint(1) @@ -56,15 +75,15 @@ func TestTrainingTaskService_GetAll_Global(t *testing.T) { {Name: "task1", UserId: userId, Status: models.Queued, TrainingMachineId: nil, TrainingDatasetId: tdId, Configuration: ""}, {Name: "task2", UserId: userId, Status: models.Benchmarking, TrainingMachineId: &tmId, TrainingDatasetId: tdId, Configuration: ""}, } - ttRepo.On("GetAll").Return(tts, nil) + ut.TTRepo.On("GetAll").Return(tts, nil) // Act tasks, err := ttService.GetAll(userId, false) // Assert assert.NoError(t, err) - ttRepo.AssertCalled(t, "GetAll") - ttRepo.AssertNotCalled(t, "GetAllUser", mock.Anything) + ut.TTRepo.AssertCalled(t, "GetAll") + ut.TTRepo.AssertNotCalled(t, "GetAllUser", mock.Anything) assert.Equal(t, 2, len(tasks)) assert.Equal(t, tts[0].Name, tasks[0].Name) assert.Equal(t, tts[1].Name, tasks[1].Name) @@ -72,7 +91,7 @@ func TestTrainingTaskService_GetAll_Global(t *testing.T) { func TestTrainingTaskService_GetAll_UserScoped(t *testing.T) { // Arrange - ttService, ttRepo, _, _, _, _, _ := newTrainingTaskService() + ttService, ut := newTrainingTaskService() userId := uint(1) tdId := uint(1) tmId := uint(1) @@ -80,15 +99,15 @@ func TestTrainingTaskService_GetAll_UserScoped(t *testing.T) { {Name: "task1", UserId: userId, Status: models.Queued, TrainingMachineId: nil, TrainingDatasetId: tdId, Configuration: ""}, {Name: "task2", UserId: userId, Status: models.Benchmarking, TrainingMachineId: &tmId, TrainingDatasetId: tdId, Configuration: ""}, } - ttRepo.On("GetAllUser", userId).Return(tts, nil) + ut.TTRepo.On("GetAllUser", userId).Return(tts, nil) // Act tasks, err := ttService.GetAll(userId, true) // Assert assert.NoError(t, err) - ttRepo.AssertNotCalled(t, "GetAll") - ttRepo.AssertCalled(t, "GetAllUser", userId) + ut.TTRepo.AssertNotCalled(t, "GetAll") + ut.TTRepo.AssertCalled(t, "GetAllUser", userId) assert.Equal(t, 2, len(tasks)) assert.Equal(t, tts[0].Name, tasks[0].Name) assert.Equal(t, tts[1].Name, tasks[1].Name) @@ -96,7 +115,7 @@ func TestTrainingTaskService_GetAll_UserScoped(t *testing.T) { func TestTrainingTaskService_Create(t *testing.T) { // Arrange - ttService, ttRepo, _, _, _, _, _ := newTrainingTaskService() + ttService, ut := newTrainingTaskService() userId := uint(1) tdId := uint(1) tt := models.TrainingTask{ @@ -106,42 +125,42 @@ func TestTrainingTaskService_Create(t *testing.T) { TrainingDatasetId: tdId, Configuration: "", } - ttRepo.On("Create", &tt).Return(nil) + ut.TTRepo.On("Create", &tt).Return(nil) // Act err := ttService.Create(&tt) // Assert assert.NoError(t, err) - ttRepo.AssertCalled(t, "Create", &tt) + ut.TTRepo.AssertCalled(t, "Create", &tt) assert.Equal(t, models.Queued, tt.Status) assert.Equal(t, (*uint)(nil), tt.TrainingMachineId) } func TestTrainingTaskService_GetHelpers(t *testing.T) { // Arrange - ttService, _, tdRepo, _, _, _, nnArch := newTrainingTaskService() + ttService, ut := newTrainingTaskService() userId := uint(1) tds := []models.TrainingDataset{ {Name: "LHC24b1b", UserId: userId, AODFiles: []jalien.AODFile{}}, {Name: "LHC24b1b2", UserId: userId, AODFiles: []jalien.AODFile{}}, } - tdRepo.On("GetAllUser", userId).Return(tds, nil) + ut.TDRepo.On("GetAllUser", userId).Return(tds, nil) // Act helpers, err := ttService.GetHelpers(userId) // Assert assert.NoError(t, err) - tdRepo.AssertCalled(t, "GetAllUser", userId) + ut.TDRepo.AssertCalled(t, "GetAllUser", userId) assert.Equal(t, tds[0].Name, helpers.TrainingDatasets[0].Name) assert.Equal(t, tds[1].Name, helpers.TrainingDatasets[1].Name) - assert.True(t, reflect.DeepEqual(helpers.FieldConfigs, nnArch.FieldConfigs)) + assert.True(t, reflect.DeepEqual(helpers.FieldConfigs, ut.NNArch.FieldConfigs)) } func TestTrainingTaskService_GetByID_Queued(t *testing.T) { // Arrange - ttService, ttRepo, ttrRepo, _, _, _, _ := newTrainingTaskService() + ttService, ut := newTrainingTaskService() userId := uint(1) tdId := uint(1) ttId := uint(1) @@ -152,18 +171,18 @@ func TestTrainingTaskService_GetByID_Queued(t *testing.T) { TrainingDatasetId: tdId, Configuration: "", } - ttRepo.On("GetByID", ttId).Return(&tt, nil) - ttRepo.On("GetByType", ttId, models.Onnx).Return([]models.TrainingTaskResult{}, nil) - ttRepo.On("GetByType", ttId, models.Image).Return([]models.TrainingTaskResult{}, nil) + ut.TTRepo.On("GetByID", ttId).Return(&tt, nil) + ut.TTRepo.On("GetByType", ttId, models.Onnx).Return([]models.TrainingTaskResult{}, nil) + ut.TTRepo.On("GetByType", ttId, models.Image).Return([]models.TrainingTaskResult{}, nil) // Act ttWithRes, err := ttService.GetByID(ttId) // Assert assert.NoError(t, err) - ttRepo.AssertCalled(t, "GetByID", ttId) - ttrRepo.AssertNotCalled(t, "GetByType", ttId, models.Onnx) - ttrRepo.AssertNotCalled(t, "GetByType", ttId, models.Image) + ut.TTRepo.AssertCalled(t, "GetByID", ttId) + ut.TTRRepo.AssertNotCalled(t, "GetByType", ttId, models.Onnx) + ut.TTRRepo.AssertNotCalled(t, "GetByType", ttId, models.Image) assert.Equal(t, tt.Status, ttWithRes.TrainingTask.Status) assert.Equal(t, []models.TrainingTaskResult(nil), ttWithRes.OnnxFiles) assert.Equal(t, []models.TrainingTaskResult(nil), ttWithRes.ImageFiles) @@ -171,7 +190,7 @@ func TestTrainingTaskService_GetByID_Queued(t *testing.T) { func TestTrainingTaskService_GetByID_Training(t *testing.T) { // Arrange - ttService, ttRepo, ttrRepo, _, _, _, _ := newTrainingTaskService() + ttService, ut := newTrainingTaskService() userId := uint(1) tdId := uint(1) ttId := uint(1) @@ -182,18 +201,18 @@ func TestTrainingTaskService_GetByID_Training(t *testing.T) { TrainingDatasetId: tdId, Configuration: "", } - ttRepo.On("GetByID", ttId).Return(&tt, nil) - ttRepo.On("GetByType", ttId, models.Onnx).Return([]models.TrainingTaskResult{}, nil) - ttRepo.On("GetByType", ttId, models.Image).Return([]models.TrainingTaskResult{}, nil) + ut.TTRepo.On("GetByID", ttId).Return(&tt, nil) + ut.TTRepo.On("GetByType", ttId, models.Onnx).Return([]models.TrainingTaskResult{}, nil) + ut.TTRepo.On("GetByType", ttId, models.Image).Return([]models.TrainingTaskResult{}, nil) // Act ttWithRes, err := ttService.GetByID(ttId) // Assert assert.NoError(t, err) - ttRepo.AssertCalled(t, "GetByID", ttId) - ttrRepo.AssertNotCalled(t, "GetByType", ttId, models.Onnx) - ttrRepo.AssertNotCalled(t, "GetByType", ttId, models.Image) + ut.TTRepo.AssertCalled(t, "GetByID", ttId) + ut.TTRRepo.AssertNotCalled(t, "GetByType", ttId, models.Onnx) + ut.TTRRepo.AssertNotCalled(t, "GetByType", ttId, models.Image) assert.Equal(t, tt.Status, ttWithRes.TrainingTask.Status) assert.Equal(t, []models.TrainingTaskResult(nil), ttWithRes.OnnxFiles) assert.Equal(t, []models.TrainingTaskResult(nil), ttWithRes.ImageFiles) @@ -201,7 +220,7 @@ func TestTrainingTaskService_GetByID_Training(t *testing.T) { func TestTrainingTaskService_GetByID_Benchmarking(t *testing.T) { // Arrange - ttService, ttRepo, _, ttrRepo, _, _, _ := newTrainingTaskService() + ttService, ut := newTrainingTaskService() userId := uint(1) tdId := uint(1) ttId := uint(1) @@ -218,18 +237,18 @@ func TestTrainingTaskService_GetByID_Benchmarking(t *testing.T) { Name: "file.onnx", Path: "./file.onnx", Size: 12312, }, TrainingTaskId: ttId}, } - ttRepo.On("GetByID", ttId).Return(&tt, nil) - ttrRepo.On("GetByType", ttId, models.Onnx).Return(onnxFiles, nil) - ttrRepo.On("GetByType", ttId, models.Image).Return([]models.TrainingTaskResult{}, nil) + ut.TTRepo.On("GetByID", ttId).Return(&tt, nil) + ut.TTRRepo.On("GetByType", ttId, models.Onnx).Return(onnxFiles, nil) + ut.TTRRepo.On("GetByType", ttId, models.Image).Return([]models.TrainingTaskResult{}, nil) // Act ttWithRes, err := ttService.GetByID(ttId) // Assert assert.NoError(t, err) - ttRepo.AssertCalled(t, "GetByID", ttId) - ttrRepo.AssertCalled(t, "GetByType", ttId, models.Onnx) - ttrRepo.AssertNotCalled(t, "GetByType", ttId, models.Image) + ut.TTRepo.AssertCalled(t, "GetByID", ttId) + ut.TTRRepo.AssertCalled(t, "GetByType", ttId, models.Onnx) + ut.TTRRepo.AssertNotCalled(t, "GetByType", ttId, models.Image) assert.Equal(t, tt.Status, ttWithRes.TrainingTask.Status) assert.True(t, reflect.DeepEqual(onnxFiles, ttWithRes.OnnxFiles)) assert.Equal(t, []models.TrainingTaskResult(nil), ttWithRes.ImageFiles) @@ -237,7 +256,7 @@ func TestTrainingTaskService_GetByID_Benchmarking(t *testing.T) { func TestTrainingTaskService_GetByID_Completed(t *testing.T) { // Arrange - ttService, ttRepo, _, ttrRepo, _, _, _ := newTrainingTaskService() + ttService, ut := newTrainingTaskService() userId := uint(1) tdId := uint(1) ttId := uint(1) @@ -259,18 +278,18 @@ func TestTrainingTaskService_GetByID_Completed(t *testing.T) { Name: "file.png", Path: "./file.png", Size: 12312231, }, TrainingTaskId: ttId}, } - ttRepo.On("GetByID", ttId).Return(&tt, nil) - ttrRepo.On("GetByType", ttId, models.Onnx).Return(onnxFiles, nil) - ttrRepo.On("GetByType", ttId, models.Image).Return(imageFiles, nil) + ut.TTRepo.On("GetByID", ttId).Return(&tt, nil) + ut.TTRRepo.On("GetByType", ttId, models.Onnx).Return(onnxFiles, nil) + ut.TTRRepo.On("GetByType", ttId, models.Image).Return(imageFiles, nil) // Act ttWithRes, err := ttService.GetByID(ttId) // Assert assert.NoError(t, err) - ttRepo.AssertCalled(t, "GetByID", ttId) - ttrRepo.AssertCalled(t, "GetByType", ttId, models.Onnx) - ttrRepo.AssertCalled(t, "GetByType", ttId, models.Image) + ut.TTRepo.AssertCalled(t, "GetByID", ttId) + ut.TTRRepo.AssertCalled(t, "GetByType", ttId, models.Onnx) + ut.TTRRepo.AssertCalled(t, "GetByType", ttId, models.Image) assert.Equal(t, tt.Status, ttWithRes.TrainingTask.Status) assert.True(t, reflect.DeepEqual(onnxFiles, ttWithRes.OnnxFiles)) assert.True(t, reflect.DeepEqual(imageFiles, ttWithRes.ImageFiles)) @@ -278,7 +297,7 @@ func TestTrainingTaskService_GetByID_Completed(t *testing.T) { func TestTrainingTaskService_GetByID_Uploaded(t *testing.T) { // Arrange - ttService, ttRepo, _, ttrRepo, _, _, _ := newTrainingTaskService() + ttService, ut := newTrainingTaskService() userId := uint(1) tdId := uint(1) ttId := uint(1) @@ -300,18 +319,18 @@ func TestTrainingTaskService_GetByID_Uploaded(t *testing.T) { Name: "file.png", Path: "./file.png", Size: 12312231, }, TrainingTaskId: ttId}, } - ttRepo.On("GetByID", ttId).Return(&tt, nil) - ttrRepo.On("GetByType", ttId, models.Onnx).Return(onnxFiles, nil) - ttrRepo.On("GetByType", ttId, models.Image).Return(imageFiles, nil) + ut.TTRepo.On("GetByID", ttId).Return(&tt, nil) + ut.TTRRepo.On("GetByType", ttId, models.Onnx).Return(onnxFiles, nil) + ut.TTRRepo.On("GetByType", ttId, models.Image).Return(imageFiles, nil) // Act ttWithRes, err := ttService.GetByID(ttId) // Assert assert.NoError(t, err) - ttRepo.AssertCalled(t, "GetByID", ttId) - ttrRepo.AssertCalled(t, "GetByType", ttId, models.Onnx) - ttrRepo.AssertCalled(t, "GetByType", ttId, models.Image) + ut.TTRepo.AssertCalled(t, "GetByID", ttId) + ut.TTRRepo.AssertCalled(t, "GetByType", ttId, models.Onnx) + ut.TTRRepo.AssertCalled(t, "GetByType", ttId, models.Image) assert.Equal(t, tt.Status, ttWithRes.TrainingTask.Status) assert.True(t, reflect.DeepEqual(onnxFiles, ttWithRes.OnnxFiles)) assert.True(t, reflect.DeepEqual(imageFiles, ttWithRes.ImageFiles)) @@ -322,9 +341,9 @@ type mockReadCloser struct{} func (m mockReadCloser) Read(p []byte) (int, error) { return 0, nil } func (m mockReadCloser) Close() error { return nil } -func TestTrainingTaskService_UploadToCCDB_Success(t *testing.T) { +func TestTrainingTaskService_UploadToCCDB_OnePeriod(t *testing.T) { // Arrange - ttService, ttRepo, _, ttrRepo, ccdbService, fileService, _ := newTrainingTaskService() + ttService, ut := newTrainingTaskService() userId := uint(1) tdId := uint(1) ttId := uint(1) @@ -338,9 +357,9 @@ func TestTrainingTaskService_UploadToCCDB_Success(t *testing.T) { Model: gorm.Model{ID: tdId}, UserId: userId, AODFiles: []jalien.AODFile{ - {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/AOD/002/321321", RunNumber: 321321, LHCPeriod: "LHC24f3", AODNumber: 2}, - {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/AOD/002/321326", RunNumber: 321326, LHCPeriod: "LHC24f3", AODNumber: 2}, - {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/AOD/002/321338", RunNumber: 321338, LHCPeriod: "LHC24f3", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/321321/AOD/002", RunNumber: 321321, LHCPeriod: "LHC24f3", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/321326/AOD/002", RunNumber: 321326, LHCPeriod: "LHC24f3", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/321338/AOD/002", RunNumber: 321338, LHCPeriod: "LHC24f3", AODNumber: 2}, }, }, Configuration: "", @@ -350,42 +369,191 @@ func TestTrainingTaskService_UploadToCCDB_Success(t *testing.T) { Name: "local_file_temp.onnx", Path: "./local_file_temp.onnx", Size: 12312, }, TrainingTaskId: ttId}, } - ttRepo.On("GetByID", ttId).Return(&tt, nil) - ttRepo.On("Update", &tt).Return(nil) - ttrRepo.On("GetByType", ttId, models.Onnx).Return(onnxFiles, nil) + ut.TTRepo.On("GetByID", ttId).Return(&tt, nil) + ut.TTRepo.On("Update", &tt).Return(nil) + ut.TTRRepo.On("GetByType", ttId, models.Onnx).Return(onnxFiles, nil) file := &mockReadCloser{} - fileService.On("OpenFile", "./local_file_temp.onnx").Return(file, func(r io.ReadCloser) { r.Close() }, nil) + ut.FileService.On("OpenFile", "./local_file_temp.onnx").Return(file, func(r io.ReadCloser) { r.Close() }, nil) + ut.JAliEnService.On("ListAndParseDirectory", "/alice/sim/2024/LHC24f3/0").Return(&jalien.DirectoryContents{ + Subdirs: []jalien.Dir{ + {Name: "321000", Path: "/alice/sim/2024/LHC24f3/0/321000"}, + {Name: "321100", Path: "/alice/sim/2024/LHC24f3/0/321100"}, + {Name: "321321", Path: "/alice/sim/2024/LHC24f3/0/321321"}, + {Name: "321326", Path: "/alice/sim/2024/LHC24f3/0/321326"}, + {Name: "321338", Path: "/alice/sim/2024/LHC24f3/0/321338"}, + {Name: "321400", Path: "/alice/sim/2024/LHC24f3/0/321400"}, + {Name: "321500", Path: "/alice/sim/2024/LHC24f3/0/321500"}, + }, + }, nil) now := time.Now().UnixMilli() - ccdbService.On("GetRunInformation", uint64(321321)).Return(&ccdb.RunInformation{ - RunNumber: 321321, + ut.CCDBService.On("GetRunInformation", uint64(321000)).Return(&ccdb.RunInformation{ + RunNumber: 321000, SOR: uint64(now - 10000), EOR: uint64(now - 9000), }, nil) - ccdbService.On("GetRunInformation", uint64(321338)).Return(&ccdb.RunInformation{ - RunNumber: 321338, + ut.CCDBService.On("GetRunInformation", uint64(321500)).Return(&ccdb.RunInformation{ + RunNumber: 321500, SOR: uint64(now + 7000), EOR: uint64(now + 10000), }, nil) - ccdbService.On("UploadFile", uint64(now-10000), uint64(now+10000), "uploaded_file.onnx", file).Return(nil) + ut.CCDBService.On("UploadFile", uint64(now-10000), uint64(now+10000), "uploaded_file.onnx", file).Return(nil) // Act err := ttService.UploadOnnxResults(ttId) // Assert assert.NoError(t, err) - ttRepo.AssertCalled(t, "GetByID", ttId) - ttRepo.AssertCalled(t, "Update", &tt) + ut.TTRepo.AssertCalled(t, "GetByID", ttId) + ut.TTRepo.AssertCalled(t, "Update", &tt) assert.Equal(t, models.Uploaded, tt.Status) - ttrRepo.AssertCalled(t, "GetByType", ttId, models.Onnx) - fileService.AssertCalled(t, "OpenFile", "./local_file_temp.onnx") - ccdbService.AssertCalled(t, "GetRunInformation", uint64(321321)) - ccdbService.AssertCalled(t, "GetRunInformation", uint64(321338)) - ccdbService.AssertCalled(t, "UploadFile", uint64(now-10000), uint64(now+10000), "uploaded_file.onnx", file) + ut.TTRRepo.AssertCalled(t, "GetByType", ttId, models.Onnx) + ut.FileService.AssertCalled(t, "OpenFile", "./local_file_temp.onnx") + ut.JAliEnService.AssertCalled(t, "ListAndParseDirectory", "/alice/sim/2024/LHC24f3/0") + ut.CCDBService.AssertCalled(t, "GetRunInformation", uint64(321000)) + ut.CCDBService.AssertCalled(t, "GetRunInformation", uint64(321500)) + ut.CCDBService.AssertCalled(t, "UploadFile", uint64(now-10000), uint64(now+10000), "uploaded_file.onnx", file) +} + +// 3 different periods in one dataset +// expected: take min SOR and max EOR from all periods' runs +func TestTrainingTaskService_UploadToCCDB_ManyPeriods(t *testing.T) { + // Arrange + ttService, ut := newTrainingTaskService() + userId := uint(1) + tdId := uint(1) + ttId := uint(1) + tt := models.TrainingTask{ + Model: gorm.Model{ID: ttId}, + Name: "task2", + UserId: userId, + Status: models.Completed, + TrainingDatasetId: tdId, + TrainingDataset: models.TrainingDataset{ + Model: gorm.Model{ID: tdId}, + UserId: userId, + AODFiles: []jalien.AODFile{ + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24b1b/0/320000/AOD/002", RunNumber: 320000, LHCPeriod: "LHC24b1b", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24b1b/0/320100/AOD/002", RunNumber: 320100, LHCPeriod: "LHC24b1b", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24b1b/0/320200/AOD/002", RunNumber: 320200, LHCPeriod: "LHC24b1b", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/321321/AOD/002", RunNumber: 321321, LHCPeriod: "LHC24f3", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/321326/AOD/002", RunNumber: 321326, LHCPeriod: "LHC24f3", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/321338/AOD/002", RunNumber: 321338, LHCPeriod: "LHC24f3", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24c1/322000/AOD/002", RunNumber: 322000, LHCPeriod: "LHC24c1", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24c1/322100/AOD/002", RunNumber: 322100, LHCPeriod: "LHC24c1", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24c1/322200/AOD/002", RunNumber: 322200, LHCPeriod: "LHC24c1", AODNumber: 2}, + }, + }, + Configuration: "", + } + onnxFiles := []models.TrainingTaskResult{ + {Name: "local_file.onnx", Type: models.Onnx, Description: "example", FileId: 1, File: models.File{ + Name: "local_file_temp.onnx", Path: "./local_file_temp.onnx", Size: 12312, + }, TrainingTaskId: ttId}, + } + + ut.TTRepo.On("GetByID", ttId).Return(&tt, nil) + ut.TTRepo.On("Update", &tt).Return(nil) + ut.TTRRepo.On("GetByType", ttId, models.Onnx).Return(onnxFiles, nil) + + file := &mockReadCloser{} + ut.FileService.On("OpenFile", "./local_file_temp.onnx").Return(file, func(r io.ReadCloser) { r.Close() }, nil) + + // Mock JAliEn and CCDB with timestamp information + now := time.Now().UnixMilli() + // LHC24b1b + ut.JAliEnService.On("ListAndParseDirectory", "/alice/sim/2024/LHC24b1b/0").Return(&jalien.DirectoryContents{ + Subdirs: []jalien.Dir{ + {Name: "319900", Path: "/alice/sim/2024/LHC24b1b/0/319900"}, // min run number LHC24b1b + {Name: "320000", Path: "/alice/sim/2024/LHC24b1b/0/320000"}, + {Name: "320100", Path: "/alice/sim/2024/LHC24b1b/0/320100"}, + {Name: "320200", Path: "/alice/sim/2024/LHC24b1b/0/320200"}, + {Name: "320300", Path: "/alice/sim/2024/LHC24b1b/0/320300"}, // max run number LHC24b1b + }, + }, nil) + ut.CCDBService.On("GetRunInformation", uint64(319900)).Return(&ccdb.RunInformation{ + RunNumber: 319900, + SOR: uint64(now - 20000), + EOR: uint64(now - 19000), + }, nil) + ut.CCDBService.On("GetRunInformation", uint64(320300)).Return(&ccdb.RunInformation{ + RunNumber: 320300, + SOR: uint64(now - 12000), + EOR: uint64(now - 11000), + }, nil) + // LHC24f3 + ut.JAliEnService.On("ListAndParseDirectory", "/alice/sim/2024/LHC24f3/0").Return(&jalien.DirectoryContents{ + Subdirs: []jalien.Dir{ + {Name: "321000", Path: "/alice/sim/2024/LHC24f3/0/321000"}, // min run number LHC24f3 + {Name: "321321", Path: "/alice/sim/2024/LHC24f3/0/321321"}, + {Name: "321326", Path: "/alice/sim/2024/LHC24f3/0/321326"}, + {Name: "321338", Path: "/alice/sim/2024/LHC24f3/0/321338"}, + {Name: "321500", Path: "/alice/sim/2024/LHC24f3/0/321500"}, // max run number LHC24f3 + }, + }, nil) + ut.CCDBService.On("GetRunInformation", uint64(321000)).Return(&ccdb.RunInformation{ + RunNumber: 321000, + SOR: uint64(now - 10000), + EOR: uint64(now - 9000), + }, nil) + ut.CCDBService.On("GetRunInformation", uint64(321500)).Return(&ccdb.RunInformation{ + RunNumber: 321500, + SOR: uint64(now - 2000), + EOR: uint64(now - 1000), + }, nil) + // LHC24c1 + // simulate 2023 datasets (without numbered subdir in period dir) + ut.JAliEnService.On("ListAndParseDirectory", "/alice/sim/2024/LHC24c1").Return(&jalien.DirectoryContents{ + Subdirs: []jalien.Dir{ + {Name: "321900", Path: "/alice/sim/2024/LHC24c1/321900"}, // min run number LHC24c1 + {Name: "322000", Path: "/alice/sim/2024/LHC24c1/322000"}, + {Name: "322100", Path: "/alice/sim/2024/LHC24c1/322100"}, + {Name: "322200", Path: "/alice/sim/2024/LHC24c1/322200"}, + {Name: "322300", Path: "/alice/sim/2024/LHC24c1/322300"}, // max run number LHC24c1 + }, + }, nil) + ut.CCDBService.On("GetRunInformation", uint64(321900)).Return(&ccdb.RunInformation{ + RunNumber: 321900, + SOR: uint64(now), + EOR: uint64(now + 1000), + }, nil) + ut.CCDBService.On("GetRunInformation", uint64(322300)).Return(&ccdb.RunInformation{ + RunNumber: 322300, + SOR: uint64(now + 9000), + EOR: uint64(now + 10000), + }, nil) + + // TODO b1b and c1 + ut.CCDBService.On("UploadFile", uint64(now-20000), uint64(now+10000), "uploaded_file.onnx", file).Return(nil) + + // Act + err := ttService.UploadOnnxResults(ttId) + + // Assert + assert.NoError(t, err) + ut.TTRepo.AssertCalled(t, "GetByID", ttId) + ut.TTRepo.AssertCalled(t, "Update", &tt) + assert.Equal(t, models.Uploaded, tt.Status) + ut.TTRRepo.AssertCalled(t, "GetByType", ttId, models.Onnx) + ut.FileService.AssertCalled(t, "OpenFile", "./local_file_temp.onnx") + // LHC24b1b info + ut.JAliEnService.AssertCalled(t, "ListAndParseDirectory", "/alice/sim/2024/LHC24b1b/0") + ut.CCDBService.AssertCalled(t, "GetRunInformation", uint64(319900)) + ut.CCDBService.AssertCalled(t, "GetRunInformation", uint64(320300)) + // LHC24f3 info + ut.JAliEnService.AssertCalled(t, "ListAndParseDirectory", "/alice/sim/2024/LHC24f3/0") + ut.CCDBService.AssertCalled(t, "GetRunInformation", uint64(321000)) + ut.CCDBService.AssertCalled(t, "GetRunInformation", uint64(321500)) + // LHC24c1 info + ut.JAliEnService.AssertCalled(t, "ListAndParseDirectory", "/alice/sim/2024/LHC24c1") + ut.CCDBService.AssertCalled(t, "GetRunInformation", uint64(321900)) + ut.CCDBService.AssertCalled(t, "GetRunInformation", uint64(322300)) + // Upload + ut.CCDBService.AssertCalled(t, "UploadFile", uint64(now-20000), uint64(now+10000), "uploaded_file.onnx", file) } func TestTrainingTaskService_UploadToCCDB_MissingExpectedFile(t *testing.T) { // Arrange - ttService, ttRepo, _, ttrRepo, ccdbService, fileService, _ := newTrainingTaskService() + ttService, ut := newTrainingTaskService() userId := uint(1) tdId := uint(1) ttId := uint(1) @@ -399,9 +567,9 @@ func TestTrainingTaskService_UploadToCCDB_MissingExpectedFile(t *testing.T) { Model: gorm.Model{ID: tdId}, UserId: userId, AODFiles: []jalien.AODFile{ - {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/AOD/002/321321", RunNumber: 321321, LHCPeriod: "LHC24f3", AODNumber: 2}, - {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/AOD/002/321326", RunNumber: 321326, LHCPeriod: "LHC24f3", AODNumber: 2}, - {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/AOD/002/321338", RunNumber: 321338, LHCPeriod: "LHC24f3", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/321321/AOD/002", RunNumber: 321321, LHCPeriod: "LHC24f3", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/321326/AOD/002", RunNumber: 321326, LHCPeriod: "LHC24f3", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/321338/AOD/002", RunNumber: 321338, LHCPeriod: "LHC24f3", AODNumber: 2}, }, }, Configuration: "", @@ -411,23 +579,34 @@ func TestTrainingTaskService_UploadToCCDB_MissingExpectedFile(t *testing.T) { Name: "local_file_temp.onnx", Path: "./local_file_temp.onnx", Size: 12312, }, TrainingTaskId: ttId}, } - ttRepo.On("GetByID", ttId).Return(&tt, nil) - ttRepo.On("Update", &tt).Return(nil) - ttrRepo.On("GetByType", ttId, models.Onnx).Return(onnxFiles, nil) + ut.TTRepo.On("GetByID", ttId).Return(&tt, nil) + ut.TTRepo.On("Update", &tt).Return(nil) + ut.TTRRepo.On("GetByType", ttId, models.Onnx).Return(onnxFiles, nil) file := &mockReadCloser{} - fileService.On("OpenFile", "./local_file_temp.onnx").Return(file, func(r io.ReadCloser) { r.Close() }, nil) + ut.FileService.On("OpenFile", "./local_file_temp.onnx").Return(file, func(r io.ReadCloser) { r.Close() }, nil) + ut.JAliEnService.On("ListAndParseDirectory", "/alice/sim/2024/LHC24f3/0").Return(&jalien.DirectoryContents{ + Subdirs: []jalien.Dir{ + {Name: "321000", Path: "/alice/sim/2024/LHC24f3/0/321000"}, + {Name: "321100", Path: "/alice/sim/2024/LHC24f3/0/321100"}, + {Name: "321321", Path: "/alice/sim/2024/LHC24f3/0/321321"}, + {Name: "321326", Path: "/alice/sim/2024/LHC24f3/0/321326"}, + {Name: "321338", Path: "/alice/sim/2024/LHC24f3/0/321338"}, + {Name: "321400", Path: "/alice/sim/2024/LHC24f3/0/321400"}, + {Name: "321500", Path: "/alice/sim/2024/LHC24f3/0/321500"}, + }, + }, nil) now := time.Now().UnixMilli() - ccdbService.On("GetRunInformation", uint64(321321)).Return(&ccdb.RunInformation{ - RunNumber: 321321, + ut.CCDBService.On("GetRunInformation", uint64(321000)).Return(&ccdb.RunInformation{ + RunNumber: 321000, SOR: uint64(now - 10000), EOR: uint64(now - 9000), }, nil) - ccdbService.On("GetRunInformation", uint64(321338)).Return(&ccdb.RunInformation{ - RunNumber: 321338, + ut.CCDBService.On("GetRunInformation", uint64(321500)).Return(&ccdb.RunInformation{ + RunNumber: 321500, SOR: uint64(now + 7000), EOR: uint64(now + 10000), }, nil) - ccdbService.On("UploadFile", uint64(now-10000), uint64(now+10000), "uploaded_file.onnx", file).Return(nil) + ut.CCDBService.On("UploadFile", uint64(now-10000), uint64(now+10000), "uploaded_file.onnx", file).Return(nil) // Act err := ttService.UploadOnnxResults(ttId) @@ -435,18 +614,19 @@ func TestTrainingTaskService_UploadToCCDB_MissingExpectedFile(t *testing.T) { // Assert assert.Error(t, err) assert.Contains(t, err.Error(), "TrainingTask's result file: local_file.onnx not found") - ttRepo.AssertCalled(t, "GetByID", ttId) - ttRepo.AssertNotCalled(t, "Update", &tt) - ttrRepo.AssertCalled(t, "GetByType", ttId, models.Onnx) - fileService.AssertNotCalled(t, "OpenFile", "./local_file_temp.onnx") - ccdbService.AssertCalled(t, "GetRunInformation", uint64(321321)) - ccdbService.AssertCalled(t, "GetRunInformation", uint64(321338)) - ccdbService.AssertNotCalled(t, "UploadFile", uint64(now-10000), uint64(now+10000), "uploaded_file.onnx", file) + ut.TTRepo.AssertCalled(t, "GetByID", ttId) + ut.TTRepo.AssertNotCalled(t, "Update", &tt) + ut.TTRRepo.AssertCalled(t, "GetByType", ttId, models.Onnx) + ut.FileService.AssertNotCalled(t, "OpenFile", "./local_file_temp.onnx") + ut.JAliEnService.AssertCalled(t, "ListAndParseDirectory", "/alice/sim/2024/LHC24f3/0") + ut.CCDBService.AssertCalled(t, "GetRunInformation", uint64(321000)) + ut.CCDBService.AssertCalled(t, "GetRunInformation", uint64(321500)) + ut.CCDBService.AssertNotCalled(t, "UploadFile", uint64(now-10000), uint64(now+10000), "uploaded_file.onnx", file) } func TestTrainingTaskService_UploadToCCDB_ErrorReadingFile(t *testing.T) { // Arrange - ttService, ttRepo, _, ttrRepo, ccdbService, fileService, _ := newTrainingTaskService() + ttService, ut := newTrainingTaskService() userId := uint(1) tdId := uint(1) ttId := uint(1) @@ -460,9 +640,9 @@ func TestTrainingTaskService_UploadToCCDB_ErrorReadingFile(t *testing.T) { Model: gorm.Model{ID: tdId}, UserId: userId, AODFiles: []jalien.AODFile{ - {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/AOD/002/321321", RunNumber: 321321, LHCPeriod: "LHC24f3", AODNumber: 2}, - {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/AOD/002/321326", RunNumber: 321326, LHCPeriod: "LHC24f3", AODNumber: 2}, - {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/AOD/002/321338", RunNumber: 321338, LHCPeriod: "LHC24f3", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/321321/AOD/002", RunNumber: 321321, LHCPeriod: "LHC24f3", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/321326/AOD/002", RunNumber: 321326, LHCPeriod: "LHC24f3", AODNumber: 2}, + {Name: "AO2D.root", Path: "/alice/sim/2024/LHC24f3/0/321338/AOD/002", RunNumber: 321338, LHCPeriod: "LHC24f3", AODNumber: 2}, }, }, Configuration: "", @@ -472,23 +652,34 @@ func TestTrainingTaskService_UploadToCCDB_ErrorReadingFile(t *testing.T) { Name: "local_file_temp.onnx", Path: "./local_file_temp.onnx", Size: 12312, }, TrainingTaskId: ttId}, } - ttRepo.On("GetByID", ttId).Return(&tt, nil) - ttRepo.On("Update", &tt).Return(nil) - ttrRepo.On("GetByType", ttId, models.Onnx).Return(onnxFiles, nil) + ut.TTRepo.On("GetByID", ttId).Return(&tt, nil) + ut.TTRepo.On("Update", &tt).Return(nil) + ut.TTRRepo.On("GetByType", ttId, models.Onnx).Return(onnxFiles, nil) file := &mockReadCloser{} - fileService.On("OpenFile", "./local_file_temp.onnx").Return(nil, nil, errors.New("error reading file")) + ut.FileService.On("OpenFile", "./local_file_temp.onnx").Return(nil, nil, errors.New("error reading file")) + ut.JAliEnService.On("ListAndParseDirectory", "/alice/sim/2024/LHC24f3/0").Return(&jalien.DirectoryContents{ + Subdirs: []jalien.Dir{ + {Name: "321000", Path: "/alice/sim/2024/LHC24f3/0/321000"}, + {Name: "321100", Path: "/alice/sim/2024/LHC24f3/0/321100"}, + {Name: "321321", Path: "/alice/sim/2024/LHC24f3/0/321321"}, + {Name: "321326", Path: "/alice/sim/2024/LHC24f3/0/321326"}, + {Name: "321338", Path: "/alice/sim/2024/LHC24f3/0/321338"}, + {Name: "321400", Path: "/alice/sim/2024/LHC24f3/0/321400"}, + {Name: "321500", Path: "/alice/sim/2024/LHC24f3/0/321500"}, + }, + }, nil) now := time.Now().UnixMilli() - ccdbService.On("GetRunInformation", uint64(321321)).Return(&ccdb.RunInformation{ - RunNumber: 321321, + ut.CCDBService.On("GetRunInformation", uint64(321000)).Return(&ccdb.RunInformation{ + RunNumber: 321000, SOR: uint64(now - 10000), EOR: uint64(now - 9000), }, nil) - ccdbService.On("GetRunInformation", uint64(321338)).Return(&ccdb.RunInformation{ - RunNumber: 321338, + ut.CCDBService.On("GetRunInformation", uint64(321500)).Return(&ccdb.RunInformation{ + RunNumber: 321500, SOR: uint64(now + 7000), EOR: uint64(now + 10000), }, nil) - ccdbService.On("UploadFile", uint64(now-10000), uint64(now+10000), "uploaded_file.onnx", file).Return(nil) + ut.CCDBService.On("UploadFile", uint64(now-10000), uint64(now+10000), "uploaded_file.onnx", file).Return(nil) // Act err := ttService.UploadOnnxResults(ttId) @@ -496,11 +687,12 @@ func TestTrainingTaskService_UploadToCCDB_ErrorReadingFile(t *testing.T) { // Assert assert.Error(t, err) assert.Contains(t, err.Error(), "error reading file") - ttRepo.AssertCalled(t, "GetByID", ttId) - ttRepo.AssertNotCalled(t, "Update", &tt) - ttrRepo.AssertCalled(t, "GetByType", ttId, models.Onnx) - fileService.AssertCalled(t, "OpenFile", "./local_file_temp.onnx") - ccdbService.AssertCalled(t, "GetRunInformation", uint64(321321)) - ccdbService.AssertCalled(t, "GetRunInformation", uint64(321338)) - ccdbService.AssertNotCalled(t, "UploadFile", uint64(now-10000), uint64(now+10000), "uploaded_file.onnx", file) + ut.TTRepo.AssertCalled(t, "GetByID", ttId) + ut.TTRepo.AssertNotCalled(t, "Update", &tt) + ut.TTRRepo.AssertCalled(t, "GetByType", ttId, models.Onnx) + ut.FileService.AssertCalled(t, "OpenFile", "./local_file_temp.onnx") + ut.JAliEnService.AssertCalled(t, "ListAndParseDirectory", "/alice/sim/2024/LHC24f3/0") + ut.CCDBService.AssertCalled(t, "GetRunInformation", uint64(321000)) + ut.CCDBService.AssertCalled(t, "GetRunInformation", uint64(321500)) + ut.CCDBService.AssertNotCalled(t, "UploadFile", uint64(now-10000), uint64(now+10000), "uploaded_file.onnx", file) }