1404 lines
46 KiB
Python
1404 lines
46 KiB
Python
# %% Scatter
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
|
||
# Load your data from TSV file
|
||
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv'
|
||
df = pd.read_csv(file_path, sep='\t')
|
||
|
||
# Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results
|
||
df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.')
|
||
df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.')
|
||
|
||
# Convert to float (handle invalid entries gracefully)
|
||
df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce')
|
||
df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce')
|
||
|
||
# Drop rows where either column is NaN
|
||
df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results'])
|
||
|
||
# Create scatter plot
|
||
plt.figure(figsize=(8, 6))
|
||
plt.scatter(df_clean['GT_EDSS'], df_clean['LLM_Results'], alpha=0.7, color='blue')
|
||
|
||
# Add labels and title
|
||
plt.xlabel('GT_EDSS')
|
||
plt.ylabel('LLM_Results')
|
||
plt.title('Comparison of GT_EDSS vs LLM_Results')
|
||
|
||
# Optional: Add a diagonal line for reference (perfect prediction)
|
||
plt.plot([0, max(df_clean['GT_EDSS'])], [0, max(df_clean['GT_EDSS'])], color='red', linestyle='--', label='Perfect Prediction')
|
||
plt.legend()
|
||
|
||
# Show plot
|
||
plt.grid(True)
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
##
|
||
|
||
|
||
# %% Bland0-altman
|
||
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import statsmodels.api as sm
|
||
|
||
# Load your data from TSV file
|
||
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv'
|
||
df = pd.read_csv(file_path, sep='\t')
|
||
|
||
# Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results
|
||
df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.')
|
||
df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.')
|
||
|
||
# Convert to float (handle invalid entries gracefully)
|
||
df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce')
|
||
df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce')
|
||
|
||
# Drop rows where either column is NaN
|
||
df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results'])
|
||
|
||
# Create Bland-Altman plot
|
||
f, ax = plt.subplots(1, figsize=(8, 5))
|
||
sm.graphics.mean_diff_plot(df_clean['GT_EDSS'], df_clean['LLM_Results'], ax=ax)
|
||
|
||
# Add labels and title
|
||
ax.set_title('Bland-Altman Plot: GT_EDSS vs LLM_Results')
|
||
ax.set_xlabel('Mean of GT_EDSS and LLM_Results')
|
||
ax.set_ylabel('Difference between GT_EDSS and LLM_Results')
|
||
|
||
# Display Bland-Altman plot
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
# Print some statistics
|
||
mean_diff = np.mean(df_clean['GT_EDSS'] - df_clean['LLM_Results'])
|
||
std_diff = np.std(df_clean['GT_EDSS'] - df_clean['LLM_Results'])
|
||
print(f"Mean difference: {mean_diff:.3f}")
|
||
print(f"Standard deviation of differences: {std_diff:.3f}")
|
||
print(f"95% Limits of Agreement: [{mean_diff - 1.96*std_diff:.3f}, {mean_diff + 1.96*std_diff:.3f}]")
|
||
|
||
##
|
||
|
||
|
||
|
||
# %% Confusion matrix
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
from sklearn.metrics import confusion_matrix, classification_report
|
||
import seaborn as sns
|
||
|
||
# Load your data from TSV file
|
||
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
|
||
df = pd.read_csv(file_path, sep='\t')
|
||
|
||
# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS
|
||
df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.')
|
||
df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.')
|
||
|
||
# Convert to float (handle invalid entries gracefully)
|
||
df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce')
|
||
df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce')
|
||
|
||
# Drop rows where either column is NaN
|
||
df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS'])
|
||
|
||
# For confusion matrix, we need to categorize the values
|
||
# Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10)
|
||
def categorize_edss(value):
|
||
if pd.isna(value):
|
||
return np.nan
|
||
elif value <= 1.0:
|
||
return '0-1'
|
||
elif value <= 2.0:
|
||
return '1-2'
|
||
elif value <= 3.0:
|
||
return '2-3'
|
||
elif value <= 4.0:
|
||
return '3-4'
|
||
elif value <= 5.0:
|
||
return '4-5'
|
||
elif value <= 6.0:
|
||
return '5-6'
|
||
elif value <= 7.0:
|
||
return '6-7'
|
||
elif value <= 8.0:
|
||
return '7-8'
|
||
elif value <= 9.0:
|
||
return '8-9'
|
||
elif value <= 10.0:
|
||
return '9-10'
|
||
else:
|
||
return '10+'
|
||
|
||
# Create categorical versions
|
||
df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss)
|
||
df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss)
|
||
|
||
# Remove any NaN categories
|
||
df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat'])
|
||
|
||
# Create confusion matrix
|
||
cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'],
|
||
labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
|
||
|
||
# Plot confusion matrix
|
||
plt.figure(figsize=(10, 8))
|
||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||
xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'],
|
||
yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
|
||
#plt.title('Confusion Matrix: Ground truth EDSS vs interferred EDSS (Categorized 0-10)')
|
||
plt.xlabel('LLM Generated EDSS')
|
||
plt.ylabel('Ground Truth EDSS')
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
# Print classification report
|
||
print("Classification Report:")
|
||
print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat']))
|
||
|
||
# Print raw counts
|
||
print("\nConfusion Matrix (Raw Counts):")
|
||
print(cm)
|
||
|
||
##
|
||
|
||
|
||
# %% Confusion matrix
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
from sklearn.metrics import confusion_matrix, classification_report
|
||
import seaborn as sns
|
||
|
||
# Load your data from TSV file
|
||
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
|
||
df = pd.read_csv(file_path, sep='\t')
|
||
|
||
# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS
|
||
df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.')
|
||
df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.')
|
||
|
||
# Convert to float (handle invalid entries gracefully)
|
||
df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce')
|
||
df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce')
|
||
|
||
# Drop rows where either column is NaN
|
||
df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS'])
|
||
|
||
# For confusion matrix, we need to categorize the values
|
||
# Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10)
|
||
def categorize_edss(value):
|
||
if pd.isna(value):
|
||
return np.nan
|
||
elif value <= 1.0:
|
||
return '0-1'
|
||
elif value <= 2.0:
|
||
return '1-2'
|
||
elif value <= 3.0:
|
||
return '2-3'
|
||
elif value <= 4.0:
|
||
return '3-4'
|
||
elif value <= 5.0:
|
||
return '4-5'
|
||
elif value <= 6.0:
|
||
return '5-6'
|
||
elif value <= 7.0:
|
||
return '6-7'
|
||
elif value <= 8.0:
|
||
return '7-8'
|
||
elif value <= 9.0:
|
||
return '8-9'
|
||
elif value <= 10.0:
|
||
return '9-10'
|
||
else:
|
||
return '10+'
|
||
|
||
# Create categorical versions
|
||
df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss)
|
||
df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss)
|
||
|
||
# Remove any NaN categories
|
||
df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat'])
|
||
|
||
# Create confusion matrix
|
||
cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'],
|
||
labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
|
||
|
||
# Plot confusion matrix
|
||
plt.figure(figsize=(10, 8))
|
||
ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||
xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'],
|
||
yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
|
||
|
||
# Add legend text above the color bar
|
||
# Get the colorbar object
|
||
cbar = ax.collections[0].colorbar
|
||
# Add text above the colorbar
|
||
cbar.set_label('Number of Cases', rotation=270, labelpad=20)
|
||
|
||
plt.xlabel('LLM Generated EDSS')
|
||
plt.ylabel('Ground Truth EDSS')
|
||
#plt.title('Confusion Matrix: Ground truth EDSS vs inferred EDSS (Categorized 0-10)')
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
# Print classification report
|
||
print("Classification Report:")
|
||
print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat']))
|
||
|
||
# Print raw counts
|
||
print("\nConfusion Matrix (Raw Counts):")
|
||
print(cm)
|
||
|
||
|
||
##
|
||
|
||
|
||
|
||
|
||
# %% Classification
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
from sklearn.metrics import confusion_matrix
|
||
import numpy as np
|
||
|
||
# Load your data from TSV file
|
||
file_path ='/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
|
||
|
||
df = pd.read_csv(file_path, sep='\t')
|
||
|
||
# Check data structure
|
||
print("Data shape:", df.shape)
|
||
print("First few rows:")
|
||
print(df.head())
|
||
print("\nColumn names:")
|
||
for col in df.columns:
|
||
print(f" {col}")
|
||
|
||
# Function to safely convert to boolean
|
||
def safe_bool_convert(series):
|
||
'''Safely convert series to boolean, handling various input formats'''
|
||
# Convert to string first, then to boolean
|
||
series_str = series.astype(str).str.strip().str.lower()
|
||
|
||
# Handle different true/false representations
|
||
bool_map = {
|
||
'true': True, '1': True, 'yes': True, 'y': True,
|
||
'false': False, '0': False, 'no': False, 'n': False
|
||
}
|
||
|
||
converted = series_str.map(bool_map)
|
||
|
||
# Handle remaining NaN values
|
||
converted = converted.fillna(False) # or True, depending on your preference
|
||
|
||
return converted
|
||
|
||
# Convert columns safely
|
||
if 'result.klassifizierbar' in df.columns:
|
||
print("\nresult.klassifizierbar column info:")
|
||
print(df['result.klassifizierbar'].head(10))
|
||
print("Unique values:", df['result.klassifizierbar'].unique())
|
||
|
||
df['result.klassifizierbar'] = safe_bool_convert(df['result.klassifizierbar'])
|
||
print("After conversion:")
|
||
print(df['result.klassifizierbar'].value_counts())
|
||
|
||
if 'GT.klassifizierbar' in df.columns:
|
||
print("\nGT.klassifizierbar column info:")
|
||
print(df['GT.klassifizierbar'].head(10))
|
||
print("Unique values:", df['GT.klassifizierbar'].unique())
|
||
|
||
df['GT.klassifizierbar'] = safe_bool_convert(df['GT.klassifizierbar'])
|
||
print("After conversion:")
|
||
print(df['GT.klassifizierbar'].value_counts())
|
||
|
||
# Create bar chart showing only True values for klassifizierbar
|
||
if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns:
|
||
# Get counts for True values only
|
||
llm_true_count = df['result.klassifizierbar'].sum()
|
||
gt_true_count = df['GT.klassifizierbar'].sum()
|
||
|
||
# Plot using matplotlib directly
|
||
fig, ax = plt.subplots(figsize=(8, 6))
|
||
|
||
x = np.arange(2)
|
||
width = 0.35
|
||
|
||
bars1 = ax.bar(x[0] - width/2, llm_true_count, width, label='LLM', color='skyblue', alpha=0.8)
|
||
bars2 = ax.bar(x[1] + width/2, gt_true_count, width, label='GT', color='lightcoral', alpha=0.8)
|
||
|
||
# Add value labels on bars
|
||
ax.annotate(f'{llm_true_count}',
|
||
xy=(x[0], llm_true_count),
|
||
xytext=(0, 3),
|
||
textcoords="offset points",
|
||
ha='center', va='bottom')
|
||
|
||
ax.annotate(f'{gt_true_count}',
|
||
xy=(x[1], gt_true_count),
|
||
xytext=(0, 3),
|
||
textcoords="offset points",
|
||
ha='center', va='bottom')
|
||
|
||
ax.set_xlabel('Classification Status (klassifizierbar)')
|
||
ax.set_ylabel('Count')
|
||
ax.set_title('True Values Comparison: LLM vs GT for "klassifizierbar"')
|
||
ax.set_xticks(x)
|
||
ax.set_xticklabels(['LLM', 'GT'])
|
||
ax.legend()
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
# Create confusion matrix if both columns exist
|
||
if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns:
|
||
try:
|
||
# Ensure both columns are boolean
|
||
llm_bool = df['result.klassifizierbar'].fillna(False).astype(bool)
|
||
gt_bool = df['GT.klassifizierbar'].fillna(False).astype(bool)
|
||
|
||
cm = confusion_matrix(gt_bool, llm_bool)
|
||
|
||
# Plot confusion matrix
|
||
fig, ax = plt.subplots(figsize=(8, 6))
|
||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||
xticklabels=['False ', 'True '],
|
||
yticklabels=['False', 'True '],
|
||
ax=ax)
|
||
ax.set_xlabel('LLM Predictions ')
|
||
ax.set_ylabel('GT Labels ')
|
||
ax.set_title('Confusion Matrix: LLM vs GT for "klassifizierbar"')
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
print("Confusion Matrix:")
|
||
print(cm)
|
||
|
||
except Exception as e:
|
||
print(f"Error creating confusion matrix: {e}")
|
||
|
||
# Show final data info
|
||
print("\nFinal DataFrame info:")
|
||
print(df[['result.klassifizierbar', 'GT.klassifizierbar']].info())
|
||
|
||
##
|
||
|
||
|
||
|
||
|
||
# %% Boxplot
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
import numpy as np
|
||
|
||
# Load your data from TSV file
|
||
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_results_unique.tsv'
|
||
df = pd.read_csv(file_path, sep='\t')
|
||
|
||
# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS
|
||
df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.')
|
||
df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.')
|
||
|
||
# Convert to float (handle invalid entries gracefully)
|
||
df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce')
|
||
df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce')
|
||
|
||
# Drop rows where either column is NaN
|
||
df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS'])
|
||
|
||
# 1. DEFINE CATEGORY ORDER
|
||
# This ensures the X-axis is numerically logical (0-1 comes before 1-2)
|
||
category_order = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10', '10+']
|
||
|
||
# Convert the column to a Categorical type with the specific order
|
||
df_clean['GT.EDSS_cat'] = pd.Categorical(df_clean['GT.EDSS'].apply(categorize_edss),
|
||
categories=category_order,
|
||
ordered=True)
|
||
|
||
plt.figure(figsize=(14, 8))
|
||
|
||
# 2. ADD HUE FOR LEGEND
|
||
# Assigning x to 'hue' allows Seaborn to generate a legend automatically
|
||
box_plot = sns.boxplot(
|
||
data=df_clean,
|
||
x='GT.EDSS_cat',
|
||
y='result.EDSS',
|
||
hue='GT.EDSS_cat', # Added hue
|
||
palette='viridis',
|
||
linewidth=1.5,
|
||
legend=True # Ensure legend is enabled
|
||
)
|
||
|
||
# 3. CUSTOMIZE PLOT
|
||
plt.title('Distribution of result.EDSS by GT.EDSS Category', fontsize=18, pad=20)
|
||
plt.xlabel('Ground Truth EDSS Category', fontsize=14)
|
||
plt.ylabel('LLM Predicted EDSS', fontsize=14)
|
||
|
||
# Move legend to the side or top
|
||
plt.legend(title="EDSS Categories", bbox_to_anchor=(1.05, 1), loc='upper left')
|
||
|
||
plt.xticks(rotation=45, ha='right', fontsize=10)
|
||
plt.grid(True, axis='y', alpha=0.3)
|
||
plt.tight_layout()
|
||
|
||
plt.show()
|
||
##
|
||
|
||
|
||
# %% Postproccessing Column names
|
||
|
||
import pandas as pd
|
||
|
||
# Read the TSV file
|
||
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
|
||
df = pd.read_csv(file_path, sep='\t')
|
||
|
||
# Create a mapping dictionary for German to English column names
|
||
column_mapping = {
|
||
'EDSS':'GT.EDSS',
|
||
'klassifizierbar': 'GT.klassifizierbar',
|
||
'Sehvermögen': 'GT.VISUAL_OPTIC_FUNCTIONS',
|
||
'Cerebellum': 'GT.CEREBELLAR_FUNCTIONS',
|
||
'Hirnstamm': 'GT.BRAINSTEM_FUNCTIONS',
|
||
'Sensibiliät': 'GT.SENSORY_FUNCTIONS',
|
||
'Pyramidalmotorik': 'GT.PYRAMIDAL_FUNCTIONS',
|
||
'Ambulation': 'GT.AMBULATION',
|
||
'Cerebrale_Funktion': 'GT.CEREBRAL_FUNCTIONS',
|
||
'Blasen-_und_Mastdarmfunktion': 'GT.BOWEL_AND_BLADDER_FUNCTIONS'
|
||
}
|
||
|
||
# Rename columns
|
||
df = df.rename(columns=column_mapping)
|
||
|
||
# Save the modified dataframe back to TSV file
|
||
df.to_csv(file_path, sep='\t', index=False)
|
||
|
||
print("Columns have been successfully renamed!")
|
||
print("Renamed columns:")
|
||
for old_name, new_name in column_mapping.items():
|
||
if old_name in df.columns:
|
||
print(f" {old_name} -> {new_name}")
|
||
|
||
|
||
##
|
||
|
||
|
||
|
||
|
||
# %% Styled table
|
||
import pandas as pd
|
||
import numpy as np
|
||
import seaborn as sns
|
||
import matplotlib.pyplot as plt
|
||
import dataframe_image as dfi
|
||
# Load data
|
||
df = pd.read_csv("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv", sep='\t')
|
||
|
||
# 1. Identify all GT and result columns
|
||
gt_columns = [col for col in df.columns if col.startswith('GT.')]
|
||
result_columns = [col for col in df.columns if col.startswith('result.')]
|
||
|
||
print("GT Columns found:", gt_columns)
|
||
print("Result Columns found:", result_columns)
|
||
|
||
# 2. Create proper mapping between GT and result columns
|
||
# Handle various naming conventions (spaces, underscores, etc.)
|
||
column_mapping = {}
|
||
|
||
for gt_col in gt_columns:
|
||
base_name = gt_col.replace('GT.', '')
|
||
|
||
# Clean the base name for matching - remove spaces, underscores, etc.
|
||
# Try different matching approaches
|
||
candidates = [
|
||
f'result.{base_name}', # Exact match
|
||
f'result.{base_name.replace(" ", "_")}', # With underscores
|
||
f'result.{base_name.replace("_", " ")}', # With spaces
|
||
f'result.{base_name.replace(" ", "")}', # No spaces
|
||
f'result.{base_name.replace("_", "")}' # No underscores
|
||
]
|
||
|
||
# Also try case-insensitive matching
|
||
candidates.append(f'result.{base_name.lower()}')
|
||
candidates.append(f'result.{base_name.upper()}')
|
||
|
||
# Try to find matching result column
|
||
matched = False
|
||
for candidate in candidates:
|
||
if candidate in result_columns:
|
||
column_mapping[gt_col] = candidate
|
||
matched = True
|
||
break
|
||
|
||
# If no exact match found, try partial matching
|
||
if not matched:
|
||
# Try to match by removing special characters and comparing
|
||
base_clean = ''.join(e for e in base_name if e.isalnum() or e in ['_', ' '])
|
||
for result_col in result_columns:
|
||
result_base = result_col.replace('result.', '')
|
||
result_clean = ''.join(e for e in result_base if e.isalnum() or e in ['_', ' '])
|
||
if base_clean.lower() == result_clean.lower():
|
||
column_mapping[gt_col] = result_col
|
||
matched = True
|
||
break
|
||
|
||
print("Column mapping:", column_mapping)
|
||
|
||
# 3. Faster, vectorized computation using the corrected mapping
|
||
data_list = []
|
||
|
||
for gt_col, result_col in column_mapping.items():
|
||
print(f"Processing {gt_col} vs {result_col}")
|
||
|
||
# Convert to numeric, forcing errors to NaN
|
||
s1 = pd.to_numeric(df[gt_col], errors='coerce').astype(float)
|
||
s2 = pd.to_numeric(df[result_col], errors='coerce').astype(float)
|
||
|
||
# Calculate matches (abs difference <= 0.5)
|
||
diff = np.abs(s1 - s2)
|
||
matches = (diff <= 0.5).sum()
|
||
|
||
# Determine the denominator (total valid comparisons)
|
||
valid_count = diff.notna().sum()
|
||
|
||
if valid_count > 0:
|
||
percentage = (matches / valid_count) * 100
|
||
else:
|
||
percentage = 0
|
||
|
||
# Extract clean base name for display
|
||
base_name = gt_col.replace('GT.', '')
|
||
|
||
data_list.append({
|
||
'GT': base_name,
|
||
'Match %': round(percentage, 1)
|
||
})
|
||
|
||
|
||
|
||
|
||
# 4. Prepare Data
|
||
match_df = pd.DataFrame(data_list)
|
||
# Clean up labels: Replace underscores with spaces and capitalize
|
||
match_df['GT'] = match_df['GT'].str.replace('_', ' ').str.title()
|
||
match_df = match_df.sort_values('Match %', ascending=False)
|
||
|
||
# 5. Create a "Beautiful" Table using Seaborn Heatmap
|
||
def create_luxury_table(df, output_file="edss_agreement.png"):
|
||
# Set the aesthetic style
|
||
sns.set_theme(style="white", font="sans-serif")
|
||
|
||
# Prepare data for heatmap
|
||
plot_data = df.set_index('GT')[['Match %']]
|
||
|
||
# Initialize the figure
|
||
# Height is dynamic based on number of rows
|
||
fig, ax = plt.subplots(figsize=(8, len(df) * 0.6))
|
||
|
||
# Create a custom diverging color map (Deep Red -> Mustard -> Emerald)
|
||
# This looks more professional than standard 'RdYlGn'
|
||
cmap = sns.diverging_palette(15, 135, s=80, l=55, as_cmap=True)
|
||
|
||
# Draw the heatmap
|
||
sns.heatmap(
|
||
plot_data,
|
||
annot=True,
|
||
fmt=".1f",
|
||
cmap=cmap,
|
||
center=85, # Centers the color transition
|
||
vmin=50, vmax=100, # Range of the gradient
|
||
linewidths=2,
|
||
linecolor='white',
|
||
cbar=False, # Remove color bar for a "table" look
|
||
annot_kws={"size": 14, "weight": "bold", "family": "sans-serif"}
|
||
)
|
||
|
||
# Styling the Axes (Turning the heatmap into a table)
|
||
ax.set_xlabel("")
|
||
ax.set_ylabel("")
|
||
ax.xaxis.tick_top() # Move "Match %" label to top
|
||
ax.set_xticklabels(['Agreement (%)'], fontsize=14, fontweight='bold', color='#2c3e50')
|
||
ax.tick_params(axis='y', labelsize=12, labelcolor='#2c3e50', length=0)
|
||
|
||
# Add a thin border around the plot
|
||
for _, spine in ax.spines.items():
|
||
spine.set_visible(True)
|
||
spine.set_color('#ecf0f1')
|
||
|
||
plt.title('EDSS Subcategory Consistency Analysis', fontsize=16, pad=40, fontweight='bold', color='#2c3e50')
|
||
|
||
# Add a subtle footer
|
||
plt.figtext(0.5, 0.0, "Tolerance: ±0.5 points",
|
||
wrap=True, horizontalalignment='center', fontsize=10, color='gray', style='italic')
|
||
|
||
# Save with high resolution
|
||
plt.tight_layout()
|
||
plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
||
print(f"Beautiful table saved as {output_file}")
|
||
|
||
# Execute
|
||
create_luxury_table(match_df)
|
||
|
||
|
||
# Run the function
|
||
save_styled_table(match_df)
|
||
# 6. Save as SVG
|
||
|
||
plt.savefig("agreement_table.svg", format='svg', dpi=300, bbox_inches='tight')
|
||
print("Successfully saved agreement_table.svg")
|
||
|
||
# Show plot if running in a GUI environment
|
||
plt.show()
|
||
##
|
||
|
||
|
||
|
||
# %% Time Plot
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import pandas as pd
|
||
from scipy import stats
|
||
|
||
# Load the TSV file
|
||
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
|
||
df = pd.read_csv(file_path, sep='\t')
|
||
|
||
# Extract the inference_time_sec column
|
||
inference_times = df['inference_time_sec'].dropna() # Remove NaN values
|
||
|
||
# Calculate statistics
|
||
mean_time = inference_times.mean()
|
||
std_time = inference_times.std()
|
||
median_time = np.median(inference_times)
|
||
|
||
# Create the histogram
|
||
fig, ax = plt.subplots(figsize=(10, 6))
|
||
|
||
# Create histogram with bins of 1 second width
|
||
min_time = int(inference_times.min())
|
||
max_time = int(inference_times.max()) + 1
|
||
bins = np.arange(min_time, max_time + 1, 1) # Bins of 1 second width
|
||
|
||
# Create histogram with counts (not probability density)
|
||
n, bins, patches = ax.hist(inference_times, bins=bins, color='lightblue', alpha=0.7, edgecolor='black', linewidth=0.5)
|
||
|
||
# Generate Gaussian curve for fit
|
||
x = np.linspace(inference_times.min(), inference_times.max(), 100)
|
||
# Scale Gaussian to match histogram counts
|
||
gaussian_counts = stats.norm.pdf(x, mean_time, std_time) * len(inference_times) * (bins[1] - bins[0])
|
||
|
||
# Plot Gaussian fit
|
||
ax.plot(x, gaussian_counts, color='red', linewidth=2, label=f'Gaussian Fit (μ={mean_time:.1f}s, σ={std_time:.1f}s)')
|
||
|
||
# Add vertical lines for mean and median
|
||
ax.axvline(mean_time, color='blue', linestyle='--', linewidth=2, label=f'Mean = {mean_time:.1f}s')
|
||
ax.axvline(median_time, color='green', linestyle='--', linewidth=2, label=f'Median = {median_time:.1f}s')
|
||
|
||
# Add standard deviation as vertical lines
|
||
ax.axvline(mean_time + std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'+1σ = {mean_time + std_time:.1f}s')
|
||
ax.axvline(mean_time - std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'-1σ = {mean_time - std_time:.1f}s')
|
||
|
||
ax.set_xlabel('Inference Time (seconds)')
|
||
ax.set_ylabel('Frequency')
|
||
ax.set_title('Inference Time Distribution with Gaussian Fit')
|
||
ax.legend()
|
||
ax.grid(True, alpha=0.3)
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
##
|
||
|
||
|
||
|
||
|
||
|
||
|
||
# %% Dashboard
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
from datetime import datetime
|
||
import numpy as np
|
||
|
||
# Load the data
|
||
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
|
||
df = pd.read_csv(file_path, sep='\t')
|
||
|
||
# Rename columns to remove 'result.' prefix and handle spaces
|
||
column_mapping = {}
|
||
for col in df.columns:
|
||
if col.startswith('result.'):
|
||
new_name = col.replace('result.', '')
|
||
# Handle spaces in column names (replace with underscores if needed)
|
||
new_name = new_name.replace(' ', '_')
|
||
column_mapping[col] = new_name
|
||
df = df.rename(columns=column_mapping)
|
||
|
||
# Convert MedDatum to datetime
|
||
df['MedDatum'] = pd.to_datetime(df['MedDatum'])
|
||
|
||
# Check what columns actually exist in the dataset
|
||
print("Available columns:")
|
||
print(df.columns.tolist())
|
||
print("\nFirst few rows:")
|
||
print(df.head())
|
||
|
||
# Hardcode specific patient names
|
||
patient_names = ['6b56865d']
|
||
|
||
# Define the functional systems (columns to plot) - adjust based on actual column names
|
||
functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual']
|
||
|
||
# Create subplots horizontally (2 columns, adjust rows as needed)
|
||
num_plots = len(functional_systems)
|
||
num_cols = 2
|
||
num_rows = (num_plots + num_cols - 1) // num_cols # Ceiling division
|
||
|
||
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 4*num_rows), sharex=False) # Changed sharex=False
|
||
if num_plots == 1:
|
||
axes = [axes]
|
||
elif num_rows == 1:
|
||
axes = axes
|
||
else:
|
||
axes = axes.flatten()
|
||
|
||
# Plot for the hardcoded patient
|
||
for i, system in enumerate(functional_systems):
|
||
# Filter data for this specific patient
|
||
patient_data = df[df['unique_id'] == patient_names[0]].sort_values('MedDatum')
|
||
|
||
# Check if patient data exists
|
||
if patient_data.empty:
|
||
print(f"No data found for patient: {patient_names[0]}")
|
||
continue
|
||
|
||
# Check if the system column exists in the data
|
||
if system in patient_data.columns:
|
||
# Plot the specific functional system
|
||
if not patient_data[system].isna().all():
|
||
axes[i].plot(patient_data['MedDatum'], patient_data[system], marker='o', linewidth=2, label=system)
|
||
axes[i].set_ylabel('Score')
|
||
axes[i].set_title(f'Functional System: {system}')
|
||
axes[i].grid(True, alpha=0.3)
|
||
axes[i].legend()
|
||
else:
|
||
axes[i].set_title(f'Functional System: {system} (No data)')
|
||
axes[i].set_ylabel('Score')
|
||
axes[i].grid(True, alpha=0.3)
|
||
else:
|
||
# Try to find column with similar name (case insensitive)
|
||
found_column = None
|
||
for col in df.columns:
|
||
if system.lower() in col.lower():
|
||
found_column = col
|
||
break
|
||
|
||
if found_column:
|
||
print(f"Found similar column: {found_column}")
|
||
if not patient_data[found_column].isna().all():
|
||
axes[i].plot(patient_data['MedDatum'], patient_data[found_column], marker='o', linewidth=2, label=found_column)
|
||
axes[i].set_ylabel('Score')
|
||
axes[i].set_title(f'Functional System: {system} (found as: {found_column})')
|
||
axes[i].grid(True, alpha=0.3)
|
||
axes[i].legend()
|
||
else:
|
||
axes[i].set_title(f'Functional System: {system} (Column not found)')
|
||
axes[i].set_ylabel('Score')
|
||
axes[i].grid(True, alpha=0.3)
|
||
|
||
# Hide empty subplots
|
||
for i in range(len(functional_systems), len(axes)):
|
||
axes[i].set_visible(False)
|
||
|
||
# Set x-axis label for the last row only
|
||
for i in range(len(functional_systems)):
|
||
if i >= len(axes) - num_cols: # Last row
|
||
axes[i].set_xlabel('Date')
|
||
|
||
# Force date formatting on all axes
|
||
for ax in axes:
|
||
ax.tick_params(axis='x', rotation=45)
|
||
ax.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d'))
|
||
ax.xaxis.set_major_locator(plt.matplotlib.dates.MonthLocator())
|
||
|
||
# Automatically format x-axis dates
|
||
plt.gcf().autofmt_xdate()
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
##
|
||
|
||
|
||
# %% Table
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
from datetime import datetime
|
||
import numpy as np
|
||
|
||
# Load the data
|
||
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
|
||
df = pd.read_csv(file_path, sep='\t')
|
||
|
||
# Convert MedDatum to datetime
|
||
df['MedDatum'] = pd.to_datetime(df['MedDatum'])
|
||
|
||
# Check what columns actually exist in the dataset
|
||
print("Available columns:")
|
||
print(df.columns.tolist())
|
||
print("\nFirst few rows:")
|
||
print(df.head())
|
||
|
||
# Check data types
|
||
print("\nData types:")
|
||
print(df.dtypes)
|
||
|
||
# Hardcode specific patient names
|
||
patient_names = ['6ccda8c6']
|
||
|
||
# Define the functional systems (columns to plot)
|
||
functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual']
|
||
|
||
# Create subplots
|
||
num_plots = len(functional_systems)
|
||
num_cols = 2
|
||
num_rows = (num_plots + num_cols - 1) // num_cols
|
||
|
||
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 4*num_rows), sharex=False)
|
||
if num_plots == 1:
|
||
axes = [axes]
|
||
elif num_rows == 1:
|
||
axes = axes
|
||
else:
|
||
axes = axes.flatten()
|
||
|
||
# Plot for the hardcoded patient
|
||
for i, system in enumerate(functional_systems):
|
||
# Filter data for this specific patient
|
||
patient_data = df[df['unique_id'] == patient_names[0]].sort_values('MedDatum')
|
||
|
||
# Check if patient data exists
|
||
if patient_data.empty:
|
||
print(f"No data found for patient: {patient_names[0]}")
|
||
axes[i].set_title(f'Functional System: {system} (No data)')
|
||
axes[i].set_ylabel('Score')
|
||
continue
|
||
|
||
# Check if the system column exists
|
||
if system in patient_data.columns:
|
||
# Plot only valid data (non-null values)
|
||
valid_data = patient_data.dropna(subset=[system])
|
||
|
||
if not valid_data.empty:
|
||
# Ensure MedDatum is properly formatted for plotting
|
||
axes[i].plot(valid_data['MedDatum'], valid_data[system], marker='o', linewidth=2, label=system)
|
||
axes[i].set_ylabel('Score')
|
||
axes[i].set_title(f'Functional System: {system}')
|
||
axes[i].grid(True, alpha=0.3)
|
||
axes[i].legend()
|
||
else:
|
||
axes[i].set_title(f'Functional System: {system} (No valid data)')
|
||
axes[i].set_ylabel('Score')
|
||
else:
|
||
# Try to find similar column names
|
||
found_column = None
|
||
for col in df.columns:
|
||
if system.lower() in col.lower():
|
||
found_column = col
|
||
break
|
||
|
||
if found_column:
|
||
valid_data = patient_data.dropna(subset=[found_column])
|
||
if not valid_data.empty:
|
||
axes[i].plot(valid_data['MedDatum'], valid_data[found_column], marker='o', linewidth=2, label=found_column)
|
||
axes[i].set_ylabel('Score')
|
||
axes[i].set_title(f'Functional System: {system} (found as: {found_column})')
|
||
axes[i].grid(True, alpha=0.3)
|
||
axes[i].legend()
|
||
else:
|
||
axes[i].set_title(f'Functional System: {system} (No valid data)')
|
||
axes[i].set_ylabel('Score')
|
||
else:
|
||
axes[i].set_title(f'Functional System: {system} (Column not found)')
|
||
axes[i].set_ylabel('Score')
|
||
|
||
# Hide empty subplots
|
||
for i in range(len(functional_systems), len(axes)):
|
||
axes[i].set_visible(False)
|
||
|
||
# Set x-axis label for the last row only
|
||
for i in range(len(functional_systems)):
|
||
if i >= len(axes) - num_cols: # Last row
|
||
axes[i].set_xlabel('Date')
|
||
|
||
# Format x-axis dates
|
||
for ax in axes:
|
||
if ax.get_lines(): # Only format if there are lines to plot
|
||
ax.tick_params(axis='x', rotation=45)
|
||
ax.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d'))
|
||
|
||
# Automatically adjust layout
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
|
||
|
||
|
||
##
|
||
|
||
|
||
|
||
|
||
# %% Histogram Fig1
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.font_manager as fm
|
||
import json
|
||
import os
|
||
|
||
def create_visit_frequency_plot(
|
||
file_path,
|
||
output_dir='/home/shahin/Lab/Doktorarbeit/Barcelona/Data',
|
||
output_filename='visit_frequency_distribution.svg',
|
||
fontsize=10,
|
||
color_scheme_path='colors.json'
|
||
):
|
||
"""
|
||
Creates a publication-ready bar chart of patient visit frequency.
|
||
|
||
Args:
|
||
file_path (str): Path to the input TSV file.
|
||
output_dir (str): Directory to save the output SVG file.
|
||
output_filename (str): Name of the output SVG file.
|
||
fontsize (int): Font size for all text elements (labels, title).
|
||
color_scheme_path (str): Path to the JSON file containing the color palette.
|
||
"""
|
||
# --- 1. Load Data and Color Scheme ---
|
||
try:
|
||
df = pd.read_csv(file_path, sep='\t')
|
||
print("Data loaded successfully.")
|
||
# Sort data for easier visual comparison
|
||
df = df.sort_values(by='Visits Count')
|
||
except FileNotFoundError:
|
||
print(f"Error: The file was not found at {file_path}")
|
||
return
|
||
|
||
try:
|
||
with open(color_scheme_path, 'r') as f:
|
||
colors = json.load(f)
|
||
# Select a blue from the sequential palette for the bars
|
||
bar_color = colors['sequential']['blues'][-2] # A saturated blue
|
||
except FileNotFoundError:
|
||
print(f"Warning: Color scheme file not found at {color_scheme_path}. Using default blue.")
|
||
bar_color = '#2171b5' # A common matplotlib blue
|
||
|
||
# --- 2. Set up the Plot with Scientific Style ---
|
||
plt.figure(figsize=(7.94, 6)) # Single-column width (7.94 cm) with appropriate height
|
||
|
||
# Set the font to Arial
|
||
arial_font = fm.FontProperties(family='Arial', size=fontsize)
|
||
plt.rcParams['font.family'] = 'Arial'
|
||
plt.rcParams['font.size'] = fontsize
|
||
|
||
# --- 3. Create the Bar Chart ---
|
||
ax = plt.gca()
|
||
bars = plt.bar(
|
||
x=df['Visits Count'],
|
||
height=df['Unique Patients'],
|
||
color=bar_color,
|
||
edgecolor='black',
|
||
linewidth=0.5, # Minimum line thickness
|
||
width=0.7
|
||
)
|
||
|
||
# --- NEW: Explicitly set x-ticks and labels to ensure all are shown ---
|
||
# Get the unique visit counts to use as tick labels
|
||
visit_counts = df['Visits Count'].unique()
|
||
# Set the x-ticks to be at the center of each bar
|
||
ax.set_xticks(visit_counts)
|
||
# Set the x-tick labels to be the visit counts, using the specified font
|
||
ax.set_xticklabels(visit_counts, fontproperties=arial_font)
|
||
# --- END OF NEW SECTION ---
|
||
|
||
# --- 4. Customize Axes and Layout (Nature style) ---
|
||
# Display only left and bottom axes
|
||
ax.spines['top'].set_visible(False)
|
||
ax.spines['right'].set_visible(False)
|
||
|
||
# Turn off axis ticks (the marks, not the labels)
|
||
plt.tick_params(axis='both', which='both', length=0)
|
||
|
||
# Remove grid lines
|
||
plt.grid(False)
|
||
|
||
# Set background to white (no shading)
|
||
ax.set_facecolor('white')
|
||
plt.gcf().set_facecolor('white')
|
||
|
||
# --- 5. Add Labels and Title ---
|
||
plt.xlabel('Number of Visits', fontproperties=arial_font, labelpad=10)
|
||
plt.ylabel('Number of Unique Patients', fontproperties=arial_font, labelpad=10)
|
||
plt.title('Distribution of Patient Visit Frequency', fontproperties=arial_font, pad=20)
|
||
|
||
# --- 6. Add y-axis values on top of each bar ---
|
||
# This adds the count of unique patients directly above each bar.
|
||
ax.bar_label(bars, fmt='%d', padding=3)
|
||
|
||
# --- 7. Export the Figure ---
|
||
# Ensure the output directory exists
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
full_output_path = os.path.join(output_dir, output_filename)
|
||
plt.savefig(full_output_path, format='svg', dpi=300, bbox_inches='tight')
|
||
print(f"\nFigure saved as '{full_output_path}'")
|
||
|
||
# --- 8. (Optional) Display the Plot ---
|
||
# plt.show()
|
||
|
||
# --- Main execution ---
|
||
if __name__ == '__main__':
|
||
# Define the file path
|
||
input_file = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/visit_freuency.tsv'
|
||
|
||
# Call the function to create and save the plot
|
||
create_visit_frequency_plot(
|
||
file_path=input_file,
|
||
fontsize=10 # Using a 10 pt font size as per guidelines
|
||
)
|
||
|
||
##
|
||
|
||
|
||
|
||
# %% Scatter Plot functional system
|
||
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import json
|
||
import os
|
||
|
||
# --- Configuration ---
|
||
# Set the font to Arial for all text in the plot, as per the guidelines
|
||
plt.rcParams['font.family'] = 'Arial'
|
||
|
||
# Define the path to your data file
|
||
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
|
||
|
||
# Define the path to save the color mapping JSON file
|
||
color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json'
|
||
|
||
# Define the path to save the final figure
|
||
figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg'
|
||
|
||
# --- 1. Load the Dataset ---
|
||
try:
|
||
# Load the TSV file
|
||
df = pd.read_csv(data_path, sep='\t')
|
||
print(f"Successfully loaded data from {data_path}")
|
||
print(f"Data shape: {df.shape}")
|
||
except FileNotFoundError:
|
||
print(f"Error: The file at {data_path} was not found.")
|
||
# Exit or handle the error appropriately
|
||
raise
|
||
|
||
# --- 2. Define Functional Systems and Create Color Mapping ---
|
||
# List of tuples containing (ground_truth_column, result_column)
|
||
functional_systems_to_plot = [
|
||
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
|
||
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
|
||
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
|
||
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
|
||
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
|
||
('GT.AMBULATION', 'result.AMBULATION'),
|
||
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
|
||
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS')
|
||
]
|
||
|
||
# Extract system names for color mapping and legend
|
||
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
|
||
|
||
# Define a professional color palette (dark blue theme)
|
||
# This is a qualitative palette with distinct, accessible colors
|
||
colors = [
|
||
'#003366', # Dark Blue
|
||
'#336699', # Medium Blue
|
||
'#6699CC', # Light Blue
|
||
'#99CCFF', # Very Light Blue
|
||
'#FF9966', # Coral
|
||
'#FF6666', # Light Red
|
||
'#CC6699', # Magenta
|
||
'#9966CC' # Purple
|
||
]
|
||
|
||
# Create a dictionary mapping system names to colors
|
||
color_map = dict(zip(system_names, colors))
|
||
|
||
# Ensure the directory for the JSON file exists
|
||
os.makedirs(os.path.dirname(color_json_path), exist_ok=True)
|
||
|
||
# Save the color map to a JSON file
|
||
with open(color_json_path, 'w') as f:
|
||
json.dump(color_map, f, indent=4)
|
||
|
||
print(f"Color mapping saved to {color_json_path}")
|
||
|
||
# --- 3. Calculate Agreement Percentages and Format Legend Labels ---
|
||
agreement_percentages = {}
|
||
legend_labels = {}
|
||
|
||
for gt_col, res_col in functional_systems_to_plot:
|
||
system_name = gt_col.split('.')[1]
|
||
|
||
# Convert columns to numeric, setting errors to NaN
|
||
gt_numeric = pd.to_numeric(df[gt_col], errors='coerce')
|
||
res_numeric = pd.to_numeric(df[res_col], errors='coerce')
|
||
|
||
# Ensure we are comparing the same rows
|
||
common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index)
|
||
gt_data = gt_numeric.loc[common_index]
|
||
res_data = res_numeric.loc[common_index]
|
||
|
||
# Calculate agreement percentage
|
||
if len(gt_data) > 0:
|
||
agreement = (gt_data == res_data).mean() * 100
|
||
else:
|
||
agreement = 0 # Handle case with no valid data
|
||
|
||
agreement_percentages[system_name] = agreement
|
||
|
||
# Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions")
|
||
formatted_name = " ".join(word.capitalize() for word in system_name.split('_'))
|
||
legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)"
|
||
|
||
# --- 4. Reshape Data for Plotting ---
|
||
plot_data = []
|
||
for gt_col, res_col in functional_systems_to_plot:
|
||
system_name = gt_col.split('.')[1]
|
||
|
||
# Convert columns to numeric, setting errors to NaN
|
||
gt_numeric = pd.to_numeric(df[gt_col], errors='coerce')
|
||
res_numeric = pd.to_numeric(df[res_col], errors='coerce')
|
||
|
||
# Create a temporary DataFrame with the numeric data
|
||
temp_df = pd.DataFrame({
|
||
'system': system_name,
|
||
'ground_truth': gt_numeric,
|
||
'inference': res_numeric
|
||
})
|
||
|
||
# Drop rows where either value is NaN, as they cannot be plotted
|
||
temp_df = temp_df.dropna()
|
||
|
||
plot_data.append(temp_df)
|
||
|
||
# Concatenate all the temporary DataFrames into one
|
||
plot_df = pd.concat(plot_data, ignore_index=True)
|
||
|
||
if plot_df.empty:
|
||
print("Warning: No valid numeric data to plot after conversion. The plot will be blank.")
|
||
else:
|
||
print(f"Prepared plot data with {len(plot_df)} data points.")
|
||
|
||
# --- 5. Create the Scatter Plot ---
|
||
plt.figure(figsize=(10, 8))
|
||
|
||
# Plot each functional system with its assigned color and formatted legend label
|
||
for system, group in plot_df.groupby('system'):
|
||
plt.scatter(
|
||
group['ground_truth'],
|
||
group['inference'],
|
||
label=legend_labels[system],
|
||
color=color_map[system],
|
||
alpha=0.7,
|
||
s=30
|
||
)
|
||
|
||
# Add a diagonal line representing perfect agreement (y = x)
|
||
# This line helps visualize how close the predictions are to the ground truth
|
||
if not plot_df.empty:
|
||
plt.plot(
|
||
[plot_df['ground_truth'].min(), plot_df['ground_truth'].max()],
|
||
[plot_df['ground_truth'].min(), plot_df['ground_truth'].max()],
|
||
color='black',
|
||
linestyle='--',
|
||
linewidth=0.8,
|
||
alpha=0.7
|
||
)
|
||
|
||
# --- 6. Apply Styling and Labels ---
|
||
plt.xlabel('Ground Truth', fontsize=12)
|
||
plt.ylabel('LLM Inference', fontsize=12)
|
||
plt.title('Comparison of EDSS Functional Systems: Ground Truth vs. LLM Inference', fontsize=14)
|
||
|
||
# Apply scientific visualization styling rules
|
||
ax = plt.gca()
|
||
ax.spines['top'].set_visible(False)
|
||
ax.spines['right'].set_visible(False)
|
||
ax.tick_params(axis='both', which='both', length=0) # Remove ticks
|
||
ax.grid(False) # Remove grid lines
|
||
plt.legend(title='Functional System', frameon=False, fontsize=10)
|
||
|
||
# --- 7. Save and Display the Figure ---
|
||
# Ensure the directory for the figure exists
|
||
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
|
||
|
||
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
||
print(f"Figure successfully saved to {figure_save_path}")
|
||
|
||
# Display the plot
|
||
plt.show()
|
||
##
|
||
|
||
|
||
|
||
|
||
# %% Confusion Matrix functional systems
|
||
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import json
|
||
import os
|
||
import numpy as np
|
||
import matplotlib.colors as mcolors
|
||
|
||
# --- Configuration ---
|
||
plt.rcParams['font.family'] = 'Arial'
|
||
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
|
||
figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg'
|
||
|
||
# --- 1. Load the Dataset ---
|
||
df = pd.read_csv(data_path, sep='\t')
|
||
|
||
# --- 2. Define Functional Systems and Colors ---
|
||
functional_systems_to_plot = [
|
||
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
|
||
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
|
||
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
|
||
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
|
||
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
|
||
('GT.AMBULATION', 'result.AMBULATION'),
|
||
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
|
||
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS')
|
||
]
|
||
|
||
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
|
||
colors = ['#003366', '#336699', '#6699CC', '#99CCFF', '#FF9966', '#FF6666', '#CC6699', '#9966CC']
|
||
color_map = dict(zip(system_names, colors))
|
||
|
||
# --- 3. Categorization Function ---
|
||
categories = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']
|
||
category_to_index = {cat: i for i, cat in enumerate(categories)}
|
||
n_categories = len(categories)
|
||
|
||
def categorize_edss(value):
|
||
if pd.isna(value): return np.nan
|
||
# Ensure value is float to avoid TypeError
|
||
val = float(value)
|
||
idx = int(min(max(val, 0), 10) - 0.001) if val > 0 else 0
|
||
return categories[min(idx, len(categories)-1)]
|
||
|
||
# --- 4. Prepare Mixed Color Matrix with Saturation ---
|
||
cell_system_counts = np.zeros((n_categories, n_categories, len(system_names)))
|
||
|
||
for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot):
|
||
# Fix: Ensure numeric conversion to avoid string comparison errors
|
||
temp_df = df[[gt_col, res_col]].copy()
|
||
temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce')
|
||
temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce')
|
||
valid_df = temp_df.dropna()
|
||
|
||
for _, row in valid_df.iterrows():
|
||
gt_cat = categorize_edss(row[gt_col])
|
||
res_cat = categorize_edss(row[res_col])
|
||
if gt_cat in category_to_index and res_cat in category_to_index:
|
||
cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1
|
||
|
||
# Create an RGBA image matrix (10x10x4)
|
||
rgba_matrix = np.zeros((n_categories, n_categories, 4))
|
||
|
||
total_counts = np.sum(cell_system_counts, axis=2)
|
||
max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1
|
||
|
||
for i in range(n_categories):
|
||
for j in range(n_categories):
|
||
count_sum = total_counts[i, j]
|
||
if count_sum > 0:
|
||
mixed_rgb = np.zeros(3)
|
||
for s_idx, s_name in enumerate(system_names):
|
||
weight = cell_system_counts[i, j, s_idx] / count_sum
|
||
system_rgb = mcolors.to_rgb(color_map[s_name])
|
||
mixed_rgb += np.array(system_rgb) * weight
|
||
|
||
# Set RGB channels
|
||
rgba_matrix[i, j, :3] = mixed_rgb
|
||
|
||
# Set Alpha channel (Saturation Effect)
|
||
# Using a square root scale to make lower counts more visible but still "lighter"
|
||
alpha = np.sqrt(count_sum / max_count)
|
||
# Ensure alpha is at least 0.1 so it's not invisible
|
||
rgba_matrix[i, j, 3] = max(0.1, alpha)
|
||
else:
|
||
# Empty cells are white
|
||
rgba_matrix[i, j] = [1, 1, 1, 0]
|
||
|
||
# --- 5. Plotting ---
|
||
fig, ax = plt.subplots(figsize=(12, 10))
|
||
|
||
# Show the matrix
|
||
# Note: we use origin='lower' if you want 0-1 at the bottom,
|
||
# but confusion matrices usually have 0-1 at the top (origin='upper')
|
||
im = ax.imshow(rgba_matrix, interpolation='nearest', origin='upper')
|
||
|
||
# Add count labels
|
||
for i in range(n_categories):
|
||
for j in range(n_categories):
|
||
if total_counts[i, j] > 0:
|
||
# Background brightness for text contrast
|
||
bg_color = rgba_matrix[i, j, :3]
|
||
lum = 0.2126 * bg_color[0] + 0.7152 * bg_color[1] + 0.0722 * bg_color[2]
|
||
# If alpha is low, background is effectively white, so use black text
|
||
text_col = "white" if (lum < 0.5 and rgba_matrix[i,j,3] > 0.5) else "black"
|
||
ax.text(j, i, int(total_counts[i, j]), ha="center", va="center",
|
||
color=text_col, fontsize=10, fontweight='bold')
|
||
|
||
# --- 6. Styling ---
|
||
ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12, labelpad=10)
|
||
ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12, labelpad=10)
|
||
ax.set_title('Saturated Confusion Matrix\nColor = System Mixture | Opacity = Density', fontsize=14, pad=20)
|
||
|
||
ax.set_xticks(np.arange(n_categories))
|
||
ax.set_xticklabels(categories)
|
||
ax.set_yticks(np.arange(n_categories))
|
||
ax.set_yticklabels(categories)
|
||
|
||
# Remove the frame/spines for a cleaner look
|
||
for spine in ax.spines.values():
|
||
spine.set_visible(False)
|
||
|
||
# Custom Legend
|
||
handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names]
|
||
labels = [name.replace('_', ' ').capitalize() for name in system_names]
|
||
ax.legend(handles, labels, title='Functional Systems', loc='upper left',
|
||
bbox_to_anchor=(1.05, 1), frameon=False)
|
||
|
||
plt.tight_layout()
|
||
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
|
||
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
||
plt.show()
|
||
|
||
#
|
||
|
||
|
||
|