update gitignore

This commit is contained in:
2026-02-04 15:29:56 +01:00
parent b2e9ccd2b6
commit c2ccb8cd11
2 changed files with 45 additions and 26 deletions

4
.gitignore vendored
View File

@@ -7,6 +7,10 @@
__pycache__/ __pycache__/
*.pyc *.pyc
=======
/reference/
*.svg
>>>>>>> Stashed changes
# 2. Ignore virtual environments COMPLETELY # 2. Ignore virtual environments COMPLETELY
# This must come BEFORE the unignore rule # This must come BEFORE the unignore rule
env*/ env*/

View File

@@ -662,7 +662,7 @@ print("\nFirst few rows:")
print(df.head()) print(df.head())
# Hardcode specific patient names # Hardcode specific patient names
patient_names = ['113c1470'] patient_names = ['6b56865d']
# Define the functional systems (columns to plot) - adjust based on actual column names # Define the functional systems (columns to plot) - adjust based on actual column names
functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual'] functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual']
@@ -1183,7 +1183,6 @@ import matplotlib.colors as mcolors
# --- Configuration --- # --- Configuration ---
plt.rcParams['font.family'] = 'Arial' plt.rcParams['font.family'] = 'Arial'
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json'
figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg' figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg'
# --- 1. Load the Dataset --- # --- 1. Load the Dataset ---
@@ -1212,19 +1211,19 @@ n_categories = len(categories)
def categorize_edss(value): def categorize_edss(value):
if pd.isna(value): return np.nan if pd.isna(value): return np.nan
idx = int(min(max(value, 0), 10) - 0.001) if value > 0 else 0 # Ensure value is float to avoid TypeError
val = float(value)
idx = int(min(max(val, 0), 10) - 0.001) if val > 0 else 0
return categories[min(idx, len(categories)-1)] return categories[min(idx, len(categories)-1)]
# --- 4. Prepare Mixed Color Matrix --- # --- 4. Prepare Mixed Color Matrix with Saturation ---
cell_system_counts = np.zeros((n_categories, n_categories, len(system_names))) cell_system_counts = np.zeros((n_categories, n_categories, len(system_names)))
for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot): for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot):
# CRITICAL FIX: Convert to numeric and drop NaNs in one go # Fix: Ensure numeric conversion to avoid string comparison errors
# 'coerce' turns non-numeric strings into NaN so they don't crash the script
temp_df = df[[gt_col, res_col]].copy() temp_df = df[[gt_col, res_col]].copy()
temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce') temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce')
temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce') temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce')
valid_df = temp_df.dropna() valid_df = temp_df.dropna()
for _, row in valid_df.iterrows(): for _, row in valid_df.iterrows():
@@ -1233,10 +1232,9 @@ for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot):
if gt_cat in category_to_index and res_cat in category_to_index: if gt_cat in category_to_index and res_cat in category_to_index:
cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1 cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1
# Create an RGB image matrix (initially white/empty) # Create an RGBA image matrix (10x10x4)
rgb_matrix = np.ones((n_categories, n_categories, 3)) rgba_matrix = np.zeros((n_categories, n_categories, 4))
# Create an Alpha matrix for the "Total Count" intensity
total_counts = np.sum(cell_system_counts, axis=2) total_counts = np.sum(cell_system_counts, axis=2)
max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1 max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1
@@ -1244,52 +1242,69 @@ for i in range(n_categories):
for j in range(n_categories): for j in range(n_categories):
count_sum = total_counts[i, j] count_sum = total_counts[i, j]
if count_sum > 0: if count_sum > 0:
# Calculate weighted average color
mixed_rgb = np.zeros(3) mixed_rgb = np.zeros(3)
for s_idx, s_name in enumerate(system_names): for s_idx, s_name in enumerate(system_names):
weight = cell_system_counts[i, j, s_idx] / count_sum weight = cell_system_counts[i, j, s_idx] / count_sum
system_rgb = mcolors.to_rgb(color_map[s_name]) system_rgb = mcolors.to_rgb(color_map[s_name])
mixed_rgb += np.array(system_rgb) * weight mixed_rgb += np.array(system_rgb) * weight
rgb_matrix[i, j] = mixed_rgb
# Set RGB channels
rgba_matrix[i, j, :3] = mixed_rgb
# Set Alpha channel (Saturation Effect)
# Using a square root scale to make lower counts more visible but still "lighter"
alpha = np.sqrt(count_sum / max_count)
# Ensure alpha is at least 0.1 so it's not invisible
rgba_matrix[i, j, 3] = max(0.1, alpha)
else:
# Empty cells are white
rgba_matrix[i, j] = [1, 1, 1, 0]
# --- 5. Plotting --- # --- 5. Plotting ---
fig, ax = plt.subplots(figsize=(12, 10)) fig, ax = plt.subplots(figsize=(12, 10))
# Display the mixed color matrix # Show the matrix
# We use alpha based on count to show density (optional, but recommended) # Note: we use origin='lower' if you want 0-1 at the bottom,
im = ax.imshow(rgb_matrix, interpolation='nearest', origin='upper') # but confusion matrices usually have 0-1 at the top (origin='upper')
im = ax.imshow(rgba_matrix, interpolation='nearest', origin='upper')
# Add text labels for total counts in each cell # Add count labels
for i in range(n_categories): for i in range(n_categories):
for j in range(n_categories): for j in range(n_categories):
if total_counts[i, j] > 0: if total_counts[i, j] > 0:
# Determine text color based on brightness of background # Background brightness for text contrast
lum = 0.2126 * rgb_matrix[i,j,0] + 0.7152 * rgb_matrix[i,j,1] + 0.0722 * rgb_matrix[i,j,2] bg_color = rgba_matrix[i, j, :3]
text_col = "white" if lum < 0.5 else "black" lum = 0.2126 * bg_color[0] + 0.7152 * bg_color[1] + 0.0722 * bg_color[2]
ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", color=text_col, fontsize=9) # If alpha is low, background is effectively white, so use black text
text_col = "white" if (lum < 0.5 and rgba_matrix[i,j,3] > 0.5) else "black"
ax.text(j, i, int(total_counts[i, j]), ha="center", va="center",
color=text_col, fontsize=10, fontweight='bold')
# --- 6. Styling --- # --- 6. Styling ---
ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12) ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12, labelpad=10)
ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12) ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12, labelpad=10)
ax.set_title('Blended Confusion Matrix (Color = Weighted System Mixture)', fontsize=14, pad=20) ax.set_title('Saturated Confusion Matrix\nColor = System Mixture | Opacity = Density', fontsize=14, pad=20)
ax.set_xticks(np.arange(n_categories)) ax.set_xticks(np.arange(n_categories))
ax.set_xticklabels(categories) ax.set_xticklabels(categories)
ax.set_yticks(np.arange(n_categories)) ax.set_yticks(np.arange(n_categories))
ax.set_yticklabels(categories) ax.set_yticklabels(categories)
# Remove the frame/spines for a cleaner look
for spine in ax.spines.values():
spine.set_visible(False)
# Custom Legend # Custom Legend
handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names] handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names]
labels = [name.replace('_', ' ').capitalize() for name in system_names] labels = [name.replace('_', ' ').capitalize() for name in system_names]
ax.legend(handles, labels, title='Functional Systems', loc='upper left', bbox_to_anchor=(1.05, 1), frameon=False) ax.legend(handles, labels, title='Functional Systems', loc='upper left',
bbox_to_anchor=(1.05, 1), frameon=False)
plt.tight_layout() plt.tight_layout()
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
plt.savefig(figure_save_path, format='svg', bbox_inches='tight') plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
plt.show() plt.show()
# #