https://github.com/jyhmiinlin/pynufft
Raw File
Tip revision: 505b5ef808e2d357b192a6ec1c4d5b4c45606cc9 authored by Jyh-Miin Lin on 14 February 2020, 19:27:23 UTC
commit message
Tip revision: 505b5ef
test_2D_inverse_method.py
    
import os
import sys
# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
import scipy

def test_2D():
    import pkg_resources
    
    DATA_PATH = pkg_resources.resource_filename('pynufft', 'src/data/')
#     PHANTOM_FILE = pkg_resources.resource_filename('pynufft', 'data/phantom_256_256.txt')
    import numpy
    import matplotlib.pyplot
    from pynufft import NUFFT_cpu
    # load example image
#     image = numpy.loadtxt(DATA_PATH +'phantom_256_256.txt')
    image = scipy.misc.ascent()[::2,::2]
    image=image.astype(numpy.float)/numpy.max(image[...])
    #numpy.save('phantom_256_256',image)
    matplotlib.pyplot.imshow(image, cmap=matplotlib.cm.gray)
    matplotlib.pyplot.show()
    print('loading image...')

    
    
    Nd = (256, 256)  # image size
    print('setting image dimension Nd...', Nd)
    Kd = (512, 512)  # k-space size
    print('setting spectrum dimension Kd...', Kd)
    Jd = (6, 6)  # interpolation size
    print('setting interpolation size Jd...', Jd)
    # load k-space points
    # om = numpy.loadtxt(DATA_PATH+'om.txt')
    om = numpy.load(DATA_PATH+'om2D.npz')['arr_0']
    print('setting non-uniform coordinates...')
    matplotlib.pyplot.plot(om[::10,0],om[::10,1],'o')
    matplotlib.pyplot.title('non-uniform coordinates')
    matplotlib.pyplot.xlabel('axis 0')
    matplotlib.pyplot.ylabel('axis 1')
    matplotlib.pyplot.show()

    NufftObj = NUFFT_cpu()
    NufftObj.plan(om, Nd, Kd, Jd)
    
    y = NufftObj.forward(image)
    
    
    print('setting non-uniform data')
    print('y is an (M,) list',type(y), y.shape)
    
    
    W = numpy.ones(Kd, dtype = numpy.complex64)
    for pp in range(0, 200):
        W2 = NufftObj.xx2k(NufftObj.adjoint(NufftObj.forward(NufftObj.k2xx(W))))
        W2 = W2*W2.conj()
        W2 = W2**0.5
        W = (W+0.9)/(W2 + 0.9)
    
    matplotlib.pyplot.subplot(1,2,1)
    matplotlib.pyplot.imshow(W2.real)
    
    matplotlib.pyplot.subplot(1,2,2)
    matplotlib.pyplot.imshow((W/W2).real)
    matplotlib.pyplot.show()
    
#     kspectrum = NufftObj.xx2k( NufftObj.solve(y,solver='bicgstab',maxiter = 100))
    image_restore = NufftObj.solve(y, solver='cg',maxiter=10)
    shifted_kspectrum = numpy.fft.fftshift( numpy.fft.fftn( numpy.fft.fftshift(image_restore)))
    print('getting the k-space spectrum, shape =',shifted_kspectrum.shape)
    print('Showing the shifted k-space spectrum')
    
    matplotlib.pyplot.imshow( shifted_kspectrum.real, cmap = matplotlib.cm.gray, norm=matplotlib.colors.Normalize(vmin=-100, vmax=100))
    matplotlib.pyplot.title('shifted k-space spectrum')
    matplotlib.pyplot.show()
    image2 = NufftObj.solve(y, 'dc', maxiter = 25)
#     image3 = NufftObj.solve(y, 'L1TVLAD',maxiter=100, rho= 1)
    image3 = NufftObj.k2xx(NufftObj.xx2k(NufftObj.adjoint(y))*W)
    print(image3.shape)
    image4 = NufftObj.solve(   y,'L1TVOLS',maxiter=100, rho= 1)
    matplotlib.pyplot.subplot(1,3,1)
    matplotlib.pyplot.imshow(image, cmap=matplotlib.cm.gray, norm=matplotlib.colors.Normalize(vmin=0.0, vmax=1))
    matplotlib.pyplot.subplot(1,3,2)
    matplotlib.pyplot.imshow(image3.real, cmap=matplotlib.cm.gray, norm=matplotlib.colors.Normalize(vmin=0.0, vmax=1))
    matplotlib.pyplot.subplot(1,3,3)
    matplotlib.pyplot.imshow(image4.real, cmap=matplotlib.cm.gray, norm=matplotlib.colors.Normalize(vmin=0.0, vmax=1))
    matplotlib.pyplot.show()  
  
