https://github.com/phbradley/alphafold_finetune
Raw File
Tip revision: af1f2f7507975ffc734ae57a928786e7f90f93b1 authored by phbradley on 03 December 2022, 14:42:01 UTC
add --data_dir argument to run_finetuning.py commands
Tip revision: af1f2f7
changes_to_alphafold.txt
diff -b -u -r ../clean/alphafold/alphafold/common/residue_constants.py alphafold/common/residue_constants.py
--- ../clean/alphafold/alphafold/common/residue_constants.py	2021-12-31 13:40:56.246103000 -0800
+++ alphafold/common/residue_constants.py	2022-07-07 11:31:02.882022000 -0700
@@ -16,6 +16,7 @@
 
 import collections
 import functools
+import os
 from typing import List, Mapping, Tuple
 
 import numpy as np
@@ -402,8 +403,9 @@
     residue_virtual_bonds: dict that maps resname --> list of Bond tuples
     residue_bond_angles: dict that maps resname --> list of BondAngle tuples
   """
-  stereo_chemical_props_path = (
-      'alphafold/common/stereo_chemical_props.txt')
+  stereo_chemical_props_path = os.path.join(
+      os.path.dirname(os.path.abspath(__file__)), 'stereo_chemical_props.txt'
+  )
   with open(stereo_chemical_props_path, 'rt') as f:
     stereo_chemical_props = f.read()
   lines_iter = iter(stereo_chemical_props.splitlines())
Only in alphafold/common: stereo_chemical_props.txt
diff -b -u -r ../clean/alphafold/alphafold/model/config.py alphafold/model/config.py
--- ../clean/alphafold/alphafold/model/config.py	2021-12-31 13:40:56.338870000 -0800
+++ alphafold/model/config.py	2022-07-07 12:33:20.067939000 -0700
@@ -93,6 +93,8 @@
     }
 }
 
+
+
 CONFIG = ml_collections.ConfigDict({
     'data': {
         'common': {
@@ -119,6 +121,7 @@
             'use_templates': False,
         },
         'eval': {
+            'crop_size': 256,
             'feat': {
                 'aatype': [NUM_RES],
                 'all_atom_mask': [NUM_RES, None],
@@ -132,7 +135,9 @@
                 'atom14_gt_positions': [NUM_RES, None, None],
                 'atom37_atom_exists': [NUM_RES, None],
                 'backbone_affine_mask': [NUM_RES],
-                'backbone_affine_tensor': [NUM_RES, None],
+                'backbone_affine_tensor': [NUM_RES, None], #not used
+                'backbone_translation': [NUM_RES, None],
+                'backbone_rotation': [NUM_RES, None],
                 'bert_mask': [NUM_MSA_SEQ, NUM_RES],
                 'chi_angles': [NUM_RES, None],
                 'chi_mask': [NUM_RES, None],
@@ -321,9 +326,10 @@
             }
         },
         'global_config': {
+            'mixed_precision': False,
             'deterministic': False,
             'subbatch_size': 4,
-            'use_remat': False,
+            'use_remat': True,
             'zero_init': True
         },
         'heads': {
@@ -350,7 +356,7 @@
                 'filter_by_resolution': True,
                 'max_resolution': 3.0,
                 'min_resolution': 0.1,
-                'weight': 0.01
+                'weight': 0.00
             },
             'structure_module': {
                 'num_layer': 8,
@@ -379,7 +385,7 @@
                     'weight_frac': 0.5,
                     'length_scale': 10.,
                 },
-                'structural_violation_loss_weight': 1.0,
+                'structural_violation_loss_weight': 0.0,
                 'violation_tolerance_factor': 12.0,
                 'weight': 1.0
             },
diff -b -u -r ../clean/alphafold/alphafold/model/features.py alphafold/model/features.py
--- ../clean/alphafold/alphafold/model/features.py	2021-12-31 13:40:56.345586000 -0800
+++ alphafold/model/features.py	2022-07-11 08:42:56.737359000 -0700
@@ -35,8 +35,8 @@
   if cfg.common.use_templates:
     feature_names += cfg.common.template_features
 
-  with cfg.unlocked():
-    cfg.eval.crop_size = num_res
+  #with cfg.unlocked():
+  #  cfg.eval.crop_size = num_res
 
   return cfg, feature_names
 
diff -b -u -r ../clean/alphafold/alphafold/model/folding.py alphafold/model/folding.py
--- ../clean/alphafold/alphafold/model/folding.py	2021-12-31 13:40:56.353568000 -0800
+++ alphafold/model/folding.py	2022-07-07 12:42:47.500375000 -0700
@@ -542,7 +542,6 @@
       value.update(compute_renamed_ground_truth(
           batch, value['final_atom14_positions']))
     sc_loss = sidechain_loss(batch, value, self.config)
-
     ret['loss'] = ((1 - self.config.sidechain.weight_frac) * ret['loss'] +
                    self.config.sidechain.weight_frac * sc_loss['loss'])
     ret['sidechain_fape'] = sc_loss['fape']
@@ -632,8 +631,9 @@
   affine_trajectory = quat_affine.QuatAffine.from_tensor(value['traj'])
   rigid_trajectory = r3.rigids_from_quataffine(affine_trajectory)
 
-  gt_affine = quat_affine.QuatAffine.from_tensor(
-      batch['backbone_affine_tensor'])
+  #gt_affine = quat_affine.QuatAffine.from_tensor(
+  #    batch['backbone_affine_tensor'])
+  gt_affine = quat_affine.QuatAffine(quaternion=None, translation=batch['backbone_translation'], rotation=batch['backbone_rotation'], unstack_inputs=True)
   gt_rigid = r3.rigids_from_quataffine(gt_affine)
   backbone_mask = batch['backbone_affine_mask']
 
@@ -754,6 +754,7 @@
       residue_constants.van_der_waals_radius[name[0]]
       for name in residue_constants.atom_types
   ]
+  atomtype_radius = np.array(atomtype_radius) # PB fix
   atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather(
       atomtype_radius, batch['residx_atom14_to_atom37'])
 
diff -b -u -r ../clean/alphafold/alphafold/model/model.py alphafold/model/model.py
--- ../clean/alphafold/alphafold/model/model.py	2021-12-31 13:40:56.360988000 -0800
+++ alphafold/model/model.py	2022-07-11 08:43:49.033746000 -0700
@@ -26,6 +26,7 @@
 import tensorflow.compat.v1 as tf
 import tree
 
+from functools import partial
 
 def get_confidence_metrics(
     prediction_result: Mapping[str, Any]) -> Mapping[str, Any]:
@@ -54,17 +55,21 @@
     self.config = config
     self.params = params
 
-    def _forward_fn(batch):
+    def _forward_fn(batch,
+                    is_training=False,
+                    compute_loss=False,
+                    ensemble_representations=False):
       model = modules.AlphaFold(self.config.model)
       return model(
           batch,
-          is_training=False,
-          compute_loss=False,
-          ensemble_representations=True)
-
-    self.apply = jax.jit(hk.transform(_forward_fn).apply)
-    self.init = jax.jit(hk.transform(_forward_fn).init)
-
+          is_training=is_training,
+          compute_loss=compute_loss,
+          ensemble_representations=False)
+
+    self.apply = jax.jit(hk.transform(partial(_forward_fn, is_training=True, compute_loss=True)).apply)
+    self.init = jax.jit(hk.transform(partial(_forward_fn, is_training=True, compute_loss=False)).init)
+    self.apply_infer = jax.jit(hk.transform(partial(_forward_fn, is_training=False, compute_loss=True)).apply)
+    self.apply_predict = jax.jit(hk.transform(partial(_forward_fn, is_training=False, compute_loss=False)).apply)
   def init_params(self, feat: features.FeatureDict, random_seed: int = 0):
     """Initializes the model parameters.
 
