'''
Feature analysis module for af2rave.
'''
import glob
from natsort import natsorted
import numpy as np
import mdtraj as md
from pathlib import Path
from .utils import *
from numpy.typing import NDArray
from sklearn.decomposition import PCA
[docs]
class FeatureSelection:
"""
Reads an ensemble of PDB files and performs feature selection.
:param input: The name(s) of the PDB file from reduced MSA.
If a directory is provided, all PDB files in the directory will be loaded.
If a list of PDB files is provided, all files will be loaded.
It can also load trajectories files as long as it can be loaded by MDTraj.
:param ref_pdb: The name of the reference structure. If none is provided,
the first frame of the input PDB file will be used as the reference.
"""
def __init__(self, input: str | list[str], ref_pdb: str | None = None) -> None:
# Handle input loading
if isinstance(input, str):
input = [input]
pdb_names = [] # stores all frame names
trj_objs = [] # store all trajectory objects
for pp in input:
p = Path(pp)
if not p.exists():
raise FileNotFoundError(f"'{pp}' does not exist")
if p.is_dir():
pdb_names += natsorted(glob.glob(f"{pp}/*.pdb"))
trj_objs.append(md.load(pdb_names))
else:
try:
trj = md.load(pp, top=ref_pdb)
except UserWarning:
pass
pdb_names += (
[f"{pp}.frame_{i}" for i in range(len(trj))] if len(trj) > 1 else [pp]
)
trj_objs.append(trj)
if len(pdb_names) == 0:
raise ValueError("No valid structure files found.")
self._pdb_name = pdb_names
self._traj: md.Trajectory = md.join(trj_objs)
# Load reference structure and topology
if ref_pdb:
self._ref_pdb = ref_pdb
self._ref = md.load(self._ref_pdb)
else:
self._ref_pdb = self._pdb_name[0]
self._ref = self._traj[0]
self._top = self._ref.topology
# Feature storage
self._features: dict = {}
self._atom_pairs: dict = {}
# ===== Properties =====
@property
def traj(self) -> md.Trajectory:
'''
Return a MDTraj object of all structures.
:return: The MDTraj object.
:rtype: md.Trajectory
'''
return self._traj
@property
def pdb_name(self) -> list[str]:
'''
The list of pdb names.
'''
return self._pdb_name
@property
def ref_pdb(self) -> str:
'''
The reference pdb name.
'''
return self._ref_pdb
@property
def top(self) -> md.Topology:
'''
The topology of the reference structure.
'''
return self._top
@property
def features(self) -> dict[str, NDArray[np.float64]]:
'''
The features dictionary. The key is the feature name and the value is the feature array.
'''
return self._features
@property
def atom_pairs(self) -> dict[str, NDArray[np.int_]]:
'''
The atom pairs dictionary. The key is the feature name and the value is the atom pairs.
'''
return self._atom_pairs
@property
def feature_array(self) -> NDArray[np.float64]:
'''
The feature array, with each feature stacked column-wise.
:return: The feature array.
:rtype: np.ndarray[float]
'''
return np.column_stack(list(self.features.values())) if self.features else np.empty((0, 0))
def __len__(self) -> int:
'''
The number of structures in the trajectory.
'''
return len(self.pdb_name)
def __getitem__(self, key):
try:
idx = self._pdb_name.index(key)
except ValueError:
raise KeyError(f"Structure '{key}' not found in the trajectory.")
return self._traj[idx]
# ===== Preprocessing =====
def _select_and_validate(self,
selection: str,
min_atoms: int | None = 1) -> NDArray[np.int_]:
"""
Select atoms from the trajectory and validate the selection
to ensure it contains at least `min_atoms` atoms.
:param str selection:
The selection string.
:param int min_atoms:
The minimum number of atoms required.
Default: 1.
Set to `None` to disable the check.
:return: An array of atom indices.
:raises ValueError:
If the selection is invalid or does not contain enough atoms.
"""
try:
atom_indices = self._top.select(selection)
except Exception as e:
raise ValueError(f"Invalid selection: {selection}. Error: {e}")
if min_atoms is not None and len(atom_indices) < min_atoms:
raise ValueError(f"Selection '{selection}' contains only {len(atom_indices)} atoms, "
f"which is less than the required {min_atoms}.")
return atom_indices
[docs]
def get_rmsd(self, selection: str = "name CA") -> dict[str, float]:
'''
Get the RMSD of the atoms in the selection for each frame in the trajectory.
The reference structure is provided in the constructor.
:param selection: str:
The selection string to use to select the atoms.
:return: Dictionary of pdb names and their RMSD values. Units: Angstrom.
:rtype: dict[str, float]
'''
sel = self._select_and_validate(selection, 2)
rmsd = md.rmsd(self._traj, self._ref, atom_indices=sel) * 10
return {pdb: r for pdb, r in zip(self._pdb_name, rmsd)}
@property
def peptide_bond_stats(self) -> dict[str, NDArray[np.float64]]:
'''
Get the mean and standard deviation of the peptide bondlengths
per structure. A dictionary with the pdb names as keys.
'''
atom_pairs = []
for c in self.top.chains:
chainid = c.index
sel_C = self._select_and_validate(f'protein and chainid {chainid} and name C', 2)[:-1]
sel_N = self._select_and_validate(f'protein and chainid {chainid} and name N', 2)[1:]
atom_pairs.append(np.column_stack((sel_N, sel_C)))
atom_pairs = np.vstack(atom_pairs)
distances = md.compute_distances(self._traj, atom_pairs=atom_pairs) * 10 # Angstrom
mean = distances.mean(1)
std = distances.std(1)
return {
pdb: r for pdb, r in
zip(self._pdb_name, np.column_stack((mean, std)))
}
@property
def nonbonded_pairs(self) -> list[tuple[int, int]]:
'''
The non-bonded atom pairs in the structure.
'''
from itertools import combinations
traj_noH = self._select_and_validate('not element H')
pairs = set(combinations(traj_noH, 2))
bonded_pairs = {(b[0].index, b[1].index) for b in self._top.bonds}
pairs.difference_update(bonded_pairs)
nb_pairs = list(pairs)
return nb_pairs
@property
def minimum_nonbonded_distance(self) -> dict[str, float]:
from scipy.spatial import KDTree
heavy_atoms = self._select_and_validate('not element H')
nb_pairs = set(self.nonbonded_pairs)
full_coords = self._traj.xyz
search_radius = 0.5
def min_dist_from_pairs(frame_coords, pairs):
if not pairs:
return search_radius * 10
i0 = np.fromiter((p[0] for p in pairs), int)
i1 = np.fromiter((p[1] for p in pairs), int)
a = frame_coords[i0]
b = frame_coords[i1]
diff = a - b
sq = np.einsum('ij,ij->i', diff, diff)
return 10.0 * np.sqrt(sq.min())
def min_dist_frame(f_idx):
# f_idx is the frame index
# c is the subset of heavy atom coordinates for the KDTree
c = full_coords[f_idx, heavy_atoms, :]
tree = KDTree(c)
found_pairs = tree.query_pairs(r=search_radius, output_type='set')
global_pairs = {(heavy_atoms[i], heavy_atoms[j]) for i, j in found_pairs}
global_pairs.intersection_update(nb_pairs)
return min_dist_from_pairs(full_coords[f_idx], global_pairs)
min_dist = [min_dist_frame(i) for i in range(len(full_coords))]
return {pdb: r for pdb, r in zip(self._pdb_name, min_dist)}
# ===== Filtering =====
[docs]
def rmsd_filter(self, selection="name CA", rmsd_cutoff: float = 10.0) -> list[str]:
'''
Filter structures with a RMSD cutoff.
Filter structures that are too irrelavant by dropping those with RMSD
larger than a cutoff (in Angstrom). This returns a list of pdb names.
The filter can be subsequently applied by the apply_filter method.
:param float rmsd_cutoff:
The RMSD cutoff value. Default: 10.0 Angstrom
:param str selection:
The selection string to the atoms to calculate the RMSD.
Default: "name CA"
:return: The pdb names of the selected structures
:rtype: list[str]
:raises ValueError: If no structures meet the cutoff criteria.
'''
rmsd = self.get_rmsd(selection)
mask = [k for k, v in rmsd.items() if v <= rmsd_cutoff]
if len(mask) == 0:
raise ValueError(f"No structures are below the RMSD cutoff of {rmsd_cutoff} Angstrom.")
return mask
[docs]
def peptide_bond_filter(self, mean_cutoff=1.4, std_cutoff=0.2) -> list[str]:
'''
Filter structures with a peptide bond cutoff.
Some AlphaFold2 generated structures have unrealistic backbone structures,
often characterized with too long or too short peptide bonds.
The mean and standard deviation of the peptide bond lengths are calculated for each structure.
If the mean is larger than the cutoff, or the standard deviation
is larger than the cutoff, the structure will be filtered out.
:param float mean_cutoff:
Maximum allowed mean peptide bond length per structure.
Default: 1.4 Angstrom
:param float std_cutoff:
Maximum allowed standard deviation of peptide bond length per structure.
Default: 0.2 Angstrom
:return: The pdb names of the selected structures
:rtype: list[str]
:raises ValueError: If no structures meet the cutoff criteria.
'''
mask = [k for k, (m, s) in self.peptide_bond_stats.items()
if m <= mean_cutoff and s <= std_cutoff
]
if len(mask) == 0:
raise ValueError("No structures are below the peptide bond cutoffs of "
f"mean={mean_cutoff} Angstrom and "
f"std={std_cutoff} Angstrom."
)
return mask
[docs]
def steric_clash_filter(self, min_non_bonded_cutoff=1.1) -> list[str]:
'''
Filter structures based on non-bonded heavy atom distances.
Some AlphaFold2-generated structures have steric clashes
between non-bonded atoms. This method filters out structures
where non-bonded heavy atom distances are too short, leading
to overlap in van der Waals radii.
:param float min_non_bonded_cutoff:
Minimum allowed non-bonded heavy atom distance.
Default: 1.1 Angstrom
:return: The pdb names of the selected structures
:rtype: list[str]
:raises ValueError: If no structures meet the cutoff criteria.
'''
min_nb_dists = self.minimum_nonbonded_distance
mask = [k for k, v in min_nb_dists.items() if v >= min_non_bonded_cutoff]
if len(mask) == 0:
raise ValueError("No structures are above the dist cutoff of "
f"{min_non_bonded_cutoff} Angstrom.")
return mask
[docs]
def apply_filter(self, *args: list[str]) -> None:
'''
Apply a mask to the trajectory.
Each mask is a list of strings which are pdb names to keep.
Multiple masks can be applied at once.
Example:
.. code-block:: python
fs.apply_filter(mask)
fs.apply_filter(mask1, mask2)
:param list[str] mask: The mask to apply.
:raises ValueError: If the mask is invalid.
'''
mask = natsorted(set.intersection(*map(set, args)))
# Check if the intersection is empty
if len(mask) == 0:
raise ValueError("No structures are selected by the filter.")
# Check if the mask is valid
exist = [m in self.pdb_name for m in mask]
if not all(exist):
non_exist = [m for m, e in zip(mask, exist) if not e]
raise ValueError(f"Invalid mask. Some structures do not exist: {non_exist}")
# Apply the mask
idx_slices = [self._pdb_name.index(m) for m in mask]
slices = [self._traj.slice(i, copy=False) for i in idx_slices]
self._pdb_name = mask
self._traj = md.join(slices, check_topology=False)
# Also check if features need to be updated
if self._features:
self._features = {k: v[idx_slices] for k, v in self._features.items()}
# ===== Feature selection =====
def _get_atom_pairs(self, selection: str | tuple[str, str]) -> NDArray[np.int_]:
"""
Get the atom pairs from the selection string.
- If `selection` is a string, it returns all pairs of atoms in the selection.
- If `selection` is a tuple of two strings, it returns all pairs of atoms between the two selections.
:param selection: A string representing a single selection or a tuple of two selection strings.
:type selection: str | tuple[str, str]
:return: A NumPy array of atom pairs.
:raises ValueError: If `selection` is not a string or a tuple of two strings.
"""
from itertools import combinations, product
if isinstance(selection, str):
atom_index = self._select_and_validate(selection, min_atoms=2)
return np.array(list(combinations(atom_index, 2)), dtype=np.int_)
if isinstance(selection, tuple) and len(selection) == 2:
idx_a = self._select_and_validate(selection[0])
idx_b = self._select_and_validate(selection[1])
return np.array(list(product(idx_a, idx_b)), dtype=np.int_)
raise ValueError("Selection must be a string or a tuple of two strings.")
[docs]
def rank_feature(self,
selection: str | tuple[str, str] | list[str | tuple[str, str]] = "name CA"
) -> tuple[list[str], NDArray[np.float64]]:
"""
Rank the features by the coefficient of variation (CV).
The argument ``selection`` can be:
- A `string`:
Computes all pairs of atoms within the selection.
- A `tuple` of two strings:
Computes all pairs of atoms between the two selections.
- A `list` of strings or tuples:
Computes atom pairs for each selection in the list.
:param selection: The selection string(s) used to determine atom pairs.
:return:
- names: A list of feature names.
- cv: A NumPy array containing the coefficient of variation values.
:raises ValueError: If `selection` is not a valid type.
"""
if isinstance(selection, (str, tuple)):
atom_pairs = self._get_atom_pairs(selection)
elif isinstance(selection, list):
# Shape: (n_pairs, 2)
atom_pairs = np.vstack([self._get_atom_pairs(s) for s in selection])
else:
raise ValueError("Selection must be a string, a tuple of two strings, or a list of them.")
# screen atom_pairs for duplicates
atom_pairs = np.array([[i, j] for i, j in atom_pairs if i != j], dtype=np.int_)
# Compute pairwise distances in nanometers, convert to Angstroms
# Shape: (n_structures, n_pairs)
pw_dist = md.compute_distances(self._traj, atom_pairs, periodic=False) * 10
# Generate feature names
names = [f"{representation(self._top, i)}-{representation(self._top, j)}"
for i, j in atom_pairs
]
# Store features
for name, pwd, ap in zip(names, pw_dist.T, atom_pairs):
self._features[name] = pwd
self._atom_pairs[name] = ap
# Compute coefficient of variation (CV)
mean_dist = np.mean(pw_dist, axis=0)
std_dist = np.std(pw_dist, axis=0)
# Handle division errors safely
with np.errstate(divide='ignore', invalid='ignore'):
cv = np.where(mean_dist != 0, std_dist / mean_dist, np.nan)
# Rank features by CV in descending order
rank = np.argsort(cv)[::-1]
names_sorted = [names[i] for i in rank]
return names_sorted, cv[rank]
# ===== Format conversion =====
[docs]
def get_chimera_plotscript(self,
feature_name: list[str],
add_header: bool = True
) -> str:
"""
Generate a Chimera plotscript to visualize the selected features.
:param feature_name: A list of feature names to visualize.
:param add_header: Whether to add the "open xxx.pdb" header to the plotscript.
:return: The Chimera plotscript as a string.
:raises ValueError: If `feature_name` is None or contains invalid names.
"""
plotscript_lines = set()
for fn in feature_name:
if fn not in self._atom_pairs:
raise ValueError(f"Feature name '{fn}' not found in stored atom pairs.")
i, j = self._atom_pairs[fn]
atom_i = chimera_representation(self.top, i)
atom_j = chimera_representation(self.top, j)
plotscript_lines.add(f"distance {atom_i} {atom_j}")
if "CA" not in atom_i:
plotscript_lines.add(f"show :{resid(self.top, i)} a")
if "CA" not in atom_j:
plotscript_lines.add(f"show :{resid(self.top, j)} a")
plotscript = "\n".join(sorted(plotscript_lines)) + "\n" # Sorting ensures deterministic output
if add_header:
plotscript = f"open {self.ref_pdb}\n{plotscript}"
return plotscript
# ===== Clustering =====
[docs]
def regular_space_clustering(self,
feature_name: list[str],
min_dist: float,
max_centers: int = 100,
batch_size: int = 100,
randomseed: int = 0) -> tuple[NDArray[np.float64], NDArray[np.int_]]:
"""
Performs regular space clustering on the selected dimensions of features.
:param list[str] feature_name:
List of feature names to use for clustering.
:param float min_dist:
Minimum distance between cluster centers. Unit: Angstrom.
:param int max_centers:
Maximum number of cluster centers. Default: 100.
:param int batch_size:
Number of points to process in each batch. Default: 100.
:param int randomseed:
Random seed for the permutation.
:return: A tuple containing:
- center (np.ndarray): Cluster center coordinates.
- center_id (np.ndarray): Indices of the cluster centers.
:raises ValueError: If `max_centers` is exceeded.
"""
if not feature_name:
raise ValueError("Feature list cannot be empty.")
# Extract feature time series and transpose to shape (npoints, nfeatures)
z = np.array([self._features[fn] for fn in feature_name], dtype=np.float64).T
npoints, ndim = z.shape
# Determine the reference index
idx = self.pdb_name.index(self.ref_pdb) if self.ref_pdb in self.pdb_name else 0
# Generate a random permutation while ensuring the reference index remains fixed
rng = np.random.default_rng(seed=randomseed)
perm = rng.permutation(npoints)
lookup = (np.arange(npoints) + np.where(perm == idx)[0][0]) % npoints
perm = perm[lookup]
data = z[perm]
# Initialize cluster centers
center_id = np.full(max_centers, -1, dtype=np.int_)
center_id[0] = perm[0]
ncenter = 1
i = 1
while i < npoints:
x_active = data[i:i + batch_size]
current_centers = data[center_id[center_id != -1]]
# Compute Euclidean distances normalized by sqrt(ndim)
distances = np.linalg.norm(x_active[:, np.newaxis, :] - current_centers[np.newaxis, :, :], axis=2) / np.sqrt(ndim)
# Find indices of points that are at least `min_dist` away from all cluster centers
valid_indices = np.nonzero(np.all(distances > min_dist, axis=1))[0]
if valid_indices.size > 0:
center_id[ncenter] = perm[i + valid_indices[0]]
ncenter += 1
i += valid_indices[0] + 1
else:
i += batch_size
if ncenter >= max_centers:
raise ValueError(f"{i}/{npoints} clustered. "
f"Exceeded the maximum number of cluster centers ({max_centers}). "
f"Consider increasing `min_dist`.")
center_id = center_id[center_id != -1]
return center_id
[docs]
def pca(self, n_components: int = 2, **kwargs) -> tuple[PCA, NDArray[np.float64]]:
"""
Perform Principal Component Analysis (PCA) on the selected features.
:param n_components: The number of principal components to compute.
:param kwargs: Additional keyword arguments to pass to the PCA constructor.
:return: A tuple containing the fitted PCA object and the transformed data.
:raises ValueError: If no features are available for PCA.
"""
if not self.features:
raise ValueError("No features available for PCA.")
# Extract time series data from features
z = np.array([self._features[fn] for fn in self.features], dtype=np.float64).T
# Ensure there are enough features to compute the requested components
if z.shape[1] < n_components:
raise ValueError(f"Number of components ({n_components}) cannot exceed available features ({z.shape[1]}).")
# Perform PCA
pca = PCA(n_components=n_components, **kwargs)
transformed_data = pca.fit_transform(z)
return pca, transformed_data