Python 组 WoC 相关

Python 组 WoC 相关

第二题《这些图片是什么鬼?!》
Kaggle 链接:这些图片是什么鬼?!

题目中要求使用 ResNet50 在训练集上训练,并在 test 集上进行预测。
一开始写的代码如下:

1
2
net = torchvision.models.resnet50(pretrained=True)
net.fc = nn.Linear(net.fc.in_features, 2)

发现效果很差。
通过查阅资料,发现 ResNet50 的全连接层是 2048*1000,直接修改为 2048*2 效果很差,经过添加线性层修改后在训练集上效果变好。

1
2
3
4
5
6
7
net = torchvision.models.resnet50(pretrained=True)
net.fc = nn.Sequential(nn.Linear(net.fc.in_features, net.fc.in_features//2),
nn.ReLU(),
nn.Linear(net.fc.in_features//2, net.fc.in_features//4),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(net.fc.in_features//4, 2))