diff --git a/train_12ECG_classifier.py b/train_12ECG_classifier.py index 7e67f6e..03bfa81 100644 --- a/train_12ECG_classifier.py +++ b/train_12ECG_classifier.py @@ -17,15 +17,6 @@ def train_12ECG_classifier(input_directory, output_directory): header_files.append(g) classes = get_classes(input_directory, header_files) - num_classes = len(classes) - num_files = len(header_files) - recordings = list() - headers = list() - - for i in range(num_files): - recording, header = load_challenge_data(header_files[i]) - recordings.append(recording) - headers.append(header) # Train model. print('Training model...') @@ -33,21 +24,11 @@ def train_12ECG_classifier(input_directory, output_directory): features = list() labels = list() + num_files = len(header_files) for i in range(num_files): - recording = recordings[i] - header = headers[i] - - tmp = get_12ECG_features(recording, header) - features.append(tmp) - - for l in header: - if l.startswith('#Dx:'): - labels_act = np.zeros(num_classes) - arrs = l.strip().split(' ') - for arr in arrs[1].split(','): - class_index = classes.index(arr.rstrip()) # Only use first positive index - labels_act[class_index] = 1 - labels.append(labels_act) + feats, labs = process_single_recording(header_files[i], classes) + features.append(feats) + labels.append(labs) features = np.array(features) labels = np.array(labels) @@ -67,6 +48,7 @@ def train_12ECG_classifier(input_directory, output_directory): filename = os.path.join(output_directory, 'finalized_model.sav') joblib.dump(final_model, filename, protocol=0) + # Load challenge data. def load_challenge_data(header_file): with open(header_file, 'r') as f: @@ -76,6 +58,7 @@ def load_challenge_data(header_file): recording = np.asarray(x['val'], dtype=np.float64) return recording, header + # Find unique classes. def get_classes(input_directory, filenames): classes = set() @@ -87,3 +70,22 @@ def get_classes(input_directory, filenames): for c in tmp: classes.add(c.strip()) return sorted(classes) + + +# Process a single sample. +def process_single_recording(header_file, classes): + + recording, header = load_challenge_data(header_file) + + feats = get_12ECG_features(recording, header) + num_classes = len(classes) + + for l in header: + if l.startswith('#Dx:'): + labels_act = np.zeros(num_classes) + arrs = l.strip().split(' ') + for arr in arrs[1].split(','): + class_index = classes.index(arr.rstrip()) # Only use first positive index + labels_act[class_index] = 1 + + return feats, labels_act