COSFORMER: RETHINKING SOFTMAX IN ATTENTION

Oi=jS(Qi,Kj)jS(Qi,Kj)Vj\mathcal{O}_{i}=\sum_{j} \frac{\mathcal{S}\left(Q_{i}, K_{j}\right)}{\sum_{j} \mathcal{S}\left(Q_{i}, K_{j}\right)} V_{j}

首先,为了解决时间空间复杂度是平方的问题。
通过之前的研究,可以有这种方法。

S(Qi,Kj)=ϕ(Qi)ϕ(Kj)T\mathcal{S}\left(Q_{i}, K_{j}\right)=\phi\left(Q_{i}\right) \phi\left(K_{j}\right)^{T}

Oi=j=1N(ϕ(Qi)ϕ(Kj)T)Vjj=1N(ϕ(Qi)ϕ(Kj)T)O_{i}=\frac{\sum_{j=1}^{N}\left(\phi\left(Q_{i}\right) \phi\left(K_{j}\right)^{T}\right) V_{j}}{\sum_{j=1}^{N}\left(\phi\left(Q_{i}\right) \phi\left(K_{j}\right)^{T}\right)}

(ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)\left(\phi(Q) \phi(K)^{T}\right) V=\phi(Q)\left(\phi(K)^{T} V\right)

就是找到合适的核函数 ϕ\phi 来模拟逼近 Softmax

由于矩阵的运算顺序变得可交换,来缩小运算量


使用了这几种来做对比

所以本文最后选取了 ReLU

S(Q,K)=s(ϕlinear (Q),ϕlinear (K))=s(Q,K)\mathcal{S}(Q, K)=\mathrm{s}\left(\phi_{\text {linear }}(Q), \phi_{\text {linear }}(K)\right)=\mathrm{s}\left(Q^{\prime}, K^{\prime}\right)

ϕlinear (x)=ReLU(x)\phi_{\text {linear }}(x)=\operatorname{ReLU}(x)

Oi=j=1Nf(ϕlinear (Qi),ϕlinear (Kj))Vjj=1Nf(ϕlinear (Qi),ϕlinear (Kj))=j=1N(ReLU(Qi)ReLU(Kj)T)Vjj=1N(ReLU(Qi)ReLU(Kj)T)\mathcal{O}_{i}=\frac{\sum_{j=1}^{N} f\left(\phi_{\text {linear }}\left(Q_{i}\right), \phi_{\text {linear }}\left(K_{j}\right)\right) V_{j}}{\sum_{j=1}^{N} f\left(\phi_{\text {linear }}\left(Q_{i}\right), \phi_{\text {linear }}\left(K_{j}\right)\right)}=\frac{\sum_{j=1}^{N}\left(\operatorname{ReLU}\left(Q_{i}\right) \operatorname{ReLU}\left(K_{j}\right)^{T}\right) V_{j}}{\sum_{j=1}^{N}\left(\operatorname{ReLU}\left(Q_{i}\right) \operatorname{ReLU}\left(K_{j}\right)^{T}\right)}

Oi=ReLU(Qi)j=1NReLU(Kj)TVjReLU(Qi)j=1NReLU(Kj)T\mathcal{O}_{i}=\frac{\operatorname{ReLU}\left(Q_{i}\right) \sum_{j=1}^{N} \operatorname{ReLU}\left(K_{j}\right)^{T} V_{j}}{\operatorname{ReLU}\left(Q_{i}\right) \sum_{j=1}^{N} \operatorname{ReLU}\left(K_{j}\right)^{T}}

在 Softmax 注意力中引入非线性重加权机制可以聚集注意力权重的分布,因而稳定训练过程。研究者还通过实证发现,这种做法可以惩罚远距离连接,并在某些情况下加强局部性。

所以需要加强局部性,提出一种 re-weight 方法来加强。

s(Qi,Kj)=QiKjTcos(π2×ijM)s\left(Q_{i}^{\prime}, K_{j}^{\prime}\right)=Q_{i}^{\prime} K_{j}^{\prime T} \cos \left(\frac{\pi}{2} \times \frac{i-j}{M}\right)

QiKjcos(π2×ijM)=QiKjT(cos(πi2M)cos(πj2M)+sin(πi2M)sin(πj2M))=(Qicos(πi2M))(Kjcos(πj2M))T+(Qisin(πi2M))(Kjsin(πj2M))T\begin{aligned} Q_{i}^{\prime} K_{j}^{\prime} \cos \left(\frac{\pi}{2} \times \frac{i-j}{M}\right) &=Q_{i}^{\prime} K_{j}^{\prime T}\left(\cos \left(\frac{\pi i}{2 M}\right) \cos \left(\frac{\pi j}{2 M}\right)+\sin \left(\frac{\pi i}{2 M}\right) \sin \left(\frac{\pi j}{2 M}\right)\right) \\ &=\left(Q_{i}^{\prime} \cos \left(\frac{\pi i}{2 M}\right)\right)\left(K_{j}^{\prime} \cos \left(\frac{\pi j}{2 M}\right)\right)^{T}+\left(Q_{i}^{\prime} \sin \left(\frac{\pi i}{2 M}\right)\right)\left(K_{j}^{\prime} \sin \left(\frac{\pi j}{2 M}\right)\right)^{T} \end{aligned}

Oi=j=1Nf(Qi,Kj)Vjj=1Nf(Qi,Kj)=j=1NQicos((Kjcos)TVj)+j=1NQisin((Kjsin)TVj)j=1NQicos(Kjcos)T+j=1NQisin(Kjsin)TO_{i}=\frac{\sum_{j=1}^{N} f\left(Q_{i}^{\prime}, K_{j}^{\prime}\right) V_{j}}{\sum_{j=1}^{N} f\left(Q_{i}^{\prime}, K_{j}^{\prime}\right)}=\frac{\sum_{j=1}^{N} Q_{i}^{\cos }\left(\left(K_{j}^{\cos }\right)^{T} V_{j}\right)+\sum_{j=1}^{N} Q_{i}^{\sin }\left(\left(K_{j}^{\sin }\right)^{T} V_{j}\right)}{\sum_{j=1}^{N} Q_{i}^{\cos }\left(K_{j}^{\cos }\right)^{T}+\sum_{j=1}^{N} Q_{i}^{\sin }\left(K_{j}^{\sin }\right)^{T}}

O=S(Q,K)V=(QcosKcos+QsinKsin)V=Qcos(KcosV)+Qsin(KsinV)\mathcal{O}=\mathcal{S}(Q, K) V=\left(Q^{\cos } K^{\cos }+Q^{\sin } K^{\sin }\right) V=Q^{\cos }\left(K^{\cos } V\right)+Q^{\sin }\left(K^{\sin } V\right)

也通过实验来验证了 这种做法的有效性。

最后效果如下