https://github.com/AstraZeneca/dpp_imp
Raw File
Tip revision: baa623a46174c477c9556112340e9fe5db66955b authored by evolu8 on 14 November 2023, 12:22:36 UTC
Update README.md
Tip revision: baa623a
patcher.py
from .sketchers import *
from .utils import *
from math import ceil, floor
import numpy as np
import inspect


class Patcher:
    def __init__(self,
                sampler,
                n_patches=100,
                subsample=None,
                colsample=1,
                batch_size=None,
                k_max=None,
                verbose=0,
                **kwargs):

        if inspect.isclass(sampler):
            self.sampler = sampler
            self.sampler_args = {}
        else:
            self.sampler = sampler.__class__
            self.sampler_args = sampler.args

        self.n_patches = n_patches
        self.subsample = subsample
        self.colsample = colsample
        self.batch_size = batch_size
        self.k_max=k_max


    def init(self, X, y, y_stratify=None):

        n,d = X.shape

        if not self.k_max:
            self.k_max = d

        if not self.subsample and self.batch_size:
            self.subsample = self.k_max/self.batch_size
        elif not self.subsample and not self.batch_size:
            self.subsample=0.5

        self.total_rows = max(int(n*self.subsample), 2)

        if not self.batch_size:
            self.batch_size = min(int(self.k_max/self.subsample), n)
            self.n_samples_rows = self.k_max

        else:
            self.n_samples_rows = int(self.batch_size*self.subsample)
            if self.n_samples_rows>self.k_max:
                self.n_samples_rows = self.k_max

        self.n_batches = ceil(self.total_rows/self.n_samples_rows)

        if y_stratify:
            self.stf_batches = StratifiedBatches(X, y_stratify, self.batch_size)
        else:
            self.stf_batches = StratifiedBatches(X, y, self.batch_size)

        self.n_batches_per_epoch = self.stf_batches.n_batches
        self.n_epochs = ceil(self.n_batches/floor(n/self.batch_size))

        if self.n_epochs==1:
            self.n_batches = self.n_batches_per_epoch

        if self.colsample<1:
            self.colsample = min(self.colsample,\
                              (self.batch_size*self.n_batches*self.colsample)/d)
            self.n_samples_cols = int(self.colsample*d)


    def _sketch_rows(self, X, y):

        row_idxs = []
        n_patches_left = self.n_patches

        new_idxs = self.stf_batches.shuffle_and_stratify(self.n_epochs)
        new_idxs = new_idxs[:self.n_batches]


        new_X = X[new_idxs]
        if y is not None:
            new_y = y[new_idxs]
        else:
            new_y = None
        row_idxs = self.sampler(X=new_X,
                                y=new_y,
                                k=self.n_samples_rows,
                                **self.sampler_args).sketch_idxs(self.n_patches)

        if len(row_idxs.shape)<3:
            row_idxs = row_idxs[np.newaxis,:,:]
        row_idxs = np.transpose(np.array(row_idxs), (1,0,2))
        row_idxs = np.array(list(map(lambda i: np.array(row_idxs[i])+(i%self.n_batches_per_epoch)*self.batch_size, range(self.n_batches))))
        row_idxs = np.transpose(row_idxs, (1,0,2))

        self._lst_row_samples = [row_idxs[i].flatten()[:self.total_rows] for i in range(self.n_patches)]


    def _sketch_columns(self, X):

        self._lst_col_samples = self.sampler(X = X.T,
                                            k=self.n_samples_cols,
                                            **self.sampler_args).sketch_idxs(self.n_patches)[:,0,:]


    def sketch_patches_idxs(self, X, y=None, y_stratify=None):

        self.init(X, y, y_stratify)
        self._sketch_rows(X, y)
        if self.colsample<1:
            self._sketch_columns(X)
            lst_samples = [(rows, cols) for rows, cols in zip(self._lst_row_samples, self._lst_col_samples)]
        else:
            lst_samples = [(rows, ...) for rows in self._lst_row_samples]
        return lst_samples


class Patcher2:
    def __init__(self,
                sampler,
                n_patches=100,
                subsample=0.5,
                colsample=1,
                batch_size=None,
                k_max=None,
                verbose=0,
                **kwargs):

        self.n_patches = n_patches
        self.colsample = colsample
        self.sampler = sampler
        self.patcher1 = Patcher(sampler=sampler,
                                n_patches=1,
                                subsample=subsample,
                                colsample=colsample,
                                batch_size=batch_size,
                                k_max=k_max)



    def _sketch_rows(self, X, y):

        lst_row_samples = []

        for i in range(self.n_patches):
            self.patcher1._sketch_rows(X, y)
            lst_row_samples.append(self.patcher1._lst_row_samples)


        self._lst_row_samples = lst_row_samples


    def _sketch_columns(self, X):
        self.patcher1._sketch_columns(X)
        self._lst_col_samples = self.patcher1._lst_col_samples


    def sketch_patches_idxs(self, X, y=None):

        self.patcher1.init(X, y)
        self._sketch_rows(X, y)
        if self.colsample<1:
            self._sketch_columns(X)
            lst_samples = [(rows, cols) for rows, cols in zip(self._lst_row_samples, self._lst_col_samples)]
        else:
            lst_samples = [(rows, ...) for rows in self._lst_row_samples]
        return lst_samples



