|
| 1 | +from datetime import datetime, timezone |
1 | 2 | from unittest import mock |
2 | 3 |
|
3 | 4 | import pytest |
4 | | -from crowdin_api.api_resources.ai.enums import AIPromptAction, AIProviderType |
| 5 | +from crowdin_api.api_resources.ai.enums import AIPromptAction, AIProviderType, DatasetPurpose |
5 | 6 | from crowdin_api.api_resources.ai.resource import AIResource, EnterpriseAIResource |
6 | | -from crowdin_api.api_resources.ai.types import AIPromptOperation, EditAIPromptPath |
| 7 | +from crowdin_api.api_resources.ai.types import ( |
| 8 | + AIPromptOperation, |
| 9 | + EditAIPromptPath, |
| 10 | + CreateAIPromptFineTuningJobRequest, |
| 11 | + HyperParameters, |
| 12 | + TrainingOptions, GenerateAIPromptFineTuningDatasetRequest |
| 13 | +) |
7 | 14 | from crowdin_api.requester import APIRequester |
8 | 15 |
|
9 | 16 |
|
@@ -405,6 +412,176 @@ def test_create_ai_proxy_chat_completion(self, m_request, base_absolut_url): |
405 | 412 | request_data=request_data, |
406 | 413 | ) |
407 | 414 |
|
| 415 | + @pytest.mark.parametrize( |
| 416 | + "incoming_data, request_data", |
| 417 | + ( |
| 418 | + ( |
| 419 | + GenerateAIPromptFineTuningDatasetRequest( |
| 420 | + projectIds=[1], |
| 421 | + tmIds=[2, 3], |
| 422 | + purpose=DatasetPurpose.TRAINING.value, |
| 423 | + dateFrom=datetime(2019, 9, 23, 11, 26, 54, |
| 424 | + tzinfo=timezone.utc).isoformat(), |
| 425 | + dateTo=datetime(2019, 9, 23, 11, 26, 54, |
| 426 | + tzinfo=timezone.utc).isoformat(), |
| 427 | + maxFileSize=20, |
| 428 | + minExamplesCount=2, |
| 429 | + maxExamplesCount=10 |
| 430 | + ), |
| 431 | + { |
| 432 | + "projectIds": [ |
| 433 | + 1 |
| 434 | + ], |
| 435 | + "tmIds": [ |
| 436 | + 2, 3 |
| 437 | + ], |
| 438 | + "purpose": "training", |
| 439 | + "dateFrom": "2019-09-23T11:26:54+00:00", |
| 440 | + "dateTo": "2019-09-23T11:26:54+00:00", |
| 441 | + "maxFileSize": 20, |
| 442 | + "minExamplesCount": 2, |
| 443 | + "maxExamplesCount": 10 |
| 444 | + } |
| 445 | + ), |
| 446 | + ), |
| 447 | + ) |
| 448 | + @mock.patch("crowdin_api.requester.APIRequester.request") |
| 449 | + def test_generate_ai_prompt_fine_tuning_dataset(self, m_request, incoming_data, request_data, base_absolut_url): |
| 450 | + m_request.return_value = "response" |
| 451 | + |
| 452 | + user_id = 1 |
| 453 | + ai_prompt_id = 2 |
| 454 | + |
| 455 | + resource = self.get_resource(base_absolut_url) |
| 456 | + assert ( |
| 457 | + resource.generate_ai_prompt_fine_tuning_dataset(user_id, ai_prompt_id, request_data=incoming_data) |
| 458 | + == "response" |
| 459 | + ) |
| 460 | + m_request.assert_called_once_with( |
| 461 | + method="post", |
| 462 | + path=f"users/{user_id}/ai/prompts/{ai_prompt_id}/fine-tuning/datasets", |
| 463 | + request_data=request_data, |
| 464 | + ) |
| 465 | + |
| 466 | + @mock.patch("crowdin_api.requester.APIRequester.request") |
| 467 | + def test_get_ai_prompt_fine_tuning_dataset_generation_status(self, m_request, base_absolut_url): |
| 468 | + m_request.return_value = "response" |
| 469 | + |
| 470 | + user_id = 1 |
| 471 | + ai_prompt_id = 2 |
| 472 | + job_identifier = "id" |
| 473 | + |
| 474 | + resource = self.get_resource(base_absolut_url) |
| 475 | + assert ( |
| 476 | + resource.get_ai_prompt_fine_tuning_dataset_generation_status(user_id, ai_prompt_id, job_identifier) |
| 477 | + == "response" |
| 478 | + ) |
| 479 | + m_request.assert_called_once_with( |
| 480 | + method="get", |
| 481 | + path=f"users/{user_id}/ai/prompts/{ai_prompt_id}/fine-tuning/datasets/{job_identifier}", |
| 482 | + ) |
| 483 | + |
| 484 | + @pytest.mark.parametrize( |
| 485 | + "incoming_data, request_data", |
| 486 | + ( |
| 487 | + ( |
| 488 | + CreateAIPromptFineTuningJobRequest( |
| 489 | + dryRun=False, |
| 490 | + hyperparameters=HyperParameters( |
| 491 | + batchSize=1, |
| 492 | + learningRateMultiplier=2.0, |
| 493 | + nEpochs=100, |
| 494 | + ), |
| 495 | + trainingOptions=TrainingOptions( |
| 496 | + projectIds=[1], |
| 497 | + tmIds=[2], |
| 498 | + dateFrom=datetime(2019, 9, 23, 11, 26, 54, |
| 499 | + tzinfo=timezone.utc).isoformat(), |
| 500 | + dateTo=datetime(2019, 9, 23, 11, 26, 54, |
| 501 | + tzinfo=timezone.utc).isoformat(), |
| 502 | + maxFileSize=10, |
| 503 | + minExamplesCount=200, |
| 504 | + maxExamplesCount=300 |
| 505 | + ) |
| 506 | + ), |
| 507 | + { |
| 508 | + "dryRun": False, |
| 509 | + "hyperparameters": { |
| 510 | + "batchSize": 1, |
| 511 | + "learningRateMultiplier": 2.0, |
| 512 | + "nEpochs": 100, |
| 513 | + }, |
| 514 | + "trainingOptions": { |
| 515 | + "projectIds": [1], |
| 516 | + "tmIds": [2], |
| 517 | + "dateFrom": "2019-09-23T11:26:54+00:00", |
| 518 | + "dateTo": "2019-09-23T11:26:54+00:00", |
| 519 | + "maxFileSize": 10, |
| 520 | + "minExamplesCount": 200, |
| 521 | + "maxExamplesCount": 300 |
| 522 | + } |
| 523 | + } |
| 524 | + ), |
| 525 | + ), |
| 526 | + ) |
| 527 | + @mock.patch("crowdin_api.requester.APIRequester.request") |
| 528 | + def test_create_ai_prompt_fine_tuning_job(self, m_request, incoming_data, request_data, base_absolut_url): |
| 529 | + m_request.return_value = "response" |
| 530 | + |
| 531 | + user_id = 1 |
| 532 | + ai_prompt_id = 2 |
| 533 | + |
| 534 | + resource = self.get_resource(base_absolut_url) |
| 535 | + assert ( |
| 536 | + resource.create_ai_prompt_fine_tuning_job(user_id, ai_prompt_id, request_data=incoming_data) |
| 537 | + == "response" |
| 538 | + ) |
| 539 | + m_request.assert_called_once_with( |
| 540 | + method="post", |
| 541 | + path=f"users/{user_id}/ai/prompts/{ai_prompt_id}/fine-tuning/jobs", |
| 542 | + request_data=request_data, |
| 543 | + ) |
| 544 | + |
| 545 | + @mock.patch("crowdin_api.requester.APIRequester.request") |
| 546 | + def test_get_ai_prompt_fine_tuning_job_status(self, m_request, base_absolut_url): |
| 547 | + m_request.return_value = "response" |
| 548 | + |
| 549 | + user_id = 1 |
| 550 | + ai_prompt_id = 2 |
| 551 | + job_identifier = "id" |
| 552 | + |
| 553 | + resource = self.get_resource(base_absolut_url) |
| 554 | + assert ( |
| 555 | + resource.get_ai_prompt_fine_tuning_job_status(user_id, ai_prompt_id, job_identifier) |
| 556 | + == "response" |
| 557 | + ) |
| 558 | + m_request.assert_called_once_with( |
| 559 | + method="get", |
| 560 | + path=f"users/{user_id}/ai/prompts/{ai_prompt_id}/fine-tuning/jobs/{job_identifier}", |
| 561 | + ) |
| 562 | + |
| 563 | + @mock.patch("crowdin_api.requester.APIRequester.request") |
| 564 | + def test_download_ai_prompt_fine_tuning_dataset( |
| 565 | + self, |
| 566 | + m_request, |
| 567 | + base_absolut_url |
| 568 | + ): |
| 569 | + m_request.return_value = "response" |
| 570 | + |
| 571 | + user_id = 1 |
| 572 | + ai_prompt_id = 2 |
| 573 | + job_identifier = "id" |
| 574 | + |
| 575 | + resource = self.get_resource(base_absolut_url) |
| 576 | + assert ( |
| 577 | + resource.download_ai_prompt_fine_tuning_dataset(user_id, ai_prompt_id, job_identifier) |
| 578 | + == "response" |
| 579 | + ) |
| 580 | + m_request.assert_called_once_with( |
| 581 | + method="get", |
| 582 | + path=f"users/{user_id}/ai/prompts/{ai_prompt_id}/fine-tuning/datasets/{job_identifier}/download", |
| 583 | + ) |
| 584 | + |
408 | 585 |
|
409 | 586 | class TestEnterpriseAIResources: |
410 | 587 | resource_class = EnterpriseAIResource |
|
0 commit comments