Virtual Adversarial Training文章解读+算法流程+核心代码详解
Virtual Adversarial Training
本博客仅做算法流程疏导,具体细节请参见原文
原文
Github代码
解读
对比Adversarial Training和VAT
VAT(Virtual Adversarial Training)和adversarial training类似。对原始训练样本添加一个比较小的扰动,会大概率使分类器分类出现错误,而我们一般希望分类器将原始样本和添加一个较小扰动的样本(加噪版本)分为同一类别,所以将扰动版本的数据也作为训练样本添加进训练,这样就增加了分类器的泛化能力。
传统的adversarial training 的扰动方向一般通过损失函数确定,即取损失函数上升的方向添加一个扰动。无标记样本没有标签,就无法算损失函数,故传统方法不适用,所以一般的adversarial training仅在监督学习中使用较多,而virtual adversarial training的创新在于能在无标记样本上实现扰动的计算,因为没用使用标签进行运算,而是用模型预测的结果替代标签,类似于persudo label,这就是virtual的含义
Adversarial Training
adversarial training的数学表达如下,其中样本及标记$(x,y)$,当前epoch模型的参数$\theta$:
损失函数:$J(\theta)=\frac{1}{N}\sum^{N}{i=1}L(x,\theta)$
其中,单项损失计算表达式为:$L(x,\theta)=D(y,p(y|x+r,\theta))$
扰动方向:$r=argmax{|r|<\xi}D(y,p(y|x+r,\theta))$
简单叙述为:找到一个扰动$r$,且$r$的大小受限,即$|r|<\xi$,使其损失函数$L(x,\theta)=D(y,p(y|x+r,\theta))$取最大值,即在此$r$下上升最多。
VAT
同样形式的,virtual adversarial training 的数学表达式如下,其中其中样本及标记$(x,y)$,当前epoch模型的参数$\theta$,前一个epoch的模型参数为$\hat{\theta}$:
损失函数同上形式:$J(\theta)=\frac{1}{N}\sum^N{i=1}L(x,\theta)$
单项损失表达式==不同==(LDS称为局部平滑度):$L(x,\theta)=D(p(y|x,\hat\theta),p(y|x+r,\theta))=LDS(x,\theta)$
扰动方向:$r=argmax{|r|<\xi}D(p(y|x,\theta),p(y|x+r,\theta))$
简单叙述为:找到一个扰动$r$,且$r$的大小受限,即$|r|<\xi$,使其损失函数$LDS(x,\theta)$取的最大值,即在此$r$下上升最多。
代码详解
代码核心就一个VAT_Loss的计算。整个框架的Loss=Classfier_Loss + VAT_Loss。其中Classfier_Loss损失函数为一般的监督网络的损失函数。VAT_Loss计算如下:
def vat_loss(model, ul_x, ul_y, xi=1e-6, eps=2.5, num_iters=1):
# find r_adv
d = torch.Tensor(ul_x.size()).normal_()
for i in range(num_iters):
d = xi *_l2_normalize(d)
d = Variable(d.cuda(), requires_grad=True)
y_hat = model(ul_x + d)
delta_kl = kl_div_with_logit(ul_y.detach(), y_hat)
delta_kl.backward()
d = d.clone().cpu()
model.zero_grad()
d = _l2_normalize(d)
d = Variable(d.cuda())
r_adv = eps * d
# compute lds
y_hat = model(ul_x + r_adv.detach())
delta_kl = kl_div_with_logit(ul_y.detach(), y_hat)
return delta_kl
其中对r_adv的计算采用的是一种快速计算方法。具体理论请查阅原文
v_loss = vat_loss(model, inputs_All, logits_All, eps=args.epsilon)
loss = v_loss+ce_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
完整损失函数Loss=Classfier_Loss + VAT_Loss反向梯度传播更新网络即可。
本文作者: Joffrey-Luo Cheng
本文链接: http://lcjoffrey.top/2021/12/04/VAT/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!