class Patcher3:
    def __init__(self,
                sampler,
                n_patches=100,
                subsample=None,
                colsample=1,
                batch_size=None,
                k_max=None,
                verbose=0,
                deterministic=True,
                weighted=True,
                **kwargs):

        if inspect.isclass(sampler):
            self.sampler = sampler
            self.sampler_args = {}
        else:
            self.sampler = sampler.__class__
            self.sampler_args = sampler.args


        self.n_patches = n_patches
        self.subsample = subsample
        self.colsample = colsample
        self.batch_size = batch_size
        self.k_max = k_max
        self.deterministic = deterministic
        self.weighted=weighted


    def init(self, X, y, y_stratify=None):

        n,d = X.shape

        if not self.k_max:
            self.k_max = d

        if not self.subsample and self.batch_size:
            self.subsample = self.k_max/self.batch_size
        elif not self.subsample and not self.batch_size:
            self.subsample=0.5

        self.total_rows = max(int(n*self.subsample), 2)

        if not self.batch_size:
            self.batch_size = min(int(self.k_max/self.subsample), n)
            self.n_samples_rows = self.k_max
        else:
            self.n_samples_rows = min(max(int(self.batch_size*self.subsample), 2), self.k_max)

        self.n_batches = floor(self.total_rows/self.n_samples_rows)

        if y_stratify:
            self.stf_batches = StratifiedBatches(X, y_stratify, self.batch_size)
        else:
            self.stf_batches = StratifiedBatches(X, y, self.batch_size)

        self.n_batches_per_epoch = self.stf_batches.n_batches
        self.n_epochs = ceil(self.n_batches/floor(n/self.batch_size))

        if self.n_epochs==1:
            self.n_batches = self.n_batches_per_epoch
        if self.colsample<1:
            self.colsample = min(self.colsample,\
                              (self.batch_size*self.n_batches*self.colsample)/d)
            self.n_samples_cols = int(self.colsample*d)

        self.n_patches_per_run = min(int(self.batch_size/self.n_samples_rows), self.n_patches)
        self.n_repetitions = ceil(self.n_patches/self.n_patches_per_run)



    def _sketch_rows(self, X, y):

        row_idxs = []
        n_patches_left = self.n_patches

        for i in range(self.n_repetitions):
            new_idxs = self.stf_batches.shuffle_and_stratify(self.n_epochs, deterministic=self.deterministic, seed=i*self.n_epochs)
            new_idxs = new_idxs[:self.n_batches]
            new_X = X[new_idxs]
            if y is not None:
                new_y = y[new_idxs]
            else:
                new_y = None
            if n_patches_left>self.n_patches_per_run:
                n_patches = self.n_patches_per_run
                n_patches_left -= self.n_patches_per_run
            else:
                n_patches = n_patches_left

            idxs = self.sampler(X=new_X,
                                y=new_y,
                                k=self.n_samples_rows,
                                **self.sampler_args).sketch_idxs(n_patches)

            if len(idxs.shape)<3:
                idxs = idxs[np.newaxis,:,:]
            idxs = np.transpose(np.array(idxs), (1,0,2))
            idxs = np.array(list(map(lambda i: np.array(idxs[i])+(i%self.n_batches_per_epoch)*self.batch_size, range(self.n_batches))))
            idxs = np.transpose(idxs, (1,0,2))

            row_idxs.extend(idxs)

        self._lst_row_samples = [row_idxs[i].flatten()[:self.total_rows] for i in range(self.n_patches)]

    def _sketch_columns(self, X):

        self._lst_col_samples = self.sampler(X=[X[rows].T for rows in self._lst_row_samples],
                                           k=self.n_samples_cols,
                                           **self.sampler_args).sketch_idxs()


    def sketch_patches_idxs(self, X, y=None, y_stratify=None):

        self.init(X, y, y_stratify)
        self._sketch_rows(X, y)
        if self.colsample<1:
            self._sketch_columns(X)
            lst_samples = [(rows, cols) for rows, cols in zip(self._lst_row_samples, self._lst_col_samples)]
        else:
            lst_samples = [(rows, ...) for rows in self._lst_row_samples]

        if self.weighted:
            weights = [(self.n_patches_per_run-i)/self.n_patches_per_run for i in range(self.n_patches_per_run)]*self.n_repetitions
            weights = weights[:self.n_patches]
            return lst_samples, weights
        else:
            return lst_samples
back to top