手撕代码GTNC

TSGO的一些说明

偏导数与内积?

文中对输入样本X的偏导数可以定义为如下的形式

WP(X;W)=limh0P(X;W+hW)P(X;W)h\partial_{W} P(X ; W)=\lim _{h \rightarrow 0} \frac{P(X ; W+h W)-P(X ; W)}{h}

上述是偏导数的定义,其中XX表示变量而WW表示函数,分号的作用可以参考Stackoverflow。简单来说一般的符号定义格式为:functionname(variable;parameters),而等式右侧则是导数的定义。而上述公式之所以可以写为如下形式

WP(X;W)=W,P(X;W)W\partial_{W} P(X ; W)=\left\langle W, \frac{\partial P(X ; W)}{\partial W}\right\rangle

是因为全微分在一点的值可以通过拆解变成内积的形式,证明及例子请参考如下资料

Quora: Why is the directional derivative the dot product of the gradient and unit vector?

总结一下TSGO可以应用到TN上的条件如下

(1)要最小化的损失函数是样本的概率分布的函数,文中采用比较常见的负对数概率形式,如下

f=1AXAlogP(X)f=-\frac{1}{A} \sum_{X \in \mathcal{A}} \log P(X)

(2)ψ\mid\psi\rangle的归一化(例如ψψ\langle\psi \mid \psi\rangle) 成为通过单个张量的范数完成,文中就是将其等效于对正交中心张量归一化

(3)任何张量都可以通过无误差或有控制误差的TN变换来表示ψ\mid\psi\rangle的范数

对于MPS而言,通过中心正交形式满足条件(2),用测量和转换满足条件(3)

TSGO如何避免梯度消失和梯度爆炸?

梯度更新的公式大家都比较熟悉,如下

T[n~](T[n~]ηfT[n~])T^{[\tilde{n}]} \leftarrow\left(T^{[\tilde{n}]}-\eta \frac{\partial f}{\partial T^{[\tilde{n}]}}\right)

中心张量的优化可以解释为希尔伯特空间的旋转,学习率由旋转角度控制。为了使ψ\mid\psi\rangledψ\mid d\psi\rangle的方向旋转,按照上式子更新中心张量,随后对该张量进行正交化

T[n~]T[n~]T[n~]T^{[\tilde{n}]} \leftarrow \frac{T^{[\tilde{n}]}}{|T[\tilde{n}]|}

图1所示的几何关系,很容易看出,学习率η和旋转角度满足如下式子

η=tanθ\eta=\tan \theta

学习率可以通过旋转角度进行稳健控制,旋转角度自然约束为0<θπ20<\theta \ll \frac{\pi}{2} ,因此学习率可以严格的控制,不会出现0或1的情况,从而避免出现了梯度爆炸

重新回顾冉老师《张量网络基础课程》

看了矩阵乘积态的规范自由度与正交形式这一节,对MPS进行**规范变换(gauge transformation)**的一个优点就是:改变MPS中的tensor不会改变其所表示的量子态,上面提到TSGO需要对张量进行梯度更新,但进行梯度更新后MPS所表示的量子态不能改变,因此需要保证MPS态始终时正交的

MPS的规范自由度:对于同一个量子态,可由多组不同的张量组成的MPS态来表示其系数。引入约束条件,可以固定MPS的规范自由度,使MPS态表示唯一,常用的约束条件为构成MPS张量的正交条件

K中心正交形式,多次SVD或QR变换可以从K中心正交形式变为KK'中心正交形式。

MPS的正则形式(canonical form),MPS链上每一处二分得到的纠缠谱都被显示的写出来,可以定义无穷长体系的系统。表示如下

φs1s2sN=As1(1)Λ(1)As21(2)Λ(2)Λ(N2)AsN1(1)(N1)Λ(N1)AsN(N)T\varphi_{s_{1} s_{2} \ldots s_{N}}=A_{s_{1}}^{(1)} \Lambda^{(1)} A_{s_{2}^{1}}^{(2)} \Lambda^{(2)} \ldots \Lambda^{(N-2)} A_{s_{N-1}^{(1)}}^{(N-1)} \Lambda^{(N-1)} A_{s_{N}}^{(N) T}

N个Tensor和N-1个对角向量构成

TSGO/GTNC代码

这两篇论文的代码基本上都差不多,用的是同一个结构,只是更新方式略有差别,项目结构如下目录树所示,下面对其中一些和新的组件进行说明。

│  start_train_gtn.py
│ start_train_gtnc.py

├─dataset
│ ├─FashionMNIST
│ └─MNIST

├─data_trained
│ ├─GTN
│ └─GTNC
└─library
BasicClass.py
BasicFunctions_szz.py
MLclass.py
MPSclass.py
MPSMLclass.py
Parameters.py
Programclass.py
TNclass.py
wheel_functions.py
  • 文件start_train_gtn.py主要用MPS来封装一些生成型张量网络的基本操作,例如成mps表示的gtn、初始化环境张量、生成和更新环境张量、计算内积、初始化gtn参数、一轮循环更新等操作。

  • 文件start_train_gtn.py主要是将生成型张量网络用于解决分类问题,它集成了gtn的定义的所有参数,也新增了一些参数例如参数保存路径,还有计算内积、计算准确率、量子态映射成label等操作

  • libraty/目录下的文件封装了一些MPS、ML的一些基本操作,可以进行移植复用

