analysis.py 1.93 KB
Newer Older
mjboos's avatar
mjboos committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
# coding: utf-8
import joblib
import numpy as np
import matplotlib.pyplot as plt
import yaml
import seaborn as sns
import glob
import os
import pandas as pd

def get_yaml_content(fn):
    with open(fn, 'r') as f:
        smpl_dict = yaml.load(f)
    smpl_dict = { int(key.split('.')[0]) : val[0] for key, val in smpl_dict.iteritems()}
    rating_dict = { key : np.array([smpl_dict[idx] for idx in clf[key]]) for key in clf.keys()}
    return rating_dict

def ratings_to_df(ratings):
    df_dict = {'noise level' : np.concatenate([ratings[key] for key in sorted(ratings.keys())]),
               'signal to noise ratio' : np.concatenate([np.repeat([key.split('_')[1]], ratings[key].shape[0]) for key in sorted(ratings.keys())])}
    return pd.DataFrame.from_dict(df_dict)

clf = joblib.load('classification_FG_ridge_logBSC_H200_predictions.pkl')
clf = {lbl : np.argsort(data)[-50:] for lbl, data in clf.iteritems()}

files = glob.glob('[0-9][0-9].yml')

vp_dict = {fn.split('.')[0] : ratings_to_df(get_yaml_content(fn)) for fn in files}

for vp, ratings in vp_dict.iteritems():
    if os.path.exists('vp_{}.svg'.format(vp)):
        continue
#    plt.boxplot([ratings['speech_{}db_snr'.format(n)] for n in [0,5,10,15]], labels=['{} db'.format(n) for n in [0,5,10,15]], showmeans=True)
#    sns.swarmplot(data=ratings, x='signal to noise ratio', y='noise level', order=['{}db'.format(n) for n in [15, 10, 5, 0]])
    sns.boxplot(data=ratings, x='signal to noise ratio', y='noise level', order=['{}db'.format(n) for n in [15, 10, 5, 0]])
    plt.savefig('vp_{}.svg'.format(vp))
    plt.close()

# aggregate results

all_vp_dict = pd.concat([vp_dict[vp] for vp in vp_dict.keys()])
#sns.swarmplot(data=all_vp_dict, x='signal to noise ratio', y='noise level', order=['{}db'.format(n) for n in [15, 10, 5, 0]])
sns.boxplot(data=all_vp_dict, x='signal to noise ratio', y='noise level', order=['{}db'.format(n) for n in [15, 10, 5, 0]])
plt.savefig('all_participants.svg')
plt.close()