Skip to content

Commit c96f236

Browse files
authored
Merge pull request #1 from WenWeiTHU/mlnode/forecast
Mlnode/forecast
2 parents 6da2799 + 111bba6 commit c96f236

File tree

4 files changed

+216
-104
lines changed

4 files changed

+216
-104
lines changed

mlnode/iotdb/mlnode/handler.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
#
18-
from iotdb.mlnode.algorithm.factory import create_forecast_model
18+
1919
from iotdb.mlnode.constant import TSStatusCode
2020
from iotdb.mlnode.data_access.factory import create_forecast_dataset
2121
from iotdb.mlnode.parser import parse_training_request
@@ -51,13 +51,8 @@ def createTrainingTask(self, req: TCreateTrainingTaskReq):
5151
model_config['input_vars'] = data_config['input_vars']
5252
model_config['output_vars'] = data_config['output_vars']
5353

54-
# create model & check model config legitimacy
55-
model, model_config = create_forecast_model(**model_config)
56-
57-
model_config['input_vars'] = data_config['input_vars']
58-
model_config['output_vars'] = data_config['output_vars']
5954
# create task & check task config legitimacy
60-
task = self.__task_manager.create_training_task(dataset, model, model_config, task_config)
55+
task = self.__task_manager.create_training_task(dataset, data_config, model_config, task_config)
6156

6257
return get_status(TSStatusCode.SUCCESS_STATUS)
6358
except Exception as e:

mlnode/iotdb/mlnode/process/manager.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,10 @@
1818

1919
import multiprocessing as mp
2020

21-
from typing import Dict
22-
23-
from torch import nn
21+
from typing import Dict, Union
2422
from torch.utils.data import Dataset
25-
2623
from iotdb.mlnode.log import logger
27-
from iotdb.mlnode.process.task import ForecastingTrainingTask
24+
from iotdb.mlnode.process.task import ForecastingSingleTrainingTask, ForecastingTuningTrainingTask
2825

2926

3027
class TaskManager(object):
@@ -43,14 +40,14 @@ def __init__(self, pool_size: int):
4340

4441
def create_training_task(self,
4542
dataset: Dataset,
46-
model: nn.Module,
43+
data_configs: Dict,
4744
model_configs: Dict,
48-
task_configs: Dict) -> ForecastingTrainingTask:
45+
task_configs: Dict):
4946
"""
5047
5148
Args:
5249
dataset: a torch dataset to be used for training
53-
model: torch.nn.Module
50+
data_configs: dict of data configurations
5451
model_configs: dict of model configurations
5552
task_configs: dict of task configurations
5653
@@ -59,16 +56,27 @@ def create_training_task(self,
5956
"""
6057
model_id = task_configs['model_id']
6158
self.__pid_info[model_id] = self.__shared_resource_manager.dict()
62-
task = ForecastingTrainingTask(
63-
task_configs,
64-
model_configs,
65-
model,
66-
dataset,
67-
self.__pid_info
68-
)
59+
if task_configs['tuning']:
60+
task = ForecastingTuningTrainingTask(
61+
task_configs,
62+
model_configs,
63+
self.__pid_info,
64+
data_configs,
65+
dataset,
66+
model_id,
67+
)
68+
else:
69+
task = ForecastingSingleTrainingTask(
70+
task_configs,
71+
model_configs,
72+
self.__pid_info,
73+
data_configs,
74+
dataset,
75+
model_id,
76+
)
6977
return task
7078

71-
def submit_training_task(self, task: ForecastingTrainingTask) -> None:
79+
def submit_training_task(self, task: Union[ForecastingTuningTrainingTask, ForecastingSingleTrainingTask]) -> None:
7280
if task is not None:
7381
self.__training_process_pool.apply_async(task, args=())
7482
logger.info(f'Task: ({task.model_id}) - Training process submitted successfully')

mlnode/iotdb/mlnode/process/task.py

Lines changed: 146 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,42 +18,49 @@
1818
import multiprocessing
1919
import os
2020
from abc import abstractmethod
21-
from typing import Dict
21+
from typing import Dict, Tuple
2222

2323
import optuna
2424
from torch import nn
2525
from torch.utils.data import Dataset
26+
from multiprocessing import Pipe
2627

2728
from iotdb.mlnode.log import logger
2829
from iotdb.mlnode.process.trial import ForecastingTrainingTrial
2930
from iotdb.mlnode.algorithm.factory import create_forecast_model
30-
from iotdb.mlnode.client import client_manager
31+
from iotdb.mlnode.client import client_manager, ConfigNodeClient
3132
from iotdb.mlnode.config import descriptor
3233
from iotdb.thrift.common.ttypes import TrainingState
3334

