Compare commits
4 Commits
clean
...
Experiment
| Author | SHA1 | Date | |
|---|---|---|---|
| 2f507bcf20 | |||
| f4bf37f71c | |||
| bc63d1ee72 | |||
| c2ccb8cd11 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -7,6 +7,10 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
|
|
||||||
|
=======
|
||||||
|
/reference/
|
||||||
|
*.svg
|
||||||
|
>>>>>>> Stashed changes
|
||||||
# 2. Ignore virtual environments COMPLETELY
|
# 2. Ignore virtual environments COMPLETELY
|
||||||
# This must come BEFORE the unignore rule
|
# This must come BEFORE the unignore rule
|
||||||
env*/
|
env*/
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ plt.figure(figsize=(10, 8))
|
|||||||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||||||
xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'],
|
xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'],
|
||||||
yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
|
yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
|
||||||
plt.title('Confusion Matrix: Ground truth EDSS vs interferred EDSS (Categorized 0-10)')
|
#plt.title('Confusion Matrix: Ground truth EDSS vs interferred EDSS (Categorized 0-10)')
|
||||||
plt.xlabel('LLM Generated EDSS')
|
plt.xlabel('LLM Generated EDSS')
|
||||||
plt.ylabel('Ground Truth EDSS')
|
plt.ylabel('Ground Truth EDSS')
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
@@ -168,6 +168,98 @@ print(cm)
|
|||||||
##
|
##
|
||||||
|
|
||||||
|
|
||||||
|
# %% Confusion matrix
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics import confusion_matrix, classification_report
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
|
# Load your data from TSV file
|
||||||
|
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
|
||||||
|
df = pd.read_csv(file_path, sep='\t')
|
||||||
|
|
||||||
|
# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS
|
||||||
|
df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.')
|
||||||
|
df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.')
|
||||||
|
|
||||||
|
# Convert to float (handle invalid entries gracefully)
|
||||||
|
df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce')
|
||||||
|
df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce')
|
||||||
|
|
||||||
|
# Drop rows where either column is NaN
|
||||||
|
df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS'])
|
||||||
|
|
||||||
|
# For confusion matrix, we need to categorize the values
|
||||||
|
# Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10)
|
||||||
|
def categorize_edss(value):
|
||||||
|
if pd.isna(value):
|
||||||
|
return np.nan
|
||||||
|
elif value <= 1.0:
|
||||||
|
return '0-1'
|
||||||
|
elif value <= 2.0:
|
||||||
|
return '1-2'
|
||||||
|
elif value <= 3.0:
|
||||||
|
return '2-3'
|
||||||
|
elif value <= 4.0:
|
||||||
|
return '3-4'
|
||||||
|
elif value <= 5.0:
|
||||||
|
return '4-5'
|
||||||
|
elif value <= 6.0:
|
||||||
|
return '5-6'
|
||||||
|
elif value <= 7.0:
|
||||||
|
return '6-7'
|
||||||
|
elif value <= 8.0:
|
||||||
|
return '7-8'
|
||||||
|
elif value <= 9.0:
|
||||||
|
return '8-9'
|
||||||
|
elif value <= 10.0:
|
||||||
|
return '9-10'
|
||||||
|
else:
|
||||||
|
return '10+'
|
||||||
|
|
||||||
|
# Create categorical versions
|
||||||
|
df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss)
|
||||||
|
df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss)
|
||||||
|
|
||||||
|
# Remove any NaN categories
|
||||||
|
df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat'])
|
||||||
|
|
||||||
|
# Create confusion matrix
|
||||||
|
cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'],
|
||||||
|
labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
|
||||||
|
|
||||||
|
# Plot confusion matrix
|
||||||
|
plt.figure(figsize=(10, 8))
|
||||||
|
ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||||||
|
xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'],
|
||||||
|
yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
|
||||||
|
|
||||||
|
# Add legend text above the color bar
|
||||||
|
# Get the colorbar object
|
||||||
|
cbar = ax.collections[0].colorbar
|
||||||
|
# Add text above the colorbar
|
||||||
|
cbar.set_label('Number of Cases', rotation=270, labelpad=20)
|
||||||
|
|
||||||
|
plt.xlabel('LLM Generated EDSS')
|
||||||
|
plt.ylabel('Ground Truth EDSS')
|
||||||
|
#plt.title('Confusion Matrix: Ground truth EDSS vs inferred EDSS (Categorized 0-10)')
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Print classification report
|
||||||
|
print("Classification Report:")
|
||||||
|
print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat']))
|
||||||
|
|
||||||
|
# Print raw counts
|
||||||
|
print("\nConfusion Matrix (Raw Counts):")
|
||||||
|
print(cm)
|
||||||
|
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %% Classification
|
# %% Classification
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -662,7 +754,7 @@ print("\nFirst few rows:")
|
|||||||
print(df.head())
|
print(df.head())
|
||||||
|
|
||||||
# Hardcode specific patient names
|
# Hardcode specific patient names
|
||||||
patient_names = ['113c1470']
|
patient_names = ['6b56865d']
|
||||||
|
|
||||||
# Define the functional systems (columns to plot) - adjust based on actual column names
|
# Define the functional systems (columns to plot) - adjust based on actual column names
|
||||||
functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual']
|
functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual']
|
||||||
@@ -1183,7 +1275,6 @@ import matplotlib.colors as mcolors
|
|||||||
# --- Configuration ---
|
# --- Configuration ---
|
||||||
plt.rcParams['font.family'] = 'Arial'
|
plt.rcParams['font.family'] = 'Arial'
|
||||||
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
|
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
|
||||||
color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json'
|
|
||||||
figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg'
|
figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg'
|
||||||
|
|
||||||
# --- 1. Load the Dataset ---
|
# --- 1. Load the Dataset ---
|
||||||
@@ -1212,19 +1303,19 @@ n_categories = len(categories)
|
|||||||
|
|
||||||
def categorize_edss(value):
|
def categorize_edss(value):
|
||||||
if pd.isna(value): return np.nan
|
if pd.isna(value): return np.nan
|
||||||
idx = int(min(max(value, 0), 10) - 0.001) if value > 0 else 0
|
# Ensure value is float to avoid TypeError
|
||||||
|
val = float(value)
|
||||||
|
idx = int(min(max(val, 0), 10) - 0.001) if val > 0 else 0
|
||||||
return categories[min(idx, len(categories)-1)]
|
return categories[min(idx, len(categories)-1)]
|
||||||
|
|
||||||
# --- 4. Prepare Mixed Color Matrix ---
|
# --- 4. Prepare Mixed Color Matrix with Saturation ---
|
||||||
cell_system_counts = np.zeros((n_categories, n_categories, len(system_names)))
|
cell_system_counts = np.zeros((n_categories, n_categories, len(system_names)))
|
||||||
|
|
||||||
for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot):
|
for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot):
|
||||||
# CRITICAL FIX: Convert to numeric and drop NaNs in one go
|
# Fix: Ensure numeric conversion to avoid string comparison errors
|
||||||
# 'coerce' turns non-numeric strings into NaN so they don't crash the script
|
|
||||||
temp_df = df[[gt_col, res_col]].copy()
|
temp_df = df[[gt_col, res_col]].copy()
|
||||||
temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce')
|
temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce')
|
||||||
temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce')
|
temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce')
|
||||||
|
|
||||||
valid_df = temp_df.dropna()
|
valid_df = temp_df.dropna()
|
||||||
|
|
||||||
for _, row in valid_df.iterrows():
|
for _, row in valid_df.iterrows():
|
||||||
@@ -1233,10 +1324,9 @@ for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot):
|
|||||||
if gt_cat in category_to_index and res_cat in category_to_index:
|
if gt_cat in category_to_index and res_cat in category_to_index:
|
||||||
cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1
|
cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1
|
||||||
|
|
||||||
# Create an RGB image matrix (initially white/empty)
|
# Create an RGBA image matrix (10x10x4)
|
||||||
rgb_matrix = np.ones((n_categories, n_categories, 3))
|
rgba_matrix = np.zeros((n_categories, n_categories, 4))
|
||||||
|
|
||||||
# Create an Alpha matrix for the "Total Count" intensity
|
|
||||||
total_counts = np.sum(cell_system_counts, axis=2)
|
total_counts = np.sum(cell_system_counts, axis=2)
|
||||||
max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1
|
max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1
|
||||||
|
|
||||||
@@ -1244,53 +1334,602 @@ for i in range(n_categories):
|
|||||||
for j in range(n_categories):
|
for j in range(n_categories):
|
||||||
count_sum = total_counts[i, j]
|
count_sum = total_counts[i, j]
|
||||||
if count_sum > 0:
|
if count_sum > 0:
|
||||||
# Calculate weighted average color
|
|
||||||
mixed_rgb = np.zeros(3)
|
mixed_rgb = np.zeros(3)
|
||||||
for s_idx, s_name in enumerate(system_names):
|
for s_idx, s_name in enumerate(system_names):
|
||||||
weight = cell_system_counts[i, j, s_idx] / count_sum
|
weight = cell_system_counts[i, j, s_idx] / count_sum
|
||||||
system_rgb = mcolors.to_rgb(color_map[s_name])
|
system_rgb = mcolors.to_rgb(color_map[s_name])
|
||||||
mixed_rgb += np.array(system_rgb) * weight
|
mixed_rgb += np.array(system_rgb) * weight
|
||||||
rgb_matrix[i, j] = mixed_rgb
|
|
||||||
|
# Set RGB channels
|
||||||
|
rgba_matrix[i, j, :3] = mixed_rgb
|
||||||
|
|
||||||
|
# Set Alpha channel (Saturation Effect)
|
||||||
|
# Using a square root scale to make lower counts more visible but still "lighter"
|
||||||
|
alpha = np.sqrt(count_sum / max_count)
|
||||||
|
# Ensure alpha is at least 0.1 so it's not invisible
|
||||||
|
rgba_matrix[i, j, 3] = max(0.1, alpha)
|
||||||
|
else:
|
||||||
|
# Empty cells are white
|
||||||
|
rgba_matrix[i, j] = [1, 1, 1, 0]
|
||||||
|
|
||||||
# --- 5. Plotting ---
|
# --- 5. Plotting ---
|
||||||
fig, ax = plt.subplots(figsize=(12, 10))
|
fig, ax = plt.subplots(figsize=(12, 10))
|
||||||
|
|
||||||
# Display the mixed color matrix
|
# Show the matrix
|
||||||
# We use alpha based on count to show density (optional, but recommended)
|
# Note: we use origin='lower' if you want 0-1 at the bottom,
|
||||||
im = ax.imshow(rgb_matrix, interpolation='nearest', origin='upper')
|
# but confusion matrices usually have 0-1 at the top (origin='upper')
|
||||||
|
im = ax.imshow(rgba_matrix, interpolation='nearest', origin='upper')
|
||||||
|
|
||||||
# Add text labels for total counts in each cell
|
# Add count labels
|
||||||
for i in range(n_categories):
|
for i in range(n_categories):
|
||||||
for j in range(n_categories):
|
for j in range(n_categories):
|
||||||
if total_counts[i, j] > 0:
|
if total_counts[i, j] > 0:
|
||||||
# Determine text color based on brightness of background
|
# Background brightness for text contrast
|
||||||
lum = 0.2126 * rgb_matrix[i,j,0] + 0.7152 * rgb_matrix[i,j,1] + 0.0722 * rgb_matrix[i,j,2]
|
bg_color = rgba_matrix[i, j, :3]
|
||||||
text_col = "white" if lum < 0.5 else "black"
|
lum = 0.2126 * bg_color[0] + 0.7152 * bg_color[1] + 0.0722 * bg_color[2]
|
||||||
ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", color=text_col, fontsize=9)
|
# If alpha is low, background is effectively white, so use black text
|
||||||
|
text_col = "white" if (lum < 0.5 and rgba_matrix[i,j,3] > 0.5) else "black"
|
||||||
|
ax.text(j, i, int(total_counts[i, j]), ha="center", va="center",
|
||||||
|
color=text_col, fontsize=10, fontweight='bold')
|
||||||
|
|
||||||
# --- 6. Styling ---
|
# --- 6. Styling ---
|
||||||
ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12)
|
ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12, labelpad=10)
|
||||||
ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12)
|
ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12, labelpad=10)
|
||||||
ax.set_title('Blended Confusion Matrix (Color = Weighted System Mixture)', fontsize=14, pad=20)
|
ax.set_title('Saturated Confusion Matrix\nColor = System Mixture | Opacity = Density', fontsize=14, pad=20)
|
||||||
|
|
||||||
ax.set_xticks(np.arange(n_categories))
|
ax.set_xticks(np.arange(n_categories))
|
||||||
ax.set_xticklabels(categories)
|
ax.set_xticklabels(categories)
|
||||||
ax.set_yticks(np.arange(n_categories))
|
ax.set_yticks(np.arange(n_categories))
|
||||||
ax.set_yticklabels(categories)
|
ax.set_yticklabels(categories)
|
||||||
|
|
||||||
|
# Remove the frame/spines for a cleaner look
|
||||||
|
for spine in ax.spines.values():
|
||||||
|
spine.set_visible(False)
|
||||||
|
|
||||||
# Custom Legend
|
# Custom Legend
|
||||||
handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names]
|
handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names]
|
||||||
labels = [name.replace('_', ' ').capitalize() for name in system_names]
|
labels = [name.replace('_', ' ').capitalize() for name in system_names]
|
||||||
ax.legend(handles, labels, title='Functional Systems', loc='upper left', bbox_to_anchor=(1.05, 1), frameon=False)
|
ax.legend(handles, labels, title='Functional Systems', loc='upper left',
|
||||||
|
bbox_to_anchor=(1.05, 1), frameon=False)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
|
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
|
||||||
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %% Difference Plot Functional system
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
# Set the font to Arial for all text in the plot, as per the guidelines
|
||||||
|
plt.rcParams['font.family'] = 'Arial'
|
||||||
|
|
||||||
|
# Define the path to your data file
|
||||||
|
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
|
||||||
|
|
||||||
|
# Define the path to save the color mapping JSON file
|
||||||
|
color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json'
|
||||||
|
|
||||||
|
# Define the path to save the final figure
|
||||||
|
figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg'
|
||||||
|
|
||||||
|
# --- 1. Load the Dataset ---
|
||||||
|
try:
|
||||||
|
# Load the TSV file
|
||||||
|
df = pd.read_csv(data_path, sep='\t')
|
||||||
|
print(f"Successfully loaded data from {data_path}")
|
||||||
|
print(f"Data shape: {df.shape}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: The file at {data_path} was not found.")
|
||||||
|
# Exit or handle the error appropriately
|
||||||
|
raise
|
||||||
|
|
||||||
|
# --- 2. Define Functional Systems and Create Color Mapping ---
|
||||||
|
# List of tuples containing (ground_truth_column, result_column)
|
||||||
|
functional_systems_to_plot = [
|
||||||
|
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
|
||||||
|
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
|
||||||
|
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
|
||||||
|
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
|
||||||
|
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
|
||||||
|
('GT.AMBULATION', 'result.AMBULATION'),
|
||||||
|
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
|
||||||
|
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS')
|
||||||
|
]
|
||||||
|
|
||||||
|
# Extract system names for color mapping and legend
|
||||||
|
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
|
||||||
|
|
||||||
|
# Define a professional color palette (dark blue theme)
|
||||||
|
# This is a qualitative palette with distinct, accessible colors
|
||||||
|
colors = [
|
||||||
|
'#003366', # Dark Blue
|
||||||
|
'#336699', # Medium Blue
|
||||||
|
'#6699CC', # Light Blue
|
||||||
|
'#99CCFF', # Very Light Blue
|
||||||
|
'#FF9966', # Coral
|
||||||
|
'#FF6666', # Light Red
|
||||||
|
'#CC6699', # Magenta
|
||||||
|
'#9966CC' # Purple
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a dictionary mapping system names to colors
|
||||||
|
color_map = dict(zip(system_names, colors))
|
||||||
|
|
||||||
|
# Ensure the directory for the JSON file exists
|
||||||
|
os.makedirs(os.path.dirname(color_json_path), exist_ok=True)
|
||||||
|
|
||||||
|
# Save the color map to a JSON file
|
||||||
|
with open(color_json_path, 'w') as f:
|
||||||
|
json.dump(color_map, f, indent=4)
|
||||||
|
|
||||||
|
print(f"Color mapping saved to {color_json_path}")
|
||||||
|
|
||||||
|
# --- 3. Calculate Agreement Percentages and Format Legend Labels ---
|
||||||
|
agreement_percentages = {}
|
||||||
|
legend_labels = {}
|
||||||
|
|
||||||
|
for gt_col, res_col in functional_systems_to_plot:
|
||||||
|
system_name = gt_col.split('.')[1]
|
||||||
|
|
||||||
|
# Convert columns to numeric, setting errors to NaN
|
||||||
|
gt_numeric = pd.to_numeric(df[gt_col], errors='coerce')
|
||||||
|
res_numeric = pd.to_numeric(df[res_col], errors='coerce')
|
||||||
|
|
||||||
|
# Ensure we are comparing the same rows
|
||||||
|
common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index)
|
||||||
|
gt_data = gt_numeric.loc[common_index]
|
||||||
|
res_data = res_numeric.loc[common_index]
|
||||||
|
|
||||||
|
# Calculate agreement percentage
|
||||||
|
if len(gt_data) > 0:
|
||||||
|
agreement = (gt_data == res_data).mean() * 100
|
||||||
|
else:
|
||||||
|
agreement = 0 # Handle case with no valid data
|
||||||
|
|
||||||
|
agreement_percentages[system_name] = agreement
|
||||||
|
|
||||||
|
# Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions")
|
||||||
|
formatted_name = " ".join(word.capitalize() for word in system_name.split('_'))
|
||||||
|
legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)"
|
||||||
|
|
||||||
|
|
||||||
|
# --- 4. Robustly Prepare Error Data for Boxplot ---
|
||||||
|
|
||||||
|
def safe_parse(s):
|
||||||
|
'''Convert to float, handling comma decimals (e.g., '3,5' → 3.5)'''
|
||||||
|
if pd.isna(s):
|
||||||
|
return np.nan
|
||||||
|
if isinstance(s, (int, float)):
|
||||||
|
return float(s)
|
||||||
|
# Replace comma with dot, then strip whitespace
|
||||||
|
s_clean = str(s).replace(',', '.').strip()
|
||||||
|
try:
|
||||||
|
return float(s_clean)
|
||||||
|
except ValueError:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
plot_data = []
|
||||||
|
for gt_col, res_col in functional_systems_to_plot:
|
||||||
|
system_name = gt_col.split('.')[1]
|
||||||
|
|
||||||
|
# Parse both columns with robust comma handling
|
||||||
|
gt_numeric = df[gt_col].apply(safe_parse)
|
||||||
|
res_numeric = df[res_col].apply(safe_parse)
|
||||||
|
|
||||||
|
# Compute error (only where both are finite)
|
||||||
|
error = res_numeric - gt_numeric
|
||||||
|
|
||||||
|
# Create temp DataFrame
|
||||||
|
temp_df = pd.DataFrame({
|
||||||
|
'system': system_name,
|
||||||
|
'error': error
|
||||||
|
}).dropna() # drop rows where either was unparseable
|
||||||
|
|
||||||
|
plot_data.append(temp_df)
|
||||||
|
|
||||||
|
plot_df = pd.concat(plot_data, ignore_index=True)
|
||||||
|
|
||||||
|
if plot_df.empty:
|
||||||
|
print("⚠️ Warning: No valid numeric error data to plot after robust parsing.")
|
||||||
|
else:
|
||||||
|
print(f"✅ Prepared error data with {len(plot_df)} data points.")
|
||||||
|
# Diagnostic: show a few samples
|
||||||
|
print("\n📌 Sample errors by system:")
|
||||||
|
for sys, grp in plot_df.groupby('system'):
|
||||||
|
print(f" {sys:25s}: n={len(grp)}, mean err = {grp['error'].mean():+.2f}, min = {grp['error'].min():+.2f}, max = {grp['error'].max():+.2f}")
|
||||||
|
|
||||||
|
# Ensure categorical ordering
|
||||||
|
plot_df['system'] = pd.Categorical(
|
||||||
|
plot_df['system'],
|
||||||
|
categories=[name.split('.')[1] for name, _ in functional_systems_to_plot],
|
||||||
|
ordered=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- 5. Prepare Data for Diverging Stacked Bar Plot ---
|
||||||
|
print("\n📊 Preparing diverging stacked bar plot data...")
|
||||||
|
|
||||||
|
# Define bins for error direction
|
||||||
|
def categorize_error(err):
|
||||||
|
if pd.isna(err):
|
||||||
|
return 'missing'
|
||||||
|
elif err < 0:
|
||||||
|
return 'underestimate'
|
||||||
|
elif err > 0:
|
||||||
|
return 'overestimate'
|
||||||
|
else:
|
||||||
|
return 'match'
|
||||||
|
|
||||||
|
# Add category column (only on finite errors)
|
||||||
|
plot_df_clean = plot_df[plot_df['error'].notna()].copy()
|
||||||
|
plot_df_clean['category'] = plot_df_clean['error'].apply(categorize_error)
|
||||||
|
|
||||||
|
# Count by system + category
|
||||||
|
category_counts = (
|
||||||
|
plot_df_clean
|
||||||
|
.groupby(['system', 'category'])
|
||||||
|
.size()
|
||||||
|
.unstack(fill_value=0)
|
||||||
|
.reindex(columns=['underestimate', 'match', 'overestimate'], fill_value=0)
|
||||||
|
)
|
||||||
|
# Reorder systems
|
||||||
|
category_counts = category_counts.reindex(system_names)
|
||||||
|
|
||||||
|
# Prepare for diverging plot:
|
||||||
|
# - Underestimates: plotted to the *left* (negative x)
|
||||||
|
# - Overestimates: plotted to the *right* (positive x)
|
||||||
|
# - Matches: centered (no width needed, or as a bar of width 0.2)
|
||||||
|
underestimate_counts = category_counts['underestimate']
|
||||||
|
match_counts = category_counts['match']
|
||||||
|
overestimate_counts = category_counts['overestimate']
|
||||||
|
|
||||||
|
# For diverging: left = -underestimate, right = overestimate
|
||||||
|
left_counts = underestimate_counts
|
||||||
|
right_counts = overestimate_counts
|
||||||
|
|
||||||
|
# Compute max absolute bar height (for symmetric x-axis)
|
||||||
|
max_bar = max(left_counts.max(), right_counts.max(), 1)
|
||||||
|
plot_range = (-max_bar, max_bar)
|
||||||
|
|
||||||
|
# X-axis positions: 0 = center, left systems to -1, -2, ..., right systems to +1, +2, ...
|
||||||
|
n_systems = len(system_names)
|
||||||
|
positions = np.arange(n_systems)
|
||||||
|
left_positions = -positions - 0.5 # left-aligned underestimates
|
||||||
|
right_positions = positions + 0.5 # right-aligned overestimates
|
||||||
|
|
||||||
|
# --- 6. Create Diverging Stacked Bar Plot ---
|
||||||
|
plt.figure(figsize=(12, 7))
|
||||||
|
|
||||||
|
# Colors: diverging palette
|
||||||
|
colors = {
|
||||||
|
'underestimate': '#E74C3C', # Red (left)
|
||||||
|
'match': '#2ECC71', # Green (center)
|
||||||
|
'overestimate': '#F39C12' # Orange (right)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Plot underestimates (left side)
|
||||||
|
bars_left = plt.barh(
|
||||||
|
left_positions,
|
||||||
|
left_counts.values,
|
||||||
|
height=0.8,
|
||||||
|
left=0, # starts at 0, extends left (since bars are negative width would be wrong; instead use negative values)
|
||||||
|
color=colors['underestimate'],
|
||||||
|
edgecolor='black',
|
||||||
|
linewidth=0.5,
|
||||||
|
alpha=0.9,
|
||||||
|
label='Underestimate'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot overestimates (right side)
|
||||||
|
bars_right = plt.barh(
|
||||||
|
right_positions,
|
||||||
|
right_counts.values,
|
||||||
|
height=0.8,
|
||||||
|
left=0,
|
||||||
|
color=colors['overestimate'],
|
||||||
|
edgecolor='black',
|
||||||
|
linewidth=0.5,
|
||||||
|
alpha=0.9,
|
||||||
|
label='Overestimate'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot matches (center — narrow bar)
|
||||||
|
# Use a very narrow width (0.2) centered at 0
|
||||||
|
plt.barh(
|
||||||
|
positions,
|
||||||
|
match_counts.values,
|
||||||
|
height=0.2,
|
||||||
|
left=0, # starts at 0, extends right
|
||||||
|
color=colors['match'],
|
||||||
|
edgecolor='black',
|
||||||
|
linewidth=0.5,
|
||||||
|
alpha=0.9,
|
||||||
|
label='Exact Match'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ✨ Better: flip match to be centered symmetrically (left=-match/2, width=match)
|
||||||
|
# For perfect symmetry:
|
||||||
|
for i, count in enumerate(match_counts.values):
|
||||||
|
if count > 0:
|
||||||
|
plt.barh(
|
||||||
|
positions[i],
|
||||||
|
width=count,
|
||||||
|
left=-count/2,
|
||||||
|
height=0.25,
|
||||||
|
color=colors['match'],
|
||||||
|
edgecolor='black',
|
||||||
|
linewidth=0.5,
|
||||||
|
alpha=0.95
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- 7. Styling & Labels ---
|
||||||
|
# Zero reference line
|
||||||
|
plt.axvline(x=0, color='black', linestyle='-', linewidth=1.2, alpha=0.8)
|
||||||
|
|
||||||
|
# X-axis: symmetric around 0
|
||||||
|
plt.xlim(plot_range[0] - max_bar*0.1, plot_range[1] + max_bar*0.1)
|
||||||
|
plt.xticks(rotation=0, fontsize=10)
|
||||||
|
plt.xlabel('Count', fontsize=12)
|
||||||
|
|
||||||
|
# Y-axis: system names at original positions (centered)
|
||||||
|
plt.yticks(positions, [name.replace('_', '\n').replace('and', '&') for name in system_names], fontsize=10)
|
||||||
|
plt.ylabel('Functional System', fontsize=12)
|
||||||
|
|
||||||
|
# Title & layout
|
||||||
|
plt.title('Diverging Error Direction by Functional System\n(Red: Underestimation | Green: Exact | Orange: Overestimation)', fontsize=13, pad=15)
|
||||||
|
|
||||||
|
# Clean axes
|
||||||
|
ax = plt.gca()
|
||||||
|
ax.spines['top'].set_visible(False)
|
||||||
|
ax.spines['right'].set_visible(False)
|
||||||
|
ax.spines['left'].set_visible(False) # We only need bottom axis
|
||||||
|
ax.xaxis.set_ticks_position('bottom')
|
||||||
|
ax.yaxis.set_ticks_position('none')
|
||||||
|
|
||||||
|
# Grid only along x
|
||||||
|
ax.xaxis.grid(True, linestyle=':', alpha=0.5)
|
||||||
|
|
||||||
|
# Legend
|
||||||
|
from matplotlib.patches import Patch
|
||||||
|
legend_elements = [
|
||||||
|
Patch(facecolor=colors['underestimate'], edgecolor='black', label='Underestimate'),
|
||||||
|
Patch(facecolor=colors['match'], edgecolor='black', label='Exact Match'),
|
||||||
|
Patch(facecolor=colors['overestimate'], edgecolor='black', label='Overestimate')
|
||||||
|
]
|
||||||
|
plt.legend(handles=legend_elements, loc='upper right', frameon=False, fontsize=10)
|
||||||
|
|
||||||
|
# Optional: Add counts on bars
|
||||||
|
for i, (left, right, match) in enumerate(zip(left_counts, right_counts, match_counts)):
|
||||||
|
if left > 0:
|
||||||
|
plt.text(-left - max_bar*0.05, left_positions[i], str(left), va='center', ha='right', fontsize=9, color='white', fontweight='bold')
|
||||||
|
if right > 0:
|
||||||
|
plt.text(right + max_bar*0.05, right_positions[i], str(right), va='center', ha='left', fontsize=9, color='white', fontweight='bold')
|
||||||
|
if match > 0:
|
||||||
|
plt.text(match_counts[i]/2, positions[i], str(match), va='center', ha='center', fontsize=8, color='black')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# --- 8. Save & Show ---
|
||||||
|
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
|
||||||
|
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
||||||
|
print(f"✅ Diverging bar plot saved to {figure_save_path}")
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %% Difference Gemini easy
|
||||||
|
|
||||||
|
|
||||||
|
# --- 1. Process Error Data ---
|
||||||
|
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
|
||||||
|
plot_list = []
|
||||||
|
|
||||||
|
for gt_col, res_col in functional_systems_to_plot:
|
||||||
|
sys_name = gt_col.split('.')[1]
|
||||||
|
|
||||||
|
# Robust parsing
|
||||||
|
gt = df[gt_col].apply(safe_parse)
|
||||||
|
res = df[res_col].apply(safe_parse)
|
||||||
|
error = res - gt
|
||||||
|
|
||||||
|
# Calculate counts
|
||||||
|
matches = (error == 0).sum()
|
||||||
|
under = (error < 0).sum()
|
||||||
|
over = (error > 0).sum()
|
||||||
|
total = error.dropna().count()
|
||||||
|
|
||||||
|
# Calculate Percentages
|
||||||
|
# Using max(total, 1) to avoid division by zero
|
||||||
|
divisor = max(total, 1)
|
||||||
|
match_pct = (matches / divisor) * 100
|
||||||
|
under_pct = (under / divisor) * 100
|
||||||
|
over_pct = (over / divisor) * 100
|
||||||
|
|
||||||
|
plot_list.append({
|
||||||
|
'System': sys_name.replace('_', ' ').title(),
|
||||||
|
'Matches': matches,
|
||||||
|
'MatchPct': match_pct,
|
||||||
|
'Under': under,
|
||||||
|
'UnderPct': under_pct,
|
||||||
|
'Over': over,
|
||||||
|
'OverPct': over_pct
|
||||||
|
})
|
||||||
|
|
||||||
|
stats_df = pd.DataFrame(plot_list)
|
||||||
|
|
||||||
|
# --- 2. Plotting ---
|
||||||
|
fig, ax = plt.subplots(figsize=(12, 8)) # Slightly taller for multi-line labels
|
||||||
|
|
||||||
|
color_under = '#E74C3C'
|
||||||
|
color_over = '#3498DB'
|
||||||
|
bar_height = 0.6
|
||||||
|
|
||||||
|
y_pos = np.arange(len(stats_df))
|
||||||
|
|
||||||
|
ax.barh(y_pos, -stats_df['Under'], bar_height, label='Under-scored', color=color_under, edgecolor='white', alpha=0.8)
|
||||||
|
ax.barh(y_pos, stats_df['Over'], bar_height, label='Over-scored', color=color_over, edgecolor='white', alpha=0.8)
|
||||||
|
|
||||||
|
# --- 3. Aesthetics & Labels ---
|
||||||
|
|
||||||
|
for i, row in stats_df.iterrows():
|
||||||
|
# Constructing a detailed label for the left side
|
||||||
|
# Matches (Bold) | Under % | Over %
|
||||||
|
label_text = (
|
||||||
|
f"$\mathbf{{{row['System']}}}$\n"
|
||||||
|
f"Matches: {int(row['Matches'])} ({row['MatchPct']:.1f}%)\n"
|
||||||
|
f"Under: {int(row['Under'])} ({row['UnderPct']:.1f}%) | Over: {int(row['Over'])} ({row['OverPct']:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Position text to the left of the x=0 line
|
||||||
|
ax.text(ax.get_xlim()[0] - 0.5, i, label_text, va='center', ha='right', fontsize=9, color='#333333', linespacing=1.3)
|
||||||
|
|
||||||
|
# Zero line
|
||||||
|
ax.axvline(0, color='black', linewidth=1.2, alpha=0.7)
|
||||||
|
|
||||||
|
# Clean up axes
|
||||||
|
ax.set_yticks([])
|
||||||
|
ax.set_xlabel('Number of Patients with Error', fontsize=11, fontweight='bold', labelpad=10)
|
||||||
|
#ax.set_title('Directional Error Analysis by Functional System', fontsize=14, pad=30)
|
||||||
|
|
||||||
|
# Make X-axis labels absolute
|
||||||
|
ax.set_xticklabels([int(abs(tick)) for tick in ax.get_xticks()])
|
||||||
|
|
||||||
|
# Remove spines
|
||||||
|
for spine in ['top', 'right', 'left']:
|
||||||
|
ax.spines[spine].set_visible(False)
|
||||||
|
|
||||||
|
# Legend
|
||||||
|
ax.legend(loc='upper right', frameon=False, bbox_to_anchor=(1, 1.1))
|
||||||
|
|
||||||
|
# Grid
|
||||||
|
ax.xaxis.grid(True, linestyle='--', alpha=0.3)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
##
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %% name
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# --- Configuration & Theme ---
|
||||||
|
plt.rcParams['font.family'] = 'Arial'
|
||||||
|
figure_save_path = 'project/visuals/functional_systems_magnitude_focus.svg'
|
||||||
|
|
||||||
|
# --- 1. Process Error Data with Magnitude Breakdown ---
|
||||||
|
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
|
||||||
|
plot_list = []
|
||||||
|
|
||||||
|
for gt_col, res_col in functional_systems_to_plot:
|
||||||
|
sys_name = gt_col.split('.')[1]
|
||||||
|
|
||||||
|
# Robust parsing
|
||||||
|
gt = df[gt_col].apply(safe_parse)
|
||||||
|
res = df[res_col].apply(safe_parse)
|
||||||
|
error = res - gt
|
||||||
|
|
||||||
|
# Granular Counts
|
||||||
|
matches = (error == 0).sum()
|
||||||
|
u_1 = (error == -1).sum()
|
||||||
|
u_2plus = (error <= -2).sum()
|
||||||
|
o_1 = (error == 1).sum()
|
||||||
|
o_2plus = (error >= 2).sum()
|
||||||
|
|
||||||
|
total = error.dropna().count()
|
||||||
|
divisor = max(total, 1)
|
||||||
|
|
||||||
|
plot_list.append({
|
||||||
|
'System': sys_name.replace('_', ' ').title(),
|
||||||
|
'Matches': matches, 'MatchPct': (matches / divisor) * 100,
|
||||||
|
'U1': u_1, 'U2': u_2plus, 'UnderTotal': u_1 + u_2plus,
|
||||||
|
'UnderPct': ((u_1 + u_2plus) / divisor) * 100,
|
||||||
|
'O1': o_1, 'O2': o_2plus, 'OverTotal': o_1 + o_2plus,
|
||||||
|
'OverPct': ((o_1 + o_2plus) / divisor) * 100
|
||||||
|
})
|
||||||
|
|
||||||
|
stats_df = pd.DataFrame(plot_list)
|
||||||
|
|
||||||
|
# --- 2. Plotting ---
|
||||||
|
fig, ax = plt.subplots(figsize=(13, 8))
|
||||||
|
|
||||||
|
# Define Magnitude Colors
|
||||||
|
c_under_dark, c_under_light = '#C0392B', '#E74C3C' # Dark Red (-2+), Soft Red (-1)
|
||||||
|
c_over_dark, c_over_light = '#2980B9', '#3498DB' # Dark Blue (+2+), Soft Blue (+1)
|
||||||
|
bar_height = 0.6
|
||||||
|
y_pos = np.arange(len(stats_df))
|
||||||
|
|
||||||
|
# Plot Under-scored (Stacked: -2+ then -1)
|
||||||
|
ax.barh(y_pos, -stats_df['U2'], bar_height, color=c_under_dark, label='Under -2+', edgecolor='white')
|
||||||
|
ax.barh(y_pos, -stats_df['U1'], bar_height, left=-stats_df['U2'], color=c_under_light, label='Under -1', edgecolor='white')
|
||||||
|
|
||||||
|
# Plot Over-scored (Stacked: +1 then +2+)
|
||||||
|
ax.barh(y_pos, stats_df['O1'], bar_height, color=c_over_light, label='Over +1', edgecolor='white')
|
||||||
|
ax.barh(y_pos, stats_df['O2'], bar_height, left=stats_df['O1'], color=c_over_dark, label='Over +2+', edgecolor='white')
|
||||||
|
|
||||||
|
# --- 3. Aesthetics & Table Labels ---
|
||||||
|
for i, row in stats_df.iterrows():
|
||||||
|
label_text = (
|
||||||
|
f"$\\mathbf{{{row['System']}}}$\n"
|
||||||
|
f"Match: {int(row['Matches'])} ({row['MatchPct']:.1f}%)\n"
|
||||||
|
f"Under: {int(row['UnderTotal'])} ({row['UnderPct']:.1f}%) | Over: {int(row['OverTotal'])} ({row['OverPct']:.1f}%)"
|
||||||
|
)
|
||||||
|
# Position table text to the left
|
||||||
|
ax.text(ax.get_xlim()[0] - 0.5, i, label_text, va='center', ha='right', fontsize=9, color='#333333', linespacing=1.4)
|
||||||
|
|
||||||
|
# Formatting
|
||||||
|
ax.axvline(0, color='black', linewidth=1.2)
|
||||||
|
ax.set_yticks([])
|
||||||
|
ax.set_xlabel('Number of Patients with Error', fontsize=11, fontweight='bold')
|
||||||
|
#ax.set_title('Directional Error Magnitude (Under vs. Over Scoring)', fontsize=14, pad=35)
|
||||||
|
|
||||||
|
# Absolute X-axis labels
|
||||||
|
ax.set_xticklabels([int(abs(tick)) for tick in ax.get_xticks()])
|
||||||
|
|
||||||
|
# Remove spines and add grid
|
||||||
|
for spine in ['top', 'right', 'left']: ax.spines[spine].set_visible(False)
|
||||||
|
ax.xaxis.grid(True, linestyle='--', alpha=0.3)
|
||||||
|
|
||||||
|
# Legend with magnitude info
|
||||||
|
ax.legend(loc='upper right', frameon=False, bbox_to_anchor=(1, 1.1), ncol=2)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
##
|
||||||
|
# %% test
|
||||||
|
# Diagnose: what are the actual differences?
|
||||||
|
print("\n🔍 Raw differences (first 5 rows per system):")
|
||||||
|
for gt_col, res_col in functional_systems_to_plot:
|
||||||
|
gt = df[gt_col].apply(safe_parse)
|
||||||
|
res = df[res_col].apply(safe_parse)
|
||||||
|
diff = res - gt
|
||||||
|
non_zero = (diff != 0).sum()
|
||||||
|
# Check if it's due to floating point noise
|
||||||
|
abs_diff = diff.abs()
|
||||||
|
tiny = (abs_diff > 0) & (abs_diff < 1e-10)
|
||||||
|
print(f"{gt_col.split('.')[1]:25s}: non-zero = {non_zero:3d}, tiny = {tiny.sum():3d}, max abs diff = {abs_diff.max():.12f}")
|
||||||
|
|
||||||
|
##
|
||||||
|
|||||||
135
Data/style2.py
135
Data/style2.py
@@ -1,135 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import seaborn as sns
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import dataframe_image as dfi
|
|
||||||
# Load data
|
|
||||||
df = pd.read_csv("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv", sep='\t')
|
|
||||||
|
|
||||||
# 1. Identify all GT and result columns
|
|
||||||
gt_columns = [col for col in df.columns if col.startswith('GT.')]
|
|
||||||
result_columns = [col for col in df.columns if col.startswith('result.')]
|
|
||||||
|
|
||||||
print("GT Columns found:", gt_columns)
|
|
||||||
print("Result Columns found:", result_columns)
|
|
||||||
|
|
||||||
# 2. Create proper mapping between GT and result columns
|
|
||||||
# Handle various naming conventions (spaces, underscores, etc.)
|
|
||||||
column_mapping = {}
|
|
||||||
|
|
||||||
for gt_col in gt_columns:
|
|
||||||
base_name = gt_col.replace('GT.', '')
|
|
||||||
|
|
||||||
# Clean the base name for matching - remove spaces, underscores, etc.
|
|
||||||
# Try different matching approaches
|
|
||||||
candidates = [
|
|
||||||
f'result.{base_name}', # Exact match
|
|
||||||
f'result.{base_name.replace(" ", "_")}', # With underscores
|
|
||||||
f'result.{base_name.replace("_", " ")}', # With spaces
|
|
||||||
f'result.{base_name.replace(" ", "")}', # No spaces
|
|
||||||
f'result.{base_name.replace("_", "")}' # No underscores
|
|
||||||
]
|
|
||||||
|
|
||||||
# Also try case-insensitive matching
|
|
||||||
candidates.append(f'result.{base_name.lower()}')
|
|
||||||
candidates.append(f'result.{base_name.upper()}')
|
|
||||||
|
|
||||||
# Try to find matching result column
|
|
||||||
matched = False
|
|
||||||
for candidate in candidates:
|
|
||||||
if candidate in result_columns:
|
|
||||||
column_mapping[gt_col] = candidate
|
|
||||||
matched = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# If no exact match found, try partial matching
|
|
||||||
if not matched:
|
|
||||||
# Try to match by removing special characters and comparing
|
|
||||||
base_clean = ''.join(e for e in base_name if e.isalnum() or e in ['_', ' '])
|
|
||||||
for result_col in result_columns:
|
|
||||||
result_base = result_col.replace('result.', '')
|
|
||||||
result_clean = ''.join(e for e in result_base if e.isalnum() or e in ['_', ' '])
|
|
||||||
if base_clean.lower() == result_clean.lower():
|
|
||||||
column_mapping[gt_col] = result_col
|
|
||||||
matched = True
|
|
||||||
break
|
|
||||||
|
|
||||||
print("Column mapping:", column_mapping)
|
|
||||||
|
|
||||||
# 3. Faster, vectorized computation using the corrected mapping
|
|
||||||
data_list = []
|
|
||||||
|
|
||||||
for gt_col, result_col in column_mapping.items():
|
|
||||||
print(f"Processing {gt_col} vs {result_col}")
|
|
||||||
|
|
||||||
# Convert to numeric, forcing errors to NaN
|
|
||||||
s1 = pd.to_numeric(df[gt_col], errors='coerce').astype(float)
|
|
||||||
s2 = pd.to_numeric(df[result_col], errors='coerce').astype(float)
|
|
||||||
|
|
||||||
# Calculate matches (abs difference <= 0.5)
|
|
||||||
diff = np.abs(s1 - s2)
|
|
||||||
matches = (diff <= 0.5).sum()
|
|
||||||
|
|
||||||
# Determine the denominator (total valid comparisons)
|
|
||||||
valid_count = diff.notna().sum()
|
|
||||||
|
|
||||||
if valid_count > 0:
|
|
||||||
percentage = (matches / valid_count) * 100
|
|
||||||
else:
|
|
||||||
percentage = 0
|
|
||||||
|
|
||||||
# Extract clean base name for display
|
|
||||||
base_name = gt_col.replace('GT.', '')
|
|
||||||
|
|
||||||
data_list.append({
|
|
||||||
'GT': base_name,
|
|
||||||
'Match %': round(percentage, 1)
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 4. Prepare Data for Plotting
|
|
||||||
match_df = pd.DataFrame(data_list)
|
|
||||||
match_df = match_df.sort_values('Match %', ascending=False) # Sort for better visual flow
|
|
||||||
|
|
||||||
# 5. Create the Styled Gradient Table
|
|
||||||
def style_agreement_table(df):
|
|
||||||
return (df.style
|
|
||||||
.format({'Match %': '{:.1f}%'}) # Add % sign
|
|
||||||
.background_gradient(cmap='RdYlGn', subset=['Match %'], vmin=50, vmax=100) # Red to Green gradient
|
|
||||||
.set_properties(**{
|
|
||||||
'text-align': 'center',
|
|
||||||
'font-size': '12pt',
|
|
||||||
'border-collapse': 'collapse',
|
|
||||||
'border': '1px solid #D3D3D3'
|
|
||||||
})
|
|
||||||
.set_table_styles([
|
|
||||||
# Style the header
|
|
||||||
{'selector': 'th', 'props': [
|
|
||||||
('background-color', '#404040'),
|
|
||||||
('color', 'white'),
|
|
||||||
('font-weight', 'bold'),
|
|
||||||
('text-transform', 'uppercase'),
|
|
||||||
('padding', '10px')
|
|
||||||
]},
|
|
||||||
# Add hover effect
|
|
||||||
{'selector': 'tr:hover', 'props': [('background-color', '#f5f5f5')]}
|
|
||||||
])
|
|
||||||
.set_caption("EDSS Agreement Analysis: Ground Truth vs. Results (Tolerance ±0.5)")
|
|
||||||
)
|
|
||||||
|
|
||||||
# To display in a Jupyter Notebook:
|
|
||||||
styled_table = style_agreement_table(match_df)
|
|
||||||
styled_table
|
|
||||||
|
|
||||||
dfi.export(styled_table, "styled_table.png")
|
|
||||||
#styled_table.to_html("agreement_report.html")
|
|
||||||
# 6. Save as SVG
|
|
||||||
|
|
||||||
#plt.savefig("agreement_table.svg", format='svg', dpi=300, bbox_inches='tight')
|
|
||||||
#print("Successfully saved agreement_table.svg")
|
|
||||||
|
|
||||||
# Show plot if running in a GUI environment
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import seaborn as sns
|
|
||||||
|
|
||||||
# Sample data (replace with your actual df)
|
|
||||||
df = pd.read_csv("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv", sep='\t')
|
|
||||||
|
|
||||||
# Identify GT and Result columns
|
|
||||||
gt_columns = [col for col in df.columns if col.startswith('GT.')]
|
|
||||||
result_columns = [col for col in df.columns if col.startswith('result.')]
|
|
||||||
|
|
||||||
# Create mapping
|
|
||||||
column_mapping = {}
|
|
||||||
for gt_col in gt_columns:
|
|
||||||
base_name = gt_col.replace('GT.', '')
|
|
||||||
result_col = f'result.{base_name}'
|
|
||||||
if result_col in result_columns:
|
|
||||||
column_mapping[gt_col] = result_col
|
|
||||||
|
|
||||||
# Function to compute match percentage for each GT-Result pair
|
|
||||||
def compute_match_percentages(df, column_mapping):
|
|
||||||
percentages = []
|
|
||||||
for gt_col, result_col in column_mapping.items():
|
|
||||||
count = 0
|
|
||||||
total = len(df)
|
|
||||||
|
|
||||||
for _, row in df.iterrows():
|
|
||||||
gt_val = row[gt_col]
|
|
||||||
result_val = row[result_col]
|
|
||||||
|
|
||||||
# Handle NaN values
|
|
||||||
if pd.isna(gt_val) or pd.isna(result_val):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle non-numeric values
|
|
||||||
try:
|
|
||||||
gt_float = float(gt_val)
|
|
||||||
result_float = float(result_val)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
# Skip rows with non-numeric values
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if values are within 0.5 tolerance
|
|
||||||
if abs(gt_float - result_float) <= 0.5:
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
percentage = (count / total) * 100
|
|
||||||
percentages.append({
|
|
||||||
'GT_Column': gt_col,
|
|
||||||
'Result_Column': result_col,
|
|
||||||
'Match_Percentage': round(percentage, 1)
|
|
||||||
})
|
|
||||||
|
|
||||||
return pd.DataFrame(percentages)
|
|
||||||
|
|
||||||
# Compute match percentages
|
|
||||||
match_df = compute_match_percentages(df, column_mapping)
|
|
||||||
|
|
||||||
# Create a pivot table for gradient display (optional but helpful)
|
|
||||||
pivot_table = match_df.set_index(['GT_Column', 'Result_Column'])['Match_Percentage'].unstack(fill_value=0)
|
|
||||||
|
|
||||||
# Apply gradient background
|
|
||||||
cm = sns.light_palette("green", as_cmap=True)
|
|
||||||
styled_table = pivot_table.style.background_gradient(cmap=cm, axis=None)
|
|
||||||
|
|
||||||
# Display result
|
|
||||||
print("Agreement Percentage Table (with gradient):")
|
|
||||||
styled_table
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Save the styled table to a file
|
|
||||||
styled_table.to_html("agreement_report.html")
|
|
||||||
print("Report saved to agreement_report.html")
|
|
||||||
57
figure1.py
57
figure1.py
@@ -263,3 +263,60 @@ plt.legend(frameon=False, loc='upper center', bbox_to_anchor=(0.5, -0.05))
|
|||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.show()
|
plt.show()
|
||||||
##
|
##
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %% name
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# Data
|
||||||
|
data = {
|
||||||
|
'Visit': [9, 8, 7, 6, 5, 4, 3, 2, 1],
|
||||||
|
'patient_count': [2, 3, 3, 6, 13, 17, 28, 24, 32]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create figure and axis
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 6))
|
||||||
|
|
||||||
|
# Plot the bar chart
|
||||||
|
bars = ax.bar(data['Visit'], data['patient_count'], color='darkblue', label='Patients by Visit Count')
|
||||||
|
|
||||||
|
# Add labels and title
|
||||||
|
ax.set_xlabel('Visit Number (from last to first)', fontsize=12)
|
||||||
|
ax.set_ylabel('Number of Patients', fontsize=12)
|
||||||
|
ax.set_title('Patient Visits by Visit Number', fontsize=14)
|
||||||
|
|
||||||
|
# Invert x-axis to show Visit 9 on the left (descending order) if desired, but keep natural order (1–9 left to right)
|
||||||
|
# For descending order (9→1 from left to right), we'd need to reverse:
|
||||||
|
# Visit = data['Visit'][::-1], patient_count = data['patient_count'][::-1]
|
||||||
|
# But standard practice is ascending (1 to 9), so we'll sort accordingly:
|
||||||
|
# Let's sort by Visit to ensure left-to-right: 1,2,...,9
|
||||||
|
|
||||||
|
# Actually, your current Visit list is [9,8,...,1], which is descending.
|
||||||
|
# Let's sort by Visit for intuitive left-to-right increasing order:
|
||||||
|
sorted_indices = sorted(range(len(data['Visit'])), key=lambda i: data['Visit'][i])
|
||||||
|
visit_sorted = [data['Visit'][i] for i in sorted_indices]
|
||||||
|
count_sorted = [data['patient_count'][i] for i in sorted_indices]
|
||||||
|
|
||||||
|
# Re-plot with sorted x-axis:
|
||||||
|
ax.clear()
|
||||||
|
bars = ax.bar(visit_sorted, count_sorted, color='darkblue', label='Patients by Visit Count')
|
||||||
|
|
||||||
|
# Re-apply labels, etc.
|
||||||
|
ax.set_xlabel('Number of Visits', fontsize=12)
|
||||||
|
ax.set_ylabel('Number of Unique Patients', fontsize=12)
|
||||||
|
#ax.set_title('Number of Patients by Visit Number', fontsize=14)
|
||||||
|
|
||||||
|
# Add legend
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
# Improve layout and grid
|
||||||
|
ax.grid(axis='y', linestyle='--', alpha=0.7)
|
||||||
|
plt.xticks(visit_sorted) # Ensure all integer visit numbers are shown
|
||||||
|
|
||||||
|
# Show the plot
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
##
|
||||||
|
|||||||
Reference in New Issue
Block a user