项目中涉及到的一些类的关系如下图所示(有点丑~,大家凑合看吧,哈哈哈)

箭头开始表示父类,箭头方向表示子类,由于子类可以使用父类的各种方法和函数,因此子类中的函数和变量就有些混乱,加上模型近百参数,就不是很好读,下面补充三个点

1.feature_map函数进行特征映射,代码中定义了三种映射方式

1607252837997

(1)many_body_Hilbert_space

C10(sinπ2)1(cosπ2)0=sinπ2=image[:,:,0]C11(sinπ2)0(cosπ2)1=cosπ2=image[:,:,1]\sqrt{C_1^0}(sin\frac{\pi}{2})^1(cos\frac{\pi}{2})^0=sin\frac{\pi}{2}=image[:,:,0]\\ \sqrt{C_1^1}(sin\frac{\pi}{2})^0(cos\frac{\pi}{2})^1=cos\frac{\pi}{2}=image[:,:,1]

(2)linear_map

x0(x0,1x0)x_0 \rightarrow(x_0,1-x_0)

(3)sqrt_linear_map

x0(x0,(1x0))x_0\rightarrow (\sqrt{x_0},\sqrt{(1-x_0)})

2.Cost function的计算

补充下torch.mul(input, other, *, out=None)的用法

该函数的作用是input中的每个元素乘标量other,返回类型为Tensor的乘积结果,参数各类情况如下

  • other是个标量,那么input中每个元素都乘该标量
  • other是个张量且input.shape==other.shape时,两张量对应元素相乘,相当于Hadamard积
  • other是个张量但input.shape!=other.shape时,阶数少的张量要与阶数多的张量维数匹配,乘积的结果与张量阶数高的shape一致,例如tensor1(3,5,2)和tensor2(5,2),乘积结果shape为(3,5,2)

i,j表示虚拟指标,v表示映射维度(2),p表示像素指标(784),m表示样本指标(NcN_c标签为c的样本个数)

2log(i=1nj=1nTivj2)log(Nc)2Ncm(Epm+log(pvImpvTivjEm,v))2log(\sqrt{\sum_{i=1}^{n} \sum_{j=1}^{n}\left|T_{ivj}\right|^{2}})-log(N_c)-\frac{2}{N_c}\sum_{m}{(E_{pm}+log(\sum_{pv}{\mid I_{mpv}T_{ivj}E_{m,v}}\mid))}

1607252859072

3.Gradient 的计算

fψ=2ψ2AXAXXψ\frac{\partial f}{\partial|\psi\rangle}=2|\psi\rangle-\frac{2}{A} \sum_{X \in \mathcal{A}} \frac{|X\rangle}{\langle X \mid \psi\rangle}

1607252876364

4.numpy.random.permutation

随机排列序列,如果x是一个多维数组,那么一般仅重置其第一个索引,下面给出两个例子

# example 1:
np.random.permutation(10)
# array([8, 2, 4, 9, 7, 3, 6, 5, 1, 0])
# example 2:
arr = np.arange(9).reshape((3, 3))
np.random.permutation(arr)
# array([[6, 7, 8],
# [0, 1, 2],
# [3, 4, 5]])

这种方式能够很方便的帮助我们shuffle数据,以增强数据初始分布的随机性,当然random.shuffle、Dataloader模块也都包含一些数据shuffle函数,可根据情况进行选择。

除了上述几个知识点之外,MLclass.pyMPSclass.py中还封装了大量的MPS正交形式、特征映射函数等,比较难理解的地方添加了注释,通 过代码的注释可以详细了解其作用。

尽管GTNC是通过梯度下降的方式去更新张量,可是它主要更新了一个张量(也就是正交中心张量),因此涉及到的参数很少,模型能有较好的效果,另外模型的cost function并不会更新,而是每次以固定的公式进行计算。

关于MPS辅助指标的维数

设张量的维数为d1××dNd_1\times\dots\times d_N ,则不进行维数裁剪的TT分解且不考虑亏秩,第n个辅助指标的维数为

min(d1××dn,dn+1××dN)\min \left(d_{1} \times \cdots \times d_{n}, d_{n+1} \times \cdots \times d_{N}\right)

其实当时对这个公式理解的也不是很深,对指数爆炸一直是云里雾里,时隔许久当黄延同学重新抛给我这个问题时,我联想到GTNC代码中virtual bond变化的规律,突然有了新的想法,做出了下面这张图。

1607252910643

最下面的两个曲线图是黄延同学完成的,画的很棒,非常形象的解释了指数爆炸问题以及裁剪维数χ\chi在其中发挥的具体作用

  • Copyright: Copyright is owned by the author. For commercial reprints, please contact the author for authorization. For non-commercial reprints, please indicate the source.

扫一扫,分享到微信

微信分享二维码
  • Copyrights © 2015-2024 YuleZhang's Blog
  • Visitors: | Views:

请我喝杯咖啡吧~

支付宝
微信