https://github.com/crisrom002/collider-descriptors-deformation-learning
Tip revision: df11ecb7cff38bfd9e321d284b1a3c365ca3e388 authored by crisrom002 on 01 May 2024, 10:42:15 UTC
corrected README typo
corrected README typo
Tip revision: df11ecb
train.py
from typing import Optional
import pathlib
import argparse
import numpy as np
import src.models as models
import src.datasets as datasets
import src.losses as losses
import src.utils as utils
import torch
def train(contact_problem: str,
log_path: pathlib.Path,
num_workers: int = 4,
learning_rate: float = 1e-3,
epochs: int = 100,
epochs_per_test: int = 1,
epochs_early_stop: Optional[int] = None,
tensorboard_log: bool = False):
"""Train a contact deformation model for a given contact interaction,
logging trained model checkpoints and evaluation losses.
Keyword arguments:
contact_problem: Name of the contact problem to be trained.
log_path: Path to store the logging data of the different trainings.
Model checkpoints (.pth) and tensorboard log are stored in this path.
num_workers: Number of worker processes for data loading.
May require adjustments for optimal results in different datasets & systems.
learning_rate: Initial learning rate used during training. Default value of 1e-3.
epochs: Number of training epochs. Maximum epochs in case of using early stoping.
epochs_per_test: Number of consecutive training epochs required for a test error evaluation.
epochs_early_stop: Number of consecutive non improving epoch for early stoping.
If not provided, early stoping is not used.
tensorboard_log: Record a tensorboard log with the train & test losses per epoch.
"""
seed = 0
torch.manual_seed(seed)
print("Torch Seed: ", seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Torch Device: ", device)
######
if contact_problem == "jelly3D_thingi3D":
path_data = pathlib.Path("./data/contacts/jelly3D_thingi3D/")
path_object = pathlib.Path("./data/objects/jelly3D/")
path_collider = pathlib.Path("./data/colliders/thingi3D/")
dataset_shapes = datasets.ShapeDataset(path_collider)
pattern = utils.getPattern(path_collider.parents[1])
model = models.DefModel(path_object,
dataset_shapes,
patch_pattern = pattern,
grid_size_sdf = 50,
grid_size_descriptor = 0,
mode_frame = models.ModeFrame.RANDOM_TANGENT,
biased_frame = False,
samples_per_frame = 50,
frame_radius = 0.1,
patch_radius = 0.2,
max_distance = 0.15,
mask_factor = 0.3)
invariance_weight = 0.2
samples_grid_ids = torch.arange(52*5*7*3*5).reshape(52,5,7,3,5)
data_slice_A = samples_grid_ids[0:30, :, :, :, :].reshape(-1)
data_slice_A = data_slice_A[torch.randperm(data_slice_A.shape[0])]
data_slice_B = samples_grid_ids[30:52, :, :, :, :].reshape(-1)
data_slice_B = data_slice_B[torch.randperm(data_slice_B.shape[0])]
data_slice_train = data_slice_A
data_slice_test = data_slice_B
else:
print("contact_problem does not exist!")
model.to(device)
loader_params_train = {'batch_size': 1024,
'shuffle': True,
'num_workers': num_workers}
loader_params_test = {'batch_size': 5024,
'shuffle': True,
'num_workers': num_workers}
batch_nodes_train = 2
batch_nodes_test = 2
dataset_train = datasets.ContactDataset(device,
model,
path_object,
path_data,
data_slice_train,
near_nodes = True,
batch_nodes = batch_nodes_train)
dataset_test = datasets.ContactDataset(device,
model,
path_object,
path_data,
data_slice_test,
near_nodes = True,
batch_nodes = batch_nodes_test)
dataset_train_loader = torch.utils.data.DataLoader(dataset_train, **loader_params_train)
dataset_test_loader = torch.utils.data.DataLoader(dataset_test, **loader_params_test)
# Training
#############################################################
loss = losses.MSENLoss()
loss.to(device)
loss_invariance = losses.MSENLoss()
loss_invariance.to(device)
optimizer = torch.optim.Adam(model.collider_net.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, 'min', patience=2, factor=0.9)
if epochs_early_stop is not None:
stoper = losses.EarlyStopper(patience=epochs_early_stop, min_delta=0)
checkpoint_path = log_path / "checkpoints"
checkpoint_path.mkdir(parents=True, exist_ok=True)
if tensorboard_log:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(str(log_path.absolute()))
# Normalize Network
model.train()
model.collider_net.clearMean()
for (q, z, s, x, x_target) in dataset_train_loader:
q = q.to(device)
z = z.to(device)
x = x.to(device)
x_target = x_target.to(device)
q_patch_relative, patch_shape, x_corr_local = model.getLocalTransform(q, z, s, x, x_target)
model.collider_net.accumMean(q_patch_relative, patch_shape, x_corr_local)
model.collider_net.divideMean(len(dataset_train_loader))
model.collider_net.clearStd()
for (q, z, s, x, x_target) in dataset_train_loader:
q = q.to(device)
z = z.to(device)
x = x.to(device)
x_target = x_target.to(device)
q_patch_relative, patch_shape, x_corr_local = model.getLocalTransform(q, z, s, x, x_target)
model.collider_net.accumStd(q_patch_relative, patch_shape, x_corr_local)
model.collider_net.divideStd(len(dataset_train_loader))
# Normalize Loss
model.train()
loss.clearNormalization()
loss_invariance.clearNormalization()
for (q, z, s, x, x_target) in dataset_train_loader:
q = q.to(device)
z = z.to(device)
x = x.to(device)
x_target = x_target.to(device)
r = model(q, z, s, x)
r_diffs = r - torch.mean(r, dim=1)[:,None,:]
loss.accumNormalization(r, x_target)
loss_invariance.accumNormalization(r_diffs, 0.0*x_target)
loss.divideNormalization(len(dataset_train_loader))
loss_invariance.divideNormalization(len(dataset_train_loader))
print("Loss mse_loss_norm: ", loss.mse_loss_norm)
print("Loss invariance mse_loss_norm: ", loss_invariance.mse_loss_norm)
print(0, "Loss Train: ", 1.0)
if tensorboard_log:
writer.add_scalar('Loss Train', 1.0, 0)
writer.add_scalar('Loss Train Target', 1.0, 0)
writer.add_scalar('Loss Train Invariance', 1.0, 0)
# Train loop
model.train()
loss_test_best = np.inf
for i in range(epochs):
print("Saving checkpoint epoch " + str(i))
filename = checkpoint_path / f"checkpoint-{i}.pth"
torch.save(model.collider_net.state_dict(), filename)
if i % epochs_per_test == 0:
loss_test_target = 0.0
loss_test_invariance = 0.0
loss_test = 0.0
with torch.no_grad():
for (q, z, s, x, x_target) in dataset_test_loader:
q = q.to(device)
z = z.to(device)
x = x.to(device)
x_target = x_target.to(device)
r = model(q, z, s, x)
r_diffs = r - torch.mean(r, dim=1)[:,None,:]
lo = loss(r, x_target) / len(dataset_test_loader)
lor = loss_invariance(r_diffs, 0.0*x_target) / len(dataset_test_loader)
l = (1.0-invariance_weight)*lo + invariance_weight*lor
loss_test_target += lo.cpu().detach().item()
loss_test_invariance += lor.cpu().detach().item()
loss_test += l.cpu().detach().item()
print(i, "Loss Test: ", loss_test)
if tensorboard_log:
writer.add_scalar('Loss Test', loss_test, i)
writer.add_scalar('Loss Test Target', loss_test_target, i)
writer.add_scalar('Loss Test Invariance', loss_test_invariance, i)
if loss_test <= loss_test_best:
filename = log_path / "model.pth"
torch.save(model.collider_net.state_dict(), filename)
loss_test_best = loss_test
if epochs_early_stop is not None and stoper.step(loss_test, epochs_per_test):
break
loss_train_target = 0.0
loss_train_invariance = 0.0
loss_train = 0.0
for (q, z, s, x, x_target) in dataset_train_loader:
q = q.to(device)
z = z.to(device)
x = x.to(device)
x_target = x_target.to(device)
optimizer.zero_grad()
r = model(q, z, s, x)
r_diffs = r - torch.mean(r, dim=1)[:,None,:]
lo = loss(r, x_target) / len(dataset_train_loader)
lor = loss_invariance(r_diffs, 0.0*x_target) / len(dataset_train_loader)
l = (1.0-invariance_weight)*lo + invariance_weight*lor
l.backward()
optimizer.step()
loss_train_target += lo.cpu().detach().item()
loss_train_invariance += lor.cpu().detach().item()
loss_train += l.cpu().detach().item()
print(i+1, "Loss Train: ", loss_train)
if tensorboard_log:
writer.add_scalar('Loss Train', loss_train, i+1)
writer.add_scalar('Loss Train Target', loss_train_target, i+1)
writer.add_scalar('Loss Train Invariance', loss_train_invariance, i+1)
scheduler.step(loss_train)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--contact_problem', type = str, required=True,
help='Name of the contact problem to be trained')
parser.add_argument('--log_path', type = pathlib.Path, required=True,
help='Path to store the logging data of the different trainings. \
Model checkpoints (.pth) and tensorboard log are stored in this path')
parser.add_argument('--num_workers', type=int, default=4,
help='Number of worker processes for data loading. \
May require adjustments for optimal results in different datasets & systems')
parser.add_argument('--learning_rate', type=float, default=1e-3,
help='Initial learning rate used during training. Default value of 1e-3')
parser.add_argument('--epochs', type=int, default=100,
help='Number of training epochs. Maximum epochs in case of using early stoping')
parser.add_argument('--epochs_per_test', type=int, default=1,
help='Number of consecutive training epochs required for a test error evaluation')
parser.add_argument('--epochs_early_stop', type=int,
help='Number of consecutive non improving epoch for early stoping')
parser.add_argument('--tensorboard_log', action='store_true',
help='Record a tensorboard log with the train & test losses per epoch')
args = parser.parse_args()
train(**vars(args))