Pytorch学习笔记#2: 搭建神经网络训练MNIST手写数字数据集
创始人
2024-05-29 11:10:41
0

学习自https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

导入并预处理数据集

pytorch中数据导入和预处理主要用torch.utils.data.DataLoader 和 torch.utils.data.Dataset
Dataset 存储样本及其相应的标签,DataLoader在数据上生成一个可迭代对象(Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset.)

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# Download training data from open datasets.
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),
)# Download test data from open datasets.
test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor(),
)

将数据集作为参数传递给 DataLoader。 这在我们的数据集上包装了一个可迭代对象,并支持自动批处理、采样、混洗和多进程数据加载。并且每一个batch大小为64。

batch_size = 64# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break

搭建神经网络

MNIST手写数字数据集的图片是2828的,所以第一层的输入为2828。
因为识别结果是0~9这10种,所以最后一层的输出就是10个。

我们需要定义神经网络结构,这部分在__init__(self)部分实现。
且我们需要forward部分定义网络正向传播的方法。

class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork().to(device)
print(model)

训练模型

首先,我们需要先定义损失函数和优化器(优化梯度下降算法)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # lr为学习率

在一次循环中,神经网络通过forward进行预测(我们写的forward函数),然后再利用预测误差。通过反向传播来进行梯度下降(pytorch帮我们实现)。

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

开始训练!

epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

在这里插入图片描述

相关内容

热门资讯

克罗恩病关节炎:身体与心灵的双... 哎呀,说到这个克罗恩病关节炎,真是让我又爱又恨啊!你可能觉得我这话有点夸张,但当你真正感受到它带来的...
胃病不能吃什么-胃病发作时,这... 哎呀,说到胃病,我这肚子就开始隐隐作痛了。你们知道吗,胃病发作的时候,有些食物简直就是“毒药”!首先...
exe文件用手机能打开-手机能... 哎呀,今天咱们来聊聊这个听起来有点玄乎的话题——手机能打开exe文件?我得先说,这可不是闹着玩的!e...
华为新视通视频会议方案:技术进... 大家好!今天我要和大家聊聊华为新出的那个“新视通”视频会议方案。哇,这个名字听起来就很有未来感,对吧...
exagear模拟器使用方法-... 嘿,各位游戏迷们,今天我要来聊聊一个超级酷炫的工具——Exagear模拟器!是不是有时候看到那些老旧...
sd卡格式化后数据自动恢复-S... 哎呀,说到这个SD卡格式化后数据自动恢复的事儿,真是让人又爱又恨!你知道吗,有一次我不小心把SD卡给...
帝国cms模板文件目录-帝国 ... 哎呀,说到这个帝国CMS的模板文件目录,我可是有一肚子的话要说!这玩意儿啊,真是个让人又爱又恨的小妖...
台湾快递单号查询网:让人心急又... 哎呀,说到这个台湾快递单号查询网,我真是又爱又恨!每次网购完,最让人心急的就是等快递了。你知道的,那...
酷管家损坏照片修复:拯救珍贵记... 哎呀,说到这个酷管家损坏照片修复,我这心里就五味杂陈啊!记得那次,我辛辛苦苦攒了几个月的照片,突然间...
live linux-体验自由... 嘿,朋友们,今天咱们聊聊那个让人爱不释手的LiveLinux!想象一下,你的电脑不再是被某个大公司绑...
phantomjs win10... 哎呀,说到在Win10上安装PhantomJS,这可真是一次让人又爱又恨的经历!你知道的,作为一个对...
西软酒店管理系统 官网-西软酒... 大家好,我是一名酒店前台的小姐姐,今天我要来聊聊我们酒店用的那个“西软酒店管理系统”官网,真的是让我...
出生医学证明大小太奇葩,让人又... 你知道吗,每次提到那个小小的出生医学证明,我心里就五味杂陈。不是因为它有多重要,而是因为它的大小,简...
路由器宽带叠加:让你家网速翻倍... 哎呀呀,说到这个路由器宽带叠加,我这小心脏就扑通扑通跳个不停!你知道吗,自从我用了这个神奇的技巧,我...
finaldata4.1注册码... 哎呀,说到这个FinalData4.1的注册码,真是让人又爱又恨啊!你知道吗,这玩意儿就像是那个关键...
多媒体教室讲台3d模型-多媒体... 哇,今天我要给大家带来一个超级酷炫的话题——我们的多媒体教室讲台3D模型!你们有没有想过,那个每天站...
手机看交通监控摄像头-在手机上... 在这个快节奏的城市生活中,我找到了一个小小的秘密花园——手机上的交通监控摄像头。每当我觉得压力山大,...
win10 老驱动-Win10... 哎呀,说到这个Win10的老驱动,我就一肚子火!你们知道吗,每次系统更新,我的电脑就像要和我闹别扭一...
急性呼吸衰竭定义-急性呼吸衰竭... 想象一下,你的肺突然像被重物压住,每一次呼吸都变成了一场挣扎。这就是急性呼吸衰竭的感觉,它像是一个不...
华天动力oa办公系统-华天动力... 大家好!我是你们的小助手,今天想和大家聊聊那个让我们的工作日变得稍微有点乐趣的家伙——华天动力OA办...