3435

35-
class TrainingTrialObjective:
36+
class ForestingTrainingObjective:
3637
"""
3738
A class which serve as a function, should accept trial as args
3839
and return the optimization objective.
3940
Optuna will try to minimize the objective.
4041
"""
4142

42-
def __init__(self, trial_configs: Dict, model_configs: Dict, dataset: Dataset, pid_info: Dict):
43+
def __init__(
44+
self,
45+
trial_configs: Dict,
46+
model_configs: Dict,
47+
dataset: Dataset,
48+
# pid_info: Dict
49+
):
4350
self.trial_configs = trial_configs
4451
self.model_configs = model_configs
4552
self.dataset = dataset
46-
self.pid_info = pid_info
53+
# self.pid_info = pid_info
4754

4855
def __call__(self, trial: optuna.Trial):
4956
# TODO: decide which parameters to tune
5057
trial_configs = self.trial_configs
5158
trial_configs['learning_rate'] = trial.suggest_float("lr", 1e-7, 1e-1, log=True)
5259
trial_configs['trial_id'] = 'tid_' + str(trial._trial_id)
5360
# TODO: check args
54-
model, model_cfg = create_forecast_model(**self.model_configs)
55-
self.pid_info[self.trial_configs['model_id']][trial._trial_id] = os.getpid()
56-
_trial = ForecastingTrainingTrial(trial_configs, model, self.model_configs, self.dataset)
61+
model, model_configs = create_forecast_model(**self.model_configs)
62+
# self.pid_info[self.trial_configs['model_id']][trial._trial_id] = os.getpid()
63+
_trial = ForecastingTrainingTrial(trial_configs, model, model_configs, self.dataset)
5764
loss = _trial.start()
5865
return loss
5966

