From b2e9ccd2b6091aab055a527542ab1417ebc502d9 Mon Sep 17 00:00:00 2001 From: Shahin Ramezanzadeh Date: Mon, 26 Jan 2026 02:02:19 +0100 Subject: [PATCH] adding some visualizations --- Data/show_plots.py | 435 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 434 insertions(+), 1 deletion(-) diff --git a/Data/show_plots.py b/Data/show_plots.py index 8de8133..786809e 100644 --- a/Data/show_plots.py +++ b/Data/show_plots.py @@ -748,7 +748,7 @@ plt.show() ## -# %% name +# %% Table import pandas as pd import matplotlib.pyplot as plt import seaborn as sns @@ -860,4 +860,437 @@ for ax in axes: plt.tight_layout() plt.show() + + + ## + + + + +# %% Histogram Fig1 +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.font_manager as fm +import json +import os + +def create_visit_frequency_plot( + file_path, + output_dir='/home/shahin/Lab/Doktorarbeit/Barcelona/Data', + output_filename='visit_frequency_distribution.svg', + fontsize=10, + color_scheme_path='colors.json' +): + """ + Creates a publication-ready bar chart of patient visit frequency. + + Args: + file_path (str): Path to the input TSV file. + output_dir (str): Directory to save the output SVG file. + output_filename (str): Name of the output SVG file. + fontsize (int): Font size for all text elements (labels, title). + color_scheme_path (str): Path to the JSON file containing the color palette. + """ + # --- 1. Load Data and Color Scheme --- + try: + df = pd.read_csv(file_path, sep='\t') + print("Data loaded successfully.") + # Sort data for easier visual comparison + df = df.sort_values(by='Visits Count') + except FileNotFoundError: + print(f"Error: The file was not found at {file_path}") + return + + try: + with open(color_scheme_path, 'r') as f: + colors = json.load(f) + # Select a blue from the sequential palette for the bars + bar_color = colors['sequential']['blues'][-2] # A saturated blue + except FileNotFoundError: + print(f"Warning: Color scheme file not found at {color_scheme_path}. Using default blue.") + bar_color = '#2171b5' # A common matplotlib blue + + # --- 2. Set up the Plot with Scientific Style --- + plt.figure(figsize=(7.94, 6)) # Single-column width (7.94 cm) with appropriate height + + # Set the font to Arial + arial_font = fm.FontProperties(family='Arial', size=fontsize) + plt.rcParams['font.family'] = 'Arial' + plt.rcParams['font.size'] = fontsize + + # --- 3. Create the Bar Chart --- + ax = plt.gca() + bars = plt.bar( + x=df['Visits Count'], + height=df['Unique Patients'], + color=bar_color, + edgecolor='black', + linewidth=0.5, # Minimum line thickness + width=0.7 + ) + + # --- NEW: Explicitly set x-ticks and labels to ensure all are shown --- + # Get the unique visit counts to use as tick labels + visit_counts = df['Visits Count'].unique() + # Set the x-ticks to be at the center of each bar + ax.set_xticks(visit_counts) + # Set the x-tick labels to be the visit counts, using the specified font + ax.set_xticklabels(visit_counts, fontproperties=arial_font) + # --- END OF NEW SECTION --- + + # --- 4. Customize Axes and Layout (Nature style) --- + # Display only left and bottom axes + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + # Turn off axis ticks (the marks, not the labels) + plt.tick_params(axis='both', which='both', length=0) + + # Remove grid lines + plt.grid(False) + + # Set background to white (no shading) + ax.set_facecolor('white') + plt.gcf().set_facecolor('white') + + # --- 5. Add Labels and Title --- + plt.xlabel('Number of Visits', fontproperties=arial_font, labelpad=10) + plt.ylabel('Number of Unique Patients', fontproperties=arial_font, labelpad=10) + plt.title('Distribution of Patient Visit Frequency', fontproperties=arial_font, pad=20) + + # --- 6. Add y-axis values on top of each bar --- + # This adds the count of unique patients directly above each bar. + ax.bar_label(bars, fmt='%d', padding=3) + + # --- 7. Export the Figure --- + # Ensure the output directory exists + os.makedirs(output_dir, exist_ok=True) + + full_output_path = os.path.join(output_dir, output_filename) + plt.savefig(full_output_path, format='svg', dpi=300, bbox_inches='tight') + print(f"\nFigure saved as '{full_output_path}'") + + # --- 8. (Optional) Display the Plot --- + # plt.show() + +# --- Main execution --- +if __name__ == '__main__': + # Define the file path + input_file = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/visit_freuency.tsv' + + # Call the function to create and save the plot + create_visit_frequency_plot( + file_path=input_file, + fontsize=10 # Using a 10 pt font size as per guidelines + ) + +## + + + +# %% Scatter Plot functional system + +import pandas as pd +import matplotlib.pyplot as plt +import json +import os + +# --- Configuration --- +# Set the font to Arial for all text in the plot, as per the guidelines +plt.rcParams['font.family'] = 'Arial' + +# Define the path to your data file +data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' + +# Define the path to save the color mapping JSON file +color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json' + +# Define the path to save the final figure +figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg' + +# --- 1. Load the Dataset --- +try: + # Load the TSV file + df = pd.read_csv(data_path, sep='\t') + print(f"Successfully loaded data from {data_path}") + print(f"Data shape: {df.shape}") +except FileNotFoundError: + print(f"Error: The file at {data_path} was not found.") + # Exit or handle the error appropriately + raise + +# --- 2. Define Functional Systems and Create Color Mapping --- +# List of tuples containing (ground_truth_column, result_column) +functional_systems_to_plot = [ + ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), + ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), + ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), + ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), + ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), + ('GT.AMBULATION', 'result.AMBULATION'), + ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), + ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') +] + +# Extract system names for color mapping and legend +system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] + +# Define a professional color palette (dark blue theme) +# This is a qualitative palette with distinct, accessible colors +colors = [ + '#003366', # Dark Blue + '#336699', # Medium Blue + '#6699CC', # Light Blue + '#99CCFF', # Very Light Blue + '#FF9966', # Coral + '#FF6666', # Light Red + '#CC6699', # Magenta + '#9966CC' # Purple +] + +# Create a dictionary mapping system names to colors +color_map = dict(zip(system_names, colors)) + +# Ensure the directory for the JSON file exists +os.makedirs(os.path.dirname(color_json_path), exist_ok=True) + +# Save the color map to a JSON file +with open(color_json_path, 'w') as f: + json.dump(color_map, f, indent=4) + +print(f"Color mapping saved to {color_json_path}") + +# --- 3. Calculate Agreement Percentages and Format Legend Labels --- +agreement_percentages = {} +legend_labels = {} + +for gt_col, res_col in functional_systems_to_plot: + system_name = gt_col.split('.')[1] + + # Convert columns to numeric, setting errors to NaN + gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') + res_numeric = pd.to_numeric(df[res_col], errors='coerce') + + # Ensure we are comparing the same rows + common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index) + gt_data = gt_numeric.loc[common_index] + res_data = res_numeric.loc[common_index] + + # Calculate agreement percentage + if len(gt_data) > 0: + agreement = (gt_data == res_data).mean() * 100 + else: + agreement = 0 # Handle case with no valid data + + agreement_percentages[system_name] = agreement + + # Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions") + formatted_name = " ".join(word.capitalize() for word in system_name.split('_')) + legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)" + +# --- 4. Reshape Data for Plotting --- +plot_data = [] +for gt_col, res_col in functional_systems_to_plot: + system_name = gt_col.split('.')[1] + + # Convert columns to numeric, setting errors to NaN + gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') + res_numeric = pd.to_numeric(df[res_col], errors='coerce') + + # Create a temporary DataFrame with the numeric data + temp_df = pd.DataFrame({ + 'system': system_name, + 'ground_truth': gt_numeric, + 'inference': res_numeric + }) + + # Drop rows where either value is NaN, as they cannot be plotted + temp_df = temp_df.dropna() + + plot_data.append(temp_df) + +# Concatenate all the temporary DataFrames into one +plot_df = pd.concat(plot_data, ignore_index=True) + +if plot_df.empty: + print("Warning: No valid numeric data to plot after conversion. The plot will be blank.") +else: + print(f"Prepared plot data with {len(plot_df)} data points.") + +# --- 5. Create the Scatter Plot --- +plt.figure(figsize=(10, 8)) + +# Plot each functional system with its assigned color and formatted legend label +for system, group in plot_df.groupby('system'): + plt.scatter( + group['ground_truth'], + group['inference'], + label=legend_labels[system], + color=color_map[system], + alpha=0.7, + s=30 + ) + +# Add a diagonal line representing perfect agreement (y = x) +# This line helps visualize how close the predictions are to the ground truth +if not plot_df.empty: + plt.plot( + [plot_df['ground_truth'].min(), plot_df['ground_truth'].max()], + [plot_df['ground_truth'].min(), plot_df['ground_truth'].max()], + color='black', + linestyle='--', + linewidth=0.8, + alpha=0.7 + ) + +# --- 6. Apply Styling and Labels --- +plt.xlabel('Ground Truth', fontsize=12) +plt.ylabel('LLM Inference', fontsize=12) +plt.title('Comparison of EDSS Functional Systems: Ground Truth vs. LLM Inference', fontsize=14) + +# Apply scientific visualization styling rules +ax = plt.gca() +ax.spines['top'].set_visible(False) +ax.spines['right'].set_visible(False) +ax.tick_params(axis='both', which='both', length=0) # Remove ticks +ax.grid(False) # Remove grid lines +plt.legend(title='Functional System', frameon=False, fontsize=10) + +# --- 7. Save and Display the Figure --- +# Ensure the directory for the figure exists +os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) + +plt.savefig(figure_save_path, format='svg', bbox_inches='tight') +print(f"Figure successfully saved to {figure_save_path}") + +# Display the plot +plt.show() +## + + + + +# %% Confusion Matrix functional systems + +import pandas as pd +import matplotlib.pyplot as plt +import json +import os +import numpy as np +import matplotlib.colors as mcolors + +# --- Configuration --- +plt.rcParams['font.family'] = 'Arial' +data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' +color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json' +figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg' + +# --- 1. Load the Dataset --- +df = pd.read_csv(data_path, sep='\t') + +# --- 2. Define Functional Systems and Colors --- +functional_systems_to_plot = [ + ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), + ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), + ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), + ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), + ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), + ('GT.AMBULATION', 'result.AMBULATION'), + ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), + ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') +] + +system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] +colors = ['#003366', '#336699', '#6699CC', '#99CCFF', '#FF9966', '#FF6666', '#CC6699', '#9966CC'] +color_map = dict(zip(system_names, colors)) + +# --- 3. Categorization Function --- +categories = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'] +category_to_index = {cat: i for i, cat in enumerate(categories)} +n_categories = len(categories) + +def categorize_edss(value): + if pd.isna(value): return np.nan + idx = int(min(max(value, 0), 10) - 0.001) if value > 0 else 0 + return categories[min(idx, len(categories)-1)] + +# --- 4. Prepare Mixed Color Matrix --- +cell_system_counts = np.zeros((n_categories, n_categories, len(system_names))) + +for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot): + # CRITICAL FIX: Convert to numeric and drop NaNs in one go + # 'coerce' turns non-numeric strings into NaN so they don't crash the script + temp_df = df[[gt_col, res_col]].copy() + temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce') + temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce') + + valid_df = temp_df.dropna() + + for _, row in valid_df.iterrows(): + gt_cat = categorize_edss(row[gt_col]) + res_cat = categorize_edss(row[res_col]) + if gt_cat in category_to_index and res_cat in category_to_index: + cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1 + +# Create an RGB image matrix (initially white/empty) +rgb_matrix = np.ones((n_categories, n_categories, 3)) + +# Create an Alpha matrix for the "Total Count" intensity +total_counts = np.sum(cell_system_counts, axis=2) +max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1 + +for i in range(n_categories): + for j in range(n_categories): + count_sum = total_counts[i, j] + if count_sum > 0: + # Calculate weighted average color + mixed_rgb = np.zeros(3) + for s_idx, s_name in enumerate(system_names): + weight = cell_system_counts[i, j, s_idx] / count_sum + system_rgb = mcolors.to_rgb(color_map[s_name]) + mixed_rgb += np.array(system_rgb) * weight + rgb_matrix[i, j] = mixed_rgb + +# --- 5. Plotting --- +fig, ax = plt.subplots(figsize=(12, 10)) + +# Display the mixed color matrix +# We use alpha based on count to show density (optional, but recommended) +im = ax.imshow(rgb_matrix, interpolation='nearest', origin='upper') + +# Add text labels for total counts in each cell +for i in range(n_categories): + for j in range(n_categories): + if total_counts[i, j] > 0: + # Determine text color based on brightness of background + lum = 0.2126 * rgb_matrix[i,j,0] + 0.7152 * rgb_matrix[i,j,1] + 0.0722 * rgb_matrix[i,j,2] + text_col = "white" if lum < 0.5 else "black" + ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", color=text_col, fontsize=9) + +# --- 6. Styling --- +ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12) +ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12) +ax.set_title('Blended Confusion Matrix (Color = Weighted System Mixture)', fontsize=14, pad=20) + +ax.set_xticks(np.arange(n_categories)) +ax.set_xticklabels(categories) +ax.set_yticks(np.arange(n_categories)) +ax.set_yticklabels(categories) + +# Custom Legend +handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names] +labels = [name.replace('_', ' ').capitalize() for name in system_names] +ax.legend(handles, labels, title='Functional Systems', loc='upper left', bbox_to_anchor=(1.05, 1), frameon=False) + +plt.tight_layout() +os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) +plt.savefig(figure_save_path, format='svg', bbox_inches='tight') +plt.show() + + + +# + + +