一、在BasicIRSTD中的计算
- 这个函数不会用于训练中,只会在测试中进行使用
- 并且这种只初始化时reset一次,其余时候全程累加
class PD_FA():
def __init__(self,):
super(PD_FA, self).__init__()
self.image_area_total = []
self.image_area_match = []
self.dismatch_pixel = 0
self.all_pixel = 0
self.PD = 0
self.target= 0
def update(self, preds, labels, size):
predits = np.array((preds).cpu()).astype('int64')
labelss = np.array((labels).cpu()).astype('int64')
image = measure.label(predits, connectivity=2)
coord_image = measure.regionprops(image)
label = measure.label(labelss , connectivity=2)
coord_label = measure.regionprops(label)
self.target += len(coord_label)
self.image_area_total = []
self.distance_match = []
self.dismatch = []
for K in range(len(coord_image)):
area_image = np.array(coord_image[K].area)
self.image_area_total.append(area_image)
true_img = np.zeros(predits.shape)
for i in range(len(coord_label)):
centroid_label = np.array(list(coord_label[i].centroid))
for m in range(len(coord_image)):
centroid_image = np.array(list(coord_image[m].centroid))
distance = np.linalg.norm(centroid_image - centroid_label)
area_image = np.array(coord_image[m].area)
if distance < 3:
self.distance_match.append(distance)
true_img[coord_image[m].coords[:,0], coord_image[m].coords[:,1]] = 1
del coord_image[m]
break
self.dismatch_pixel += (predits - true_img).sum()
self.all_pixel +=size[0]*size[1]
self.PD +=len(self.distance_match)
def get(self):
Final_FA = self.dismatch_pixel / self.all_pixel
Final_PD = self.PD /self.target
return Final_PD, float(Final_FA.cpu().detach().numpy())
def reset(self):
self.FA = np.zeros([self.bins+1])
self.PD = np.zeros([self.bins+1])
二、在SCTransNet中的计算
- 这个函数不会用于训练中,只会在测试中进行使用
- 并且这种只初始化时reset一次,其余时候全程累加
class PD_FA():
def __init__(self, ):
super(PD_FA, self).__init__()
self.image_area_total = []
self.image_area_match = []
self.dismatch_pixel = 0
self.all_pixel = 0
self.PD = 0
self.target = 0
def update(self, preds, labels, size):
predits = np.array((preds).cpu()).astype('int64')
labelss = np.array((labels).cpu()).astype('int64')
image = measure.label(predits, connectivity=2)
coord_image = measure.regionprops(image)
label = measure.label(labelss, connectivity=2)
coord_label = measure.regionprops(label)
self.target += len(coord_label)
self.image_area_total = []
self.image_area_match = []
self.distance_match = []
self.dismatch = []
for K in range(len(coord_image)):
area_image = np.array(coord_image[K].area)
self.image_area_total.append(area_image)
for i in range(len(coord_label)):
centroid_label = np.array(list(coord_label[i].centroid))
for m in range(len(coord_image)):
centroid_image = np.array(list(coord_image[m].centroid))
distance = np.linalg.norm(centroid_image - centroid_label)
area_image = np.array(coord_image[m].area)
if distance < 3:
self.distance_match.append(distance)
self.image_area_match.append(area_image)
del coord_image[m]
break
self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match]
self.dismatch_pixel += np.sum(self.dismatch)
self.all_pixel += size[0] * size[1]
self.PD += len(self.distance_match)
def get(self):
Final_FA = self.dismatch_pixel / self.all_pixel
Final_PD = self.PD / self.target
return Final_PD, float(Final_FA.cpu().detach().numpy())
def reset(self):
self.FA = np.zeros([self.bins + 1])
self.PD = np.zeros([self.bins + 1])
三、在SeRankDet中的计算
- 这个函数不会用于训练中,只会在测试中进行使用
- 并且这种只初始化时reset一次,其余时候全程累加
class PD_FA():
def __init__(self, nclass, bins, cfg):
super(PD_FA, self).__init__()
self.nclass = nclass
self.bins = bins
self.image_area_total = []
self.image_area_match = []
self.FA = np.zeros(self.bins + 1)
self.PD = np.zeros(self.bins + 1)
self.target = np.zeros(self.bins + 1)
self.cfg = cfg
def update(self, preds, labels):
for iBin in range(self.bins + 1):
score_thresh = iBin * (255 / self.bins)
batch = preds.size()[0]
for b in range(batch):
predits = np.array((preds[b, :, :, :] > score_thresh).cpu()).astype('int64')
predits = np.reshape(predits, (self.cfg.data['crop_size'], self.cfg.data['crop_size']))
labelss = np.array((labels[b, :, :, :]).cpu()).astype('int64')
labelss = np.reshape(labelss, (self.cfg.data['crop_size'], self.cfg.data['crop_size']))
image = measure.label(predits, connectivity=2)
coord_image = measure.regionprops(image)
label = measure.label(labelss, connectivity=2)
coord_label = measure.regionprops(label)
self.target[iBin] += len(coord_label)
self.image_area_total = []
self.image_area_match = []
self.distance_match = []
self.dismatch = []
for K in range(len(coord_image)):
area_image = np.array(coord_image[K].area)
self.image_area_total.append(area_image)
for i in range(len(coord_label)):
centroid_label = np.array(list(coord_label[i].centroid))
for m in range(len(coord_image)):
centroid_image = np.array(list(coord_image[m].centroid))
distance = np.linalg.norm(centroid_image - centroid_label)
area_image = np.array(coord_image[m].area)
if distance < 3:
self.distance_match.append(distance)
self.image_area_match.append(area_image)
del coord_image[m]
break
self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match]
self.FA[iBin] += np.sum(self.dismatch)
self.PD[iBin] += len(self.distance_match)
def get(self, img_num):
Final_FA = self.FA / ((self.cfg.data['crop_size'] * self.cfg.data['crop_size']) * img_num)
Final_PD = self.PD / self.target
return Final_FA, Final_PD
def reset(self):
self.FA = np.zeros([self.bins + 1])
self.PD = np.zeros([self.bins + 1])
self.target = np.zeros(self.bins + 1)