https://github.com/tiga1231/sgd2
Raw File
Tip revision: 13ca3d978626473e59fdddd641e457edc57208bc authored by jack on 02 March 2022, 18:24:49 UTC
update OS note
Tip revision: 13ca3d9
gd2.py
from utils import utils
from utils import vis as V
from utils.CrossingDetector import CrossingDetector


import criteria as C
import quality as Q

import time
import itertools

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

import pickle as pkl


        
        
def is_interactive():
    import __main__ as main
    return not hasattr(main, '__file__')

if is_interactive():
    from tqdm.notebook import tqdm
    from IPython import display
else:
    from tqdm import tqdm
    display = None



##TODO custom lr_scheduler, linear slowdown on plateau 

class GD2:
    def __init__(self, G, device='cpu'):
        self.G = G
        
        
        self.D, self.adj_sparse, self.k2i = utils.shortest_path(G)
        
        self.adj = torch.from_numpy((self.adj_sparse+self.adj_sparse.T).toarray())
        self.D = torch.from_numpy(self.D)
        self.i2k = {i:k for k,i in self.k2i.items()}
        self.W = 1/(self.D**2+1e-6)
#         self.W = 1/(self.D**1.3+1e-6)
#         self.W = 1/(self.D**0.5+1e-6)

        self.degrees = np.array([self.G.degree(self.i2k[i]) for i in range(len(self.G))])
        self.maxDegree = max(dict(self.G.degree).values())

        self.edge_indices = [(self.k2i[e0], self.k2i[e1]) for e0,e1 in self.G.edges]
        self.node_indices = range(len(self.G))
        self.node_index_pairs = np.c_[
            np.repeat(self.node_indices, len(self.G)),
            np.tile(self.node_indices, len(self.G))
        ]
        self.node_index_pairs = self.node_index_pairs[self.node_index_pairs[:,0]<self.node_index_pairs[:,1]]
        
        self.node_edge_pairs = list(itertools.product(self.node_indices, self.edge_indices))
        
        
        incident_edge_groups = [
            [(G.degree(k), self.k2i[k], self.k2i[n])
            for n in G.neighbors(k)]
            for k in G.nodes
        ]
        
        incident_edge_pairs = [
            [
                (i[0],)+i[1:]+j[1:] 
                for i,j in itertools.product(ieg, ieg) 
                if i<j
            ] 
            for ieg in incident_edge_groups
        ]
        self.incident_edge_pairs = sum(incident_edge_pairs, [])


        ## init
#         self.pos = torch.randn(len(self.G.nodes), 2, device=device).requires_grad_(True)
#         self.pos = (0.5*len(self.G.nodes)**0.5) * torch.randn(len(self.G.nodes), 2, device=device)
        self.pos = (len(self.G.nodes)**0.5) * torch.randn(len(self.G.nodes), 2, device=device)
        self.pos = self.pos.requires_grad_(True)
#         self.pos = torch.randn(len(self.G.nodes), 2, device=device)
#         self.pos[:,0] *= 20
#         self.pos.requires_grad_(True)
        
        
        self.qualities_by_time = []
        self.i = 0
        self.runtime = 0
        self.last_time_eval = 0 ## last time that evaluates the quality
        self.last_time_vis = 0
        self.iters = []
        self.loss_curve = []
        self.sample_sizes = {}
        
        self.crossing_detector = CrossingDetector()
        self.crossing_detector_loss_fn = nn.BCELoss()
        self.crossing_pos_loss_fn = nn.BCELoss(reduction='sum')
