checkatlas

checkatlas.check

checkatlas.check

atlas_info = {'Atlas_name': 'pbmc_6k_v1_v1', 'Atlas_type': 'Cellranger < v3', 'Atlas_extension': '.mtx', 'Atlas_path': '/data/analysis/data_becavin/checkatlas_test/tuto/data5/pbmc_6k_v1_v1/outs/filtered_matrices_mex/hg19/matrix.mtx'} module-attribute

atlas_info = {'Atlas_name': 'pbmc_5k_v3_v3', 'Atlas_type': 'Cellranger >= v3', 'Atlas_extension': '.h5', 'Atlas_path': '/data/analysis/data_becavin/' 'checkatlas_test/tuto/data5/' 'pbmc_5k_v3_v3/outs/' '5k_pbmc_v3_filtered_feature_bc_matrix.h5'}

atlas_info = {'Atlas_name': 'pbmc_5k_v3_v7', 'Atlas_type': 'Cellranger >= v3', 'Atlas_extension': '.h5', 'Atlas_path': '/data/analysis/data_becavin/' 'checkatlas_test/tuto/data5/' 'pbmc_5k_v3_v7/outs/' 'SC3pv3_GEX_Human_PBMC_filtered_feature_bc_matrix.h5'}

atlas_info = {'Atlas_name': 'pbmc_3k_multiome', 'Atlas_type': 'Cellranger >= v3', 'Atlas_extension': '.h5', 'Atlas_path': '/data/analysis/data_becavin/' 'checkatlas_test/tuto/data5/' 'pbmc_3k_multiome/outs/' 'pbmc_unsorted_3k_filtered_feature_bc_matrix.h5'}

checkatlas.atlas

checkatlas.atlas

atlas_sampling(df_annot: pd.DataFrame, type_df: str, args: argparse.Namespace) -> pd.DataFrame

If args.plot_celllimit != 0 and args.plot_celllimit < len(df_annot) The atlas qC table will be sampled for MultiQC

Parameters:
  • df_annot (DataFrame) –

    Table to sample

  • type_df (str) –

    type of table

  • args (Namespace) –

    arguments of checkatlas workflow

Returns:
  • DataFrame

    pd.DataFrame: Sampled QC table

Source code in checkatlas/atlas.py
def atlas_sampling(
    df_annot: pd.DataFrame, type_df: str, args: argparse.Namespace
) -> pd.DataFrame:
    """
    If args.plot_celllimit != 0 and args.plot_celllimit < len(df_annot)
    The atlas qC table will be sampled for MultiQC

    Args:
        df_annot (pd.DataFrame): Table to sample
        type_df (str): type of table
        args (argparse.Namespace): arguments of checkatlas workflow

    Returns:
        pd.DataFrame: Sampled QC table
    """
    if args.plot_celllimit != 0 and args.plot_celllimit < len(df_annot):
        logger.debug(f"Sample {type_df} table with {len(df_annot)} cells")
        df_annot = df_annot.sample(args.plot_celllimit)
        logger.debug(f"{type_df} table sampled to {len(df_annot)} cells")
    return df_annot
clean_scanpy_atlas(adata: AnnData, atlas_info: dict) -> AnnData

Clean the Scanpy object to be sure to get all information out of it

  • Make var names unique
  • Make var unique for Raw matrix
  • If OBS_CLUSTERS are present and in int32 -> be sure to transform them in categorical
Parameters:
  • adata (AnnData) –

    atlas to analyse

  • atlas_info (dict) –

    info on the atlas

Returns:
  • AnnData( AnnData ) –

    cleaned atlas

Source code in checkatlas/atlas.py
def clean_scanpy_atlas(adata: AnnData, atlas_info: dict) -> AnnData:
    """
    Clean the Scanpy object to be sure to get all information out of it

    - Make var names unique
    - Make var unique for Raw matrix
    - If OBS_CLUSTERS are present and in int32 -> be sure to
    transform them in categorical

    Args:
        adata (AnnData): atlas to analyse
        atlas_info (dict): info on the atlas

    Returns:
        AnnData: cleaned atlas
    """
    logger.debug(f"Clean scanpy: {atlas_info[check.ATLAS_NAME_KEY]}")
    # Make var names unique
    list_var = adata.var_names
    if len(set(list_var)) == len(list_var):
        logger.debug("Var names unique")
    else:
        logger.debug("Var names not unique, ran : adata.var_names_make_unique()")
        adata.var_names_make_unique()
        # Test a second time if it is unique (sometimes it helps)
        list_var = adata.var_names
        if len(set(list_var)) == len(list_var):
            logger.debug("Var names unique")
        else:
            logger.debug("Var names not unique, ran : adata.var_names_make_unique()")
            adata.var_names_make_unique()
            # If it is still not unique, create unique var_names "by hand"
            list_var = adata.var_names
            if len(set(list_var)) == len(list_var):
                logger.debug("Var names unique")
            else:
                logger.debug("Var names not unique, ran : adata.var_names_make_unique()")
                adata.var.index = [
                    x + "_" + str(i)
                    for i, x in zip(range(len(adata.var)), adata.var_names)
                ]
                list_var = adata.var_names
                if len(set(list_var)) == len(list_var):
                    logger.debug("Var names unique")
    # Make var unique for Raw matrix
    if adata.raw is not None:
        list_var = adata.raw.var_names
        if len(set(list_var)) == len(list_var):
            logger.debug("Var names for Raw unique, transform ")
        else:
            logger.debug("Var names for Raw not unique")
            adata.raw.var.index = [
                x + "_" + str(i)
                for i, x in zip(range(len(adata.raw.var)), adata.raw.var_names)
            ]
            list_var = adata.raw.var_names
            if len(set(list_var)) == len(list_var):
                logger.debug("Var names for Raw unique")

    # If OBS_CLUSTERS are present and in int32 -> be sure to
    # transform them in categorical
    for obs_key in adata.obs_keys():
        for obs_key_celltype in OBS_CLUSTERS:
            if obs_key_celltype in obs_key:
                if (
                    adata.obs[obs_key].dtype == np.int32
                    or adata.obs[obs_key].dtype == np.int64
                ):
                    adata.obs[obs_key] = pd.Categorical(adata.obs[obs_key])
    return adata
col_annotation_pred(adata: AnnData, min_score: float = 0.5, return_all: bool = False, max_results: int = 5) -> Optional[List[str]]

Detect predicted/cluster annotation columns in AnnData object.

This function identifies columns containing cluster labels or automated cell type predictions (e.g., leiden, louvain, seurat_clusters, celltypist).

Parameters:
  • adata (AnnData) –

    Scanpy AnnData object to analyze

  • min_score (float, default: 0.5 ) –

    Minimum confidence score threshold (0-1). Default: 0.5

  • return_all (bool, default: False ) –

    If True, return with scores. Default: False

  • max_results (int, default: 5 ) –

    Maximum number of columns to return. Default: 5

Returns:
  • Optional[List[str]]

    List[str] or List[Tuple[str, float]] or None: - If return_all=False: List of column names sorted by confidence - If return_all=True: List of (column_name, score) tuples - None if no columns found

Example

import scanpy as sc import checkatlas.atlas as atlas adata = sc.read_h5ad("atlas.h5ad") pred_cols = atlas.col_annotation_pred(adata) print(f"Predicted columns: {pred_cols}")

Get with scores

pred_with_scores = atlas.col_annotation_pred(adata, return_all=True) for col, score in pred_with_scores: ... print(f"{col}: {score:.3f}")

Source code in checkatlas/atlas.py
def col_annotation_pred(
    adata: AnnData,
    min_score: float = 0.5,
    return_all: bool = False,
    max_results: int = 5,
) -> Optional[List[str]]:
    """
    Detect predicted/cluster annotation columns in AnnData object.

    This function identifies columns containing cluster labels or automated
    cell type predictions (e.g., leiden, louvain, seurat_clusters, celltypist).

    Args:
        adata (AnnData): Scanpy AnnData object to analyze
        min_score (float): Minimum confidence score threshold (0-1). Default: 0.5
        return_all (bool): If True, return with scores. Default: False
        max_results (int): Maximum number of columns to return. Default: 5

    Returns:
        List[str] or List[Tuple[str, float]] or None:
            - If return_all=False: List of column names sorted by confidence
            - If return_all=True: List of (column_name, score) tuples
            - None if no columns found

    Example:
        >>> import scanpy as sc
        >>> import checkatlas.atlas as atlas
        >>> adata = sc.read_h5ad("atlas.h5ad")
        >>> pred_cols = atlas.col_annotation_pred(adata)
        >>> print(f"Predicted columns: {pred_cols}")
        >>>
        >>> # Get with scores
        >>> pred_with_scores = atlas.col_annotation_pred(adata, return_all=True)
        >>> for col, score in pred_with_scores:
        ...     print(f"{col}: {score:.3f}")
    """
    detector = CheckAtlasColumnDetector(adata)
    results = detector.detect_all_parameters(
        min_reference_score=0.3, min_predicted_score=min_score
    )

    pred_candidates = results["annotation"]["predicted"][:max_results]

    if not pred_candidates:
        return None

    if return_all:
        return pred_candidates
    else:
        return [col for col, score in pred_candidates]
col_annotation_ref(adata: AnnData, min_score: float = 0.5, return_all: bool = False) -> Optional[str]

Detect reference (ground truth) annotation column in AnnData object.

This function uses intelligent semantic and statistical analysis to identify the most likely reference/ground truth cell type annotation column.

Parameters:
  • adata (AnnData) –

    Scanpy AnnData object to analyze

  • min_score (float, default: 0.5 ) –

    Minimum confidence score threshold (0-1). Default: 0.5

  • return_all (bool, default: False ) –

    If True, return list of all candidates with scores. Default: False

Returns:
  • Optional[str]

    str or List[Tuple[str, float]] or None: - If return_all=False: Best reference column name, or None if none found - If return_all=True: List of (column_name, score) tuples sorted by score

Example

import scanpy as sc import checkatlas.atlas as atlas adata = sc.read_h5ad("atlas.h5ad") ref_col = atlas.col_annotation_ref(adata) print(f"Reference column: {ref_col}")

Get all candidates with scores

all_refs = atlas.col_annotation_ref(adata, return_all=True) for col, score in all_refs: ... print(f"{col}: {score:.3f}")

Source code in checkatlas/atlas.py
def col_annotation_ref(
    adata: AnnData, min_score: float = 0.5, return_all: bool = False
) -> Optional[str]:
    """
    Detect reference (ground truth) annotation column in AnnData object.

    This function uses intelligent semantic and statistical analysis to identify
    the most likely reference/ground truth cell type annotation column.

    Args:
        adata (AnnData): Scanpy AnnData object to analyze
        min_score (float): Minimum confidence score threshold (0-1). Default: 0.5
        return_all (bool): If True, return list of all candidates with scores. Default: False

    Returns:
        str or List[Tuple[str, float]] or None:
            - If return_all=False: Best reference column name, or None if none found
            - If return_all=True: List of (column_name, score) tuples sorted by score

    Example:
        >>> import scanpy as sc
        >>> import checkatlas.atlas as atlas
        >>> adata = sc.read_h5ad("atlas.h5ad")
        >>> ref_col = atlas.col_annotation_ref(adata)
        >>> print(f"Reference column: {ref_col}")
        >>>
        >>> # Get all candidates with scores
        >>> all_refs = atlas.col_annotation_ref(adata, return_all=True)
        >>> for col, score in all_refs:
        ...     print(f"{col}: {score:.3f}")
    """

    detector = CheckAtlasColumnDetector(adata)
    results = detector.detect_all_parameters(
        min_reference_score=min_score, min_predicted_score=0.3
    )

    ref_candidates = results["annotation"]["reference"]

    if return_all:
        return ref_candidates
    else:
        return ref_candidates[0][0] if ref_candidates else None
col_cluster(adata: AnnData, min_score: float = 0.5, return_all: bool = False, max_results: int = 5) -> Optional[List[str]]

Detect cluster label columns in AnnData object.

Uses the dedicated cluster-label detector (Leiden, Louvain, k‑means, Seurat clusters, PhenoGraph, etc.) which applies semantic 70 % + statistical 30 % scoring tuned for algorithmic clustering outputs.

Parameters:
  • adata (AnnData) –

    Scanpy AnnData object to analyze

  • min_score (float, default: 0.5 ) –

    Minimum confidence score threshold (0-1). Default: 0.5

  • return_all (bool, default: False ) –

    If True, return with scores. Default: False

  • max_results (int, default: 5 ) –

    Maximum number of columns to return. Default: 5

Returns:
  • Optional[List[str]]

    List[str] or List[Tuple[str, float]] or None: - If return_all=False: List of column names sorted by confidence - If return_all=True: List of (column_name, score) tuples - None if no columns found

Example

import scanpy as sc import checkatlas.atlas as atlas adata = sc.read_h5ad("atlas.h5ad") cluster_cols = atlas.col_cluster(adata) print(f"Cluster columns: {cluster_cols}")

Source code in checkatlas/atlas.py
def col_cluster(
    adata: AnnData,
    min_score: float = 0.5,
    return_all: bool = False,
    max_results: int = 5,
) -> Optional[List[str]]:
    """
    Detect cluster label columns in AnnData object.

    Uses the dedicated cluster-label detector (Leiden, Louvain, k‑means,
    Seurat clusters, PhenoGraph, etc.) which applies semantic 70 % +
    statistical 30 % scoring tuned for algorithmic clustering outputs.

    Args:
        adata (AnnData): Scanpy AnnData object to analyze
        min_score (float): Minimum confidence score threshold (0-1). Default: 0.5
        return_all (bool): If True, return with scores. Default: False
        max_results (int): Maximum number of columns to return. Default: 5

    Returns:
        List[str] or List[Tuple[str, float]] or None:
            - If return_all=False: List of column names sorted by confidence
            - If return_all=True: List of (column_name, score) tuples
            - None if no columns found

    Example:
        >>> import scanpy as sc
        >>> import checkatlas.atlas as atlas
        >>> adata = sc.read_h5ad("atlas.h5ad")
        >>> cluster_cols = atlas.col_cluster(adata)
        >>> print(f"Cluster columns: {cluster_cols}")
    """
    detector = CheckAtlasColumnDetector(adata)
    results = detector.detect_all_parameters(min_cluster_score=min_score)

    clust_candidates = results["clustering"]["cluster_labels"][:max_results]

    if not clust_candidates:
        return None

    if return_all:
        return clust_candidates
    else:
        return [col for col, _score in clust_candidates]
col_dimred(adata: AnnData, return_all: bool = False, max_results: int = 10) -> Optional[List[str] | List[dict[str, Any]]]

Detect dimensionality reduction representations in AnnData.obsm.

This function identifies embedding keys like X_pca, X_umap, X_tsne, etc.

Parameters:
  • adata (AnnData) –

    Scanpy AnnData object to analyze

  • return_all (bool, default: False ) –

    If True, return with metadata. Default: False

  • max_results (int, default: 10 ) –

    Maximum number of representations to return. Default: 10

Returns:
  • Optional[List[str] | List[dict[str, Any]]]

    List[str] or List[Dict] or None: - If return_all=False: List of obsm keys (e.g., ['X_umap', 'X_pca']) - If return_all=True: List of dicts with 'key', 'shape', 'n_components' - None if no representations found

Example

import scanpy as sc import checkatlas.atlas as atlas adata = sc.read_h5ad("atlas.h5ad") dimred_keys = atlas.col_dimred(adata) print(f"Dimensionality reductions: {dimred_keys}")

Get with metadata

dimred_detailed = atlas.col_dimred(adata, return_all=True) for emb in dimred_detailed: ... print(f"{emb['key']}: {emb['n_components']} components")

Source code in checkatlas/atlas.py
def col_dimred(
    adata: AnnData, return_all: bool = False, max_results: int = 10
) -> Optional[List[str] | List[dict[str, Any]]]:
    """
    Detect dimensionality reduction representations in AnnData.obsm.

    This function identifies embedding keys like X_pca, X_umap, X_tsne, etc.

    Args:
        adata (AnnData): Scanpy AnnData object to analyze
        return_all (bool): If True, return with metadata. Default: False
        max_results (int): Maximum number of representations to return. Default: 10

    Returns:
        List[str] or List[Dict] or None:
            - If return_all=False: List of obsm keys (e.g., ['X_umap', 'X_pca'])
            - If return_all=True: List of dicts with 'key', 'shape', 'n_components'
            - None if no representations found

    Example:
        >>> import scanpy as sc
        >>> import checkatlas.atlas as atlas
        >>> adata = sc.read_h5ad("atlas.h5ad")
        >>> dimred_keys = atlas.col_dimred(adata)
        >>> print(f"Dimensionality reductions: {dimred_keys}")
        >>>
        >>> # Get with metadata
        >>> dimred_detailed = atlas.col_dimred(adata, return_all=True)
        >>> for emb in dimred_detailed:
        ...     print(f"{emb['key']}: {emb['n_components']} components")
    """
    detector = CheckAtlasColumnDetector(adata)
    results = detector.detect_all_parameters()

    embeddings = results["clustering"]["embeddings"][:max_results]

    if not embeddings:
        return None

    if return_all:
        return [
            {
                "key": key,
                "shape": meta["shape"],
                "n_components": meta["n_components"],
            }
            for key, meta in embeddings
        ]
    else:
        return [key for key, meta in embeddings]