@@ -65,71 +72,163 @@ class _BasicTask(object):
6572
"""
6673

6774
def __init__(
68-
self,
69-
task_configs: Dict,
70-
model_configs: Dict,
71-
model: nn.Module,
72-
dataset: Dataset,
73-
pid_info: Dict
75+
self,
76+
task_configs: Dict,
77+
model_configs: Dict,
78+
pid_info: Dict
7479
):
7580
"""
7681
Args:
7782
task_configs:
7883
model_configs:
79-
model:
80-
dataset:
8184
pid_info:
8285
"""
8386
self.pid_info = pid_info
8487
self.task_configs = task_configs
8588
self.model_configs = model_configs
86-
self.model = model
89+
90+
@abstractmethod
91+
def __call__(self):
92+
raise NotImplementedError
93+
94+
95+
class _BasicTrainingTask(_BasicTask):
96+
def __init__(
97+
self,
98+
task_configs: Dict,
99+
model_configs: Dict,
100+
pid_info: Dict,
101+
data_configs: Dict,
102+
dataset: Dataset,
103+
):
104+
"""
105+
Args:
106+
task_configs:
107+
model_configs:
108+
pid_info:
109+
data_configs:
110+
dataset:
111+
"""
112+
super().__init__(task_configs, model_configs, pid_info)
113+
self.data_configs = data_configs
87114
self.dataset = dataset
115+
self.confignode_client = client_manager.borrow_config_node_client()
116+
117+
@abstractmethod
118+
def __call__(self):
119+
raise NotImplementedError
120+
121+
122+
class _BasicInferenceTask(_BasicTask):
123+
def __int__(
124+
self,
125+
task_configs: Dict,
126+
model_configs: Dict,
127+
pid_info: Dict,
128+
data: Tuple,
129+
):
130+
super().__init__(task_configs, model_configs, pid_info)
131+
self.data = data
132+
self.model, self.model_configs = create_forecast_model(**self.model_configs)
88133

89134
@abstractmethod
90135
def __call__(self):
91136
raise NotImplementedError
92137

138+
@abstractmethod
139+
def data_align(self):
140+
raise NotImplementedError
141+
93142

94-
class ForecastingTrainingTask(_BasicTask):
95-
def __init__(self, task_configs: Dict, model_configs: Dict, model: nn.Module, dataset: Dataset,
96-
pid_info: Dict):
143+
class ForecastingSingleTrainingTask(_BasicTrainingTask):
144+
def __init__(
145+
self,
146+
task_configs: Dict,
147+
model_configs: Dict,
148+
pid_info: Dict,
149+
data_configs: Dict,
150+
dataset: Dataset,
151+
model_id: str,
152+
):
97153
"""
98154
Args:
99155
task_configs: dict of task configurations
100156
model_configs: dict of model configurations
101-
model: nn.Module
102-
dataset: training dataset
103157
pid_info: a map shared between processes, can be used to find the pid with model_id and trial_id
158+
data_configs: dict of data configurations
159+
dataset: training dataset
104160
"""
105-
super(ForecastingTrainingTask, self).__init__(task_configs, model_configs, model, dataset, pid_info)
106-
self.model_id = self.task_configs['model_id']
107-
self.tuning = self.task_configs['tuning']
108-
self.confignode_client = client_manager.borrow_config_node_client()
109-
110-
if self.tuning:
111-
self.study = optuna.create_study(direction='minimize')
112-
else:
113-
self.default_trial_id = 'tid_0'
114-
self.task_configs['trial_id'] = self.default_trial_id
115-
self.trial = ForecastingTrainingTrial(self.task_configs, self.model, self.model_configs, self.dataset)
116-
self.pid_info[self.model_id][self.default_trial_id] = os.getpid()
161+
super().__init__(task_configs, model_configs, pid_info, data_configs, dataset)
162+
self.model_id = model_id
163+
self.default_trial_id = 'tid_0'
164+
self.task_configs['trial_id'] = self.default_trial_id
165+
model, model_configs = create_forecast_model(**model_configs)
166+
self.trial = ForecastingTrainingTrial(task_configs, model, model_configs, dataset)
167+
self.pid_info[self.model_id][self.default_trial_id] = os.getpid()
117168

118169
def __call__(self):
119170
try:
120-
if self.tuning:
121-
self.study.optimize(TrainingTrialObjective(
122-
self.task_configs,
123-
self.model_configs,
124-
self.dataset,
125-
self.pid_info),
126-
n_trials=descriptor.get_config().get_mn_tuning_trial_num(),
127-
n_jobs=descriptor.get_config().get_mn_tuning_trial_concurrency())
128-
best_trial_id = 'tid_' + str(self.study.best_trial._trial_id)
129-
self.confignode_client.update_model_state(self.model_id, TrainingState.FINISHED, best_trial_id)
130-
else:
131-
self.trial.start()
132-
self.confignode_client.update_model_state(self.model_id, TrainingState.FINISHED, self.default_trial_id)
171+
self.trial.start()
172+
self.confignode_client.update_model_state(self.model_id, TrainingState.FINISHED, self.default_trial_id)
133173
except Exception as e:
134174
logger.warn(e)
135175
raise e
176+
177+
178+
class ForecastingTuningTrainingTask(_BasicTrainingTask):
179+
def __init__(
180+
self,
181+
task_configs: Dict,
182+
model_configs: Dict,
183+
pid_info: Dict,
184+
data_configs: Dict,
185+
dataset: Dataset,
186+
model_id: str,
187+
):
188+
"""
189+
Args:
190+
task_configs: dict of task configurations
191+
model_configs: dict of model configurations
192+
pid_info: a map shared between processes, can be used to find the pid with model_id and trial_id
193+
data_configs: dict of data configurations
194+
dataset: training dataset
195+
"""
196+
super().__init__(task_configs, model_configs, pid_info, data_configs, dataset)
197+
self.model_id = model_id
198+
self.study = optuna.create_study(direction='minimize')
199+
200+
def __call__(self):
201+
# try:
202+
self.study.optimize(ForestingTrainingObjective(
203+
self.task_configs,
204+
self.model_configs,
205+
self.dataset),
206+
n_trials=descriptor.get_config().get_mn_tuning_trial_num(),
207+
n_jobs=descriptor.get_config().get_mn_tuning_trial_concurrency())
208+
best_trial_id = 'tid_' + str(self.study.best_trial._trial_id)
209+
self.confignode_client.update_model_state(self.model_id, TrainingState.FINISHED, best_trial_id)
210+
#
211+
# except Exception as e:
212+
# logger.warn(e)
213+
# raise e
214+
215+
216+
class ForecastingInferenceTask(_BasicInferenceTask):
217+
def __int__(
218+
self,
219+
task_configs: Dict,
220+
model_configs: Dict,
221+
pid_info: Dict,
222+
data: Tuple,
223+
pipe: Pipe,
224+
):
225+
super().__init__(task_configs, model_configs, pid_info, data)
226+
227+
def __call__(self):
228+
pass
229+
230+
def data_align(self):
231+
pass
232+
233+
def generate_future_mark(self):
234+
pass

0 commit comments

Comments
 (0)