Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 25 additions & 23 deletions train_12ECG_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,18 @@ def train_12ECG_classifier(input_directory, output_directory):
header_files.append(g)

classes = get_classes(input_directory, header_files)
num_classes = len(classes)
num_files = len(header_files)
recordings = list()
headers = list()

for i in range(num_files):
recording, header = load_challenge_data(header_files[i])
recordings.append(recording)
headers.append(header)

# Train model.
print('Training model...')

features = list()
labels = list()

num_files = len(header_files)
for i in range(num_files):
recording = recordings[i]
header = headers[i]

tmp = get_12ECG_features(recording, header)
features.append(tmp)

for l in header:
if l.startswith('#Dx:'):
labels_act = np.zeros(num_classes)
arrs = l.strip().split(' ')
for arr in arrs[1].split(','):
class_index = classes.index(arr.rstrip()) # Only use first positive index
labels_act[class_index] = 1
labels.append(labels_act)
feats, labs = process_single_recording(header_files[i], classes)
features.append(feats)
labels.append(labs)

features = np.array(features)
labels = np.array(labels)
Expand All @@ -67,6 +48,7 @@ def train_12ECG_classifier(input_directory, output_directory):
filename = os.path.join(output_directory, 'finalized_model.sav')
joblib.dump(final_model, filename, protocol=0)


# Load challenge data.
def load_challenge_data(header_file):
with open(header_file, 'r') as f:
Expand All @@ -76,6 +58,7 @@ def load_challenge_data(header_file):
recording = np.asarray(x['val'], dtype=np.float64)
return recording, header


# Find unique classes.
def get_classes(input_directory, filenames):
classes = set()
Expand All @@ -87,3 +70,22 @@ def get_classes(input_directory, filenames):
for c in tmp:
classes.add(c.strip())
return sorted(classes)


# Process a single sample.
def process_single_recording(header_file, classes):

recording, header = load_challenge_data(header_file)

feats = get_12ECG_features(recording, header)
num_classes = len(classes)

for l in header:
if l.startswith('#Dx:'):
labels_act = np.zeros(num_classes)
arrs = l.strip().split(' ')
for arr in arrs[1].split(','):
class_index = classes.index(arr.rstrip()) # Only use first positive index
labels_act[class_index] = 1

return feats, labels_act