create_anndata_table(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None

Create an html table with all AnnData arguments The html code will make all elements of the table visible in MultiQC Args: adata (AnnData): atlas to analyse atlas_info (dict): info dict on the atlas args (argparse.Namespace): list of arguments from checkatlas workflow

Source code in checkatlas/atlas.py
def create_anndata_table(
    adata: AnnData, atlas_info: dict, args: argparse.Namespace
) -> None:
    """
    Create an html table with all AnnData arguments
    The html code will make all elements of the table visible in MultiQC
    Args:
        adata (AnnData): atlas to analyse
        atlas_info (dict): info dict on the atlas
        args (argparse.Namespace): list of arguments from checkatlas workflow
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]

    logger.debug(f"Create Adata table for {atlas_name}")
    csv_path = files.get_file_path(
        atlas_name, folders.ANNDATA, check.TSV_EXTENSION, args.path
    )
    # Create AnnData table
    header = ["atlas_obs", "obsm", "var", "varm", "uns"]
    df_summary = pd.DataFrame(index=[atlas_name], columns=header)
    # html_element = "<span class=\"label label-primary\">"
    # new_line = ''
    # for value in list(adata.obs.columns):
    #     new_line += html_element + value + "</span><br>"
    #     print(new_line)
    df_summary["atlas_obs"][atlas_name] = (
        "<code>" + "</code><br><code>".join(list(adata.obs.columns)) + "</code>"
    )
    df_summary["obsm"][atlas_name] = (
        "<code>" + "</code><br><code>".join(list(adata.obsm_keys())) + "</code>"
    )
    df_summary["var"][atlas_name] = (
        "<code>" + "</code><br><code>".join(list(adata.var_keys())) + "</code>"
    )
    df_summary["varm"][atlas_name] = (
        "<code>" + "</code><br><code>".join(list(adata.varm_keys())) + "</code>"
    )
    df_summary["uns"][atlas_name] = (
        "<code>" + "</code><br><code>".join(list(adata.uns_keys())) + "</code>"
    )
    df_summary.to_csv(csv_path, index=False, quoting=False, sep="\t")
create_metric_annot(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None

Calculate all annotation metrics via the comprehensive cal_annot pipeline (ref-vs-pred, embedding-based, batch/integration, graph).

The pipeline auto-detects reference/predicted columns, embedding keys, and batch labels; runs every metric listed in METRICS_ANNOT; and writes results as tab-separated files in the annotation folder.

Parameters:
  • adata (AnnData) –

    atlas to analyse

  • atlas_info (dict) –

    info of the atlas

  • args (Namespace) –

    list of arguments from checkatlas workflow

Source code in checkatlas/atlas.py
def create_metric_annot(
    adata: AnnData, atlas_info: dict, args: argparse.Namespace
) -> None:
    """
    Calculate all annotation metrics via the comprehensive ``cal_annot``
    pipeline (ref-vs-pred, embedding-based, batch/integration, graph).

    The pipeline auto-detects reference/predicted columns, embedding keys,
    and batch labels; runs every metric listed in ``METRICS_ANNOT``; and
    writes results as tab-separated files in the annotation folder.

    Args:
        adata (AnnData): atlas to analyse
        atlas_info (dict): info of the atlas
        args (argparse.Namespace): list of arguments from checkatlas workflow
    """
    if args.metric_annot == ["none"]:
        logger.info("Skipping annotation metrics (--metric_annot none)")
        return

    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    annotation_dir = folders.get_folder(args.path, folders.ANNOTATION)

    logger.info("Running full annotation pipeline for %s", atlas_name)

    preprocess_ctx = _try_load_context(atlas_info, args)

    df = metrics.cal_annot(
        adata,
        atlas_name=atlas_name,
        metric_list=args.metric_annot,
        all=True,
        file_dir=annotation_dir,
        n_jobs=_resolve_n_jobs(args),
        verbose=True,
        preprocess_context=preprocess_ctx,
    )

    if not df.empty:
        csv_path = files.get_file_path(
            atlas_name,
            folders.ANNOTATION,
            check.TSV_EXTENSION,
            args.path,
        )
        wide_df = metrics._pivot_annot_to_wide(df, atlas_name)
        wide_df.to_csv(csv_path, index=False, sep="\t")
        logger.info("Annotation metrics saved to %s", csv_path)
    else:
        logger.warning("No annotation metrics calculated for %s", atlas_name)
create_metric_cluster(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None

Calculate all clustering metrics via the comprehensive cal_cluster pipeline. The pipeline auto-detects embeddings and cluster-label columns, runs every metric listed in METRICS_CLUST across all combinations, and writes results as a tab-separated file in the cluster folder.

Parameters:
  • adata (AnnData) –

    atlas to analyse

  • atlas_info (dict) –

    info of the atlas

  • args (Namespace) –

    list of arguments from checkatlas workflow

Source code in checkatlas/atlas.py
def create_metric_cluster(
    adata: AnnData, atlas_info: dict, args: argparse.Namespace
) -> None:
    """
    Calculate all clustering metrics via the comprehensive ``cal_cluster``
    pipeline.  The pipeline auto-detects embeddings and cluster-label columns,
    runs every metric listed in ``METRICS_CLUST`` across all combinations,
    and writes results as a tab-separated file in the cluster folder.

    Args:
        adata (AnnData): atlas to analyse
        atlas_info (dict): info of the atlas
        args (argparse.Namespace): list of arguments from checkatlas workflow
    """
    if not args.metric_cluster:
        logger.info("Skipping clustering metrics (no metrics requested)")
        return

    if args.metric_cluster == ["none"]:
        logger.info("Skipping clustering metrics (--metric_cluster none)")
        return

    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    cluster_dir = folders.get_folder(args.path, folders.CLUSTER)

    logger.info("Running full clustering pipeline for %s", atlas_name)

    preprocess_ctx = _try_load_context(atlas_info, args)

    df = metrics.cal_cluster(
        adata,
        atlas_name=atlas_name,
        metric_list=args.metric_cluster,
        file_dir=cluster_dir,
        n_jobs=_resolve_n_jobs(args),
        verbose=True,
        seed=42,
        preprocess_context=preprocess_ctx,
    )

    if not df.empty:
        csv_path = files.get_file_path(
            atlas_name,
            folders.CLUSTER,
            check.TSV_EXTENSION,
            args.path,
        )
        wide_df = metrics._pivot_cluster_to_wide(df, atlas_name)
        wide_df.to_csv(csv_path, index=False, sep="\t")
        logger.info("Clustering metrics saved to %s", csv_path)
    else:
        logger.warning("No clustering metrics calculated for %s", atlas_name)
create_metric_dimred(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None

Calculate all dimred metrics via the comprehensive cal_dimred pipeline. The pipeline auto-detects all .obsm embedding keys, compares each against adata.X as the high‑dimensional reference, runs every metric listed by --metric_dimred, and writes results as a tab‑separated file in the dimred folder compatible with MultiQC.

Parameters:
  • adata (AnnData) –

    atlas to analyse

  • atlas_info (dict) –

    info of the atlas

  • args (Namespace) –

    list of arguments from checkatlas workflow

Source code in checkatlas/atlas.py
def create_metric_dimred(
    adata: AnnData, atlas_info: dict, args: argparse.Namespace
) -> None:
    """
    Calculate all dimred metrics via the comprehensive ``cal_dimred``
    pipeline.  The pipeline auto-detects all ``.obsm`` embedding keys,
    compares each against ``adata.X`` as the high‑dimensional reference,
    runs every metric listed by ``--metric_dimred``, and writes results as
    a tab‑separated file in the dimred folder compatible with MultiQC.

    Args:
        adata (AnnData): atlas to analyse
        atlas_info (dict): info of the atlas
        args (argparse.Namespace): list of arguments from checkatlas workflow
    """
    if args.metric_dimred == ["none"]:
        logger.info("Skipping dimred metrics (--metric_dimred none)")
        return

    atlas_name = atlas_info[check.ATLAS_NAME_KEY]

    logger.info("Running full dimred pipeline for %s", atlas_name)

    dimred_dir = folders.get_folder(args.path, folders.DIMRED)
    # Per-atlas persistent cache under temp/
    cache_dir = os.path.join(
        folders.get_folder(args.path, folders.TEMP), atlas_name, "dimred"
    )

    df = metrics.cal_dimred(
        adata,
        atlas_name=atlas_name,
        metric_list=args.metric_dimred,
        file_dir=cache_dir,
        use_cache=True,
        n_jobs=_resolve_n_jobs(args),
        verbose=True,
        seed=42,
    )

    if not df.empty:
        csv_path = files.get_file_path(
            atlas_name,
            folders.DIMRED,
            check.TSV_EXTENSION,
            args.path,
        )
        wide_df = metrics._pivot_dimred_to_wide(df, atlas_name)
        wide_df.to_csv(csv_path, index=False, sep="\t")
        logger.info("Dimred metrics saved to %s", csv_path)
    else:
        logger.warning("No dimred metrics calculated for %s", atlas_name)
create_qc_plots(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None

Display the atlas QC plot Search for the OBS variable which correspond to the toal_RNA, total_UMI, MT_ratio, RT_ratio

Parameters:
  • adata (AnnData) –

    atlas to analyse

  • atlas_info (dict) –

    info on the atlas

  • args (Namespace) –

    list of arguments from checkatlas workflow

Source code in checkatlas/atlas.py
def create_qc_plots(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None:
    """
    Display the atlas QC plot
    Search for the OBS variable which correspond to the toal_RNA, total_UMI,
     MT_ratio, RT_ratio

    Args:
        adata (AnnData): atlas to analyse
        atlas_info (dict): info on the atlas
        args (argparse.Namespace): list of arguments from checkatlas workflow
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    sc.settings.figdir = folders.get_workingdir(args.path)
    sc.set_figure_params(dpi_save=80)
    qc_path = os.sep + atlas_name + check.QC_FIG_EXTENSION
    logger.debug(f"Create QC violin plot for {atlas_name}")
    # mitochondrial genes
    adata.var["mt"] = adata.var_names.str.startswith("MT-")
    # ribosomal genes
    adata.var["ribo"] = adata.var_names.str.startswith(("RPS", "RPL"))
    sc.pp.calculate_qc_metrics(
        adata,
        qc_vars=["mt", "ribo"],
        percent_top=None,
        log1p=False,
        inplace=True,
    )
    sc.pl.violin(
        adata,
        [
            "n_genes_by_counts",
            "total_counts",
            "pct_counts_mt",
            "pct_counts_ribo",
        ],
        jitter=0.4,
        multi_panel=True,
        show=False,
        save=qc_path,
    )
create_qc_tables(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None

Display the atlas QC table Search for the OBS variable which correspond to the toal_RNA, total_UMI, MT_ratio, RT_ratio

Parameters:
  • adata (AnnData) –

    atlas to analyse

  • atlas_info (dict) –

    info on the atlas

  • args (Namespace) –

    list of arguments from checkatlas workflow

Source code in checkatlas/atlas.py
def create_qc_tables(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None:
    """
    Display the atlas QC table
    Search for the OBS variable which correspond to the toal_RNA, total_UMI,
     MT_ratio, RT_ratio

    Args:
        adata (AnnData): atlas to analyse
        atlas_info (dict): info on the atlas
        args (argparse.Namespace): list of arguments from checkatlas workflow
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    qc_path = files.get_file_path(atlas_name, folders.QC, check.TSV_EXTENSION, args.path)
    logger.debug(f"Create QC tables for {atlas_name}")
    qc_genes = []
    # mitochondrial genes
    adata.var["mt"] = adata.var_names.str.startswith("MT-")
    if len(adata.var[adata.var["mt"]]) != 0:
        qc_genes.append("mt")
        logger.debug(f"Mitochondrial genes in {atlas_name} for QC")
    else:
        logger.debug(f"No mitochondrial genes in {atlas_name} for QC")
    # ribosomal genes
    adata.var["ribo"] = adata.var_names.str.startswith(("RPS", "RPL"))
    if len(adata.var[adata.var["mt"]]) != 0:
        qc_genes.append("ribo")
        logger.debug(f"Ribosomal genes in {atlas_name} for QC")
    else:
        logger.debug(f"No ribosomal genes in {atlas_name} for QC")

    sc.pp.calculate_qc_metrics(
        adata,
        qc_vars=qc_genes,
        percent_top=None,
        log1p=False,
        inplace=True,
    )
    df_annot = adata.obs[get_viable_obs_qc(adata, args)]
    # Rank cell by qc metric
    for header in df_annot.columns:
        if header != CELLINDEX_HEADER:
            new_header = f"cellrank_{header}"
            df_annot = df_annot.sort_values(header, ascending=False)
            df_annot.loc[:, [new_header]] = range(1, adata.n_obs + 1)

    # Sample QC table when more cells than args.plot_celllimit are present
    df_annot = atlas_sampling(df_annot, "QC", args)
    df_annot.loc[:, [CELLINDEX_HEADER]] = range(1, len(df_annot) + 1)
    df_annot.to_csv(qc_path, index=False, quoting=False, sep="\t")
create_summary_table(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None

Create a table with all summarizing variables

Parameters:
  • adata (AnnData) –

    atlas to analyse

  • atlas_info (str) –

    info dict of the atlas

  • args (Namespace) –

    list of arguments from checkatlas workflow

Source code in checkatlas/atlas.py
def create_summary_table(
    adata: AnnData, atlas_info: dict, args: argparse.Namespace
) -> None:
    """
    Create a table with all summarizing variables

    Args:
        adata (AnnData): atlas to analyse
        atlas_info (str): info dict of the atlas
        args (argparse.Namespace): list of arguments from checkatlas workflow
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    atlas_type = atlas_info[check.ATLAS_TYPE_KEY]
    atlas_path = atlas_info[check.ATLAS_PATH_KEY]
    logger.debug(f"Create Summary table for {atlas_name}")
    csv_path = files.get_file_path(
        atlas_name, folders.SUMMARY, check.TSV_EXTENSION, args.path
    )
    # Create summary table
    header = [
        "AtlasFileType",
        "NbCells",
        "NbGenes",
        "AnnData.raw",
        "AnnData.X",
        "File_extension",
        "File_path",
    ]
    df_summary = pd.DataFrame(index=[atlas_name], columns=header)
    df_summary["AtlasFileType"][atlas_name] = atlas_type
    df_summary["NbCells"][atlas_name] = adata.n_obs
    df_summary["NbGenes"][atlas_name] = adata.n_vars
    df_summary["AnnData.raw"][atlas_name] = adata.raw is not None
    df_summary["AnnData.X"][atlas_name] = adata.X is not None
    df_summary["File_extension"][atlas_name] = atlas_name
    df_summary["File_path"][atlas_name] = atlas_path
    df_summary.to_csv(csv_path, index=False, sep="\t")
create_tsne_fig(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None

Display the TSNE of celltypes Search for the OBS variable which correspond to the celltype annotation

Parameters:
  • adata (AnnData) –

    atlas to analyse

  • atlas_info (dict) –

    info on the atlas

  • args (Namespace) –

    list of arguments from checkatlas workflow

Source code in checkatlas/atlas.py
def create_tsne_fig(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None:
    """
    Display the TSNE of celltypes
    Search for the OBS variable which correspond to the celltype annotation

    Args:
        adata (AnnData): atlas to analyse
        atlas_info (dict): info on the atlas
        args (argparse.Namespace): list of arguments from checkatlas workflow
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    sc.set_figure_params(dpi_save=150)
    # Search if tsne reduction exists
    obsm_keys = get_viable_obsm(adata, args)
    r = re.compile(".*tsne*.")
    obsm_tsne_keys = list(filter(r.match, obsm_keys))
    if len(obsm_tsne_keys) > 0:
        obsm_tsne = obsm_tsne_keys[0]
        logger.debug(f"Create t-SNE figure for {atlas_name} with obsm={obsm_tsne}")
        # Set the t-sne to display
        if isinstance(adata.obsm[obsm_tsne], pd.DataFrame):
            # Transform to numpy if it is a pandas dataframe
            adata.obsm["X_tsne"] = adata.obsm[obsm_tsne].to_numpy()
        else:
            adata.obsm["X_tsne"] = adata.obsm[obsm_tsne]
        # Setting up figures directory
        sc.settings.figdir = sc.settings.figdir = folders.get_workingdir(args.path)
        tsne_path = os.sep + atlas_name + check.TSNE_EXTENSION
        # Exporting tsne
        obs_keys = get_viable_obs_annot(adata, args)
        if len(obs_keys) != 0:
            sc.pl.tsne(adata, color=obs_keys[0], show=False, save=tsne_path)
        else:
            sc.pl.tsne(adata, show=False, save=tsne_path)
create_umap_fig(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None

Display the UMAP of celltypes Search for the OBS variable which correspond to the celltype annotation

Parameters:
  • adata (AnnData) –

    atlas to analyse

  • atlas_info (dict) –

    info on the atlas

  • args (Namespace) –

    list of arguments from checkatlas workflow

Source code in checkatlas/atlas.py
def create_umap_fig(adata: AnnData, atlas_info: dict, args: argparse.Namespace) -> None:
    """
    Display the UMAP of celltypes
    Search for the OBS variable which correspond to the celltype annotation

    Args:
        adata (AnnData): atlas to analyse
        atlas_info (dict): info on the atlas
        args (argparse.Namespace): list of arguments from checkatlas workflow
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    sc.set_figure_params(dpi_save=150)
    # Search if umap reduction exists
    obsm_keys = get_viable_obsm(adata, args)
    r = re.compile(".*umap*.")
    obsm_umap_keys = list(filter(r.match, obsm_keys))
    if len(obsm_umap_keys) > 0:
        obsm_umap = obsm_umap_keys[0]
        logger.debug(f"Create UMAP figure for {atlas_name} with obsm={obsm_umap}")
        # Set the umap to display
        if isinstance(adata.obsm[obsm_umap], pd.DataFrame):
            # Transform to numpy if it is a pandas dataframe
            adata.obsm["X_umap"] = adata.obsm[obsm_umap].to_numpy()
        else:
            adata.obsm["X_umap"] = adata.obsm[obsm_umap]
        # Setting up figures directory
        sc.settings.figdir = folders.get_workingdir(args.path)
        umap_path = os.sep + atlas_name + check.UMAP_EXTENSION
        # Exporting umap
        obs_keys = get_viable_obs_annot(adata, args)
        if len(obs_keys) != 0:
            sc.pl.umap(adata, color=obs_keys[0], show=False, save=umap_path)
        else:
            sc.pl.umap(adata, show=False, save=umap_path)
get_viable_obs_annot(adata: AnnData, args: argparse.Namespace) -> list

Search in obs_keys a match to OBS_CLUSTERS values ! Remove obs_key with only one category ! Extract sorted obs_keys in same order then OBS_CLUSTERS

Parameters:
  • adata (AnnData) –

    atlas to analyse

  • args (Namespace) –

    list of arguments from checkatlas workflow

Returns:
  • list( list ) –

    obs_keys

Source code in checkatlas/atlas.py
def get_viable_obs_annot(adata: AnnData, args: argparse.Namespace) -> list:
    """
    Search in obs_keys a match to OBS_CLUSTERS values
    ! Remove obs_key with only one category !
    Extract sorted obs_keys in same order then OBS_CLUSTERS

    Args:
        adata (AnnData): atlas to analyse
        args (argparse.Namespace): list of arguments from checkatlas workflow

    Returns:
        list: obs_keys
    """
    obs_keys = list()
    # Get keys from OBS_CLUSTERS
    for obs_key in adata.obs_keys():
        for obs_key_celltype in args.obs_cluster:
            if obs_key_celltype in obs_key:
                if isinstance(adata.obs[obs_key].dtype, pd.CategoricalDtype):
                    obs_keys.append(obs_key)
    # Remove keys with only one category and no NaN in the array
    obs_keys_final = list()
    for obs_key in obs_keys:
        annotations = adata.obs[obs_key]
        if not _object_dtype_isnan(annotations).any():
            categories_temp = annotations.cat.categories
            # remove nan if found
            categories = categories_temp.dropna()
            if True in categories.isin(["nan"]):
                index = categories.get_loc("nan")
                categories = categories.delete(index)
            # Add obs_key with more than one category (with Nan removed)
            if len(categories) != 1:
                logger.debug(f"Add obs_key {obs_key} with cat {categories_temp}")
                obs_keys_final.append(obs_key)
    return sorted(obs_keys_final)
get_viable_obs_qc(adata: AnnData, args: argparse.Namespace) -> list

Search in obs_keys a match to OBS_QC values Extract sorted obs_keys in same order then OBS_QC

Parameters:
  • adata (AnnData) –

    atlas to analyse

  • args (Namespace) –

    list of arguments from checkatlas workflow

Returns:
  • list( list ) –

    obs_keys

Source code in checkatlas/atlas.py
def get_viable_obs_qc(adata: AnnData, args: argparse.Namespace) -> list:
    """
    Search in obs_keys a match to OBS_QC values
    Extract sorted obs_keys in same order then OBS_QC

    Args:
        adata (AnnData): atlas to analyse
        args (argparse.Namespace): list of arguments from checkatlas workflow

    Returns:
        list: obs_keys
    """
    obs_keys = list()
    for obs_key in adata.obs_keys():
        if obs_key in args.qc_display:
            obs_keys.append(obs_key)
    return obs_keys
get_viable_obsm(adata: AnnData, args: argparse.Namespace) -> list

TO DO Search viable obsm for dimensionality reduction metric calc. ! No filter on osbm is appled for now ! Args: adata (AnnData): atlas to analyse args (argparse.Namespace): list of arguments from checkatlas workflow

Returns:
  • list( list ) –

    obsm_keys

Source code in checkatlas/atlas.py
def get_viable_obsm(adata: AnnData, args: argparse.Namespace) -> list:
    """
    TO DO
    Search viable obsm for dimensionality reduction metric
    calc.
    ! No filter on osbm is appled for now !
    Args:
        adata (AnnData): atlas to analyse
        args (argparse.Namespace): list of arguments from checkatlas workflow

    Returns:
        list: obsm_keys
    """
    obsm_keys = list()
    # for obsm_key in adata.obsm_keys():
    #   if obsm_key in args.obsm_dimred:
    obsm_keys = adata.obsm_keys()
    logger.debug(f"Add obsm {obsm_keys}")
    return obsm_keys
preprocess_atlas(atlas_info: dict, args=None) -> AnnData

Read adata, clean it, and run task-specific precomputations (column detection, kNN graphs, distance matrices, etc.) based on which metric categories are requested in args.

Precomputed artefacts are persisted under checkatlas_files/temp/<atlas>/<task>/ so that downstream child processes (including Nextflow metric steps) can skip redundant computation.

When args is None the function behaves exactly like the legacy version: read + clean only.

Returns the cleaned AnnData object.

Source code in checkatlas/atlas.py
def preprocess_atlas(atlas_info: dict, args=None) -> AnnData:
    """
    Read adata, clean it, and run task-specific precomputations
    (column detection, kNN graphs, distance matrices, etc.) based
    on which metric categories are requested in *args*.

    Precomputed artefacts are persisted under
    ``checkatlas_files/temp/<atlas>/<task>/`` so that downstream
    child processes (including Nextflow metric steps) can skip
    redundant computation.

    When *args* is ``None`` the function behaves exactly like the
    legacy version: read + clean only.

    Returns the cleaned AnnData object.
    """
    adata = read_atlas(atlas_info)
    adata = clean_scanpy_atlas(adata, atlas_info)

    if not _should_precompute(args):
        return adata

    run_cluster = _wants_task(args, "cluster")
    run_annot = _wants_task(args, "annot")
    run_dimred = _wants_task(args, "dimred")

    if not (run_cluster or run_annot or run_dimred):
        return adata

    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    source_path = atlas_info.get(check.ATLAS_PATH_KEY, None)

    # ── 1. Column detection (once) ──────────────────────────────────
    detector = CheckAtlasColumnDetector(adata)
    params = detector.detect_all_parameters()

    ref_keys = [c for c, _ in params["annotation"]["reference"]]
    pred_keys = [c for c, _ in params["annotation"]["predicted"]]
    cluster_label_keys = [c for c, _ in params["clustering"]["cluster_labels"]]
    batch_keys = [c for c, _ in params.get("batch", [])]
    if not batch_keys:
        batch_keys = [col for col in adata.obs.columns if "batch" in col.lower()]

    embedding_keys = []
    cluster_embedding_keys = []
    annotation_embedding_keys = []
    for emb, meta in params["clustering"]["embeddings"]:
        embedding_keys.append(emb)
        cluster_embedding_keys.append(emb)
        n_comp = meta.get("n_components", 0)
        if n_comp > 2:
            annotation_embedding_keys.append(emb)
    # Also add ALL .obsm keys for dimred/cluster use;
    # for annotation only include embeddings with > 2 components
    all_obsm_keys = adata.obsm_keys()
    for key in all_obsm_keys:
        if key not in embedding_keys:
            embedding_keys.append(key)
        if key not in cluster_embedding_keys:
            cluster_embedding_keys.append(key)
        if key not in annotation_embedding_keys:
            try:
                if adata.obsm[key].shape[1] > 2:
                    annotation_embedding_keys.append(key)
            except Exception:
                pass

    k_max = 90  # covers LISI (90) and kBET (25 via subset)
    fingerprint = make_preprocess_fingerprint(
        adata,
        embedding_keys=embedding_keys,
        cluster_label_keys=cluster_label_keys,
        batch_keys=batch_keys,
        k_neighbors=k_max,
        source_path=source_path,
        annotation_embedding_keys=annotation_embedding_keys,
    )

    # ── 2. Early exit: cached context still valid ──────────────────
    temp_parent = folders.get_folder(args.path, folders.TEMP)
    existing = load_context(atlas_name, temp_parent, fingerprint)
    if existing is not None:
        logger.info("Precompute context already valid — skipping recomputation")
        return adata

    # ── 3. Build fresh context ─────────────────────────────────────
    ctx = PreprocessContext(
        atlas_name=atlas_name,
        fingerprint=fingerprint,
        ref_keys=ref_keys,
        pred_keys=pred_keys,
        embedding_keys=embedding_keys,
        annotation_embedding_keys=annotation_embedding_keys,
        cluster_embedding_keys=cluster_embedding_keys,
        cluster_label_keys=cluster_label_keys,
        batch_keys=batch_keys,
        temp_parent_dir=temp_parent,
        dimred_dir=os.path.join(temp_parent, atlas_name, folders.DIMRED),
        annotation_dir=os.path.join(temp_parent, atlas_name, folders.ANNOTATION),
        cluster_dir=os.path.join(temp_parent, atlas_name, folders.CLUSTER),
    )
    for d in (ctx.dimred_dir, ctx.annotation_dir, ctx.cluster_dir):
        os.makedirs(d, exist_ok=True)

    # ── 4. Task-specific precomputation ────────────────────────────
    n_jobs = _resolve_n_jobs(args)

    if run_dimred:
        _precompute_dimred(adata, ctx, k_neighbors=30, n_jobs=n_jobs)

    if run_annot:
        _precompute_annot(adata, ctx, k_neighbors=90, n_jobs=n_jobs)

    if run_cluster:
        _precompute_cluster(adata, ctx, k_neighbors=30, n_jobs=n_jobs)

    # ── 5. Persist context ─────────────────────────────────────────
    save_context(ctx)

    return adata
read_atlas(atlas_info: dict) -> AnnData

Read Scanpy or Cellranger data : .h5ad or .h5

Parameters:
  • atlas_info (dict) –

    info dict about the atlas

Returns:
  • AnnData( AnnData ) –

    scanpy object from .h5ad

Source code in checkatlas/atlas.py
def read_atlas(atlas_info: dict) -> AnnData:
    """
    Read Scanpy or Cellranger data : .h5ad or .h5

    Args:
        atlas_info (dict): info dict about the atlas

    Returns:
        AnnData: scanpy object from .h5ad
    """
    logger.info(
        f"Load {atlas_info[check.ATLAS_NAME_KEY]} "
        f"in {atlas_info[check.ATLAS_PATH_KEY]}"
    )
    try:
        if atlas_info[check.ATLAS_TYPE_KEY] == cellranger.CELLRANGER_TYPE_CURRENT:
            logger.debug(
                "Read Cellranger >= v3 results " f"{atlas_info[check.ATLAS_PATH_KEY]}"
            )
            adata = cellranger.read_cellranger_current(atlas_info)
        elif atlas_info[check.ATLAS_TYPE_KEY] == cellranger.CELLRANGER_TYPE_OBSOLETE:
            logger.debug(
                "Read Cellranger < v3 results " f"{atlas_info[check.ATLAS_PATH_KEY]}"
            )
            adata = cellranger.read_cellranger_obsolete(atlas_info)
        else:
            logger.debug(f"Read Scanpy file {atlas_info[check.ATLAS_PATH_KEY]}")
            adata = sc.read_h5ad(atlas_info[check.ATLAS_PATH_KEY])
        return adata
    except _io.utils.AnnDataReadError:
        logger.warning(
            "AnnDataReadError, cannot read: " f"{atlas_info[check.ATLAS_PATH_KEY]}"
        )
        return dict()

checkatlas.seurat

checkatlas.seurat

check_seurat_install() -> None

Check if Seurat is installed, run installation if not

Source code in checkatlas/seurat.py
def check_seurat_install() -> None:
    """Check if Seurat is installed, run installation if not"""
    # import R's utility package
    utils = rpackages.importr("utils")
    # select a mirror for R packages
    utils.chooseCRANmirror(ind=1)  # select the first mirror in the list
    # R package names
    packnames = ("Seurat", "SeuratObject")
    # Selectively install what needs to be install.
    # We are fancy, just because we can.
    names_to_install = [x for x in packnames if not rpackages.isinstalled(x)]
    if len(names_to_install) > 0:
        # create personal library
        rcode = """dir.create(Sys.getenv("R_LIBS_USER"), recursive = TRUE)"""
        robjects.r(rcode)
        # add to the path
        rcode = """.libPaths(Sys.getenv("R_LIBS_USER"))"""
        robjects.r(rcode)
        logger.debug(f"Set Rlibpaths: {robjects.r(rcode)}")
        utils.install_packages(StrVector(names_to_install))
create_anndata_table(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None

Create a table with all AnnData-like arguments in Seurat object :param seurat: :param atlas_name: :param atlas_path: :return:

Source code in checkatlas/seurat.py
def create_anndata_table(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
    """
    Create a table with all AnnData-like arguments in Seurat object
    :param seurat:
    :param atlas_name:
    :param atlas_path:
    :return:
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    logger.debug(f"Create Adata table for {atlas_name}")
    csv_path = os.path.join(
        folders.get_folder(args.path, folders.ANNDATA),
        atlas_name + check.TSV_EXTENSION,
    )
    # Create AnnData table
    header = ["atlas_obs", "obsm", "var", "varm", "uns"]
    df_summary = pd.DataFrame(index=[atlas_name], columns=header)

    # Create r_functions
    r_obs = robjects.r("obs <- function(seurat){ return(colnames(seurat@meta.data))}")
    r_obsm = robjects.r("f<-function(seurat){return(names(seurat@reductions))}")
    r_uns = robjects.r("uns <- function(seurat){ return(colnames(seurat@misc))}")

    obs_list = r_obs(seurat)
    obsm_list = r_obsm(seurat)
    var_list = [""]
    varm_list = [""]
    uns_list = [""]
    if not isinstance(r_uns(seurat), NULLType):
        uns_list = r_uns(seurat)

    df_summary["atlas_obs"][atlas_name] = (
        "<code>" + "</code><br><code>".join(obs_list) + "</code>"
    )
    df_summary["obsm"][atlas_name] = (
        "<code>" + "</code><br><code>".join(obsm_list) + "</code>"
    )
    df_summary["var"][atlas_name] = (
        "<code>" + "</code><br><code>".join(var_list) + "</code>"
    )
    df_summary["varm"][atlas_name] = (
        "<code>" + "</code><br><code>".join(varm_list) + "</code>"
    )
    df_summary["uns"][atlas_name] = (
        "<code>" + "</code><br><code>".join(uns_list) + "</code>"
    )
    df_summary.to_csv(csv_path, index=False, quoting=False, sep="\t")
create_metric_annot(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None

Calc annotation metrics :param adata: :param atlas_path: :param atlas_info: :param args: :return:

Source code in checkatlas/seurat.py
def create_metric_annot(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
    """
    Calc annotation metrics
    :param adata:
    :param atlas_path:
    :param atlas_info:
    :param args:
    :return:
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    csv_path = os.path.join(
        folders.get_folder(args.path, folders.ANNOTATION),
        atlas_name + check.TSV_EXTENSION,
    )
    metric_annot = args.metric_annot
    if metric_annot is None:
        metric_annot = metrics.METRICS_ANNOT
    header = ["Annot_Sample", "Reference", "obs"] + metric_annot
    df_annot = pd.DataFrame(columns=header)
    obs_keys = get_viable_obs_annot(seurat, args)
    if len(obs_keys) > 1:
        logger.debug(f"Calc annotation metrics for {atlas_name}")
        if len(obs_keys) != 0:
            ref_obs = obs_keys[0]
            for i in range(1, len(obs_keys)):
                obs_key = obs_keys[i]
                dict_line = {
                    "Annot_Sample": [atlas_name + "_" + obs_key],
                    "Reference": [ref_obs],
                    "obs": [obs_key],
                }
                for metric in metric_annot:
                    logger.debug(
                        f"Calc {metric} for {atlas_name} "
                        f"with obs {obs_key} vs ref_obs {ref_obs}"
                    )
                    metric_value = metrics.calc_metric_annot_seurat(
                        metric, seurat, obs_key, ref_obs
                    )
                    dict_line[metric] = metric_value
                df_line = pd.DataFrame(dict_line)
                df_annot = pd.concat([df_annot, df_line], ignore_index=True, axis=0)
            df_annot.to_csv(csv_path, index=False, sep="\t")
    else:
        logger.debug(f"No viable obs_key was found for {atlas_name}")
create_metric_cluster(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None

Calc clustering metrics :param seurat: :param atlas_path: :param atlas_info: :param args: :return:

Source code in checkatlas/seurat.py
def create_metric_cluster(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
    """
    Calc clustering metrics
    :param seurat:
    :param atlas_path:
    :param atlas_info:
    :param args:
    :return:
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    csv_path = os.path.join(
        folders.get_folder(args.path, folders.CLUSTER),
        atlas_name + check.TSV_EXTENSION,
    )
    header = ["Clust_Sample", "obs"] + args.metric_cluster
    df_cluster = pd.DataFrame(columns=header)
    obs_keys = get_viable_obs_annot(seurat, args)
    obsm_key_representation = "umap"
    if len(obs_keys) > 0:
        logger.debug(f"Calc clustering metrics for {atlas_name}")
        for obs_key in obs_keys:
            dict_line = {
                "Clust_Sample": [atlas_name + "_" + obs_key],
                "obs": [obs_key],
            }
            for metric in args.metric_cluster:
                logger.debug(
                    f"Calc {metric} for {atlas_name} "
                    f"with obs {obs_key} and obsm {obsm_key_representation}"
                )
                result = metrics.calc_metric_cluster_seurat(
                    metric, seurat, obs_key, obsm_key_representation
                )
                if isinstance(result, tuple):
                    metric_value, running_time = result
                    dict_line[metric] = metric_value
                    dict_line[f"{metric}_running_time"] = running_time
                else:
                    dict_line[metric] = result
            df_line = pd.DataFrame(dict_line)
            df_cluster = pd.concat([df_cluster, df_line], ignore_index=True, axis=0)
        df_cluster.to_csv(csv_path, index=False, sep="\t")
    else:
        logger.debug(f"No viable obs_key was found for {atlas_name}")
create_metric_dimred(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None

Calc dimensionality reduction metrics :param adata: :param atlas_path: :param atlas_info: :param args: :return:

Source code in checkatlas/seurat.py
def create_metric_dimred(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
    """
    Calc dimensionality reduction metrics
    :param adata:
    :param atlas_path:
    :param atlas_info:
    :param args:
    :return:
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    csv_path = os.path.join(
        folders.get_folder(args.path, folders.DIMRED),
        atlas_name + check.TSV_EXTENSION,
    )
    header = ["Dimred_Sample", "obsm"] + args.metric_dimred
    df_dimred = pd.DataFrame(columns=header)
    # r_reduction = robjects.r(
    #     "reduc <- function(seurat, obsm_key){"
    #     " return(Embeddings(object = seurat, reduction = obsm_key))}"
    # )
    obsm_keys = get_viable_obsm(seurat, args)
    if len(obsm_keys) > 0:
        logger.debug(f"Calc dim red metrics for {atlas_name}")
        for obsm_key in obsm_keys:
            dict_line = {
                "Dimred_Sample": [atlas_name + "_" + obsm_key],
                "obsm": [obsm_key],
            }
            for metric in args.metric_dimred:
                logger.debug(f"Calc {metric} for {atlas_name} with obsm {obsm_key}")
                # r_countmatrix = robjects.r(
                #     "mat <- function(seurat)
                #     { return(seurat@assays$RNA@counts)}"
                # )
                # high_dim_counts = ro.conversion.rpy2py(r_countmatrix(seurat))
                # low_dim_counts = ro.conversion.rpy2py(
                #    r_reduction(seurat, obsm_key)
                # )
                # metric_value = metrics.calc_metric_dimred(
                # metric, high_dim_counts, low_dim_counts)
                logger.warning(
                    "!!! Dim reduction metrics not available for Seurat"
                    " at the moment !!!"
                )
                # metric_value = -1
                # dict_line[metric] = str(metric_value)
            df_line = pd.DataFrame(dict_line)
            df_dimred = pd.concat([df_dimred, df_line], ignore_index=True, axis=0)
        df_dimred.to_csv(csv_path, index=False, sep="\t")
    else:
        logger.debug(f"No viable obsm_key was found for {atlas_name}")
create_qc_plots(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None

Display the atlas QC Search for the OBS variable which correspond to the toal_RNA, total_UMI, MT_ratio, RT_ratio :param path: :param adata: :param atlas_name: :param atlas_path: :return:

Source code in checkatlas/seurat.py
def create_qc_plots(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
    """
    Display the atlas QC
    Search for the OBS variable which correspond to the toal_RNA, total_UMI,
     MT_ratio, RT_ratio
    :param path:
    :param adata:
    :param atlas_name:
    :param atlas_path:
    :return:
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    qc_path = os.path.join(
        folders.get_folder(args.path, folders.QC_FIG),
        atlas_name + check.QC_FIG_EXTENSION,
    )
    logger.debug(f"Create QC violin plot for {atlas_name}")
    importr("ggplot2")
    r_cmd = (
        "vln_plot <- function(seurat, obs, qc_path){"
        "vln <- VlnPlot(seurat, features = obs, ncol = length(obs));"
        "ggsave(qc_path, vln, width = 10, "
        "height = 4, dpi = 150)}"
    )
    r_violin = robjects.r(r_cmd)
    obs_keys = list(SEURAT_TO_SCANPY_OBS.keys())
    r_obs = robjects.StrVector(obs_keys)
    r_violin(seurat, r_obs, qc_path)
create_qc_tables(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None

Display the atlas QC of seurat Search for the metadata variable which correspond to the total_RNA, total_UMI, MT_ratio, RT_ratio :param path: :param adata: :param atlas_name: :param atlas_path: :return:

Source code in checkatlas/seurat.py
def create_qc_tables(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
    """
    Display the atlas QC of seurat
    Search for the metadata variable which correspond
    to the total_RNA, total_UMI, MT_ratio, RT_ratio
    :param path:
    :param adata:
    :param atlas_name:
    :param atlas_path:
    :return:
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    qc_path = os.path.join(
        folders.get_folder(args.path, folders.QC),
        atlas_name + check.TSV_EXTENSION,
    )
    logger.debug(f"Create QC tables for {atlas_name}")
    obs_keys = get_viable_obs_qc(seurat, args)
    r_meta = robjects.r("obs <- function(seurat){ return(seurat@meta.data)}")
    r_metadata = r_meta(seurat)
    with (ro.default_converter + pandas2ri.converter).context():
        df_metadata = ro.conversion.get_conversion().rpy2py(r_metadata)
        df_annot = df_metadata[obs_keys]
        # rename columns with scanpy names
        new_columns = list()
        for column in df_annot.columns:
            new_columns.append(SEURAT_TO_SCANPY_OBS[column])
        df_annot.columns = new_columns

        # Rank cell by qc metric
        for header in df_annot.columns:
            if header != atlas.CELLINDEX_HEADER:
                new_header = f"cellrank_{header}"
                df_annot = df_annot.sort_values(header, ascending=False)
                df_annot.loc[:, [new_header]] = range(1, len(df_annot) + 1)

        # Sample QC table when more cells than args.plot_celllimit are present
        df_annot = atlas.atlas_sampling(df_annot, "QC", args)
        df_annot.loc[:, [atlas.CELLINDEX_HEADER]] = range(1, len(df_annot) + 1)
        df_annot.to_csv(qc_path, index=False, quoting=False, sep="\t")
create_summary_table(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None

Create a table with all interesting variables :param seurat: :param atlas_name: :param csv_path: :return:

Source code in checkatlas/seurat.py
def create_summary_table(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
    """
    Create a table with all interesting variables
    :param seurat:
    :param atlas_name:
    :param csv_path:
    :return:
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    logger.debug(f"Create Summary table for {atlas_name}")
    csv_path = os.path.join(
        folders.get_folder(args.path, folders.SUMMARY),
        atlas_name + check.TSV_EXTENSION,
    )
    # Create summary table
    header = [
        "AtlasFileType",
        "NbCells",
        "NbGenes",
        "AnnData.raw",
        "AnnData.X",
        "File_extension",
        "File_path",
    ]
    r_nrow = robjects.r["nrow"]
    r_ncol = robjects.r["ncol"]
    ncells = r_ncol(seurat)[0]
    ngenes = r_nrow(seurat)[0]
    x_raw = False
    x_norm = True
    df_summary = pd.DataFrame(index=[atlas_name], columns=header)
    df_summary["AtlasFileType"][atlas_name] = atlas_info[check.ATLAS_TYPE_KEY]
    df_summary["NbCells"][atlas_name] = ncells
    df_summary["NbGenes"][atlas_name] = ngenes
    df_summary["AnnData.raw"][atlas_name] = x_raw
    df_summary["AnnData.X"][atlas_name] = x_norm
    df_summary["File_extension"][atlas_name] = atlas_info[check.ATLAS_EXTENSION_KEY]
    df_summary["File_path"][atlas_name] = atlas_info[check.ATLAS_PATH_KEY]
    df_summary.to_csv(csv_path, index=False, sep="\t")
create_tsne_fig(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None

Display the TSNE of celltypes Search for the OBS variable which correspond to the celltype annotation :param path: :param adata: :param atlas_name: :param atlas_path: :return:

Source code in checkatlas/seurat.py
def create_tsne_fig(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
    """
    Display the TSNE of celltypes
    Search for the OBS variable which correspond to the celltype annotation
    :param path:
    :param adata:
    :param atlas_name:
    :param atlas_path:
    :return:
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    # Search if tsne reduction exists
    r = re.compile(".*tsne*.")
    r_names = robjects.r["names"]
    obsm_list = r_names(seurat)
    importr("ggplot2")
    if len(list(filter(r.match, obsm_list))) > 0:
        logger.debug(f"Create t-SNE figure for {atlas_name}")
        # Setting up figures directory
        tsne_path = os.path.join(
            folders.get_folder(args.path, folders.TSNE),
            atlas_name + check.TSNE_EXTENSION,
        )
        # Exporting tsne
        obs_keys = get_viable_obs_annot(seurat, args)
        r_cmd = (
            "tsne <- function(seurat, obs_key, tsne_path){"
            "tsne_plot <- DimPlot(seurat, group.by = obs_key, "
            'reduction = "tsne");'
            "ggsave(tsne_path, tsne_plot, width = 10, "
            "height = 6, dpi = 76)}"
        )
        r_tsne = robjects.r(r_cmd)
        r_tsne(seurat, obs_keys[0], tsne_path)
create_umap_fig(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None

Display the UMAP of celltypes Search for the OBS variable which correspond to the celltype annotation :param path: :param adata: :param atlas_name: :param atlas_path: :return:

Source code in checkatlas/seurat.py
def create_umap_fig(seurat: RS4, atlas_info: dict, args=argparse.Namespace) -> None:
    """
    Display the UMAP of celltypes
    Search for the OBS variable which correspond to the celltype annotation
    :param path:
    :param adata:
    :param atlas_name:
    :param atlas_path:
    :return:
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    # Search if tsne reduction exists
    r = re.compile(".*umap*.")
    r_names = robjects.r["names"]
    obsm_list = r_names(seurat)
    importr("ggplot2")
    if len(list(filter(r.match, obsm_list))) > 0:
        logger.debug(f"Create UMAP figure for {atlas_name}")
        # Setting up figures directory
        umap_path = os.path.join(
            folders.get_folder(args.path, folders.UMAP),
            atlas_name + check.UMAP_EXTENSION,
        )
        # Exporting umap
        obs_keys = get_viable_obs_annot(seurat, args)
        r_cmd = (
            "umap <- function(seurat, obs_key, umap_path){"
            "umap_plot <- DimPlot(seurat, group.by = obs_key, "
            'reduction = "umap");'
            "ggsave(umap_path, umap_plot, width = 10, "
            "height = 6, dpi = 76)}"
        )
        r_umap = robjects.r(r_cmd)
        r_umap(seurat, obs_keys[0], umap_path)
get_viable_obs_annot(seurat: RS4, args: argparse.Namespace) -> list

Search in obs_keys a match to OBS_CLUSTERS values ! Remove obs_key with only one category ! Extract sorted obs_keys in same order then OBS_CLUSTERS

Parameters:
  • seurat (RS4) –

    description

  • args (Namespace) –

    description

Returns:
  • list( list ) –

    description

Source code in checkatlas/seurat.py
def get_viable_obs_annot(seurat: RS4, args: argparse.Namespace) -> list:
    """
    Search in obs_keys a match to OBS_CLUSTERS values
    ! Remove obs_key with only one category !
    Extract sorted obs_keys in same order then OBS_CLUSTERS

    Args:
        seurat (RS4): _description_
        args (argparse.Namespace): _description_

    Returns:
        list: _description_
    """
    obs_keys = list()
    r_obs = robjects.r("obs <- function(seurat){ return(colnames(seurat@meta.data))}")
    obs_key_seurat = r_obs(seurat)
    r_annot = robjects.r(
        "type <- function(seurat, obs_key){ " "return(seurat[[obs_key]][[obs_key]])}"
    )
    # Get keys from OBS_CLUSTERS
    for obs_key in obs_key_seurat:
        for obs_key_celltype in args.obs_cluster:
            if obs_key_celltype in obs_key:
                if isinstance(r_annot(seurat, obs_key), FactorVector):
                    obs_keys.append(obs_key)
    # Remove keys with only one category
    obs_keys_final = list()
    for obs_key in obs_keys:
        annotations = r_annot(seurat, obs_key)
        if len(annotations.levels) != 1:
            logger.debug(f"Add obs_key {obs_key} with cat {annotations.levels}")
            obs_keys_final.append(obs_key)
    return sorted(obs_keys_final)
get_viable_obs_qc(seurat: RS4, args: argparse.Namespace) -> list

Search in obs_keys a match to OBS_QC values Extract sorted obs_keys in same order then OBS_QC

Parameters:
  • seurat (RS4) –

    description

  • args (Namespace) –

    description

Returns:
  • list( list ) –

    description

Source code in checkatlas/seurat.py
def get_viable_obs_qc(seurat: RS4, args: argparse.Namespace) -> list:
    """
    Search in obs_keys a match to OBS_QC values
    Extract sorted obs_keys in same order then OBS_QC

    Args:
        seurat (RS4): _description_
        args (argparse.Namespace): _description_

    Returns:
        list: _description_
    """
    r_obs = robjects.r("obs <- function(seurat){ return(colnames(seurat@meta.data))}")
    obs_keys = list()
    for obs_qc in args.qc_display:
        obs_qc = SCANPY_TO_SEURAT_OBS[obs_qc]
        if obs_qc in r_obs(seurat):
            obs_keys.append(obs_qc)
    return obs_keys
get_viable_obsm(seurat, args)

Search viable obsm for dimensionality reduction metric calc. ! No filter on osbm is appled for now ! :param seurat: :param args: :return:

Source code in checkatlas/seurat.py
def get_viable_obsm(seurat, args):
    """
    Search viable obsm for dimensionality reduction metric
    calc.
    ! No filter on osbm is appled for now !
    :param seurat:
    :param args:
    :return:
    """
    obsm_keys = list()
    # for obsm_key in adata.obsm_keys():
    #   if obsm_key in args.obsm_dimred:
    r_obsm = robjects.r("f<-function(seurat){return(names(seurat@reductions))}")
    obsm_keys_r = r_obsm(seurat)
    obsm_keys = list()
    for obsm_key in obsm_keys_r:
        print(obsm_key)
        obsm_keys.append(obsm_key)
    logger.debug(f"Add obsm {obsm_keys}")
    return obsm_keys
read_atlas(atlas_info: dict) -> RS4

Read Seurat object in python using rpy2

Parameters:
  • atlas_info (dict) –

    info dict about the atlas

Returns:
  • RS4( RS4 ) –

    description

Source code in checkatlas/seurat.py
def read_atlas(atlas_info: dict) -> RS4:
    """Read Seurat object in python using rpy2

    Args:
        atlas_info (dict): info dict about the atlas

    Returns:
        RS4: _description_
    """
    atlas_name = atlas_info[check.ATLAS_NAME_KEY]
    atlas_path = atlas_info[check.ATLAS_PATH_KEY]
    logger.info(f"Load {atlas_name} in " f"{atlas_path}")
    rcode = f'readRDS("{atlas_path}")'
    seurat = robjects.r(rcode)
    rclass = robjects.r["class"]
    if rclass(seurat)[0] == "Seurat":
        importr("Seurat")
        return seurat
    else:
        logger.info(f"{atlas_name} is not a Seurat object")
        return None

checkatlas.cellranger

checkatlas.cellranger

read_cellranger_current(atlas_info: dict) -> AnnData

Read cellranger files.

Load first /outs/filtered_feature_bc_matrix.h5 Then add (if found): - Clustering - PCA- - UMAP - TSNE Args: atlas_path (dict): info on the atlas

Returns:
  • AnnData( AnnData ) –

    scanpy object from cellranger

Source code in checkatlas/cellranger.py
def read_cellranger_current(atlas_info: dict) -> AnnData:
    """
    Read cellranger files.

    Load first /outs/filtered_feature_bc_matrix.h5
    Then add (if found):
    - Clustering
    - PCA-
    - UMAP
    - TSNE
    Args:
        atlas_path (dict): info on the atlas

    Returns:
        AnnData: scanpy object from cellranger
    """
    cellranger_out_path = os.path.dirname(atlas_info[check.ATLAS_PATH_KEY])
    cellranger_analysis_path = os.path.join(cellranger_out_path, "analysis")
    cellranger_clust_path = os.path.join(cellranger_analysis_path, "clustering")
    cellranger_umap_path = os.path.join(cellranger_analysis_path, "umap")
    cellranger_tsne_path = os.path.join(cellranger_analysis_path, "tsne")
    cellranger_pca_path = os.path.join(cellranger_analysis_path, "pca")

    # Search graphclust
    graphclust_path = ""
    for root, dirs, files in os.walk(cellranger_clust_path):
        for dir in dirs:
            if dir.endswith("graphclust"):
                cluster_path = os.path.join(root, dir, "clusters.csv")
                if os.path.exists(cluster_path) and not root.endswith("atac"):
                    graphclust_path = cluster_path
                    break
    # Search kmeans
    kmeans_path = ""
    k_value = 0
    found_kmeans = False
    for root, dirs, files in os.walk(cellranger_clust_path):
        for dir in dirs:
            # Search the highest kmeans = 10
            dir_prefix = "kmeans_10"
            if dir_prefix in dir and not found_kmeans:
                cluster_path = os.path.join(root, dir, "clusters.csv")
                if os.path.exists(cluster_path):
                    kmeans_path = cluster_path
                    k_value = 10
                    found_kmeans = True
                    break
            # Or search the highest kmeans = 5 (for multiome atlas)
            dir_prefix = os.path.join("gex", "kmeans_5")
            if dir_prefix in os.path.join(root, dir) and not found_kmeans:
                cluster_path = os.path.join(root, dir, "clusters.csv")
                if os.path.exists(cluster_path):
                    kmeans_path = cluster_path
                    k_value = 5
                    found_kmeans = True
                    break

    # Search umap
    rna_umap = ""
    for root, dirs, files in os.walk(cellranger_umap_path):
        for file in files:
            if file.endswith("projection.csv") and not root.endswith("atac"):
                rna_umap = os.path.join(root, file)
                break

    # Search t-SNE
    rna_tsne = ""
    for root, dirs, files in os.walk(cellranger_tsne_path):
        for file in files:
            if file.endswith("projection.csv") and not root.endswith("atac"):
                rna_tsne = os.path.join(root, file)
                break

    rna_pca = ""
    for root, dirs, files in os.walk(cellranger_pca_path):
        for file in files:
            if file.endswith("projection.csv") and not root.endswith("atac"):
                rna_pca = os.path.join(root, file)
                break

    # Manage multiome cellranger files
    dim_red_path = os.path.join(cellranger_analysis_path, "dimensionality_reduction")
    if os.path.exists(dim_red_path):
        gex_path = os.path.join(dim_red_path, "gex")
        if os.path.exists(gex_path):
            rna_umap = os.path.join(gex_path, "umap_projection.csv")
            rna_tsne = os.path.join(gex_path, "tsne_projection.csv")
            rna_pca = os.path.join(gex_path, "pca_projection.csv")

    # Read 10x h5 file
    adata = sc.read_10x_h5(atlas_info[check.ATLAS_PATH_KEY])
    adata.var_names_make_unique()

    # Add cluster
    if os.path.exists(graphclust_path):
        df_cluster = pd.read_csv(graphclust_path, index_col=0)
        adata.obs["cellranger_graphclust"] = df_cluster["Cluster"]
    if os.path.exists(kmeans_path):
        df_cluster = pd.read_csv(kmeans_path, index_col=0)
        adata.obs["cellranger_kmeans_" + str(k_value)] = df_cluster["Cluster"]

    # Add reduction
    if os.path.exists(rna_umap):
        df_umap = pd.read_csv(rna_umap, index_col=0)
        adata.obsm["X_umap"] = df_umap
    if os.path.exists(rna_tsne):
        df_tsne = pd.read_csv(rna_tsne, index_col=0)
        adata.obsm["X_tsne"] = df_tsne
    if os.path.exists(rna_pca):
        df_pca = pd.read_csv(rna_pca, index_col=0)
        adata.obsm["X_pca"] = df_pca

    return adata
read_cellranger_obsolete(atlas_info: dict) -> AnnData

Read cellranger files.

Load first /outs/filtered_feature_bc_matrix.h5 Then add (if found): - Clustering - PCA- - UMAP - TSNE Args: atlas_path (dict): info on the atlas

Returns:
  • AnnData( AnnData ) –

    scanpy object from cellranger

Source code in checkatlas/cellranger.py
def read_cellranger_obsolete(atlas_info: dict) -> AnnData:
    """
    Read cellranger files.

    Load first /outs/filtered_feature_bc_matrix.h5
    Then add (if found):
    - Clustering
    - PCA-
    - UMAP
    - TSNE
    Args:
        atlas_path (dict): info on the atlas

    Returns:
        AnnData: scanpy object from cellranger
    """
    cellranger_path = atlas_info[check.ATLAS_PATH_KEY].replace(CELLRANGER_MATRIX_FILE, "")
    cellranger_out_path = os.path.join(cellranger_path, os.pardir, os.pardir)
    cellranger_analysis_path = os.path.join(cellranger_out_path, "analysis_csv")

    cellranger_umap_path = os.path.join(cellranger_analysis_path, "umap")
    cellranger_tsne_path = os.path.join(cellranger_analysis_path, "tsne")
    cellranger_pca_path = os.path.join(cellranger_analysis_path, "pca")
    print(cellranger_out_path)
    print(cellranger_analysis_path)
    print(cellranger_umap_path)

    # Search graphclust
    graphclust_path = ""
    for root, dirs, files in os.walk(cellranger_out_path):
        for dir in dirs:
            if dir.endswith("graphclust"):
                cluster_path = os.path.join(root, dir, "clusters.csv")
                if os.path.exists(cluster_path):
                    graphclust_path = cluster_path
                    break
    # Search kmeans
    kmeans_path = ""
    k_value = 0
    for root, dirs, files in os.walk(cellranger_out_path):
        for dir in dirs:
            if dir.endswith("kmeans"):
                # Search the highest kmeans from 15 to 3
                for k in reversed(range(3, 16)):
                    cluster_path = os.path.join(
                        root, dir, str(k) + "_clusters", "clusters.csv"
                    )
                    if os.path.exists(cluster_path):
                        kmeans_path = cluster_path
                        k_value = k
                        break

    rna_umap = os.path.join(cellranger_umap_path, "projection.csv")
    rna_tsne = os.path.join(cellranger_tsne_path, "projection.csv")
    rna_pca = os.path.join(cellranger_pca_path, "projection.csv")

    # get matrix folder
    matrix_folder = os.path.dirname(atlas_info[check.ATLAS_PATH_KEY])
    adata = sc.read_10x_mtx(matrix_folder)
    adata.var_names_make_unique()

    # Add cluster
    if os.path.exists(graphclust_path):
        df_cluster = pd.read_csv(graphclust_path, index_col=0)
        adata.obs["cellranger_graphclust"] = df_cluster["Cluster"]
    if os.path.exists(kmeans_path):
        df_cluster = pd.read_csv(kmeans_path, index_col=0)
        adata.obs["cellranger_kmeans_" + str(k_value)] = df_cluster["Cluster"]

    # Add reduction
    if os.path.exists(rna_umap):
        df_umap = pd.read_csv(rna_umap, index_col=0)
        adata.obsm["X_umap"] = df_umap
    if os.path.exists(rna_tsne):
        df_tsne = pd.read_csv(rna_tsne, index_col=0)
        if len(df_tsne) == len(adata):
            adata.obsm["X_tsne"] = df_tsne
    if os.path.exists(rna_pca):
        df_pca = pd.read_csv(rna_pca, index_col=0)
        if len(df_pca) == len(adata):
            adata.obsm["X_pca"] = df_pca
    return adata

checkatlas.metrics.metrics

checkatlas.metrics.metrics

annotation_to_num(annotation, ref_annotation)

Transforms the annotations from categorical to numerical

Parameters

adata partition_key reference

Returns
Source code in checkatlas/metrics/metrics.py
def annotation_to_num(annotation, ref_annotation):
    """
    Transforms the annotations from categorical to numerical

    Parameters
    ----------
    adata
    partition_key
    reference

    Returns
    -------

    """
    annotation = annotation.to_numpy()
    ref_annotation = ref_annotation.to_numpy()
    le = LabelEncoder()
    le.fit(annotation)
    annotation = le.transform(annotation)
    le2 = LabelEncoder()
    le2.fit(ref_annotation)
    ref_annotation = le2.transform(ref_annotation)
    return annotation, ref_annotation
cal_annot(adata, atlas_name=None, metric_list=None, all=False, file_dir=None, n_jobs=-1, verbose=True, preprocess_context=None)

Comprehensive annotation pipeline for all annotation metrics.

Parameters:
  • adata (AnnData) –

    Annotated data matrix.

  • metric_list (list, default: None ) –

    List of metric names to calculate. If provided, overrides all parameter.

  • all (bool, default: False ) –

    If True, calculate all available annotation metrics. If False, calculate a default subset. Ignored if metric_list is provided.

  • file_dir (str, default: None ) –

    Directory path where the results CSV will be saved. If None, saves to current working directory.

  • n_jobs (int, default: -1 ) –

    Number of parallel jobs (-1 = all cores).

  • verbose (bool, default: True ) –

    Whether to print progress information.

  • preprocess_context (PreprocessContext, default: None ) –

    If provided, column detection and kNN precomputation are skipped and the context's precomputed data (kNN graphs, neighbour graphs, batch keys) is reused.

Returns:
  • pd.DataFrame: Results dataframe with columns: [Atlas Name, Metric Name, Reference/Input 1, Prediction/Input 2, Value, Time (s)]

Source code in checkatlas/metrics/metrics.py
def cal_annot(
    adata,
    atlas_name=None,
    metric_list=None,
    all=False,
    file_dir=None,
    n_jobs=-1,
    verbose=True,
    preprocess_context=None,
):
    """
    Comprehensive annotation pipeline for all annotation metrics.

    Args:
        adata (AnnData): Annotated data matrix.
        metric_list (list): List of metric names to calculate.
                           If provided, overrides `all` parameter.
        all (bool): If True, calculate all available annotation metrics.
                    If False, calculate a default subset.
                    Ignored if metric_list is provided.
        file_dir (str): Directory path where the results CSV will be saved.
                       If None, saves to current working directory.
        n_jobs (int): Number of parallel jobs (-1 = all cores).
        verbose (bool): Whether to print progress information.
        preprocess_context (PreprocessContext, optional): If provided,
            column detection and kNN precomputation are skipped and
            the context's precomputed data (kNN graphs, neighbour
            graphs, batch keys) is reused.

    Returns:
        pd.DataFrame: Results dataframe with columns:
                       [Atlas Name, Metric Name, Reference/Input 1, Prediction/Input 2, Value, Time (s)]
    """
    import inspect

    from ..utils.col_detector import CheckAtlasColumnDetector
    from ._jax_utils import _GPU_AVAILABLE as _cal_gpu
    from ._jax_utils import _JAX_AVAILABLE as _cal_jax
    from ._neighbors import NeighborResults, _clear_neighbors_cache
    from ._neighbors import compute_neighbors as _cal_knn

    _USE_JAX = _cal_jax and _cal_gpu

    # Set file directory
    if file_dir is None:
        file_dir = os.getcwd()
    else:
        os.makedirs(file_dir, exist_ok=True)

    # ── Precomputed kNN lookup (populated from context or built locally) ──
    emb_nn = {}

    if preprocess_context is not None:
        ref_keys = preprocess_context.ref_keys
        pred_keys = preprocess_context.pred_keys
        embedding_keys = getattr(preprocess_context, "annotation_embedding_keys", None) or preprocess_context.embedding_keys
        batch_keys = preprocess_context.batch_keys
        if not batch_keys:
            batch_keys = [col for col in adata.obs.columns if "batch" in col.lower()]

        if verbose:
            print("Using precomputed context — skipping column detection")
            print(f"  Reference keys: {ref_keys}")
            print(f"  Predicted keys: {pred_keys}")
            print(f"  Embedding keys: {embedding_keys}")
            print(f"  Batch keys: {batch_keys}")

        # Load precomputed kNN from .npz files
        for emb in embedding_keys:
            _safe = emb.replace("/", "_").replace(" ", "_")
            if emb in preprocess_context.knn_paths or _safe in preprocess_context.knn_paths:
                from ._cache import load_knn

                loaded = load_knn(preprocess_context.annotation_dir, f"knn_{_safe}")
                if loaded is not None:
                    emb_nn[emb] = NeighborResults(
                        indices=loaded[0], distances=loaded[1]
                    )

        # Re-inject precomputed neighbour graphs for graph_connectivity
        for emb, payload in preprocess_context.neighbor_graphs.items():
            key_added = payload.get("key_added", f"neighbors_{emb}")
            if key_added not in adata.uns:
                adata.uns[key_added] = payload.get("uns_entry", {})
            conn_key = (
                payload.get("uns_entry", {})
                .get("connectivities_key", "connectivities")
            )
            dist_key = (
                payload.get("uns_entry", {})
                .get("distances_key", "distances")
            )
            if conn_key in payload and conn_key not in adata.obsp:
                adata.obsp[conn_key] = payload["connectivities"]
            if dist_key in payload and dist_key not in adata.obsp:
                adata.obsp[dist_key] = payload["distances"]

        # ── Load precomputed distance matrices from cluster cache ──
        # (same pattern as cal_cluster lines 835-855)
        precomputed_dists = {}
        if preprocess_context is not None:
            _safe = lambda s: s.replace("/", "_").replace(" ", "_")
            for emb in embedding_keys:
                tri_path = os.path.join(
                    preprocess_context.cluster_dir,
                    f"dist_{_safe(emb)}.tri",
                )
                npy_path = tri_path.replace(".tri", ".npy")
                if os.path.exists(tri_path):
                    n_cells = (
                        adata.obsm[emb].shape[0]
                        if emb in adata.obsm
                        else adata.n_obs
                    )
                    precomputed_dists[emb] = TriangularMatrix(
                        n=n_cells, filepath=tri_path, mode="r"
                    )
                elif os.path.exists(npy_path):
                    precomputed_dists[emb] = np.load(npy_path)
    else:
        # Detect columns
        detector = CheckAtlasColumnDetector(adata)
        params = detector.detect_all_parameters()

        ref_keys = [x[0] for x in params["annotation"]["reference"]]
        pred_keys = [x[0] for x in params["annotation"]["predicted"]]
        embedding_keys = [x[0] for x in params["clustering"]["embeddings"]]

        batch_keys = [x[0] for x in params.get("batch", [])]
        if not batch_keys:
            batch_keys = [col for col in adata.obs.columns if "batch" in col.lower()]

    # Define metrics to run
    if metric_list is not None:
        metrics_list = [m for m in metric_list if m in METRICS_ANNOT]
    elif all:
        metrics_list = METRICS_ANNOT
    else:
        metrics_list = [
            "adj_rand_index",
            "normalized_mutual_info",
            "adj_mutual_info",
        ]
        metrics_list = [m for m in metrics_list if m in METRICS_ANNOT]

    results = []
    atlas_name = atlas_name

    # Categorize metrics based on their input requirements
    # Ref vs Pred
    ref_pred_metrics = [
        "adj_mutual_info",
        "adj_rand_index",
        "fowlkes_mallows",
        "isolated_f1_score",
        "mutual_info",
        "normalized_mutual_info",
        "rand_index",
        "vmeasure",
    ]

    # Embedding + Labels
    emb_label_metrics = ["average_silhouette_width", "dunn_index"]

    # Batch / Integration (adata + batch/label)
    batch_metrics = ["kbet", "pcr"]  # lisi is special (iLISI vs cLISI)

    # Graph Connectivity (adata + neighbors)
    graph_metrics = ["graph_connectivity"]

    # Bio Conservation (adata_before, adata_after) - Skipping for single adata pipeline
    # unless we define strategy.
    bio_metrics = ["cell_cycle_conservation", "highly_variable_genes"]

    # Create progress bar with custom format
    pbar = tqdm(
        metrics_list,
        desc="Calculating Annotation Metrics",
        bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
    )

    for metric in pbar:
        # Start timing for this metric
        metric_start_time = time.time()

        # Update progress bar with current metric name
        pbar.set_description(f"Processing: {metric}")

        metric_module = getattr(annot, metric)

        try:
            # 1. Ref vs Pred Metrics
            if metric in ref_pred_metrics:
                if not ref_keys or not pred_keys:
                    continue
                for ref in ref_keys:
                    for pred in pred_keys:
                        # Skip if ref == pred
                        if ref == pred:
                            continue
                        try:
                            # Preprocess labels
                            labels_true = adata.obs[ref]
                            labels_pred = adata.obs[pred]
                            # Convert to numeric if needed (some metrics handle it, some don't)
                            # Most sklearn metrics handle strings, but let's be safe or rely on metric impl
                            # checkatlas metrics usually take raw inputs or handle conversion
                            # But calc_metric_annot_scanpy uses annotation_to_num.
                            # We should probably use that helper or do it here.
                            l_pred, l_true = annotation_to_num(labels_pred, labels_true)

                            pair_start = time.time()
                            # Build kwargs dynamically
                            sig = inspect.signature(metric_module.run)
                            kw = {}
                            if "n_jobs" in sig.parameters:
                                kw["n_jobs"] = n_jobs
                            if "verbose" in sig.parameters:
                                kw["verbose"] = False
                            val = metric_module.run(l_pred, l_true, **kw)
                            pair_elapsed = time.time() - pair_start
                            results.append(
                                {
                                    "Atlas Name": atlas_name,
                                    "Metric Name": metric,
                                    "Reference/Input 1": ref,
                                    "Prediction/Input 2": pred,
                                    "Value": val,
                                    "Time (s)": round(pair_elapsed, 3),
                                }
                            )
                        except Exception as e:
                            logger.warning(
                                f"Failed to calculate {metric} for {ref} vs {pred}: {e}"
                            )

            # 2. Embedding + Labels (ASW, Dunn)
            elif metric in emb_label_metrics:
                if not embedding_keys:
                    continue
                # Run for both ref and pred labels? Usually for predicted clusters.
                # But ASW can be run on ground truth too.
                targets = list(set(ref_keys + pred_keys))
                for emb in embedding_keys:
                    if emb not in adata.obsm:
                        continue
                    X_emb = adata.obsm[emb]
                    for label in targets:
                        try:
                            labels = adata.obs[label]
                            # Convert to numeric for ASW/Dunn?
                            # ASW handles labels.
                            pair_start = time.time()
                            sig = inspect.signature(metric_module.run)
                            kw = {}
                            if "n_jobs" in sig.parameters:
                                kw["n_jobs"] = n_jobs
                            if "verbose" in sig.parameters:
                                kw["verbose"] = False
                            if "precomputed_dists" in sig.parameters and emb in precomputed_dists:
                                kw["precomputed_dists"] = precomputed_dists[emb]
                            val = metric_module.run(X_emb, labels, **kw)
                            pair_elapsed = time.time() - pair_start
                            results.append(
                                {
                                    "Atlas Name": atlas_name,
                                    "Metric Name": metric,
                                    "Reference/Input 1": emb,
                                    "Prediction/Input 2": label,
                                    "Value": val,
                                    "Time (s)": round(pair_elapsed, 3),
                                }
                            )
                        except Exception as e:
                            logger.warning(
                                f"Failed to calculate {metric} for {emb} vs {label}: {e}"
                            )

            # 3. LISI (Special Case: iLISI and cLISI)
            elif metric == "lisi":
                # ── Precompute kNN per embedding once (GPU/JAX or CPU) ──
                # Only build if not already loaded from preprocess_context
                if not emb_nn and _USE_JAX:
                    for emb in embedding_keys:
                        X_emb = np.asarray(adata.obsm[emb], dtype=np.float64)
                        emb_nn[emb] = _cal_knn(X_emb, n_neighbors=90, backend="auto")

                # iLISI: needs batch
                if batch_keys:
                    for batch in batch_keys:
                        try:
                            if embedding_keys:
                                for emb in embedding_keys:
                                    X_emb = adata.obsm[emb]
                                    pair_start = time.time()
                                    if emb in emb_nn:
                                        val = metric_module.run_with_neighbors(
                                            emb_nn[emb],
                                            adata.obs[batch],
                                            verbose=False,
                                        )
                                    else:
                                        sig = inspect.signature(metric_module.run)
                                        kw = {}
                                        if "n_jobs" in sig.parameters:
                                            kw["n_jobs"] = n_jobs
                                        if "verbose" in sig.parameters:
                                            kw["verbose"] = False
                                        val = metric_module.run(
                                            X_emb, adata.obs[batch], **kw
                                        )
                                    pair_elapsed = time.time() - pair_start
                                    results.append(
                                        {
                                            "Atlas Name": atlas_name,
                                            "Metric Name": "iLISI",
                                            "Reference/Input 1": emb,
                                            "Prediction/Input 2": batch,
                                            "Value": val,
                                            "Time (s)": round(pair_elapsed, 3),
                                        }
                                    )
                            else:
                                pair_start = time.time()
                                sig = inspect.signature(metric_module.run)
                                kw = {}
                                if "n_jobs" in sig.parameters:
                                    kw["n_jobs"] = n_jobs
                                if "verbose" in sig.parameters:
                                    kw["verbose"] = False
                                val = metric_module.run(adata.X, adata.obs[batch], **kw)
                                pair_elapsed = time.time() - pair_start
                                results.append(
                                    {
                                        "Atlas Name": atlas_name,
                                        "Metric Name": "iLISI",
                                        "Reference/Input 1": "X",
                                        "Prediction/Input 2": batch,
                                        "Value": val,
                                        "Time (s)": round(pair_elapsed, 3),
                                    }
                                )
                        except Exception as e:
                            logger.warning(f"Failed to calculate iLISI for {batch}: {e}")

                # cLISI: needs cell type (ref or pred)
                targets = list(set(ref_keys + pred_keys))
                for label in targets:
                    try:
                        if embedding_keys:
                            for emb in embedding_keys:
                                X_emb = adata.obsm[emb]
                                pair_start = time.time()
                                if emb in emb_nn:
                                    val = metric_module.run_with_neighbors(
                                        emb_nn[emb],
                                        adata.obs[label],
                                        verbose=False,
                                    )
                                else:
                                    sig = inspect.signature(metric_module.run)
                                    kw = {}
                                    if "n_jobs" in sig.parameters:
                                        kw["n_jobs"] = n_jobs
                                    if "verbose" in sig.parameters:
                                        kw["verbose"] = False
                                    val = metric_module.run(X_emb, adata.obs[label], **kw)
                                pair_elapsed = time.time() - pair_start
                                results.append(
                                    {
                                        "Atlas Name": atlas_name,
                                        "Metric Name": "cLISI",
                                        "Reference/Input 1": emb,
                                        "Prediction/Input 2": label,
                                        "Value": val,
                                        "Time (s)": round(pair_elapsed, 3),
                                    }
                                )
                        else:
                            pair_start = time.time()
                            sig = inspect.signature(metric_module.run)
                            kw = {}
                            if "n_jobs" in sig.parameters:
                                kw["n_jobs"] = n_jobs
                            if "verbose" in sig.parameters:
                                kw["verbose"] = False
                            val = metric_module.run(adata.X, adata.obs[label], **kw)
                            pair_elapsed = time.time() - pair_start
                            results.append(
                                {
                                    "Atlas Name": atlas_name,
                                    "Metric Name": "cLISI",
                                    "Reference/Input 1": "X",
                                    "Prediction/Input 2": label,
                                    "Value": val,
                                    "Time (s)": round(pair_elapsed, 3),
                                }
                            )
                    except Exception as e:
                        logger.warning(f"Failed to calculate cLISI for {label}: {e}")

            # 4. Batch Metrics (kBET, PCR) — evaluate per (embedding × batch)
            elif metric in batch_metrics:
                if not batch_keys:
                    continue
                for batch in batch_keys:
                    if embedding_keys:
                        for emb in embedding_keys:
                            try:
                                X_emb = adata.obsm[emb]
                                pair_start = time.time()
                                # kBET: use precomputed kNN from context or JAX
                                if (
                                    metric == "kbet"
                                    and hasattr(metric_module, "run_with_neighbors")
                                ):
                                    if emb in emb_nn:
                                        nn = emb_nn[emb]
                                        if nn.n_neighbors > 25:
                                            nn = nn.subset_neighbors(25)
                                        val = metric_module.run_with_neighbors(
                                            nn, adata.obs[batch], verbose=False
                                        )
                                    elif _USE_JAX:
                                        X_arr = np.asarray(X_emb, dtype=np.float64)
                                        nn = _cal_knn(X_arr, n_neighbors=25, backend="auto")
                                        val = metric_module.run_with_neighbors(
                                            nn, adata.obs[batch], verbose=False
                                        )
                                    else:
                                        sig = inspect.signature(metric_module.run)
                                        kw = {}
                                        if "n_jobs" in sig.parameters:
                                            kw["n_jobs"] = n_jobs
                                        if "verbose" in sig.parameters:
                                            kw["verbose"] = False
                                        val = metric_module.run(X_emb, adata.obs[batch], **kw)
                                else:
                                    sig = inspect.signature(metric_module.run)
                                    kw = {}
                                    if "n_jobs" in sig.parameters:
                                        kw["n_jobs"] = n_jobs
                                    if "verbose" in sig.parameters:
                                        kw["verbose"] = False
                                    val = metric_module.run(X_emb, adata.obs[batch], **kw)
                                pair_elapsed = time.time() - pair_start
                                results.append(
                                    {
                                        "Atlas Name": atlas_name,
                                        "Metric Name": metric,
                                        "Reference/Input 1": emb,
                                        "Prediction/Input 2": batch,
                                        "Value": val,
                                        "Time (s)": round(pair_elapsed, 3),
                                    }
                                )
                            except Exception as e:
                                logger.warning(
                                    f"Failed to calculate {metric} for {emb} vs {batch}: {e}"
                                )
                    else:
                        try:
                            pair_start = time.time()
                            if (
                                metric == "kbet"
                                and hasattr(metric_module, "run_with_neighbors")
                            ):
                                if "X" in emb_nn:
                                    nn = emb_nn["X"]
                                    if nn.n_neighbors > 25:
                                        nn = nn.subset_neighbors(25)
                                    val = metric_module.run_with_neighbors(
                                        nn, adata.obs[batch], verbose=False
                                    )
                                elif _USE_JAX:
                                    X_arr = np.asarray(adata.X, dtype=np.float64)
                                    if hasattr(adata.X, "toarray"):
                                        X_arr = adata.X.toarray().astype(np.float64)
                                    nn = _cal_knn(X_arr, n_neighbors=25, backend="auto")
                                    val = metric_module.run_with_neighbors(
                                        nn, adata.obs[batch], verbose=False
                                    )
                                else:
                                    sig = inspect.signature(metric_module.run)
                                    kw = {}
                                    if "n_jobs" in sig.parameters:
                                        kw["n_jobs"] = n_jobs
                                    if "verbose" in sig.parameters:
                                        kw["verbose"] = False
                                    val = metric_module.run(adata.X, adata.obs[batch], **kw)
                            else:
                                sig = inspect.signature(metric_module.run)
                                kw = {}
                                if "n_jobs" in sig.parameters:
                                    kw["n_jobs"] = n_jobs
                                if "verbose" in sig.parameters:
                                    kw["verbose"] = False
                                val = metric_module.run(adata.X, adata.obs[batch], **kw)
                            pair_elapsed = time.time() - pair_start
                            results.append(
                                {
                                    "Atlas Name": atlas_name,
                                    "Metric Name": metric,
                                    "Reference/Input 1": "X",
                                    "Prediction/Input 2": batch,
                                    "Value": val,
                                    "Time (s)": round(pair_elapsed, 3),
                                }
                            )
                        except Exception as e:
                            logger.warning(
                                f"Failed to calculate {metric} for {batch}: {e}"
                            )

            # 5. Graph Connectivity
            elif metric in graph_metrics:
                # We need embeddings AND labels
                targets = list(set(ref_keys + pred_keys))

                if embedding_keys:
                    for emb in embedding_keys:
                        # Calculate neighbors for this embedding
                        key_added = f"neighbors_{emb}"
                        try:
                            # Ensure neighbors are calculated
                            import scanpy as sc

                            sc.pp.neighbors(adata, use_rep=emb, key_added=key_added)
                        except Exception as e:
                            logger.warning(
                                f"Failed to calculate neighbors for {emb}: {e}"
                            )
                            continue

                        for label in targets:
                            try:
                                pair_start = time.time()
                                sig = inspect.signature(metric_module.run)
                                kw = {
                                    "neighbors_key": key_added,
                                    "label_key": label,
                                }
                                if "n_jobs" in sig.parameters:
                                    kw["n_jobs"] = n_jobs
                                if "verbose" in sig.parameters:
                                    kw["verbose"] = False
                                val = metric_module.run(adata, **kw)
                                pair_elapsed = time.time() - pair_start
                                results.append(
                                    {
                                        "Atlas Name": atlas_name,
                                        "Metric Name": metric,
                                        "Reference/Input 1": emb,
                                        "Prediction/Input 2": label,
                                        "Value": val,
                                        "Time (s)": round(pair_elapsed, 3),
                                    }
                                )
                            except Exception as e:
                                logger.warning(
                                    f"Failed to calculate {metric} for {emb} vs {label}: {e}"
                                )
                else:
                    # No embeddings found, use default neighbors (X or PCA)
                    # We still need labels
                    for label in targets:
                        try:
                            # metric_module.run will calculate neighbors if 'neighbors' key missing
                            pair_start = time.time()
                            sig = inspect.signature(metric_module.run)
                            kw = {"label_key": label}
                            if "n_jobs" in sig.parameters:
                                kw["n_jobs"] = n_jobs
                            if "verbose" in sig.parameters:
                                kw["verbose"] = False
                            val = metric_module.run(adata, **kw)
                            pair_elapsed = time.time() - pair_start
                            results.append(
                                {
                                    "Atlas Name": atlas_name,
                                    "Metric Name": metric,
                                    "Reference/Input 1": "Default",
                                    "Prediction/Input 2": label,
                                    "Value": val,
                                    "Time (s)": round(pair_elapsed, 3),
                                }
                            )
                        except Exception as e:
                            logger.warning(
                                f"Failed to calculate {metric} for Default vs {label}: {e}"
                            )

            # 6. Bio Conservation
            elif metric in bio_metrics:
                # Requires adata_before.
                # If we don't have it, we can't run it properly.
                # We'll skip for now or log warning.
                logger.info(f"Skipping {metric} as it requires 'adata_before'.")

            else:
                logger.warning(f"Metric {metric} not categorized in pipeline.")

        except Exception as e:
            logger.error(f"Error running metric {metric}: {e}")

        # Calculate and display metric execution time
        metric_elapsed = time.time() - metric_start_time
        pbar.set_postfix_str(f"Time: {metric_elapsed:.2f}s", refresh=True)

    df = pd.DataFrame(results)

    # Save MultiQC-compatible wide format to file_dir if provided
    if not df.empty and file_dir is not None and atlas_name is not None:
        os.makedirs(file_dir, exist_ok=True)
        wide_df = _pivot_annot_to_wide(df, atlas_name)
        wide_path = os.path.join(file_dir, f"{atlas_name}.tsv")
        wide_df.to_csv(wide_path, sep="\t", index=False)
        logger.info("MultiQC-compatible annotation table saved to %s", wide_path)

    # Clear LISI and kBET kNN caches to free memory
    try:
        from .annot import lisi as _lisi

        _lisi._clear_knn_cache()
    except Exception:
        pass
    try:
        from .annot import kbet as _kbet

        _kbet._clear_knn_cache()
    except Exception:
        pass

    return df
cal_cluster(adata, atlas_name=None, metric_list=None, all_metrics=True, file_dir=None, n_jobs=-1, verbose=True, seed=42, preprocess_context=None)

Comprehensive clustering assessment pipeline.

Calculates all clustering metrics to evaluate the quality of cluster assignments against embedding representations. Uses CheckAtlasColumnDetector to auto-detect embedding keys and cluster label columns.

Parameters:
  • adata (AnnData) –

    Annotated data matrix.

  • atlas_name (str, default: None ) –

    Name of the atlas for labeling results.

  • metric_list (list, default: None ) –

    List of metric names to calculate. If provided, overrides all_metrics.

  • all_metrics (bool, default: True ) –

    If True, calculate all available cluster metrics. If False, calculate a default subset. Ignored if metric_list is provided.

  • file_dir (str, default: None ) –

    Directory path where the results CSV will be saved. If None, saves to current working directory.

  • n_jobs (int, default: -1 ) –

    Number of parallel jobs (-1 = all cores).

  • verbose (bool, default: True ) –

    Whether to print progress.

  • seed (int, default: 42 ) –

    Random seed for reproducibility.

  • preprocess_context (PreprocessContext, default: None ) –

    If provided, column detection and distance-matrix precomputation are skipped and the context's precomputed data is reused.

Returns:
  • pd.DataFrame: Results dataframe with columns: [Atlas Name, Metric Name, Embedding, Label Key, Value, Time (s)]

Source code in checkatlas/metrics/metrics.py
def cal_cluster(
    adata,
    atlas_name=None,
    metric_list=None,
    all_metrics=True,
    file_dir=None,
    n_jobs=-1,
    verbose=True,
    seed=42,
    preprocess_context=None,
):
    """
    Comprehensive clustering assessment pipeline.

    Calculates all clustering metrics to evaluate the quality of cluster assignments
    against embedding representations. Uses CheckAtlasColumnDetector to auto-detect
    embedding keys and cluster label columns.

    Args:
        adata (AnnData): Annotated data matrix.
        atlas_name (str): Name of the atlas for labeling results.
        metric_list (list, optional): List of metric names to calculate.
                           If provided, overrides ``all_metrics``.
        all_metrics (bool): If True, calculate all available cluster metrics.
                            If False, calculate a default subset.
                            Ignored if ``metric_list`` is provided.
        file_dir (str): Directory path where the results CSV will be saved.
                       If None, saves to current working directory.
        n_jobs (int): Number of parallel jobs (-1 = all cores).
        verbose (bool): Whether to print progress.
        seed (int): Random seed for reproducibility.
        preprocess_context (PreprocessContext, optional): If provided,
            column detection and distance-matrix precomputation are
            skipped and the context's precomputed data is reused.

    Returns:
        pd.DataFrame: Results dataframe with columns:
                      [Atlas Name, Metric Name, Embedding, Label Key, Value, Time (s)]
    """
    import gc
    import inspect

    from scipy.sparse import issparse

    from ..utils.col_detector import CheckAtlasColumnDetector

    # Set file directory
    if file_dir is None:
        file_dir = os.getcwd()
    else:
        os.makedirs(file_dir, exist_ok=True)

    precomputed_dists = {}

    if preprocess_context is not None:
        embedding_keys = preprocess_context.cluster_embedding_keys or preprocess_context.embedding_keys
        label_keys = preprocess_context.cluster_label_keys

        if not label_keys:
            logger.warning(
                "No cluster labels in preprocess context. Skipping."
            )
            return pd.DataFrame()

        if verbose:
            print("Using precomputed context — skipping column detection")
            print(f"  Embeddings: {embedding_keys}")
            print(f"  Cluster labels: {label_keys}")

        # Load precomputed distance matrices for silhouette
        _safe = lambda s: s.replace("/", "_").replace(" ", "_")
        for emb in embedding_keys:
            tri_path = os.path.join(
                preprocess_context.cluster_dir, f"dist_{_safe(emb)}.tri"
            )
            npy_path = tri_path.replace(".tri", ".npy")
            if os.path.exists(tri_path):
                from ._triangular import TriangularMatrix

                if emb == "X":
                    n_cells = adata.X.shape[0]
                elif emb in adata.obsm:
                    n_cells = adata.obsm[emb].shape[0]
                else:
                    n_cells = adata.n_obs
                precomputed_dists[emb] = TriangularMatrix(
                    n=n_cells, filepath=tri_path, mode="r"
                )
            elif os.path.exists(npy_path):
                precomputed_dists[emb] = np.load(npy_path)
    else:
        # Detect columns using CheckAtlasColumnDetector
        if verbose:
            print("Detecting embeddings and cluster labels...")
        detector = CheckAtlasColumnDetector(adata)
        params = detector.detect_all_parameters()

        embedding_keys = [x[0] for x in params["clustering"]["embeddings"]]
        label_keys = [x[0] for x in params["clustering"]["cluster_labels"]]

    if verbose:
        print(f"  Detected embeddings: {embedding_keys}")
        print(f"  Detected cluster labels: {label_keys}")

    if not embedding_keys:
        logger.warning(
            "No embeddings detected in adata.obsm. Trying adata.X as fallback."
        )
        embedding_keys = ["X"]

    if not label_keys:
        logger.warning(
            "No cluster labels detected in adata.obs. Cannot run cluster metrics."
        )
        return pd.DataFrame()

    # Define metrics to run
    if metric_list is not None:
        metrics_list = [m for m in metric_list if m in METRICS_CLUST]
    elif all_metrics:
        metrics_list = METRICS_CLUST
    else:
        # Default subset: fastest and most commonly used
        metrics_list = [
            "silhouette",
            "davies_bouldin",
            "calinski_harabasz",
        ]
        metrics_list = [m for m in metrics_list if m in METRICS_CLUST]

    results = []

    # Total combinations
    total_combos = len(embedding_keys) * len(label_keys) * len(metrics_list)
    if verbose:
        print(
            f"\nRunning {len(metrics_list)} metrics × "
            f"{len(embedding_keys)} embeddings × "
            f"{len(label_keys)} label keys = {total_combos} calculations\n"
        )

    # Create progress bar
    pbar = tqdm(
        total=total_combos,
        desc="Calculating Cluster Metrics",
        bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
        disable=not verbose,
    )

    for emb_key in embedding_keys:
        # Extract embedding data
        if emb_key == "X":
            X_emb = adata.X
            if issparse(X_emb):
                X_emb = X_emb.toarray()
        else:
            if emb_key not in adata.obsm:
                logger.warning(
                    f"Embedding '{emb_key}' not found in adata.obsm. Skipping."
                )
                pbar.update(len(label_keys) * len(metrics_list))
                continue
            X_emb = np.asarray(adata.obsm[emb_key])

        for label_key in label_keys:
            if label_key not in adata.obs.columns:
                logger.warning(
                    f"Label key '{label_key}' not found in adata.obs. Skipping."
                )
                pbar.update(len(metrics_list))
                continue

            labels = np.asarray(adata.obs[label_key])

            # Check we have at least 2 clusters
            n_unique = len(np.unique(labels))
            if n_unique < 2:
                logger.warning(f"Label '{label_key}' has < 2 clusters. Skipping.")
                pbar.update(len(metrics_list))
                continue

            for metric_name in metrics_list:
                metric_start_time = time.time()
                pbar.set_description(f"{emb_key}|{label_key}: {metric_name}")

                try:
                    metric_module = getattr(cluster, metric_name)

                    # Build kwargs dynamically based on metric's signature
                    sig = inspect.signature(metric_module.run)
                    metric_params = sig.parameters

                    kwargs = {}

                    if "n_jobs" in metric_params:
                        kwargs["n_jobs"] = n_jobs
                    if "verbose" in metric_params:
                        kwargs["verbose"] = False
                    if "random_state" in metric_params:
                        kwargs["random_state"] = seed
                    if "seed" in metric_params:
                        kwargs["seed"] = seed
                    if "max_samples" in metric_params:
                        kwargs["max_samples"] = None  # disable subsampling

                    # Pass precomputed distances for silhouette when available
                    if "precomputed_dists" in metric_params and emb_key in precomputed_dists:
                        kwargs["precomputed_dists"] = precomputed_dists[emb_key]

                    # Call the metric
                    value = metric_module.run(X_emb, labels, **kwargs)

                    elapsed = time.time() - metric_start_time

                    results.append(
                        {
                            "Atlas Name": atlas_name,
                            "Metric Name": metric_name,
                            "Embedding": emb_key,
                            "Label Key": label_key,
                            "Value": value,
                            "Time (s)": round(elapsed, 3),
                        }
                    )

                    pbar.set_postfix_str(f"={value:.4f} ({elapsed:.1f}s)", refresh=True)

                except Exception as e:
                    elapsed = time.time() - metric_start_time
                    logger.warning(f"Failed {metric_name} for {emb_key}/{label_key}: {e}")
                    results.append(
                        {
                            "Atlas Name": atlas_name,
                            "Metric Name": metric_name,
                            "Embedding": emb_key,
                            "Label Key": label_key,
                            "Value": np.nan,
                            "Time (s)": round(elapsed, 3),
                        }
                    )

                pbar.update(1)

    pbar.close()

    # Create DataFrame
    df = pd.DataFrame(results)

    # Summary
    if verbose and not df.empty:
        total_time = df["Time (s)"].sum()
        print(f"\nTotal computation time: {total_time:.2f}s")
        print(
            f"Results: {len(df)} measurements across {len(df['Metric Name'].unique())} metrics"
        )

    gc.collect()
    return df
cal_dimred(adata, atlas_name=None, low_dim_keys=None, high_dim_key='X', metric_list=None, k_neighbors=30, n_samples=None, seed=42, n_jobs=-1, file_dir=None, verbose=True, use_cache=True)

Calculate dimensionality reduction metrics for multiple embeddings.

For each low_dim_key the embedding is compared against the reference high_dim_key (defaults to adata.X — the raw gene expression matrix).

Precomputation (distance matrices + kNN graphs) is performed once for the high‑dimensional reference and once per low‑dimensional embedding key. Memory‑mapped files are used for N×N distance matrices on datasets with > 10 000 cells to keep RAM usage bounded.

Parameters

adata : AnnData atlas_name : str, optional Used only for logging. low_dim_keys : list of str, optional .obsm keys to evaluate. Defaults to all keys returned by :meth:CheckAtlasColumnDetector.get_dimred_embeddings. high_dim_key : str Reference key. 'X' means adata.X. (Default 'X') metric_list : list of str, optional Metrics to compute. Defaults to :data:METRICS_DIMRED. k_neighbors : int n_samples : int or None Number of cells to subsample. None = use all cells. seed : int n_jobs : int file_dir : str, optional Directory for temporary memmap files. Falls back to system /tmp. verbose : bool

Returns

pd.DataFrame Long‑format table with columns [Metric Name, Low Dim Key, High Dim Key, Value, Time (s)]. Caller is responsible for pivoting and saving.

Source code in checkatlas/metrics/metrics.py
def cal_dimred(
    adata,
    atlas_name=None,
    low_dim_keys=None,
    high_dim_key="X",
    metric_list=None,
    k_neighbors=30,
    n_samples=None,
    seed=42,
    n_jobs=-1,
    file_dir=None,
    verbose=True,
    use_cache=True,
):
    """Calculate dimensionality reduction metrics for multiple embeddings.

    For each ``low_dim_key`` the embedding is compared against the
    reference ``high_dim_key`` (defaults to ``adata.X`` — the raw gene
    expression matrix).

    Precomputation (distance matrices + kNN graphs) is performed once for
    the high‑dimensional reference and once per low‑dimensional embedding
    key.  Memory‑mapped files are used for N×N distance matrices on
    datasets with > 10 000 cells to keep RAM usage bounded.

    Parameters
    ----------
    adata : AnnData
    atlas_name : str, optional
        Used only for logging.
    low_dim_keys : list of str, optional
        ``.obsm`` keys to evaluate.  Defaults to all keys returned by
        :meth:`CheckAtlasColumnDetector.get_dimred_embeddings`.
    high_dim_key : str
        Reference key.  ``'X'`` means ``adata.X``.  (Default ``'X'``)
    metric_list : list of str, optional
        Metrics to compute.  Defaults to :data:`METRICS_DIMRED`.
    k_neighbors : int
    n_samples : int or None
        Number of cells to subsample.  ``None`` = use all cells.
    seed : int
    n_jobs : int
    file_dir : str, optional
        Directory for temporary memmap files.  Falls back to system
        ``/tmp``.
    verbose : bool

    Returns
    -------
    pd.DataFrame
        Long‑format table with columns
        ``[Metric Name, Low Dim Key, High Dim Key, Value, Time (s)]``.
        Caller is responsible for pivoting and saving.
    """
    import gc
    import inspect
    import tempfile
    import uuid

    from sklearn.metrics import pairwise_distances
    from sklearn.neighbors import NearestNeighbors

    from ..utils.col_detector import CheckAtlasColumnDetector

    if metric_list is None:
        metric_list = METRICS_DIMRED
    metric_list = [m for m in metric_list if m in METRICS_DIMRED]
    if not metric_list:
        return pd.DataFrame()

    if low_dim_keys is None:
        detector = CheckAtlasColumnDetector(adata)
        low_dim_keys = detector.get_dimred_embeddings()

    low_dim_keys = [k for k in low_dim_keys if k != high_dim_key]
    if not low_dim_keys:
        if verbose:
            print(f"  No embeddings to compare " f"(only key is {high_dim_key}).")
        return pd.DataFrame()

    n_obs = adata.n_obs
    if n_samples is not None and n_samples < n_obs:
        np.random.seed(seed)
        sample_indices = np.random.choice(n_obs, n_samples, replace=False)
        n_cells = n_samples
    else:
        sample_indices = np.arange(n_obs)
        n_cells = n_obs

    if high_dim_key == "X":
        high_dim_data = adata.X[sample_indices]
        if hasattr(high_dim_data, "toarray"):
            high_dim_data = high_dim_data.toarray()
    else:
        if high_dim_key not in adata.obsm_keys():
            raise ValueError(f"High-dim key '{high_dim_key}' not found in adata.obsm.")
        high_dim_data = adata.obsm[high_dim_key][sample_indices]

    high_n_features = high_dim_data.shape[1]

    use_memmap = n_cells > 10000
    all_memmap_files = []

    if file_dir:
        temp_dir = file_dir
    else:
        # Try default, fall back to /data if system temp is on a small partition
        default_tmp = os.path.join(os.path.expanduser("~"), ".checkatlas", "tmp")
        temp_dir = default_tmp
        try:
            import shutil

            _usage = shutil.disk_usage(default_tmp)
            if _usage.free < 50 * 1024**3:  # < 50 GB free
                _fallback = "/data/tmp" if os.path.exists("/data/tmp") else "/tmp"
                # Also try nextflow TMPDIR if set
                _nf_tmp = os.environ.get("TMPDIR", "")
                if _nf_tmp and os.path.exists(_nf_tmp):
                    _fallback = _nf_tmp
                temp_dir = os.path.join(_fallback, ".checkatlas_tmp")
                if verbose:
                    print(f"  Low disk space on default temp; using {temp_dir}")
        except Exception:
            pass
    os.makedirs(temp_dir, exist_ok=True)

    # ── Persistent cache check ───────────────────────────────────
    _from_cache = False
    _cache_low_dim = {}
    if use_cache and atlas_name:
        _fp = compute_fingerprint(
            n_cells=n_cells,
            n_features=high_n_features,
            embedding_keys=low_dim_keys,
            embedding_shapes={
                k: tuple(adata.obsm[k][sample_indices].shape) for k in low_dim_keys
            },
            k_neighbors=k_neighbors,
            source_path=getattr(adata, "filename", None),
        )
        _cached = load_dimred_cache(temp_dir, _fp, n_cells, low_dim_keys)
        if _cached is not None:
            high_dim_dists = _cached["high_dim_dists"]
            high_knn_dists = _cached["high_knn_dists"]
            high_knn_indices = _cached["high_knn_indices"]
            _cache_low_dim = _cached["low_dim"]
            _from_cache = True
            if verbose:
                print("  [CACHE HIT] Reusing precomputed distances & kNN")

    run_id = str(uuid.uuid4())[:8]
    chunk_size = min(1000, n_cells)

    if verbose:
        print(
            f"  Reference: {high_dim_key}  "
            f"({n_cells:,} cells  ×  {high_n_features} features)"
        )
        print(f"  Embeddings: {low_dim_keys}")
        print(f"  Metrics: {metric_list}")
        backend = "GPU (JAX)" if (_JAX_AVAILABLE and _GPU_AVAILABLE) else "CPU"
        print(f"  Backend: {backend}")

    # ── GPU path: single-kernel distance matrix + kNN (Phase 3) ────
    # Memory: N² float32 ≈ 4·N² bytes → with intermediate buffers ≈ 4·N²·3.5 bytes
    # A100 40 GB: safe N ≤ 50 000 (≈ 32.6 GB total)
    # For 50k < N ≤ 150k: chunked GPU + memmap on /data (avoid disk IO bottleneck)
    _GPU_SINGLE_SHOT = _JAX_AVAILABLE and _GPU_AVAILABLE and (n_cells <= 50000)
    _GPU_CHUNKED = _JAX_AVAILABLE and _GPU_AVAILABLE and (50000 < n_cells <= 150000)

    _DIST_METRICS = frozenset(
        (
            "kruskal_stress",
            "spearman_rho",
            "dCor",
            "trustworthiness",
            "continuity",
        )
    )

    _low_dim_precomputed = {}  # collect for saving to cache later

    if _from_cache:
        pass  # precomputation already loaded

    elif _GPU_SINGLE_SHOT:
        import jax.numpy as jnp

        high_dim_dists = pdist_squareform(high_dim_data)  # GPU matmul → (n,n) float32
        high_dists_jax = jnp.asarray(high_dim_dists, dtype=jnp.float32)

        # kNN from precomputed distance matrix on GPU
        import jax

        high_knn_dists_jax, high_knn_indices_jax = jax.lax.approx_min_k(
            high_dists_jax, k=k_neighbors + 1
        )
        high_knn_dists = _get_ndarray(high_knn_dists_jax)
        high_knn_indices = _get_ndarray(high_knn_indices_jax)

        # ── Persist GPU distance matrix for cache reuse ───────
        # save_dimred_cache only stores TriangularMatrix memmaps;
        # convert the GPU numpy array so subsequent runs don't
        # have to recompute this expensive (N² × D) matrix.
        if use_cache and atlas_name:
            _high_tri_path = os.path.join(temp_dir, "high_dists.tri")
            _high_tri = TriangularMatrix(
                n=n_cells, filepath=_high_tri_path, mode="w+"
            )
            _chunk = min(5000, n_cells)
            for _i in range(0, n_cells, _chunk):
                _end = min(_i + _chunk, n_cells)
                store_upper_triangle(
                    _high_tri._data, high_dim_dists[_i:_end, :],
                    _i, 0, n_cells,
                )
            _high_tri.flush()
            high_dim_dists = _high_tri

    elif _GPU_CHUNKED:
        # ── Chunked GPU path: fused kNN + distance matrix ──
        # kNN: streaming GPU, auto-detect for large atlases
        # Distances: upper‑triangle float16 memmap written in same pass
        import jax.numpy as jnp

        _qchunk = 15000  # query rows per chunk
        _rchunk = 10000  # reference columns per chunk

        if verbose:
            print(
                f"  Chunked GPU/streaming: {n_cells:,} cells, "
                f"query={_qchunk}, ref={_rchunk}"
            )
            print(f"  Storing in: {temp_dir}")

        # Create TriangularMatrix BEFORE kNN pass to fuse the operations
        high_dists_path = os.path.join(temp_dir, f"high_dists_{run_id}.tri")
        all_memmap_files.append(high_dists_path)
        high_dim_dists = TriangularMatrix(n=n_cells, filepath=high_dists_path, mode="w+")

        if verbose:
            print("  Computing kNN + distances (fused GPU)...")
        high_knn_results = compute_neighbors(
            np.asarray(high_dim_data, dtype=np.float64),
            n_neighbors=k_neighbors + 1,
            backend="auto",
            tri_memmap=high_dim_dists,
        )
        high_knn_dists = high_knn_results.distances
        high_knn_indices = high_knn_results.indices
        high_dim_dists.flush()
        gc.collect()
    else:
        # ─── CPU path (unchanged) ───────────────────────────────────
        # ── High‑dim kNN ──
        nbrs_high = NearestNeighbors(n_neighbors=k_neighbors + 1, n_jobs=n_jobs).fit(
            high_dim_data
        )
        high_knn_dists, high_knn_indices = nbrs_high.kneighbors(high_dim_data)

        # ── High‑dim distance matrix (only when distance‑based metrics
        #     are in the list) ──
        need_high_dists = bool(set(metric_list) & _DIST_METRICS)
        high_dim_dists = None

        if need_high_dists:
            if verbose:
                print(
                    "  Precomputing high-dim distance matrix "
                    f"({n_cells:,}×{n_cells:,})…"
                )
            if use_memmap:
                high_dists_path = os.path.join(temp_dir, f"high_dists_{run_id}.tri")
                all_memmap_files.append(high_dists_path)
                high_dim_dists = TriangularMatrix(
                    n=n_cells, filepath=high_dists_path, mode="w+"
                )
                for i in tqdm(
                    range(0, n_cells, chunk_size),
                    desc="High-Dim Distances",
                    disable=not verbose,
                ):
                    end = min(i + chunk_size, n_cells)
                    block = pairwise_distances(
                        high_dim_data[i:end], high_dim_data, n_jobs=n_jobs
                    )
                    store_upper_triangle(high_dim_dists._data, block, i, 0, n_cells)
                    high_dim_dists.flush()
            else:
                high_dim_dists = np.zeros((n_cells, n_cells), dtype=np.float32)
                for i in tqdm(
                    range(0, n_cells, chunk_size),
                    desc="High-Dim Distances",
                    disable=not verbose,
                ):
                    end = min(i + chunk_size, n_cells)
                    high_dim_dists[i:end, :] = pairwise_distances(
                        high_dim_data[i:end], high_dim_data, n_jobs=n_jobs
                    )
            gc.collect()

    # ── end of precomputation block ──

    total_combos = len(low_dim_keys) * len(metric_list)
    if verbose:
        print(
            f"\n  {len(metric_list)} metrics × {len(low_dim_keys)} embeddings"
            f" = {total_combos} calculations\n"
        )

    results = []
    pbar = tqdm(
        total=total_combos,
        desc="Calculating Dimred Metrics",
        bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} " "[{elapsed}<{remaining}]",
        disable=not verbose,
    )

    for low_dim_key in low_dim_keys:
        low_dim_data = adata.obsm[low_dim_key][sample_indices]

        low_dim_dists = None
        low_dists_path = None

        if _from_cache and low_dim_key in _cache_low_dim:
            # ── Load from cache ───────────────────────────────
            _c = _cache_low_dim[low_dim_key]
            low_dim_dists = _c.get("dists")
            low_knn_dists = _c["knn_dists"]
            low_knn_indices = _c["knn_indices"]

            # ── Recompute missing distance matrix ─────────────
            # Cache may have kNN but not distances (GPU numpy
            # arrays are not persisted — save_dimred_cache only
            # stores TriangularMatrix memmaps).  Compute the
            # distance matrix once so every metric can share it.
            if low_dim_dists is None and low_knn_dists is not None:
                if _GPU_SINGLE_SHOT:
                    import jax.numpy as _jnp

                    _dists_np = pdist_squareform(low_dim_data)
                    if use_cache and atlas_name:
                        _safe = low_dim_key.replace("/", "_").replace(" ", "_")
                        _tri_path = os.path.join(
                            temp_dir, f"low_dists_{_safe}.tri"
                        )
                        _tri = TriangularMatrix(
                            n=n_cells, filepath=_tri_path, mode="w+"
                        )
                        _chunk = min(5000, n_cells)
                        for _i in range(0, n_cells, _chunk):
                            _end = min(_i + _chunk, n_cells)
                            store_upper_triangle(
                                _tri._data, _dists_np[_i:_end, :],
                                _i, 0, n_cells,
                            )
                        _tri.flush()
                        low_dim_dists = _tri
                        del _dists_np
                    else:
                        low_dim_dists = _dists_np
                elif n_cells > 10000:
                    _safe = low_dim_key.replace("/", "_").replace(" ", "_")
                    _tri_path = os.path.join(
                        temp_dir, f"low_dists_{_safe}_{run_id}.tri"
                    )
                    _tri = TriangularMatrix(
                        n=n_cells, filepath=_tri_path, mode="w+"
                    )
                    for _i in range(0, n_cells, chunk_size):
                        _end = min(_i + chunk_size, n_cells)
                        _block = pairwise_distances(
                            low_dim_data[_i:_end], low_dim_data,
                            n_jobs=n_jobs,
                        )
                        store_upper_triangle(
                            _tri._data, _block, _i, 0, n_cells
                        )
                    _tri.flush()
                    low_dim_dists = _tri
                else:
                    need_d = bool(set(metric_list) & _DIST_METRICS)
                    if need_d:
                        low_dim_dists = np.zeros(
                            (n_cells, n_cells), dtype=np.float32
                        )
                        for _i in range(0, n_cells, chunk_size):
                            _end = min(_i + chunk_size, n_cells)
                            low_dim_dists[_i:_end, :] = pairwise_distances(
                                low_dim_data[_i:_end], low_dim_data,
                                n_jobs=n_jobs,
                            )
                gc.collect()

        elif _GPU_SINGLE_SHOT:
            # ── GPU path: single-kernel distance + kNN ──────────
            import jax
            import jax.numpy as jnp

            low_dim_dists = pdist_squareform(low_dim_data)
            low_dists_jax = jnp.asarray(low_dim_dists, dtype=jnp.float32)
            low_knn_dists_jax, low_knn_indices_jax = jax.lax.approx_min_k(
                low_dists_jax, k=k_neighbors + 1
            )
            low_knn_dists = _get_ndarray(low_knn_dists_jax)
            low_knn_indices = _get_ndarray(low_knn_indices_jax)

            # ── Persist GPU distance matrix for cache reuse ──
            if use_cache and atlas_name:
                _safe = low_dim_key.replace("/", "_").replace(" ", "_")
                _tri_path = os.path.join(
                    temp_dir, f"low_dists_{_safe}.tri"
                )
                _tri = TriangularMatrix(
                    n=n_cells, filepath=_tri_path, mode="w+"
                )
                _chunk = min(5000, n_cells)
                for _i in range(0, n_cells, _chunk):
                    _end = min(_i + _chunk, n_cells)
                    store_upper_triangle(
                        _tri._data, low_dim_dists[_i:_end, :],
                        _i, 0, n_cells,
                    )
                _tri.flush()
                low_dim_dists = _tri

        elif _GPU_CHUNKED:
            # ── GPU for low-dim when features ≤ 200 dims ─────────
            # Low-dim embeddings (X_pca=50d, X_umap=2d) fit on GPU easily.
            # High-dim X (60k genes) needs CPU memmap for N×N storage.
            low_ndim = low_dim_data.shape[1]
            if low_ndim <= 200 and n_cells <= 50000:
                # One-shot GPU: small feature dim → N² fits
                import jax
                import jax.numpy as jnp

                low_dim_dists = pdist_squareform(low_dim_data)
                low_dists_jax = jnp.asarray(low_dim_dists, dtype=jnp.float32)
                low_knn_dists_jax, low_knn_indices_jax = jax.lax.approx_min_k(
                    low_dists_jax, k=k_neighbors + 1
                )
                low_knn_dists = _get_ndarray(low_knn_dists_jax)
                low_knn_indices = _get_ndarray(low_knn_indices_jax)

                # ── Persist GPU distance matrix for cache reuse ──
                if use_cache and atlas_name:
                    _safe = low_dim_key.replace("/", "_").replace(" ", "_")
                    _tri_path = os.path.join(
                        temp_dir, f"low_dists_{_safe}.tri"
                    )
                    _tri = TriangularMatrix(
                        n=n_cells, filepath=_tri_path, mode="w+"
                    )
                    _chunk = min(5000, n_cells)
                    for _i in range(0, n_cells, _chunk):
                        _end = min(_i + _chunk, n_cells)
                        store_upper_triangle(
                            _tri._data, low_dim_dists[_i:_end, :],
                            _i, 0, n_cells,
                        )
                    _tri.flush()
                    low_dim_dists = _tri

            elif low_ndim <= 200:
                # Fused GPU streaming kNN + optional distance matrix
                need_low_dists = bool(set(metric_list) & _DIST_METRICS)
                if need_low_dists:
                    low_dists_path = os.path.join(
                        temp_dir,
                        f"low_dists_{low_dim_key.replace('/', '_')}_{run_id}.tri",
                    )
                    all_memmap_files.append(low_dists_path)
                    low_dim_dists = TriangularMatrix(
                        n=n_cells, filepath=low_dists_path, mode="w+"
                    )
                else:
                    low_dim_dists = None

                kNN = compute_neighbors(
                    np.asarray(low_dim_data, dtype=np.float64),
                    n_neighbors=k_neighbors + 1,
                    backend="auto",
                    tri_memmap=low_dim_dists,
                )
                low_knn_dists = kNN.distances
                low_knn_indices = kNN.indices
                if need_low_dists and low_dim_dists is not None:
                    low_dim_dists.flush()
                gc.collect()
            else:
                # High-dim X — CPU memmap (sklearn) → float16 upper‑triangle
                nbrs_low = NearestNeighbors(
                    n_neighbors=k_neighbors + 1, n_jobs=n_jobs
                ).fit(low_dim_data)
                low_knn_dists, low_knn_indices = nbrs_low.kneighbors(low_dim_data)
                low_dists_path = os.path.join(
                    temp_dir,
                    f"low_dists_{low_dim_key.replace('/', '_')}_{run_id}.tri",
                )
                all_memmap_files.append(low_dists_path)
                low_dim_dists = TriangularMatrix(
                    n=n_cells, filepath=low_dists_path, mode="w+"
                )
                for i in tqdm(
                    range(0, n_cells, chunk_size),
                    desc=f"Low-Dim Distances ({low_dim_key})",
                    disable=not verbose,
                ):
                    end = min(i + chunk_size, n_cells)
                    block = pairwise_distances(
                        low_dim_data[i:end], low_dim_data, n_jobs=n_jobs
                    )
                    store_upper_triangle(low_dim_dists._data, block, i, 0, n_cells)
                    low_dim_dists.flush()
                gc.collect()
        else:
            # ── CPU path ────────────────────────────────────────
            nbrs_low = NearestNeighbors(n_neighbors=k_neighbors + 1, n_jobs=n_jobs).fit(
                low_dim_data
            )
            low_knn_dists, low_knn_indices = nbrs_low.kneighbors(low_dim_data)

            need_low_dists = bool(set(metric_list) & _DIST_METRICS)
            if need_low_dists:
                if use_memmap:
                    low_dists_path = os.path.join(
                        temp_dir,
                        f"low_dists_{low_dim_key.replace('/', '_')}_{run_id}.tri",
                    )
                    all_memmap_files.append(low_dists_path)
                    low_dim_dists = TriangularMatrix(
                        n=n_cells, filepath=low_dists_path, mode="w+"
                    )
                    for i in tqdm(
                        range(0, n_cells, chunk_size),
                        desc=f"Low-Dim Distances ({low_dim_key})",
                        disable=not verbose,
                    ):
                        end = min(i + chunk_size, n_cells)
                        block = pairwise_distances(
                            low_dim_data[i:end], low_dim_data, n_jobs=n_jobs
                        )
                        store_upper_triangle(low_dim_dists._data, block, i, 0, n_cells)
                        low_dim_dists.flush()
                else:
                    low_dim_dists = np.zeros((n_cells, n_cells), dtype=np.float32)
                    for i in tqdm(
                        range(0, n_cells, chunk_size),
                        desc=f"Low-Dim Distances ({low_dim_key})",
                        disable=not verbose,
                    ):
                        end = min(i + chunk_size, n_cells)
                        low_dim_dists[i:end, :] = pairwise_distances(
                            low_dim_data[i:end], low_dim_data, n_jobs=n_jobs
                        )
                gc.collect()

        # ── Capture precomputed low-dim data for cache ────────
        if use_cache and atlas_name and not _from_cache:
            _low_dim_precomputed[low_dim_key] = {
                "dists": low_dim_dists,
                "knn_indices": low_knn_indices,
                "knn_dists": low_knn_dists,
            }

        for metric_name in metric_list:
            pbar.set_description(f"{low_dim_key}: {metric_name}")
            t_start = time.time()

            # ── Inspect metric signature to decide on to_dense ──────
            metric_module = getattr(dimred, metric_name)
            sig = inspect.signature(metric_module.run)
            params = sig.parameters

            # ── Materialize TriangularMatrix only when needed ───────
            # For N ≤ 50k cells to_dense() is cheap (~2.5 GB) and
            # saves the _rank_penalty helper a few row-gathers.
            # For N > 50k the dense allocation would be O(N²) and
            # catastrophic — the helper handles TriangularMatrix
            # natively via the GPU chunked / CPU chunked paths.
            _DENSE_TO_DENSE_MAX_N = 50_000
            _high_mat = high_dim_dists
            _low_mat = low_dim_dists
            if metric_name in ("trustworthiness", "continuity"):
                if (
                    isinstance(high_dim_dists, TriangularMatrix)
                    and high_dim_dists.n <= _DENSE_TO_DENSE_MAX_N
                ):
                    _high_mat = high_dim_dists.to_dense()
                if (
                    isinstance(low_dim_dists, TriangularMatrix)
                    and low_dim_dists is not None
                    and low_dim_dists.n <= _DENSE_TO_DENSE_MAX_N
                ):
                    _low_mat = low_dim_dists.to_dense()

            try:
                kwargs = {}
                if "low_dim_key" in params:
                    kwargs["low_dim_key"] = low_dim_key
                if "high_dim_key" in params:
                    kwargs["high_dim_key"] = high_dim_key
                if "k_neighbors" in params:
                    kwargs["k_neighbors"] = k_neighbors
                if "n_samples" in params:
                    kwargs["n_samples"] = n_samples
                if "n_jobs" in params:
                    kwargs["n_jobs"] = n_jobs
                if "seed" in params:
                    kwargs["seed"] = seed
                if "verbose" in params:
                    kwargs["verbose"] = False
                if "rank_backend" in params:
                    # Trust/continuity support "auto" | "jax_single_shot"
                    # | "jax_chunked" | "cpu".  "auto" picks the
                    # fastest backend based on N + GPU availability.
                    kwargs["rank_backend"] = "auto"

                if "precomputed_high_knn" in params:
                    kwargs["precomputed_high_knn"] = high_knn_indices
                if "precomputed_low_knn" in params:
                    kwargs["precomputed_low_knn"] = low_knn_indices
                if "precomputed_high_knn_dists" in params:
                    kwargs["precomputed_high_knn_dists"] = high_knn_dists
                if "precomputed_low_knn_dists" in params:
                    kwargs["precomputed_low_knn_dists"] = low_knn_dists
                if "precomputed_high_dists" in params:
                    kwargs["precomputed_high_dists"] = high_dim_dists
                if "precomputed_low_dists" in params:
                    kwargs["precomputed_low_dists"] = low_dim_dists

                value = metric_module.run(adata, **kwargs)
                elapsed = round(time.time() - t_start, 3)

                results.append(
                    {
                        "Metric Name": metric_name,
                        "Low Dim Key": low_dim_key,
                        "High Dim Key": high_dim_key,
                        "Value": value,
                        "Time (s)": elapsed,
                    }
                )
            except Exception as e:
                elapsed = round(time.time() - t_start, 3)
                logger.warning("Failed %s on %s: %s", metric_name, low_dim_key, e)
                results.append(
                    {
                        "Metric Name": metric_name,
                        "Low Dim Key": low_dim_key,
                        "High Dim Key": high_dim_key,
                        "Value": np.nan,
                        "Time (s)": elapsed,
                    }
                )

            pbar.update(1)

        if low_dim_dists is not None:
            try:
                if isinstance(low_dim_dists, TriangularMatrix):
                    _safe_close_triangular(low_dim_dists)
                else:
                    _safe_close_memmap(low_dim_dists)
            except Exception:
                pass
            # Persist when caching; delete otherwise
            if not use_cache:
                if low_dists_path and os.path.exists(low_dists_path):
                    try:
                        os.remove(low_dists_path)
                    except Exception:
                        pass
                    if low_dists_path in all_memmap_files:
                        all_memmap_files.remove(low_dists_path)
        gc.collect()

    pbar.close()

    if high_dim_dists is not None:
        try:
            if isinstance(high_dim_dists, TriangularMatrix):
                _safe_close_triangular(high_dim_dists)
            else:
                _safe_close_memmap(high_dim_dists)
        except Exception:
            pass
        gc.collect()

    # ── Save to persistent cache ──
    if use_cache and atlas_name and not _from_cache and high_dim_dists is not None:
        _low_info = {}
        for emb in low_dim_keys:
            _low_info[emb] = {
                "dists": _low_dim_precomputed.get(emb, {}).get("dists"),
                "knn_indices": _low_dim_precomputed.get(emb, {}).get("knn_indices"),
                "knn_dists": _low_dim_precomputed.get(emb, {}).get("knn_dists"),
            }
        try:
            save_dimred_cache(
                temp_dir,
                _fp,
                high_dim_dists=(
                    high_dim_dists
                    if isinstance(high_dim_dists, TriangularMatrix)
                    else None
                ),
                high_knn_dists=high_knn_dists,
                high_knn_indices=high_knn_indices,
                low_dim_data=_low_info,
                low_dim_keys=low_dim_keys,
            )
        except Exception as exc:
            logger.warning("Failed to save dimred cache: %s", exc)

    # Delete leftover temp files (only when NOT caching)
    if not use_cache:
        for fpath in all_memmap_files:
            try:
                if os.path.exists(fpath):
                    os.remove(fpath)
            except OSError as exc:
                logger.warning("Failed to clean up temp file %s: %s", fpath, exc)

    if "_dimred_cache" in adata.uns:
        del adata.uns["_dimred_cache"]

    gc.collect()

    df = pd.DataFrame(results)

    if verbose and not df.empty:
        total_time = df["Time (s)"].sum()
        n_results = len(df)
        print(f"\nTotal computation time: {total_time:.2f}s")
        print(f"Results: {n_results} measurements across " f"{len(metric_list)} metrics")

    return df
calc_metric_dimred(metric, adata, obsm_key)

Calculate a single dimensionality reduction metric (legacy function).

For comprehensive assessment, use cal_dimred() instead.

Source code in checkatlas/metrics/metrics.py
def calc_metric_dimred(metric, adata, obsm_key):
    """
    Calculate a single dimensionality reduction metric (legacy function).

    For comprehensive assessment, use cal_dimred() instead.
    """
    if metric in METRICS_DIMRED:
        start_time = time.time()
        # Get the right module for the metric
        logger.debug(f"Start {metric} calc")
        metric_module = getattr(dimred, metric)
        print(metric_module)
        high_dim_counts = adata.X
        low_dim_counts = adata.obsm[obsm_key]
        # execute the run function from metric_module
        metric_value = metric_module.run(high_dim_counts, low_dim_counts)
        running_time = time.time() - start_time
        logger.debug(f"{metric} calc finished, duration {running_time}")
        return metric_value, running_time
    else:
        logger.warning(
            f"{metric} is not a recognized "
            f"dimensionality reduction metric."
            f"\nList of dim. red. metrics: {METRICS_DIMRED}"
        )
        return -1
run_all_metrics(adata=None, atlas_path=None, atlas_name=None, file_dir=None, n_jobs=-1, verbose=True, seed=42, run_annotation=True, all_annot_metrics=True, run_clustering=True, all_cluster_metrics=True, run_dimred=True, all_dimred_metrics=True, low_dim_key='X_umap', high_dim_key='X', k_neighbors=30, n_samples=None, use_memmap=True, temp_dir=None, chunk_size=1000)

Unified metrics pipeline for CheckAtlas.

The standalone entry point that runs ALL metric tasks (annotation, clustering, dimensionality reduction) on a single-cell atlas and produces a consolidated results CSV. Designed for biologists and bioinformaticians — just provide an atlas path and get comprehensive quality metrics.

Usage

from checkatlas.metrics.metrics import run_all_metrics results = run_all_metrics(atlas_path="my_atlas.h5ad")

Or with a preloaded AnnData:

results = run_all_metrics(adata=my_adata, atlas_name="my_atlas")

Parameters:
  • adata (AnnData, default: None ) –

    Preloaded AnnData object. If None, loads from atlas_path.

  • atlas_path (str, default: None ) –

    Path to .h5ad atlas file. Used if adata is None.

  • atlas_name (str, default: None ) –

    Name for the atlas (used in output). If None, derived from atlas_path filename.

  • file_dir (str, default: None ) –

    Directory for output CSV. Defaults to current directory.

  • n_jobs (int, default: -1 ) –

    Number of parallel jobs (-1 = all cores).

  • verbose (bool, default: True ) –

    Whether to print progress and summary.

  • seed (int, default: 42 ) –

    Random seed for reproducibility.

  • run_annotation (bool, default: True ) –

    Whether to run annotation metrics.

  • all_annot_metrics (bool, default: True ) –

    If True, run all annotation metrics.

  • run_clustering (bool, default: True ) –

    Whether to run clustering metrics.

  • all_cluster_metrics (bool, default: True ) –

    If True, run all clustering metrics.

  • run_dimred (bool, default: True ) –

    Whether to run dimensionality reduction metrics.

  • all_dimred_metrics (bool, default: True ) –

    If True, run all dimred metrics.

  • low_dim_key (str, default: 'X_umap' ) –

    Key in adata.obsm for low-dimensional embedding.

  • high_dim_key (str, default: 'X' ) –

    Key for high-dimensional data ('X' = adata.X).

  • k_neighbors (int, default: 30 ) –

    Number of neighbors for kNN-based metrics.

  • n_samples (int, default: None ) –

    Number of samples for dimred metrics.

  • use_memmap (bool, default: True ) –

    Whether to use memory-mapped files for large distance matrices.

  • temp_dir (str, default: None ) –

    Directory for temporary memmap files.

  • chunk_size (int, default: 1000 ) –

    Chunk size for batched distance computation.

Returns:
  • pd.DataFrame: Consolidated results with columns: [Atlas Name, Task, Metric Name, Input 1, Input 2, Value, Time (s)]

Source code in checkatlas/metrics/metrics.py
def run_all_metrics(
    adata=None,
    atlas_path=None,
    atlas_name=None,
    file_dir=None,
    n_jobs=-1,
    verbose=True,
    seed=42,
    # Annotation params
    run_annotation=True,
    all_annot_metrics=True,
    # Clustering params
    run_clustering=True,
    all_cluster_metrics=True,
    # Dimred params
    run_dimred=True,
    all_dimred_metrics=True,
    low_dim_key="X_umap",
    high_dim_key="X",
    k_neighbors=30,
    n_samples=None,
    use_memmap=True,
    temp_dir=None,
    chunk_size=1000,
):
    """
    Unified metrics pipeline for CheckAtlas.

    The standalone entry point that runs ALL metric tasks (annotation, clustering,
    dimensionality reduction) on a single-cell atlas and produces a consolidated
    results CSV. Designed for biologists and bioinformaticians — just provide
    an atlas path and get comprehensive quality metrics.

    Usage:
        >>> from checkatlas.metrics.metrics import run_all_metrics
        >>> results = run_all_metrics(atlas_path="my_atlas.h5ad")

        >>> # Or with a preloaded AnnData:
        >>> results = run_all_metrics(adata=my_adata, atlas_name="my_atlas")

    Args:
        adata (AnnData, optional): Preloaded AnnData object. If None, loads from atlas_path.
        atlas_path (str, optional): Path to .h5ad atlas file. Used if adata is None.
        atlas_name (str, optional): Name for the atlas (used in output). If None,
                                   derived from atlas_path filename.
        file_dir (str, optional): Directory for output CSV. Defaults to current directory.
        n_jobs (int): Number of parallel jobs (-1 = all cores).
        verbose (bool): Whether to print progress and summary.
        seed (int): Random seed for reproducibility.

        run_annotation (bool): Whether to run annotation metrics.
        all_annot_metrics (bool): If True, run all annotation metrics.

        run_clustering (bool): Whether to run clustering metrics.
        all_cluster_metrics (bool): If True, run all clustering metrics.

        run_dimred (bool): Whether to run dimensionality reduction metrics.
        all_dimred_metrics (bool): If True, run all dimred metrics.
        low_dim_key (str): Key in adata.obsm for low-dimensional embedding.
        high_dim_key (str): Key for high-dimensional data ('X' = adata.X).
        k_neighbors (int): Number of neighbors for kNN-based metrics.
        n_samples (int, optional): Number of samples for dimred metrics.
        use_memmap (bool): Whether to use memory-mapped files for large distance matrices.
        temp_dir (str, optional): Directory for temporary memmap files.
        chunk_size (int): Chunk size for batched distance computation.

    Returns:
        pd.DataFrame: Consolidated results with columns:
                      [Atlas Name, Task, Metric Name, Input 1, Input 2, Value, Time (s)]
    """
    import gc

    import scanpy as sc

    overall_start = time.time()

    # ────────────────────────────────────────────────────────────────────
    # 1. Load atlas
    # ────────────────────────────────────────────────────────────────────
    if adata is None and atlas_path is not None:
        if verbose:
            print(f"Loading atlas from: {atlas_path}")
        adata = sc.read_h5ad(atlas_path)
        if atlas_name is None:
            atlas_name = os.path.splitext(os.path.basename(atlas_path))[0]
    elif adata is not None:
        if atlas_name is None:
            atlas_name = "atlas"
    else:
        raise ValueError("Provide either `adata` (AnnData) or `atlas_path` (str).")

    if verbose:
        print(f"\n{'='*70}")
        print(f"  CheckAtlas — Unified Metrics Pipeline")
        print(f"  Atlas: {atlas_name}")
        print(f"  Cells: {adata.n_obs:,}  |  Genes: {adata.n_vars:,}")
        print(f"  Tasks: ", end="")
        tasks = []
        if run_annotation:
            tasks.append("Annotation")
        if run_clustering:
            tasks.append("Clustering")
        if run_dimred:
            tasks.append("Dimred")
        print(" → ".join(tasks))
        print(f"{'='*70}\n")

    # Set output directory
    if file_dir is None:
        file_dir = os.getcwd()
    else:
        os.makedirs(file_dir, exist_ok=True)

    # Consolidated results
    all_results = []
    task_summary = {}

    # ────────────────────────────────────────────────────────────────────
    # 2. Run Annotation Metrics
    # ────────────────────────────────────────────────────────────────────
    if run_annotation:
        task_start = time.time()
        if verbose:
            print(f"\n{'─'*50}")
            print(f"  TASK 1/3: Annotation Metrics")
            print(f"{'─'*50}")

        try:
            df_annot = cal_annot(
                adata,
                atlas_name=atlas_name,
                all=all_annot_metrics,
                file_dir=file_dir,
            )

            if not df_annot.empty:
                # Normalize columns to unified schema
                unified = pd.DataFrame(
                    {
                        "Atlas Name": df_annot["Atlas Name"],
                        "Task": "annotation",
                        "Metric Name": df_annot["Metric Name"],
                        "Input 1": df_annot["Reference/Input 1"],
                        "Input 2": df_annot["Prediction/Input 2"],
                        "Value": df_annot["Value"],
                        "Time (s)": df_annot.get("Time (s)", np.nan),
                    }
                )
                all_results.append(unified)

            task_elapsed = time.time() - task_start
            n_metrics = len(df_annot) if not df_annot.empty else 0
            task_summary["Annotation"] = {
                "count": n_metrics,
                "time": round(task_elapsed, 2),
            }
            if verbose:
                print(f"  ✓ Annotation: {n_metrics} results in {task_elapsed:.1f}s")

        except Exception as e:
            task_elapsed = time.time() - task_start
            logger.error(f"Annotation pipeline failed: {e}")
            task_summary["Annotation"] = {
                "count": 0,
                "time": round(task_elapsed, 2),
                "error": str(e),
            }
            if verbose:
                print(f"  ✗ Annotation failed: {e}")

    # ────────────────────────────────────────────────────────────────────
    # 3. Run Clustering Metrics
    # ────────────────────────────────────────────────────────────────────
    if run_clustering:
        task_start = time.time()
        if verbose:
            print(f"\n{'─'*50}")
            print(f"  TASK 2/3: Clustering Metrics")
            print(f"{'─'*50}")

        try:
            df_cluster = cal_cluster(
                adata,
                atlas_name=atlas_name,
                all_metrics=all_cluster_metrics,
                file_dir=file_dir,
                n_jobs=n_jobs,
                verbose=verbose,
                seed=seed,
            )

            if not df_cluster.empty:
                unified = pd.DataFrame(
                    {
                        "Atlas Name": df_cluster["Atlas Name"],
                        "Task": "clustering",
                        "Metric Name": df_cluster["Metric Name"],
                        "Input 1": df_cluster["Embedding"],
                        "Input 2": df_cluster["Label Key"],
                        "Value": df_cluster["Value"],
                        "Time (s)": df_cluster["Time (s)"],
                    }
                )
                all_results.append(unified)

            task_elapsed = time.time() - task_start
            n_metrics = len(df_cluster) if not df_cluster.empty else 0
            task_summary["Clustering"] = {
                "count": n_metrics,
                "time": round(task_elapsed, 2),
            }
            if verbose:
                print(f"  ✓ Clustering: {n_metrics} results in {task_elapsed:.1f}s")

        except Exception as e:
            task_elapsed = time.time() - task_start
            logger.error(f"Clustering pipeline failed: {e}")
            task_summary["Clustering"] = {
                "count": 0,
                "time": round(task_elapsed, 2),
                "error": str(e),
            }
            if verbose:
                print(f"  ✗ Clustering failed: {e}")

    # ────────────────────────────────────────────────────────────────────
    # 4. Run Dimensionality Reduction Metrics
    # ────────────────────────────────────────────────────────────────────
    if run_dimred:
        task_start = time.time()
        if verbose:
            print(f"\n{'─'*50}")
            print(f"  TASK 3/3: Dimensionality Reduction Metrics")
            print(f"{'─'*50}")

        try:
            metric_list = METRICS_DIMRED if all_dimred_metrics else None
            df_dimred = cal_dimred(
                adata,
                atlas_name=atlas_name,
                high_dim_key=high_dim_key,
                metric_list=metric_list,
                k_neighbors=k_neighbors,
                n_samples=n_samples,
                seed=seed,
                n_jobs=n_jobs,
                file_dir=file_dir,
                verbose=verbose,
            )

            if not df_dimred.empty:
                unified = pd.DataFrame(
                    {
                        "Atlas Name": atlas_name,
                        "Task": "dimred",
                        "Metric Name": df_dimred["Metric Name"],
                        "Input 1": df_dimred["Low Dim Key"],
                        "Input 2": df_dimred["High Dim Key"],
                        "Value": df_dimred["Value"],
                        "Time (s)": df_dimred["Time (s)"],
                    }
                )
                all_results.append(unified)

            task_elapsed = time.time() - task_start
            n_metrics = len(df_dimred) if not df_dimred.empty else 0
            task_summary["Dimred"] = {
                "count": n_metrics,
                "time": round(task_elapsed, 2),
            }
            if verbose:
                print(f"  ✓ Dimred: {n_metrics} results in {task_elapsed:.1f}s")

        except Exception as e:
            task_elapsed = time.time() - task_start
            logger.error(f"Dimred pipeline failed: {e}")
            task_summary["Dimred"] = {
                "count": 0,
                "time": round(task_elapsed, 2),
                "error": str(e),
            }
            if verbose:
                print(f"  ✗ Dimred failed: {e}")

    # ────────────────────────────────────────────────────────────────────
    # 5. Consolidate & Save
    # ────────────────────────────────────────────────────────────────────
    if all_results:
        df_all = pd.concat(all_results, ignore_index=True)
    else:
        df_all = pd.DataFrame(
            columns=[
                "Atlas Name",
                "Task",
                "Metric Name",
                "Input 1",
                "Input 2",
                "Value",
                "Time (s)",
            ]
        )

    # Save unified CSV
    if not df_all.empty:
        filename = os.path.join(file_dir, f"checkatlas_all_metrics_{atlas_name}.csv")
        df_all.to_csv(filename, index=False)
        if verbose:
            print(f"\nSaved unified results to: {filename}")
        logger.info(f"Saved unified metrics to {filename}")

    # ────────────────────────────────────────────────────────────────────
    # 6. Print Summary
    # ────────────────────────────────────────────────────────────────────
    overall_elapsed = time.time() - overall_start

    if verbose:
        print(f"\n{'='*70}")
        print(f"  CheckAtlas — Results Summary")
        print(f"{'='*70}")
        print(f"  {'Task':<20} {'Metrics':>10} {'Time (s)':>12} {'Status':>10}")
        print(f"  {'─'*52}")

        for task_name, info in task_summary.items():
            status = "✗ Error" if "error" in info else "✓ Done"
            print(
                f"  {task_name:<20} {info['count']:>10} {info['time']:>12.2f} {status:>10}"
            )

        print(f"  {'─'*52}")
        total_count = sum(info["count"] for info in task_summary.values())
        print(f"  {'TOTAL':<20} {total_count:>10} {overall_elapsed:>12.2f}")
        print(f"{'='*70}")

        if not df_all.empty:
            print(f"\n  Unique metrics computed: {df_all['Metric Name'].nunique()}")
            print(f"  Total measurements: {len(df_all)}")

            # Top 5 slowest metrics
            if "Time (s)" in df_all.columns:
                slowest = df_all.nlargest(5, "Time (s)")[
                    ["Task", "Metric Name", "Time (s)"]
                ]
                if not slowest.empty:
                    print(f"\n  Slowest metrics:")
                    for _, row in slowest.iterrows():
                        print(
                            f"    {row['Task']}/{row['Metric Name']}: {row['Time (s)']:.3f}s"
                        )

        print()

    gc.collect()
    return df_all

checkatlas.utils.folders

checkatlas.utils.folders

checkatlas_folders(path: str) -> None

Check in path if the different checkatlas folders exists.
Create them if needed. All folders are given by DICT_FOLDER

Parameters:
  • path (str) –

    Search path for atlas given by user

Returns:
  • None( None ) –

    None

Source code in checkatlas/utils/folders.py
def checkatlas_folders(path: str) -> None:
    """Check in path if the different checkatlas folders exists.<br>
    Create them if needed.
    All folders are given by DICT_FOLDER

    Args:
        path (str): Search path for atlas given by user

    Returns:
        None: None
    """
    global_path = get_workingdir(path)
    if not os.path.exists(global_path):
        os.mkdir(global_path)

    for key_folder in DICT_FOLDER.keys():
        temp_path = os.path.join(global_path, key_folder)
        if not os.path.exists(temp_path):
            logger.debug(f"Create folder: {temp_path}")
            os.mkdir(temp_path)
get_folder(path: str, key_folder: str) -> str

Get the folder path giving the search path and the folder key in DICT_FOLDER

Parameters:
  • path (str) –

    Search path for atlas given by user

  • key_folder (str) –

    key folder in the DICT_FOLDER example: ANNDATA, SUMMARY, UMAP

Returns:
  • str( str ) –

    the folder path

Source code in checkatlas/utils/folders.py
def get_folder(path: str, key_folder: str) -> str:
    """Get the folder path giving the search path and
    the folder key in DICT_FOLDER

    Args:
        path (str): Search path for atlas given by user
        key_folder (str): key folder in the DICT_FOLDER
            example: ANNDATA, SUMMARY, UMAP

    Returns:
        str: the folder path
    """
    return os.path.join(get_workingdir(path), DICT_FOLDER[key_folder])
get_workingdir(path: str) -> str

Return the working_dir = path of search + working_dir with working_dir = checkatlas_files/

Parameters:
  • path (str) –

    Search path for atlas given by user

Returns:
  • str( str ) –

    os.path.join(path, working_dir)

Source code in checkatlas/utils/folders.py
def get_workingdir(path: str) -> str:
    """Return the working_dir = path of search
    + working_dir
    with working_dir = checkatlas_files/

    Args:
        path (str): Search path for atlas given by user

    Returns:
        str: os.path.join(path, working_dir)
    """
    return os.path.join(path, WORKING_DIR)

checkatlas.utils.files

checkatlas.utils.files

checkatlas.utils.checkatlas_arguments

checkatlas.utils.checkatlas_arguments

MAX_N_JOBS = 48 module-attribute

Upper limit on the number of CPU threads any CheckAtlas process may consume. Capped at 48 so that on a 80-core workstation at least 32 threads remain free for other users / pipelines. See :func:cap_n_jobs for the enforcement helper.

cap_n_jobs(n_jobs)

Cap n_jobs at :data:MAX_N_JOBS and treat None / -1 as "use the system default" by resolving to the cap.

Rules
  • None → :data:MAX_N_JOBS (covers the legacy path where getattr(args, "n_jobs", -1) returned the fallback).
  • -1 → :data:MAX_N_JOBS (sklearn / joblib convention: "all available cores", capped for politeness).
  • values > :data:MAX_N_JOBS → :data:MAX_N_JOBS.
  • any positive integer ≤ :data:MAX_N_JOBS → unchanged.
  • anything else (non-integer, zero, negative) → :data:MAX_N_JOBS.
Source code in checkatlas/utils/checkatlas_arguments.py
def cap_n_jobs(n_jobs):
    """Cap *n_jobs* at :data:`MAX_N_JOBS` and treat ``None`` / ``-1``
    as "use the system default" by resolving to the cap.

    Rules:
      * ``None``  → :data:`MAX_N_JOBS` (covers the legacy path where
        ``getattr(args, "n_jobs", -1)`` returned the fallback).
      * ``-1``    → :data:`MAX_N_JOBS` (sklearn / joblib convention:
        "all available cores", capped for politeness).
      * values > :data:`MAX_N_JOBS` → :data:`MAX_N_JOBS`.
      * any positive integer ≤ :data:`MAX_N_JOBS` → unchanged.
      * anything else (non-integer, zero, negative) → :data:`MAX_N_JOBS`.
    """
    if n_jobs is None or n_jobs == -1:
        return MAX_N_JOBS
    try:
        n_jobs = int(n_jobs)
    except (TypeError, ValueError):
        return MAX_N_JOBS
    if n_jobs > MAX_N_JOBS:
        return MAX_N_JOBS
    if n_jobs < 1:
        return MAX_N_JOBS
    return n_jobs
get_version()

Get version of checkatlas from checkatlas/VERSION file :return: checkatlas version

Source code in checkatlas/utils/checkatlas_arguments.py
def get_version():
    """
    Get version of checkatlas from checkatlas/VERSION file
    :return: checkatlas version
    """
    version_file = files(__package__).joinpath("VERSION")
    with open(version_file) as file:
        version = file.readline()
        return version