#     matplotlib.pyplot.imshow(image2.real, cmap=matplotlib.cm.gray, norm=matplotlib.colors.Normalize(vmin=0.0, vmax=1))
#     matplotlib.pyplot.show()
    maxiter =25
    counter = 1
    for solver in ('dc','bicg','bicgstab','cg', 'gmres','lgmres',  'lsqr'):
        print(counter, solver)
        if 'lsqr' == solver:
            image2 = NufftObj.solve(y, solver,iter_lim=maxiter)
        else:
            image2 = NufftObj.solve(y, solver,maxiter=maxiter)
#     image2 = NufftObj.solve(y, solver='bicgstab',maxiter=30)
        matplotlib.pyplot.subplot(2,4,counter)
        matplotlib.pyplot.imshow(image2.real, cmap=matplotlib.cm.gray, norm=matplotlib.colors.Normalize(vmin=0.0, vmax=1))
        matplotlib.pyplot.title(solver)
#         print(counter, solver)
        counter += 1
    matplotlib.pyplot.show()

# def test_asoperator():
#     
#     import pkg_resources
#     
#     DATA_PATH = pkg_resources.resource_filename('pynufft', 'data/')
# #     PHANTOM_FILE = pkg_resources.resource_filename('pynufft', 'data/phantom_256_256.txt')
#     import numpy
#     import matplotlib.pyplot
#     # load example image
# #     image = numpy.loadtxt(DATA_PATH +'phantom_256_256.txt')
#     image = scipy.misc.ascent()
#     
#     image = scipy.misc.imresize(image, (256,256))
#     
#     image=image.astype(numpy.float)/numpy.max(image[...])
#     #numpy.save('phantom_256_256',image)
#     matplotlib.pyplot.imshow(image, cmap=matplotlib.cm.gray)
#     matplotlib.pyplot.show()
#     print('loading image...')
# 
#     
#     
#     Nd = (256, 256)  # image size
#     print('setting image dimension Nd...', Nd)
#     Kd = (512, 512)  # k-space size
#     print('setting spectrum dimension Kd...', Kd)
#     Jd = (6, 6)  # interpolation size
#     print('setting interpolation size Jd...', Jd)
#     # load k-space points
#     # om = numpy.loadtxt(DATA_PATH+'om.txt')
#     om = numpy.load(DATA_PATH+'om2D.npz')['arr_0']
#     print('setting non-uniform coordinates...')
#     matplotlib.pyplot.plot(om[::10,0],om[::10,1],'o')
#     matplotlib.pyplot.title('non-uniform coordinates')
#     matplotlib.pyplot.xlabel('axis 0')
#     matplotlib.pyplot.ylabel('axis 1')
#     matplotlib.pyplot.show()
# 
#     NufftObj = NUFFT()
#     NufftObj.plan(om, Nd, Kd, Jd)
#         
#     print('NufftObj.dtype=',NufftObj.dtype, '  NufftObj.shape=',NufftObj.shape)
#     
#     y1 = NufftObj.forward(image)
#     x_vec= numpy.reshape(image,  (numpy.prod(NufftObj.st['Nd']), ) , order='F')
#     
#     y2 = NufftObj._matvec(x_vec)
#     
#     A = scipy.sparse.linalg.LinearOperator(NufftObj.shape, matvec=NufftObj._matvec, rmatvec=NufftObj._adjoint, dtype=NufftObj.dtype)
# #     scipy.sparse.linalg.aslinearoperator(A)
#     
#     print(type(A))
# #     y1 = A.matvec(x_vec)
#     print('y1.shape',numpy.shape(y1))  
#     import time
#     t0=time.time()
#     KKK = scipy.sparse.linalg.lsqr(A, y1, )
#     print(numpy.shape(KKK))  
#     
#     
#     print(time.time() - t0)
#     
#     x2 = numpy.reshape(KKK[0], NufftObj.st['Nd'], order='F')
#     
#     
#     matplotlib.pyplot.subplot(2,1,1)
#     matplotlib.pyplot.imshow(x2.real,cmap=matplotlib.cm.gray)
#     matplotlib.pyplot.subplot(2,1,2)
#     matplotlib.pyplot.imshow(image,cmap=matplotlib.cm.gray)
#     
#     matplotlib.pyplot.show()
#     
#     
#     print('y1 y2 close? ', numpy.allclose(y1, y2))
        

if __name__ == '__main__':
    """
    Test the module pynufft
    """
    test_2D()
#     test_asoperator()
#     test_installation()
back to top