Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

https://github.com/awreed/Neural-Volumetric-Reconstruction-for-Coherent-SAS
28 October 2023, 08:36:07 UTC
  • Code
  • Branches (3)
  • Releases (0)
  • Visits
    • Branches
    • Releases
    • HEAD
    • refs/heads/main
    • refs/heads/main2
    • refs/heads/site
    No releases to show
  • ccf2688
  • /
  • sas_utils.py
Raw File Download
Take a new snapshot of a software origin

If the archived software origin currently browsed is not synchronized with its upstream version (for instance when new commits have been issued), you can explicitly request Software Heritage to take a new snapshot of it.

Use the form below to proceed. Once a request has been submitted and accepted, it will be processed as soon as possible. You can then check its processing state by visiting this dedicated page.
swh spinner

Processing "take a new snapshot" request ...

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
  • revision
  • snapshot
origin badgecontent badge Iframe embedding
swh:1:cnt:db7273e424da892ff5f114b689dd087474d40ec9
origin badgedirectory badge Iframe embedding
swh:1:dir:ccf26889638e429deaa3e8d7a0fa77153a6fff49
origin badgerevision badge
swh:1:rev:007dfc65380872d7ae27477eff7e7b5cbb766f71
origin badgesnapshot badge
swh:1:snp:9e2c8e761f723f56d02b61d1740453a509f14182

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
  • revision
  • snapshot
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Tip revision: 007dfc65380872d7ae27477eff7e7b5cbb766f71 authored by Albert on 27 October 2023, 21:59:43 UTC
Added scripts and documentation for lower memory GPUs
Tip revision: 007dfc6
sas_utils.py
import scipy.signal
import numpy as np
import torch
from tqdm import tqdm
import math
import constants as c
import matplotlib.pyplot as plt
import scipy.fftpack
from data_schemas import WfmCropSettings
from sampling import find_voxels_within_fov


def find_indeces_within_scene(x, corners):
    return torch.where((x[..., 0] >= corners[..., 0].min()) &
                       (x[..., 0] <= corners[..., 0].max()) &
                       (x[..., 1] >= corners[..., 1].min()) &
                       (x[..., 1] <= corners[..., 1].max()) &
                       (x[..., 2] >= corners[..., 2].min()) &
                       (x[..., 2] <= corners[..., 2].max()))[0].long()


def figure_to_tensorboard(writer, fig, fig_name, global_step):
    fig.canvas.draw()
    # Convert the figure to numpy array, read the pixel values and reshape the array
    fig_img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    fig_img = fig_img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    # Normalize into 0-1 range for TensorBoard(X). Swap axes for newer versions where API expects colors in first dim
    fig_img = fig_img / 255.0
    writer.add_image(fig_name, fig_img.transpose(2, 0, 1), global_step)


def view_fft(y, fs, N=None, path=None):
    assert y.ndim == 1
    if N is None:
        N = y.shape[0]

    T = 1 / fs
    xf = np.linspace(0, 1 / (2. * T), N // 2)
    yf = scipy.fftpack.fft(y)

    plt.figure(figsize=(10, 7))
    plt.subplot(2, 1, 1)
    plt.plot(xf, 2.0 / N * np.abs(yf[:N // 2]))
    plt.title("Freq Domain")
    plt.xlabel('Hz')
    plt.subplot(2, 1, 2)
    plt.plot(y)
    plt.title("Time Domain")
    plt.xlabel('Samples')
    plt.tight_layout()
    plt.show()
    if path is not None:
        plt.savefig(path)


def matplotlib_render(mag, thresh, x_voxels, y_voxels, z_voxels, x_corners, y_corners, z_corners, save_path):
    mag = np.abs(mag)
    mag = mag.ravel()

    u = mag.mean()
    var = mag.std()
    mag[mag[:] < (u + thresh * var)] = None

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.clear()
    im = ax.scatter(x_voxels,
                    y_voxels,
                    z_voxels,
                    c=mag, alpha=0.5)
    ax.set_xlim3d(
        (x_corners.min(), x_corners.max()))
    ax.set_ylim3d(
        (y_corners.min(), y_corners.max()))
    ax.set_zlim3d(
        (z_corners.min(), z_corners.max()))
    plt.grid(True)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    fig.colorbar(im)
    fig.savefig(save_path)
    plt.close(fig)
    return fig


def comp_mag(x):
    return torch.sqrt(x[..., 0] ** 2 + x[..., 1] ** 2 + 1e-5)


# ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
def finite_difference_normal(x, model, epsilon=1e-3):
    bound = 1e2
    # x: [N, 3]
    dx_pos = torch.relu(model(
        (x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)))[..., 0]).clamp(-bound, bound)
    dx_neg = torch.relu(model(
        (x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)))[..., 0]).clamp(-bound, bound)
    dy_pos = torch.relu(model(
        (x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)))[..., 0]).clamp(-bound, bound)
    dy_neg = torch.relu(model(
        (x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)))[..., 0]).clamp(-bound, bound)
    dz_pos = torch.relu(model(
        (x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)))[..., 0]).clamp(-bound, bound)
    dz_neg = torch.relu(model(
        (x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)))[..., 0]).clamp(-bound, bound)

    normal = torch.stack([
        .5 * (dx_pos - dx_neg) / epsilon,
        .5 * (dy_pos - dy_neg) / epsilon,
        .5 * (dz_pos - dz_neg) / epsilon
    ], dim=-1)

    return -normal


