test_latent_space_reconstruction_joint_model.py 2.29 KB
Newer Older
mjboos's avatar
mjboos committed
1 2 3 4 5
import numpy as np
import matplotlib.pyplot as plt
from nilearn import image as img
import pandas as pd
import joblib
mjboos's avatar
mjboos committed
6
import seaborn as sns
mjboos's avatar
mjboos committed
7 8 9
import dill
from sklearn.linear_model import LinearRegression
from copy import deepcopy
mjboos's avatar
mjboos committed
10 11 12 13 14 15 16 17 18 19 20 21
from coef_helper_functions import remove_BF_from_coefs, get_cluster_coefs_from_estimator, make_df_for_lineplot
from auditory_feature_helpers import *

if __name__=='__main__':
    feature_dict = get_feature_dict()
    bsc = feature_dict.pop('BSC')
    estimator = get_average_estimator()
    joint_pcs = estimator.predict(bsc)
    ratings_dict = joblib.load('ratings_dict.pkl')
    joint_scores = get_feature_scores(feature_dict, joint_pcs, ratings_dict['ratings_idx'], estimator=LinearRegression())
    cluster_dict = get_cluster_infos()
    cluster_idx = cluster_dict.pop('index')
mjboos's avatar
mjboos committed
22
    scores_dict = dict()
mjboos's avatar
mjboos committed
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
    for i in np.unique(cluster_idx):
        pc_predictions_wo_cluster = remove_BF_from_coefs(estimator, cluster_idx==i).predict(bsc)[:,:3]
        scores_dict[i] = get_feature_scores(feature_dict, pc_predictions_wo_cluster, ratings_dict['ratings_idx'], estimator=LinearRegression())

    #individual_scores = [get_feature_scores(feature_dict, individual_pcs, ratings_dict['ratings_idx'], estimator=LinearRegression(), return_estimator=True) for individual_pcs in pcs]

    joint_scores_mean = {feature : feature_arr.mean() for feature, feature_arr in joint_scores.iteritems()}
    cluster_joint_diff = {cluster : {feature : joint_scores[feature]-scores_ft
                                    for feature, scores_ft in cluster_scores.iteritems()}
                        for cluster, cluster_scores in scores_dict.iteritems()}

    reshaped_dict = {(feature,"cluster {}".format(cluster+1)) : cluster_scores[feature] for feature in joint_scores for cluster, cluster_scores in cluster_joint_diff.iteritems()}

    feature_cluster_df = pd.melt(pd.DataFrame(reshaped_dict))
    feature_cluster_df.columns = ['Feature', 'Cluster', 'Difference in explained variance']
    g = sns.catplot(data=feature_cluster_df, col='Feature', kind='strip', x='Difference in explained variance',
                    y='Cluster', col_wrap=3)
    g.savefig('Differences_explained_variance_per_cluster_compressed_new.svg')
    #fig, axes = plt.subplots(4,3,figsize=(15,20), constrained_layout=True)
    #flat_axes = axes.flatten()