Revision 12cd680cc614ed8aade4956a430e288e05425e78 authored by Yijie Tang on 10 April 2024, 13:52:17 UTC, committed by Yijie Tang on 10 April 2024, 13:52:17 UTC
1 parent 5934d01
Raw File
keyframeSet.py
import torch
import numpy as np
import random

from helper_functions.sampling_helper import pixel_rc_to_indices, sample_pixels_uniformly
from helper_functions.geometry_helper import extract_first_kf_pose, compute_surface_center


class KeyframeSet():
    # @param num_kf: max keyframe number;
    def __init__(self, config, H, W, num_kf, device) -> None:
        self.config = config
        self.keyframes = {}
        self.device = device

        self.frame_ids = None  # frame_Id of each keyframe, Tensor(N, )
        self.collected_kf_num = torch.zeros((1, ), dtype=torch.int64).share_memory_()
        self.H = H
        self.W = W

        self.n_rays_h = self.config["sampling"]["kf_n_rays_h"]  # number of rows for saving a keyframe downsampled
        self.n_rays_w = self.config["sampling"]["kf_n_rays_w"]  # number of cols for saving a keyframe downsampled
        self.num_rays_to_save = self.n_rays_h * self.n_rays_w  # pixel num saved for each keyframe
        self.row_indices, self.col_indices = sample_pixels_uniformly(self.H, self.W, self.n_rays_h, self.n_rays_w)  # get downsampled pixels (row_Ids, col_Ids)
        self.rays = torch.zeros( (num_kf, self.num_rays_to_save, 7) )  # [direction, rgb, depth]

        # keyframe, localMLP related vars
        self.create_MLP_data(num_kf)
        self.create_overlapping_pts_data()


    def __len__(self):
        return len(self.frame_ids)


    def get_length(self):
        return self.__len__()


    # @brief: create vars related to localMLPs
    def create_MLP_data(self, num_kf):
        num_localMLP = self.config["mapping"]["localMLP_num"]
        # this tensor records the information of each localMLP:
        # (1) col[0:1]: whether this localMLP is created and used (1/0);
        # (2) col[1:4]: xyz center (in World Coordinate System);
        # (3) col[5:]: xyz axis-aligned length (in World Coordinate System)
        self.localMLP_info = torch.zeros( (num_localMLP, 7) ).share_memory_()

        self.localMLP_max_len = torch.tensor(self.config["mapping"]["localMLP_max_len"])[None, ...].repeat((num_localMLP, 1)).share_memory_()  # Tensor(num_localMLP, 3)
        self.localMLP_adjacent = torch.zeros( (self.config["mapping"]["localMLP_num"], self.config["mapping"]["localMLP_num"]) ).share_memory_()  # whether 2 localMLPs are adjacent(0/1), Tensor(num_localMLP, num_localMLP)

        # this tensor records the related localMLP_Id of each keyframe (-1 means none)
        self.keyframe_localMLP = torch.full( (num_kf, 2), fill_value=-1 ).share_memory_()

        # keyframe_Id of each submap's first keyframe
        self.localMLP_first_kf = torch.full( (num_localMLP, ), fill_value=-1 ).share_memory_()

        self.keyframe_mutex_mask = torch.zeros((1, num_kf), dtype=torch.int64).share_memory_()


    # @brief: create vars for loop closure optimization
    def create_overlapping_pts_data(self):
        self.ovlp_rays_h = self.config["mapping"]["overlapping"]["n_rays_h"]
        self.ovlp_rays_w = self.config["mapping"]["overlapping"]["n_rays_w"]
        self.ovlp_depth = torch.zeros( (1, self.ovlp_rays_h * self.ovlp_rays_w) ).share_memory_()
        self.ovlp_rays = torch.zeros( (1, self.ovlp_rays_h * self.ovlp_rays_w, 3) ).share_memory_()
        self.ovlp_pts_mask = torch.zeros( (1, self.ovlp_rays_h * self.ovlp_rays_w), dtype=torch.bool ).share_memory_()

        self.near_kf_num = 10
        self.nearest_kf_Ids = torch.full( (1, self.near_kf_num), fill_value=-1, dtype=torch.int64 ).share_memory_()
        self.nearest_kf_mask = torch.zeros( (1, self.near_kf_num, self.ovlp_rays_h * self.ovlp_rays_w) ).share_memory_()


    # @brief: Sampling strategy for current keyframe rays
    # @param rays: Tensor(1, H * W, 7);
    def sample_single_keyframe_rays(self, rays):
        idxs = pixel_rc_to_indices(self.row_indices, self.col_indices, self.H, self.W)
        rays = rays[:, idxs]  # Tensor(1, self.num_rays_to_save, 7), device=cpu
        return rays


    # @brief: modify self.keyframe_localMLP;
    # @param kf_Id: Tensor(, );
    # @param localMLP_Id1: Tensor(, );
    # @param localMLP_Id2: Tensor(, ), default: None.
    def add_keyframe_localMLP(self, kf_Id, localMLP_Id1, localMLP_Id2=None):
        if localMLP_Id2 is None:
            if self.keyframe_localMLP[kf_Id][0] == -1:
                self.keyframe_localMLP[kf_Id][0] = localMLP_Id1
            else:
                self.keyframe_localMLP[kf_Id][1] = localMLP_Id1
        else:
            self.keyframe_localMLP[kf_Id][0] = localMLP_Id1
            self.keyframe_localMLP[kf_Id][1] = localMLP_Id2


    # @brief: modify self.localMLP_info and self.localMLP_first_kf (if needed)
    # @param localMLP_Id: Tensor(, );
    # @param localMLP_center: Tensor(3, );
    # @param localMLP_len: Tensor(3, );
    def modify_localMLP_info(self, localMLP_Id, localMLP_center, localMLP_len):
        self.localMLP_info[localMLP_Id][1:4] = localMLP_center
        self.localMLP_info[localMLP_Id][4:7] = localMLP_len


    # @brief: set localMLP_Id1 and localMLP_Id2 are adjacent localMLPs (having overlapping keyframes)
    def add_adjcent_pair(self, localMLP_Id1, localMLP_Id2):
        if localMLP_Id1 is not None and localMLP_Id2 is not None:
            self.localMLP_adjacent[localMLP_Id1][localMLP_Id2] = 1
            self.localMLP_adjacent[localMLP_Id2][localMLP_Id1] = 1


    # @brief: find all adjacent localMLP pairs;
    #-@return adja_pairs: all adjacent localMLP pairs, each line represents a pair, Tensor(n, 2);
    #-@return part_localMLP: localMLP_Ids of all participated localMLPs, Tensor(m, ).
    def find_adjacent_localMLP_pair(self):
        localMLP_num = self.localMLP_adjacent.shape[0]
        adja_pairs = []
        part_localMLP = []
        for i in range(localMLP_num):
            for j in range(localMLP_num):
                if j <= i:
                    continue
                if self.localMLP_adjacent[i][j] > 0:
                    adja_pairs.append( torch.tensor([j, i], dtype=torch.int32) )
                    if i not in part_localMLP:
                        part_localMLP.append(i)
                    if j not in part_localMLP:
                        part_localMLP.append(j)
        adja_pairs = torch.stack(adja_pairs, 0)  # Tensor(n, 2)
        adja_pairs = torch.sort(adja_pairs, -1)[0]
        part_localMLP = torch.tensor(part_localMLP)  # localMLP_Ids of all participated localMLPs, Tensor(m, )
        part_localMLP = torch.sort(part_localMLP)[0]
        return adja_pairs, part_localMLP


    # @brief: fill localMLP info and localMLP's first keyframe for a newly created localMLP;
    #-@return: Tensor(, ).
    def modify_new_localMLP_info(self, localMLP_center, localMLP_len, kf_Id):
        new_localMLP_Id = torch.count_nonzero(self.localMLP_info[:, 0])  # next available localMLP_Id
        updated_line = torch.cat( [torch.ones((1, )), localMLP_center, localMLP_len], 0 )  # Tensor(7, )
        if new_localMLP_Id < self.localMLP_info.shape[0]:
            self.localMLP_info[new_localMLP_Id] = updated_line
            self.localMLP_first_kf[new_localMLP_Id] = kf_Id
        else:
            kf_Id_tensor = torch.tensor([kf_Id], dtype=torch.int64)
            self.localMLP_info = torch.cat([self.localMLP_info, updated_line.unsqueeze(0)], 0).share_memory_()
            self.localMLP_first_kf = torch.cat([self.localMLP_first_kf, kf_Id_tensor], 0).share_memory_()

            localMLP_max_len_new = torch.tensor(self.config["mapping"]["localMLP_max_len"])[None, ...]
            self.localMLP_max_len = torch.cat([self.localMLP_max_len, localMLP_max_len_new], 0).share_memory_()

            localMLP_adjacent_new_l = torch.zeros_like(self.localMLP_adjacent[-1].unsqueeze(0))
            self.localMLP_adjacent = torch.cat([self.localMLP_adjacent, localMLP_adjacent_new_l], 0)
            localMLP_adjacent_new_c = torch.zeros_like(self.localMLP_adjacent[:, -1].unsqueeze(0))
            self.localMLP_adjacent = torch.cat([self.localMLP_adjacent, localMLP_adjacent_new_c], -1).share_memory_()
        return new_localMLP_Id


    # @brief: insert the frame_Id (of a keyframe) to list
    def attach_ids(self, frame_ids):
        if self.frame_ids is None:
            self.frame_ids = torch.ones((1, )) * frame_ids
        else:
            frame_ids = torch.ones((1, )) * frame_ids
            self.frame_ids = torch.cat([self.frame_ids, frame_ids], dim=0)


    # @brief: Add keyframe rays to the keyframe database: (1) add frame_Id of keyframe; (2) store the rays
    def add_keyframe(self, batch):
        rays = torch.cat([batch['direction'], batch['rgb'], batch['depth'][..., None]], dim=-1)
        rays = rays.reshape(1, -1, rays.shape[-1])  # Tensor(1, H * W, 7), device=cpu
        rays = self.sample_single_keyframe_rays(rays)  # Tensor(1, num_rays_to_save, 7), device=cpu
        self.attach_ids(batch['frame_id'])
        self.rays[len(self.frame_ids)-1] = rays


    # @brief: find all keyframes which relates to active localMLP and an inactive localMLP;
    # @param active_localMLP_Id: Tensor(, );
    # @param keyframe_ref: Tensor(N_kf, );
    #-@return: Tensor(kf_num, ), 0/1.
    def update_mutex_mask(self, active_localMLP_Id, keyframe_ref, kf_num):
        mask1 = torch.where(keyframe_ref[:kf_num] == -2, torch.ones_like(keyframe_ref[:kf_num]), torch.zeros_like(keyframe_ref[:kf_num])).to(torch.bool)  # Tensor(kf_num, )

        mask2_1 = self.keyframe_localMLP[:kf_num, 0] == active_localMLP_Id
        mask2_2 = self.keyframe_localMLP[:kf_num, 1] == active_localMLP_Id
        mask2 = torch.logical_or(mask2_1, mask2_2)

        final_mask = -1 * torch.logical_and(mask1, mask2).to(torch.int64)  # Let ActiveMap process optimize first
        self.keyframe_mutex_mask[0][:kf_num] = final_mask


    # @brief: get corresponding localMLP_Id of eacg keyframe
    def get_kf_localMLP_Id(self):
        condition1 = torch.where(self.keyframe_localMLP[:, 0] < 0, torch.ones_like(self.keyframe_localMLP[:, 0]), torch.zeros_like(self.keyframe_localMLP[:, 0]))
        condition2 = torch.where(self.keyframe_localMLP[:, 1] < 0, torch.ones_like(self.keyframe_localMLP[:, 1]), torch.zeros_like(self.keyframe_localMLP[:, 1]))
        condition_flag = condition1 + condition2  # Tensor(N_kf, ), dtype=tf.int32
        selected_localMLP_Id = torch.where(condition_flag == 0, self.keyframe_localMLP[:, 1], self.keyframe_localMLP[:, 0])  # Tensor(N_kf, ), dtype=tf.int32
        selected_localMLP_Id = torch.where(selected_localMLP_Id >= 0, selected_localMLP_Id, torch.zeros_like(selected_localMLP_Id))  # Tensor(N_kf, ), dtype=tf.int32
        return selected_localMLP_Id


    # @brief: extract a given localMLP's first keyframe pose in World Coordinate System;
    # @param kf_localMLP_Ids: corresponding localMLP_Id of each keyframe, Tensor(n, )/Tensor(, );
    # @param kf_poses: tensor stored world of all first keyframes and overlapping keyframes, Tensor(n, 4, 4);
    #-@return first_kf_pose: first keyframes' poses of given localMLPs, Tensor(n, 4, 4)/Tensor(4, 4);
    #-@return first_kf_Ids: first keyframes' kf_Id of given localMLPs, Tensor(n, )/Tensor(, ).
    def extract_first_kf_pose(self, kf_localMLP_Ids, kf_poses):
        first_kf_Ids = self.localMLP_first_kf[kf_localMLP_Ids]  # first keyframe's kf_Id of each localMLP, Tensor(n, )/Tensor(, )
        first_kf_pose = kf_poses[first_kf_Ids]  # first keyframe's pose(c2w, in World Coordinate System) of each localMLP, Tensor(n, 4, 4)/Tensor(4, 4)
        return first_kf_pose, first_kf_Ids


    # @brief: get poses in World Coordinate System of selected keyframes;
    # @param kf_Ids: Tensor(n, );
    # @param localMLP_idx: the related localMLP index that given local pose belongs to, Tensor(n, 1);
    # @param keyframe_ref: Tensor(n, );
    # @param est_c2w_data: local pose of each given keyframe, Tensor(n, 4, 4);
    # @param first_kf_pose: each localMLP's first kf pose(in World Coordinate System), Tensor(localMLP_num, 4, 4);
    #-@return poses_world: Tensor(n, 4, 4).
    def extract_kf_world_poses(self, kf_Ids, localMLP_idx, keyframe_ref, est_c2w_data, first_kf_pose):
        kf_frame_Ids = kf_Ids * self.config["mapping"]["keyframe_every"]  # Tensor(n, )
        ref_localMLP_Id = torch.gather(self.keyframe_localMLP[kf_Ids], 1, localMLP_idx).squeeze(-1)  # corresponding localMLP_Id of each selected keyframe, Tensor(n, )
        first_poses = first_kf_pose[ref_localMLP_Id]  # corresponding localMLP's pose of each selected keyframe, Tensor(n, 4, 4)

        poses_local = est_c2w_data[kf_frame_Ids]
        keyframe_ref = keyframe_ref[..., None, None].to(self.device)  # Tensor(n, 1, 1)
        poses_world = torch.where(keyframe_ref == -1, first_poses, first_poses @ poses_local)
        return poses_world


    # @brief: find keyframes which are bound to more than 1 localMLPs
    def find_ovlp_kf_Ids(self, kf_num=None):
        if kf_num is None:
            kf_num = self.collected_kf_num[0].clone()
        keyframe_localMLP = self.keyframe_localMLP[:kf_num, :].clone()
        condition1 = torch.where( keyframe_localMLP[:, 0] >= 0, torch.ones_like(keyframe_localMLP[:, 0]), torch.zeros_like(keyframe_localMLP[:, 0]) )
        condition2 = torch.where(keyframe_localMLP[:, 1] >= 0, torch.ones_like(keyframe_localMLP[:, 0]), torch.zeros_like(keyframe_localMLP[:, 0]))
        ovlp_kf_Ids = torch.where(condition1 * condition2 > 0)[0]
        return ovlp_kf_Ids


    # @brief: giving a keyframe, and other related keyframe_Ids, compute center distance between this keyframe and each related keyframe;
    # @param kf_center: Tensor(3, );
    # @param related_kf_Ids: keyframe_Ids of related keyframes, Tensor(n, );
    # @param related_kf_pose: world poses of related keyframes, Tensor(n, 4, 4);
    #-@return: distances, Tensor(n, ).
    def sort_center_dist_kf(self, kf_center, related_kf_Ids, related_kf_pose):
        # Step 1: compute surface centers of each related keyframe (in Camera Coordinate System)
        related_rays = self.rays[related_kf_Ids]  # Tensor(n, self.num_rays_to_save, 7)
        surface_centers = compute_surface_center(related_rays)  # Tensor(n, 3)

        # Step 2: cam coords --> world coords
        related_kf_rot = related_kf_pose[:, :3, :3]  # rot mat w2c, Tensor(n, 3, 3)
        related_kf_trans = related_kf_pose[:, :3, 3]  # trans vec w2c, Tensor(n, 3)
        rotated_pts = torch.sum(surface_centers[:, None, :] * related_kf_rot, -1)  # Tensor(n, 3)
        transed_pts = rotated_pts + related_kf_trans

        # Step 3: compute center distance between given keyframe and each related keyframe
        dists = torch.norm(transed_pts - kf_center[None, ...], dim=-1)  # Tensor(n, )
        return dists


    # @brief: Sample rays from self.rays as well as frame_ids
    # @param bs:
    #-@return sample_rays: sampled rays, Tensor(bs, 7);
    #-@return kf_ids: corresponding keyframe_Id of each ray, Tensor(bs, ).
    def sample_global_rays(self, bs):
        num_kf = self.get_length()  # collected keyframe num so far, int
        idxs = torch.tensor( random.sample(range(num_kf * self.num_rays_to_save), bs) )  # ray_Id of sampled rays, Tensor(bs, ), device=cpu
        sample_rays = self.rays[:num_kf].reshape(-1, 7)[idxs]  # Tensor(bs, 7), device=cpu

        kf_ids = torch.div(idxs, self.num_rays_to_save, rounding_mode="floor")  # corresponding keyframe_Id of each ray, Tensor(bs, )
        frame_ids = self.frame_ids[kf_ids]

        return sample_rays, kf_ids


    # @brief: Sample rays from given keyframes;
    # @param kf_Ids: keyframe_Ids sampled from, Tensor(n, );
    # @param bs: pixel num to sample, int;
    #-@return sample_rays: sampled rays, Tensor(bs, 7);
    #-@return kf_indices: corresponding keyframe indices of each ray, Tensor(bs, ).
    def sample_rays_from_given(self, kf_Ids, bs):
        num_kf = kf_Ids.shape[0]
        idxs = torch.tensor( random.sample( range(num_kf * self.num_rays_to_save), bs ) )  # ray_Id of sampled rays, Tensor(bs, ), device=cpu
        sample_rays = self.rays[kf_Ids].reshape(-1, 7)[idxs]  # Tensor(bs, 7), device=cpu

        kf_indices = torch.div(idxs, self.num_rays_to_save, rounding_mode="floor")  # corresponding keyframe indices of each ray, Tensor(bs, )
        return sample_rays, kf_indices


    #-@return: Tensor(num_kf, ), 0/1.
    def get_related_keyframes(self, localMLP_Id, num_kf):
        keyframe_localMLP = self.keyframe_localMLP[:num_kf, :]
        related_mat = torch.where(keyframe_localMLP == localMLP_Id, torch.ones_like(keyframe_localMLP), torch.zeros_like(keyframe_localMLP))
        keyframes_mask = torch.sum(related_mat, dim=-1)  # Tensor(num_kf, )
        return keyframes_mask


    def get_related_keyframes2(self, localMLP_Id, num_kf, localMLP_Id_exclude):
        keyframe_localMLP = self.keyframe_localMLP[:num_kf, :]

        # Step 1: filter out all keyframes that relates to localMLP_Id
        related_mat1 = torch.where(keyframe_localMLP == localMLP_Id, torch.ones_like(keyframe_localMLP), torch.zeros_like(keyframe_localMLP))
        keyframes_mask1 = torch.sum(related_mat1, dim=-1)  # Tensor(num_kf, )

        # Step 2: filter out all keyframes that relates to localMLP_Id_exclude
        related_mat2 = torch.where(keyframe_localMLP == localMLP_Id_exclude, torch.ones_like(keyframe_localMLP), torch.zeros_like(keyframe_localMLP))
        keyframes_mask2 = torch.sum(related_mat2, dim=-1)  # Tensor(num_kf, )

        keyframes_mask = torch.logical_and( keyframes_mask1.to(torch.bool), torch.logical_not(keyframes_mask2) )
        return keyframes_mask


    # @brief: giving a localMLP_Id and keyframe-localMLP relationships of some keyframes, for each keyframe, judge the given localMLP is its first or second related localMLP;
    # @param keyframe_localMLP: Tensor(n, 2);
    # @param localMLP_Id: Tensor(, );
    #-@return: Tensor(n, ), -1/0/1.
    def get_related_localMLP_index(self, keyframe_localMLP, localMLP_Id):
        col1_mask = torch.where(keyframe_localMLP[:, 0] == localMLP_Id, torch.ones_like(keyframe_localMLP[:, 0]), torch.zeros_like(keyframe_localMLP[:, 0]))
        col2_mask = torch.where(keyframe_localMLP[:, 1] == localMLP_Id, 2*torch.ones_like(keyframe_localMLP[:, 1]), torch.zeros_like(keyframe_localMLP[:, 1]))
        idx = torch.stack([col1_mask, col2_mask], dim=-1)  # Tensor(n, 2)
        hit_idx = torch.max(idx, -1)[0] - 1  # 0/1: given localMLP is this keyframe's first/second related localMLP; -1: given localMLP is not related to this keyframe
        return hit_idx


    # @brief: convert selected keyframes' local pose(meybe in its first or second related localMLP's CS) to local pose in given localMLP's Local Coordinate System;
    # @param keyframe_localMLP: keyframe-localMLP relationships of selected keyframes, Tensor(n, 2);
    # @param hit_idx: given localMLP is this keyframe's first/second related localMLP(0/1), Tensor(n, );
    # @param given_first_kf_pose: first keyframe pose(in World Coordinate System) of given localMLP, Tensor(4, 4)
    # @param poses_local: Tensor(n, 4, 4);
    #-@return: selected keyframes' local poses in given localMLP's Local Coordinate System, Tensor(n, 4, 4);
    def convert_given_local_pose(self, keyframe_localMLP, hit_idx, kf_poses, given_first_kf_pose, poses_local):
        hit_idx = hit_idx.to(self.device)  # Tensor(n, )
        first_kf_poses = extract_first_kf_pose(keyframe_localMLP[:, 0], self.localMLP_first_kf, kf_poses)  # first keyframe pose of each keyframe's first-related localMLP, Tensor(n, 4, 4)
        given_first_kf_pose_inv = given_first_kf_pose.inverse().unsqueeze(0)  # Tensor(1, 4, 4)

        pose_local_transed = given_first_kf_pose_inv @ first_kf_poses @ poses_local
        poses_local_given = torch.where(hit_idx[..., None, None] == 0, poses_local, pose_local_transed)
        return poses_local_given


    # @brief: get world pose of given keyframes;
    # @param keyframe_Ids: Tensor(n, );
    # @param keyframe_ref: Tensor(n, );
    # @param kf_poses: Tensor(n_kf, 4, 4);
    # @param poses_local: Tensor(n, 4, 4);
    #-@return: world pose of asking keyframes, Tensor(n, 4, 4).
    def convert_given_world_pose(self, keyframe_Ids, keyframe_ref, kf_poses, poses_local):
        first_kf_poses = extract_first_kf_pose(self.keyframe_localMLP[keyframe_Ids][:, 0], self.localMLP_first_kf, kf_poses)  # first keyframe pose of each keyframe's first-related localMLP, Tensor(n, 4, 4)
        pose_world_trans = first_kf_poses @ poses_local

        pose_world = kf_poses[keyframe_Ids]  # Tensor(n, 4, 4)
        pose_world_final = torch.where(keyframe_ref[..., None, None] == -1, pose_world, pose_world_trans)
        return pose_world_final


    # @brief:
    # @param localMLP_Id: Tensor(, );
    # @param num_kf: number of keyframes collected so far, int;
    # @param overlap_kf_flag: Tensor(num_kf, );
    # @param process_flag: Process_flag of invoking process (1: ActiveMap process, -1: InactiveMap process), int;
    #-@return: Tensor(num_kf, ).
    def get_related_keyframes_exclude(self, localMLP_Id, num_kf, overlap_kf_flag, process_flag):
        # Step 1: find all related keyframes
        keyframe_localMLP = self.keyframe_localMLP[:num_kf, :]
        related_mat = torch.where(keyframe_localMLP == localMLP_Id, torch.ones_like(keyframe_localMLP), torch.zeros_like(keyframe_localMLP))
        keyframes_mask = torch.sum(related_mat, dim=-1)  # Tensor(num_kf, ), 0/1

        # Step 2: excluded related overlapping keyframes, which are optimized by given process last time
        if torch.count_nonzero(overlap_kf_flag[:num_kf]) > 0:
            condition = (overlap_kf_flag[:num_kf] == process_flag)
            overlap_mask = torch.where(condition, torch.zeros_like(keyframes_mask), torch.ones_like(keyframes_mask))
            keyframes_mask = overlap_mask * keyframes_mask

        return keyframes_mask


    # @brief: sample rays in a submap globally (which will sample pixels from first and last keyframes individually);
    # @param first_kf_Id: keyframe_Id of this localMLP's first keyframe, Tensor(, );
    # @param related_kf_ids: keyframe_Ids of this localMLP's related keyframes, Tensor(n', );
    # @param pix_num:
    # -@return sampled_rays: sampled rays from related keyframes, Tensor(pix_num, 7);
    # -@return kf_ids: corresponding keyframe_Id of each sampled ray, Tensor(pix_num, ).
    def sample_rays_in_submap(self, first_kf_Id, related_kf_ids, pix_num):
        # Step 1: find all keyframes related to given localMLP
        related_kf_num = related_kf_ids.shape[0]

        # Step 2: sample rays from first keyframe and other related keyframes respectively
        # 2.1: sampling from first keyframe
        pix_num_first = max(pix_num // related_kf_num,  pix_num // 10)
        idx_first = torch.tensor(random.sample(range(self.num_rays_to_save), pix_num_first))  # ray_Id of sampled rays, Tensor(pix_num_first, )
        first_rays = self.rays[first_kf_Id].reshape((-1, 7))[idx_first]  # Tensor(pix_num_first, 7)
        first_kf_indices = torch.zeros_like(idx_first)  # corresponding keyframe infices of each ray sampled from first keyframe, Tensor(pix_num_first, )
        first_kf_ids = torch.ones_like(idx_first) * first_kf_Id  # corresponding keyframe infices of each ray sampled from first keyframe, Tensor(pix_num_first, )

        if related_kf_num > 1:
            if related_kf_num > 2:
                # 2.2: sampling from latest keyframe
                last_kf_Id = related_kf_ids[-1]
                pix_num_last = max(pix_num // related_kf_num,  pix_num // 5)
                idx_last = torch.tensor(random.sample(range(self.num_rays_to_save), pix_num_last))  # ray_Id of sampled rays, Tensor(pix_num_last, )
                last_rays = self.rays[last_kf_Id].reshape((-1, 7))[idx_last]  # Tensor(pix_num_last, 7)
                last_kf_indices = torch.ones_like(idx_last) * (related_kf_num - 1)  # corresponding keyframe infices of each ray sampled from last keyframe, Tensor(pix_num_last, )
                last_kf_ids = torch.ones_like(idx_last) * last_kf_Id  # corresponding keyframe infices of each ray sampled from first keyframe, Tensor(pix_num_first, )

                other_kf_ids = related_kf_ids[1:-1]
                pix_num_other = pix_num - pix_num_first - pix_num_last
                other_kf_num = related_kf_num - 2
            else:
                other_kf_ids = related_kf_ids[1:]
                pix_num_other = pix_num - pix_num_first
                other_kf_num = related_kf_num - 1

            # 2.3: sampling from other related keyframes (except first keyframe)
            idx_other = torch.tensor(random.sample(range(other_kf_num * self.num_rays_to_save), pix_num_other))  # ray_Id of sampled rays, Tensor(pix_num_other, )
            other_rays = self.rays[other_kf_ids].reshape((-1, 7))[idx_other]  # Tensor(pix_num_other, 7)

            other_kf_indices = torch.div(idx_other, self.num_rays_to_save, rounding_mode="floor")  # corresponding keyframe infices of each ray sampled from other keyframes, Tensor(pix_num_other, )
            other_kf_ids = other_kf_ids[other_kf_indices]  # corresponding keyframe_Id of each ray, Tensor(pix_num_other, )
            other_kf_indices = other_kf_indices + 1

            if related_kf_num > 2:
                sampled_rays = torch.cat([first_rays, other_rays, last_rays], dim=0)
                kf_indices = torch.cat([first_kf_indices, other_kf_indices, last_kf_indices], dim=0)
                kf_ids = torch.cat([first_kf_ids, other_kf_ids, last_kf_ids], dim=0)
            else:
                sampled_rays = torch.cat([first_rays, other_rays], dim=0)
                kf_indices = torch.cat([first_kf_indices, other_kf_indices], dim=0)
                kf_ids = torch.cat([first_kf_ids, other_kf_ids], dim=0)
        else:
            sampled_rays = first_rays
            kf_indices = first_kf_indices
            kf_ids = first_kf_ids
        return sampled_rays, kf_ids, kf_indices


    # @brief: sample rays in given keyframes;
    # @param given_kf_ids: keyframe_Ids of given related keyframes, Tensor(n', );
    # @param pix_num: total pixel number to sample, int;
    #-@return sampled_rays: sampled rays from related keyframes, Tensor(pix_num, 7);
    #-@return kf_ids: corresponding keyframe_Id of each sampled ray, Tensor(pix_num, ).
    def sample_rays_in_given_kf(self, given_kf_ids, pix_num):
        # Step 1: find all keyframes related to given localMLP
        given_kf_num = given_kf_ids.shape[0]

        # Step 2: sample rays from given keyframes
        idx = torch.tensor(random.sample(range(given_kf_num * self.num_rays_to_save), pix_num))  # ray_Id of sampled rays, Tensor(pix_num_other, )
        sampled_rays = self.rays[given_kf_ids].reshape((-1, 7))[idx]  # Tensor(pix_num_other, 7)

        kf_indices = torch.div(idx, self.num_rays_to_save, rounding_mode="floor")  # corresponding keyframe infices of each ray sampled from other keyframes, Tensor(pix_num_other, )
        kf_ids = given_kf_ids[kf_indices]  # corresponding keyframe_Id of each ray, Tensor(pix_num_other, )

        return sampled_rays, kf_ids, kf_indices


    # @brief: extract related vars of a given localMLP_Id;
    # @param localMLP_Id: given localMLP_Id, Tensor(, );
    # @param kf_poses: tensor storing all first keyframes' world poses, Tensor(num_kf, 4, 4);
    # @param est_c2w_data: tensor storing all keyframes' local poses(in its first related localMLP), Tensor(num_frame, 4, 4);
    # @param kf_ref: keyframe ref type of each keyframe, Tensor(num_kf, );
    # @param process_flag: Process_flag of invoking process (1: ActiveMap process, -1: InactiveMap process), int;
    #-@return first_kf_pose: first keyframe's pose in World Coordinate System of given localMLP, Tensor(4, 4);
    #-@return first_kf_Id: first keyframe's kf_Id of given localMLP, Tensor(, );
    #-@return poses_local: local pose (in given localMLP's coordinate system) of related keyframes, Tensor(selected_num_kf, 4, 4);
    #-@return avail_kf_Ids: keyframe_Ids of all available keyframes related to given localMLP, Tensor(selected_num_kf, );
    #-@return avail_kf_frame_Ids: frame_Ids of all available keyframes related to given localMLP, Tensor(selected_num_kf, );
    #-@return avail_kf_ref: Tensor(n', );
    #-@return avail_ovlp_kf_idx: indices of available overlapping keyframes in avail_kf_Ids, Tensor(k', )
    #-@return avail_ovlp_kf_Ids: keyframe_Ids of available overlapping keyframes, Tensor(k', ).
    def extract_localMLP_vars(self, localMLP_Id, kf_poses, est_c2w_data, kf_ref, process_flag):
        num_kf = self.collected_kf_num[0].clone()  # collected keyframe num so far, Tensor(, )
        # Step 1: get overlapping keyframes mutex mask (only those overlapping keyframes which bind to currrent active localMLP and another inactive localMLP have non-zero values)
        ovlp_mutex = self.keyframe_mutex_mask.clone()[0, :num_kf]  # Tensor(num_kf, ), 0/1/-1
        ovlp_mutex_mask = torch.where(ovlp_mutex==process_flag, torch.zeros_like(ovlp_mutex), torch.ones_like(ovlp_mutex))  # Tensor(num_kf, ), 0/1

        # Step 2: find first keyframe of given localMLP (world pose and keyframe_Id)
        first_kf_pose, first_kf_Id = self.extract_first_kf_pose(localMLP_Id, kf_poses)  # first keyframe's pose in World Coordinate System / kf_Id of given localMLP, Tensor(4, 4)/Tensor(, )
        first_kf_pose = first_kf_pose.detach()

        # Step 3: find all available keyframes (1.it must be related keyframe; 2.for overlapping keyframe, its last optimization must be done in another process)
        related_kf_mask = self.get_related_keyframes(localMLP_Id, num_kf)  # Tensor(num_kf, ), 0/1
        kf_mask = related_kf_mask * ovlp_mutex_mask  # mask of all available keyframes, Tensor(num_kf, ) 0/1

        avail_kf_Ids = torch.where(kf_mask > 0)[0]  # keyframe_Ids of all available keyframes, Tensor(n', )
        avail_kf_ref = kf_ref[avail_kf_Ids]  # keyframe_ref type of all available keyframes, Tensor(n', ), n(>=0)/-1/-2
        avail_kf_frame_Ids = avail_kf_Ids * self.config["mapping"]["keyframe_every"]  # frame_Ids of all available keyframes, Tensor(n', )

        avail_ovlp_kf_idx = torch.where(avail_kf_ref == -2)[0]  # overlapping keyframes' indices in avail_kf_Ids, Tensor(k', )
        avail_ovlp_kf_Ids = avail_kf_Ids[avail_ovlp_kf_idx]  # keyframe_Ids of all available overlapping keyframes, Tensor(k', )

        # Step 4: extract local pose of all available keyframes in localMLP_Id's Local Coordinate System
        # 4.1: local pose of all ordinary keyframes (whose keyframe_ref >= 0)
        first_pose_local = torch.eye(4).to(self.device)
        poses_local = est_c2w_data[avail_kf_frame_Ids]  # local pose of all related keyframes (indexed by frame_Id)
        poses_local[0] = first_pose_local

        # 4.2: for available keyframes which are first keyframe of another localMLP: firstly extract their world poses, and then convert them to local poses
        ano_first_kf_idx = torch.where( torch.logical_and(avail_kf_Ids != first_kf_Id, avail_kf_ref == -1) )[0]
        if ano_first_kf_idx.shape[0] > 0:
            ano_first_kf_Ids = avail_kf_Ids[ano_first_kf_idx]
            ano_first_kf_poses_world = kf_poses[ano_first_kf_Ids]  # Tensor(m', 4, 4)
            ano_first_kf_poses_local = first_kf_pose.inverse().unsqueeze(0) @ ano_first_kf_poses_world
            poses_local[ano_first_kf_idx] = ano_first_kf_poses_local

        # 4.3: for available keyframes which are overlapping keyframes
        if avail_ovlp_kf_idx.shape[0] > 0:
            ovlp_pose_local = poses_local[avail_ovlp_kf_idx]
            keyframe_localMLP = self.keyframe_localMLP[avail_ovlp_kf_Ids]  # Tensor(k', )
            localMLP_hit_dix = self.get_related_localMLP_index(keyframe_localMLP, localMLP_Id)  # Tensor(k', )
            ovlp_pose_local_given = self.convert_given_local_pose(keyframe_localMLP, localMLP_hit_dix, kf_poses, first_kf_pose, ovlp_pose_local)  # Tensor(k', )
            poses_local[avail_ovlp_kf_idx] = ovlp_pose_local_given

        return first_kf_pose, first_kf_Id, poses_local, avail_kf_Ids, avail_kf_frame_Ids, avail_kf_ref, avail_ovlp_kf_idx, avail_ovlp_kf_Ids


    # @brief: extract related vars of a given localMLP_Id;
    # @param localMLP_Id: given localMLP_Id, Tensor(, );
    # @param kf_poses: tensor storing all first keyframes' world poses, Tensor(num_kf, 4, 4);
    # @param est_c2w_data: tensor storing all keyframes' local poses(in its first related localMLP), Tensor(num_frame, 4, 4);
    # @param kf_ref: keyframe ref type of each keyframe, Tensor(num_kf, );
    # @param process_flag: Process_flag of invoking process (1: ActiveMap process, -1: InactiveMap process), int;
    #-@return first_kf_pose: first keyframe's pose in World Coordinate System of given localMLP, Tensor(4, 4);
    #-@return first_kf_Id: first keyframe's kf_Id of given localMLP, Tensor(, );
    #-@return poses_local: local pose (in given localMLP's coordinate system) of related keyframes, Tensor(selected_num_kf, 4, 4);
    #-@return avail_kf_Ids: keyframe_Ids of all available keyframes related to given localMLP, Tensor(selected_num_kf, );
    #-@return avail_kf_frame_Ids: frame_Ids of all available keyframes related to given localMLP, Tensor(selected_num_kf, );
    #-@return avail_kf_ref: Tensor(n', );
    #-@return avail_ovlp_kf_idx: indices of available overlapping keyframes in avail_kf_Ids, Tensor(k', )
    #-@return avail_ovlp_kf_Ids: keyframe_Ids of available overlapping keyframes, Tensor(k', ).
    def extract_localMLP_vars_given(self, localMLP_Id, given_kf_Ids, kf_poses, est_c2w_data, kf_ref):
        given_kf_Ids = torch.sort(given_kf_Ids)[0]

        # Step 1: find first keyframe of given localMLP (world pose and keyframe_Id)
        first_kf_pose, first_kf_Id = self.extract_first_kf_pose(localMLP_Id, kf_poses)  # first keyframe's pose in World Coordinate System / kf_Id of given localMLP, Tensor(4, 4)/Tensor(, )
        first_kf_pose = first_kf_pose.detach()

        # Step 2:
        given_kf_ref = kf_ref[given_kf_Ids]  # keyframe_ref type of all given keyframes, Tensor(n', ), n(>=0)/-1/-2
        given_kf_frame_Ids = given_kf_Ids * self.config["mapping"]["keyframe_every"]  # frame_Ids of all given keyframes, Tensor(n', )

        given_ovlp_kf_idx = torch.where(given_kf_ref == -2)[0]  # overlapping keyframes' indices in given_kf_Ids, Tensor(k', )
        given_ovlp_kf_Ids = given_kf_Ids[given_ovlp_kf_idx]  # keyframe_Ids of all given overlapping keyframes, Tensor(k', )

        # Step 3: extract local pose of all given keyframes in given localMLP_Id's Local Coordinate System
        # 3.1: local pose of all ordinary keyframes (whose keyframe_ref >= 0)
        poses_local = est_c2w_data[given_kf_frame_Ids]  # local pose of all given keyframes (indexed by frame_Id)

        if given_kf_Ids[0] == first_kf_Id:
            first_pose_local = torch.eye(4).to(self.device)
            poses_local[0] = first_pose_local

        # 3.2: for given keyframes which are first keyframe of another localMLP: firstly extract their world poses, and then convert them to local poses
        ano_first_kf_idx = torch.where( torch.logical_and(given_kf_Ids != first_kf_Id, given_kf_ref == -1) )[0]
        if ano_first_kf_idx.shape[0] > 0:
            ano_first_kf_Ids = given_kf_Ids[ano_first_kf_idx]
            ano_first_kf_poses_world = kf_poses[ano_first_kf_Ids]  # Tensor(m', 4, 4)
            ano_first_kf_poses_local = first_kf_pose.inverse().unsqueeze(0) @ ano_first_kf_poses_world
            poses_local[ano_first_kf_idx] = ano_first_kf_poses_local

        # 3.3: for given keyframes which are overlapping keyframes
        if given_ovlp_kf_idx.shape[0] > 0:
            ovlp_pose_local = poses_local[given_ovlp_kf_idx]
            keyframe_localMLP = self.keyframe_localMLP[given_ovlp_kf_Ids]  # Tensor(k', )
            localMLP_hit_dix = self.get_related_localMLP_index(keyframe_localMLP, localMLP_Id)  # Tensor(k', )
            ovlp_pose_local_given = self.convert_given_local_pose(keyframe_localMLP, localMLP_hit_dix, kf_poses, first_kf_pose, ovlp_pose_local)  # Tensor(k', )
            poses_local[given_ovlp_kf_idx] = ovlp_pose_local_given

        return first_kf_pose, first_kf_Id, poses_local, given_kf_Ids, given_kf_frame_Ids, given_kf_ref, given_ovlp_kf_idx, given_ovlp_kf_Ids
back to top