From bc63d1ee72cc01b6f5d52ef868a6bdaed30dd13f Mon Sep 17 00:00:00 2001 From: Shahin Ramezanzadeh Date: Wed, 4 Feb 2026 18:01:11 +0100 Subject: [PATCH] added new confusion matrix --- Data/show_plots.py | 94 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 1 deletion(-) diff --git a/Data/show_plots.py b/Data/show_plots.py index c14e5a5..e699989 100644 --- a/Data/show_plots.py +++ b/Data/show_plots.py @@ -151,7 +151,7 @@ plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'], yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) -plt.title('Confusion Matrix: Ground truth EDSS vs interferred EDSS (Categorized 0-10)') +#plt.title('Confusion Matrix: Ground truth EDSS vs interferred EDSS (Categorized 0-10)') plt.xlabel('LLM Generated EDSS') plt.ylabel('Ground Truth EDSS') plt.tight_layout() @@ -168,6 +168,98 @@ print(cm) ## +# %% Confusion matrix +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import confusion_matrix, classification_report +import seaborn as sns + +# Load your data from TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS +df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') +df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') + +# Convert to float (handle invalid entries gracefully) +df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') +df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') + +# Drop rows where either column is NaN +df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) + +# For confusion matrix, we need to categorize the values +# Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10) +def categorize_edss(value): + if pd.isna(value): + return np.nan + elif value <= 1.0: + return '0-1' + elif value <= 2.0: + return '1-2' + elif value <= 3.0: + return '2-3' + elif value <= 4.0: + return '3-4' + elif value <= 5.0: + return '4-5' + elif value <= 6.0: + return '5-6' + elif value <= 7.0: + return '6-7' + elif value <= 8.0: + return '7-8' + elif value <= 9.0: + return '8-9' + elif value <= 10.0: + return '9-10' + else: + return '10+' + +# Create categorical versions +df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss) +df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss) + +# Remove any NaN categories +df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat']) + +# Create confusion matrix +cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'], + labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) + +# Plot confusion matrix +plt.figure(figsize=(10, 8)) +ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'], + yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) + +# Add legend text above the color bar +# Get the colorbar object +cbar = ax.collections[0].colorbar +# Add text above the colorbar +cbar.set_label('Number of Cases', rotation=270, labelpad=20) + +plt.xlabel('LLM Generated EDSS') +plt.ylabel('Ground Truth EDSS') +#plt.title('Confusion Matrix: Ground truth EDSS vs inferred EDSS (Categorized 0-10)') +plt.tight_layout() +plt.show() + +# Print classification report +print("Classification Report:") +print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'])) + +# Print raw counts +print("\nConfusion Matrix (Raw Counts):") +print(cm) + + +## + + + # %% Classification import pandas as pd