1818import multiprocessing
1919import os
2020from abc import abstractmethod
21- from typing import Dict
21+ from typing import Dict , Tuple
2222
2323import optuna
2424from torch import nn
2525from torch .utils .data import Dataset
26+ from multiprocessing import Pipe
2627
2728from iotdb .mlnode .log import logger
2829from iotdb .mlnode .process .trial import ForecastingTrainingTrial
2930from iotdb .mlnode .algorithm .factory import create_forecast_model
30- from iotdb .mlnode .client import client_manager
31+ from iotdb .mlnode .client import client_manager , ConfigNodeClient
3132from iotdb .mlnode .config import descriptor
3233from iotdb .thrift .common .ttypes import TrainingState
3334
3435
35- class TrainingTrialObjective :
36+ class ForestingTrainingObjective :
3637 """
3738 A class which serve as a function, should accept trial as args
3839 and return the optimization objective.
3940 Optuna will try to minimize the objective.
4041 """
4142
42- def __init__ (self , trial_configs : Dict , model_configs : Dict , dataset : Dataset , pid_info : Dict ):
43+ def __init__ (
44+ self ,
45+ trial_configs : Dict ,
46+ model_configs : Dict ,
47+ dataset : Dataset ,
48+ # pid_info: Dict
49+ ):
4350 self .trial_configs = trial_configs
4451 self .model_configs = model_configs
4552 self .dataset = dataset
46- self .pid_info = pid_info
53+ # self.pid_info = pid_info
4754
4855 def __call__ (self , trial : optuna .Trial ):
4956 # TODO: decide which parameters to tune
5057 trial_configs = self .trial_configs
5158 trial_configs ['learning_rate' ] = trial .suggest_float ("lr" , 1e-7 , 1e-1 , log = True )
5259 trial_configs ['trial_id' ] = 'tid_' + str (trial ._trial_id )
5360 # TODO: check args
54- model , model_cfg = create_forecast_model (** self .model_configs )
55- self .pid_info [self .trial_configs ['model_id' ]][trial ._trial_id ] = os .getpid ()
56- _trial = ForecastingTrainingTrial (trial_configs , model , self . model_configs , self .dataset )
61+ model , model_configs = create_forecast_model (** self .model_configs )
62+ # self.pid_info[self.trial_configs['model_id']][trial._trial_id] = os.getpid()
63+ _trial = ForecastingTrainingTrial (trial_configs , model , model_configs , self .dataset )
5764 loss = _trial .start ()
5865 return loss
5966
@@ -65,71 +72,163 @@ class _BasicTask(object):
6572 """
6673
6774 def __init__ (
68- self ,
69- task_configs : Dict ,
70- model_configs : Dict ,
71- model : nn .Module ,
72- dataset : Dataset ,
73- pid_info : Dict
75+ self ,
76+ task_configs : Dict ,
77+ model_configs : Dict ,
78+ pid_info : Dict
7479 ):
7580 """
7681 Args:
7782 task_configs:
7883 model_configs:
79- model:
80- dataset:
8184 pid_info:
8285 """
8386 self .pid_info = pid_info
8487 self .task_configs = task_configs
8588 self .model_configs = model_configs
86- self .model = model
89+
90+ @abstractmethod
91+ def __call__ (self ):
92+ raise NotImplementedError
93+
94+
95+ class _BasicTrainingTask (_BasicTask ):
96+ def __init__ (
97+ self ,
98+ task_configs : Dict ,
99+ model_configs : Dict ,
100+ pid_info : Dict ,
101+ data_configs : Dict ,
102+ dataset : Dataset ,
103+ ):
104+ """
105+ Args:
106+ task_configs:
107+ model_configs:
108+ pid_info:
109+ data_configs:
110+ dataset:
111+ """
112+ super ().__init__ (task_configs , model_configs , pid_info )
113+ self .data_configs = data_configs
87114 self .dataset = dataset
115+ self .confignode_client = client_manager .borrow_config_node_client ()
116+
117+ @abstractmethod
118+ def __call__ (self ):
119+ raise NotImplementedError
120+
121+
122+ class _BasicInferenceTask (_BasicTask ):
123+ def __int__ (
124+ self ,
125+ task_configs : Dict ,
126+ model_configs : Dict ,
127+ pid_info : Dict ,
128+ data : Tuple ,
129+ ):
130+ super ().__init__ (task_configs , model_configs , pid_info )
131+ self .data = data
132+ self .model , self .model_configs = create_forecast_model (** self .model_configs )
88133
89134 @abstractmethod
90135 def __call__ (self ):
91136 raise NotImplementedError
92137
138+ @abstractmethod
139+ def data_align (self ):
140+ raise NotImplementedError
141+
93142
94- class ForecastingTrainingTask (_BasicTask ):
95- def __init__ (self , task_configs : Dict , model_configs : Dict , model : nn .Module , dataset : Dataset ,
96- pid_info : Dict ):
143+ class ForecastingSingleTrainingTask (_BasicTrainingTask ):
144+ def __init__ (
145+ self ,
146+ task_configs : Dict ,
147+ model_configs : Dict ,
148+ pid_info : Dict ,
149+ data_configs : Dict ,
150+ dataset : Dataset ,
151+ model_id : str ,
152+ ):
97153 """
98154 Args:
99155 task_configs: dict of task configurations
100156 model_configs: dict of model configurations
101- model: nn.Module
102- dataset: training dataset
103157 pid_info: a map shared between processes, can be used to find the pid with model_id and trial_id
158+ data_configs: dict of data configurations
159+ dataset: training dataset
104160 """
105- super (ForecastingTrainingTask , self ).__init__ (task_configs , model_configs , model , dataset , pid_info )
106- self .model_id = self .task_configs ['model_id' ]
107- self .tuning = self .task_configs ['tuning' ]
108- self .confignode_client = client_manager .borrow_config_node_client ()
109-
110- if self .tuning :
111- self .study = optuna .create_study (direction = 'minimize' )
112- else :
113- self .default_trial_id = 'tid_0'
114- self .task_configs ['trial_id' ] = self .default_trial_id
115- self .trial = ForecastingTrainingTrial (self .task_configs , self .model , self .model_configs , self .dataset )
116- self .pid_info [self .model_id ][self .default_trial_id ] = os .getpid ()
161+ super ().__init__ (task_configs , model_configs , pid_info , data_configs , dataset )
162+ self .model_id = model_id
163+ self .default_trial_id = 'tid_0'
164+ self .task_configs ['trial_id' ] = self .default_trial_id
165+ model , model_configs = create_forecast_model (** model_configs )
166+ self .trial = ForecastingTrainingTrial (task_configs , model , model_configs , dataset )
167+ self .pid_info [self .model_id ][self .default_trial_id ] = os .getpid ()
117168
118169 def __call__ (self ):
119170 try :
120- if self .tuning :
121- self .study .optimize (TrainingTrialObjective (
122- self .task_configs ,
123- self .model_configs ,
124- self .dataset ,
125- self .pid_info ),
126- n_trials = descriptor .get_config ().get_mn_tuning_trial_num (),
127- n_jobs = descriptor .get_config ().get_mn_tuning_trial_concurrency ())
128- best_trial_id = 'tid_' + str (self .study .best_trial ._trial_id )
129- self .confignode_client .update_model_state (self .model_id , TrainingState .FINISHED , best_trial_id )
130- else :
131- self .trial .start ()
132- self .confignode_client .update_model_state (self .model_id , TrainingState .FINISHED , self .default_trial_id )
171+ self .trial .start ()
172+ self .confignode_client .update_model_state (self .model_id , TrainingState .FINISHED , self .default_trial_id )
133173 except Exception as e :
134174 logger .warn (e )
135175 raise e
176+
177+
178+ class ForecastingTuningTrainingTask (_BasicTrainingTask ):
179+ def __init__ (
180+ self ,
181+ task_configs : Dict ,
182+ model_configs : Dict ,
183+ pid_info : Dict ,
184+ data_configs : Dict ,
185+ dataset : Dataset ,
186+ model_id : str ,
187+ ):
188+ """
189+ Args:
190+ task_configs: dict of task configurations
191+ model_configs: dict of model configurations
192+ pid_info: a map shared between processes, can be used to find the pid with model_id and trial_id
193+ data_configs: dict of data configurations
194+ dataset: training dataset
195+ """
196+ super ().__init__ (task_configs , model_configs , pid_info , data_configs , dataset )
197+ self .model_id = model_id
198+ self .study = optuna .create_study (direction = 'minimize' )
199+
200+ def __call__ (self ):
201+ # try:
202+ self .study .optimize (ForestingTrainingObjective (
203+ self .task_configs ,
204+ self .model_configs ,
205+ self .dataset ),
206+ n_trials = descriptor .get_config ().get_mn_tuning_trial_num (),
207+ n_jobs = descriptor .get_config ().get_mn_tuning_trial_concurrency ())
208+ best_trial_id = 'tid_' + str (self .study .best_trial ._trial_id )
209+ self .confignode_client .update_model_state (self .model_id , TrainingState .FINISHED , best_trial_id )
210+ #
211+ # except Exception as e:
212+ # logger.warn(e)
213+ # raise e
214+
215+
216+ class ForecastingInferenceTask (_BasicInferenceTask ):
217+ def __int__ (
218+ self ,
219+ task_configs : Dict ,
220+ model_configs : Dict ,
221+ pid_info : Dict ,
222+ data : Tuple ,
223+ pipe : Pipe ,
224+ ):
225+ super ().__init__ (task_configs , model_configs , pid_info , data )
226+
227+ def __call__ (self ):
228+ pass
229+
230+ def data_align (self ):
231+ pass
232+
233+ def generate_future_mark (self ):
234+ pass
0 commit comments