-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathconfig.py
68 lines (57 loc) · 1.99 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# coding:utf-8
# Created on 2021/04
# Author: NZY & XJM
import torch
DEBUG = False
DATASET = 'EEG_ERP' # the name of dataset
TENSORBOARD_LOG_DIR = "./tmp_logging/tensorboard/" # path of saving tensorboard data
SAVE_DATA_PICKLE = False # save the data into pickle
LOAD_DATA_PICKLE = True # load the data from pickle
# path of saving or loading pickle data
DATA_PICKLE_FILE = './tmp_logging/DATA_%s.pkl' % DATASET
GPUS = [0, 1, 2, 3]# , 4, 5, 6, 7]
# GPUS = [0, 1]
USE_CUDA = torch.cuda.is_available() and len(GPUS) > 0
RANDOM_SEED = 0 # the random seed
RELOAD_PRETRAIN_MODEL = False # reload pretrain model
# path of saving model params
PARAMS_PATH = './tmp_logging/'
# path of pretrained model
MODEL_SAVE_PATH = './tmp_logging/model_,.pt'
IS_TRAIN = True
EEG_SAMPLE_RATE = 1000
if DEBUG:
DATA_FOLDER = "./../EEG_ERP/"
MAX_EPOCH_NUM = 20
BATCH_SIZE = 4
BATCH_SIZE_EVAL = 2
BATCH_SIZE_TEST = 2
WIN_FOR_EEG_SIGNAL = 16
EEG_ENCODER_DIM = 32 # 2 # 16
EEG_FEATURE_DIM = 16 # 2 # 8
RNN_HIDDEN_DIM = 32 # 2 # 20
RNN_LAYER_NUM = 1
LOCAL_RNN_STEP = int((600/1000)*EEG_SAMPLE_RATE // (WIN_FOR_EEG_SIGNAL // 2)) # 600 ms extract ERP
GLOBAL_RNN_STEP = 12 # 12 sequence classification tasks
EARLY_STOP_EPOCH_GAP = 10
RELOAD_EVAL_EPOCH = 20 # 821
else:
DATA_FOLDER = "/home/Brain_Machine_Interface/EEG_ERP/" # the path of the dataset
MAX_EPOCH_NUM = 150
BATCH_SIZE = 16
BATCH_SIZE_EVAL = 16
BATCH_SIZE_TEST = 4
WIN_FOR_EEG_SIGNAL = 16
EEG_ENCODER_DIM = 256
EEG_FEATURE_DIM = 64
RNN_HIDDEN_DIM = 128
RNN_LAYER_NUM = 3
LOCAL_RNN_STEP = int((600/1000)*EEG_SAMPLE_RATE // (WIN_FOR_EEG_SIGNAL // 2))
GLOBAL_RNN_STEP = 12
EARLY_STOP_EPOCH_GAP = 100 # the epoch step of reload pretrain model
RELOAD_EVAL_EPOCH = 821
INIT_LEARNING_RATE = 0.001
EVAL_EPOCH = 1 # num of epochs to eval
SAVE_EPOCH = 1 # num of epochs to save model
LR_PATIENT = 2 # patient for lr decay
ES_PATIENT = 20 # 20 # patient for early stop