HR-VITON代码笔记二
train_condition.py import torch import torch.nn as nn from torchvision.utils import make_grid from networks import make_grid as mkgrid import argparse import os import time from cp_dataset import CPDataset, CPDatasetTest, CPDataLoader from networks import ConditionGenerator, VGGLoss, GANLoss, load_checkpoint, save_checkpoint, define_D from tqdm import tqdm from tensorboardX import SummaryWriter from utils import * from torch.utils.data import Subset 引入了很多库。 def iou_metric(y_pred_batch, y_true_batch): B = y_pred_batch.shape[0] iou = 0 for i in range(B): y_pred = y_pred_batch[i] y_true = y_true_batch[i] # y_pred is not one-hot, so need to threshold it y_pred = y_pred > 0.5 y_pred = y_pred.flatten() y_true = y_true.flatten() intersection = torch.sum(y_pred[y_true == 1]) union = torch.sum(y_pred) + torch.sum(y_true) iou += (intersection + 1e-7) / (union - intersection + 1e-7) / B return iou 这个函数是用来计算IoU的,即Intersection Over Union,计算公式为 $$IoU = \frac{\text{Area of Intersection}}{\text{Area of Union}}$$ 这里代码中分母用的是union - intersection的原因是,代码中的"union"不是真正的并集,而是把他们两个的区域面积直接加起来,多算了一次交集的面积,所以要减掉。 def remove_overlap(seg_out, warped_cm): assert len(warped_cm.shape) == 4 warped_cm = warped_cm - (torch.cat([seg_out[:, 1:3, :, :], seg_out[:, 5:, :, :]], dim=1)).sum(dim=1, keepdim=True) * warped_cm return warped_cm 其中torch.cat([seg_out[:, 1:3, :, :], seg_out[:, 5:, :, :]], dim=1)表示segmentation中可能和clothes mask重叠的人体部位。 ...