"""
This module contains helper functions for the ´models´ subpackage.
"""
import logging
import os
import pickle
import dill
from collections import OrderedDict
from typing import Optional, Tuple, Literal
from collections import defaultdict
import scipy.sparse as sp
from scipy.sparse import isspmatrix_csr
from sklearn.preprocessing import normalize
import numpy as np
import pandas as pd
import torch
import scanpy as sc
import anndata
import anndata as ad
from anndata import AnnData, concat
from scipy.sparse import csr_matrix, hstack
from sklearn.neighbors import KNeighborsTransformer
import matplotlib
import matplotlib.pyplot as plt
logger = logging.getLogger(__name__)
def save_model_with_fallback(model, file_path):
try:
# 尝试使用 pickle 保存
with open(file_path, "wb") as f:
pickle.dump(model, f)
print(f"Model saved successfully using pickle at {file_path}")
except (AttributeError, pickle.PicklingError) as e:
# 如果 pickle 保存失败,捕获异常并使用 joblib
print(f"Pickle failed with error: {e}, switching to dill...")
with open(file_path, "wb") as f:
dill.dump(model, f)
print(f"Model saved successfully using dill at {file_path}")
def load_model_with_fallback(file_path):
try:
# 尝试使用 pickle 加载
with open(file_path, "rb") as f:
model = pickle.load(f)
print(f"Model loaded successfully using pickle from {file_path}")
except (AttributeError, pickle.UnpicklingError, EOFError) as e:
# 如果 pickle 加载失败,捕获异常并使用 joblib
print(f"Pickle failed with error: {e}, switching to dill...")
with open(file_path, "rb") as f:
model = dill.load(f)
print(f"Model loaded successfully using dill from {file_path}")
return model
def load_saved_files(
dir_path: str,
query_adata: Optional[AnnData] = None,
ref_adata_name: str = "adata_ref.h5ad",
batch_key: Optional[str] = None,
map_location: Optional[Literal["cpu", "cuda"]] = None,
) -> Tuple[OrderedDict, dict, np.ndarray, ad.AnnData]:
"""
Helper to load saved model files.
Parts of the implementation are adapted from
https://github.com/scverse/scvi-tools/blob/master/scvi/model/base/_utils.py#L55
(01.10.2022)
Parameters
----------
dir_path:
Path where the saved model files are stored.
query_adata:
Query anndata object.
ref_adata_name:
Name of the reference anndata object.
batch_key:
Batch key for the reference anndata object.
map_location:
Memory location where to map the model files to.
Returns
----------
model_state_dict:
The stored model state dict.
var_names:
The stored variable names.
attr_dict:
The stored attributes.
adata_concat:
The concatenated anndata object.
"""
attr_path = os.path.join(dir_path, "attr.pkl")
adata_path = os.path.join(dir_path, ref_adata_name)
var_names_path = os.path.join(dir_path, "var_names.csv")
model_path = os.path.join(dir_path, "model_params.pt")
if os.path.exists(adata_path):
adata_ref = ad.read(adata_path)
adata_ref.X = (
adata_ref.layers["counts"].copy()
if "counts" in adata_ref.layers.keys()
else adata_ref.X
)
else:
raise ValueError("Dir path contains no saved reference anndata")
var_names = np.genfromtxt(var_names_path, delimiter=",", dtype=str)
if query_adata is not None:
query_adata = validate_var_names(query_adata, var_names)
if batch_key is None:
raise ValueError("batch_key is required when query_adata is provided.")
# 检查 batch_key 是否存在于 adata_ref.obs
if batch_key not in adata_ref.obs:
raise ValueError(f"The column '{batch_key}' does not exist in adata_ref.obs.")
# 给 query_adata 添加 batch_key 并标记为新 batch
if batch_key not in query_adata.obs:
new_batch_label = "new_batch" # 可以根据需要动态设置
query_adata.obs[batch_key] = new_batch_label
# 合并数据集
adata_concat = anndata.concat(
[adata_ref, query_adata],
label="projection",
keys=["reference", "query"],
index_unique=None,
join="outer",
)
# 验证合并后的结果是否包含 sample_col
if batch_key not in adata_concat.obs:
raise ValueError(f"The column '{batch_key}' is missing in adata_concat.obs.")
del adata_concat.obsm['garfield_latent'] # remove garfield_latent
else:
adata_concat = adata_ref
model_state_dict = torch.load(model_path, map_location=map_location)
attr_dict = load_model_with_fallback(attr_path)
# with open(attr_path, "rb") as handle:
# attr_dict = pickle.load(handle)
return model_state_dict, var_names, attr_dict, adata_concat
def validate_var_names(adata, source_var_names):
# Warning for gene percentage
user_var_names = adata.var_names
try:
percentage = (
len(user_var_names.intersection(source_var_names)) / len(user_var_names)
) * 100
percentage = round(percentage, 4)
if percentage != 100:
logger.warning(
f"WARNING: Query shares {percentage}% of its genes with the reference."
"This may lead to inaccuracy in the results."
)
except Exception:
logger.warning("WARNING: Something is wrong with the reference genes.")
user_var_names = user_var_names.astype(str)
new_adata = adata
# Get genes in reference that are not in query
ref_genes_not_in_query = []
for name in source_var_names:
if name not in user_var_names:
ref_genes_not_in_query.append(name)
if len(ref_genes_not_in_query) > 0:
print(
"Query data is missing expression data of ",
len(ref_genes_not_in_query),
" genes which were contained in the reference dataset.",
)
print("The missing information will be filled with zeroes.")
filling_X = np.zeros((len(adata), len(ref_genes_not_in_query)))
if isinstance(adata.X, csr_matrix):
filling_X = csr_matrix(filling_X) # support csr sparse matrix
new_target_X = hstack((adata.X, filling_X))
else:
new_target_X = np.concatenate((adata.X, filling_X), axis=1)
new_target_vars = adata.var_names.tolist() + ref_genes_not_in_query
new_adata = AnnData(new_target_X, dtype="float32")
new_adata.var_names = new_target_vars
new_adata.obs = adata.obs.copy()
if len(user_var_names) - (len(source_var_names) - len(ref_genes_not_in_query)) > 0:
print(
"Query data contains expression data of ",
len(user_var_names) - (len(source_var_names) - len(ref_genes_not_in_query)),
" genes that were not contained in the reference dataset. This information "
"will be removed from the query data object for further processing.",
)
# remove unseen gene information and order anndata
new_adata = new_adata[:, source_var_names].copy()
print(new_adata)
return new_adata
[docs]
def weighted_knn_trainer(train_adata, train_adata_emb, n_neighbors=50):
"""
Trains a weighted KNN classifier on ``train_adata``.
Parameters
----------
train_adata: :class:`~anndata.AnnData`
Annotated dataset to be used to train KNN classifier with ``label_key`` as the target variable.
train_adata_emb: str
Name of the obsm layer to be used for calculation of neighbors. If set to "X", anndata.X will be
used
n_neighbors: int
Number of nearest neighbors in KNN classifier.
"""
print(f"Weighted KNN with n_neighbors = {n_neighbors} ... ")
k_neighbors_transformer = KNeighborsTransformer(
n_neighbors=n_neighbors,
mode="distance",
algorithm="brute",
metric="euclidean",
n_jobs=-1,
)
if train_adata_emb == "X":
train_emb = train_adata.X
elif train_adata_emb in train_adata.obsm.keys():
train_emb = train_adata.obsm[train_adata_emb]
else:
raise ValueError(
"train_adata_emb should be set to either 'X' or the name of the obsm layer to be used!"
)
k_neighbors_transformer.fit(train_emb)
return k_neighbors_transformer
[docs]
def weighted_knn_transfer(
query_adata,
query_adata_emb,
ref_adata_obs,
label_keys,
knn_model,
threshold=1,
pred_unknown=False,
mode="package",
):
"""
Annotates ``query_adata`` cells with an input trained weighted KNN classifier.
Parameters
----------
query_adata: :class:`~anndata.AnnData`
Annotated dataset to be used to queryate KNN classifier. Embedding to be used
query_adata_emb: str
Name of the obsm layer to be used for label transfer. If set to "X",
query_adata.X will be used
ref_adata_obs: :class:`pd.DataFrame`
obs of ref Anndata
label_keys: str
Names of the columns to be used as target variables (e.g. cell_type) in ``query_adata``.
knn_model: :class:`~sklearn.neighbors._graph.KNeighborsTransformer`
knn model trained on reference adata with weighted_knn_trainer function
threshold: float
Threshold of uncertainty used to annotating cells as "Unknown". cells with
uncertainties higher than this value will be annotated as "Unknown".
Set to 1 to keep all predictions. This enables one to later on play
with thresholds.
pred_unknown: bool
``False`` by default. Whether to annotate any cell as "unknown" or not.
If `False`, ``threshold`` will not be used and each cell will be annotated
with the label which is the most common in its ``n_neighbors`` nearest cells.
mode: str
Has to be one of "paper" or "package". If mode is set to "package",
uncertainties will be 1 - P(pred_label), otherwise it will be 1 - P(true_label).
"""
if not type(knn_model) == KNeighborsTransformer:
raise ValueError(
"knn_model should be of type sklearn.neighbors._graph.KNeighborsTransformer!"
)
if query_adata_emb == "X":
query_emb = query_adata.X
elif query_adata_emb in query_adata.obsm.keys():
query_emb = query_adata.obsm[query_adata_emb]
else:
raise ValueError(
"query_adata_emb should be set to either 'X' or the name of the obsm layer to be used!"
)
top_k_distances, top_k_indices = knn_model.kneighbors(X=query_emb)
stds = np.std(top_k_distances, axis=1)
stds = (2.0 / stds) ** 2
stds = stds.reshape(-1, 1)
top_k_distances_tilda = np.exp(-np.true_divide(top_k_distances, stds))
weights = top_k_distances_tilda / np.sum(
top_k_distances_tilda, axis=1, keepdims=True
)
cols = ref_adata_obs.columns[ref_adata_obs.columns.str.startswith(label_keys)]
uncertainties = pd.DataFrame(columns=cols, index=query_adata.obs_names)
pred_labels = pd.DataFrame(columns=cols, index=query_adata.obs_names)
for i in range(len(weights)):
for j in cols:
y_train_labels = ref_adata_obs[j].values
unique_labels = np.unique(y_train_labels[top_k_indices[i]])
best_label, best_prob = None, 0.0
for candidate_label in unique_labels:
candidate_prob = weights[
i, y_train_labels[top_k_indices[i]] == candidate_label
].sum()
if best_prob < candidate_prob:
best_prob = candidate_prob
best_label = candidate_label
if pred_unknown:
if best_prob >= threshold:
pred_label = best_label
else:
pred_label = "Unknown"
else:
pred_label = best_label
if mode == "package":
uncertainties.iloc[i][j] = max(1 - best_prob, 0)
else:
raise Exception("Inquery Mode!")
pred_labels.iloc[i][j] = pred_label
print("Label transfer finished!", flush=True)
return pred_labels, uncertainties