Source code for af2rave.spib.spib

'''
The SPIB data analysis module.
'''

import torch
import numpy as np
import time
import hashlib
import shutil

from .spib_result import SPIBResult
from .wrapper import spib as spib_kernel
from ..colvar import Colvar


[docs] class SPIBProcess(object): ''' This is the af2rave wrapper for SPIB. To initialize, provide the list of Colvar files for SPIB to process. :param traj: The list of trajectory files to process. :type traj: str | list[str] ''' def __init__(self, traj: str | list[str], **kwargs): self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"[af2rave.spib] Using device: {self._device}") if isinstance(traj, str): self._traj = [traj] else: self._traj = traj self._n_traj = len(traj) self._kwargs = kwargs # garbage collection self._basename = [] self._traj_data_list = self._load_data() self._traj_labels_list = self._init_default_label() self._min_max_scaling() def _load_data(self) -> list[torch.Tensor]: ''' Load the colvar data from the trajectory files. :return traj_data_list: list of torch.Tensor in designated device :rtype: list[torch.Tensor] ''' traj_data_list = [] for f in self._traj: data = Colvar.from_file(f).data.T data = torch.tensor(data, dtype=torch.float32).to(self._device) traj_data_list.append(data) return traj_data_list def _init_default_label(self): ''' Initialize the default labels for the trajectories. The default labels are one-hot encoded, with the first half of the trajectory labeled as 0 and the second half labeled as 1. :return traj_labels_list: list of torch.Tensor in designated device :rtype: list[torch.Tensor] ''' traj_labels_list = [] for i, f in enumerate(self._traj_data_list): n_data = f.shape[0] scalar_label = np.rint(np.arange(n_data) >= n_data / 2).astype(int) + i * 2 onehot_label = np.eye(self._n_traj * 2)[scalar_label] label = torch.tensor(onehot_label, dtype=torch.float32).to(self._device) traj_labels_list.append(label) return traj_labels_list def _min_max_scaling(self): self._min = torch.tensor(np.inf, dtype=torch.float32).to(self._device) self._max = torch.tensor(-np.inf, dtype=torch.float32).to(self._device) for i, t in enumerate(self._traj_data_list): self._min = torch.minimum(self._min, torch.min(t, 0).values) self._max = torch.maximum(self._max, torch.max(t, 0).values) self._b = self._min self._k = self._max - self._min for i in range(self._n_traj): self._traj_data_list[i] = (self._traj_data_list[i] - self._b) / self._k
[docs] def run(self, time_lag: int, **kwargs): ''' Run SPIB on the loaded data. :param time_lag: The time lag for SPIB. :type time_lag: int :return: SPIBResult object. :rtype: SPIBResult ''' basename = "tmp_" + hashlib.md5(str(time.time()).encode()).hexdigest() self._basename.append(basename) seed = kwargs.get('seed', 42) spib_kernel(self._traj_data_list, self._traj_labels_list, time_lag, base_path=basename, device=self._device, **kwargs) prefix = f"{basename}/model_dt_{time_lag}" postfix = f"{seed}.npy" return SPIBResult(prefix, postfix, self._n_traj, dt=time_lag, b=self._b.cpu().numpy(), k=self._k.cpu().numpy())
def __del__(self): for b in self._basename: shutil.rmtree(b, ignore_errors=True)