https://gitlab.cs.duke.edu/bartesaghilab/smartscopeAI
Tip revision: 43b29ae8c333a94463e0a4d9ecb97a5d5b6adf92 authored by Alberto Bartesaghi on 03 August 2022, 01:01:21 UTC
Update README.md
Update README.md
Tip revision: 43b29ae
detect_squares.py
from detectron2.structures import BoxMode
import argparse
from detectron2.data.datasets import register_coco_instances
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
# import some common libraries
import numpy as np
import os, json, cv2, random
from detectron2.engine import DefaultTrainer
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
import detectron2.data.transforms as T
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.data import MetadataCatalog, DatasetCatalog
import glob
import argparse
# from sklearn.preprocessing import StandardScaler
import time
# from sklearn.decomposition import PCA
from pathlib import Path
import glob
import cv2
import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from PIL import Image
from numpy import random
from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box, to_shape, check_if_square
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized
import mrcfile
import torchvision
from utils.common_config import get_train_transformations, get_val_transformations,\
get_train_dataset, get_train_dataloader,\
get_val_dataset, get_val_dataloader,\
get_optimizer, get_model, get_criterion,\
adjust_learning_rate
from models.resnet_squares import resnet18
from models.models import ClusteringModel
from models.models import ContrastiveModel
from sklearn.cluster import KMeans
class GrayscalePredictor:
def __init__(self, cfg):
self.cfg = cfg.clone() # cfg can be modified by model
self.model = build_model(self.cfg)
self.model.eval()
if len(cfg.DATASETS.TEST):
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
checkpointer = DetectionCheckpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS)
self.aug = T.ResizeShortestEdge(
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
)
self.input_format = cfg.INPUT.FORMAT
def __call__(self, original_image):
"""
original image input is H * W * C
"""
with torch.no_grad():
image = self.aug.get_transform(original_image).apply_image(original_image)
image = np.expand_dims(image, -1)
height, width = image.shape[:2]
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs = {"image": image, "height": height, "width": width}
predictions = self.model([inputs])[0]
return predictions
def auto_contrast(img, cutperc=[0.05, 0.01], to_8bits=True):
hist, x = np.histogram(img.flatten(), bins=256)
total = np.sum(hist)
min_side = 0
min_accum = 0
max_side = 255
max_accum = 0
while min_accum < cutperc[0]:
min_accum += hist[min_side] / total * 100
min_side += 1
while max_accum < cutperc[1]:
max_accum += hist[max_side] / total * 100
max_side -= 1
# print(f'Using auto_contrast {min_side} ({x[min_side]}), {max_side} ({x[max_side]})')
max_side = x[max_side] - x[min_side]
img = (img.astype('float32') - x[min_side]) / max_side
img[img < 0] = 0
img[img > 1] = 1
if to_8bits is True:
return np.round(img * 255).astype('uint8')
else:
return img
def detect_grayscale(atlas, device = '0', imgsz = 2048, thresh = 0.2, iou = 0.3, weights = 'runs/square_weights/best_square.pth'):
"""
Detect squares on a atlas
Parameters
--------------
atlas: the atlas.mrc file or a numpy array of the atlas
device: optional, '0' for gpu and 'cpu' for cpu
imgsz: resized atlas size, 2048 is the default
thresh: confidence threshold for detection, 0.2 for default
weights: weights for detector
Return
--------------
square coordinates and type
"""
label_dict = ['squares','contaminated','cracked','dry','fraction','small','square']
all_coords = []
all_labels = []
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 8 # faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 7
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = thresh
cfg.MODEL.RPN.PRE_NMS_TOPK_TEST = 2000
cfg.MODEL.WEIGHTS = weights
cfg.INPUT.MIN_SIZE_TEST = 2048
cfg.INPUT.MAX_SIZE_TEST = 2048
cfg.TEST.DETECTIONS_PER_IMAGE = 400
cfg.INPUT.FORMAT = "L"
cfg.MODEL.PIXEL_MEAN = [26.0]
cfg.MODEL.PIXEL_STD = [1.0]
if device == 'cpu':
cfg.MODEL.DEVICE='cpu'
if isinstance(atlas, str):
with mrcfile.open(atlas) as mrc:
atlas = mrc.data
else:
atlas = atlas
# print('atlas,', atlas.shape)
r, c = atlas.shape
atlas = cv2.normalize(atlas, None, 0, 255,cv2.NORM_MINMAX)
atlas = auto_contrast(atlas)
if r != c:
max_shape = np.max([r, c])
holder = np.zeros((max_shape, max_shape))
holder[:r, :c] = atlas
img = holder
# img = to_shape(atlas, (max_shape, max_shape))
else:
img = atlas
max_shape = r
resized = cv2.resize(img, (imgsz, imgsz))
resized = cv2.normalize(resized, None, 0, 255, cv2.NORM_MINMAX)
resized = np.uint8(resized)
scale = max_shape / imgsz
img_mult = np.dstack((resized, resized, resized))
with torch.no_grad():
print('here?')
predictor_square = GrayscalePredictor(cfg)
print('yeah??')
outputs = predictor_square(resized)
out_inst = outputs["instances"].to("cpu")
pred_box = out_inst.pred_boxes
scores = out_inst.scores
pred_classes = out_inst.pred_classes
box_tensor = pred_box.tensor
keep = torchvision.ops.nms(box_tensor, scores, iou)
pred_box_keep = pred_box[keep]
scores_keep = scores[keep]
pred_classes_keep = pred_classes[keep]
pred_box_keep.scale(scale, scale)
ind = 0
for sq_coords in pred_box_keep:
if_square = check_if_square(sq_coords[0],sq_coords[1],sq_coords[2],sq_coords[3], 1.7, 0.3)
if 1:
all_coords.append(sq_coords)
lb = pred_classes_keep[ind]
all_labels.append(label_dict[lb])
ind += 1
torch.cuda.empty_cache()
return all_coords, all_labels, img_mult, scale
def detect(atlas, device = '0', imgsz = 2048, thresh = 0.2, iou = 0.3, weights = 'runs/square_weights/best_square.pth'):
"""
Detect squares on a atlas
Parameters
--------------
atlas: the atlas.mrc file or a numpy array of the atlas
device: optional, '0' for gpu and 'cpu' for cpu
imgsz: resized atlas size, 2048 is the default
thresh: confidence threshold for detection, 0.2 for default
weights: weights for detector
Return
--------------
square coordinates and type
"""
label_dict = ['squares','contaminated','cracked','dry','fraction','small','square']
all_coords = []
all_labels = []
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 8 # faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 7
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = thresh
cfg.MODEL.RPN.PRE_NMS_TOPK_TEST = 2000
cfg.MODEL.WEIGHTS = weights
cfg.INPUT.MIN_SIZE_TEST = 2048
cfg.INPUT.MAX_SIZE_TEST = 2048
cfg.TEST.DETECTIONS_PER_IMAGE = 400
if device == 'cpu':
cfg.MODEL.DEVICE='cpu'
if isinstance(atlas, str):
with mrcfile.open(atlas) as mrc:
atlas = mrc.data
else:
atlas = atlas
# print('atlas,', atlas.shape)
r, c = atlas.shape
atlas = cv2.normalize(atlas, None, 0, 255,cv2.NORM_MINMAX)
atlas = auto_contrast(atlas)
if r != c:
max_shape = np.max([r, c])
holder = np.zeros((max_shape, max_shape))
holder[:r, :c] = atlas
img = holder
# img = to_shape(atlas, (max_shape, max_shape))
else:
img = atlas
max_shape = r
resized = cv2.resize(img, (imgsz, imgsz))
resized = cv2.normalize(resized, None, 0, 255, cv2.NORM_MINMAX)
resized = np.uint8(resized)
scale = max_shape / imgsz
img_mult = np.dstack((resized, resized, resized))
with torch.no_grad():
print('here?')
predictor_square = DefaultPredictor(cfg)
print('yeah??')
outputs = predictor_square(img_mult)
out_inst = outputs["instances"].to("cpu")
pred_box = out_inst.pred_boxes
scores = out_inst.scores
pred_classes = out_inst.pred_classes
box_tensor = pred_box.tensor
keep = torchvision.ops.nms(box_tensor, scores, iou)
pred_box_keep = pred_box[keep]
scores_keep = scores[keep]
pred_classes_keep = pred_classes[keep]
pred_box_keep.scale(scale, scale)
ind = 0
for sq_coords in pred_box_keep:
if_square = check_if_square(sq_coords[0],sq_coords[1],sq_coords[2],sq_coords[3], 1.7, 0.3)
if 1:
all_coords.append(sq_coords)
lb = pred_classes_keep[ind]
all_labels.append(label_dict[lb])
ind += 1
torch.cuda.empty_cache()
return all_coords, all_labels, img_mult, scale
# detect()