"""Hilbert transform in pytorch"""


def hilbert_torch(x):
    N = x.shape[-1]

    # add extra dimension that will be removed later.
    if x.ndim == 1:
        x = x[None, :]

    # Take forward fourier transform
    Xf = torch.fft.fft(x, dim=-1)
    h = torch.zeros_like(x)

    if N % 2 == 0:
        h[:, 0] = h[:, N // 2] = 1
        h[:, 1:N // 2] = 2
    else:
        h[:, 0] = 1
        h[:, 1:(N + 1) // 2] = 2

    # Take inverse Fourier transform
    x_hilbert = torch.fft.ifft(Xf * h.to(Xf.device), dim=-1).squeeze()

    return x_hilbert


"""Method to remove the mean amplitude and phase response from time series"""


def remove_room(ts):
    # ts.shape = [nangles, n_samples]
    ang = np.angle(ts)
    cm = np.mean(ts, 0)
    # cm.shape = [n_samples]
    dang = np.angle(np.exp(1j * (np.angle(cm[None, ...]) - np.angle(ts))))
    # dang.shape = [n_angles, n_samples]
    beta = 1. / (1 + np.abs(dang) ** 2)
    alpha = 1. / (1 + (np.abs(cm)[None, ...] - np.abs(ts)) ** 2)
    rm = np.abs(cm)[None, ...] * np.exp(1j * ang)
    nts = ts - alpha * beta * rm

    return nts


"""Method to create LFM waveform
Fs: sample rate
n_samples: number of samples in padded lfm
f_start: LFM start frequency
f_stop: LFM stop frequency
t_dur: LFM duration
window: Option to apply Tukey Window
win_ratio: Tukey window ratio
"""


def gen_real_lfm(Fs, f_start, f_stop, t_dur, window=True, win_ratio=0.1, phase=0):
    times = np.linspace(0, t_dur - 1 / Fs, num=int((t_dur) * Fs))
    LFM = scipy.signal.chirp(times, f_start, t_dur, f_stop, phi=phase)

    if window:
        tuk_win = scipy.signal.windows.tukey(len(LFM), win_ratio)
        LFM = tuk_win * LFM

    return LFM


def modulate_signal(x, fs, fc, keep_quadrature=False):
    if x.ndim == 1:
        x = x[None, :]

    modulate_vec = np.exp(1j * 2 * np.pi * fc * np.arange(0, x.shape[-1], 1) / fs)

    x_mod = x * modulate_vec[None, :]
    x_mod = x_mod.squeeze()

    if keep_quadrature:
        return 2 * x_mod
    else:
        return 2 * x_mod.real


# baseband along the last dimension of x
def baseband_signal(x, fs, fc):
    if x.ndim == 1:
        x = x[None, :]

    demodvect = np.exp(-1j * 2 * np.pi * fc * np.arange(0, x.shape[-1], 1) / fs)
    x_demod = x * demodvect[None, :]
    x_demod = x_demod.squeeze()

    # LPF
    b, a = scipy.signal.butter(5, fc * 2 / fs)
    x_demod = scipy.signal.filtfilt(b, a, x_demod)

    return x_demod


def match_filter_all(x, kernel):
    assert x.ndim == 2

    data_rc = torch.zeros((x.shape[0], x.shape[1]), dtype=torch.complex128)

    if not torch.is_tensor(x):
        x = torch.from_numpy(x)
    if not torch.is_tensor(kernel):
        kernel = torch.from_numpy(kernel)

    fft_kernel = torch.zeros((x.shape[1]), dtype=kernel.dtype)
    fft_kernel[:kernel.shape[0]] = kernel
    fft_kernel = torch.fft.fft(hilbert_torch(fft_kernel))

    # plt.figure()
    # plt.plot(np.abs(fft_kernel.detach().cpu().numpy()))
    # plt.savefig('./scene_data/das/fft_kernel_mf_all.png')
    # exit(0)

    for i in tqdm(range(x.shape[0]), desc='Match filtering'):
        data_rc[i, ...] = replica_correlate_torch(x[i, ...], fft_kernel)

    return data_rc.detach().cpu().numpy()


def replica_correlate_torch(x, kernel):
    assert not x.dtype == torch.complex, "x should be real"
    # Forward fourier transform of received waveform
    x_hil = hilbert_torch(x)
    x_fft = torch.fft.fft(x_hil)

    # Definition of cross-correlation
    x_rc = torch.fft.ifft(x_fft * torch.conj(kernel))

    return x_rc


"""Interpolation using fft"""


def interpfft(x, r):
    nx = len(x)
    X = torch.fft.fft(x)

    Xint = torch.zeros(int(len(X) * r), dtype=X.dtype)
    nxint = len(Xint)

    if len(x) % 2 == 0:
        Xint[0:nx // 2] = X[0:nx // 2]
        Xint[nx // 2] = X[nx // 2] / 2
        Xint[nxint - nx // 2] = X[nx // 2] / 2
        Xint[nxint - nx // 2 + 1:] = X[nx // 2 + 1:]
    else:
        Xint[0:math.floor(nx / 2) + 1] = X[:math.floor(nx / 2) + 1]
        Xint[nxint - math.floor(nx / 2):] = X[math.floor(nx / 2) + 1:]

    xint = torch.fft.ifft(Xint) * r

    if torch.is_complex(x):
        return xint
    else:
        return xint.real


"""
Same as crop_wfm except accounts for beamwidth of TX/RX. Finite beamwidth means voxels may not be within FOV
"""


def crop_wfm_beamwidth(tx_coords, rx_coords, tx_vec, rx_vec, tx_bw, rx_bw, voxels, wfm_length, fs, speed_of_sound,
                       same_tx_per_k_rx=1, pad=.05, device='cpu'):
    all_dists_min = []
    all_dists_max = []

    count = 0
    valid_indeces = []
    # Find all the valid distances from tx/rx to voxels (voxels within both FOV)
    for tx, rx, tx_v, rx_v in tqdm(zip(tx_coords, rx_coords, tx_vec, rx_vec),
                                   desc="Cropping waveforms"):
        # Only update tx when it changes
        if count % same_tx_per_k_rx == 0:
            _, in_tx_fov_voxels = find_voxels_within_fov(trans_pos=tx,
                                                         tx_vec=tx_v,
                                                         origin=torch.tensor([0., 0., -1.]),
                                                         voxels=voxels,
                                                         bw=tx_bw,
                                                         device=device)

        _, in_both_fov_voxels = find_voxels_within_fov(trans_pos=rx,
                                                       tx_vec=rx_v,
                                                       origin=torch.tensor([0., 0., -1.]),
                                                       voxels=in_tx_fov_voxels,
                                                       bw=rx_bw,
                                                       device=device)

        in_both_fov_voxels = in_tx_fov_voxels

        if in_both_fov_voxels.shape[0] > 0:
            valid_indeces.append(count)

        # print(in_tx_fov_voxels.shape)
        # np.save('/home/albert/tmp/both_fov_' + str(count) + '.npy', in_both_fov_voxels)

            d1 = np.sqrt(np.sum((tx[None, ...] - in_both_fov_voxels) ** 2, axis=-1))
            d2 = np.sqrt(np.sum((rx[None, ...] - in_both_fov_voxels) ** 2, axis=-1))

            tot_dist = d1+d2
            min_val = tot_dist.ravel().min()
            max_val = tot_dist.ravel().max()

            all_dists_min.append(min_val)
            all_dists_max.append(max_val)

        count = count + 1


    # Crop the waveforms based off of these distances
    all_dists_min = np.array(all_dists_min)
    all_dists_max = np.array(all_dists_max)
    # Pad by waveform length and some scalar offset
    min_dist = all_dists_min.min() - pad - (wfm_length / fs * speed_of_sound)
    max_dist = all_dists_max.max() + pad + (wfm_length / fs * speed_of_sound)

    assert max_dist > min_dist, "Sanity check failed"

    min_sample = math.floor(min_dist / speed_of_sound * fs)

    # Update the min dist based off the rounded down sample
    min_dist = min_sample / fs * speed_of_sound

    # 340 is a conservative sound speed to use.
    t_dur = (max_dist - min_dist) / speed_of_sound

    num_samples = math.ceil(t_dur * fs)

    # Update the max dist based off the rounded up sample
    max_dist = ((min_sample + num_samples) / fs) * speed_of_sound

    wfm_crop_settings = WfmCropSettings()
    wfm_crop_settings[c.MIN_SAMPLE] = min_sample
    wfm_crop_settings[c.MIN_DIST] = min_dist
    wfm_crop_settings[c.MAX_DIST] = max_dist
    wfm_crop_settings[c.NUM_SAMPLES] = num_samples

    return wfm_crop_settings, np.array(valid_indeces)


def crop_wfm(tx_coords, rx_coords, corners, wfm_length, fs, speed_of_sound, pad=.05):
    # A conservative estimate for speed of sound in water
    assert tx_coords.shape[0] == rx_coords.shape[0]
    # [1, num_tx, 3] - [num_corners, 1, 3] = [num_corners, num_tx, 3]
    d1 = np.sqrt(np.sum((tx_coords[None, ...] - corners[:, None, :]) ** 2, axis=-1))
    d2 = np.sqrt(np.sum((rx_coords[None, ...] - corners[:, None, :]) ** 2, axis=-1))

    # TODO Should really pad waveform to proper length with with zeros prior to deconvolution.
    min_dist = (d1 + d2).ravel().min() - pad - (min(wfm_length, 100) / fs * speed_of_sound)
    max_dist = (d1 + d2).ravel().max() + pad + (min(wfm_length, 100) / fs * speed_of_sound)

    assert max_dist > min_dist, "Sanity check failed"

    min_sample = math.floor(min_dist / speed_of_sound * fs)

    # Update the min dist based off the rounded down sample
    min_dist = min_sample / fs * speed_of_sound

    # 340 is a conservative sound speed to use.
    t_dur = (max_dist - min_dist) / speed_of_sound

    num_samples = math.ceil(t_dur * fs)

    # Update the max dist based off the rounded up sample
    max_dist = ((min_sample + num_samples) / fs) * speed_of_sound

    wfm_crop_settings = WfmCropSettings()
    wfm_crop_settings[c.MIN_SAMPLE] = min_sample
    wfm_crop_settings[c.MIN_DIST] = min_dist
    wfm_crop_settings[c.MAX_DIST] = max_dist
    wfm_crop_settings[c.NUM_SAMPLES] = num_samples

    return wfm_crop_settings


def radial_delay_wfms_fast(tsd, weights):
    # [batch_size, num_radial, 1] * [1, num_radial, num_samples] = [batch_size, num_radial, num_samples]
    tsd_scaled = weights[..., None] * tsd[None, ...]
    # tsd_scaled = weights[..., None] * tsd[None, ...]
    tsd_sum = torch.sum(tsd_scaled, 1)

    return tsd_sum


# Fixed SNR Wiener filter
def wiener_deconvolution(signal_fft, kernel_fft, lambd):
    if signal_fft.ndim == 2:
        if kernel_fft.ndim == 1:
            kernel_fft = kernel_fft[None, :]
    deconvolved = torch.real(torch.fft.ifft(signal_fft * torch.conj(kernel_fft) /
                                            (kernel_fft * torch.conj(kernel_fft) + lambd ** 2)))
    return deconvolved


# Source https://github.com/ashawkey/stable-dreamfusion/blob/5c8b53f8e8fc041e98bd7d3d210bdd62e7d6fae2/nerf/utils.py#L39
def safe_normalize(x, eps=1e-4):
    return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))


def no_rc_kernel_from_waveform(wfm, num_samples):
    assert wfm.ndim == 1
    sig = np.zeros((num_samples), dtype=wfm.dtype)
    sig[:wfm.shape[0]] = wfm
    sig = torch.from_numpy(sig)
    if wfm.dtype == complex:
        kernel = torch.fft.fft(sig)
        # kernel = torch.fft.fft(hilbert_torch(sig))
    else:
        kernel = torch.fft.fft(hilbert_torch(sig))
    # kernel_cc = kernel * torch.conj(kernel)
    return kernel


def kernel_from_waveform(wfm, num_samples):
    assert wfm.ndim == 1
    sig = np.zeros((num_samples))
    sig[:wfm.shape[0]] = wfm
    sig = torch.from_numpy(sig)
    kernel = torch.fft.fft(hilbert_torch(sig))
    kernel_cc = kernel * torch.conj(kernel)
    return kernel_cc


def correct_group_delay(wfm, gd, fs):
    assert np.isreal(wfm).all()
    num_samples = wfm.shape[-1]
    df = fs / num_samples
    f_ind = np.linspace(0, int(num_samples - 1), num=int(num_samples),
                        dtype=np.float64)
    f = f_ind * df
    f[f > (fs / 2)] -= fs
    w = (2 * math.pi * f)

    tau = gd / fs

    phase = np.array([tau * w])

    complex_phase = np.zeros_like(phase) + 1j * phase

    pr = np.exp(complex_phase)

    wfm_correct_ifft = np.fft.fft(wfm) * pr

    wfm_correct = np.fft.ifft(wfm_correct_ifft, axis=1).real

    return wfm_correct


def precompute_time_series(dists, min_dist, kernel, speed_of_sound, fs, num_samples):
    df = fs / num_samples
    f_ind = torch.linspace(0, int(num_samples - 1), steps=int(num_samples),
                           dtype=torch.float64)
    f = f_ind * df
    f[f > (fs / 2)] -= fs
    w = (2 * math.pi * f).to(dists.device)

    tau = ((dists) - min_dist) / speed_of_sound

    phase = tau[:, None] * w[None, :]

    complex_phase = torch.complex(real=torch.zeros_like(phase).to(phase.device),
                                  imag=-1 * phase)

    pr = torch.exp(complex_phase)

    tsd_fft = kernel[None, :] * pr  # * torch.exp(1j * 2 * np.pi * tau).to(pr.device)
    tsd = torch.fft.ifft(tsd_fft, dim=1)

    return tsd


def delay_waveforms(tx_pos, rx_pos, weights, voxels, kernel, kernel_no_rc, min_dist, group_delay, fs, speed_of_sound):
    assert tx_pos.shape[0] == rx_pos.shape[0]
    assert kernel.ndim == 1
    num_samples = kernel.shape[0]

    df = fs / (num_samples)
    f_ind = torch.linspace(0, int(num_samples - 1), steps=int(num_samples),
                           dtype=torch.float64)
    f = f_ind * df
    f[f > (fs / 2)] -= fs
    w = (2 * math.pi * f).to(weights.device)

    data_rc = torch.zeros((tx_pos.shape[0], num_samples), dtype=torch.complex128)
    data = torch.zeros((tx_pos.shape[0], num_samples), dtype=torch.float64)

    for i in tqdm(range(tx_pos.shape[0]), desc='Simulating waveforms...'):
        d1 = torch.sqrt(torch.sum((voxels - tx_pos[i, :][None, ...]) ** 2, dim=1))
        d2 = torch.sqrt(torch.sum((voxels - rx_pos[i, :][None, ...]) ** 2, dim=1))

        tau = ((d1 + d2 + (group_delay / fs) * speed_of_sound) - min_dist) / speed_of_sound

        phase = tau[:, None] * w[None, :]

        complex_phase = torch.complex(real=torch.zeros_like(phase).to(phase.device),
                                      imag=-1 * phase)

        pr = torch.exp(complex_phase)

        tsd_fft = kernel[None, :] * pr
        tsd = torch.fft.ifft(tsd_fft, dim=1)
        tsd_scaled = weights[:, None] * tsd
        tsd_sum = torch.sum(tsd_scaled, 0)
        data_rc[i, :] = tsd_sum

        tsd_fft = kernel_no_rc[None, :] * pr
        tsd = torch.fft.ifft(tsd_fft, dim=1).real
        tsd_scaled = weights[:, None] * tsd
        tsd_sum = torch.sum(tsd_scaled, 0)
        data[i, :] = tsd_sum

    return data, data_rc


def range_normalize(x):
    return (x - x.min()) / (x.max() - x.min())

back to top

Software Heritage — Copyright (C) 2015–2025, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API