@@ -130,7 +135,7 @@
     self.init_params(feat)
     logging.info('Running predict with shape(feat) = %s',
                  tree.map_structure(lambda x: x.shape, feat))
-    result = self.apply(self.params, jax.random.PRNGKey(0), feat)
+    result = self.apply_predict(self.params, jax.random.PRNGKey(0), feat) # was apply(
     # This block is to ensure benchmark timings are accurate. Some blocking is
     # already happening when computing get_confidence_metrics, and this ensures
     # all outputs are blocked on.
diff -b -u -r ../clean/alphafold/alphafold/model/modules.py alphafold/model/modules.py
--- ../clean/alphafold/alphafold/model/modules.py	2021-12-31 13:40:56.368886000 -0800
+++ alphafold/model/modules.py	2022-07-08 11:36:13.436809000 -0700
@@ -30,7 +30,7 @@
 import haiku as hk
 import jax
 import jax.numpy as jnp
-
+import jmp
 
 def softmax_cross_entropy(logits, labels):
   """Computes softmax cross entropy given logits and one-hot class labels."""
@@ -146,15 +146,16 @@
 
     num_ensemble = jnp.asarray(ensembled_batch['seq_length'].shape[0])
 
-    if not ensemble_representations:
-      assert ensembled_batch['seq_length'].shape[0] == 1
-
+#    if not ensemble_representations:
+#      print('ENSEMBLED_BATCH', ensembled_batch['seq_length'].shape[0], ensembled_batch['seq_length'])
+#      assert ensembled_batch['seq_length'].shape[0] == 1
     def slice_batch(i):
       b = {k: v[i] for k, v in ensembled_batch.items()}
       b.update(non_ensembled_batch)
       return b
 
     # Compute representations for each batch element and average.
