"""
File for get the tum stroma mask using bins of image
and the SOM clustering
"""
import logging
import cv2
import os
import time
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
from minisom import MiniSom
from itertools import product
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from typing import List, Optional, Tuple, Dict, Any, Union
# from mpl_toolkits.axes_grid1 import make_axes_locatable
from gridgene.logger import get_logger
# TODO make bins overlapping?
[docs]
class GetBins:
"""
Bin spatial transcriptomics data into grid cells and create AnnData objects.
"""
def __init__(self, bin_size: int, unique_targets: List[str], logger: Optional[logging.Logger] = None):
"""
Initialize GetBins.
Parameters
----------
bin_size : int
Size of bins in pixels.
unique_targets : List[str]
List of target genes.
logger : Optional[logging.Logger], optional
Logger instance, by default None
"""
self.bin_size = bin_size
self.unique_targets = unique_targets
self.adata = None
self.eval_som_statistical_df = None
self.logger = logger or get_logger(f'{__name__}.{contour_name or "GetContour"}')
self.logger.info("Initialized GetContour")
[docs]
def get_bin_df(self, df: pd.DataFrame, df_name: str) -> ad.AnnData:
"""
Convert a DataFrame of cells with spatial coordinates and target labels into a binned AnnData object.
Parameters
----------
df : pd.DataFrame
DataFrame with columns ['X', 'Y', 'target'] representing cell positions and target labels.
df_name : str
Identifier for the dataset.
Returns
-------
ad.AnnData
AnnData object with spatial bins and counts per target.
"""
# Calculate grid positions
df['x_grid'] = df['X'] // self.bin_size
df['y_grid'] = df['Y'] // self.bin_size
# Count occurrences of each target in each grid cell
quadrant_counts = df.groupby(['x_grid', 'y_grid', 'target']).size().unstack(fill_value=0)
# Reindex to ensure all targets are included, even if they have 0 counts
quadrant_counts = quadrant_counts.reindex(columns=self.unique_targets, fill_value=0)
# Convert the counts DataFrame to a numpy array for AnnData
quadrant_counts_array = quadrant_counts.values
# Create an AnnData object
adata = sc.AnnData(X=quadrant_counts_array)
# Set observation and variable (gene) names
adata.obs_names = [f"grid_{x}_{y}" for x, y in quadrant_counts.index]
adata.var_names = quadrant_counts.columns
adata.obs['name'] = df_name
# Calculate centroid coordinates
adata.obs['x_centroid'] = df.groupby(['x_grid', 'y_grid'])['X'].mean().values * self.bin_size
adata.obs['y_centroid'] = df.groupby(['x_grid', 'y_grid'])['Y'].mean().values * self.bin_size
# Store grid positions in the observation metadata
adata.obs['x_grid'] = [x for x, y in quadrant_counts.index]
adata.obs['y_grid'] = [y for x, y in quadrant_counts.index]
# Store spatial information
adata.obs['x_coord'] = df.groupby(['x_grid', 'y_grid'])['X'].first().values * self.bin_size
adata.obs['y_coord'] = df.groupby(['x_grid', 'y_grid'])['Y'].first().values * self.bin_size
adata.obsm["spatial"] = adata.obs[["x_centroid", "y_centroid"]].copy().to_numpy()
self.adata = adata
return adata
[docs]
def get_bin_cohort(self, df_list: List[pd.DataFrame], df_name_list: List[str], cohort_name: str) -> None:
"""
Process multiple datasets into binned AnnData objects and concatenate them into a cohort.
Parameters
----------
df_list : List[pd.DataFrame]
List of DataFrames to process.
df_name_list : List[str]
List of dataset names corresponding to each DataFrame.
cohort_name : str
Name of the cohort to assign to all data.
"""
start_time = time.time()
adata_list = []
for df, df_name in zip(df_list, df_name_list):
adata = self.get_bin_df(df, df_name)
adata.obs['cohort'] = cohort_name
adata_list.append(adata)
combined_adata = ad.concat(adata_list, join='outer')
self.adata = combined_adata
self.logger.info(f'Time to get bins for {len(df_list)} dataframes: {time.time() - start_time:.2f} seconds')
self.logger.info(f'Number of bins: {len(combined_adata)}')
self.logger.info(f'Number of genes: {len(combined_adata.var_names)}')
[docs]
def preprocess_bin(self, min_counts: int = 10, adata: Optional[ad.AnnData] = None) -> None:
"""
Filter and normalize the binned AnnData.
Parameters
----------
min_counts : int, optional
Minimum total counts per bin to retain it, by default 10
adata : Optional[ad.AnnData], optional
AnnData object to preprocess (defaults to internal one), by default None
"""
if adata is None:
adata = self.adata
sc.pp.filter_cells(adata, min_counts=min_counts)
adata.layers["counts"] = adata.X.copy()
adata.obs['total_counts'] = adata.X.sum(axis=1)
adata.obs['n_genes_by_counts'] = (adata.X > 0).sum(axis=1)
sc.pp.normalize_total(adata, inplace=True)
sc.pp.log1p(adata)
self.adata = adata
[docs]
class GetContour:
"""
Perform SOM clustering on spatial bins and evaluate clusters.
"""
def __init__(self, adata: ad.AnnData, logger: Optional[logging.Logger] = None):
"""
Initialize GetContour.
Parameters
----------
adata : ad.AnnData
AnnData object containing binned spatial transcriptomics data.
logger : Optional[logging.Logger], optional
Logger instance, by default None
"""
self.adata = adata
self.logger = logger
if logger is None:
# Configure default logger if none is provided
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
else:
self.logger = logger
[docs]
def run_som(
self,
som_shape: Tuple[int, int] = (2, 1),
n_iter: int = 5000,
sigma: float = 0.5,
learning_rate: float = 0.5,
random_state: int = 42
) -> None:
"""
Apply SOM clustering on the AnnData object.
Parameters
----------
som_shape : Tuple[int, int], optional
Shape of the SOM grid (rows, columns), by default (2, 1)
n_iter : int, optional
Number of iterations for SOM training, by default 5000
sigma : float, optional
Width of the Gaussian neighborhood function, by default 0.5
learning_rate : float, optional
Learning rate for SOM training, by default 0.5
random_state : int, optional
Random seed for reproducibility, by default 42
"""
start = time.time()
som = MiniSom(som_shape[0], som_shape[1], self.adata.shape[1],
sigma=sigma, learning_rate=learning_rate, random_seed=random_state)
som.train_random(self.adata.X,n_iter)
# Step 3: Assign Clusters
clusters = np.zeros(len(self.adata), dtype=int)
clusters = list(clusters)
possible_tuples = list(product(range(som_shape[0]), range(som_shape[1])))
table_values = list(range(len(possible_tuples)))
# Create a dictionary to map tuples to values
table_dict = {t: v for t, v in zip(possible_tuples, table_values)}
# print(table_dict)
for i, q in enumerate(self.adata.X):
# print(som.winner(q)) #(x,y)
x, y = som.winner(q)
clusters[i] = int(table_dict.get((x, y)))
self.adata.obs['cluster_som'] = pd.Categorical(clusters)
self.logger.info(f'Time to run som on {len(self.adata.X)} bins: {time.time() - start:.2f}')
self.logger.info(f'Number of clusters: {len(set(clusters))}')
self.logger.info(f'number of bins in each cluster: {self.adata.obs["cluster_som"].value_counts()}')
[docs]
def eval_som_statistical(self, top_n: int = 20) -> None:
"""
Compute and log top ranked features per SOM cluster.
Parameters
----------
top_n : int, optional
Number of top features to retrieve for each cluster, by default 20
"""
sc.tl.rank_genes_groups(self.adata, "cluster_som", method="t-test")
stats = []
groups = self.adata.uns['rank_genes_groups']['names'].dtype.names
for group in groups:
df = sc.get.rank_genes_groups_df(self.adata, group)
df['group'] = group
df_sorted = df.sort_values(by='scores', ascending=False).head(top_n)
self.logger.info(f"n top genes for group {group}")
self.logger.info("\n" + df_sorted.to_string())
stats.append(df_sorted)
self.eval_som_statistical_df = pd.concat(stats, ignore_index=True)
[docs]
def create_cluster_image(self, adata: ad.AnnData, grid_size: int) -> np.ndarray:
"""
Reconstruct an image from cluster annotations in the AnnData object.
Parameters
----------
adata : ad.AnnData
AnnData object containing clustering results and grid positions.
grid_size : int
Size of each grid cell in pixels.
Returns
-------
np.ndarray
2D array with cluster IDs as pixel values.
"""
# Initialize an empty image
max_x_grid = adata.obs['x_grid'].max()
max_y_grid = adata.obs['y_grid'].max()
image_shape = (int((max_x_grid + 1) * grid_size), int((max_y_grid + 1) * grid_size))
reconstructed_image = np.zeros(image_shape)
# Iterate over the observations in the AnnData object
for _, row in adata.obs.iterrows():
# Retrieve the SOM cluster and grid coordinates
cluster = row['cluster_som'] +1
x_start = int(row['x_grid'] * grid_size)
y_start = int(row['y_grid'] * grid_size)
# Set all pixels in the corresponding grid to the SOM cluster value
reconstructed_image[x_start:x_start + grid_size, y_start:y_start + grid_size] = cluster
return reconstructed_image
[docs]
def plot_som(
self,
som_image: np.ndarray,
cmap: Optional[Any] = None,
path: Optional[str] = None,
show: bool = False,
figsize: Tuple[int, int] = (10, 10),
ax: Optional[plt.Axes] = None,
legend_labels: Optional[Dict[int, str]] = None
) -> plt.Axes:
"""
Visualize the SOM cluster map.
Parameters
----------
som_image : np.ndarray
2D array representing the SOM clusters.
cmap : Optional[Any], optional
Colormap to use for visualization, by default None (uses 'tab10')
path : Optional[str], optional
Optional path to save the plot image, by default None
show : bool, optional
Whether to display the plot, by default False
figsize : Tuple[int, int], optional
Size of the figure, by default (10, 10)
ax : Optional[plt.Axes], optional
Matplotlib Axes to plot on, by default None (creates new figure)
legend_labels : Optional[Dict[int, str]], optional
Dictionary mapping cluster indices to labels for legend, by default None
Returns
-------
plt.Axes
The matplotlib Axes object containing the plot.
"""
if ax is None:
plt.figure(figsize=figsize)
ax = plt.gca()
ax.imshow(som_image, cmap=cmap, interpolation='none', origin='lower')
ax.set_title('SOM clustering')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
if legend_labels:
# Create custom legend handles
handles = [mpatches.Patch(color=cmap(idx / max(legend_labels.keys())), label=label)
for idx, label in legend_labels.items()]
# ax.legend(handles=handles, loc='upper right', title="Clusters")
ax.legend(handles=handles, loc='center left', bbox_to_anchor=(1.05, 0.5), title="Clusters")
if path is not None:
save_path = os.path.join(path, 'SOM_clustering.png')
plt.savefig(save_path, dpi=1000, bbox_inches='tight')
self.logger.info(f'Plot saved at {save_path}')
if show:
plt.show()
return ax