|
5 | 5 | from scaleapi.exceptions import ScaleInvalidRequest
|
6 | 6 | from scaleapi.files import File
|
7 | 7 | from scaleapi.projects import Project
|
| 8 | +from scaleapi.training_tasks import TrainingTask |
8 | 9 |
|
9 | 10 | from ._version import __version__ # noqa: F401
|
10 | 11 | from .api import Api
|
@@ -794,10 +795,12 @@ def create_evaluation_task(
|
794 | 795 | task_type: TaskType,
|
795 | 796 | **kwargs,
|
796 | 797 | ) -> EvaluationTask:
|
797 |
| - """This method can only be used for Self-Serve projects. |
| 798 | + """This method can only be used for Rapid projects. |
798 | 799 | Supported Task Types: [
|
| 800 | + DocumentTranscription, |
| 801 | + SegmentAnnotation, |
| 802 | + VideoPlaybackAnnotation, |
799 | 803 | ImageAnnotation,
|
800 |
| - Categorization, |
801 | 804 | TextCollection,
|
802 | 805 | NamedEntityRecognition
|
803 | 806 | ]
|
@@ -827,3 +830,38 @@ def create_evaluation_task(
|
827 | 830 |
|
828 | 831 | evaluation_task_data = self.api.post_request(endpoint, body=kwargs)
|
829 | 832 | return EvaluationTask(evaluation_task_data, self)
|
| 833 | + |
| 834 | + def create_training_task( |
| 835 | + self, |
| 836 | + task_type: TaskType, |
| 837 | + **kwargs, |
| 838 | + ) -> TrainingTask: |
| 839 | + """This method can only be used for Rapid projects. |
| 840 | + Supported Task Types: [ |
| 841 | + DocumentTranscription, |
| 842 | + SegmentAnnotation, |
| 843 | + VideoPlaybackAnnotation, |
| 844 | + ImageAnnotation, |
| 845 | + TextCollection, |
| 846 | + NamedEntityRecognition |
| 847 | + ] |
| 848 | + Parameters may differ based on the given task_type. |
| 849 | +
|
| 850 | + Args: |
| 851 | + task_type (TaskType): |
| 852 | + Task type to be created |
| 853 | + e.g.. `TaskType.ImageAnnotation` |
| 854 | + **kwargs: |
| 855 | + The same set of parameters are expected with |
| 856 | + create_task function and an additional expected_response. |
| 857 | + Scale's API documentation. |
| 858 | + https://docs.scale.com/reference |
| 859 | +
|
| 860 | + Returns: |
| 861 | + TrainingTask: |
| 862 | + Returns created training task. |
| 863 | + """ |
| 864 | + endpoint = f"training_tasks/{task_type.value}" |
| 865 | + |
| 866 | + training_task_data = self.api.post_request(endpoint, body=kwargs) |
| 867 | + return TrainingTask(training_task_data, self) |
0 commit comments