1 数据集处理
服务器地址:10.1.16.116
数据集目录:/home/data1/sebas/projects/sketch-recognition/resnet152-test/dataset
2 网络结构
三分支网络,以resnet152为backbone,前两个分支之间做交叉注意力机制学习,然后再和第三个分支进行特征融合。三个分支的输入分别是三个不同的数据集,每个数据集中的图像是1111x1111x1的灰度图(白色背景,黑色线条)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
| optim.Adam(model.parameters(), lr=lr)
best_acc = 0.0
train_steps = len(train_loader)
for epoch in range(epochs): model.train() train_loss = 0.0 train_correct_num = 0 train_bar = tqdm(zip(train_external_loader, train_internal_loader, train_loader), file=sys.stdout)
for step, (external_data, internal_data, data) in enumerate(train_bar): external_images, external_labels = external_data internal_images, internal_labels = internal_data images, labels = data
optimizer.zero_grad()
external_images = external_images.to(device) external_labels = external_labels.to(device) internal_images = internal_images.to(device) internal_labels = internal_labels.to(device) images = images.to(device) labels = labels.to(device)
external_outputs, internal_outputs, outputs = model(external_images, internal_images, images)
_, external_predicted = torch.max(external_outputs, 1) train_correct_num += (external_predicted == external_labels).sum().item()
_, internal_predicted = torch.max(internal_outputs, 1) train_correct_num += (internal_predicted == internal_labels).sum().item()
_, predicted = torch.max(outputs, 1) train_correct_num += (predicted == labels).sum().item()
external_loss = criterion(external_outputs, external_labels) internal_loss = criterion(internal_outputs, internal_labels) loss = criterion(outputs, labels)
total_loss = external_loss + internal_loss + loss total_loss.backward() optimizer.step()
train_loss += total_loss.item()
|
3 实验结果