https://github.com/phbradley/alphafold_finetune
Tip revision: af1f2f7507975ffc734ae57a928786e7f90f93b1 authored by phbradley on 03 December 2022, 14:42:01 UTC
add --data_dir argument to run_finetuning.py commands
add --data_dir argument to run_finetuning.py commands
Tip revision: af1f2f7
changes_to_alphafold.txt
diff -b -u -r ../clean/alphafold/alphafold/common/residue_constants.py alphafold/common/residue_constants.py
--- ../clean/alphafold/alphafold/common/residue_constants.py 2021-12-31 13:40:56.246103000 -0800
+++ alphafold/common/residue_constants.py 2022-07-07 11:31:02.882022000 -0700
@@ -16,6 +16,7 @@
import collections
import functools
+import os
from typing import List, Mapping, Tuple
import numpy as np
@@ -402,8 +403,9 @@
residue_virtual_bonds: dict that maps resname --> list of Bond tuples
residue_bond_angles: dict that maps resname --> list of BondAngle tuples
"""
- stereo_chemical_props_path = (
- 'alphafold/common/stereo_chemical_props.txt')
+ stereo_chemical_props_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), 'stereo_chemical_props.txt'
+ )
with open(stereo_chemical_props_path, 'rt') as f:
stereo_chemical_props = f.read()
lines_iter = iter(stereo_chemical_props.splitlines())
Only in alphafold/common: stereo_chemical_props.txt
diff -b -u -r ../clean/alphafold/alphafold/model/config.py alphafold/model/config.py
--- ../clean/alphafold/alphafold/model/config.py 2021-12-31 13:40:56.338870000 -0800
+++ alphafold/model/config.py 2022-07-07 12:33:20.067939000 -0700
@@ -93,6 +93,8 @@
}
}
+
+
CONFIG = ml_collections.ConfigDict({
'data': {
'common': {
@@ -119,6 +121,7 @@
'use_templates': False,
},
'eval': {
+ 'crop_size': 256,
'feat': {
'aatype': [NUM_RES],
'all_atom_mask': [NUM_RES, None],
@@ -132,7 +135,9 @@
'atom14_gt_positions': [NUM_RES, None, None],
'atom37_atom_exists': [NUM_RES, None],
'backbone_affine_mask': [NUM_RES],
- 'backbone_affine_tensor': [NUM_RES, None],
+ 'backbone_affine_tensor': [NUM_RES, None], #not used
+ 'backbone_translation': [NUM_RES, None],
+ 'backbone_rotation': [NUM_RES, None],
'bert_mask': [NUM_MSA_SEQ, NUM_RES],
'chi_angles': [NUM_RES, None],
'chi_mask': [NUM_RES, None],
@@ -321,9 +326,10 @@
}
},
'global_config': {
+ 'mixed_precision': False,
'deterministic': False,
'subbatch_size': 4,
- 'use_remat': False,
+ 'use_remat': True,
'zero_init': True
},
'heads': {
@@ -350,7 +356,7 @@
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
- 'weight': 0.01
+ 'weight': 0.00
},
'structure_module': {
'num_layer': 8,
@@ -379,7 +385,7 @@
'weight_frac': 0.5,
'length_scale': 10.,
},
- 'structural_violation_loss_weight': 1.0,
+ 'structural_violation_loss_weight': 0.0,
'violation_tolerance_factor': 12.0,
'weight': 1.0
},
diff -b -u -r ../clean/alphafold/alphafold/model/features.py alphafold/model/features.py
--- ../clean/alphafold/alphafold/model/features.py 2021-12-31 13:40:56.345586000 -0800
+++ alphafold/model/features.py 2022-07-11 08:42:56.737359000 -0700
@@ -35,8 +35,8 @@
if cfg.common.use_templates:
feature_names += cfg.common.template_features
- with cfg.unlocked():
- cfg.eval.crop_size = num_res
+ #with cfg.unlocked():
+ # cfg.eval.crop_size = num_res
return cfg, feature_names
diff -b -u -r ../clean/alphafold/alphafold/model/folding.py alphafold/model/folding.py
--- ../clean/alphafold/alphafold/model/folding.py 2021-12-31 13:40:56.353568000 -0800
+++ alphafold/model/folding.py 2022-07-07 12:42:47.500375000 -0700
@@ -542,7 +542,6 @@
value.update(compute_renamed_ground_truth(
batch, value['final_atom14_positions']))
sc_loss = sidechain_loss(batch, value, self.config)
-
ret['loss'] = ((1 - self.config.sidechain.weight_frac) * ret['loss'] +
self.config.sidechain.weight_frac * sc_loss['loss'])
ret['sidechain_fape'] = sc_loss['fape']
@@ -632,8 +631,9 @@
affine_trajectory = quat_affine.QuatAffine.from_tensor(value['traj'])
rigid_trajectory = r3.rigids_from_quataffine(affine_trajectory)
- gt_affine = quat_affine.QuatAffine.from_tensor(
- batch['backbone_affine_tensor'])
+ #gt_affine = quat_affine.QuatAffine.from_tensor(
+ # batch['backbone_affine_tensor'])
+ gt_affine = quat_affine.QuatAffine(quaternion=None, translation=batch['backbone_translation'], rotation=batch['backbone_rotation'], unstack_inputs=True)
gt_rigid = r3.rigids_from_quataffine(gt_affine)
backbone_mask = batch['backbone_affine_mask']
@@ -754,6 +754,7 @@
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
+ atomtype_radius = np.array(atomtype_radius) # PB fix
atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather(
atomtype_radius, batch['residx_atom14_to_atom37'])
diff -b -u -r ../clean/alphafold/alphafold/model/model.py alphafold/model/model.py
--- ../clean/alphafold/alphafold/model/model.py 2021-12-31 13:40:56.360988000 -0800
+++ alphafold/model/model.py 2022-07-11 08:43:49.033746000 -0700
@@ -26,6 +26,7 @@
import tensorflow.compat.v1 as tf
import tree
+from functools import partial
def get_confidence_metrics(
prediction_result: Mapping[str, Any]) -> Mapping[str, Any]:
@@ -54,17 +55,21 @@
self.config = config
self.params = params
- def _forward_fn(batch):
+ def _forward_fn(batch,
+ is_training=False,
+ compute_loss=False,
+ ensemble_representations=False):
model = modules.AlphaFold(self.config.model)
return model(
batch,
- is_training=False,
- compute_loss=False,
- ensemble_representations=True)
-
- self.apply = jax.jit(hk.transform(_forward_fn).apply)
- self.init = jax.jit(hk.transform(_forward_fn).init)
-
+ is_training=is_training,
+ compute_loss=compute_loss,
+ ensemble_representations=False)
+
+ self.apply = jax.jit(hk.transform(partial(_forward_fn, is_training=True, compute_loss=True)).apply)
+ self.init = jax.jit(hk.transform(partial(_forward_fn, is_training=True, compute_loss=False)).init)
+ self.apply_infer = jax.jit(hk.transform(partial(_forward_fn, is_training=False, compute_loss=True)).apply)
+ self.apply_predict = jax.jit(hk.transform(partial(_forward_fn, is_training=False, compute_loss=False)).apply)
def init_params(self, feat: features.FeatureDict, random_seed: int = 0):
"""Initializes the model parameters.
@@ -130,7 +135,7 @@
self.init_params(feat)
logging.info('Running predict with shape(feat) = %s',
tree.map_structure(lambda x: x.shape, feat))
- result = self.apply(self.params, jax.random.PRNGKey(0), feat)
+ result = self.apply_predict(self.params, jax.random.PRNGKey(0), feat) # was apply(
# This block is to ensure benchmark timings are accurate. Some blocking is
# already happening when computing get_confidence_metrics, and this ensures
# all outputs are blocked on.
diff -b -u -r ../clean/alphafold/alphafold/model/modules.py alphafold/model/modules.py
--- ../clean/alphafold/alphafold/model/modules.py 2021-12-31 13:40:56.368886000 -0800
+++ alphafold/model/modules.py 2022-07-08 11:36:13.436809000 -0700
@@ -30,7 +30,7 @@
import haiku as hk
import jax
import jax.numpy as jnp
-
+import jmp
def softmax_cross_entropy(logits, labels):
"""Computes softmax cross entropy given logits and one-hot class labels."""
@@ -146,15 +146,16 @@
num_ensemble = jnp.asarray(ensembled_batch['seq_length'].shape[0])
- if not ensemble_representations:
- assert ensembled_batch['seq_length'].shape[0] == 1
-
+# if not ensemble_representations:
+# print('ENSEMBLED_BATCH', ensembled_batch['seq_length'].shape[0], ensembled_batch['seq_length'])
+# assert ensembled_batch['seq_length'].shape[0] == 1
def slice_batch(i):
b = {k: v[i] for k, v in ensembled_batch.items()}
b.update(non_ensembled_batch)
return b
# Compute representations for each batch element and average.
+
evoformer_module = EmbeddingsAndEvoformer(
self.config.embeddings_and_evoformer, self.global_config)
batch0 = slice_batch(0)
@@ -331,7 +332,6 @@
else:
num_ensemble = batch_size
ensembled_batch = batch
-
non_ensembled_batch = jax.tree_map(lambda x: x, prev)
return impl(
@@ -357,17 +357,16 @@
# The value for each ensemble batch is the same, so arbitrarily taking
# 0-th.
num_iter = batch['num_iter_recycling'][0]
-
# Add insurance that we will not run more
# recyclings than the model is configured to run.
num_iter = jnp.minimum(num_iter, self.config.num_recycle)
else:
# Eval mode or tests: use the maximum number of iterations.
num_iter = self.config.num_recycle
-
body = lambda x: (x[0] + 1, # pylint: disable=g-long-lambda
get_prev(do_call(x[1], recycle_idx=x[0],
compute_loss=False)))
+
if hk.running_init():
# When initializing the Haiku module, run one iteration of the
# while_loop to initialize the Haiku modules used in `body`.
@@ -669,11 +668,19 @@
v = jnp.einsum('bka,ac->bkc', m_data, v_weights)
- q_avg = utils.mask_mean(q_mask, q_data, axis=1)
+ if self.global_config.mixed_precision:
+ big_n = 6e4
+ small_n = 7e-5
+ else:
+ big_n = 1e9
+ small_n = 1e-10
+
+ q_avg = utils.mask_mean(q_mask, q_data, axis=1, eps=small_n)
q = jnp.einsum('ba,ahc->bhc', q_avg, q_weights) * key_dim**(-0.5)
k = jnp.einsum('bka,ac->bkc', m_data, k_weights)
- bias = (1e9 * (q_mask[:, None, :, 0] - 1.))
+
+ bias = (big_n * (q_mask[:, None, :, 0] - 1.))
logits = jnp.einsum('bhc,bkc->bhk', q, k) + bias
weights = jax.nn.softmax(logits)
weighted_avg = jnp.einsum('bhk,bkc->bhc', weights, v)
@@ -743,7 +750,12 @@
assert len(msa_mask.shape) == 2
assert c.orientation == 'per_row'
- bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
+ if self.global_config.mixed_precision:
+ big_n = 6e4
+ else:
+ big_n = 1e9
+
+ bias = (big_n * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
@@ -810,7 +822,13 @@
msa_act = jnp.swapaxes(msa_act, -2, -3)
msa_mask = jnp.swapaxes(msa_mask, -1, -2)
- bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
+ if self.global_config.mixed_precision:
+ big_n = 6e4
+ else:
+ big_n = 1e9
+
+
+ bias = (big_n * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = hk.LayerNorm(
@@ -922,7 +940,13 @@
pair_act = jnp.swapaxes(pair_act, -2, -3)
pair_mask = jnp.swapaxes(pair_mask, -1, -2)
- bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
+
+ if self.global_config.mixed_precision:
+ big_n = 6e4
+ else:
+ big_n = 1e9
+
+ bias = (big_n * (pair_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
pair_act = hk.LayerNorm(
@@ -1096,7 +1120,7 @@
& (batch['resolution'] <= self.config.max_resolution)).astype(
jnp.float32)
- output = {'loss': loss}
+ output = {'loss': loss, 'lddt_ca': lddt_ca, 'ca_mask': all_atom_mask[None, :, 1:2].astype(jnp.float32)}
return output
@@ -1145,8 +1169,9 @@
predicted_affine = quat_affine.QuatAffine.from_tensor(
value['structure_module']['final_affines'])
# Shape (num_res, 7)
- true_affine = quat_affine.QuatAffine.from_tensor(
- batch['backbone_affine_tensor'])
+ #true_affine = quat_affine.QuatAffine.from_tensor(
+ # batch['backbone_affine_tensor'])
+ true_affine = quat_affine.QuatAffine(quaternion=None, translation=batch['backbone_translation'], rotation=batch['backbone_rotation'], unstack_inputs=True)
# Shape (num_res)
mask = batch['backbone_affine_mask']
# Shape (num_res, num_res)
@@ -1487,7 +1512,7 @@
c.chunk_size,
batched_args=[left_act],
nonbatched_args=[],
- low_memory=True,
+ low_memory=not is_training,
input_subbatch_dim=1,
output_subbatch_dim=0)
@@ -1596,6 +1621,20 @@
safe_key, *sub_keys = safe_key.split(10)
sub_keys = iter(sub_keys)
+ if self.global_config.mixed_precision:
+ mp_string = 'p=f32,c=f16,o=f32'
+ else:
+ mp_string = 'p=f32,c=f32,o=f32'
+
+ get_policy = lambda: jmp.get_policy(mp_string)
+ mp_policy = get_policy()
+ hk.mixed_precision.set_policy(TriangleMultiplication, mp_policy)
+ hk.mixed_precision.set_policy(TriangleAttention, mp_policy)
+ hk.mixed_precision.set_policy(Transition, mp_policy)
+ hk.mixed_precision.set_policy(MSARowAttentionWithPairBias, mp_policy)
+ hk.mixed_precision.set_policy(MSAColumnAttention, mp_policy)
+ hk.mixed_precision.set_policy(MSAColumnGlobalAttention, mp_policy)
+ hk.mixed_precision.set_policy(OuterProductMean, mp_policy)
msa_act = dropout_wrapper_fn(
MSARowAttentionWithPairBias(
@@ -2069,7 +2108,12 @@
jnp.transpose(template_pair_representation, [1, 2, 0, 3]),
[num_res * num_res, num_templates, num_channels])
- bias = (1e9 * (template_mask[None, None, None, :] - 1.))
+ if self.global_config.mixed_precision:
+ big_n = 6e4
+ else:
+ big_n = 1e9
+
+ bias = (big_n * (template_mask[None, None, None, :] - 1.))
template_pointwise_attention_module = Attention(
self.config.attention, self.global_config, query_num_channels)