diff --git a/.gitignore b/.gitignore index 072bc4e..c114207 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,10 @@ __pycache__/ *.pyc +======= +/reference/ +*.svg +>>>>>>> Stashed changes # 2. Ignore virtual environments COMPLETELY # This must come BEFORE the unignore rule env*/ diff --git a/Data/show_plots.py b/Data/show_plots.py index 786809e..c14e5a5 100644 --- a/Data/show_plots.py +++ b/Data/show_plots.py @@ -662,7 +662,7 @@ print("\nFirst few rows:") print(df.head()) # Hardcode specific patient names -patient_names = ['113c1470'] +patient_names = ['6b56865d'] # Define the functional systems (columns to plot) - adjust based on actual column names functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual'] @@ -1183,7 +1183,6 @@ import matplotlib.colors as mcolors # --- Configuration --- plt.rcParams['font.family'] = 'Arial' data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' -color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json' figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg' # --- 1. Load the Dataset --- @@ -1212,19 +1211,19 @@ n_categories = len(categories) def categorize_edss(value): if pd.isna(value): return np.nan - idx = int(min(max(value, 0), 10) - 0.001) if value > 0 else 0 + # Ensure value is float to avoid TypeError + val = float(value) + idx = int(min(max(val, 0), 10) - 0.001) if val > 0 else 0 return categories[min(idx, len(categories)-1)] -# --- 4. Prepare Mixed Color Matrix --- +# --- 4. Prepare Mixed Color Matrix with Saturation --- cell_system_counts = np.zeros((n_categories, n_categories, len(system_names))) for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot): - # CRITICAL FIX: Convert to numeric and drop NaNs in one go - # 'coerce' turns non-numeric strings into NaN so they don't crash the script + # Fix: Ensure numeric conversion to avoid string comparison errors temp_df = df[[gt_col, res_col]].copy() temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce') temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce') - valid_df = temp_df.dropna() for _, row in valid_df.iterrows(): @@ -1233,10 +1232,9 @@ for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot): if gt_cat in category_to_index and res_cat in category_to_index: cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1 -# Create an RGB image matrix (initially white/empty) -rgb_matrix = np.ones((n_categories, n_categories, 3)) +# Create an RGBA image matrix (10x10x4) +rgba_matrix = np.zeros((n_categories, n_categories, 4)) -# Create an Alpha matrix for the "Total Count" intensity total_counts = np.sum(cell_system_counts, axis=2) max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1 @@ -1244,52 +1242,69 @@ for i in range(n_categories): for j in range(n_categories): count_sum = total_counts[i, j] if count_sum > 0: - # Calculate weighted average color mixed_rgb = np.zeros(3) for s_idx, s_name in enumerate(system_names): weight = cell_system_counts[i, j, s_idx] / count_sum system_rgb = mcolors.to_rgb(color_map[s_name]) mixed_rgb += np.array(system_rgb) * weight - rgb_matrix[i, j] = mixed_rgb + + # Set RGB channels + rgba_matrix[i, j, :3] = mixed_rgb + + # Set Alpha channel (Saturation Effect) + # Using a square root scale to make lower counts more visible but still "lighter" + alpha = np.sqrt(count_sum / max_count) + # Ensure alpha is at least 0.1 so it's not invisible + rgba_matrix[i, j, 3] = max(0.1, alpha) + else: + # Empty cells are white + rgba_matrix[i, j] = [1, 1, 1, 0] # --- 5. Plotting --- fig, ax = plt.subplots(figsize=(12, 10)) -# Display the mixed color matrix -# We use alpha based on count to show density (optional, but recommended) -im = ax.imshow(rgb_matrix, interpolation='nearest', origin='upper') +# Show the matrix +# Note: we use origin='lower' if you want 0-1 at the bottom, +# but confusion matrices usually have 0-1 at the top (origin='upper') +im = ax.imshow(rgba_matrix, interpolation='nearest', origin='upper') -# Add text labels for total counts in each cell +# Add count labels for i in range(n_categories): for j in range(n_categories): if total_counts[i, j] > 0: - # Determine text color based on brightness of background - lum = 0.2126 * rgb_matrix[i,j,0] + 0.7152 * rgb_matrix[i,j,1] + 0.0722 * rgb_matrix[i,j,2] - text_col = "white" if lum < 0.5 else "black" - ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", color=text_col, fontsize=9) + # Background brightness for text contrast + bg_color = rgba_matrix[i, j, :3] + lum = 0.2126 * bg_color[0] + 0.7152 * bg_color[1] + 0.0722 * bg_color[2] + # If alpha is low, background is effectively white, so use black text + text_col = "white" if (lum < 0.5 and rgba_matrix[i,j,3] > 0.5) else "black" + ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", + color=text_col, fontsize=10, fontweight='bold') # --- 6. Styling --- -ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12) -ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12) -ax.set_title('Blended Confusion Matrix (Color = Weighted System Mixture)', fontsize=14, pad=20) +ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12, labelpad=10) +ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12, labelpad=10) +ax.set_title('Saturated Confusion Matrix\nColor = System Mixture | Opacity = Density', fontsize=14, pad=20) ax.set_xticks(np.arange(n_categories)) ax.set_xticklabels(categories) ax.set_yticks(np.arange(n_categories)) ax.set_yticklabels(categories) +# Remove the frame/spines for a cleaner look +for spine in ax.spines.values(): + spine.set_visible(False) + # Custom Legend handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names] labels = [name.replace('_', ' ').capitalize() for name in system_names] -ax.legend(handles, labels, title='Functional Systems', loc='upper left', bbox_to_anchor=(1.05, 1), frameon=False) +ax.legend(handles, labels, title='Functional Systems', loc='upper left', + bbox_to_anchor=(1.05, 1), frameon=False) plt.tight_layout() os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') plt.show() - - #