#         self.crossing_detector_optimizer = optim.SGD(self.crossing_detector.parameters(), lr=0.1)
        self.crossing_detector_optimizer = optim.Adam(self.crossing_detector.parameters(), lr=0.01)
        ## filter out incident edge pairs
        self.non_incident_edge_pairs = [
            [self.k2i[e1[0]], self.k2i[e1[1]], self.k2i[e2[0]], self.k2i[e2[1]]] 
            for e1,e2 in itertools.product(G.edges, G.edges) 
            if e1<e2 and len(set(e1+e2))==4
        ]
        self.device='cpu'
        
        
    def grad_clamp(self, l, c, weight, optimizer, ref=1):
        optimizer.zero_grad()
        l.backward(retain_graph=True)
        grad = self.pos.grad.clone()
        grad_norm = grad.norm(dim=1)
        is_large = grad_norm > weight*ref
        grad[is_large] = grad[is_large] / grad_norm[is_large].view(-1,1) * weight*ref
        self.grads[c] = grad
                    
        
    def optimize(self,
        criteria_weights={'stress':1.0}, 
        sample_sizes={'stress':128},
        evaluate=None,
        evaluate_interval=None,
        evaluate_interval_unit = 'iter',
        max_iter=int(1e4),
        time_limit=7200,
        grad_clamp=4,
        vis_interval=100,
        vis_interval_unit = 'iter',
        clear_output=False,
        optimizer_kwargs=None,
        scheduler_kwargs=None,
        criteria_kwargs=None,         
    ):
        if criteria_kwargs is None:
            criteria_kwargs = {c:dict() for c in criteria_weights}
        self.sample_sizes = sample_sizes
        ## shortcut of object attributes
        G = self.G
        D, k2i = self.D, self.k2i
        i2k = self.i2k
        adj = self.adj
        W = self.W
        pos = self.pos
        degrees = self.degrees
        maxDegree = self.maxDegree
        device = self.device
        
        self.init_sampler(criteria_weights)
            
        ## measure runtime
        t0 = time.time()

        
        ## def optimizer
        if optimizer_kwargs.get('mode', 'SGD') == 'SGD':
            optimizer_kwargs_default = dict(
                lr=1, 
                momentum=0.7, 
                nesterov=True
            )
            Optimizer = optim.SGD
        elif optimizer_kwargs.get('mode', 'SGD') == 'Adam':
            optimizer_kwargs_default = dict(lr=0.0001)
            Optimizer = optim.Adam
        for k,v in optimizer_kwargs.items():
            if k!='mode':
                optimizer_kwargs_default[k] = v
        optimizer = Optimizer([pos], **optimizer_kwargs_default)
        
        
        
#         patience = np.ceil(np.log2(len(G)+1))*100
#         if 'stress' in criteria_weights and sample_sizes['stress'] < 16:
#             patience += 100 * 16/sample_sizes['stress']
        patience = np.ceil(np.log2(len(G)+1)) * 300 * max(1, 16/min(sample_sizes.values()))
        patience = 20000
        scheduler_kwargs_default = dict(
            factor=0.9, 
            patience=patience, 
            min_lr=1e-5, 
            verbose=True
        )
        if scheduler_kwargs is not None:
            for k,v in scheduler_kwargs.items():
                scheduler_kwargs_default[k] = v
        scheduler_kwargs = scheduler_kwargs_default
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs)

        iterBar = tqdm(range(max_iter))

        ## smoothed loss curve during training
        s = 0.5**(1/100) ## smoothing factor for loss curve, e.g. s=0.5**(1/100) is setting 'half-life' to 100 iterations
        weighted_sum_of_loss, total_weight = 0, 0
        vr_target_dist, vr_target_weight = 1, 0
            
        ## start training
        for iter_index in iterBar:
            t0 = time.time()
            ## optimization
            loss = self.pos[0,0]*0 ## dummy loss
            self.grads = {}
            ref = 1
            for c, weight in criteria_weights.items():
                if callable(weight):
                    weight = weight(iter_index)
                    
                if weight == 0:
                    continue
                
                if c == 'stress':
                    sample = self.sample(c)
                    l = weight * C.stress(
                        pos, D, W, 
                        sample=sample, reduce='mean')
                    loss += l
#                     self.grad_clamp(l, c, weight, optimizer, ref)

                    
                    
                elif c == 'ideal_edge_length':
                    sample = self.sample(c)
                    l = weight * C.ideal_edge_length(
                        pos, G, k2i, targetLengths=None, 
                        sampleSize=sample_sizes[c],
                        reduce='mean',
                    )
                    loss += l
#                     self.grad_clamp(l, c, weight, optimizer, ref)

                    

                elif c == 'neighborhood_preservation':
                    sample = self.sample(c).tolist()
                    l = weight * C.neighborhood_preseration(
                        pos, G, adj, 
                        k2i, i2k, 
                        degrees, maxDegree,
                        sample=sample,
#                         n_roots=sample_sizes[c], 
                        depth_limit=2
                    )
                    loss += l
#                     self.grad_clamp(l, c, weight, optimizer, ref)


                elif c == 'crossings':

                    ## neural crossing detector
                    sample = self.sample(c)
                    sample = torch.stack(sample, 1)
                    edge_pair_pos = self.pos[sample].view(-1,8)
                    labels = utils.are_edge_pairs_crossed(edge_pair_pos)
                    
