修改了算法代码,并建立了一个简单的训练脚本.修改bert处理二维输入,移除PPO的permute参数

This commit is contained in:
2025-10-22 16:56:12 +08:00
parent b626702cbb
commit 3f7e183c4b
101 changed files with 3837 additions and 39 deletions

View File

@@ -28,17 +28,26 @@ class Bert(nn.Module):
self.classifier.train()
def forward(self, x, mask=None):
# x可以是2D (batch_size, input_dim) 或 3D (batch_size, seq_len, feature_dim)
is_2d_input = (x.dim() == 2)
if is_2d_input:
# 如果输入是2D,添加一个序列维度
x = x.unsqueeze(1) # (batch_size, 1, input_dim)
# x: (batch_size, seq_len, input_dim)
# 线性投影
x = self.projection(x) # (batch_size, input_dim, embed_dim)
x = self.projection(x) # (batch_size, seq_len, embed_dim)
batch_size = x.size(0)
if self.CLS:
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (batch_size, 29, embed_dim)
x = torch.cat([cls_tokens, x], dim=1) # (batch_size, seq_len+1, embed_dim)
# 添加位置编码
x = x + self.pos_embed
# 添加位置编码(截断或扩展以匹配序列长度)
seq_len = x.size(1)
pos_embed = self.pos_embed[:, :seq_len, :]
x = x + pos_embed
# 转置为(seq_len, batch_size, embed_dim)
x = x.permute(1, 0, 2)