Source code for privileged_residues.postproc

import numpy as np
import pyrosetta

from pyrosetta.rosetta.core import scoring
from pyrosetta.rosetta.core.kinematics import MoveMap
from pyrosetta.rosetta.protocols.minimization_packing import MinMover

from rmsd import rmsd

[docs]def filter_clash_minimize(pose, hits, clash_cutoff = 35.0, rmsd_cutoff = 0.5, sfx = None, mmap = None, limit = 0): """Filter match output for clashes, then minimize the remaining structures against the target pose. Parameters ---------- pose : pyrosetta.Pose Target structure. hits : np.ndarray Set of functional group matches against positions in the target. clash_cutoff : float Maximum tolerated increase in score terms during clash checking. sfx : pyrosetta.rosetta.core.scoring.ScoreFunction, optional Scorefunction to use during minimization. If left as None, a default scorefunction is constructed. mmap : pyrosetta.rosetta.protocols.minimization_packing.MinMover, optional Movemap to use during minimization. If left as None, a default movemap is constructed. Yields ------ pyrosetta.Pose Next target structure and matched functional group, minimized and in complex. """ if (not sfx): sfx = scoring.ScoreFunctionFactory.create_score_function("beta_nov16") sfx.set_weight(scoring.hbond_bb_sc, 2.0) sfx.set_weight(scoring.hbond_sc, 2.0) if (not mmap): mmap = MoveMap() mmap.set_bb(False) mmap.set_chi(False) mmap.set_jump(1, True) minmov = MinMover(mmap, sfx, "dfpmin_armijo_nonmonotone", 0.01, False) for n, (hash, match) in enumerate(hits): proto_pose = pose.clone() proto_pose.append_pose_by_jump(match, len(proto_pose.residues)) sfx(proto_pose) fa_rep = proto_pose.energies().total_energies()[scoring.fa_rep] fa_elec = proto_pose.energies().total_energies()[scoring.fa_elec] if (fa_rep + fa_elec > clash_cutoff): continue minmov.apply(proto_pose) minimized = proto_pose.split_by_chain(proto_pose.num_chains()) ori_coords = np.array([atom.xyz() for res in match.residues for atom in res.atoms()]) min_coords = np.array([atom.xyz() for res in minimized.residues for atom in res.atoms()]) if (rmsd(ori_coords, min_coords) < rmsd_cutoff): yield (hash, minimized) if (limit and n + 1 >= limit): break return []