#                     edge_pair_pos = self.pos[sample].view(-1,8)
#                     labels = utils.are_edge_pairs_crossed(edge_pair_pos)
#                     if labels.sum() < 1:
#                         sample = torch.cat([sample, self.sample_crossings(c)], dim=0)
#                         edge_pair_pos = self.pos[sample].view(-1,8)
#                         labels = utils.are_edge_pairs_crossed(edge_pair_pos)   
                        
                        
                    ## train crossing detector
                    self.crossing_detector.train()
                    for _ in range(2):
                        preds = self.crossing_detector(edge_pair_pos.detach().to(device)).view(-1)
                        loss_nn = self.crossing_detector_loss_fn(
                            preds, 
                            (
#                                 labels.float()*0.7+0.15 + 0.10*(2*torch.rand(labels.shape)-1)
                                labels.float()
                            ).to(device)
                        )
                        self.crossing_detector_optimizer.zero_grad()
                        loss_nn.backward()
                        self.crossing_detector_optimizer.step()
                    
                    ## loss of crossing
                    self.crossing_detector.eval()
                    preds = self.crossing_detector(edge_pair_pos.to(device)).view(-1)
                    l = weight * self.crossing_pos_loss_fn(preds, (labels.float()*0).to(device))
                    loss += l
#                     self.grad_clamp(l, c, weight, optimizer, ref)
                    
                elif c == 'crossing_angle_maximization':
                    sample = self.sample(c)
                    sample = torch.stack(sample, dim=-1)
                    pos_segs = pos[sample.flatten()].view(-1,4,2)
                    sample_labels = utils.are_edge_pairs_crossed(pos_segs.view(-1,8))
                    if sample_labels.sum() < 1:
                        sample = self.sample_crossings(c)
                        pos_segs = pos[sample.flatten()].view(-1,4,2)
                        sample_labels = utils.are_edge_pairs_crossed(pos_segs.view(-1,8))
                    
                    l = weight * C.crossing_angle_maximization(
                        pos, G, k2i, i2k,
                        sample = sample,
                        sample_labels = sample_labels,
#                         sampleSize=sample_sizes[c], 
#                         sampleOn='crossings') ## SLOW for large sample size
#                         sampleOn='edges'
                    )
                    loss += l
#                     self.grad_clamp(l, c, weight, optimizer, ref)


                elif c == 'aspect_ratio':
                    sample = self.sample(c)
                    l = weight * C.aspect_ratio(
                        pos, sample=sample,
                        **criteria_kwargs.get(c)
#                         sampleSize=sample_sizes[c],
                    )
                    loss += l
#                     self.grad_clamp(l, c, weight, optimizer, ref)


                elif c == 'angular_resolution':
#                     sample = [self.i2k[i.item()] for i in self.sample(c)]
                    sample = self.sample(c)
                    l = weight * C.angular_resolution(
                        pos, G, k2i, 
                        sampleSize=sample_sizes[c],
                        sample=sample,
                    )
                    
                    loss += l
#                     self.grad_clamp(l, c, weight, optimizer, ref)


                elif c == 'vertex_resolution':
                    sample = self.sample(c)
                    l, vr_target_dist, vr_target_weight = C.vertex_resolution(
                        pos, 
                        sample=sample, 
                        target=1/len(G)**0.5, 
                        prev_target_dist=vr_target_dist,
                        prev_weight=vr_target_weight
                    )
                    l = weight * l
                    loss += l
#                     self.grad_clamp(l, c, weight, optimizer, ref)


                elif c == 'gabriel':
                    sample = self.sample(c)
                    l = weight * C.gabriel(
                        pos, G, k2i, 
                        sample=sample,
#                         sampleSize=sample_sizes[c],
                    )
                    loss += l
#                     self.grad_clamp(l, c, weight, optimizer, ref)


                else:
                    print(f'Criteria not supported: {c}')
#             if len(self.grads) > 0:
#                 pos.grad = sum(g for c,g in self.grads.items())
#                 pos.grad.clamp_(-grad_clamp, grad_clamp)
#                 optimizer.step()
#                 ref = pos.grad.norm(dim=1).max()

            optimizer.zero_grad()
            loss.backward()
            pos.grad.clamp_(-grad_clamp, grad_clamp)
            optimizer.step()
            
            self.runtime += time.time() - t0
            if self.runtime > time_limit:
                qualities = self.evaluate(qualities=evaluate)
                self.qualities_by_time.append(dict(
                    time=self.runtime,
                    iter=self.i,
                    qualities=qualities
                ))
                break
            
            
            if (vis_interval is not None and vis_interval>0) and (
                ## option 1: if unit='iter'
                (vis_interval_unit == 'iter' and (self.i%vis_interval == 0 or self.i == max_iter-1)) 
                or  
                ## option 2: if unit='sec'
                (vis_interval_unit == 'sec' and (
                    self.runtime - self.last_time_vis>= vis_interval
                ))
            ):
