Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

Revision d60db86886186e80622deaa91045caccaf4103d3 authored by Diego del Alamo on 17 February 2022, 16:16:20 UTC, committed by GitHub on 17 February 2022, 16:16:20 UTC
Update predict.py
1 parent 714006d
  • Files
  • Changes
  • 732ab36
  • /
  • scripts
  • /
  • predict.py
Raw File Download

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • revision
  • directory
  • content
revision badge
swh:1:rev:d60db86886186e80622deaa91045caccaf4103d3
directory badge
swh:1:dir:918316435ea4041acce3baede725569425cbcd95
content badge
swh:1:cnt:7338b21f3efcf83231098517cb362d88b89afac3

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • revision
  • directory
  • content
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
predict.py
from . import util
import os
import numpy as np
import random
import sys

from alphafold.common import protein
from alphafold.model import data
from alphafold.model import config
from alphafold.model import model

from typing import Any, List, Mapping, NoReturn

from absl import logging


def set_config(
    use_templates: bool,
    max_msa_clusters: int,
    max_extra_msa: int,
    max_recycles: int,
    model_id: int,
    n_struct_module_repeats: int,
    n_features_in: int,
    monomer: bool = True,
    model_params: int = 0,
) -> model.RunModel:

    r"""Generated Runner object for AlphaFold

    Parameters
    ----------
    use_templates : Whether templates are used
    max_msa_cluster : How many sequences to use in MSA
    max_extra_msa : How many extra sequences to include for summary stats
    max_recycles : Number of recycling iterations
    model_id : Which AF2 model to use
    n_struct_module_repeats : Number of passes through structure module
    n_features_in : Unclear
    monomer : Predicting as a monomer (set to False if using AlphaFold-multimer)
    model_params : Which AF2 model config to use

    Returns
    ----------
    AlphaFold RunModel object

    """

    if model_id not in range(1, 6):
        logging.warning("model_id must be between 1 and 5!")
        if use_templates:
            model_id = random.randint(1, 2)
        else:
            model_id = random.randint(1, 5)

    # Match model_params to model_id
    # Sometimes we don't want to do this, for example,
    #   to reproduce output from ColabFold (which only uses models 1 and 3)

    name = f"model_{ model_params }_ptm"
    if not monomer:
        name = f"model_{ model_params }_multimer"

    cfg = config.model_config(name)

    #### Provide config settings

    #### MSAs

    cfg.data.eval.num_ensemble = 1
    if max_msa_clusters > 0:
        cfg.data.eval.max_msa_clusters = min(n_features_in, max_msa_clusters)
    if max_extra_msa > 0:
        cfg.data.common.max_extra_msa = max(
            1, min(n_features_in - max_msa_clusters, max_extra_msa)
        )

    #### Recycle and number of iterations

    if monomer:
        cfg.data.common.num_recycle = max_recycles
    cfg.model.num_recycle = max_recycles
    cfg.model.heads.structure_module.num_layer = n_struct_module_repeats

    #### Templates

    t = use_templates  # for brevity

    cfg.data.common.use_templates = use_templates
    cfg.model.embeddings_and_evoformer.template.embed_torsion_angles = t
    cfg.model.embeddings_and_evoformer.template.enabled = t
    cfg.data.common.reduce_msa_clusters_by_max_templates = t
    cfg.data.eval.subsample_templates = t

    p = data.get_model_haiku_params(model_name=name, data_dir=".")

    logging.debug("Prediction parameters:")
    logging.debug("\tModel ID: {}".format(model_id))
    logging.debug("\tUsing templates: {}".format(t))
    logging.debug(
        "\tMaximum MSA clusters: {}".format(cfg.data.eval.max_msa_clusters)
    )
    logging.debug(
        "\tMaximum extra MSA clusters: {}".format(
            cfg.data.common.max_extra_msa
        )
    )
    logging.debug(
        "\tNumber recycling iterations: {}".format(cfg.model.num_recycle)
    )
    logging.debug(
        "\tNumber of structure module repeats: {}".format(
            cfg.model.heads.structure_module.num_layer
        )
    )

    return model.RunModel(cfg, p)


def run_one_job(
    runner: model.RunModel, features_in: dict, random_seed: int, outname: str
) -> Mapping[str, Any]:
    r"""Runs one AF2 job with input parameters

    Parameters
    ----------
    runner : AlphaFold2 job runner
    features_in : Input features, including MSA and templates
    random_seed : Random seed
    outname : Name of PDB file to write

    Returns
    ----------
    None

    """

    # Do one last bit of processing
    features = runner.process_features(features_in, random_seed=random_seed)

    # Generate the model
    result = runner.predict(features, random_seed)
    pred = protein.from_prediction(features, result)

    # Write to file
    to_pdb(outname, pred, result["plddt"], features_in["residue_index"])

    return result


