# plot figures
import math
import os
import pathlib
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from scipy import interpolate
# an runtime error
# does not have corresponding correction for a pmf
[docs]
class NoCorrectionFileError(RuntimeError):
def __init__(self, arg):
self.args = arg
[docs]
def isGaWTM(pmfFiles):
"""determine whether the input PMFs indicate Ga-WTM simulations
Args:
pmfFiles (list[str]): path to a set of PMFs (and pmf corrections)
Returns:
bool: GaWTM simulation or not
"""
for file in pmfFiles:
fileName = pathlib.Path(file).name
if fileName.endswith('.reweightamd1.cumulant.pmf') or \
fileName.endswith('.reweightamd1.reweight.pmf'):
return True
return False
[docs]
def getGaWTMBaseName(filePath):
"""Extract the base name from a GaWTM PMF file or its correction file.
For example:
'path/to/001.czar.pmf' -> '001'
'path/to/001.reweightamd1.cumulant.pmf' -> '001'
'path/to/step1.czar.pmf' -> 'step1'
Args:
filePath (str): path to the file
Returns:
str: base name of the file (without extensions)
"""
fileName = pathlib.Path(filePath).name
# Remove known GaWTM suffixes
if fileName.endswith('.reweightamd1.cumulant.pmf'):
return fileName[:-len('.reweightamd1.cumulant.pmf')]
elif fileName.endswith('.czar.pmf'):
return fileName[:-len('.czar.pmf')]
else:
# For other PMF files, just remove .pmf extension
return pathlib.Path(fileName).stem
[docs]
def getGaWTMBaseNames(filePath):
"""Extract all possible base names from a GaWTM PMF file for flexible matching.
This function returns a list of candidate base names to support flexible pairing.
For example::
'path/to/abf_1.abf1.czar.pmf' -> ['abf_1.abf1', 'abf_1']
'path/to/001.czar.pmf' -> ['001']
'path/to/step1.abf2.czar.pmf' -> ['step1.abf2', 'step1']
Args:
filePath (str): path to the file
Returns:
list[str]: list of possible base names (primary first, then alternatives)
"""
primaryBaseName = getGaWTMBaseName(filePath)
candidates = [primaryBaseName]
# Check if the base name contains patterns like .abf1, .abf2, etc.
# Find the last dot and check if it's followed by 'abf' + digits
dotIndex = primaryBaseName.rfind('.')
if dotIndex > 0:
suffix = primaryBaseName[dotIndex + 1:]
if suffix.startswith('abf') and suffix[3:].isdigit():
alternativeBaseName = primaryBaseName[:dotIndex]
if alternativeBaseName not in candidates:
candidates.append(alternativeBaseName)
return candidates
[docs]
def pairGaWTMFiles(pmfFiles):
"""Pair GaWTM PMF files with their corresponding correction files.
This function looks for files with matching base names:
- xxx.czar.pmf paired with xxx.reweightamd1.cumulant.pmf
Files without a corresponding correction file are returned as unpaired.
Orphan correction files (without matching czar.pmf) are also returned separately.
Wrong correction files (.reweightamd1.reweight.pmf instead of .cumulant.pmf) are also detected.
Args:
pmfFiles (list[str]): list of all PMF file paths (including correction files)
Returns:
tuple: (paired_list, unpaired_czar_list, orphan_correction_list, other_files, wrong_correction_files)
paired_list: list of tuples (pmf_file_path, correction_file_path)
unpaired_czar_list: list of czar.pmf file paths without correction
orphan_correction_list: list of correction file paths without matching czar.pmf
other_files: list of other PMF files
wrong_correction_files: list of .reweightamd1.reweight.pmf files (wrong type)
"""
# Separate czar.pmf files and correction files
czar_files = {} # baseName -> filePath
czar_file_candidates = {} # baseName -> list of candidate base names for matching
correction_files = {} # baseName -> filePath
other_files = [] # files that are neither
wrong_correction_files = [] # .reweightamd1.reweight.pmf files (wrong type)
for filePath in pmfFiles:
fileName = pathlib.Path(filePath).name
if fileName.endswith('.reweightamd1.cumulant.pmf'):
baseName = getGaWTMBaseName(filePath)
correction_files[baseName] = filePath
elif fileName.endswith('.reweightamd1.reweight.pmf'):
# Wrong correction file type
wrong_correction_files.append(filePath)
elif fileName.endswith('.czar.pmf'):
baseName = getGaWTMBaseName(filePath)
czar_files[baseName] = filePath
czar_file_candidates[baseName] = getGaWTMBaseNames(filePath)
else:
other_files.append(filePath)
# Pair files by base name (with flexible matching)
paired = []
unpaired_czar = []
matched_corrections = set() # Track which correction files have been matched
for baseName, czarPath in czar_files.items():
# Try all candidate base names for this czar file
candidates = czar_file_candidates.get(baseName, [baseName])
matched = False
for candidate in candidates:
if candidate in correction_files:
paired.append((czarPath, correction_files[candidate]))
matched_corrections.add(candidate)
matched = True
break
if not matched:
unpaired_czar.append(czarPath)
# Find orphan correction files (correction without czar.pmf)
orphan_corrections = []
for baseName, corrPath in correction_files.items():
if baseName not in matched_corrections:
orphan_corrections.append(corrPath)
return paired, unpaired_czar, orphan_corrections, other_files, wrong_correction_files
[docs]
def correctGaWTM(pmfFile, correctionFile=None):
"""read a 1D namd PMF file and correct it using cumulant.pmf file
Args:
pmfFile (str): path to the pmf File
correctionFile (str, optional): path to the correction file.
If None, will try to find it in the same directory (legacy behavior).
Returns:
np.array (N*2): 1D PMF
"""
pmf = np.loadtxt(pmfFile)
# If correction file is not provided, try to find it (legacy behavior)
if correctionFile is None:
correctionFile = pmfFile.replace('.czar.pmf', '') + '.reweightamd1.cumulant.pmf'
if not os.path.exists(correctionFile):
raise NoCorrectionFileError(f'{pmfFile} does not have a corresponding correction!')
correction_data = np.loadtxt(correctionFile)
correction_interpolate = interpolate.interp1d(correction_data[:,0], correction_data[:,1], fill_value="extrapolate")
pmf[:,1] += correction_interpolate(pmf[:,0])
return pmf
[docs]
def readPMF(pmfFile):
"""read a 1D namd PMF file
Args:
pmfFile (str): path to the pmf File
Returns:
np.array (N*2): 1D PMF
"""
return np.loadtxt(pmfFile)
[docs]
def mergePMF(pmfFiles):
"""merge several PMF files
Args:
pmfFiles (list of np.arrays): list of 1D pmfs
Returns:
np.array (N*2): merged PMF if the PMFs overlap, pmfFiles[0] otherwise
"""
numPmfs = len(pmfFiles)
assert(numPmfs > 0)
# sort pmfs
pmfSort = [i for i in range(numPmfs)]
pmfSort.sort(key=lambda x: pmfFiles[x][0][0])
finalPMF = pmfFiles[pmfSort[0]]
if len(pmfFiles) > 1:
for i in range(1, len(pmfFiles)):
for j in range(len(finalPMF)):
if finalPMF[j][0] == pmfFiles[pmfSort[i]][0][0]:
# overlapped region
avgDifference = np.average(finalPMF[j:,1:] - pmfFiles[pmfSort[i]][0:len(finalPMF)-j,1:])
pmfFiles[pmfSort[i]][:,1:] += avgDifference
finalPMF[j:,1:] = (finalPMF[j:,1:] + pmfFiles[pmfSort[i]][0:len(finalPMF)-j,1:]) / 2
# other region
finalPMF = np.append(finalPMF, pmfFiles[pmfSort[i]][len(finalPMF)-j:], axis=0)
break
finalPMF[:,1] -= finalPMF[:,1].min()
return finalPMF
[docs]
def writePMF(pmfFile, pmf):
"""write a 1D namd PMF file
Args:
pmfFile (str): path to the pmf File
pmf (np.array, N*2): pmf to be written
"""
np.savetxt(pmfFile, pmf, fmt='%g')
# ============== History PMF functions ==============
[docs]
def isGaWTMHist(pmfFiles):
"""determine whether the input History PMFs indicate Ga-WTM simulations
Args:
pmfFiles (list[str]): path to a set of History PMFs (and pmf corrections)
Returns:
bool: GaWTM simulation or not
"""
for file in pmfFiles:
fileName = pathlib.Path(file).name
# Check for correction files (correct or wrong type)
if fileName.endswith('.reweightamd1.cumulant.hist.pmf') or \
fileName.endswith('.reweightamd1.cumulant.pmf') or \
fileName.endswith('.reweightamd1.reweight.hist.pmf') or \
fileName.endswith('.reweightamd1.reweight.pmf'):
return True
return False
[docs]
def getGaWTMHistBaseName(filePath):
"""Extract the base name from a GaWTM History PMF file or its correction file.
For example:
'path/to/001.hist.czar.pmf' -> '001'
'path/to/001.reweightamd1.cumulant.hist.pmf' -> '001'
'path/to/001.reweightamd1.cumulant.pmf' -> '001'
'path/to/step1.hist.czar.pmf' -> 'step1'
Args:
filePath (str): path to the file
Returns:
str: base name of the file (without extensions)
"""
fileName = pathlib.Path(filePath).name
# Remove known GaWTM history suffixes (order matters - check longer first)
if fileName.endswith('.reweightamd1.cumulant.hist.pmf'):
return fileName[:-len('.reweightamd1.cumulant.hist.pmf')]
elif fileName.endswith('.reweightamd1.cumulant.pmf'):
return fileName[:-len('.reweightamd1.cumulant.pmf')]
elif fileName.endswith('.hist.czar.pmf'):
return fileName[:-len('.hist.czar.pmf')]
else:
# For other PMF files, try to remove common extensions
if fileName.endswith('.hist.pmf'):
return fileName[:-len('.hist.pmf')]
return pathlib.Path(fileName).stem
[docs]
def getGaWTMHistBaseNames(filePath):
"""Extract all possible base names from a GaWTM History PMF file for flexible matching.
This function returns a list of candidate base names to support flexible pairing.
For example::
'path/to/abf_1.abf1.hist.czar.pmf' -> ['abf_1.abf1', 'abf_1']
'path/to/001.hist.czar.pmf' -> ['001']
Args:
filePath (str): path to the file
Returns:
list[str]: list of possible base names (primary first, then alternatives)
"""
primaryBaseName = getGaWTMHistBaseName(filePath)
candidates = [primaryBaseName]
# Check if the base name contains patterns like .abf1, .abf2, etc.
# Find the last dot and check if it's followed by 'abf' + digits
dotIndex = primaryBaseName.rfind('.')
if dotIndex > 0:
suffix = primaryBaseName[dotIndex + 1:]
if suffix.startswith('abf') and suffix[3:].isdigit():
alternativeBaseName = primaryBaseName[:dotIndex]
if alternativeBaseName not in candidates:
candidates.append(alternativeBaseName)
return candidates
[docs]
def pairGaWTMHistFiles(pmfFiles):
"""Pair GaWTM History PMF files with their corresponding correction files.
This function looks for files with matching base names:
- xxx.hist.czar.pmf paired with xxx.reweightamd1.cumulant.hist.pmf or xxx.reweightamd1.cumulant.pmf
Files without a corresponding correction file are returned as unpaired.
Orphan correction files (without matching hist.czar.pmf) are also returned separately.
Wrong correction files (.reweightamd1.reweight.pmf instead of .cumulant.pmf) are also detected.
Args:
pmfFiles (list[str]): list of all History PMF file paths (including correction files)
Returns:
tuple: (paired_list, unpaired_czar_list, orphan_correction_list, other_files, wrong_correction_files)
``paired_list`` is a list of tuples
``(pmf_file_path, correction_file_path, is_hist_correction)``.
``is_hist_correction`` is True if correction is .hist.pmf and False if
it is a single-frame .pmf. Other returned lists contain unpaired
hist.czar.pmf files, orphan correction files, other PMF files, and
wrong .reweightamd1.reweight.pmf correction files.
"""
czar_files = {} # baseName -> filePath
czar_file_candidates = {} # baseName -> list of candidate base names for matching
correction_files = {} # baseName -> (filePath, is_hist_correction)
other_files = []
wrong_correction_files = [] # .reweightamd1.reweight.pmf files (wrong type)
for filePath in pmfFiles:
fileName = pathlib.Path(filePath).name
if fileName.endswith('.reweightamd1.cumulant.hist.pmf'):
baseName = getGaWTMHistBaseName(filePath)
correction_files[baseName] = (filePath, True)
elif fileName.endswith('.reweightamd1.cumulant.pmf'):
baseName = getGaWTMHistBaseName(filePath)
# Only add if not already have a hist correction (prefer hist over single-frame)
if baseName not in correction_files:
correction_files[baseName] = (filePath, False)
elif fileName.endswith('.reweightamd1.reweight.hist.pmf') or \
fileName.endswith('.reweightamd1.reweight.pmf'):
# Wrong correction file type
wrong_correction_files.append(filePath)
elif fileName.endswith('.hist.czar.pmf'):
baseName = getGaWTMHistBaseName(filePath)
czar_files[baseName] = filePath
czar_file_candidates[baseName] = getGaWTMHistBaseNames(filePath)
else:
other_files.append(filePath)
# Pair files by base name (with flexible matching)
paired = []
unpaired_czar = []
matched_corrections = set() # Track which correction files have been matched
for baseName, czarPath in czar_files.items():
# Try all candidate base names for this czar file
candidates = czar_file_candidates.get(baseName, [baseName])
matched = False
for candidate in candidates:
if candidate in correction_files:
corrPath, is_hist = correction_files[candidate]
paired.append((czarPath, corrPath, is_hist))
matched_corrections.add(candidate)
matched = True
break
if not matched:
unpaired_czar.append(czarPath)
orphan_corrections = []
for baseName, (corrPath, _) in correction_files.items():
if baseName not in matched_corrections:
orphan_corrections.append(corrPath)
return paired, unpaired_czar, orphan_corrections, other_files, wrong_correction_files
[docs]
def readHistPMF(histPmfFile):
"""Read a History PMF file and return a list of PMF frames.
Each frame is a 2D numpy array with shape (N, 2) where N is the number of points.
Args:
histPmfFile (str): path to the history PMF file
Returns:
list[np.array]: list of PMF frames, each as (N, 2) array
"""
frames = []
current_frame = []
with open(histPmfFile, 'r') as f:
for line in f:
stripped = line.strip()
# Skip empty lines
if not stripped:
# If we have accumulated data, save the frame
if current_frame:
frames.append(np.array(current_frame))
current_frame = []
continue
# Skip comment/header lines
if stripped.startswith('#'):
continue
# Parse data line
parts = stripped.split()
if len(parts) >= 2:
try:
x = float(parts[0])
y = float(parts[1])
current_frame.append([x, y])
except ValueError:
continue
# Don't forget the last frame if file doesn't end with empty line
if current_frame:
frames.append(np.array(current_frame))
return frames
[docs]
def correctGaWTMHist(histPmfFile, correctionFile, is_hist_correction=True):
"""Apply GaWTM correction to a History PMF file.
Args:
histPmfFile (str): path to the history PMF file (.hist.czar.pmf)
correctionFile (str): path to the correction file
is_hist_correction (bool): True if correction file is history format (.hist.pmf),
False if single-frame format (.pmf)
Returns:
list[np.array]: list of corrected PMF frames
"""
pmf_frames = readHistPMF(histPmfFile)
if not os.path.exists(correctionFile):
raise NoCorrectionFileError(f'{histPmfFile} does not have a corresponding correction!')
if is_hist_correction:
# History correction file - apply frame by frame
correction_frames = readHistPMF(correctionFile)
# Interpolate correction frames if needed to match PMF frames
if len(correction_frames) != len(pmf_frames):
correction_frames = interpolateHistPMFFrames([correction_frames], len(pmf_frames))[0]
corrected_frames = []
for i, pmf in enumerate(pmf_frames):
correction = correction_frames[i]
correction_interp = interpolate.interp1d(
correction[:, 0], correction[:, 1], fill_value="extrapolate"
)
corrected = pmf.copy()
corrected[:, 1] += correction_interp(pmf[:, 0])
corrected_frames.append(corrected)
return corrected_frames
else:
# Single-frame correction file - apply same correction to all frames
correction_data = np.loadtxt(correctionFile)
correction_interp = interpolate.interp1d(
correction_data[:, 0], correction_data[:, 1], fill_value="extrapolate"
)
corrected_frames = []
for pmf in pmf_frames:
corrected = pmf.copy()
corrected[:, 1] += correction_interp(pmf[:, 0])
corrected_frames.append(corrected)
return corrected_frames
[docs]
def interpolateHistPMFFrames(all_hist_pmfs, target_num_frames):
"""Interpolate history PMF frames to a target number of frames.
Each PMF file may have different number of frames. This function interpolates
the frame values to achieve a uniform frame count across all files.
Args:
all_hist_pmfs (list[list[np.array]]): list of history PMFs, each is a list of frames
target_num_frames (int): target number of frames to interpolate to
Returns:
list[list[np.array]]: interpolated history PMFs with uniform frame count
"""
interpolated = []
for hist_pmf in all_hist_pmfs:
num_frames = len(hist_pmf)
if num_frames == target_num_frames:
interpolated.append(hist_pmf)
continue
if num_frames == 0:
interpolated.append([])
continue
# Create interpolated frames
new_frames = []
for target_idx in range(target_num_frames):
# Calculate the source frame index (floating point)
source_idx = target_idx * (num_frames - 1) / (target_num_frames - 1) if target_num_frames > 1 else 0
# Get the two neighboring frames for interpolation
lower_idx = int(source_idx)
upper_idx = min(lower_idx + 1, num_frames - 1)
# Interpolation weight
weight = source_idx - lower_idx
lower_frame = hist_pmf[lower_idx]
upper_frame = hist_pmf[upper_idx]
if weight == 0 or lower_idx == upper_idx:
# No interpolation needed
new_frames.append(lower_frame.copy())
else:
# Interpolate y values (x values should be the same across frames)
# But to be safe, we interpolate upper_frame to lower_frame's x-coordinates
upper_interp = interpolate.interp1d(
upper_frame[:, 0], upper_frame[:, 1], fill_value="extrapolate"
)
new_frame = lower_frame.copy()
new_frame[:, 1] = (1 - weight) * lower_frame[:, 1] + weight * upper_interp(lower_frame[:, 0])
new_frames.append(new_frame)
interpolated.append(new_frames)
return interpolated
[docs]
def mergeHistPMF(all_hist_pmfs):
"""Merge multiple history PMF files frame-by-frame.
Each input is a list of frames (from different PMF windows).
The function first interpolates all to have the same number of frames,
then merges each frame using the standard mergePMF function.
Args:
all_hist_pmfs (list[list[np.array]]): list of history PMFs,
each is a list of frames from one window
Returns:
list[np.array]: merged history PMF (list of merged frames)
"""
if not all_hist_pmfs:
return []
# Find the maximum number of frames
max_frames = max(len(hist_pmf) for hist_pmf in all_hist_pmfs)
if max_frames == 0:
return []
# Interpolate all to have the same number of frames
interpolated = interpolateHistPMFFrames(all_hist_pmfs, max_frames)
# Merge frame-by-frame
merged_frames = []
for frame_idx in range(max_frames):
# Collect the same frame from all windows
frames_to_merge = [hist_pmf[frame_idx] for hist_pmf in interpolated if hist_pmf]
if frames_to_merge:
merged = mergePMF(frames_to_merge)
merged_frames.append(merged)
return merged_frames
[docs]
def writeHistPMF(histPmfFile, frames):
"""Write a History PMF file.
Args:
histPmfFile (str): path to the output history PMF file
frames (list[np.array]): list of PMF frames to write
"""
with open(histPmfFile, 'w') as f:
for frame_idx, frame in enumerate(frames):
# Write frame header (similar to original format)
f.write(f'# 1\n')
if len(frame) > 0:
x_min = frame[0, 0]
x_max = frame[-1, 0]
dx = (x_max - x_min) / (len(frame) - 1) if len(frame) > 1 else 0.1
f.write(f'# {x_min:.14e} {dx:.14e} {len(frame)} 0\n')
f.write('\n')
# Write data
for row in frame:
f.write(f' {row[0]:.14e} {row[1]:.14e}\n')
# Add blank line between frames
f.write('\n')
[docs]
def plotPMF(pmf):
"""plot a pmf
Args:
pmf (np.array, N*2): pmf to be plotted
"""
plt.plot(pmf[:,0], pmf[:,1])
plt.xlabel('Transition coordinate')
plt.ylabel('ΔG (kcal/mol)')
plt.show()
[docs]
def plotHysteresis(forwardProfile, backwardProfile):
"""plot the profile describing the hysteresis between forward and backward
simulations
Args:
forwardProfile (np.array, N*2): forward free-energy profile to be plotted
backwardProfile (np.array, N*2): backward free-energy profile to be plotted
"""
plt.plot(forwardProfile[:,0], forwardProfile[:,1], label='Forward')
plt.plot(backwardProfile[:,0], backwardProfile[:,1], label='Backward')
plt.xlabel('Lambda')
plt.ylabel('ΔG (kcal/mol)')
plt.legend()
plt.show()
[docs]
def saveHysteresis(forwardProfile, backwardProfile, filePath):
"""save the hysteresis data to a text file
Args:
forwardProfile (np.array, N*2): forward free-energy profile data
backwardProfile (np.array, N*2): backward free-energy profile data
filePath (str): path to save the data file
"""
# Combine forward and backward data into a single array
# Format: Lambda, Forward_dG, Backward_dG
combined_data = np.column_stack([
forwardProfile[:,0],
forwardProfile[:,1],
backwardProfile[:,1]
])
header = 'Lambda\tForward_dG(kcal/mol)\tBackward_dG(kcal/mol)'
np.savetxt(filePath, combined_data, fmt='%g', header=header, delimiter='\t')
[docs]
def calcRMSD(inputArray):
"""calculate RMSD of a np.array with respect to (0,0,0,...0)
Args:
inputArray (1D np.array): the input array
Returns:
float: RMSD of a np.array with respect to (0,0,0,...0)
"""
sumG2 = sum(map(lambda x: x * x, inputArray))
return math.sqrt(sumG2 / len(inputArray))
[docs]
def readFrame(input):
"""read a frame of Colvars hist file and calculate its RMSD with respect to zero array
Args:
input (python file object): input object
Returns:
float: RMSD with respect to zero array
"""
G = []
while True:
line = input.readline()
# end of file
if not line:
return False
splitedLine = line.strip().split()
if splitedLine == []:
if G == []:
continue
else:
break
if splitedLine[0].startswith('#'):
continue
G.append(float(splitedLine[1]))
if G != []:
return calcRMSD(G)
else:
return None
[docs]
def parseHistFile(histPath):
"""parse a hist.czar.pmf file and return frame-RMSD list
Args:
histPath (str): path to a hist.czar.pmf file
Returns:
1D np.array: time evolution of RMSD with respect to zero array
"""
rmsd = []
with open(histPath, 'r') as ifile:
while True:
rmsdPerFrame = readFrame(ifile)
if rmsdPerFrame is False:
break
rmsd.append(rmsdPerFrame)
return rmsd
[docs]
def plotConvergence(rmsdList):
"""plot the time evolution of PMF rmsd
Args:
rmsdList (list or 1D np.array, float): time evolution of RMSD with respect to zero array
"""
plt.plot(range(1, len(rmsdList) + 1), rmsdList)
plt.xlabel('Frame')
plt.ylabel('RMSD (Colvars Unit)')
plt.show()
[docs]
def saveConvergence(rmsdList, filePath):
"""save the PMF RMSD convergence data to a text file
Args:
rmsdList (list or 1D np.array, float): time evolution of RMSD with respect to zero array
filePath (str): path to save the data file
"""
frames = np.arange(1, len(rmsdList) + 1)
data = np.column_stack([frames, rmsdList])
header = 'Frame\tRMSD(Colvars_Unit)'
np.savetxt(filePath, data, fmt='%g', header=header, delimiter='\t')