+
     evoformer_module = EmbeddingsAndEvoformer(
         self.config.embeddings_and_evoformer, self.global_config)
     batch0 = slice_batch(0)
@@ -331,7 +332,6 @@
       else:
         num_ensemble = batch_size
         ensembled_batch = batch
-
       non_ensembled_batch = jax.tree_map(lambda x: x, prev)
 
       return impl(
@@ -357,17 +357,16 @@
         # The value for each ensemble batch is the same, so arbitrarily taking
         # 0-th.
         num_iter = batch['num_iter_recycling'][0]
-
         # Add insurance that we will not run more
         # recyclings than the model is configured to run.
         num_iter = jnp.minimum(num_iter, self.config.num_recycle)
       else:
         # Eval mode or tests: use the maximum number of iterations.
         num_iter = self.config.num_recycle
-
       body = lambda x: (x[0] + 1,  # pylint: disable=g-long-lambda
                         get_prev(do_call(x[1], recycle_idx=x[0],
                                          compute_loss=False)))
+
       if hk.running_init():
         # When initializing the Haiku module, run one iteration of the
         # while_loop to initialize the Haiku modules used in `body`.
@@ -669,11 +668,19 @@
 
     v = jnp.einsum('bka,ac->bkc', m_data, v_weights)
 
-    q_avg = utils.mask_mean(q_mask, q_data, axis=1)
+    if self.global_config.mixed_precision:
+        big_n = 6e4
+        small_n = 7e-5
+    else:
+        big_n = 1e9
+        small_n = 1e-10
+
+    q_avg = utils.mask_mean(q_mask, q_data, axis=1, eps=small_n)
 
     q = jnp.einsum('ba,ahc->bhc', q_avg, q_weights) * key_dim**(-0.5)
     k = jnp.einsum('bka,ac->bkc', m_data, k_weights)
-    bias = (1e9 * (q_mask[:, None, :, 0] - 1.))
+
+    bias = (big_n * (q_mask[:, None, :, 0] - 1.))
     logits = jnp.einsum('bhc,bkc->bhk', q, k) + bias
     weights = jax.nn.softmax(logits)
     weighted_avg = jnp.einsum('bhk,bkc->bhc', weights, v)
@@ -743,7 +750,12 @@
     assert len(msa_mask.shape) == 2
     assert c.orientation == 'per_row'
 
-    bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
+    if self.global_config.mixed_precision:
+        big_n = 6e4
+    else:
+        big_n = 1e9
+
+    bias = (big_n * (msa_mask - 1.))[:, None, None, :]
     assert len(bias.shape) == 4
 
     msa_act = hk.LayerNorm(
@@ -810,7 +822,13 @@
     msa_act = jnp.swapaxes(msa_act, -2, -3)
     msa_mask = jnp.swapaxes(msa_mask, -1, -2)
 
-    bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
+    if self.global_config.mixed_precision:
+        big_n = 6e4
+    else:
+        big_n = 1e9
+
+
+    bias = (big_n * (msa_mask - 1.))[:, None, None, :]
     assert len(bias.shape) == 4
 
     msa_act = hk.LayerNorm(
@@ -922,7 +940,13 @@
       pair_act = jnp.swapaxes(pair_act, -2, -3)
       pair_mask = jnp.swapaxes(pair_mask, -1, -2)
 
-    bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
+
+    if self.global_config.mixed_precision:
+        big_n = 6e4
+    else:
+        big_n = 1e9
+
+    bias = (big_n * (pair_mask - 1.))[:, None, None, :]
     assert len(bias.shape) == 4
 
     pair_act = hk.LayerNorm(
@@ -1096,7 +1120,7 @@
                & (batch['resolution'] <= self.config.max_resolution)).astype(
                    jnp.float32)
 
-    output = {'loss': loss}
+    output = {'loss': loss, 'lddt_ca': lddt_ca, 'ca_mask': all_atom_mask[None, :, 1:2].astype(jnp.float32)}
     return output
 
 
@@ -1145,8 +1169,9 @@
     predicted_affine = quat_affine.QuatAffine.from_tensor(
         value['structure_module']['final_affines'])
     # Shape (num_res, 7)
-    true_affine = quat_affine.QuatAffine.from_tensor(
-        batch['backbone_affine_tensor'])
+    #true_affine = quat_affine.QuatAffine.from_tensor(
+    #    batch['backbone_affine_tensor'])
+    true_affine = quat_affine.QuatAffine(quaternion=None, translation=batch['backbone_translation'], rotation=batch['backbone_rotation'], unstack_inputs=True)
     # Shape (num_res)
     mask = batch['backbone_affine_mask']
     # Shape (num_res, num_res)
@@ -1487,7 +1512,7 @@
         c.chunk_size,
         batched_args=[left_act],
         nonbatched_args=[],
-        low_memory=True,
+        low_memory=not is_training,
         input_subbatch_dim=1,
         output_subbatch_dim=0)
 
@@ -1596,6 +1621,20 @@
 
     safe_key, *sub_keys = safe_key.split(10)
     sub_keys = iter(sub_keys)
+    if self.global_config.mixed_precision:
+        mp_string = 'p=f32,c=f16,o=f32'
+    else:
+        mp_string = 'p=f32,c=f32,o=f32'
+
+    get_policy = lambda: jmp.get_policy(mp_string)
+    mp_policy = get_policy()
+    hk.mixed_precision.set_policy(TriangleMultiplication, mp_policy)
+    hk.mixed_precision.set_policy(TriangleAttention, mp_policy)
+    hk.mixed_precision.set_policy(Transition, mp_policy)
+    hk.mixed_precision.set_policy(MSARowAttentionWithPairBias, mp_policy)
+    hk.mixed_precision.set_policy(MSAColumnAttention, mp_policy)
+    hk.mixed_precision.set_policy(MSAColumnGlobalAttention, mp_policy)
+    hk.mixed_precision.set_policy(OuterProductMean, mp_policy)
 
     msa_act = dropout_wrapper_fn(
         MSARowAttentionWithPairBias(
@@ -2069,7 +2108,12 @@
         jnp.transpose(template_pair_representation, [1, 2, 0, 3]),
         [num_res * num_res, num_templates, num_channels])
 
-    bias = (1e9 * (template_mask[None, None, None, :] - 1.))
+    if self.global_config.mixed_precision:
+        big_n = 6e4
+    else:
+        big_n = 1e9
+
+    bias = (big_n * (template_mask[None, None, None, :] - 1.))
 
     template_pointwise_attention_module = Attention(
         self.config.attention, self.global_config, query_num_channels)
back to top