Skip to content

Commit 4914e5b

Browse files
committed
added timepoint decoding for precomputed ec correlationsg
1 parent ceaa4e2 commit 4914e5b

File tree

1 file changed

+257
-3
lines changed

1 file changed

+257
-3
lines changed

timecorr/helpers.py

Lines changed: 257 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,219 @@ def timepoint_decoder(data, mu=None, nfolds=2, level=0, cfun=isfc, weights_fun=l
472472

473473
return results_pd
474474

475+
def weighted_timepoint_decoder_ec(data, nfolds=2, level=0, optimize_levels=None, cfun=isfc, weights_fun=laplace_weights,
476+
weights_params=laplace_params, combine=mean_combine, rfun=None, opt_init=None):
477+
"""
478+
:param data: a list of number-of-observations by number-of-features matrices
479+
:param nfolds: number of cross-validation folds (train using out-of-fold data;
480+
test using in-fold data)
481+
:param level: integer or list of integers for levels to be evaluated (default:0)
482+
:param cfun: function for transforming the group data (default: isfc)
483+
:param weights_fun: used to compute per-timepoint weights for cfun; default: laplace_weights
484+
:param weights_params: parameters passed to weights_fun; default: laplace_params
485+
:params combine: function for combining data within each group, or a list of such functions (default: mean_combine)
486+
:param rfun: function for reducing output (default: None)
487+
:return: results dictionary with the following keys:
488+
'rank': mean percentile rank (across all timepoints and folds) in the
489+
decoding distribution of the true timepoint
490+
'accuracy': mean percent accuracy (across all timepoints and folds)
491+
'error': mean estimation error (across all timepoints and folds) between
492+
the decoded and actual window numbers, expressed as a percentage
493+
of the total number of windows
494+
"""
495+
496+
if nfolds == 1:
497+
sub_nfolds = 1
498+
nfolds = 2
499+
warnings.warn('When nfolds is set to one, the analysis will be circular.')
500+
else:
501+
sub_nfolds = 1
502+
503+
group_assignments = get_xval_assignments(data.shape[1], nfolds)
504+
505+
506+
orig_level = level
507+
orig_level = np.ravel(orig_level)
508+
509+
if type(level) is int:
510+
level = np.arange(level + 1)
511+
512+
level = np.ravel(level)
513+
514+
assert type(level) is np.ndarray, 'level needs be an integer, list, or np.ndarray'
515+
assert not np.any(level < 0), 'level cannot contain negative numbers'
516+
517+
if not np.all(np.arange(level.max()+1)==level):
518+
level = np.arange(level.max()+1)
519+
520+
if callable(combine):
521+
combine = [combine] * np.shape(level)[0]
522+
523+
combine = np.ravel(combine)
524+
525+
assert type(combine) is np.ndarray and type(combine[0]) is not np.str_, 'combine needs to be a function, list of ' \
526+
'functions, or np.ndarray of functions'
527+
assert len(level)==len(combine), 'combine length need to be the same as level if input is type np.ndarray or list'
528+
529+
if callable(cfun):
530+
cfun = [cfun] * np.shape(level)[0]
531+
532+
cfun = np.ravel(cfun)
533+
534+
assert type(cfun) is np.ndarray and type(cfun[0]) is not np.str_, 'combine needs be a function, list of functions, ' \
535+
'or np.ndarray of functions'
536+
assert len(level)==len(cfun), 'cfun length need to be the same as level if input is type np.ndarray or list'
537+
538+
539+
if type(rfun) not in [list, np.ndarray]:
540+
rfun = [rfun] * np.shape(level)[0]
541+
542+
p_rfun = [None] * np.shape(level)[0]
543+
544+
assert len(level)==len(rfun), 'parameter lengths need to be the same as level if input is ' \
545+
'type np.ndarray or list'
546+
547+
548+
549+
results_pd = pd.DataFrame()
550+
551+
552+
for i in range(0, nfolds):
553+
554+
sub_corrs = []
555+
corrs = []
556+
557+
subgroup_assignments = get_xval_assignments(len(data[0][group_assignments == i]), nfolds)
558+
559+
in_data = [x for x in data[0][group_assignments == i]]
560+
out_data = [x for x in data[0][group_assignments != i]]
561+
562+
for v in level:
563+
564+
if v==0:
565+
566+
in_smooth, out_smooth, in_raw, out_raw = folding_levels_ec(in_data, out_data, level=v, cfun=None, rfun=p_rfun,
567+
combine=combine, weights_fun=weights_fun,
568+
weights_params=weights_params)
569+
570+
next_corrs = (1 - sd.cdist(mean_combine(in_smooth), mean_combine(out_smooth),
571+
'correlation'))
572+
corrs.append(next_corrs)
573+
574+
for s in range(0, 1):
575+
576+
sub_in_data = [x for x in data[0][group_assignments == i][subgroup_assignments==s]]
577+
sub_out_data = [x for x in data[0][group_assignments == i][subgroup_assignments!=s]]
578+
579+
sub_in_smooth, sub_out_smooth, sub_in_raw, sub_out_raw = folding_levels_ec(sub_in_data, sub_out_data,
580+
level=v, cfun=None, rfun=p_rfun,
581+
combine=combine, weights_fun=weights_fun,
582+
weights_params=weights_params)
583+
next_subcorrs = (1 - sd.cdist(mean_combine(sub_in_smooth),
584+
mean_combine(sub_out_smooth), 'correlation'))
585+
sub_corrs.append(next_subcorrs)
586+
587+
elif v==1:
588+
589+
in_smooth, out_smooth, in_raw, out_raw = folding_levels_ec(in_raw, out_raw, level=v, cfun=cfun,
590+
rfun=rfun, combine=combine,
591+
weights_fun=weights_fun,
592+
weights_params=weights_params)
593+
594+
next_corrs = (1 - sd.cdist(mean_combine(in_smooth), mean_combine(out_smooth),
595+
'correlation'))
596+
corrs.append(next_corrs)
597+
598+
for s in range(0, 1):
599+
600+
601+
sub_in_smooth, sub_out_smooth, sub_in_raw, sub_out_raw = folding_levels_ec(sub_in_raw,
602+
sub_out_raw,
603+
level=v,
604+
cfun=cfun,
605+
rfun=rfun,
606+
combine=combine,
607+
weights_fun=weights_fun,
608+
weights_params=weights_params)
609+
next_subcorrs = (1 - sd.cdist(mean_combine(sub_in_smooth),
610+
mean_combine(sub_out_smooth), 'correlation'))
611+
sub_corrs.append(next_subcorrs)
612+
613+
614+
615+
else:
616+
617+
in_raw = [x for x in data[v-1][group_assignments == i]]
618+
out_raw = [x for x in data[v-1][group_assignments != i]]
619+
620+
in_smooth, out_smooth, in_raw, out_raw = folding_levels_ec(in_raw, out_raw, level=v, cfun=cfun,
621+
rfun=rfun, combine=combine,
622+
weights_fun=weights_fun,
623+
weights_params=weights_params)
624+
625+
next_corrs = (1 - sd.cdist(in_smooth, out_smooth, 'correlation'))
626+
corrs.append(next_corrs)
627+
print('corrs ' + str(v))
628+
629+
for s in range(0, 1):
630+
631+
sub_in_raw = [x for x in data[v-1][group_assignments == i][subgroup_assignments==s]]
632+
sub_out_raw = [x for x in data[v-1][group_assignments == i][subgroup_assignments!=s]]
633+
634+
sub_in_smooth, sub_out_smooth, sub_in_raw, sub_out_raw = folding_levels_ec(sub_in_raw,
635+
sub_out_raw,
636+
level=v,
637+
cfun=cfun,
638+
rfun=rfun,
639+
combine=combine,
640+
weights_fun=weights_fun,
641+
weights_params=weights_params)
642+
print('sub corrs ' + str(v) + str(s))
643+
next_subcorrs = (1 - sd.cdist(sub_in_smooth, sub_out_smooth, 'correlation'))
644+
sub_corrs.append(next_subcorrs)
645+
646+
647+
sub_corrs = np.array(sub_corrs)
648+
corrs = np.array(corrs)
649+
650+
if sub_nfolds == 1:
651+
sub_corrs = corrs
652+
653+
if not optimize_levels:
654+
optimize_levels = range(v+1)
655+
656+
opt_over = []
657+
658+
for lev in optimize_levels:
659+
660+
opt_over.append(lev)
661+
662+
sub_out_corrs = sub_corrs[opt_over,:,:]
663+
out_corrs = corrs[opt_over, :, :]
664+
665+
mu = optimize_weights(sub_out_corrs, opt_init)
666+
667+
w_corrs = weight_corrs(out_corrs, mu)
668+
669+
next_results_pd = decoder(w_corrs)
670+
print(next_results_pd)
671+
next_results_pd['level'] = lev
672+
next_results_pd['folds'] = i
673+
674+
mu_pd = pd.DataFrame()
675+
676+
for c in opt_over:
677+
mu_pd['level_' + str(c)] = [0]
678+
679+
mu_pd += mu
680+
681+
next_results_pd = pd.concat([next_results_pd, mu_pd], axis=1, join_axes=[next_results_pd.index])
682+
683+
results_pd = pd.concat([results_pd, next_results_pd])
684+
685+
686+
return results_pd
687+
475688

476689
def weighted_timepoint_decoder(data, nfolds=2, level=0, optimize_levels=None, cfun=isfc, weights_fun=laplace_weights,
477690
weights_params=laplace_params, combine=mean_combine, rfun=None, opt_init=None):
@@ -575,7 +788,10 @@ def weighted_timepoint_decoder(data, nfolds=2, level=0, optimize_levels=None, cf
575788
combine=combine, weights_fun=weights_fun,
576789
weights_params=weights_params)
577790

578-
next_corrs = (1 - sd.cdist(mean_combine([x for x in in_raw]), mean_combine([x for x in out_raw]),
791+
# next_corrs = (1 - sd.cdist(mean_combine([x for x in in_raw]), mean_combine([x for x in out_raw]),
792+
# 'correlation'))
793+
794+
next_corrs = (1 - sd.cdist(mean_combine(in_smooth), mean_combine(out_smooth),
579795
'correlation'))
580796
corrs.append(next_corrs)
581797

@@ -588,8 +804,10 @@ def weighted_timepoint_decoder(data, nfolds=2, level=0, optimize_levels=None, cf
588804
level=v, cfun=None, rfun=p_rfun,
589805
combine=combine, weights_fun=weights_fun,
590806
weights_params=weights_params)
591-
next_subcorrs = (1 - sd.cdist(mean_combine([x for x in sub_in_raw]),
592-
mean_combine([x for x in sub_out_raw]), 'correlation'))
807+
# next_subcorrs = (1 - sd.cdist(mean_combine([x for x in sub_in_raw]),
808+
# mean_combine([x for x in sub_out_raw]), 'correlation'))
809+
next_subcorrs = (1 - sd.cdist(mean_combine(sub_in_smooth),
810+
mean_combine(sub_out_smooth), 'correlation'))
593811
sub_corrs.append(next_subcorrs)
594812

595813

@@ -663,6 +881,42 @@ def weighted_timepoint_decoder(data, nfolds=2, level=0, optimize_levels=None, cf
663881
return results_pd
664882

665883

884+
def folding_levels_ec(infold_data, outfold_data, level=0, cfun=None, weights_fun=None, weights_params=None, combine=None,
885+
rfun=None):
886+
887+
from .timecorr import timecorr
888+
889+
if rfun is None:
890+
rfun = [None] * np.shape(level)[0]
891+
892+
p_cfun = eval('autofc')
893+
894+
if level == 0:
895+
896+
in_fold_smooth = np.asarray(timecorr([x for x in infold_data], cfun=None,
897+
rfun=rfun[level], combine=combine[level], weights_function=weights_fun,
898+
weights_params=weights_params))
899+
out_fold_smooth = np.asarray(timecorr([x for x in outfold_data], cfun=None,
900+
rfun=rfun[level], combine=combine[level], weights_function=weights_fun,
901+
weights_params=weights_params))
902+
903+
in_fold_raw = infold_data
904+
out_fold_raw = outfold_data
905+
else:
906+
907+
raw_rfun = [None] * (level + 1)
908+
909+
in_fold_smooth = np.asarray(timecorr(list(infold_data), cfun=cfun[level], rfun=rfun[level], combine=combine[level],
910+
weights_function=weights_fun, weights_params=weights_params))
911+
out_fold_smooth = np.asarray(timecorr(list(outfold_data), cfun=cfun[level], rfun=rfun[level], combine=combine[level],
912+
weights_function=weights_fun, weights_params=weights_params))
913+
in_fold_raw = infold_data
914+
out_fold_raw = outfold_data
915+
916+
return in_fold_smooth, out_fold_smooth, in_fold_raw, out_fold_raw
917+
918+
919+
666920
def pca_decoder(data, nfolds=2, dims=10, cfun=isfc, weights_fun=laplace_weights,
667921
weights_params=laplace_params, combine=mean_combine, rfun=None):
668922
"""

0 commit comments

Comments
 (0)