#             if vis_interval is not None and vis_interval>0 \
#             and self.i%vis_interval==vis_interval-1:
                pos_numpy = pos.detach().cpu().numpy()
                pos_G = {k:pos_numpy[k2i[k]] for k in G.nodes}
                if display is not None and clear_output:
                    display.clear_output(wait=True)
                V.plot(
                    G, pos_G,
                    self.loss_curve, 
                    self.i, self.runtime, 
                    node_size=0,
                    edge_width=0.6,
                    show=True, save=False
                )
                self.last_time_vis = self.runtime


#             if loss.isnan():
#                 raise Exception('loss is nan')
#                 break
            if pos.isnan().any():
                raise Exception('pos is nan')
                break

            if self.i % 100 == 99:
                iterBar.set_postfix({'loss': loss.item(), })    
            
            ## compute current loss as an expoexponential moving average  for lr scheduling
            ## and loss curve visualization
            weighted_sum_of_loss = weighted_sum_of_loss*s + loss.item()
            total_weight = total_weight*s+1
            if self.i % 10 == 0:
                self.loss_curve.append(weighted_sum_of_loss/total_weight)
                self.iters.append(self.i)
                if scheduler is not None:
                    scheduler.step(self.loss_curve[-1])


            lr = optimizer.param_groups[0]['lr']
            if lr <= scheduler.min_lrs[0]:
                break

                
            ## if eval in enabled
            if (evaluate_interval is not None and evaluate_interval>0) and (
                ## option 1: if unit='iter'
                (evaluate_interval_unit == 'iter' and (self.i%evaluate_interval == 0 or self.i == max_iter-1)) 
                or  
                ## option 2: if unit='sec'
                (evaluate_interval_unit == 'sec' and (
                    self.runtime - self.last_time_eval >= evaluate_interval or self.i == max_iter-1
                ))
            ):
                     
                qualities = self.evaluate(qualities=evaluate)
                self.qualities_by_time.append(dict(
                    time=self.runtime,
                    iter=self.i,
                    qualities=qualities
                ))
                self.last_time_eval = self.runtime
                
            self.i+=1
            
        
        ## attach pos to G.nodes        
        pos_numpy = pos.detach().numpy()
        for k in G.nodes:
            G.nodes[k]['pos'] = pos_numpy[k2i[k],:]
        
        ## prepare result
        return self.get_result_dict(evaluate, sample_sizes)
    
    
    def init_sampler(self, criteria_weights):
        self.samplers = {}
        self.dataloaders = {}
        for c,w in criteria_weights.items():
            if w == 0:
                continue
            if c == 'stress':
                self.dataloaders[c] = DataLoader(
                    self.node_index_pairs, 
                    batch_size=self.sample_sizes[c],
                    shuffle=True)
            elif c == 'ideal_edge_length':
                self.dataloaders[c] = DataLoader(
                    self.edge_indices, 
                    batch_size=self.sample_sizes[c], 
                    shuffle=True)
            elif c == 'neighborhood_preservation':
                self.dataloaders[c] = DataLoader(
                    range(len(self.G.nodes)), 
                    batch_size=self.sample_sizes[c], 
                    shuffle=True)
            elif c == 'crossings':
                self.dataloaders[c] = DataLoader(
                    self.non_incident_edge_pairs, 
                    batch_size=self.sample_sizes[c], 
                    shuffle=True)
            elif c == 'crossing_angle_maximization':
                self.dataloaders[c] = DataLoader(
                    self.non_incident_edge_pairs, 
                    batch_size=self.sample_sizes[c], 
                    shuffle=True)
            elif c == 'aspect_ratio':
                self.dataloaders[c] = DataLoader(
                    range(len(self.G.nodes)), 
                    batch_size=self.sample_sizes[c],
                    shuffle=True
                )
            elif c == 'angular_resolution':
