demo.cpp
#include "gms_matcher.h"
//#define USE_GPU
#ifdef USE_GPU
#include <opencv2/cudafeatures2d.hpp>
using cuda::GpuMat;
#endif
void GmsMatch(Mat &img1, Mat &img2);
Mat DrawInlier(Mat &src1, Mat &src2, vector<KeyPoint> &kpt1, vector<KeyPoint> &kpt2, vector<DMatch> &inlier, int type);
void runImagePair() {
Mat img1 = imread("../data/01.jpg");
Mat img2 = imread("../data/02.jpg");
GmsMatch(img1, img2);
}
int main()
{
#ifdef USE_GPU
int flag = cuda::getCudaEnabledDeviceCount();
if (flag != 0) { cuda::setDevice(0); }
#endif // USE_GPU
runImagePair();
return 0;
}
void GmsMatch(Mat &img1, Mat &img2) {
vector<KeyPoint> kp1, kp2;
Mat d1, d2;
vector<DMatch> matches_all, matches_gms;
Ptr<ORB> orb = ORB::create(10000);
orb->setFastThreshold(0);
orb->detectAndCompute(img1, Mat(), kp1, d1);
orb->detectAndCompute(img2, Mat(), kp2, d2);
#ifdef USE_GPU
GpuMat gd1(d1), gd2(d2);
Ptr<cuda::DescriptorMatcher> matcher = cv::cuda::DescriptorMatcher::createBFMatcher(NORM_HAMMING);
matcher->match(gd1, gd2, matches_all);
#else
BFMatcher matcher(NORM_HAMMING);
matcher.match(d1, d2, matches_all);
#endif
// GMS filter
std::vector<bool> vbInliers;
gms_matcher gms(kp1, img1.size(), kp2, img2.size(), matches_all);
int num_inliers = gms.GetInlierMask(vbInliers, false, false);
cout << "Get total " << num_inliers << " matches." << endl;
// collect matches
for (size_t i = 0; i < vbInliers.size(); ++i)
{
if (vbInliers[i] == true)
{
matches_gms.push_back(matches_all[i]);
}
}
// draw matching
Mat show = DrawInlier(img1, img2, kp1, kp2, matches_gms, 1);
imshow("show", show);
waitKey();
}
Mat DrawInlier(Mat &src1, Mat &src2, vector<KeyPoint> &kpt1, vector<KeyPoint> &kpt2, vector<DMatch> &inlier, int type) {
const int height = max(src1.rows, src2.rows);
const int width = src1.cols + src2.cols;
Mat output(height, width, CV_8UC3, Scalar(0, 0, 0));
src1.copyTo(output(Rect(0, 0, src1.cols, src1.rows)));
src2.copyTo(output(Rect(src1.cols, 0, src2.cols, src2.rows)));
if (type == 1)
{
for (size_t i = 0; i < inlier.size(); i++)
{
Point2f left = kpt1[inlier[i].queryIdx].pt;
Point2f right = (kpt2[inlier[i].trainIdx].pt + Point2f((float)src1.cols, 0.f));
line(output, left, right, Scalar(0, 255, 255));
}
}
else if (type == 2)
{
for (size_t i = 0; i < inlier.size(); i++)
{
Point2f left = kpt1[inlier[i].queryIdx].pt;
Point2f right = (kpt2[inlier[i].trainIdx].pt + Point2f((float)src1.cols, 0.f));
line(output, left, right, Scalar(255, 0, 0));
}
for (size_t i = 0; i < inlier.size(); i++)
{
Point2f left = kpt1[inlier[i].queryIdx].pt;
Point2f right = (kpt2[inlier[i].trainIdx].pt + Point2f((float)src1.cols, 0.f));
circle(output, left, 1, Scalar(0, 255, 255), 2);
circle(output, right, 1, Scalar(0, 255, 0), 2);
}
}
return output;
}