def predict_structure_from_templates(
    seq: str,
    outname: str,
    a3m_lines: str,
    template_path: str,
    model_id: int = -1,
    model_params: int = -1,
    random_seed: int = -1,
    max_msa_clusters: int = -1,
    max_extra_msa: int = -1,
    max_recycles: int = 3,
    n_struct_module_repeats: int = 8,
) -> NoReturn:

    r"""Predicts the structure.

    Parameters
    ----------
    seq : Sequence
    outname : Name of output PDB
    a3m_lines : String of entire alignment
    template_paths : Where to locate templates
    model_id : Which AF2 model to run (must be 1 or 2 for templates)
    model_params : Which parameters to provide to AF2 model
    random_seed : Random seed
    max_msa_clusters : Number of sequences to use
    max_extra_msa : Number of extra seqs for summary stats
    max_recycles : Number of iterations through AF2
    n_struct_module_repeats : Number of passes through structural refinement
    move_prefix : Prefix for temporary files (deleted after fxn completion)

    Returns
    ----------
    None

    """

    if random_seed == -1:
        random_seed = random.randrange(sys.maxsize)

    if model_id not in (1, 2):
        model_id = random.randint(1, 2)

    if model_params not in (1, 2):
        model_params = random.randint(1, 2)

    # Assemble the dictionary of input features
    features_in = util.setup_features(
        seq, a3m_lines, util.mk_template(seq, a3m_lines, template_path).features
    )

    # Run the models
    model_runner = set_config(
        True,
        max_msa_clusters,
        max_extra_msa,
        max_recycles,
        model_id,
        n_struct_module_repeats,
        len(features_in["msa"]),
        model_params=model_params,
    )

    result = run_one_job(model_runner, features_in, random_seed, outname)

    del model_runner

    return result


def predict_structure_no_templates(
    seq: str,
    outname: str,
    a3m_lines: str,
    model_id: int = -1,
    model_params: int = -1,
    random_seed: int = -1,
    max_msa_clusters: int = -1,
    max_extra_msa: int = -1,
    max_recycles: int = 3,
    n_struct_module_repeats: int = 8,
) -> NoReturn:

    r"""Predicts the structure.

    Parameters
    ----------
    seq : Sequence
    outname : Name of output PDB
    a3m_lines : String of entire alignment
    model_id : Which AF2 model to run (must be 1 or 2 for templates)
    random_seed : Random seed
    max_msa_clusters : Number of sequences to use
    max_extra_msa : Number of extra seqs for summary stats
    max_recycles : Number of iterations through AF2
    n_struct_module_repeats : Number of passes through structural refinement

    Returns
    ----------
    None

    """

    # Set AF2 model details
    if model_id not in range(1, 6):
        model_id = random.randint(1, 5)

    if model_params not in range(1, 6):
        model_params = model_id

    if random_seed == -1:
        random_seed = random.randrange(sys.maxsize)

    features_in = util.setup_features(seq, a3m_lines, util.mk_mock_template(seq))

    model_runner = set_config(
        False,
        max_msa_clusters,
        max_extra_msa,
        max_recycles,
        model_id,
        n_struct_module_repeats,
        len(features_in["msa"]),
        model_params=model_params,
    )

    result = run_one_job(model_runner, features_in, random_seed, outname)

    del model_runner

    return result


def to_pdb(
    outname, pred, plddts, res_idx  # type unknown but check?  # type unknown but check?
) -> NoReturn:

    r"""Writes unrelaxed PDB to file

    Parameters
    ----------
    outname : Name of output PDB
    pred : Prediction to write to PDB
    plddts : Predicted errors
    res_idx : Residues to print (default=all)

    Returns
    ----------
    None

    """

    with open(outname, "w") as outfile:
        outfile.write(protein.to_pdb(pred))

    with open(f"b_{ outname }", "w") as outfile:
        for line in open(outname, "r").readlines():
            if line[0:6] == "ATOM  ":
                seq_id = int(line[22:26].strip()) - 1
                seq_id = np.where(res_idx == seq_id)[0][0]
                outfile.write(
                    "{}A{}{:6.2f}{}".format(
                        line[:21], line[22:60], plddts[seq_id], line[66:]
                    )
                )

    os.rename(f"b_{ outname }", outname)
The diff you're trying to view is too large. Only the first 1000 changed files have been loaded.
Showing with 0 additions and 0 deletions (0 / 0 diffs computed)
swh spinner

Computing file changes ...

back to top

Software Heritage — Copyright (C) 2015–2026, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API