牛客暑期多校第九场补题
Note 这是一条需要注意的普通信息。 I题 题解 把 $L$ 和 $R$ 的高 $\dfrac{n}{2}$ 位计作 $L_{1}$ 和 $R_{1}$,把低 $\dfrac{n}{2}$ 位计作 $L_{2}$ 和 $R_{2}$。大概的思路即为统计 $[L_{1}, R_{1}]$ 和 $[L_{2}, R_{2}]$ 分别有多少个平方数,处理一些细节并组合起来即可。这道题数字还蛮大,用python来实现更简单一点。 代码 from math import sqrt, isqrt def count_square(l, r): if (l > r): return 0 l = int(l) r = int(r) ans = isqrt(r) - isqrt(l) if (isqrt(l) * isqrt(l) == l): ans = ans + 1 return ans # 读取 n n = int(input("")) # 读取两个字符串 str1, str2 = input().split() mid_index = n // 2 # 取出字符串的高 n/2 位和第 n/2 位 high_str1 = str1[:mid_index] low_str1 = str1[mid_index:] high_str2 = str2[:mid_index] low_str2 = str2[mid_index:] # 转化为整数 l1 = int(high_str1) r1 = int(low_str1) l2 = int(high_str2) r2 = int(low_str2) cnt1 = count_square(l1 + 1, l2 - 1) cnt2 = count_square(0, 10 ** (n//2) - 1) ans = cnt1 * cnt2 if (isqrt(l1) * isqrt(l1) == l1): ans += count_square(r1, 10 ** (n//2) - 1) if (l1 != l2): if (isqrt(l2) * isqrt(l2) == l2): ans += count_square(0, r2) print("%d" % ans) K题 题目描述 ...
HR-VITON代码笔记二
train_condition.py import torch import torch.nn as nn from torchvision.utils import make_grid from networks import make_grid as mkgrid import argparse import os import time from cp_dataset import CPDataset, CPDatasetTest, CPDataLoader from networks import ConditionGenerator, VGGLoss, GANLoss, load_checkpoint, save_checkpoint, define_D from tqdm import tqdm from tensorboardX import SummaryWriter from utils import * from torch.utils.data import Subset 引入了很多库。 def iou_metric(y_pred_batch, y_true_batch): B = y_pred_batch.shape[0] iou = 0 for i in range(B): y_pred = y_pred_batch[i] y_true = y_true_batch[i] # y_pred is not one-hot, so need to threshold it y_pred = y_pred > 0.5 y_pred = y_pred.flatten() y_true = y_true.flatten() intersection = torch.sum(y_pred[y_true == 1]) union = torch.sum(y_pred) + torch.sum(y_true) iou += (intersection + 1e-7) / (union - intersection + 1e-7) / B return iou 这个函数是用来计算IoU的,即Intersection Over Union,计算公式为 $$IoU = \frac{\text{Area of Intersection}}{\text{Area of Union}}$$ 这里代码中分母用的是union - intersection的原因是,代码中的"union"不是真正的并集,而是把他们两个的区域面积直接加起来,多算了一次交集的面积,所以要减掉。 def remove_overlap(seg_out, warped_cm): assert len(warped_cm.shape) == 4 warped_cm = warped_cm - (torch.cat([seg_out[:, 1:3, :, :], seg_out[:, 5:, :, :]], dim=1)).sum(dim=1, keepdim=True) * warped_cm return warped_cm 其中torch.cat([seg_out[:, 1:3, :, :], seg_out[:, 5:, :, :]], dim=1)表示segmentation中可能和clothes mask重叠的人体部位。 ...