checkatlas¶
checkatlas.check¶
checkatlas.check
¶
atlas_info = {'Atlas_name': 'pbmc_6k_v1_v1', 'Atlas_type': 'Cellranger < v3', 'Atlas_extension': '.mtx', 'Atlas_path': '/data/analysis/data_becavin/checkatlas_test/tuto/data5/pbmc_6k_v1_v1/outs/filtered_matrices_mex/hg19/matrix.mtx'}
module-attribute
¶
atlas_info = {'Atlas_name': 'pbmc_5k_v3_v3', 'Atlas_type': 'Cellranger >= v3', 'Atlas_extension': '.h5', 'Atlas_path': '/data/analysis/data_becavin/' 'checkatlas_test/tuto/data5/' 'pbmc_5k_v3_v3/outs/' '5k_pbmc_v3_filtered_feature_bc_matrix.h5'}
atlas_info = {'Atlas_name': 'pbmc_5k_v3_v7', 'Atlas_type': 'Cellranger >= v3', 'Atlas_extension': '.h5', 'Atlas_path': '/data/analysis/data_becavin/' 'checkatlas_test/tuto/data5/' 'pbmc_5k_v3_v7/outs/' 'SC3pv3_GEX_Human_PBMC_filtered_feature_bc_matrix.h5'}
atlas_info = {'Atlas_name': 'pbmc_3k_multiome', 'Atlas_type': 'Cellranger >= v3', 'Atlas_extension': '.h5', 'Atlas_path': '/data/analysis/data_becavin/' 'checkatlas_test/tuto/data5/' 'pbmc_3k_multiome/outs/' 'pbmc_unsorted_3k_filtered_feature_bc_matrix.h5'}
checkatlas.atlas¶
checkatlas.atlas
¶
atlas_sampling(df_annot: pd.DataFrame, type_df: str, args: argparse.Namespace) -> pd.DataFrame
¶
If args.plot_celllimit != 0 and args.plot_celllimit < len(df_annot) The atlas qC table will be sampled for MultiQC
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/atlas.py
def atlas_sampling(
df_annot: pd.DataFrame, type_df: str, args: argparse.Namespace
) -> pd.DataFrame:
"""
If args.plot_celllimit != 0 and args.plot_celllimit < len(df_annot)
The atlas qC table will be sampled for MultiQC
Args:
df_annot (pd.DataFrame): Table to sample
type_df (str): type of table
args (argparse.Namespace): arguments of checkatlas workflow
Returns:
pd.DataFrame: Sampled QC table
"""
if args.plot_celllimit != 0 and args.plot_celllimit < len(df_annot):
logger.debug(f"Sample {type_df} table with {len(df_annot)} cells")
df_annot = df_annot.sample(args.plot_celllimit)
logger.debug(f"{type_df} table sampled to {len(df_annot)} cells")
return df_annot
clean_scanpy_atlas(adata: AnnData, atlas_info: dict) -> AnnData
¶
Clean the Scanpy object to be sure to get all information out of it
- Make var names unique
- Make var unique for Raw matrix
- If OBS_CLUSTERS are present and in int32 -> be sure to transform them in categorical
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/atlas.py
def clean_scanpy_atlas(adata: AnnData, atlas_info: dict) -> AnnData:
"""
Clean the Scanpy object to be sure to get all information out of it
- Make var names unique
- Make var unique for Raw matrix
- If OBS_CLUSTERS are present and in int32 -> be sure to
transform them in categorical
Args:
adata (AnnData): atlas to analyse
atlas_info (dict): info on the atlas
Returns:
AnnData: cleaned atlas
"""
logger.debug(f"Clean scanpy: {atlas_info[check.ATLAS_NAME_KEY]}")
# Make var names unique
list_var = adata.var_names
if len(set(list_var)) == len(list_var):
logger.debug("Var names unique")
else:
logger.debug("Var names not unique, ran : adata.var_names_make_unique()")
adata.var_names_make_unique()
# Test a second time if it is unique (sometimes it helps)
list_var = adata.var_names
if len(set(list_var)) == len(list_var):
logger.debug("Var names unique")
else:
logger.debug("Var names not unique, ran : adata.var_names_make_unique()")
adata.var_names_make_unique()
# If it is still not unique, create unique var_names "by hand"
list_var = adata.var_names
if len(set(list_var)) == len(list_var):
logger.debug("Var names unique")
else:
logger.debug("Var names not unique, ran : adata.var_names_make_unique()")
adata.var.index = [
x + "_" + str(i)
for i, x in zip(range(len(adata.var)), adata.var_names)
]
list_var = adata.var_names
if len(set(list_var)) == len(list_var):
logger.debug("Var names unique")
# Make var unique for Raw matrix
if adata.raw is not None:
list_var = adata.raw.var_names
if len(set(list_var)) == len(list_var):
logger.debug("Var names for Raw unique, transform ")
else:
logger.debug("Var names for Raw not unique")
adata.raw.var.index = [
x + "_" + str(i)
for i, x in zip(range(len(adata.raw.var)), adata.raw.var_names)
]
list_var = adata.raw.var_names
if len(set(list_var)) == len(list_var):
logger.debug("Var names for Raw unique")
# If OBS_CLUSTERS are present and in int32 -> be sure to
# transform them in categorical
for obs_key in adata.obs_keys():
for obs_key_celltype in OBS_CLUSTERS:
if obs_key_celltype in obs_key:
if (
adata.obs[obs_key].dtype == np.int32
or adata.obs[obs_key].dtype == np.int64
):
adata.obs[obs_key] = pd.Categorical(adata.obs[obs_key])
return adata
col_annotation_pred(adata: AnnData, min_score: float = 0.5, return_all: bool = False, max_results: int = 5) -> Optional[List[str]]
¶
Detect predicted/cluster annotation columns in AnnData object.
This function identifies columns containing cluster labels or automated cell type predictions (e.g., leiden, louvain, seurat_clusters, celltypist).
| Parameters: |
|
|---|
| Returns: |
|
|---|
Example
import scanpy as sc import checkatlas.atlas as atlas adata = sc.read_h5ad("atlas.h5ad") pred_cols = atlas.col_annotation_pred(adata) print(f"Predicted columns: {pred_cols}")
Get with scores¶
pred_with_scores = atlas.col_annotation_pred(adata, return_all=True) for col, score in pred_with_scores: ... print(f"{col}: {score:.3f}")
Source code in checkatlas/atlas.py
def col_annotation_pred(
adata: AnnData,
min_score: float = 0.5,
return_all: bool = False,
max_results: int = 5,
) -> Optional[List[str]]:
"""
Detect predicted/cluster annotation columns in AnnData object.
This function identifies columns containing cluster labels or automated
cell type predictions (e.g., leiden, louvain, seurat_clusters, celltypist).
Args:
adata (AnnData): Scanpy AnnData object to analyze
min_score (float): Minimum confidence score threshold (0-1). Default: 0.5
return_all (bool): If True, return with scores. Default: False
max_results (int): Maximum number of columns to return. Default: 5
Returns:
List[str] or List[Tuple[str, float]] or None:
- If return_all=False: List of column names sorted by confidence
- If return_all=True: List of (column_name, score) tuples
- None if no columns found
Example:
>>> import scanpy as sc
>>> import checkatlas.atlas as atlas
>>> adata = sc.read_h5ad("atlas.h5ad")
>>> pred_cols = atlas.col_annotation_pred(adata)
>>> print(f"Predicted columns: {pred_cols}")
>>>
>>> # Get with scores
>>> pred_with_scores = atlas.col_annotation_pred(adata, return_all=True)
>>> for col, score in pred_with_scores:
... print(f"{col}: {score:.3f}")
"""
detector = CheckAtlasColumnDetector(adata)
results = detector.detect_all_parameters(
min_reference_score=0.3, min_predicted_score=min_score
)
pred_candidates = results["annotation"]["predicted"][:max_results]
if not pred_candidates:
return None
if return_all:
return pred_candidates
else:
return [col for col, score in pred_candidates]
col_annotation_ref(adata: AnnData, min_score: float = 0.5, return_all: bool = False) -> Optional[str]
¶
Detect reference (ground truth) annotation column in AnnData object.
This function uses intelligent semantic and statistical analysis to identify the most likely reference/ground truth cell type annotation column.
| Parameters: |
|
|---|
| Returns: |
|
|---|
Example
import scanpy as sc import checkatlas.atlas as atlas adata = sc.read_h5ad("atlas.h5ad") ref_col = atlas.col_annotation_ref(adata) print(f"Reference column: {ref_col}")
Get all candidates with scores¶
all_refs = atlas.col_annotation_ref(adata, return_all=True) for col, score in all_refs: ... print(f"{col}: {score:.3f}")
Source code in checkatlas/atlas.py
def col_annotation_ref(
adata: AnnData, min_score: float = 0.5, return_all: bool = False
) -> Optional[str]:
"""
Detect reference (ground truth) annotation column in AnnData object.
This function uses intelligent semantic and statistical analysis to identify
the most likely reference/ground truth cell type annotation column.
Args:
adata (AnnData): Scanpy AnnData object to analyze
min_score (float): Minimum confidence score threshold (0-1). Default: 0.5
return_all (bool): If True, return list of all candidates with scores. Default: False
Returns:
str or List[Tuple[str, float]] or None:
- If return_all=False: Best reference column name, or None if none found
- If return_all=True: List of (column_name, score) tuples sorted by score
Example:
>>> import scanpy as sc
>>> import checkatlas.atlas as atlas
>>> adata = sc.read_h5ad("atlas.h5ad")
>>> ref_col = atlas.col_annotation_ref(adata)
>>> print(f"Reference column: {ref_col}")
>>>
>>> # Get all candidates with scores
>>> all_refs = atlas.col_annotation_ref(adata, return_all=True)
>>> for col, score in all_refs:
... print(f"{col}: {score:.3f}")
"""
detector = CheckAtlasColumnDetector(adata)
results = detector.detect_all_parameters(
min_reference_score=min_score, min_predicted_score=0.3
)
ref_candidates = results["annotation"]["reference"]
if return_all:
return ref_candidates
else:
return ref_candidates[0][0] if ref_candidates else None
col_cluster(adata: AnnData, min_score: float = 0.5, return_all: bool = False, max_results: int = 5) -> Optional[List[str]]
¶
Detect cluster label columns in AnnData object.
Uses the dedicated cluster-label detector (Leiden, Louvain, k‑means, Seurat clusters, PhenoGraph, etc.) which applies semantic 70 % + statistical 30 % scoring tuned for algorithmic clustering outputs.
| Parameters: |
|
|---|
| Returns: |
|
|---|
Example
import scanpy as sc import checkatlas.atlas as atlas adata = sc.read_h5ad("atlas.h5ad") cluster_cols = atlas.col_cluster(adata) print(f"Cluster columns: {cluster_cols}")
Source code in checkatlas/atlas.py
def col_cluster(
adata: AnnData,
min_score: float = 0.5,
return_all: bool = False,
max_results: int = 5,
) -> Optional[List[str]]:
"""
Detect cluster label columns in AnnData object.
Uses the dedicated cluster-label detector (Leiden, Louvain, k‑means,
Seurat clusters, PhenoGraph, etc.) which applies semantic 70 % +
statistical 30 % scoring tuned for algorithmic clustering outputs.
Args:
adata (AnnData): Scanpy AnnData object to analyze
min_score (float): Minimum confidence score threshold (0-1). Default: 0.5
return_all (bool): If True, return with scores. Default: False
max_results (int): Maximum number of columns to return. Default: 5
Returns:
List[str] or List[Tuple[str, float]] or None:
- If return_all=False: List of column names sorted by confidence
- If return_all=True: List of (column_name, score) tuples
- None if no columns found
Example:
>>> import scanpy as sc
>>> import checkatlas.atlas as atlas
>>> adata = sc.read_h5ad("atlas.h5ad")
>>> cluster_cols = atlas.col_cluster(adata)
>>> print(f"Cluster columns: {cluster_cols}")
"""
detector = CheckAtlasColumnDetector(adata)
results = detector.detect_all_parameters(min_cluster_score=min_score)
clust_candidates = results["clustering"]["cluster_labels"][:max_results]
if not clust_candidates:
return None
if return_all:
return clust_candidates
else:
return [col for col, _score in clust_candidates]
col_dimred(adata: AnnData, return_all: bool = False, max_results: int = 10) -> Optional[List[str] | List[dict[str, Any]]]
¶
Detect dimensionality reduction representations in AnnData.obsm.
This function identifies embedding keys like X_pca, X_umap, X_tsne, etc.
| Parameters: |
|
|---|
| Returns: |
|
|---|
Example
import scanpy as sc import checkatlas.atlas as atlas adata = sc.read_h5ad("atlas.h5ad") dimred_keys = atlas.col_dimred(adata) print(f"Dimensionality reductions: {dimred_keys}")
Get with metadata¶
dimred_detailed = atlas.col_dimred(adata, return_all=True) for emb in dimred_detailed: ... print(f"{emb['key']}: {emb['n_components']} components")
Source code in checkatlas/atlas.py
def col_dimred(
adata: AnnData, return_all: bool = False, max_results: int = 10
) -> Optional[List[str] | List[dict[str, Any]]]:
"""
Detect dimensionality reduction representations in AnnData.obsm.
This function identifies embedding keys like X_pca, X_umap, X_tsne, etc.
Args:
adata (AnnData): Scanpy AnnData object to analyze
return_all (bool): If True, return with metadata. Default: False
max_results (int): Maximum number of representations to return. Default: 10
Returns:
List[str] or List[Dict] or None:
- If return_all=False: List of obsm keys (e.g., ['X_umap', 'X_pca'])
- If return_all=True: List of dicts with 'key', 'shape', 'n_components'
- None if no representations found
Example:
>>> import scanpy as sc
>>> import checkatlas.atlas as atlas
>>> adata = sc.read_h5ad("atlas.h5ad")
>>> dimred_keys = atlas.col_dimred(adata)
>>> print(f"Dimensionality reductions: {dimred_keys}")
>>>
>>> # Get with metadata
>>> dimred_detailed = atlas.col_dimred(adata, return_all=True)
>>> for emb in dimred_detailed:
... print(f"{emb['key']}: {emb['n_components']} components")
"""
detector = CheckAtlasColumnDetector(adata)
results = detector.detect_all_parameters()
embeddings = results["clustering"]["embeddings"][:max_results]
if not embeddings:
return None
if return_all:
return [
{
"key": key,
"shape": meta["shape"],
"n_components": meta["n_components"],
}
for key, meta in embeddings
]
else:
return [key for key, meta in embeddings]
create_anndata_table(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None
¶
Create an html table with all AnnData arguments The html code will make all elements of the table visible in MultiQC Args: adata (AnnData): atlas to analyse atlas_info (dict): info dict on the atlas args (argparse.Namespace): list of arguments from checkatlas workflow
Source code in checkatlas/atlas.py
def create_anndata_table(
adata: AnnData, atlas_info: dict, args: argparse.Namespace
) -> None:
"""
Create an html table with all AnnData arguments
The html code will make all elements of the table visible in MultiQC
Args:
adata (AnnData): atlas to analyse
atlas_info (dict): info dict on the atlas
args (argparse.Namespace): list of arguments from checkatlas workflow
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
logger.debug(f"Create Adata table for {atlas_name}")
csv_path = files.get_file_path(
atlas_name, folders.ANNDATA, check.TSV_EXTENSION, args.path
)
# Create AnnData table
header = ["atlas_obs", "obsm", "var", "varm", "uns"]
df_summary = pd.DataFrame(index=[atlas_name], columns=header)
# html_element = "<span class=\"label label-primary\">"
# new_line = ''
# for value in list(adata.obs.columns):
# new_line += html_element + value + "</span><br>"
# print(new_line)
df_summary["atlas_obs"][atlas_name] = (
"<code>" + "</code><br><code>".join(list(adata.obs.columns)) + "</code>"
)
df_summary["obsm"][atlas_name] = (
"<code>" + "</code><br><code>".join(list(adata.obsm_keys())) + "</code>"
)
df_summary["var"][atlas_name] = (
"<code>" + "</code><br><code>".join(list(adata.var_keys())) + "</code>"
)
df_summary["varm"][atlas_name] = (
"<code>" + "</code><br><code>".join(list(adata.varm_keys())) + "</code>"
)
df_summary["uns"][atlas_name] = (
"<code>" + "</code><br><code>".join(list(adata.uns_keys())) + "</code>"
)
df_summary.to_csv(csv_path, index=False, quoting=False, sep="\t")
create_metric_annot(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None
¶
Calculate all annotation metrics via the comprehensive cal_annot
pipeline (ref-vs-pred, embedding-based, batch/integration, graph).
The pipeline auto-detects reference/predicted columns, embedding keys,
and batch labels; runs every metric listed in METRICS_ANNOT; and
writes results as tab-separated files in the annotation folder.
| Parameters: |
|
|---|
Source code in checkatlas/atlas.py
def create_metric_annot(
adata: AnnData, atlas_info: dict, args: argparse.Namespace
) -> None:
"""
Calculate all annotation metrics via the comprehensive ``cal_annot``
pipeline (ref-vs-pred, embedding-based, batch/integration, graph).
The pipeline auto-detects reference/predicted columns, embedding keys,
and batch labels; runs every metric listed in ``METRICS_ANNOT``; and
writes results as tab-separated files in the annotation folder.
Args:
adata (AnnData): atlas to analyse
atlas_info (dict): info of the atlas
args (argparse.Namespace): list of arguments from checkatlas workflow
"""
if args.metric_annot == ["none"]:
logger.info("Skipping annotation metrics (--metric_annot none)")
return
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
annotation_dir = folders.get_folder(args.path, folders.ANNOTATION)
logger.info("Running full annotation pipeline for %s", atlas_name)
preprocess_ctx = _try_load_context(atlas_info, args)
df = metrics.cal_annot(
adata,
atlas_name=atlas_name,
metric_list=args.metric_annot,
all=True,
file_dir=annotation_dir,
n_jobs=_resolve_n_jobs(args),
verbose=True,
preprocess_context=preprocess_ctx,
)
if not df.empty:
csv_path = files.get_file_path(
atlas_name,
folders.ANNOTATION,
check.TSV_EXTENSION,
args.path,
)
wide_df = metrics._pivot_annot_to_wide(df, atlas_name)
wide_df.to_csv(csv_path, index=False, sep="\t")
logger.info("Annotation metrics saved to %s", csv_path)
else:
logger.warning("No annotation metrics calculated for %s", atlas_name)
create_metric_cluster(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None
¶
Calculate all clustering metrics via the comprehensive cal_cluster
pipeline. The pipeline auto-detects embeddings and cluster-label columns,
runs every metric listed in METRICS_CLUST across all combinations,
and writes results as a tab-separated file in the cluster folder.
| Parameters: |
|
|---|
Source code in checkatlas/atlas.py
def create_metric_cluster(
adata: AnnData, atlas_info: dict, args: argparse.Namespace
) -> None:
"""
Calculate all clustering metrics via the comprehensive ``cal_cluster``
pipeline. The pipeline auto-detects embeddings and cluster-label columns,
runs every metric listed in ``METRICS_CLUST`` across all combinations,
and writes results as a tab-separated file in the cluster folder.
Args:
adata (AnnData): atlas to analyse
atlas_info (dict): info of the atlas
args (argparse.Namespace): list of arguments from checkatlas workflow
"""
if not args.metric_cluster:
logger.info("Skipping clustering metrics (no metrics requested)")
return
if args.metric_cluster == ["none"]:
logger.info("Skipping clustering metrics (--metric_cluster none)")
return
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
cluster_dir = folders.get_folder(args.path, folders.CLUSTER)
logger.info("Running full clustering pipeline for %s", atlas_name)
preprocess_ctx = _try_load_context(atlas_info, args)
df = metrics.cal_cluster(
adata,
atlas_name=atlas_name,
metric_list=args.metric_cluster,
file_dir=cluster_dir,
n_jobs=_resolve_n_jobs(args),
verbose=True,
seed=42,
preprocess_context=preprocess_ctx,
)
if not df.empty:
csv_path = files.get_file_path(
atlas_name,
folders.CLUSTER,
check.TSV_EXTENSION,
args.path,
)
wide_df = metrics._pivot_cluster_to_wide(df, atlas_name)
wide_df.to_csv(csv_path, index=False, sep="\t")
logger.info("Clustering metrics saved to %s", csv_path)
else:
logger.warning("No clustering metrics calculated for %s", atlas_name)
create_metric_dimred(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None
¶
Calculate all dimred metrics via the comprehensive cal_dimred
pipeline. The pipeline auto-detects all .obsm embedding keys,
compares each against adata.X as the high‑dimensional reference,
runs every metric listed by --metric_dimred, and writes results as
a tab‑separated file in the dimred folder compatible with MultiQC.
| Parameters: |
|
|---|
Source code in checkatlas/atlas.py
def create_metric_dimred(
adata: AnnData, atlas_info: dict, args: argparse.Namespace
) -> None:
"""
Calculate all dimred metrics via the comprehensive ``cal_dimred``
pipeline. The pipeline auto-detects all ``.obsm`` embedding keys,
compares each against ``adata.X`` as the high‑dimensional reference,
runs every metric listed by ``--metric_dimred``, and writes results as
a tab‑separated file in the dimred folder compatible with MultiQC.
Args:
adata (AnnData): atlas to analyse
atlas_info (dict): info of the atlas
args (argparse.Namespace): list of arguments from checkatlas workflow
"""
if args.metric_dimred == ["none"]:
logger.info("Skipping dimred metrics (--metric_dimred none)")
return
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
logger.info("Running full dimred pipeline for %s", atlas_name)
dimred_dir = folders.get_folder(args.path, folders.DIMRED)
# Per-atlas persistent cache under temp/
cache_dir = os.path.join(
folders.get_folder(args.path, folders.TEMP), atlas_name, "dimred"
)
df = metrics.cal_dimred(
adata,
atlas_name=atlas_name,
metric_list=args.metric_dimred,
file_dir=cache_dir,
use_cache=True,
n_jobs=_resolve_n_jobs(args),
verbose=True,
seed=42,
)
if not df.empty:
csv_path = files.get_file_path(
atlas_name,
folders.DIMRED,
check.TSV_EXTENSION,
args.path,
)
wide_df = metrics._pivot_dimred_to_wide(df, atlas_name)
wide_df.to_csv(csv_path, index=False, sep="\t")
logger.info("Dimred metrics saved to %s", csv_path)
else:
logger.warning("No dimred metrics calculated for %s", atlas_name)
create_qc_plots(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None
¶
Display the atlas QC plot Search for the OBS variable which correspond to the toal_RNA, total_UMI, MT_ratio, RT_ratio
| Parameters: |
|
|---|
Source code in checkatlas/atlas.py
def create_qc_plots(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None:
"""
Display the atlas QC plot
Search for the OBS variable which correspond to the toal_RNA, total_UMI,
MT_ratio, RT_ratio
Args:
adata (AnnData): atlas to analyse
atlas_info (dict): info on the atlas
args (argparse.Namespace): list of arguments from checkatlas workflow
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
sc.settings.figdir = folders.get_workingdir(args.path)
sc.set_figure_params(dpi_save=80)
qc_path = os.sep + atlas_name + check.QC_FIG_EXTENSION
logger.debug(f"Create QC violin plot for {atlas_name}")
# mitochondrial genes
adata.var["mt"] = adata.var_names.str.startswith("MT-")
# ribosomal genes
adata.var["ribo"] = adata.var_names.str.startswith(("RPS", "RPL"))
sc.pp.calculate_qc_metrics(
adata,
qc_vars=["mt", "ribo"],
percent_top=None,
log1p=False,
inplace=True,
)
sc.pl.violin(
adata,
[
"n_genes_by_counts",
"total_counts",
"pct_counts_mt",
"pct_counts_ribo",
],
jitter=0.4,
multi_panel=True,
show=False,
save=qc_path,
)
create_qc_tables(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None
¶
Display the atlas QC table Search for the OBS variable which correspond to the toal_RNA, total_UMI, MT_ratio, RT_ratio
| Parameters: |
|
|---|
Source code in checkatlas/atlas.py
def create_qc_tables(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None:
"""
Display the atlas QC table
Search for the OBS variable which correspond to the toal_RNA, total_UMI,
MT_ratio, RT_ratio
Args:
adata (AnnData): atlas to analyse
atlas_info (dict): info on the atlas
args (argparse.Namespace): list of arguments from checkatlas workflow
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
qc_path = files.get_file_path(atlas_name, folders.QC, check.TSV_EXTENSION, args.path)
logger.debug(f"Create QC tables for {atlas_name}")
qc_genes = []
# mitochondrial genes
adata.var["mt"] = adata.var_names.str.startswith("MT-")
if len(adata.var[adata.var["mt"]]) != 0:
qc_genes.append("mt")
logger.debug(f"Mitochondrial genes in {atlas_name} for QC")
else:
logger.debug(f"No mitochondrial genes in {atlas_name} for QC")
# ribosomal genes
adata.var["ribo"] = adata.var_names.str.startswith(("RPS", "RPL"))
if len(adata.var[adata.var["mt"]]) != 0:
qc_genes.append("ribo")
logger.debug(f"Ribosomal genes in {atlas_name} for QC")
else:
logger.debug(f"No ribosomal genes in {atlas_name} for QC")
sc.pp.calculate_qc_metrics(
adata,
qc_vars=qc_genes,
percent_top=None,
log1p=False,
inplace=True,
)
df_annot = adata.obs[get_viable_obs_qc(adata, args)]
# Rank cell by qc metric
for header in df_annot.columns:
if header != CELLINDEX_HEADER:
new_header = f"cellrank_{header}"
df_annot = df_annot.sort_values(header, ascending=False)
df_annot.loc[:, [new_header]] = range(1, adata.n_obs + 1)
# Sample QC table when more cells than args.plot_celllimit are present
df_annot = atlas_sampling(df_annot, "QC", args)
df_annot.loc[:, [CELLINDEX_HEADER]] = range(1, len(df_annot) + 1)
df_annot.to_csv(qc_path, index=False, quoting=False, sep="\t")
create_summary_table(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None
¶
Create a table with all summarizing variables
| Parameters: |
|
|---|
Source code in checkatlas/atlas.py
def create_summary_table(
adata: AnnData, atlas_info: dict, args: argparse.Namespace
) -> None:
"""
Create a table with all summarizing variables
Args:
adata (AnnData): atlas to analyse
atlas_info (str): info dict of the atlas
args (argparse.Namespace): list of arguments from checkatlas workflow
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
atlas_type = atlas_info[check.ATLAS_TYPE_KEY]
atlas_path = atlas_info[check.ATLAS_PATH_KEY]
logger.debug(f"Create Summary table for {atlas_name}")
csv_path = files.get_file_path(
atlas_name, folders.SUMMARY, check.TSV_EXTENSION, args.path
)
# Create summary table
header = [
"AtlasFileType",
"NbCells",
"NbGenes",
"AnnData.raw",
"AnnData.X",
"File_extension",
"File_path",
]
df_summary = pd.DataFrame(index=[atlas_name], columns=header)
df_summary["AtlasFileType"][atlas_name] = atlas_type
df_summary["NbCells"][atlas_name] = adata.n_obs
df_summary["NbGenes"][atlas_name] = adata.n_vars
df_summary["AnnData.raw"][atlas_name] = adata.raw is not None
df_summary["AnnData.X"][atlas_name] = adata.X is not None
df_summary["File_extension"][atlas_name] = atlas_name
df_summary["File_path"][atlas_name] = atlas_path
df_summary.to_csv(csv_path, index=False, sep="\t")
create_tsne_fig(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None
¶
Display the TSNE of celltypes Search for the OBS variable which correspond to the celltype annotation
| Parameters: |
|
|---|
Source code in checkatlas/atlas.py
def create_tsne_fig(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None:
"""
Display the TSNE of celltypes
Search for the OBS variable which correspond to the celltype annotation
Args:
adata (AnnData): atlas to analyse
atlas_info (dict): info on the atlas
args (argparse.Namespace): list of arguments from checkatlas workflow
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
sc.set_figure_params(dpi_save=150)
# Search if tsne reduction exists
obsm_keys = get_viable_obsm(adata, args)
r = re.compile(".*tsne*.")
obsm_tsne_keys = list(filter(r.match, obsm_keys))
if len(obsm_tsne_keys) > 0:
obsm_tsne = obsm_tsne_keys[0]
logger.debug(f"Create t-SNE figure for {atlas_name} with obsm={obsm_tsne}")
# Set the t-sne to display
if isinstance(adata.obsm[obsm_tsne], pd.DataFrame):
# Transform to numpy if it is a pandas dataframe
adata.obsm["X_tsne"] = adata.obsm[obsm_tsne].to_numpy()
else:
adata.obsm["X_tsne"] = adata.obsm[obsm_tsne]
# Setting up figures directory
sc.settings.figdir = sc.settings.figdir = folders.get_workingdir(args.path)
tsne_path = os.sep + atlas_name + check.TSNE_EXTENSION
# Exporting tsne
obs_keys = get_viable_obs_annot(adata, args)
if len(obs_keys) != 0:
sc.pl.tsne(adata, color=obs_keys[0], show=False, save=tsne_path)
else:
sc.pl.tsne(adata, show=False, save=tsne_path)
create_umap_fig(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None
¶
Display the UMAP of celltypes Search for the OBS variable which correspond to the celltype annotation
| Parameters: |
|
|---|
Source code in checkatlas/atlas.py
def create_umap_fig(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None:
"""
Display the UMAP of celltypes
Search for the OBS variable which correspond to the celltype annotation
Args:
adata (AnnData): atlas to analyse
atlas_info (dict): info on the atlas
args (argparse.Namespace): list of arguments from checkatlas workflow
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
sc.set_figure_params(dpi_save=150)
# Search if umap reduction exists
obsm_keys = get_viable_obsm(adata, args)
r = re.compile(".*umap*.")
obsm_umap_keys = list(filter(r.match, obsm_keys))
if len(obsm_umap_keys) > 0:
obsm_umap = obsm_umap_keys[0]
logger.debug(f"Create UMAP figure for {atlas_name} with obsm={obsm_umap}")
# Set the umap to display
if isinstance(adata.obsm[obsm_umap], pd.DataFrame):
# Transform to numpy if it is a pandas dataframe
adata.obsm["X_umap"] = adata.obsm[obsm_umap].to_numpy()
else:
adata.obsm["X_umap"] = adata.obsm[obsm_umap]
# Setting up figures directory
sc.settings.figdir = folders.get_workingdir(args.path)
umap_path = os.sep + atlas_name + check.UMAP_EXTENSION
# Exporting umap
obs_keys = get_viable_obs_annot(adata, args)
if len(obs_keys) != 0:
sc.pl.umap(adata, color=obs_keys[0], show=False, save=umap_path)
else:
sc.pl.umap(adata, show=False, save=umap_path)
get_viable_obs_annot(adata: AnnData, args: argparse.Namespace) -> list
¶
Search in obs_keys a match to OBS_CLUSTERS values ! Remove obs_key with only one category ! Extract sorted obs_keys in same order then OBS_CLUSTERS
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/atlas.py
def get_viable_obs_annot(adata: AnnData, args: argparse.Namespace) -> list:
"""
Search in obs_keys a match to OBS_CLUSTERS values
! Remove obs_key with only one category !
Extract sorted obs_keys in same order then OBS_CLUSTERS
Args:
adata (AnnData): atlas to analyse
args (argparse.Namespace): list of arguments from checkatlas workflow
Returns:
list: obs_keys
"""
obs_keys = list()
# Get keys from OBS_CLUSTERS
for obs_key in adata.obs_keys():
for obs_key_celltype in args.obs_cluster:
if obs_key_celltype in obs_key:
if isinstance(adata.obs[obs_key].dtype, pd.CategoricalDtype):
obs_keys.append(obs_key)
# Remove keys with only one category and no NaN in the array
obs_keys_final = list()
for obs_key in obs_keys:
annotations = adata.obs[obs_key]
if not _object_dtype_isnan(annotations).any():
categories_temp = annotations.cat.categories
# remove nan if found
categories = categories_temp.dropna()
if True in categories.isin(["nan"]):
index = categories.get_loc("nan")
categories = categories.delete(index)
# Add obs_key with more than one category (with Nan removed)
if len(categories) != 1:
logger.debug(f"Add obs_key {obs_key} with cat {categories_temp}")
obs_keys_final.append(obs_key)
return sorted(obs_keys_final)
get_viable_obs_qc(adata: AnnData, args: argparse.Namespace) -> list
¶
Search in obs_keys a match to OBS_QC values Extract sorted obs_keys in same order then OBS_QC
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/atlas.py
def get_viable_obs_qc(adata: AnnData, args: argparse.Namespace) -> list:
"""
Search in obs_keys a match to OBS_QC values
Extract sorted obs_keys in same order then OBS_QC
Args:
adata (AnnData): atlas to analyse
args (argparse.Namespace): list of arguments from checkatlas workflow
Returns:
list: obs_keys
"""
obs_keys = list()
for obs_key in adata.obs_keys():
if obs_key in args.qc_display:
obs_keys.append(obs_key)
return obs_keys
get_viable_obsm(adata: AnnData, args: argparse.Namespace) -> list
¶
TO DO Search viable obsm for dimensionality reduction metric calc. ! No filter on osbm is appled for now ! Args: adata (AnnData): atlas to analyse args (argparse.Namespace): list of arguments from checkatlas workflow
| Returns: |
|
|---|
Source code in checkatlas/atlas.py
def get_viable_obsm(adata: AnnData, args: argparse.Namespace) -> list:
"""
TO DO
Search viable obsm for dimensionality reduction metric
calc.
! No filter on osbm is appled for now !
Args:
adata (AnnData): atlas to analyse
args (argparse.Namespace): list of arguments from checkatlas workflow
Returns:
list: obsm_keys
"""
obsm_keys = list()
# for obsm_key in adata.obsm_keys():
# if obsm_key in args.obsm_dimred:
obsm_keys = adata.obsm_keys()
logger.debug(f"Add obsm {obsm_keys}")
return obsm_keys
preprocess_atlas(atlas_info: dict, args=None) -> AnnData
¶
Read adata, clean it, and run task-specific precomputations (column detection, kNN graphs, distance matrices, etc.) based on which metric categories are requested in args.
Precomputed artefacts are persisted under
checkatlas_files/temp/<atlas>/<task>/ so that downstream
child processes (including Nextflow metric steps) can skip
redundant computation.
When args is None the function behaves exactly like the
legacy version: read + clean only.
Returns the cleaned AnnData object.
Source code in checkatlas/atlas.py
def preprocess_atlas(atlas_info: dict, args=None) -> AnnData:
"""
Read adata, clean it, and run task-specific precomputations
(column detection, kNN graphs, distance matrices, etc.) based
on which metric categories are requested in *args*.
Precomputed artefacts are persisted under
``checkatlas_files/temp/<atlas>/<task>/`` so that downstream
child processes (including Nextflow metric steps) can skip
redundant computation.
When *args* is ``None`` the function behaves exactly like the
legacy version: read + clean only.
Returns the cleaned AnnData object.
"""
adata = read_atlas(atlas_info)
adata = clean_scanpy_atlas(adata, atlas_info)
if not _should_precompute(args):
return adata
run_cluster = _wants_task(args, "cluster")
run_annot = _wants_task(args, "annot")
run_dimred = _wants_task(args, "dimred")
if not (run_cluster or run_annot or run_dimred):
return adata
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
source_path = atlas_info.get(check.ATLAS_PATH_KEY, None)
# ── 1. Column detection (once) ──────────────────────────────────
detector = CheckAtlasColumnDetector(adata)
params = detector.detect_all_parameters()
ref_keys = [c for c, _ in params["annotation"]["reference"]]
pred_keys = [c for c, _ in params["annotation"]["predicted"]]
cluster_label_keys = [c for c, _ in params["clustering"]["cluster_labels"]]
batch_keys = [c for c, _ in params.get("batch", [])]
if not batch_keys:
batch_keys = [col for col in adata.obs.columns if "batch" in col.lower()]
embedding_keys = []
cluster_embedding_keys = []
annotation_embedding_keys = []
for emb, meta in params["clustering"]["embeddings"]:
embedding_keys.append(emb)
cluster_embedding_keys.append(emb)
n_comp = meta.get("n_components", 0)
if n_comp > 2:
annotation_embedding_keys.append(emb)
# Also add ALL .obsm keys for dimred/cluster use;
# for annotation only include embeddings with > 2 components
all_obsm_keys = adata.obsm_keys()
for key in all_obsm_keys:
if key not in embedding_keys:
embedding_keys.append(key)
if key not in cluster_embedding_keys:
cluster_embedding_keys.append(key)
if key not in annotation_embedding_keys:
try:
if adata.obsm[key].shape[1] > 2:
annotation_embedding_keys.append(key)
except Exception:
pass
k_max = 90 # covers LISI (90) and kBET (25 via subset)
fingerprint = make_preprocess_fingerprint(
adata,
embedding_keys=embedding_keys,
cluster_label_keys=cluster_label_keys,
batch_keys=batch_keys,
k_neighbors=k_max,
source_path=source_path,
annotation_embedding_keys=annotation_embedding_keys,
)
# ── 2. Early exit: cached context still valid ──────────────────
temp_parent = folders.get_folder(args.path, folders.TEMP)
existing = load_context(atlas_name, temp_parent, fingerprint)
if existing is not None:
logger.info("Precompute context already valid — skipping recomputation")
return adata
# ── 3. Build fresh context ─────────────────────────────────────
ctx = PreprocessContext(
atlas_name=atlas_name,
fingerprint=fingerprint,
ref_keys=ref_keys,
pred_keys=pred_keys,
embedding_keys=embedding_keys,
annotation_embedding_keys=annotation_embedding_keys,
cluster_embedding_keys=cluster_embedding_keys,
cluster_label_keys=cluster_label_keys,
batch_keys=batch_keys,
temp_parent_dir=temp_parent,
dimred_dir=os.path.join(temp_parent, atlas_name, folders.DIMRED),
annotation_dir=os.path.join(temp_parent, atlas_name, folders.ANNOTATION),
cluster_dir=os.path.join(temp_parent, atlas_name, folders.CLUSTER),
)
for d in (ctx.dimred_dir, ctx.annotation_dir, ctx.cluster_dir):
os.makedirs(d, exist_ok=True)
# ── 4. Task-specific precomputation ────────────────────────────
n_jobs = _resolve_n_jobs(args)
if run_dimred:
_precompute_dimred(adata, ctx, k_neighbors=30, n_jobs=n_jobs)
if run_annot:
_precompute_annot(adata, ctx, k_neighbors=90, n_jobs=n_jobs)
if run_cluster:
_precompute_cluster(adata, ctx, k_neighbors=30, n_jobs=n_jobs)
# ── 5. Persist context ─────────────────────────────────────────
save_context(ctx)
return adata
read_atlas(atlas_info: dict) -> AnnData
¶
Read Scanpy or Cellranger data : .h5ad or .h5
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/atlas.py
def read_atlas(atlas_info: dict) -> AnnData:
"""
Read Scanpy or Cellranger data : .h5ad or .h5
Args:
atlas_info (dict): info dict about the atlas
Returns:
AnnData: scanpy object from .h5ad
"""
logger.info(
f"Load {atlas_info[check.ATLAS_NAME_KEY]} "
f"in {atlas_info[check.ATLAS_PATH_KEY]}"
)
try:
if atlas_info[check.ATLAS_TYPE_KEY] == cellranger.CELLRANGER_TYPE_CURRENT:
logger.debug(
"Read Cellranger >= v3 results " f"{atlas_info[check.ATLAS_PATH_KEY]}"
)
adata = cellranger.read_cellranger_current(atlas_info)
elif atlas_info[check.ATLAS_TYPE_KEY] == cellranger.CELLRANGER_TYPE_OBSOLETE:
logger.debug(
"Read Cellranger < v3 results " f"{atlas_info[check.ATLAS_PATH_KEY]}"
)
adata = cellranger.read_cellranger_obsolete(atlas_info)
else:
logger.debug(f"Read Scanpy file {atlas_info[check.ATLAS_PATH_KEY]}")
adata = sc.read_h5ad(atlas_info[check.ATLAS_PATH_KEY])
return adata
except _io.utils.AnnDataReadError:
logger.warning(
"AnnDataReadError, cannot read: " f"{atlas_info[check.ATLAS_PATH_KEY]}"
)
return dict()
checkatlas.seurat¶
checkatlas.seurat
¶
check_seurat_install() -> None
¶
Check if Seurat is installed, run installation if not
Source code in checkatlas/seurat.py
def check_seurat_install() -> None:
"""Check if Seurat is installed, run installation if not"""
# import R's utility package
utils = rpackages.importr("utils")
# select a mirror for R packages
utils.chooseCRANmirror(ind=1) # select the first mirror in the list
# R package names
packnames = ("Seurat", "SeuratObject")
# Selectively install what needs to be install.
# We are fancy, just because we can.
names_to_install = [x for x in packnames if not rpackages.isinstalled(x)]
if len(names_to_install) > 0:
# create personal library
rcode = """dir.create(Sys.getenv("R_LIBS_USER"), recursive = TRUE)"""
robjects.r(rcode)
# add to the path
rcode = """.libPaths(Sys.getenv("R_LIBS_USER"))"""
robjects.r(rcode)
logger.debug(f"Set Rlibpaths: {robjects.r(rcode)}")
utils.install_packages(StrVector(names_to_install))
create_anndata_table(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None
¶
Create a table with all AnnData-like arguments in Seurat object :param seurat: :param atlas_name: :param atlas_path: :return:
Source code in checkatlas/seurat.py
def create_anndata_table(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
"""
Create a table with all AnnData-like arguments in Seurat object
:param seurat:
:param atlas_name:
:param atlas_path:
:return:
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
logger.debug(f"Create Adata table for {atlas_name}")
csv_path = os.path.join(
folders.get_folder(args.path, folders.ANNDATA),
atlas_name + check.TSV_EXTENSION,
)
# Create AnnData table
header = ["atlas_obs", "obsm", "var", "varm", "uns"]
df_summary = pd.DataFrame(index=[atlas_name], columns=header)
# Create r_functions
r_obs = robjects.r("obs <- function(seurat){ return(colnames(seurat@meta.data))}")
r_obsm = robjects.r("f<-function(seurat){return(names(seurat@reductions))}")
r_uns = robjects.r("uns <- function(seurat){ return(colnames(seurat@misc))}")
obs_list = r_obs(seurat)
obsm_list = r_obsm(seurat)
var_list = [""]
varm_list = [""]
uns_list = [""]
if not isinstance(r_uns(seurat), NULLType):
uns_list = r_uns(seurat)
df_summary["atlas_obs"][atlas_name] = (
"<code>" + "</code><br><code>".join(obs_list) + "</code>"
)
df_summary["obsm"][atlas_name] = (
"<code>" + "</code><br><code>".join(obsm_list) + "</code>"
)
df_summary["var"][atlas_name] = (
"<code>" + "</code><br><code>".join(var_list) + "</code>"
)
df_summary["varm"][atlas_name] = (
"<code>" + "</code><br><code>".join(varm_list) + "</code>"
)
df_summary["uns"][atlas_name] = (
"<code>" + "</code><br><code>".join(uns_list) + "</code>"
)
df_summary.to_csv(csv_path, index=False, quoting=False, sep="\t")
create_metric_annot(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None
¶
Calc annotation metrics :param adata: :param atlas_path: :param atlas_info: :param args: :return:
Source code in checkatlas/seurat.py
def create_metric_annot(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
"""
Calc annotation metrics
:param adata:
:param atlas_path:
:param atlas_info:
:param args:
:return:
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
csv_path = os.path.join(
folders.get_folder(args.path, folders.ANNOTATION),
atlas_name + check.TSV_EXTENSION,
)
metric_annot = args.metric_annot
if metric_annot is None:
metric_annot = metrics.METRICS_ANNOT
header = ["Annot_Sample", "Reference", "obs"] + metric_annot
df_annot = pd.DataFrame(columns=header)
obs_keys = get_viable_obs_annot(seurat, args)
if len(obs_keys) > 1:
logger.debug(f"Calc annotation metrics for {atlas_name}")
if len(obs_keys) != 0:
ref_obs = obs_keys[0]
for i in range(1, len(obs_keys)):
obs_key = obs_keys[i]
dict_line = {
"Annot_Sample": [atlas_name + "_" + obs_key],
"Reference": [ref_obs],
"obs": [obs_key],
}
for metric in metric_annot:
logger.debug(
f"Calc {metric} for {atlas_name} "
f"with obs {obs_key} vs ref_obs {ref_obs}"
)
metric_value = metrics.calc_metric_annot_seurat(
metric, seurat, obs_key, ref_obs
)
dict_line[metric] = metric_value
df_line = pd.DataFrame(dict_line)
df_annot = pd.concat([df_annot, df_line], ignore_index=True, axis=0)
df_annot.to_csv(csv_path, index=False, sep="\t")
else:
logger.debug(f"No viable obs_key was found for {atlas_name}")
create_metric_cluster(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None
¶
Calc clustering metrics :param seurat: :param atlas_path: :param atlas_info: :param args: :return:
Source code in checkatlas/seurat.py
def create_metric_cluster(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
"""
Calc clustering metrics
:param seurat:
:param atlas_path:
:param atlas_info:
:param args:
:return:
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
csv_path = os.path.join(
folders.get_folder(args.path, folders.CLUSTER),
atlas_name + check.TSV_EXTENSION,
)
header = ["Clust_Sample", "obs"] + args.metric_cluster
df_cluster = pd.DataFrame(columns=header)
obs_keys = get_viable_obs_annot(seurat, args)
obsm_key_representation = "umap"
if len(obs_keys) > 0:
logger.debug(f"Calc clustering metrics for {atlas_name}")
for obs_key in obs_keys:
dict_line = {
"Clust_Sample": [atlas_name + "_" + obs_key],
"obs": [obs_key],
}
for metric in args.metric_cluster:
logger.debug(
f"Calc {metric} for {atlas_name} "
f"with obs {obs_key} and obsm {obsm_key_representation}"
)
result = metrics.calc_metric_cluster_seurat(
metric, seurat, obs_key, obsm_key_representation
)
if isinstance(result, tuple):
metric_value, running_time = result
dict_line[metric] = metric_value
dict_line[f"{metric}_running_time"] = running_time
else:
dict_line[metric] = result
df_line = pd.DataFrame(dict_line)
df_cluster = pd.concat([df_cluster, df_line], ignore_index=True, axis=0)
df_cluster.to_csv(csv_path, index=False, sep="\t")
else:
logger.debug(f"No viable obs_key was found for {atlas_name}")
create_metric_dimred(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None
¶
Calc dimensionality reduction metrics :param adata: :param atlas_path: :param atlas_info: :param args: :return:
Source code in checkatlas/seurat.py
def create_metric_dimred(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
"""
Calc dimensionality reduction metrics
:param adata:
:param atlas_path:
:param atlas_info:
:param args:
:return:
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
csv_path = os.path.join(
folders.get_folder(args.path, folders.DIMRED),
atlas_name + check.TSV_EXTENSION,
)
header = ["Dimred_Sample", "obsm"] + args.metric_dimred
df_dimred = pd.DataFrame(columns=header)
# r_reduction = robjects.r(
# "reduc <- function(seurat, obsm_key){"
# " return(Embeddings(object = seurat, reduction = obsm_key))}"
# )
obsm_keys = get_viable_obsm(seurat, args)
if len(obsm_keys) > 0:
logger.debug(f"Calc dim red metrics for {atlas_name}")
for obsm_key in obsm_keys:
dict_line = {
"Dimred_Sample": [atlas_name + "_" + obsm_key],
"obsm": [obsm_key],
}
for metric in args.metric_dimred:
logger.debug(f"Calc {metric} for {atlas_name} with obsm {obsm_key}")
# r_countmatrix = robjects.r(
# "mat <- function(seurat)
# { return(seurat@assays$RNA@counts)}"
# )
# high_dim_counts = ro.conversion.rpy2py(r_countmatrix(seurat))
# low_dim_counts = ro.conversion.rpy2py(
# r_reduction(seurat, obsm_key)
# )
# metric_value = metrics.calc_metric_dimred(
# metric, high_dim_counts, low_dim_counts)
logger.warning(
"!!! Dim reduction metrics not available for Seurat"
" at the moment !!!"
)
# metric_value = -1
# dict_line[metric] = str(metric_value)
df_line = pd.DataFrame(dict_line)
df_dimred = pd.concat([df_dimred, df_line], ignore_index=True, axis=0)
df_dimred.to_csv(csv_path, index=False, sep="\t")
else:
logger.debug(f"No viable obsm_key was found for {atlas_name}")
create_qc_plots(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None
¶
Display the atlas QC Search for the OBS variable which correspond to the toal_RNA, total_UMI, MT_ratio, RT_ratio :param path: :param adata: :param atlas_name: :param atlas_path: :return:
Source code in checkatlas/seurat.py
def create_qc_plots(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
"""
Display the atlas QC
Search for the OBS variable which correspond to the toal_RNA, total_UMI,
MT_ratio, RT_ratio
:param path:
:param adata:
:param atlas_name:
:param atlas_path:
:return:
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
qc_path = os.path.join(
folders.get_folder(args.path, folders.QC_FIG),
atlas_name + check.QC_FIG_EXTENSION,
)
logger.debug(f"Create QC violin plot for {atlas_name}")
importr("ggplot2")
r_cmd = (
"vln_plot <- function(seurat, obs, qc_path){"
"vln <- VlnPlot(seurat, features = obs, ncol = length(obs));"
"ggsave(qc_path, vln, width = 10, "
"height = 4, dpi = 150)}"
)
r_violin = robjects.r(r_cmd)
obs_keys = list(SEURAT_TO_SCANPY_OBS.keys())
r_obs = robjects.StrVector(obs_keys)
r_violin(seurat, r_obs, qc_path)
create_qc_tables(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None
¶
Display the atlas QC of seurat Search for the metadata variable which correspond to the total_RNA, total_UMI, MT_ratio, RT_ratio :param path: :param adata: :param atlas_name: :param atlas_path: :return:
Source code in checkatlas/seurat.py
def create_qc_tables(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
"""
Display the atlas QC of seurat
Search for the metadata variable which correspond
to the total_RNA, total_UMI, MT_ratio, RT_ratio
:param path:
:param adata:
:param atlas_name:
:param atlas_path:
:return:
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
qc_path = os.path.join(
folders.get_folder(args.path, folders.QC),
atlas_name + check.TSV_EXTENSION,
)
logger.debug(f"Create QC tables for {atlas_name}")
obs_keys = get_viable_obs_qc(seurat, args)
r_meta = robjects.r("obs <- function(seurat){ return(seurat@meta.data)}")
r_metadata = r_meta(seurat)
with (ro.default_converter + pandas2ri.converter).context():
df_metadata = ro.conversion.get_conversion().rpy2py(r_metadata)
df_annot = df_metadata[obs_keys]
# rename columns with scanpy names
new_columns = list()
for column in df_annot.columns:
new_columns.append(SEURAT_TO_SCANPY_OBS[column])
df_annot.columns = new_columns
# Rank cell by qc metric
for header in df_annot.columns:
if header != atlas.CELLINDEX_HEADER:
new_header = f"cellrank_{header}"
df_annot = df_annot.sort_values(header, ascending=False)
df_annot.loc[:, [new_header]] = range(1, len(df_annot) + 1)
# Sample QC table when more cells than args.plot_celllimit are present
df_annot = atlas.atlas_sampling(df_annot, "QC", args)
df_annot.loc[:, [atlas.CELLINDEX_HEADER]] = range(1, len(df_annot) + 1)
df_annot.to_csv(qc_path, index=False, quoting=False, sep="\t")
create_summary_table(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None
¶
Create a table with all interesting variables :param seurat: :param atlas_name: :param csv_path: :return:
Source code in checkatlas/seurat.py
def create_summary_table(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
"""
Create a table with all interesting variables
:param seurat:
:param atlas_name:
:param csv_path:
:return:
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
logger.debug(f"Create Summary table for {atlas_name}")
csv_path = os.path.join(
folders.get_folder(args.path, folders.SUMMARY),
atlas_name + check.TSV_EXTENSION,
)
# Create summary table
header = [
"AtlasFileType",
"NbCells",
"NbGenes",
"AnnData.raw",
"AnnData.X",
"File_extension",
"File_path",
]
r_nrow = robjects.r["nrow"]
r_ncol = robjects.r["ncol"]
ncells = r_ncol(seurat)[0]
ngenes = r_nrow(seurat)[0]
x_raw = False
x_norm = True
df_summary = pd.DataFrame(index=[atlas_name], columns=header)
df_summary["AtlasFileType"][atlas_name] = atlas_info[check.ATLAS_TYPE_KEY]
df_summary["NbCells"][atlas_name] = ncells
df_summary["NbGenes"][atlas_name] = ngenes
df_summary["AnnData.raw"][atlas_name] = x_raw
df_summary["AnnData.X"][atlas_name] = x_norm
df_summary["File_extension"][atlas_name] = atlas_info[check.ATLAS_EXTENSION_KEY]
df_summary["File_path"][atlas_name] = atlas_info[check.ATLAS_PATH_KEY]
df_summary.to_csv(csv_path, index=False, sep="\t")
create_tsne_fig(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None
¶
Display the TSNE of celltypes Search for the OBS variable which correspond to the celltype annotation :param path: :param adata: :param atlas_name: :param atlas_path: :return:
Source code in checkatlas/seurat.py
def create_tsne_fig(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
"""
Display the TSNE of celltypes
Search for the OBS variable which correspond to the celltype annotation
:param path:
:param adata:
:param atlas_name:
:param atlas_path:
:return:
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
# Search if tsne reduction exists
r = re.compile(".*tsne*.")
r_names = robjects.r["names"]
obsm_list = r_names(seurat)
importr("ggplot2")
if len(list(filter(r.match, obsm_list))) > 0:
logger.debug(f"Create t-SNE figure for {atlas_name}")
# Setting up figures directory
tsne_path = os.path.join(
folders.get_folder(args.path, folders.TSNE),
atlas_name + check.TSNE_EXTENSION,
)
# Exporting tsne
obs_keys = get_viable_obs_annot(seurat, args)
r_cmd = (
"tsne <- function(seurat, obs_key, tsne_path){"
"tsne_plot <- DimPlot(seurat, group.by = obs_key, "
'reduction = "tsne");'
"ggsave(tsne_path, tsne_plot, width = 10, "
"height = 6, dpi = 76)}"
)
r_tsne = robjects.r(r_cmd)
r_tsne(seurat, obs_keys[0], tsne_path)
create_umap_fig(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None
¶
Display the UMAP of celltypes Search for the OBS variable which correspond to the celltype annotation :param path: :param adata: :param atlas_name: :param atlas_path: :return:
Source code in checkatlas/seurat.py
def create_umap_fig(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
"""
Display the UMAP of celltypes
Search for the OBS variable which correspond to the celltype annotation
:param path:
:param adata:
:param atlas_name:
:param atlas_path:
:return:
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
# Search if tsne reduction exists
r = re.compile(".*umap*.")
r_names = robjects.r["names"]
obsm_list = r_names(seurat)
importr("ggplot2")
if len(list(filter(r.match, obsm_list))) > 0:
logger.debug(f"Create UMAP figure for {atlas_name}")
# Setting up figures directory
umap_path = os.path.join(
folders.get_folder(args.path, folders.UMAP),
atlas_name + check.UMAP_EXTENSION,
)
# Exporting umap
obs_keys = get_viable_obs_annot(seurat, args)
r_cmd = (
"umap <- function(seurat, obs_key, umap_path){"
"umap_plot <- DimPlot(seurat, group.by = obs_key, "
'reduction = "umap");'
"ggsave(umap_path, umap_plot, width = 10, "
"height = 6, dpi = 76)}"
)
r_umap = robjects.r(r_cmd)
r_umap(seurat, obs_keys[0], umap_path)
get_viable_obs_annot(seurat: RS4, args: argparse.Namespace) -> list
¶
Search in obs_keys a match to OBS_CLUSTERS values ! Remove obs_key with only one category ! Extract sorted obs_keys in same order then OBS_CLUSTERS
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/seurat.py
def get_viable_obs_annot(seurat: RS4, args: argparse.Namespace) -> list:
"""
Search in obs_keys a match to OBS_CLUSTERS values
! Remove obs_key with only one category !
Extract sorted obs_keys in same order then OBS_CLUSTERS
Args:
seurat (RS4): _description_
args (argparse.Namespace): _description_
Returns:
list: _description_
"""
obs_keys = list()
r_obs = robjects.r("obs <- function(seurat){ return(colnames(seurat@meta.data))}")
obs_key_seurat = r_obs(seurat)
r_annot = robjects.r(
"type <- function(seurat, obs_key){ " "return(seurat[[obs_key]][[obs_key]])}"
)
# Get keys from OBS_CLUSTERS
for obs_key in obs_key_seurat:
for obs_key_celltype in args.obs_cluster:
if obs_key_celltype in obs_key:
if isinstance(r_annot(seurat, obs_key), FactorVector):
obs_keys.append(obs_key)
# Remove keys with only one category
obs_keys_final = list()
for obs_key in obs_keys:
annotations = r_annot(seurat, obs_key)
if len(annotations.levels) != 1:
logger.debug(f"Add obs_key {obs_key} with cat {annotations.levels}")
obs_keys_final.append(obs_key)
return sorted(obs_keys_final)
get_viable_obs_qc(seurat: RS4, args: argparse.Namespace) -> list
¶
Search in obs_keys a match to OBS_QC values Extract sorted obs_keys in same order then OBS_QC
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/seurat.py
def get_viable_obs_qc(seurat: RS4, args: argparse.Namespace) -> list:
"""
Search in obs_keys a match to OBS_QC values
Extract sorted obs_keys in same order then OBS_QC
Args:
seurat (RS4): _description_
args (argparse.Namespace): _description_
Returns:
list: _description_
"""
r_obs = robjects.r("obs <- function(seurat){ return(colnames(seurat@meta.data))}")
obs_keys = list()
for obs_qc in args.qc_display:
obs_qc = SCANPY_TO_SEURAT_OBS[obs_qc]
if obs_qc in r_obs(seurat):
obs_keys.append(obs_qc)
return obs_keys
get_viable_obsm(seurat, args)
¶
Search viable obsm for dimensionality reduction metric calc. ! No filter on osbm is appled for now ! :param seurat: :param args: :return:
Source code in checkatlas/seurat.py
def get_viable_obsm(seurat, args):
"""
Search viable obsm for dimensionality reduction metric
calc.
! No filter on osbm is appled for now !
:param seurat:
:param args:
:return:
"""
obsm_keys = list()
# for obsm_key in adata.obsm_keys():
# if obsm_key in args.obsm_dimred:
r_obsm = robjects.r("f<-function(seurat){return(names(seurat@reductions))}")
obsm_keys_r = r_obsm(seurat)
obsm_keys = list()
for obsm_key in obsm_keys_r:
print(obsm_key)
obsm_keys.append(obsm_key)
logger.debug(f"Add obsm {obsm_keys}")
return obsm_keys
read_atlas(atlas_info: dict) -> RS4
¶
Read Seurat object in python using rpy2
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/seurat.py
def read_atlas(atlas_info: dict) -> RS4:
"""Read Seurat object in python using rpy2
Args:
atlas_info (dict): info dict about the atlas
Returns:
RS4: _description_
"""
atlas_name = atlas_info[check.ATLAS_NAME_KEY]
atlas_path = atlas_info[check.ATLAS_PATH_KEY]
logger.info(f"Load {atlas_name} in " f"{atlas_path}")
rcode = f'readRDS("{atlas_path}")'
seurat = robjects.r(rcode)
rclass = robjects.r["class"]
if rclass(seurat)[0] == "Seurat":
importr("Seurat")
return seurat
else:
logger.info(f"{atlas_name} is not a Seurat object")
return None
checkatlas.cellranger¶
checkatlas.cellranger
¶
read_cellranger_current(atlas_info: dict) -> AnnData
¶
Read cellranger files.
Load first /outs/filtered_feature_bc_matrix.h5 Then add (if found): - Clustering - PCA- - UMAP - TSNE Args: atlas_path (dict): info on the atlas
| Returns: |
|
|---|
Source code in checkatlas/cellranger.py
def read_cellranger_current(atlas_info: dict) -> AnnData:
"""
Read cellranger files.
Load first /outs/filtered_feature_bc_matrix.h5
Then add (if found):
- Clustering
- PCA-
- UMAP
- TSNE
Args:
atlas_path (dict): info on the atlas
Returns:
AnnData: scanpy object from cellranger
"""
cellranger_out_path = os.path.dirname(atlas_info[check.ATLAS_PATH_KEY])
cellranger_analysis_path = os.path.join(cellranger_out_path, "analysis")
cellranger_clust_path = os.path.join(cellranger_analysis_path, "clustering")
cellranger_umap_path = os.path.join(cellranger_analysis_path, "umap")
cellranger_tsne_path = os.path.join(cellranger_analysis_path, "tsne")
cellranger_pca_path = os.path.join(cellranger_analysis_path, "pca")
# Search graphclust
graphclust_path = ""
for root, dirs, files in os.walk(cellranger_clust_path):
for dir in dirs:
if dir.endswith("graphclust"):
cluster_path = os.path.join(root, dir, "clusters.csv")
if os.path.exists(cluster_path) and not root.endswith("atac"):
graphclust_path = cluster_path
break
# Search kmeans
kmeans_path = ""
k_value = 0
found_kmeans = False
for root, dirs, files in os.walk(cellranger_clust_path):
for dir in dirs:
# Search the highest kmeans = 10
dir_prefix = "kmeans_10"
if dir_prefix in dir and not found_kmeans:
cluster_path = os.path.join(root, dir, "clusters.csv")
if os.path.exists(cluster_path):
kmeans_path = cluster_path
k_value = 10
found_kmeans = True
break
# Or search the highest kmeans = 5 (for multiome atlas)
dir_prefix = os.path.join("gex", "kmeans_5")
if dir_prefix in os.path.join(root, dir) and not found_kmeans:
cluster_path = os.path.join(root, dir, "clusters.csv")
if os.path.exists(cluster_path):
kmeans_path = cluster_path
k_value = 5
found_kmeans = True
break
# Search umap
rna_umap = ""
for root, dirs, files in os.walk(cellranger_umap_path):
for file in files:
if file.endswith("projection.csv") and not root.endswith("atac"):
rna_umap = os.path.join(root, file)
break
# Search t-SNE
rna_tsne = ""
for root, dirs, files in os.walk(cellranger_tsne_path):
for file in files:
if file.endswith("projection.csv") and not root.endswith("atac"):
rna_tsne = os.path.join(root, file)
break
rna_pca = ""
for root, dirs, files in os.walk(cellranger_pca_path):
for file in files:
if file.endswith("projection.csv") and not root.endswith("atac"):
rna_pca = os.path.join(root, file)
break
# Manage multiome cellranger files
dim_red_path = os.path.join(cellranger_analysis_path, "dimensionality_reduction")
if os.path.exists(dim_red_path):
gex_path = os.path.join(dim_red_path, "gex")
if os.path.exists(gex_path):
rna_umap = os.path.join(gex_path, "umap_projection.csv")
rna_tsne = os.path.join(gex_path, "tsne_projection.csv")
rna_pca = os.path.join(gex_path, "pca_projection.csv")
# Read 10x h5 file
adata = sc.read_10x_h5(atlas_info[check.ATLAS_PATH_KEY])
adata.var_names_make_unique()
# Add cluster
if os.path.exists(graphclust_path):
df_cluster = pd.read_csv(graphclust_path, index_col=0)
adata.obs["cellranger_graphclust"] = df_cluster["Cluster"]
if os.path.exists(kmeans_path):
df_cluster = pd.read_csv(kmeans_path, index_col=0)
adata.obs["cellranger_kmeans_" + str(k_value)] = df_cluster["Cluster"]
# Add reduction
if os.path.exists(rna_umap):
df_umap = pd.read_csv(rna_umap, index_col=0)
adata.obsm["X_umap"] = df_umap
if os.path.exists(rna_tsne):
df_tsne = pd.read_csv(rna_tsne, index_col=0)
adata.obsm["X_tsne"] = df_tsne
if os.path.exists(rna_pca):
df_pca = pd.read_csv(rna_pca, index_col=0)
adata.obsm["X_pca"] = df_pca
return adata
read_cellranger_obsolete(atlas_info: dict) -> AnnData
¶
Read cellranger files.
Load first /outs/filtered_feature_bc_matrix.h5 Then add (if found): - Clustering - PCA- - UMAP - TSNE Args: atlas_path (dict): info on the atlas
| Returns: |
|
|---|
Source code in checkatlas/cellranger.py
def read_cellranger_obsolete(atlas_info: dict) -> AnnData:
"""
Read cellranger files.
Load first /outs/filtered_feature_bc_matrix.h5
Then add (if found):
- Clustering
- PCA-
- UMAP
- TSNE
Args:
atlas_path (dict): info on the atlas
Returns:
AnnData: scanpy object from cellranger
"""
cellranger_path = atlas_info[check.ATLAS_PATH_KEY].replace(CELLRANGER_MATRIX_FILE, "")
cellranger_out_path = os.path.join(cellranger_path, os.pardir, os.pardir)
cellranger_analysis_path = os.path.join(cellranger_out_path, "analysis_csv")
cellranger_umap_path = os.path.join(cellranger_analysis_path, "umap")
cellranger_tsne_path = os.path.join(cellranger_analysis_path, "tsne")
cellranger_pca_path = os.path.join(cellranger_analysis_path, "pca")
print(cellranger_out_path)
print(cellranger_analysis_path)
print(cellranger_umap_path)
# Search graphclust
graphclust_path = ""
for root, dirs, files in os.walk(cellranger_out_path):
for dir in dirs:
if dir.endswith("graphclust"):
cluster_path = os.path.join(root, dir, "clusters.csv")
if os.path.exists(cluster_path):
graphclust_path = cluster_path
break
# Search kmeans
kmeans_path = ""
k_value = 0
for root, dirs, files in os.walk(cellranger_out_path):
for dir in dirs:
if dir.endswith("kmeans"):
# Search the highest kmeans from 15 to 3
for k in reversed(range(3, 16)):
cluster_path = os.path.join(
root, dir, str(k) + "_clusters", "clusters.csv"
)
if os.path.exists(cluster_path):
kmeans_path = cluster_path
k_value = k
break
rna_umap = os.path.join(cellranger_umap_path, "projection.csv")
rna_tsne = os.path.join(cellranger_tsne_path, "projection.csv")
rna_pca = os.path.join(cellranger_pca_path, "projection.csv")
# get matrix folder
matrix_folder = os.path.dirname(atlas_info[check.ATLAS_PATH_KEY])
adata = sc.read_10x_mtx(matrix_folder)
adata.var_names_make_unique()
# Add cluster
if os.path.exists(graphclust_path):
df_cluster = pd.read_csv(graphclust_path, index_col=0)
adata.obs["cellranger_graphclust"] = df_cluster["Cluster"]
if os.path.exists(kmeans_path):
df_cluster = pd.read_csv(kmeans_path, index_col=0)
adata.obs["cellranger_kmeans_" + str(k_value)] = df_cluster["Cluster"]
# Add reduction
if os.path.exists(rna_umap):
df_umap = pd.read_csv(rna_umap, index_col=0)
adata.obsm["X_umap"] = df_umap
if os.path.exists(rna_tsne):
df_tsne = pd.read_csv(rna_tsne, index_col=0)
if len(df_tsne) == len(adata):
adata.obsm["X_tsne"] = df_tsne
if os.path.exists(rna_pca):
df_pca = pd.read_csv(rna_pca, index_col=0)
if len(df_pca) == len(adata):
adata.obsm["X_pca"] = df_pca
return adata
checkatlas.metrics.metrics¶
checkatlas.metrics.metrics
¶
annotation_to_num(annotation, ref_annotation)
¶
Transforms the annotations from categorical to numerical
Parameters¶
adata partition_key reference
Returns¶
Source code in checkatlas/metrics/metrics.py
def annotation_to_num(annotation, ref_annotation):
"""
Transforms the annotations from categorical to numerical
Parameters
----------
adata
partition_key
reference
Returns
-------
"""
annotation = annotation.to_numpy()
ref_annotation = ref_annotation.to_numpy()
le = LabelEncoder()
le.fit(annotation)
annotation = le.transform(annotation)
le2 = LabelEncoder()
le2.fit(ref_annotation)
ref_annotation = le2.transform(ref_annotation)
return annotation, ref_annotation
cal_annot(adata, atlas_name=None, metric_list=None, all=False, file_dir=None, n_jobs=-1, verbose=True, preprocess_context=None)
¶
Comprehensive annotation pipeline for all annotation metrics.
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/metrics/metrics.py
def cal_annot(
adata,
atlas_name=None,
metric_list=None,
all=False,
file_dir=None,
n_jobs=-1,
verbose=True,
preprocess_context=None,
):
"""
Comprehensive annotation pipeline for all annotation metrics.
Args:
adata (AnnData): Annotated data matrix.
metric_list (list): List of metric names to calculate.
If provided, overrides `all` parameter.
all (bool): If True, calculate all available annotation metrics.
If False, calculate a default subset.
Ignored if metric_list is provided.
file_dir (str): Directory path where the results CSV will be saved.
If None, saves to current working directory.
n_jobs (int): Number of parallel jobs (-1 = all cores).
verbose (bool): Whether to print progress information.
preprocess_context (PreprocessContext, optional): If provided,
column detection and kNN precomputation are skipped and
the context's precomputed data (kNN graphs, neighbour
graphs, batch keys) is reused.
Returns:
pd.DataFrame: Results dataframe with columns:
[Atlas Name, Metric Name, Reference/Input 1, Prediction/Input 2, Value, Time (s)]
"""
import inspect
from ..utils.col_detector import CheckAtlasColumnDetector
from ._jax_utils import _GPU_AVAILABLE as _cal_gpu
from ._jax_utils import _JAX_AVAILABLE as _cal_jax
from ._neighbors import NeighborResults, _clear_neighbors_cache
from ._neighbors import compute_neighbors as _cal_knn
_USE_JAX = _cal_jax and _cal_gpu
# Set file directory
if file_dir is None:
file_dir = os.getcwd()
else:
os.makedirs(file_dir, exist_ok=True)
# ── Precomputed kNN lookup (populated from context or built locally) ──
emb_nn = {}
if preprocess_context is not None:
ref_keys = preprocess_context.ref_keys
pred_keys = preprocess_context.pred_keys
embedding_keys = getattr(preprocess_context, "annotation_embedding_keys", None) or preprocess_context.embedding_keys
batch_keys = preprocess_context.batch_keys
if not batch_keys:
batch_keys = [col for col in adata.obs.columns if "batch" in col.lower()]
if verbose:
print("Using precomputed context — skipping column detection")
print(f" Reference keys: {ref_keys}")
print(f" Predicted keys: {pred_keys}")
print(f" Embedding keys: {embedding_keys}")
print(f" Batch keys: {batch_keys}")
# Load precomputed kNN from .npz files
for emb in embedding_keys:
_safe = emb.replace("/", "_").replace(" ", "_")
if emb in preprocess_context.knn_paths or _safe in preprocess_context.knn_paths:
from ._cache import load_knn
loaded = load_knn(preprocess_context.annotation_dir, f"knn_{_safe}")
if loaded is not None:
emb_nn[emb] = NeighborResults(
indices=loaded[0], distances=loaded[1]
)
# Re-inject precomputed neighbour graphs for graph_connectivity
for emb, payload in preprocess_context.neighbor_graphs.items():
key_added = payload.get("key_added", f"neighbors_{emb}")
if key_added not in adata.uns:
adata.uns[key_added] = payload.get("uns_entry", {})
conn_key = (
payload.get("uns_entry", {})
.get("connectivities_key", "connectivities")
)
dist_key = (
payload.get("uns_entry", {})
.get("distances_key", "distances")
)
if conn_key in payload and conn_key not in adata.obsp:
adata.obsp[conn_key] = payload["connectivities"]
if dist_key in payload and dist_key not in adata.obsp:
adata.obsp[dist_key] = payload["distances"]
# ── Load precomputed distance matrices from cluster cache ──
# (same pattern as cal_cluster lines 835-855)
precomputed_dists = {}
if preprocess_context is not None:
_safe = lambda s: s.replace("/", "_").replace(" ", "_")
for emb in embedding_keys:
tri_path = os.path.join(
preprocess_context.cluster_dir,
f"dist_{_safe(emb)}.tri",
)
npy_path = tri_path.replace(".tri", ".npy")
if os.path.exists(tri_path):
n_cells = (
adata.obsm[emb].shape[0]
if emb in adata.obsm
else adata.n_obs
)
precomputed_dists[emb] = TriangularMatrix(
n=n_cells, filepath=tri_path, mode="r"
)
elif os.path.exists(npy_path):
precomputed_dists[emb] = np.load(npy_path)
else:
# Detect columns
detector = CheckAtlasColumnDetector(adata)
params = detector.detect_all_parameters()
ref_keys = [x[0] for x in params["annotation"]["reference"]]
pred_keys = [x[0] for x in params["annotation"]["predicted"]]
embedding_keys = [x[0] for x in params["clustering"]["embeddings"]]
batch_keys = [x[0] for x in params.get("batch", [])]
if not batch_keys:
batch_keys = [col for col in adata.obs.columns if "batch" in col.lower()]
# Define metrics to run
if metric_list is not None:
metrics_list = [m for m in metric_list if m in METRICS_ANNOT]
elif all:
metrics_list = METRICS_ANNOT
else:
metrics_list = [
"adj_rand_index",
"normalized_mutual_info",
"adj_mutual_info",
]
metrics_list = [m for m in metrics_list if m in METRICS_ANNOT]
results = []
atlas_name = atlas_name
# Categorize metrics based on their input requirements
# Ref vs Pred
ref_pred_metrics = [
"adj_mutual_info",
"adj_rand_index",
"fowlkes_mallows",
"isolated_f1_score",
"mutual_info",
"normalized_mutual_info",
"rand_index",
"vmeasure",
]
# Embedding + Labels
emb_label_metrics = ["average_silhouette_width", "dunn_index"]
# Batch / Integration (adata + batch/label)
batch_metrics = ["kbet", "pcr"] # lisi is special (iLISI vs cLISI)
# Graph Connectivity (adata + neighbors)
graph_metrics = ["graph_connectivity"]
# Bio Conservation (adata_before, adata_after) - Skipping for single adata pipeline
# unless we define strategy.
bio_metrics = ["cell_cycle_conservation", "highly_variable_genes"]
# Create progress bar with custom format
pbar = tqdm(
metrics_list,
desc="Calculating Annotation Metrics",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
)
for metric in pbar:
# Start timing for this metric
metric_start_time = time.time()
# Update progress bar with current metric name
pbar.set_description(f"Processing: {metric}")
metric_module = getattr(annot, metric)
try:
# 1. Ref vs Pred Metrics
if metric in ref_pred_metrics:
if not ref_keys or not pred_keys:
continue
for ref in ref_keys:
for pred in pred_keys:
# Skip if ref == pred
if ref == pred:
continue
try:
# Preprocess labels
labels_true = adata.obs[ref]
labels_pred = adata.obs[pred]
# Convert to numeric if needed (some metrics handle it, some don't)
# Most sklearn metrics handle strings, but let's be safe or rely on metric impl
# checkatlas metrics usually take raw inputs or handle conversion
# But calc_metric_annot_scanpy uses annotation_to_num.
# We should probably use that helper or do it here.
l_pred, l_true = annotation_to_num(labels_pred, labels_true)
pair_start = time.time()
# Build kwargs dynamically
sig = inspect.signature(metric_module.run)
kw = {}
if "n_jobs" in sig.parameters:
kw["n_jobs"] = n_jobs
if "verbose" in sig.parameters:
kw["verbose"] = False
val = metric_module.run(l_pred, l_true, **kw)
pair_elapsed = time.time() - pair_start
results.append(
{
"Atlas Name": atlas_name,
"Metric Name": metric,
"Reference/Input 1": ref,
"Prediction/Input 2": pred,
"Value": val,
"Time (s)": round(pair_elapsed, 3),
}
)
except Exception as e:
logger.warning(
f"Failed to calculate {metric} for {ref} vs {pred}: {e}"
)
# 2. Embedding + Labels (ASW, Dunn)
elif metric in emb_label_metrics:
if not embedding_keys:
continue
# Run for both ref and pred labels? Usually for predicted clusters.
# But ASW can be run on ground truth too.
targets = list(set(ref_keys + pred_keys))
for emb in embedding_keys:
if emb not in adata.obsm:
continue
X_emb = adata.obsm[emb]
for label in targets:
try:
labels = adata.obs[label]
# Convert to numeric for ASW/Dunn?
# ASW handles labels.
pair_start = time.time()
sig = inspect.signature(metric_module.run)
kw = {}
if "n_jobs" in sig.parameters:
kw["n_jobs"] = n_jobs
if "verbose" in sig.parameters:
kw["verbose"] = False
if "precomputed_dists" in sig.parameters and emb in precomputed_dists:
kw["precomputed_dists"] = precomputed_dists[emb]
val = metric_module.run(X_emb, labels, **kw)
pair_elapsed = time.time() - pair_start
results.append(
{
"Atlas Name": atlas_name,
"Metric Name": metric,
"Reference/Input 1": emb,
"Prediction/Input 2": label,
"Value": val,
"Time (s)": round(pair_elapsed, 3),
}
)
except Exception as e:
logger.warning(
f"Failed to calculate {metric} for {emb} vs {label}: {e}"
)
# 3. LISI (Special Case: iLISI and cLISI)
elif metric == "lisi":
# ── Precompute kNN per embedding once (GPU/JAX or CPU) ──
# Only build if not already loaded from preprocess_context
if not emb_nn and _USE_JAX:
for emb in embedding_keys:
X_emb = np.asarray(adata.obsm[emb], dtype=np.float64)
emb_nn[emb] = _cal_knn(X_emb, n_neighbors=90, backend="auto")
# iLISI: needs batch
if batch_keys:
for batch in batch_keys:
try:
if embedding_keys:
for emb in embedding_keys:
X_emb = adata.obsm[emb]
pair_start = time.time()
if emb in emb_nn:
val = metric_module.run_with_neighbors(
emb_nn[emb],
adata.obs[batch],
verbose=False,
)
else:
sig = inspect.signature(metric_module.run)
kw = {}
if "n_jobs" in sig.parameters:
kw["n_jobs"] = n_jobs
if "verbose" in sig.parameters:
kw["verbose"] = False
val = metric_module.run(
X_emb, adata.obs[batch], **kw
)
pair_elapsed = time.time() - pair_start
results.append(
{
"Atlas Name": atlas_name,
"Metric Name": "iLISI",
"Reference/Input 1": emb,
"Prediction/Input 2": batch,
"Value": val,
"Time (s)": round(pair_elapsed, 3),
}
)
else:
pair_start = time.time()
sig = inspect.signature(metric_module.run)
kw = {}
if "n_jobs" in sig.parameters:
kw["n_jobs"] = n_jobs
if "verbose" in sig.parameters:
kw["verbose"] = False
val = metric_module.run(adata.X, adata.obs[batch], **kw)
pair_elapsed = time.time() - pair_start
results.append(
{
"Atlas Name": atlas_name,
"Metric Name": "iLISI",
"Reference/Input 1": "X",
"Prediction/Input 2": batch,
"Value": val,
"Time (s)": round(pair_elapsed, 3),
}
)
except Exception as e:
logger.warning(f"Failed to calculate iLISI for {batch}: {e}")
# cLISI: needs cell type (ref or pred)
targets = list(set(ref_keys + pred_keys))
for label in targets:
try:
if embedding_keys:
for emb in embedding_keys:
X_emb = adata.obsm[emb]
pair_start = time.time()
if emb in emb_nn:
val = metric_module.run_with_neighbors(
emb_nn[emb],
adata.obs[label],
verbose=False,
)
else:
sig = inspect.signature(metric_module.run)
kw = {}
if "n_jobs" in sig.parameters:
kw["n_jobs"] = n_jobs
if "verbose" in sig.parameters:
kw["verbose"] = False
val = metric_module.run(X_emb, adata.obs[label], **kw)
pair_elapsed = time.time() - pair_start
results.append(
{
"Atlas Name": atlas_name,
"Metric Name": "cLISI",
"Reference/Input 1": emb,
"Prediction/Input 2": label,
"Value": val,
"Time (s)": round(pair_elapsed, 3),
}
)
else:
pair_start = time.time()
sig = inspect.signature(metric_module.run)
kw = {}
if "n_jobs" in sig.parameters:
kw["n_jobs"] = n_jobs
if "verbose" in sig.parameters:
kw["verbose"] = False
val = metric_module.run(adata.X, adata.obs[label], **kw)
pair_elapsed = time.time() - pair_start
results.append(
{
"Atlas Name": atlas_name,
"Metric Name": "cLISI",
"Reference/Input 1": "X",
"Prediction/Input 2": label,
"Value": val,
"Time (s)": round(pair_elapsed, 3),
}
)
except Exception as e:
logger.warning(f"Failed to calculate cLISI for {label}: {e}")
# 4. Batch Metrics (kBET, PCR) — evaluate per (embedding × batch)
elif metric in batch_metrics:
if not batch_keys:
continue
for batch in batch_keys:
if embedding_keys:
for emb in embedding_keys:
try:
X_emb = adata.obsm[emb]
pair_start = time.time()
# kBET: use precomputed kNN from context or JAX
if (
metric == "kbet"
and hasattr(metric_module, "run_with_neighbors")
):
if emb in emb_nn:
nn = emb_nn[emb]
if nn.n_neighbors > 25:
nn = nn.subset_neighbors(25)
val = metric_module.run_with_neighbors(
nn, adata.obs[batch], verbose=False
)
elif _USE_JAX:
X_arr = np.asarray(X_emb, dtype=np.float64)
nn = _cal_knn(X_arr, n_neighbors=25, backend="auto")
val = metric_module.run_with_neighbors(
nn, adata.obs[batch], verbose=False
)
else:
sig = inspect.signature(metric_module.run)
kw = {}
if "n_jobs" in sig.parameters:
kw["n_jobs"] = n_jobs
if "verbose" in sig.parameters:
kw["verbose"] = False
val = metric_module.run(X_emb, adata.obs[batch], **kw)
else:
sig = inspect.signature(metric_module.run)
kw = {}
if "n_jobs" in sig.parameters:
kw["n_jobs"] = n_jobs
if "verbose" in sig.parameters:
kw["verbose"] = False
val = metric_module.run(X_emb, adata.obs[batch], **kw)
pair_elapsed = time.time() - pair_start
results.append(
{
"Atlas Name": atlas_name,
"Metric Name": metric,
"Reference/Input 1": emb,
"Prediction/Input 2": batch,
"Value": val,
"Time (s)": round(pair_elapsed, 3),
}
)
except Exception as e:
logger.warning(
f"Failed to calculate {metric} for {emb} vs {batch}: {e}"
)
else:
try:
pair_start = time.time()
if (
metric == "kbet"
and hasattr(metric_module, "run_with_neighbors")
):
if "X" in emb_nn:
nn = emb_nn["X"]
if nn.n_neighbors > 25:
nn = nn.subset_neighbors(25)
val = metric_module.run_with_neighbors(
nn, adata.obs[batch], verbose=False
)
elif _USE_JAX:
X_arr = np.asarray(adata.X, dtype=np.float64)
if hasattr(adata.X, "toarray"):
X_arr = adata.X.toarray().astype(np.float64)
nn = _cal_knn(X_arr, n_neighbors=25, backend="auto")
val = metric_module.run_with_neighbors(
nn, adata.obs[batch], verbose=False
)
else:
sig = inspect.signature(metric_module.run)
kw = {}
if "n_jobs" in sig.parameters:
kw["n_jobs"] = n_jobs
if "verbose" in sig.parameters:
kw["verbose"] = False
val = metric_module.run(adata.X, adata.obs[batch], **kw)
else:
sig = inspect.signature(metric_module.run)
kw = {}
if "n_jobs" in sig.parameters:
kw["n_jobs"] = n_jobs
if "verbose" in sig.parameters:
kw["verbose"] = False
val = metric_module.run(adata.X, adata.obs[batch], **kw)
pair_elapsed = time.time() - pair_start
results.append(
{
"Atlas Name": atlas_name,
"Metric Name": metric,
"Reference/Input 1": "X",
"Prediction/Input 2": batch,
"Value": val,
"Time (s)": round(pair_elapsed, 3),
}
)
except Exception as e:
logger.warning(
f"Failed to calculate {metric} for {batch}: {e}"
)
# 5. Graph Connectivity
elif metric in graph_metrics:
# We need embeddings AND labels
targets = list(set(ref_keys + pred_keys))
if embedding_keys:
for emb in embedding_keys:
# Calculate neighbors for this embedding
key_added = f"neighbors_{emb}"
try:
# Ensure neighbors are calculated
import scanpy as sc
sc.pp.neighbors(adata, use_rep=emb, key_added=key_added)
except Exception as e:
logger.warning(
f"Failed to calculate neighbors for {emb}: {e}"
)
continue
for label in targets:
try:
pair_start = time.time()
sig = inspect.signature(metric_module.run)
kw = {
"neighbors_key": key_added,
"label_key": label,
}
if "n_jobs" in sig.parameters:
kw["n_jobs"] = n_jobs
if "verbose" in sig.parameters:
kw["verbose"] = False
val = metric_module.run(adata, **kw)
pair_elapsed = time.time() - pair_start
results.append(
{
"Atlas Name": atlas_name,
"Metric Name": metric,
"Reference/Input 1": emb,
"Prediction/Input 2": label,
"Value": val,
"Time (s)": round(pair_elapsed, 3),
}
)
except Exception as e:
logger.warning(
f"Failed to calculate {metric} for {emb} vs {label}: {e}"
)
else:
# No embeddings found, use default neighbors (X or PCA)
# We still need labels
for label in targets:
try:
# metric_module.run will calculate neighbors if 'neighbors' key missing
pair_start = time.time()
sig = inspect.signature(metric_module.run)
kw = {"label_key": label}
if "n_jobs" in sig.parameters:
kw["n_jobs"] = n_jobs
if "verbose" in sig.parameters:
kw["verbose"] = False
val = metric_module.run(adata, **kw)
pair_elapsed = time.time() - pair_start
results.append(
{
"Atlas Name": atlas_name,
"Metric Name": metric,
"Reference/Input 1": "Default",
"Prediction/Input 2": label,
"Value": val,
"Time (s)": round(pair_elapsed, 3),
}
)
except Exception as e:
logger.warning(
f"Failed to calculate {metric} for Default vs {label}: {e}"
)
# 6. Bio Conservation
elif metric in bio_metrics:
# Requires adata_before.
# If we don't have it, we can't run it properly.
# We'll skip for now or log warning.
logger.info(f"Skipping {metric} as it requires 'adata_before'.")
else:
logger.warning(f"Metric {metric} not categorized in pipeline.")
except Exception as e:
logger.error(f"Error running metric {metric}: {e}")
# Calculate and display metric execution time
metric_elapsed = time.time() - metric_start_time
pbar.set_postfix_str(f"Time: {metric_elapsed:.2f}s", refresh=True)
df = pd.DataFrame(results)
# Save MultiQC-compatible wide format to file_dir if provided
if not df.empty and file_dir is not None and atlas_name is not None:
os.makedirs(file_dir, exist_ok=True)
wide_df = _pivot_annot_to_wide(df, atlas_name)
wide_path = os.path.join(file_dir, f"{atlas_name}.tsv")
wide_df.to_csv(wide_path, sep="\t", index=False)
logger.info("MultiQC-compatible annotation table saved to %s", wide_path)
# Clear LISI and kBET kNN caches to free memory
try:
from .annot import lisi as _lisi
_lisi._clear_knn_cache()
except Exception:
pass
try:
from .annot import kbet as _kbet
_kbet._clear_knn_cache()
except Exception:
pass
return df
cal_cluster(adata, atlas_name=None, metric_list=None, all_metrics=True, file_dir=None, n_jobs=-1, verbose=True, seed=42, preprocess_context=None)
¶
Comprehensive clustering assessment pipeline.
Calculates all clustering metrics to evaluate the quality of cluster assignments against embedding representations. Uses CheckAtlasColumnDetector to auto-detect embedding keys and cluster label columns.
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/metrics/metrics.py
def cal_cluster(
adata,
atlas_name=None,
metric_list=None,
all_metrics=True,
file_dir=None,
n_jobs=-1,
verbose=True,
seed=42,
preprocess_context=None,
):
"""
Comprehensive clustering assessment pipeline.
Calculates all clustering metrics to evaluate the quality of cluster assignments
against embedding representations. Uses CheckAtlasColumnDetector to auto-detect
embedding keys and cluster label columns.
Args:
adata (AnnData): Annotated data matrix.
atlas_name (str): Name of the atlas for labeling results.
metric_list (list, optional): List of metric names to calculate.
If provided, overrides ``all_metrics``.
all_metrics (bool): If True, calculate all available cluster metrics.
If False, calculate a default subset.
Ignored if ``metric_list`` is provided.
file_dir (str): Directory path where the results CSV will be saved.
If None, saves to current working directory.
n_jobs (int): Number of parallel jobs (-1 = all cores).
verbose (bool): Whether to print progress.
seed (int): Random seed for reproducibility.
preprocess_context (PreprocessContext, optional): If provided,
column detection and distance-matrix precomputation are
skipped and the context's precomputed data is reused.
Returns:
pd.DataFrame: Results dataframe with columns:
[Atlas Name, Metric Name, Embedding, Label Key, Value, Time (s)]
"""
import gc
import inspect
from scipy.sparse import issparse
from ..utils.col_detector import CheckAtlasColumnDetector
# Set file directory
if file_dir is None:
file_dir = os.getcwd()
else:
os.makedirs(file_dir, exist_ok=True)
precomputed_dists = {}
if preprocess_context is not None:
embedding_keys = preprocess_context.cluster_embedding_keys or preprocess_context.embedding_keys
label_keys = preprocess_context.cluster_label_keys
if not label_keys:
logger.warning(
"No cluster labels in preprocess context. Skipping."
)
return pd.DataFrame()
if verbose:
print("Using precomputed context — skipping column detection")
print(f" Embeddings: {embedding_keys}")
print(f" Cluster labels: {label_keys}")
# Load precomputed distance matrices for silhouette
_safe = lambda s: s.replace("/", "_").replace(" ", "_")
for emb in embedding_keys:
tri_path = os.path.join(
preprocess_context.cluster_dir, f"dist_{_safe(emb)}.tri"
)
npy_path = tri_path.replace(".tri", ".npy")
if os.path.exists(tri_path):
from ._triangular import TriangularMatrix
if emb == "X":
n_cells = adata.X.shape[0]
elif emb in adata.obsm:
n_cells = adata.obsm[emb].shape[0]
else:
n_cells = adata.n_obs
precomputed_dists[emb] = TriangularMatrix(
n=n_cells, filepath=tri_path, mode="r"
)
elif os.path.exists(npy_path):
precomputed_dists[emb] = np.load(npy_path)
else:
# Detect columns using CheckAtlasColumnDetector
if verbose:
print("Detecting embeddings and cluster labels...")
detector = CheckAtlasColumnDetector(adata)
params = detector.detect_all_parameters()
embedding_keys = [x[0] for x in params["clustering"]["embeddings"]]
label_keys = [x[0] for x in params["clustering"]["cluster_labels"]]
if verbose:
print(f" Detected embeddings: {embedding_keys}")
print(f" Detected cluster labels: {label_keys}")
if not embedding_keys:
logger.warning(
"No embeddings detected in adata.obsm. Trying adata.X as fallback."
)
embedding_keys = ["X"]
if not label_keys:
logger.warning(
"No cluster labels detected in adata.obs. Cannot run cluster metrics."
)
return pd.DataFrame()
# Define metrics to run
if metric_list is not None:
metrics_list = [m for m in metric_list if m in METRICS_CLUST]
elif all_metrics:
metrics_list = METRICS_CLUST
else:
# Default subset: fastest and most commonly used
metrics_list = [
"silhouette",
"davies_bouldin",
"calinski_harabasz",
]
metrics_list = [m for m in metrics_list if m in METRICS_CLUST]
results = []
# Total combinations
total_combos = len(embedding_keys) * len(label_keys) * len(metrics_list)
if verbose:
print(
f"\nRunning {len(metrics_list)} metrics × "
f"{len(embedding_keys)} embeddings × "
f"{len(label_keys)} label keys = {total_combos} calculations\n"
)
# Create progress bar
pbar = tqdm(
total=total_combos,
desc="Calculating Cluster Metrics",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
disable=not verbose,
)
for emb_key in embedding_keys:
# Extract embedding data
if emb_key == "X":
X_emb = adata.X
if issparse(X_emb):
X_emb = X_emb.toarray()
else:
if emb_key not in adata.obsm:
logger.warning(
f"Embedding '{emb_key}' not found in adata.obsm. Skipping."
)
pbar.update(len(label_keys) * len(metrics_list))
continue
X_emb = np.asarray(adata.obsm[emb_key])
for label_key in label_keys:
if label_key not in adata.obs.columns:
logger.warning(
f"Label key '{label_key}' not found in adata.obs. Skipping."
)
pbar.update(len(metrics_list))
continue
labels = np.asarray(adata.obs[label_key])
# Check we have at least 2 clusters
n_unique = len(np.unique(labels))
if n_unique < 2:
logger.warning(f"Label '{label_key}' has < 2 clusters. Skipping.")
pbar.update(len(metrics_list))
continue
for metric_name in metrics_list:
metric_start_time = time.time()
pbar.set_description(f"{emb_key}|{label_key}: {metric_name}")
try:
metric_module = getattr(cluster, metric_name)
# Build kwargs dynamically based on metric's signature
sig = inspect.signature(metric_module.run)
metric_params = sig.parameters
kwargs = {}
if "n_jobs" in metric_params:
kwargs["n_jobs"] = n_jobs
if "verbose" in metric_params:
kwargs["verbose"] = False
if "random_state" in metric_params:
kwargs["random_state"] = seed
if "seed" in metric_params:
kwargs["seed"] = seed
if "max_samples" in metric_params:
kwargs["max_samples"] = None # disable subsampling
# Pass precomputed distances for silhouette when available
if "precomputed_dists" in metric_params and emb_key in precomputed_dists:
kwargs["precomputed_dists"] = precomputed_dists[emb_key]
# Call the metric
value = metric_module.run(X_emb, labels, **kwargs)
elapsed = time.time() - metric_start_time
results.append(
{
"Atlas Name": atlas_name,
"Metric Name": metric_name,
"Embedding": emb_key,
"Label Key": label_key,
"Value": value,
"Time (s)": round(elapsed, 3),
}
)
pbar.set_postfix_str(f"={value:.4f} ({elapsed:.1f}s)", refresh=True)
except Exception as e:
elapsed = time.time() - metric_start_time
logger.warning(f"Failed {metric_name} for {emb_key}/{label_key}: {e}")
results.append(
{
"Atlas Name": atlas_name,
"Metric Name": metric_name,
"Embedding": emb_key,
"Label Key": label_key,
"Value": np.nan,
"Time (s)": round(elapsed, 3),
}
)
pbar.update(1)
pbar.close()
# Create DataFrame
df = pd.DataFrame(results)
# Summary
if verbose and not df.empty:
total_time = df["Time (s)"].sum()
print(f"\nTotal computation time: {total_time:.2f}s")
print(
f"Results: {len(df)} measurements across {len(df['Metric Name'].unique())} metrics"
)
gc.collect()
return df
cal_dimred(adata, atlas_name=None, low_dim_keys=None, high_dim_key='X', metric_list=None, k_neighbors=30, n_samples=None, seed=42, n_jobs=-1, file_dir=None, verbose=True, use_cache=True)
¶
Calculate dimensionality reduction metrics for multiple embeddings.
For each low_dim_key the embedding is compared against the
reference high_dim_key (defaults to adata.X — the raw gene
expression matrix).
Precomputation (distance matrices + kNN graphs) is performed once for the high‑dimensional reference and once per low‑dimensional embedding key. Memory‑mapped files are used for N×N distance matrices on datasets with > 10 000 cells to keep RAM usage bounded.
Parameters¶
adata : AnnData
atlas_name : str, optional
Used only for logging.
low_dim_keys : list of str, optional
.obsm keys to evaluate. Defaults to all keys returned by
:meth:CheckAtlasColumnDetector.get_dimred_embeddings.
high_dim_key : str
Reference key. 'X' means adata.X. (Default 'X')
metric_list : list of str, optional
Metrics to compute. Defaults to :data:METRICS_DIMRED.
k_neighbors : int
n_samples : int or None
Number of cells to subsample. None = use all cells.
seed : int
n_jobs : int
file_dir : str, optional
Directory for temporary memmap files. Falls back to system
/tmp.
verbose : bool
Returns¶
pd.DataFrame
Long‑format table with columns
[Metric Name, Low Dim Key, High Dim Key, Value, Time (s)].
Caller is responsible for pivoting and saving.
Source code in checkatlas/metrics/metrics.py
def cal_dimred(
adata,
atlas_name=None,
low_dim_keys=None,
high_dim_key="X",
metric_list=None,
k_neighbors=30,
n_samples=None,
seed=42,
n_jobs=-1,
file_dir=None,
verbose=True,
use_cache=True,
):
"""Calculate dimensionality reduction metrics for multiple embeddings.
For each ``low_dim_key`` the embedding is compared against the
reference ``high_dim_key`` (defaults to ``adata.X`` — the raw gene
expression matrix).
Precomputation (distance matrices + kNN graphs) is performed once for
the high‑dimensional reference and once per low‑dimensional embedding
key. Memory‑mapped files are used for N×N distance matrices on
datasets with > 10 000 cells to keep RAM usage bounded.
Parameters
----------
adata : AnnData
atlas_name : str, optional
Used only for logging.
low_dim_keys : list of str, optional
``.obsm`` keys to evaluate. Defaults to all keys returned by
:meth:`CheckAtlasColumnDetector.get_dimred_embeddings`.
high_dim_key : str
Reference key. ``'X'`` means ``adata.X``. (Default ``'X'``)
metric_list : list of str, optional
Metrics to compute. Defaults to :data:`METRICS_DIMRED`.
k_neighbors : int
n_samples : int or None
Number of cells to subsample. ``None`` = use all cells.
seed : int
n_jobs : int
file_dir : str, optional
Directory for temporary memmap files. Falls back to system
``/tmp``.
verbose : bool
Returns
-------
pd.DataFrame
Long‑format table with columns
``[Metric Name, Low Dim Key, High Dim Key, Value, Time (s)]``.
Caller is responsible for pivoting and saving.
"""
import gc
import inspect
import tempfile
import uuid
from sklearn.metrics import pairwise_distances
from sklearn.neighbors import NearestNeighbors
from ..utils.col_detector import CheckAtlasColumnDetector
if metric_list is None:
metric_list = METRICS_DIMRED
metric_list = [m for m in metric_list if m in METRICS_DIMRED]
if not metric_list:
return pd.DataFrame()
if low_dim_keys is None:
detector = CheckAtlasColumnDetector(adata)
low_dim_keys = detector.get_dimred_embeddings()
low_dim_keys = [k for k in low_dim_keys if k != high_dim_key]
if not low_dim_keys:
if verbose:
print(f" No embeddings to compare " f"(only key is {high_dim_key}).")
return pd.DataFrame()
n_obs = adata.n_obs
if n_samples is not None and n_samples < n_obs:
np.random.seed(seed)
sample_indices = np.random.choice(n_obs, n_samples, replace=False)
n_cells = n_samples
else:
sample_indices = np.arange(n_obs)
n_cells = n_obs
if high_dim_key == "X":
high_dim_data = adata.X[sample_indices]
if hasattr(high_dim_data, "toarray"):
high_dim_data = high_dim_data.toarray()
else:
if high_dim_key not in adata.obsm_keys():
raise ValueError(f"High-dim key '{high_dim_key}' not found in adata.obsm.")
high_dim_data = adata.obsm[high_dim_key][sample_indices]
high_n_features = high_dim_data.shape[1]
use_memmap = n_cells > 10000
all_memmap_files = []
if file_dir:
temp_dir = file_dir
else:
# Try default, fall back to /data if system temp is on a small partition
default_tmp = os.path.join(os.path.expanduser("~"), ".checkatlas", "tmp")
temp_dir = default_tmp
try:
import shutil
_usage = shutil.disk_usage(default_tmp)
if _usage.free < 50 * 1024**3: # < 50 GB free
_fallback = "/data/tmp" if os.path.exists("/data/tmp") else "/tmp"
# Also try nextflow TMPDIR if set
_nf_tmp = os.environ.get("TMPDIR", "")
if _nf_tmp and os.path.exists(_nf_tmp):
_fallback = _nf_tmp
temp_dir = os.path.join(_fallback, ".checkatlas_tmp")
if verbose:
print(f" Low disk space on default temp; using {temp_dir}")
except Exception:
pass
os.makedirs(temp_dir, exist_ok=True)
# ── Persistent cache check ───────────────────────────────────
_from_cache = False
_cache_low_dim = {}
if use_cache and atlas_name:
_fp = compute_fingerprint(
n_cells=n_cells,
n_features=high_n_features,
embedding_keys=low_dim_keys,
embedding_shapes={
k: tuple(adata.obsm[k][sample_indices].shape) for k in low_dim_keys
},
k_neighbors=k_neighbors,
source_path=getattr(adata, "filename", None),
)
_cached = load_dimred_cache(temp_dir, _fp, n_cells, low_dim_keys)
if _cached is not None:
high_dim_dists = _cached["high_dim_dists"]
high_knn_dists = _cached["high_knn_dists"]
high_knn_indices = _cached["high_knn_indices"]
_cache_low_dim = _cached["low_dim"]
_from_cache = True
if verbose:
print(" [CACHE HIT] Reusing precomputed distances & kNN")
run_id = str(uuid.uuid4())[:8]
chunk_size = min(1000, n_cells)
if verbose:
print(
f" Reference: {high_dim_key} "
f"({n_cells:,} cells × {high_n_features} features)"
)
print(f" Embeddings: {low_dim_keys}")
print(f" Metrics: {metric_list}")
backend = "GPU (JAX)" if (_JAX_AVAILABLE and _GPU_AVAILABLE) else "CPU"
print(f" Backend: {backend}")
# ── GPU path: single-kernel distance matrix + kNN (Phase 3) ────
# Memory: N² float32 ≈ 4·N² bytes → with intermediate buffers ≈ 4·N²·3.5 bytes
# A100 40 GB: safe N ≤ 50 000 (≈ 32.6 GB total)
# For 50k < N ≤ 150k: chunked GPU + memmap on /data (avoid disk IO bottleneck)
_GPU_SINGLE_SHOT = _JAX_AVAILABLE and _GPU_AVAILABLE and (n_cells <= 50000)
_GPU_CHUNKED = _JAX_AVAILABLE and _GPU_AVAILABLE and (50000 < n_cells <= 150000)
_DIST_METRICS = frozenset(
(
"kruskal_stress",
"spearman_rho",
"dCor",
"trustworthiness",
"continuity",
)
)
_low_dim_precomputed = {} # collect for saving to cache later
if _from_cache:
pass # precomputation already loaded
elif _GPU_SINGLE_SHOT:
import jax.numpy as jnp
high_dim_dists = pdist_squareform(high_dim_data) # GPU matmul → (n,n) float32
high_dists_jax = jnp.asarray(high_dim_dists, dtype=jnp.float32)
# kNN from precomputed distance matrix on GPU
import jax
high_knn_dists_jax, high_knn_indices_jax = jax.lax.approx_min_k(
high_dists_jax, k=k_neighbors + 1
)
high_knn_dists = _get_ndarray(high_knn_dists_jax)
high_knn_indices = _get_ndarray(high_knn_indices_jax)
# ── Persist GPU distance matrix for cache reuse ───────
# save_dimred_cache only stores TriangularMatrix memmaps;
# convert the GPU numpy array so subsequent runs don't
# have to recompute this expensive (N² × D) matrix.
if use_cache and atlas_name:
_high_tri_path = os.path.join(temp_dir, "high_dists.tri")
_high_tri = TriangularMatrix(
n=n_cells, filepath=_high_tri_path, mode="w+"
)
_chunk = min(5000, n_cells)
for _i in range(0, n_cells, _chunk):
_end = min(_i + _chunk, n_cells)
store_upper_triangle(
_high_tri._data, high_dim_dists[_i:_end, :],
_i, 0, n_cells,
)
_high_tri.flush()
high_dim_dists = _high_tri
elif _GPU_CHUNKED:
# ── Chunked GPU path: fused kNN + distance matrix ──
# kNN: streaming GPU, auto-detect for large atlases
# Distances: upper‑triangle float16 memmap written in same pass
import jax.numpy as jnp
_qchunk = 15000 # query rows per chunk
_rchunk = 10000 # reference columns per chunk
if verbose:
print(
f" Chunked GPU/streaming: {n_cells:,} cells, "
f"query={_qchunk}, ref={_rchunk}"
)
print(f" Storing in: {temp_dir}")
# Create TriangularMatrix BEFORE kNN pass to fuse the operations
high_dists_path = os.path.join(temp_dir, f"high_dists_{run_id}.tri")
all_memmap_files.append(high_dists_path)
high_dim_dists = TriangularMatrix(n=n_cells, filepath=high_dists_path, mode="w+")
if verbose:
print(" Computing kNN + distances (fused GPU)...")
high_knn_results = compute_neighbors(
np.asarray(high_dim_data, dtype=np.float64),
n_neighbors=k_neighbors + 1,
backend="auto",
tri_memmap=high_dim_dists,
)
high_knn_dists = high_knn_results.distances
high_knn_indices = high_knn_results.indices
high_dim_dists.flush()
gc.collect()
else:
# ─── CPU path (unchanged) ───────────────────────────────────
# ── High‑dim kNN ──
nbrs_high = NearestNeighbors(n_neighbors=k_neighbors + 1, n_jobs=n_jobs).fit(
high_dim_data
)
high_knn_dists, high_knn_indices = nbrs_high.kneighbors(high_dim_data)
# ── High‑dim distance matrix (only when distance‑based metrics
# are in the list) ──
need_high_dists = bool(set(metric_list) & _DIST_METRICS)
high_dim_dists = None
if need_high_dists:
if verbose:
print(
" Precomputing high-dim distance matrix "
f"({n_cells:,}×{n_cells:,})…"
)
if use_memmap:
high_dists_path = os.path.join(temp_dir, f"high_dists_{run_id}.tri")
all_memmap_files.append(high_dists_path)
high_dim_dists = TriangularMatrix(
n=n_cells, filepath=high_dists_path, mode="w+"
)
for i in tqdm(
range(0, n_cells, chunk_size),
desc="High-Dim Distances",
disable=not verbose,
):
end = min(i + chunk_size, n_cells)
block = pairwise_distances(
high_dim_data[i:end], high_dim_data, n_jobs=n_jobs
)
store_upper_triangle(high_dim_dists._data, block, i, 0, n_cells)
high_dim_dists.flush()
else:
high_dim_dists = np.zeros((n_cells, n_cells), dtype=np.float32)
for i in tqdm(
range(0, n_cells, chunk_size),
desc="High-Dim Distances",
disable=not verbose,
):
end = min(i + chunk_size, n_cells)
high_dim_dists[i:end, :] = pairwise_distances(
high_dim_data[i:end], high_dim_data, n_jobs=n_jobs
)
gc.collect()
# ── end of precomputation block ──
total_combos = len(low_dim_keys) * len(metric_list)
if verbose:
print(
f"\n {len(metric_list)} metrics × {len(low_dim_keys)} embeddings"
f" = {total_combos} calculations\n"
)
results = []
pbar = tqdm(
total=total_combos,
desc="Calculating Dimred Metrics",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} " "[{elapsed}<{remaining}]",
disable=not verbose,
)
for low_dim_key in low_dim_keys:
low_dim_data = adata.obsm[low_dim_key][sample_indices]
low_dim_dists = None
low_dists_path = None
if _from_cache and low_dim_key in _cache_low_dim:
# ── Load from cache ───────────────────────────────
_c = _cache_low_dim[low_dim_key]
low_dim_dists = _c.get("dists")
low_knn_dists = _c["knn_dists"]
low_knn_indices = _c["knn_indices"]
# ── Recompute missing distance matrix ─────────────
# Cache may have kNN but not distances (GPU numpy
# arrays are not persisted — save_dimred_cache only
# stores TriangularMatrix memmaps). Compute the
# distance matrix once so every metric can share it.
if low_dim_dists is None and low_knn_dists is not None:
if _GPU_SINGLE_SHOT:
import jax.numpy as _jnp
_dists_np = pdist_squareform(low_dim_data)
if use_cache and atlas_name:
_safe = low_dim_key.replace("/", "_").replace(" ", "_")
_tri_path = os.path.join(
temp_dir, f"low_dists_{_safe}.tri"
)
_tri = TriangularMatrix(
n=n_cells, filepath=_tri_path, mode="w+"
)
_chunk = min(5000, n_cells)
for _i in range(0, n_cells, _chunk):
_end = min(_i + _chunk, n_cells)
store_upper_triangle(
_tri._data, _dists_np[_i:_end, :],
_i, 0, n_cells,
)
_tri.flush()
low_dim_dists = _tri
del _dists_np
else:
low_dim_dists = _dists_np
elif n_cells > 10000:
_safe = low_dim_key.replace("/", "_").replace(" ", "_")
_tri_path = os.path.join(
temp_dir, f"low_dists_{_safe}_{run_id}.tri"
)
_tri = TriangularMatrix(
n=n_cells, filepath=_tri_path, mode="w+"
)
for _i in range(0, n_cells, chunk_size):
_end = min(_i + chunk_size, n_cells)
_block = pairwise_distances(
low_dim_data[_i:_end], low_dim_data,
n_jobs=n_jobs,
)
store_upper_triangle(
_tri._data, _block, _i, 0, n_cells
)
_tri.flush()
low_dim_dists = _tri
else:
need_d = bool(set(metric_list) & _DIST_METRICS)
if need_d:
low_dim_dists = np.zeros(
(n_cells, n_cells), dtype=np.float32
)
for _i in range(0, n_cells, chunk_size):
_end = min(_i + chunk_size, n_cells)
low_dim_dists[_i:_end, :] = pairwise_distances(
low_dim_data[_i:_end], low_dim_data,
n_jobs=n_jobs,
)
gc.collect()
elif _GPU_SINGLE_SHOT:
# ── GPU path: single-kernel distance + kNN ──────────
import jax
import jax.numpy as jnp
low_dim_dists = pdist_squareform(low_dim_data)
low_dists_jax = jnp.asarray(low_dim_dists, dtype=jnp.float32)
low_knn_dists_jax, low_knn_indices_jax = jax.lax.approx_min_k(
low_dists_jax, k=k_neighbors + 1
)
low_knn_dists = _get_ndarray(low_knn_dists_jax)
low_knn_indices = _get_ndarray(low_knn_indices_jax)
# ── Persist GPU distance matrix for cache reuse ──
if use_cache and atlas_name:
_safe = low_dim_key.replace("/", "_").replace(" ", "_")
_tri_path = os.path.join(
temp_dir, f"low_dists_{_safe}.tri"
)
_tri = TriangularMatrix(
n=n_cells, filepath=_tri_path, mode="w+"
)
_chunk = min(5000, n_cells)
for _i in range(0, n_cells, _chunk):
_end = min(_i + _chunk, n_cells)
store_upper_triangle(
_tri._data, low_dim_dists[_i:_end, :],
_i, 0, n_cells,
)
_tri.flush()
low_dim_dists = _tri
elif _GPU_CHUNKED:
# ── GPU for low-dim when features ≤ 200 dims ─────────
# Low-dim embeddings (X_pca=50d, X_umap=2d) fit on GPU easily.
# High-dim X (60k genes) needs CPU memmap for N×N storage.
low_ndim = low_dim_data.shape[1]
if low_ndim <= 200 and n_cells <= 50000:
# One-shot GPU: small feature dim → N² fits
import jax
import jax.numpy as jnp
low_dim_dists = pdist_squareform(low_dim_data)
low_dists_jax = jnp.asarray(low_dim_dists, dtype=jnp.float32)
low_knn_dists_jax, low_knn_indices_jax = jax.lax.approx_min_k(
low_dists_jax, k=k_neighbors + 1
)
low_knn_dists = _get_ndarray(low_knn_dists_jax)
low_knn_indices = _get_ndarray(low_knn_indices_jax)
# ── Persist GPU distance matrix for cache reuse ──
if use_cache and atlas_name:
_safe = low_dim_key.replace("/", "_").replace(" ", "_")
_tri_path = os.path.join(
temp_dir, f"low_dists_{_safe}.tri"
)
_tri = TriangularMatrix(
n=n_cells, filepath=_tri_path, mode="w+"
)
_chunk = min(5000, n_cells)
for _i in range(0, n_cells, _chunk):
_end = min(_i + _chunk, n_cells)
store_upper_triangle(
_tri._data, low_dim_dists[_i:_end, :],
_i, 0, n_cells,
)
_tri.flush()
low_dim_dists = _tri
elif low_ndim <= 200:
# Fused GPU streaming kNN + optional distance matrix
need_low_dists = bool(set(metric_list) & _DIST_METRICS)
if need_low_dists:
low_dists_path = os.path.join(
temp_dir,
f"low_dists_{low_dim_key.replace('/', '_')}_{run_id}.tri",
)
all_memmap_files.append(low_dists_path)
low_dim_dists = TriangularMatrix(
n=n_cells, filepath=low_dists_path, mode="w+"
)
else:
low_dim_dists = None
kNN = compute_neighbors(
np.asarray(low_dim_data, dtype=np.float64),
n_neighbors=k_neighbors + 1,
backend="auto",
tri_memmap=low_dim_dists,
)
low_knn_dists = kNN.distances
low_knn_indices = kNN.indices
if need_low_dists and low_dim_dists is not None:
low_dim_dists.flush()
gc.collect()
else:
# High-dim X — CPU memmap (sklearn) → float16 upper‑triangle
nbrs_low = NearestNeighbors(
n_neighbors=k_neighbors + 1, n_jobs=n_jobs
).fit(low_dim_data)
low_knn_dists, low_knn_indices = nbrs_low.kneighbors(low_dim_data)
low_dists_path = os.path.join(
temp_dir,
f"low_dists_{low_dim_key.replace('/', '_')}_{run_id}.tri",
)
all_memmap_files.append(low_dists_path)
low_dim_dists = TriangularMatrix(
n=n_cells, filepath=low_dists_path, mode="w+"
)
for i in tqdm(
range(0, n_cells, chunk_size),
desc=f"Low-Dim Distances ({low_dim_key})",
disable=not verbose,
):
end = min(i + chunk_size, n_cells)
block = pairwise_distances(
low_dim_data[i:end], low_dim_data, n_jobs=n_jobs
)
store_upper_triangle(low_dim_dists._data, block, i, 0, n_cells)
low_dim_dists.flush()
gc.collect()
else:
# ── CPU path ────────────────────────────────────────
nbrs_low = NearestNeighbors(n_neighbors=k_neighbors + 1, n_jobs=n_jobs).fit(
low_dim_data
)
low_knn_dists, low_knn_indices = nbrs_low.kneighbors(low_dim_data)
need_low_dists = bool(set(metric_list) & _DIST_METRICS)
if need_low_dists:
if use_memmap:
low_dists_path = os.path.join(
temp_dir,
f"low_dists_{low_dim_key.replace('/', '_')}_{run_id}.tri",
)
all_memmap_files.append(low_dists_path)
low_dim_dists = TriangularMatrix(
n=n_cells, filepath=low_dists_path, mode="w+"
)
for i in tqdm(
range(0, n_cells, chunk_size),
desc=f"Low-Dim Distances ({low_dim_key})",
disable=not verbose,
):
end = min(i + chunk_size, n_cells)
block = pairwise_distances(
low_dim_data[i:end], low_dim_data, n_jobs=n_jobs
)
store_upper_triangle(low_dim_dists._data, block, i, 0, n_cells)
low_dim_dists.flush()
else:
low_dim_dists = np.zeros((n_cells, n_cells), dtype=np.float32)
for i in tqdm(
range(0, n_cells, chunk_size),
desc=f"Low-Dim Distances ({low_dim_key})",
disable=not verbose,
):
end = min(i + chunk_size, n_cells)
low_dim_dists[i:end, :] = pairwise_distances(
low_dim_data[i:end], low_dim_data, n_jobs=n_jobs
)
gc.collect()
# ── Capture precomputed low-dim data for cache ────────
if use_cache and atlas_name and not _from_cache:
_low_dim_precomputed[low_dim_key] = {
"dists": low_dim_dists,
"knn_indices": low_knn_indices,
"knn_dists": low_knn_dists,
}
for metric_name in metric_list:
pbar.set_description(f"{low_dim_key}: {metric_name}")
t_start = time.time()
# ── Inspect metric signature to decide on to_dense ──────
metric_module = getattr(dimred, metric_name)
sig = inspect.signature(metric_module.run)
params = sig.parameters
# ── Materialize TriangularMatrix only when needed ───────
# For N ≤ 50k cells to_dense() is cheap (~2.5 GB) and
# saves the _rank_penalty helper a few row-gathers.
# For N > 50k the dense allocation would be O(N²) and
# catastrophic — the helper handles TriangularMatrix
# natively via the GPU chunked / CPU chunked paths.
_DENSE_TO_DENSE_MAX_N = 50_000
_high_mat = high_dim_dists
_low_mat = low_dim_dists
if metric_name in ("trustworthiness", "continuity"):
if (
isinstance(high_dim_dists, TriangularMatrix)
and high_dim_dists.n <= _DENSE_TO_DENSE_MAX_N
):
_high_mat = high_dim_dists.to_dense()
if (
isinstance(low_dim_dists, TriangularMatrix)
and low_dim_dists is not None
and low_dim_dists.n <= _DENSE_TO_DENSE_MAX_N
):
_low_mat = low_dim_dists.to_dense()
try:
kwargs = {}
if "low_dim_key" in params:
kwargs["low_dim_key"] = low_dim_key
if "high_dim_key" in params:
kwargs["high_dim_key"] = high_dim_key
if "k_neighbors" in params:
kwargs["k_neighbors"] = k_neighbors
if "n_samples" in params:
kwargs["n_samples"] = n_samples
if "n_jobs" in params:
kwargs["n_jobs"] = n_jobs
if "seed" in params:
kwargs["seed"] = seed
if "verbose" in params:
kwargs["verbose"] = False
if "rank_backend" in params:
# Trust/continuity support "auto" | "jax_single_shot"
# | "jax_chunked" | "cpu". "auto" picks the
# fastest backend based on N + GPU availability.
kwargs["rank_backend"] = "auto"
if "precomputed_high_knn" in params:
kwargs["precomputed_high_knn"] = high_knn_indices
if "precomputed_low_knn" in params:
kwargs["precomputed_low_knn"] = low_knn_indices
if "precomputed_high_knn_dists" in params:
kwargs["precomputed_high_knn_dists"] = high_knn_dists
if "precomputed_low_knn_dists" in params:
kwargs["precomputed_low_knn_dists"] = low_knn_dists
if "precomputed_high_dists" in params:
kwargs["precomputed_high_dists"] = high_dim_dists
if "precomputed_low_dists" in params:
kwargs["precomputed_low_dists"] = low_dim_dists
value = metric_module.run(adata, **kwargs)
elapsed = round(time.time() - t_start, 3)
results.append(
{
"Metric Name": metric_name,
"Low Dim Key": low_dim_key,
"High Dim Key": high_dim_key,
"Value": value,
"Time (s)": elapsed,
}
)
except Exception as e:
elapsed = round(time.time() - t_start, 3)
logger.warning("Failed %s on %s: %s", metric_name, low_dim_key, e)
results.append(
{
"Metric Name": metric_name,
"Low Dim Key": low_dim_key,
"High Dim Key": high_dim_key,
"Value": np.nan,
"Time (s)": elapsed,
}
)
pbar.update(1)
if low_dim_dists is not None:
try:
if isinstance(low_dim_dists, TriangularMatrix):
_safe_close_triangular(low_dim_dists)
else:
_safe_close_memmap(low_dim_dists)
except Exception:
pass
# Persist when caching; delete otherwise
if not use_cache:
if low_dists_path and os.path.exists(low_dists_path):
try:
os.remove(low_dists_path)
except Exception:
pass
if low_dists_path in all_memmap_files:
all_memmap_files.remove(low_dists_path)
gc.collect()
pbar.close()
if high_dim_dists is not None:
try:
if isinstance(high_dim_dists, TriangularMatrix):
_safe_close_triangular(high_dim_dists)
else:
_safe_close_memmap(high_dim_dists)
except Exception:
pass
gc.collect()
# ── Save to persistent cache ──
if use_cache and atlas_name and not _from_cache and high_dim_dists is not None:
_low_info = {}
for emb in low_dim_keys:
_low_info[emb] = {
"dists": _low_dim_precomputed.get(emb, {}).get("dists"),
"knn_indices": _low_dim_precomputed.get(emb, {}).get("knn_indices"),
"knn_dists": _low_dim_precomputed.get(emb, {}).get("knn_dists"),
}
try:
save_dimred_cache(
temp_dir,
_fp,
high_dim_dists=(
high_dim_dists
if isinstance(high_dim_dists, TriangularMatrix)
else None
),
high_knn_dists=high_knn_dists,
high_knn_indices=high_knn_indices,
low_dim_data=_low_info,
low_dim_keys=low_dim_keys,
)
except Exception as exc:
logger.warning("Failed to save dimred cache: %s", exc)
# Delete leftover temp files (only when NOT caching)
if not use_cache:
for fpath in all_memmap_files:
try:
if os.path.exists(fpath):
os.remove(fpath)
except OSError as exc:
logger.warning("Failed to clean up temp file %s: %s", fpath, exc)
if "_dimred_cache" in adata.uns:
del adata.uns["_dimred_cache"]
gc.collect()
df = pd.DataFrame(results)
if verbose and not df.empty:
total_time = df["Time (s)"].sum()
n_results = len(df)
print(f"\nTotal computation time: {total_time:.2f}s")
print(f"Results: {n_results} measurements across " f"{len(metric_list)} metrics")
return df
calc_metric_dimred(metric, adata, obsm_key)
¶
Calculate a single dimensionality reduction metric (legacy function).
For comprehensive assessment, use cal_dimred() instead.
Source code in checkatlas/metrics/metrics.py
def calc_metric_dimred(metric, adata, obsm_key):
"""
Calculate a single dimensionality reduction metric (legacy function).
For comprehensive assessment, use cal_dimred() instead.
"""
if metric in METRICS_DIMRED:
start_time = time.time()
# Get the right module for the metric
logger.debug(f"Start {metric} calc")
metric_module = getattr(dimred, metric)
print(metric_module)
high_dim_counts = adata.X
low_dim_counts = adata.obsm[obsm_key]
# execute the run function from metric_module
metric_value = metric_module.run(high_dim_counts, low_dim_counts)
running_time = time.time() - start_time
logger.debug(f"{metric} calc finished, duration {running_time}")
return metric_value, running_time
else:
logger.warning(
f"{metric} is not a recognized "
f"dimensionality reduction metric."
f"\nList of dim. red. metrics: {METRICS_DIMRED}"
)
return -1
run_all_metrics(adata=None, atlas_path=None, atlas_name=None, file_dir=None, n_jobs=-1, verbose=True, seed=42, run_annotation=True, all_annot_metrics=True, run_clustering=True, all_cluster_metrics=True, run_dimred=True, all_dimred_metrics=True, low_dim_key='X_umap', high_dim_key='X', k_neighbors=30, n_samples=None, use_memmap=True, temp_dir=None, chunk_size=1000)
¶
Unified metrics pipeline for CheckAtlas.
The standalone entry point that runs ALL metric tasks (annotation, clustering, dimensionality reduction) on a single-cell atlas and produces a consolidated results CSV. Designed for biologists and bioinformaticians — just provide an atlas path and get comprehensive quality metrics.
Usage
from checkatlas.metrics.metrics import run_all_metrics results = run_all_metrics(atlas_path="my_atlas.h5ad")
Or with a preloaded AnnData:¶
results = run_all_metrics(adata=my_adata, atlas_name="my_atlas")
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/metrics/metrics.py
def run_all_metrics(
adata=None,
atlas_path=None,
atlas_name=None,
file_dir=None,
n_jobs=-1,
verbose=True,
seed=42,
# Annotation params
run_annotation=True,
all_annot_metrics=True,
# Clustering params
run_clustering=True,
all_cluster_metrics=True,
# Dimred params
run_dimred=True,
all_dimred_metrics=True,
low_dim_key="X_umap",
high_dim_key="X",
k_neighbors=30,
n_samples=None,
use_memmap=True,
temp_dir=None,
chunk_size=1000,
):
"""
Unified metrics pipeline for CheckAtlas.
The standalone entry point that runs ALL metric tasks (annotation, clustering,
dimensionality reduction) on a single-cell atlas and produces a consolidated
results CSV. Designed for biologists and bioinformaticians — just provide
an atlas path and get comprehensive quality metrics.
Usage:
>>> from checkatlas.metrics.metrics import run_all_metrics
>>> results = run_all_metrics(atlas_path="my_atlas.h5ad")
>>> # Or with a preloaded AnnData:
>>> results = run_all_metrics(adata=my_adata, atlas_name="my_atlas")
Args:
adata (AnnData, optional): Preloaded AnnData object. If None, loads from atlas_path.
atlas_path (str, optional): Path to .h5ad atlas file. Used if adata is None.
atlas_name (str, optional): Name for the atlas (used in output). If None,
derived from atlas_path filename.
file_dir (str, optional): Directory for output CSV. Defaults to current directory.
n_jobs (int): Number of parallel jobs (-1 = all cores).
verbose (bool): Whether to print progress and summary.
seed (int): Random seed for reproducibility.
run_annotation (bool): Whether to run annotation metrics.
all_annot_metrics (bool): If True, run all annotation metrics.
run_clustering (bool): Whether to run clustering metrics.
all_cluster_metrics (bool): If True, run all clustering metrics.
run_dimred (bool): Whether to run dimensionality reduction metrics.
all_dimred_metrics (bool): If True, run all dimred metrics.
low_dim_key (str): Key in adata.obsm for low-dimensional embedding.
high_dim_key (str): Key for high-dimensional data ('X' = adata.X).
k_neighbors (int): Number of neighbors for kNN-based metrics.
n_samples (int, optional): Number of samples for dimred metrics.
use_memmap (bool): Whether to use memory-mapped files for large distance matrices.
temp_dir (str, optional): Directory for temporary memmap files.
chunk_size (int): Chunk size for batched distance computation.
Returns:
pd.DataFrame: Consolidated results with columns:
[Atlas Name, Task, Metric Name, Input 1, Input 2, Value, Time (s)]
"""
import gc
import scanpy as sc
overall_start = time.time()
# ────────────────────────────────────────────────────────────────────
# 1. Load atlas
# ────────────────────────────────────────────────────────────────────
if adata is None and atlas_path is not None:
if verbose:
print(f"Loading atlas from: {atlas_path}")
adata = sc.read_h5ad(atlas_path)
if atlas_name is None:
atlas_name = os.path.splitext(os.path.basename(atlas_path))[0]
elif adata is not None:
if atlas_name is None:
atlas_name = "atlas"
else:
raise ValueError("Provide either `adata` (AnnData) or `atlas_path` (str).")
if verbose:
print(f"\n{'='*70}")
print(f" CheckAtlas — Unified Metrics Pipeline")
print(f" Atlas: {atlas_name}")
print(f" Cells: {adata.n_obs:,} | Genes: {adata.n_vars:,}")
print(f" Tasks: ", end="")
tasks = []
if run_annotation:
tasks.append("Annotation")
if run_clustering:
tasks.append("Clustering")
if run_dimred:
tasks.append("Dimred")
print(" → ".join(tasks))
print(f"{'='*70}\n")
# Set output directory
if file_dir is None:
file_dir = os.getcwd()
else:
os.makedirs(file_dir, exist_ok=True)
# Consolidated results
all_results = []
task_summary = {}
# ────────────────────────────────────────────────────────────────────
# 2. Run Annotation Metrics
# ────────────────────────────────────────────────────────────────────
if run_annotation:
task_start = time.time()
if verbose:
print(f"\n{'─'*50}")
print(f" TASK 1/3: Annotation Metrics")
print(f"{'─'*50}")
try:
df_annot = cal_annot(
adata,
atlas_name=atlas_name,
all=all_annot_metrics,
file_dir=file_dir,
)
if not df_annot.empty:
# Normalize columns to unified schema
unified = pd.DataFrame(
{
"Atlas Name": df_annot["Atlas Name"],
"Task": "annotation",
"Metric Name": df_annot["Metric Name"],
"Input 1": df_annot["Reference/Input 1"],
"Input 2": df_annot["Prediction/Input 2"],
"Value": df_annot["Value"],
"Time (s)": df_annot.get("Time (s)", np.nan),
}
)
all_results.append(unified)
task_elapsed = time.time() - task_start
n_metrics = len(df_annot) if not df_annot.empty else 0
task_summary["Annotation"] = {
"count": n_metrics,
"time": round(task_elapsed, 2),
}
if verbose:
print(f" ✓ Annotation: {n_metrics} results in {task_elapsed:.1f}s")
except Exception as e:
task_elapsed = time.time() - task_start
logger.error(f"Annotation pipeline failed: {e}")
task_summary["Annotation"] = {
"count": 0,
"time": round(task_elapsed, 2),
"error": str(e),
}
if verbose:
print(f" ✗ Annotation failed: {e}")
# ────────────────────────────────────────────────────────────────────
# 3. Run Clustering Metrics
# ────────────────────────────────────────────────────────────────────
if run_clustering:
task_start = time.time()
if verbose:
print(f"\n{'─'*50}")
print(f" TASK 2/3: Clustering Metrics")
print(f"{'─'*50}")
try:
df_cluster = cal_cluster(
adata,
atlas_name=atlas_name,
all_metrics=all_cluster_metrics,
file_dir=file_dir,
n_jobs=n_jobs,
verbose=verbose,
seed=seed,
)
if not df_cluster.empty:
unified = pd.DataFrame(
{
"Atlas Name": df_cluster["Atlas Name"],
"Task": "clustering",
"Metric Name": df_cluster["Metric Name"],
"Input 1": df_cluster["Embedding"],
"Input 2": df_cluster["Label Key"],
"Value": df_cluster["Value"],
"Time (s)": df_cluster["Time (s)"],
}
)
all_results.append(unified)
task_elapsed = time.time() - task_start
n_metrics = len(df_cluster) if not df_cluster.empty else 0
task_summary["Clustering"] = {
"count": n_metrics,
"time": round(task_elapsed, 2),
}
if verbose:
print(f" ✓ Clustering: {n_metrics} results in {task_elapsed:.1f}s")
except Exception as e:
task_elapsed = time.time() - task_start
logger.error(f"Clustering pipeline failed: {e}")
task_summary["Clustering"] = {
"count": 0,
"time": round(task_elapsed, 2),
"error": str(e),
}
if verbose:
print(f" ✗ Clustering failed: {e}")
# ────────────────────────────────────────────────────────────────────
# 4. Run Dimensionality Reduction Metrics
# ────────────────────────────────────────────────────────────────────
if run_dimred:
task_start = time.time()
if verbose:
print(f"\n{'─'*50}")
print(f" TASK 3/3: Dimensionality Reduction Metrics")
print(f"{'─'*50}")
try:
metric_list = METRICS_DIMRED if all_dimred_metrics else None
df_dimred = cal_dimred(
adata,
atlas_name=atlas_name,
high_dim_key=high_dim_key,
metric_list=metric_list,
k_neighbors=k_neighbors,
n_samples=n_samples,
seed=seed,
n_jobs=n_jobs,
file_dir=file_dir,
verbose=verbose,
)
if not df_dimred.empty:
unified = pd.DataFrame(
{
"Atlas Name": atlas_name,
"Task": "dimred",
"Metric Name": df_dimred["Metric Name"],
"Input 1": df_dimred["Low Dim Key"],
"Input 2": df_dimred["High Dim Key"],
"Value": df_dimred["Value"],
"Time (s)": df_dimred["Time (s)"],
}
)
all_results.append(unified)
task_elapsed = time.time() - task_start
n_metrics = len(df_dimred) if not df_dimred.empty else 0
task_summary["Dimred"] = {
"count": n_metrics,
"time": round(task_elapsed, 2),
}
if verbose:
print(f" ✓ Dimred: {n_metrics} results in {task_elapsed:.1f}s")
except Exception as e:
task_elapsed = time.time() - task_start
logger.error(f"Dimred pipeline failed: {e}")
task_summary["Dimred"] = {
"count": 0,
"time": round(task_elapsed, 2),
"error": str(e),
}
if verbose:
print(f" ✗ Dimred failed: {e}")
# ────────────────────────────────────────────────────────────────────
# 5. Consolidate & Save
# ────────────────────────────────────────────────────────────────────
if all_results:
df_all = pd.concat(all_results, ignore_index=True)
else:
df_all = pd.DataFrame(
columns=[
"Atlas Name",
"Task",
"Metric Name",
"Input 1",
"Input 2",
"Value",
"Time (s)",
]
)
# Save unified CSV
if not df_all.empty:
filename = os.path.join(file_dir, f"checkatlas_all_metrics_{atlas_name}.csv")
df_all.to_csv(filename, index=False)
if verbose:
print(f"\nSaved unified results to: {filename}")
logger.info(f"Saved unified metrics to {filename}")
# ────────────────────────────────────────────────────────────────────
# 6. Print Summary
# ────────────────────────────────────────────────────────────────────
overall_elapsed = time.time() - overall_start
if verbose:
print(f"\n{'='*70}")
print(f" CheckAtlas — Results Summary")
print(f"{'='*70}")
print(f" {'Task':<20} {'Metrics':>10} {'Time (s)':>12} {'Status':>10}")
print(f" {'─'*52}")
for task_name, info in task_summary.items():
status = "✗ Error" if "error" in info else "✓ Done"
print(
f" {task_name:<20} {info['count']:>10} {info['time']:>12.2f} {status:>10}"
)
print(f" {'─'*52}")
total_count = sum(info["count"] for info in task_summary.values())
print(f" {'TOTAL':<20} {total_count:>10} {overall_elapsed:>12.2f}")
print(f"{'='*70}")
if not df_all.empty:
print(f"\n Unique metrics computed: {df_all['Metric Name'].nunique()}")
print(f" Total measurements: {len(df_all)}")
# Top 5 slowest metrics
if "Time (s)" in df_all.columns:
slowest = df_all.nlargest(5, "Time (s)")[
["Task", "Metric Name", "Time (s)"]
]
if not slowest.empty:
print(f"\n Slowest metrics:")
for _, row in slowest.iterrows():
print(
f" {row['Task']}/{row['Metric Name']}: {row['Time (s)']:.3f}s"
)
print()
gc.collect()
return df_all
checkatlas.utils.folders¶
checkatlas.utils.folders
¶
checkatlas_folders(path: str) -> None
¶
Check in path if the different checkatlas folders exists.
Create them if needed.
All folders are given by DICT_FOLDER
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/utils/folders.py
def checkatlas_folders(path: str) -> None:
"""Check in path if the different checkatlas folders exists.<br>
Create them if needed.
All folders are given by DICT_FOLDER
Args:
path (str): Search path for atlas given by user
Returns:
None: None
"""
global_path = get_workingdir(path)
if not os.path.exists(global_path):
os.mkdir(global_path)
for key_folder in DICT_FOLDER.keys():
temp_path = os.path.join(global_path, key_folder)
if not os.path.exists(temp_path):
logger.debug(f"Create folder: {temp_path}")
os.mkdir(temp_path)
get_folder(path: str, key_folder: str) -> str
¶
Get the folder path giving the search path and the folder key in DICT_FOLDER
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/utils/folders.py
def get_folder(path: str, key_folder: str) -> str:
"""Get the folder path giving the search path and
the folder key in DICT_FOLDER
Args:
path (str): Search path for atlas given by user
key_folder (str): key folder in the DICT_FOLDER
example: ANNDATA, SUMMARY, UMAP
Returns:
str: the folder path
"""
return os.path.join(get_workingdir(path), DICT_FOLDER[key_folder])
get_workingdir(path: str) -> str
¶
Return the working_dir = path of search + working_dir with working_dir = checkatlas_files/
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in checkatlas/utils/folders.py
def get_workingdir(path: str) -> str:
"""Return the working_dir = path of search
+ working_dir
with working_dir = checkatlas_files/
Args:
path (str): Search path for atlas given by user
Returns:
str: os.path.join(path, working_dir)
"""
return os.path.join(path, WORKING_DIR)
checkatlas.utils.files¶
checkatlas.utils.files
¶
checkatlas.utils.checkatlas_arguments¶
checkatlas.utils.checkatlas_arguments
¶
MAX_N_JOBS = 48
module-attribute
¶
Upper limit on the number of CPU threads any CheckAtlas process may
consume. Capped at 48 so that on a 80-core workstation at least 32
threads remain free for other users / pipelines. See
:func:cap_n_jobs for the enforcement helper.
cap_n_jobs(n_jobs)
¶
Cap n_jobs at :data:MAX_N_JOBS and treat None / -1
as "use the system default" by resolving to the cap.
Rules
None→ :data:MAX_N_JOBS(covers the legacy path wheregetattr(args, "n_jobs", -1)returned the fallback).-1→ :data:MAX_N_JOBS(sklearn / joblib convention: "all available cores", capped for politeness).- values > :data:
MAX_N_JOBS→ :data:MAX_N_JOBS. - any positive integer ≤ :data:
MAX_N_JOBS→ unchanged. - anything else (non-integer, zero, negative) → :data:
MAX_N_JOBS.
Source code in checkatlas/utils/checkatlas_arguments.py
def cap_n_jobs(n_jobs):
"""Cap *n_jobs* at :data:`MAX_N_JOBS` and treat ``None`` / ``-1``
as "use the system default" by resolving to the cap.
Rules:
* ``None`` → :data:`MAX_N_JOBS` (covers the legacy path where
``getattr(args, "n_jobs", -1)`` returned the fallback).
* ``-1`` → :data:`MAX_N_JOBS` (sklearn / joblib convention:
"all available cores", capped for politeness).
* values > :data:`MAX_N_JOBS` → :data:`MAX_N_JOBS`.
* any positive integer ≤ :data:`MAX_N_JOBS` → unchanged.
* anything else (non-integer, zero, negative) → :data:`MAX_N_JOBS`.
"""
if n_jobs is None or n_jobs == -1:
return MAX_N_JOBS
try:
n_jobs = int(n_jobs)
except (TypeError, ValueError):
return MAX_N_JOBS
if n_jobs > MAX_N_JOBS:
return MAX_N_JOBS
if n_jobs < 1:
return MAX_N_JOBS
return n_jobs
get_version()
¶
Get version of checkatlas from checkatlas/VERSION file :return: checkatlas version
Source code in checkatlas/utils/checkatlas_arguments.py
def get_version():
"""
Get version of checkatlas from checkatlas/VERSION file
:return: checkatlas version
"""
version_file = files(__package__).joinpath("VERSION")
with open(version_file) as file:
version = file.readline()
return version