Gemnet-oc框架
2024-09-15 17:13:22 4 举报
AI智能生成
一个机器学习流程基础解析
作者其他创作
大纲/内容
1. 图结构提取
图结构
edge_index:shape[2*N],对应边节点的索引
distance:边的长度
vector:-(两原子坐标向量差)/distance(归一化)
cell_offsets:边上的原子是否跨越了晶胞边界
num_neighbors:所有原子的总邻居数
target_neighbor_idx:表示第i个节点的第j条边,以edge_index[1]计算得到
a2a_graph:原子相互作用图
if 原子、边相互作用 isTrue:edge_index\distance\vector\cell_offsets\num_neighbors\target_neighbor_idx
else:{}
mian_graph:原子相互连接图
edge_index\distance\vector\cell_offsets
a2ee2a_graph:原子和边、边和原子相互作用图
if 原子、边相互作用 is True:edge_index\distance\vector\cell_offsets\num_neighbors\target_neighbor_idx
else:{}
qint_graph:四体相互作用图
if 原子、边 and 四体相互作用 is True:edge_index\distance\vector\cell_offsets\num_neighbors\target_neighbor_idx
else {}
id_swap:对main_graph的边进行无向化处理,此变量中保存了图中正反对应边的索引
三元组信息
in:三元组输入边索引
out:三元组输出边索引
out_agg:类似target_neighbor_idx,三元组第i条边的第j条边
trip_idx_e2e:边和边相互作用
trip_idx_ae:原子和边相互作用
trip_idx_e2a:边和原子相互作用
quad_idx:四元组信息
if 原子、边 and 四体相互作用 is True:
四元组信息:四元组定义d->b->a<-c
四元组信息:四元组定义d->b->a<-c
triplet_in
adj_edges:记录输入三元组d->b->a的邻接信息
row:输入三元组中的目标节点a
col:输入三元组中每条边的源节点d
val:每条边的索引
size:张量大小(row长度*col长度)
nnz:非零元素的数量,表示张量中记录的边的数量
density:稀疏矩阵的密度,表示非零元素占总元素的比例
in:d->b索引;in和out来自不同的图,in来自qint_graph;out来自main_graph(可能是通过对比两图差别获得的)
out:b->a索引
triplet_out
in:b->a索引
out:c->a索引
out
c->a的索引
trip_in_to_quad:输入三元组索引
trip_out_to_quad:输出三元组索引
out_agg:类似前
else: {}
2. 基函数嵌入
get_bases
get_bases
rbf:二体相互作用(距离)
rbf类型
gaussian
d:输入距离/变量;u:高斯基函数中心点,峰值位置,β:尺度参数,控制基函数宽度,值越大函数越尖锐,越小函数越平缓
特点:高斯基函数对距离d 的变化敏感,其值在距离中心μ 附近较大,随着距离增大迅速衰减至接近零。这使得它非常适合描述局部相互作用。
缺点:只适合局部距离变化,不能有效描述周期性或全局特征。
spherical_bessel
球贝塞尔基函数:描述对称系统中波动的特征函数,常用于球坐标系中
特点:球贝塞尔基函数呈现周期性或震荡行为,适合描述周期性的物理现象,同时可以捕捉全局特性,不局限于局部相互作用
缺点:计算复杂度高,难以稳定处理
bernstein
伯恩斯坦基函数常用语多项式逼近和曲线拟合
特点:可以形成任意多项式逼近,可以通过增加基函数的数量来控制逼近的精度。通过加权组合不同次幂的x和1-x值,既能捕捉局部特征,也能生成全局光滑的曲线
缺点:需要较多基函数组合才能描述复杂的几何特征
RadialBasis
参数设置
参数设置
num_radial:基函数频率,这个数值控制了阈值内数据划分的细致程度
例如:阈值为1,此值为10,表示将1划分为10份;这个值控制了输入嵌入向量的长度
例如:阈值为1,此值为10,表示将1划分为10份;这个值控制了输入嵌入向量的长度
cutoff:截止距离
rbf:设置基函数类型
envelope:包络函数类型
用于让基函数在接近cutoff时平滑衰减
用于让基函数在接近cutoff时平滑衰减
polynomial
exponential
scale_basis:是否对基函数输出进行缩放,以增强数值稳定性
模型中的几种rbf配置
radial_basis
最基础的rbf函数,
用于main_graph的距离嵌入
最基础的rbf函数,
用于main_graph的距离嵌入
num_radial=num_radial=128
函数变量 输入变量 实际值 下同
函数变量 输入变量 实际值 下同
cutoff=cutoff=12
rbf=rbf={name:gaussian}
envelope=envelope={name:polynomial,exponent:5}
scale_basis=scale_basis=false
radial_basis_spherical
后续多种基函数使用了这个配置
后续多种基函数使用了这个配置
rbf=rbf_spherical={name:gaussian}
注:rbf_spherical如果没有设置默认和rbf变量一致,这里就是这种情况
注:rbf_spherical如果没有设置默认和rbf变量一致,这里就是这种情况
其余参数和变量与radial_basis一致
radial_basis_spherical_qint
四体相互作用rbf
四体相互作用rbf
cutoff=cutoff_qint=12
rbf=rbf_spherical={name:gaussian}
其余参数和变量与radial_basis一致
radial_basis_aeaint
原子和边/边和原子相互作用
原子和边/边和原子相互作用
cutoff=cutoff_aeaint=12
其余参数和变量与radial_basis一致
radial_basis_spherical_aeaint
只在边和原子相互作用时使用
只在边和原子相互作用时使用
rbf=rbf_spherical={name:gaussian}
cutoff=cutoff_aeaint=12
其余参数和变量与radial_basis一致;这里为什么只有在边和原子相互作用开启时才使用rbf_spherical这个参数呢?
输出
basis_rad_main_raw:main_graph["distance"]使用radial_basis嵌入
basis_rad_a2a_raw:a2a_graph["distance"]使用radial_basis_aint嵌入
if atom_interaction is True
if atom_interaction is True
cbf:三体相互作用(角度)
cbf类型
gaussian:同上
spherical_harmonics
基于球面坐标的基函数,广泛应用于三维空间的角度描述
特点:周期性和复杂角度的相互作用
缺点:计算复杂度较高,局部角度变化较小时,可能不如高斯基函数灵活
CircularBasisLayer
参数设置
参数设置
num_spherical:同上num_radial
radial_basis:一个rbf实例,cbf输出的值一个是对距离的嵌入,所以有一个rbf实例
cbf:设置基函数类型,对角度cos值嵌入
scale_basis:同上
模型中的几种cbf配置
cbf_basis_tint
三体相互作用
三体相互作用
num_spherical=num_spherical=7
rbf=radial_basis_spherical(rbf配置部分)
cbf=cbf={name:spherical_harmonics}
scale_basis=scale_basis=False
cbf_basis_qint
四体相互作用
if quad_interaction is True
四体相互作用
if quad_interaction is True
radial_basis=radial_basis_spherical_qint
其余配置和变量与cbf_basis_tint一致
cbf_basis_aeint
原子和边相互作用
原子和边相互作用
这个变量和cbf_basis_tint完全一致
cbf_basis_eaint
边和原子相互作用
边和原子相互作用
radial_basis=radial_basis_spherical_aeaint
输出
cosφ_cab,角度cos值
basis_rad_cir_e2e_raw:main_graph["distance"]使用radial_basis_spherical嵌入
basis_cir_e2e_raw:cosφ_cab 使用sbf_basis_tint嵌入
sbf:四体相互作用(二面角)
if quad_interaction is True:
(这里默认是False)
if quad_interaction is True:
(这里默认是False)
sbf类型
sphericlal_harmonics:同上
legendre_outer
勒让德基函数外积,外积指的是计算两个向量之间的相互作用,具体来说是将它们各自的勒让德多项式的结果组合在一起。
通过计算这些外积,可以捕捉更高阶的角度变化特征。
通过计算这些外积,可以捕捉更高阶的角度变化特征。
优点:可以适应不同的角度相互作用,比球谐函数更通用,
可以减少特征之间的相互影响,
并且数值相对稳定,适用于高阶角度计算
可以减少特征之间的相互影响,
并且数值相对稳定,适用于高阶角度计算
缺点:在复杂三维球面对称结构中效果不如球谐函数,计算开销较大
gaussian_outer
通过计算两个或多个高斯基函数的外积,可以捕捉不同距离之间的相互关系
优缺点:同上高斯基函数
SphericalBasisLayer
参数设置
参数设置
num_spherical:同上
radial_basis:sbf中同样会对距离进行嵌入,所以有一个rbf实例
sbf:设置基函数类型
scal_basis:用上
模型中的sbf配置
sbf_basis_qint
只有四体作用时用到sbf
只有四体作用时用到sbf
num_spherical=num_spherical=7
radial_basis=radial_basis_spherical
sbf=sbf={name:legendre_outer}
scale_basis=scale_basis=False
输出变量解析
cosφ_cab_q:四体cab cos值;cosφ_adb:adb cos值; angle_cabd:二面角
basis_rad_cir_qint_raw:qint_graph["distance"]使用radial_basis_spherical_qint嵌入
basis_cir_qint_raw:cosφ_adb 使用cbf_basis_qint
basis_rad_sph_qint_raw:main_graph["distance"]使用radial_basis_spherical嵌入
basis_sph_qint_raw:cosφ_cab_q和angle_cabd使用sbf_basis_qint嵌入
cbf-原子和边相互作用
if atom_edge_interaction is True:
(这里默认是True)
if atom_edge_interaction is True:
(这里默认是True)
输出变量解析
basis_rad_a2ee2a_interaction:a2ee2a_graph["distance"] radical_basis_graph嵌入
问题:为什么只有a2e的时候要对这个距离进行嵌入?
问题:为什么只有a2e的时候要对这个距离进行嵌入?
cosφ_cab_a2e:main_graph和a2ee2a_graph ["vector"]夹角cos值
basis_rad_cir_a2e_raw:main_graph["distance"]使用radial_basis_spherical嵌入
basis_cir_a2e_raw:cosφ_cab_a2e使用cbf_basis_aeint嵌入
cbf-边和原子相互作用
if edge_atom_interaction is True:
(这里默认是True)
if edge_atom_interaction is True:
(这里默认是True)
输出变量解析
cosφ_cab_e2a:a2ee2a_graph和main_graph ["vector"]夹角cos值
这里和cosφ_cab_a2e的差别是夹角取值的边不同(角不同)
这里和cosφ_cab_a2e的差别是夹角取值的边不同(角不同)
basis_red_cir_e2a_raw:a2ee2a_graph["graph"] 使用radical_basis_spherical_aeaint嵌入
basis_cir_e2a_raw:cosφ_cab_e2a使用cbf_basis_eaint嵌入
全连接层映射设置
init_shared_basis_layers
init_shared_basis_layers
Dense
参数设置
in_features:输入特征长度
out_features:输出特征长度
bias:偏置项
activation:激活函数名字
silu或者swish
σ(x)是Sigmod函数,该函数的输出是x和Sigmod激活的乘积,因此它是一个平滑的非线性函数
特点:在整个范围内的导数是连续的,这有助于模型的优化和梯度流动
与其他激活函数不同,他的输出是由输入x本身调控的其他激活函数通常将输入所谓一个整体处理
它是非单调的,尽管激活函数大部分是正值,但他允许部分负值通过,相比于ReLu,提供了更细粒度的控制
与其他激活函数不同,他的输出是由输入x本身调控的其他激活函数通常将输入所谓一个整体处理
它是非单调的,尽管激活函数大部分是正值,但他允许部分负值通过,相比于ReLu,提供了更细粒度的控制
缺点:计算复杂度比ReLu略高
可能导致过拟合
没有负饱和,简单理解输入中包含过多负值可能不理想
可能导致过拟合
没有负饱和,简单理解输入中包含过多负值可能不理想
if activation is None
torch.nn.Identity():不对输入进行任何操作输出
所有Dense层
配置都是一样的
输入128维度向量线性变换为16维度后输出
配置都是一样的
输入128维度向量线性变换为16维度后输出
in_feature=num_radial=128
out_feature=emb_size_rbf=16
activation=None
bias=False
变量名
mlp_rbf_qint
mlp_rbf_aeint
mlp_rbf_eaint
mlp_rbf_tint
mlp_rbf_h
mlp_rbf_out
BasisEmbedding
参数设置
设置权重矩阵的形状
if num_spherical:
(emb_size_interm,num_radial)
else:
(num_radial,num_spherical,emb_size_interm)
设置权重矩阵的形状
if num_spherical:
(emb_size_interm,num_radial)
else:
(num_radial,num_spherical,emb_size_interm)
num_radial:基函数维度
emb_size_interm:嵌入的中间尺寸大小
num_spherical:球谐函数的数量
一共有三种配置
rbf
num_radial=num_radial=128
emb_size_interm=emb_size_rbf=7
num_spherical=None
变量名
mlp_rbf_aint
cbf
num_radial=num_radial=128
emb_size_interm=emb_size_cbf=16
num_spherical=num_spherical=7
变量名
mlp_cbf_qint
mlp_cbf_aeint
mlp_cbf_eaint
mlp_cbf_tint
sbf
num_radial=num_radial=128
emb_size_interm=emb_size_sbf=32
num_spherical=num_spherical*2=14
变量名
mlp_sbf_qint
前向传播
rad_basis:径向基函数,表示与每条边相关的基函数,形状为(num_edges,num_radial)或(num_edges,num_orders*num_radial)
sph_basis:球谐基函数,表示三元或四元组中每个三元组相关的基函数,形状为(num_triplets,num_spherical)
idx_rad_outer和id_rad_inner:用于聚合径向基函数的索引,分别指代与每个基函数相关联的原子和每个原子对应的径向基函数
idx_sph_outer和idx_sph_inner:用于聚合球谐基函数的索引,分别指代与每个基函数相关联的边缘和每条边缘对应的球谐基函数
num_atoms:总原子数,用于确定零填充矩阵的大小
输出变量
bases_e2e
rad
mlp_rbf_tinr
mlp_rbf_tinr
rad_basis=basis_rad_main_raw
cir
mlp_cbf_tint
mlp_cbf_tint
rad_basis=basis_rad_cir_e2e_raw
sph_basis=basis_cir_e2e_raw
idx_sph_outer=trip_idx_e2e["out"]
idx_sph_inner=trip_idx_e2e["out_agg"]
basis_atom_update
mlp_rbf_h
rad_basis=basis_rad_main_raw
basis_out_put
注意:这里和basis_atom_updata是完全一样的
可能后续需要分别做处理
注意:这里和basis_atom_updata是完全一样的
可能后续需要分别做处理
mlp_rbf_out
rad_basis=basis_rad_main_raw
bases_qint:dict
这类字典中rad属性表示二体相互作用
cir三体,sph四体
if atom_edge_interaction is True
else {}
这类字典中rad属性表示二体相互作用
cir三体,sph四体
if atom_edge_interaction is True
else {}
rad
mlp_rbf_qint
mlp_rbf_qint
rad_basis=basis_rad_main_raw
cir
mlp_cbf_qint
mlp_cbf_qint
rad_basis=basis_rad_cir_qint_raw
sph_basis=basis_cir_qint_raw
idx_sph_outer=quad_idx["triplet_in"]["out"]
sph
mlp_sbf_qint
mlp_sbf_qint
rad_basis=basis_rad_qint_raw
sph_basis=basis_sph_qint_raw
idx_sph_outer=quad_idx["out"]
idx_sph_inner=quad_idx["out_agg"]
bases_a2e
if atom_edge_interaction is True
else {}
if atom_edge_interaction is True
else {}
rad
mlp_rbf_aeint
mlp_rbf_aeint
rad_basis=basis_rad_a2ee2a_raw
cir
mlp_cbf_aeint
mlp_cbf_aeint
rad_basis=basis_rad_cir_a2e_raw
sph_basis=basis_cir_a2e_raw
idx_sph_outer=trip_idx_a2e["outer"]
idx_sph_inner=trip_idx_a2e["out_agg"]
bases_e2a
if edge_atom_interaction is True
else {}
if edge_atom_interaction is True
else {}
rad
mlp_rbf_eaint
mlp_rbf_eaint
rad_basis=basis_rad_main_raw
cir
mlp_cbf_eaint
mlp_cbf_eaint
rad_basis=basis_rad_cir_e2a_raw
sph_basis=basis_cir_e2a_raw
idx_rad_outer=a2ee2a_graph["edge_index"][1]
idx_rad_inner=a2ee2a_graph["target_neighbor_idx"]
idx_sph_outer=trip_idx_e2a["out"]
idx_sph_inner=trip_idx_e2a["out_agg"]
num_atoms=num_atoms
basis_a2a_rad
if atom_interaction is True:
else None
if atom_interaction is True:
else None
mlp_rbf_aint
rad_basis=basis_rad_a2a_raw
idx_rad_outer=a2a_graph["edge_index"][1]
idx_rad_inner=a2a_graph["target_neighbor_idx]
num_atoms=num_atoms
输出变量整理
basis_rad_main_raw
basis_atom_update
basis_output
bases_qint
bases_e2e
bases_a2e
bases_e2a
basis_a2a_rad
3.原子和边嵌入
原子嵌入
AtomEmbediing
AtomEmbediing
参数设置
emb_size:每种原子对应的嵌入向量的维度大小
num_elements:包含的原子类型
示例:num_elements=10,emb_size=64,那么每个原子类型都会被映射到64维的向量
torch.nn.init.uniform:初始化嵌入矩阵的权重,这里的权重初始化为均匀分布
atom_emb配置
emb_size=emb_size_atom=256
num_elements=num_elements=83
边嵌入
EdgeEmbedding
EdgeEmbedding
参数设置
atom_features:原子嵌入维度大小
edge_features:边嵌入维度大小
out_features:输出维度大小
activation:dense层激活函数名字
前向传播
h:原子AtomEmbeding信息
m:边经过rbf嵌入信息
edge_index:边的索引信息
将原子嵌入的信息按照边的关系拼接起来,进行线性变换,
线性变换层设置输入长度为2*atom_feature+edge_features;输出长度为out_features
线性变换层设置输入长度为2*atom_feature+edge_features;输出长度为out_features
4. 交互层
InteractionBlock
InteractionBlock
参数设置
5. 输出模块
out_blocks
out_blocks
继承类AtomUpdateBlock
对原子嵌入进行更新
对原子嵌入进行更新
OutputBlock
参数设置
参数设置
emb_size_atom:每个原子的嵌入向量的维度
emb_size_edge:每条边嵌入向量的维度
emb_size_rbf:径向基函数的嵌入大小
nHidden:AtomUpdateBlock之前使用的残差快ResudyakBlock的数量
nHidden_afteratom:AtomUpdateBlock之后使用的残差块数量
activation:激活函数的名称
direct_forces:是否直接预测力,False的时候会通过能量梯度计算力
默认设置
层数=num_blocks+1=4+1
为什么要设置成num_blocks+1?
为什么要设置成num_blocks+1?
emb_size_atom=emb_size_atom=256
emb_size_edge=emb_size_edge=512
emb_size_rbf=emb_size_rbf=16
nHidden=num_atom=3
nHidden_afteratom=num_output_afteratom=3
activation=activation=silu
direct_forces=direct_forces=True
输出参数
x_E,x_F=out_blocks[0](h,m,basis_output,idx_t)
h:原子嵌入向量
m:边嵌入向量
basis_output:main_graph["distance"]使用rbf嵌入
idx_t:main_graph["edge_index"][1],边节点输出索引
收藏
0 条评论
下一页