计算机视觉四:resnet50实现cifar10分类
'''
@File : resnet50.py
@Contact : zhangzhilong2022@gmail.com
@Modify Time 2024/3/18 17:39
@Author huahai2022
@Desciption
'''
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
# 定义数据预处理的转换
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载CIFAR-10数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# 定义数据加载器
batch_size = 128
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 定义ResNet-50模型
model = models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10) # 替换最后一层全连接层,使其适配CIFAR-10的类别数
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 定义训练函数
def train(model, train_loader, criterion, optimizer):
model.train()
for images, labels in train_loader:
images = images.cuda()
labels = labels.cuda()
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 定义测试函数
def test(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.cuda()
labels = labels.cuda()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
return accuracy
# 将模型移动到GPU上(如果可用)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# 训练和测试模型
num_epochs = 10
for epoch in range(num_epochs):
train(model, train_loader, criterion, optimizer)
accuracy = test(model, test_loader)
print(f'Epoch [{epoch+1}/{num_epochs}], Test Accuracy: {accuracy:.2f}%')
在跨越相加的时候,有一个地方就是输入和输出的形状不一定是一样的,采用了两个方法,1是使用下采样改变矩阵的形状,2是使用广播机制。