added new confusion matrix
This commit is contained in:
@@ -151,7 +151,7 @@ plt.figure(figsize=(10, 8))
|
||||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||||
xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'],
|
||||
yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
|
||||
plt.title('Confusion Matrix: Ground truth EDSS vs interferred EDSS (Categorized 0-10)')
|
||||
#plt.title('Confusion Matrix: Ground truth EDSS vs interferred EDSS (Categorized 0-10)')
|
||||
plt.xlabel('LLM Generated EDSS')
|
||||
plt.ylabel('Ground Truth EDSS')
|
||||
plt.tight_layout()
|
||||
@@ -168,6 +168,98 @@ print(cm)
|
||||
##
|
||||
|
||||
|
||||
# %% Confusion matrix
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from sklearn.metrics import confusion_matrix, classification_report
|
||||
import seaborn as sns
|
||||
|
||||
# Load your data from TSV file
|
||||
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
|
||||
df = pd.read_csv(file_path, sep='\t')
|
||||
|
||||
# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS
|
||||
df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.')
|
||||
df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.')
|
||||
|
||||
# Convert to float (handle invalid entries gracefully)
|
||||
df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce')
|
||||
df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce')
|
||||
|
||||
# Drop rows where either column is NaN
|
||||
df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS'])
|
||||
|
||||
# For confusion matrix, we need to categorize the values
|
||||
# Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10)
|
||||
def categorize_edss(value):
|
||||
if pd.isna(value):
|
||||
return np.nan
|
||||
elif value <= 1.0:
|
||||
return '0-1'
|
||||
elif value <= 2.0:
|
||||
return '1-2'
|
||||
elif value <= 3.0:
|
||||
return '2-3'
|
||||
elif value <= 4.0:
|
||||
return '3-4'
|
||||
elif value <= 5.0:
|
||||
return '4-5'
|
||||
elif value <= 6.0:
|
||||
return '5-6'
|
||||
elif value <= 7.0:
|
||||
return '6-7'
|
||||
elif value <= 8.0:
|
||||
return '7-8'
|
||||
elif value <= 9.0:
|
||||
return '8-9'
|
||||
elif value <= 10.0:
|
||||
return '9-10'
|
||||
else:
|
||||
return '10+'
|
||||
|
||||
# Create categorical versions
|
||||
df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss)
|
||||
df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss)
|
||||
|
||||
# Remove any NaN categories
|
||||
df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat'])
|
||||
|
||||
# Create confusion matrix
|
||||
cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'],
|
||||
labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
|
||||
|
||||
# Plot confusion matrix
|
||||
plt.figure(figsize=(10, 8))
|
||||
ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||||
xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'],
|
||||
yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
|
||||
|
||||
# Add legend text above the color bar
|
||||
# Get the colorbar object
|
||||
cbar = ax.collections[0].colorbar
|
||||
# Add text above the colorbar
|
||||
cbar.set_label('Number of Cases', rotation=270, labelpad=20)
|
||||
|
||||
plt.xlabel('LLM Generated EDSS')
|
||||
plt.ylabel('Ground Truth EDSS')
|
||||
#plt.title('Confusion Matrix: Ground truth EDSS vs inferred EDSS (Categorized 0-10)')
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# Print classification report
|
||||
print("Classification Report:")
|
||||
print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat']))
|
||||
|
||||
# Print raw counts
|
||||
print("\nConfusion Matrix (Raw Counts):")
|
||||
print(cm)
|
||||
|
||||
|
||||
##
|
||||
|
||||
|
||||
|
||||
|
||||
# %% Classification
|
||||
import pandas as pd
|
||||
|
||||
Reference in New Issue
Block a user