修改了算法代码,并建立了一个简单的训练脚本.修改bert处理二维输入,移除PPO的permute参数
This commit is contained in:
@@ -28,17 +28,26 @@ class Bert(nn.Module):
|
||||
self.classifier.train()
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
# x可以是2D (batch_size, input_dim) 或 3D (batch_size, seq_len, feature_dim)
|
||||
is_2d_input = (x.dim() == 2)
|
||||
|
||||
if is_2d_input:
|
||||
# 如果输入是2D,添加一个序列维度
|
||||
x = x.unsqueeze(1) # (batch_size, 1, input_dim)
|
||||
|
||||
# x: (batch_size, seq_len, input_dim)
|
||||
# 线性投影
|
||||
x = self.projection(x) # (batch_size, input_dim, embed_dim)
|
||||
x = self.projection(x) # (batch_size, seq_len, embed_dim)
|
||||
|
||||
batch_size = x.size(0)
|
||||
if self.CLS:
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
x = torch.cat([cls_tokens, x], dim=1) # (batch_size, 29, embed_dim)
|
||||
x = torch.cat([cls_tokens, x], dim=1) # (batch_size, seq_len+1, embed_dim)
|
||||
|
||||
# 添加位置编码
|
||||
x = x + self.pos_embed
|
||||
# 添加位置编码(截断或扩展以匹配序列长度)
|
||||
seq_len = x.size(1)
|
||||
pos_embed = self.pos_embed[:, :seq_len, :]
|
||||
x = x + pos_embed
|
||||
|
||||
# 转置为(seq_len, batch_size, embed_dim)
|
||||
x = x.permute(1, 0, 2)
|
||||
|
||||
Reference in New Issue
Block a user