#                 self.dataloaders[c] = DataLoader(
#                     range(len(self.G.nodes)), 
#                     batch_size=self.sample_sizes[c],
#                     shuffle=True)
                self.dataloaders[c] = DataLoader(
                    self.incident_edge_pairs, 
                    batch_size=self.sample_sizes[c],
                    shuffle=True)
            elif c == 'vertex_resolution':
                self.dataloaders[c] = DataLoader(
                    self.node_index_pairs, 
                    batch_size=self.sample_sizes[c],
                    shuffle=True)
            elif c == 'gabriel':
                self.dataloaders[c] = DataLoader(
                    self.node_edge_pairs, 
                    batch_size=self.sample_sizes[c],
                    shuffle=True
                )
    
    def sample_crossings(self, c='crossing_angle_maximization', mode='use_existing_crossings'):
        if not hasattr(self, 'crossing_loaders'):
            self.crossing_loaders = {}    
        if not hasattr(self, 'crossing_samplers'):
            self.crossing_samplers = {}
        
        if mode == 'new' or c not in self.crossing_loaders:
            crossing_segs = utils.find_crossings(self.pos, list(self.G.edges), self.k2i)
            self.crossing_loaders[c] = DataLoader(crossing_segs, batch_size=self.sample_sizes[c])
            self.crossing_samplers[c] = iter(self.crossing_loaders[c])
#             print(f'finding new crossings...{crossing_segs.shape}')
        try:
            sample = next(self.crossing_samplers[c])
        except StopIteration:
            sample = self.sample_crossings(c, mode='new')
        return sample

    
    def sample(self, criterion):
        if criterion not in self.samplers:
            self.samplers[criterion] = iter(self.dataloaders[criterion])
        try:
            sample = next(self.samplers[criterion])
        except StopIteration:
            self.samplers[criterion] = iter(self.dataloaders[criterion])
            sample = next(self.samplers[criterion])
        return sample
        
    
    def get_result_dict(self, evaluate, sample_sizes):
        return dict(
            iter=self.i, 
            loss_curve=self.loss_curve,
            runtime=self.runtime,
            qualities_by_time=self.qualities_by_time,
            qualities=self.qualities_by_time[-1]['qualities'],
            sample_sizes=self.sample_sizes,
            pos=self.pos,
        )
    
    
    def evaluate(
        self,
        pos=None,
        qualities={'stress'},
        verbose=False,
        mode='original'
    ):
        
        if pos is None:
            pos = self.pos
            
        qualityMeasures = dict()
        if qualities == 'all':
            qualities = {
                'stress',
                'ideal_edge_length',
                'neighborhood_preservation',
                'crossings',
                'crossing_angle_maximization',
                'aspect_ratio',
                'angular_resolution',
                'vertex_resolution',
                'gabriel',
            }


        for q in qualities:
            if verbose:
                print(f'Evaluating {q}...', end='')
                
            t0 = time.time()
            if q == 'stress':
                if mode == 'original':
                    qualityMeasures[q] = Q.stress(pos, self.D, self.W, None)
                elif mode == 'best_scale':
                    s = utils.best_scale_stress(pos, self.D, self.W)
                    qualityMeasures[q] = Q.stress(pos*s, self.D, self.W, None)
                    
            elif q == 'ideal_edge_length':
                if mode == 'original':
                    qualityMeasures[q] = Q.ideal_edge_length(pos, self.G, self.k2i)
                elif mode == 'best_scale':
                    s = utils.best_scale_ideal_edge_length(pos, self.G, self.k2i)
                    qualityMeasures[q] = Q.ideal_edge_length(s*pos, self.G, self.k2i)

            elif q == 'neighborhood_preservation':
                qualityMeasures[q] = 1 - Q.neighborhood_preservation(
                    pos, self.G, self.adj, self.i2k)
            elif q == 'crossings':
                qualityMeasures[q] = Q.crossings(pos, self.edge_indices)
            elif q == 'crossing_angle_maximization':
                qualityMeasures[q] = Q.crossing_angle_maximization(
                    pos, self.G.edges, self.k2i)
            elif q == 'aspect_ratio':
                qualityMeasures[q] = 1 - Q.aspect_ratio(pos)
            elif q == 'angular_resolution':
                qualityMeasures[q] = 1- Q.angular_resolution(pos, self.G, self.k2i)
            elif q == 'vertex_resolution':
                qualityMeasures[q] = 1 - Q.vertex_resolution(pos, target=1/len(self.G)**0.5)
            elif q == 'gabriel':
                qualityMeasures[q] = 1 - Q.gabriel(pos, self.G, self.k2i)
           
            if verbose:
                print(f'done in {time.time()-t0:.2f}s')
        return qualityMeasures
    
    def save(self, fn='result.pkl'):
        with open(fn, 'wb') as f:
            pkl.dump(dict(
                G=self.G,
                pos=self.pos,
                i2k=self.i2k,
                k2i=self.k2i,
                iter=self.i,
                runtime=self.runtime,
                loss_curve=self.loss_curve,
                qualities_by_time = self.qualities_by_time,
            ), f)
        
back to top