Pytorch学习笔记#2: 搭建神经网络训练MNIST手写数字数据集
创始人
2024-05-29 11:10:41
0

学习自https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

导入并预处理数据集

pytorch中数据导入和预处理主要用torch.utils.data.DataLoader 和 torch.utils.data.Dataset
Dataset 存储样本及其相应的标签,DataLoader在数据上生成一个可迭代对象(Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset.)

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# Download training data from open datasets.
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),
)# Download test data from open datasets.
test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor(),
)

将数据集作为参数传递给 DataLoader。 这在我们的数据集上包装了一个可迭代对象,并支持自动批处理、采样、混洗和多进程数据加载。并且每一个batch大小为64。

batch_size = 64# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break

搭建神经网络

MNIST手写数字数据集的图片是2828的,所以第一层的输入为2828。
因为识别结果是0~9这10种,所以最后一层的输出就是10个。

我们需要定义神经网络结构,这部分在__init__(self)部分实现。
且我们需要forward部分定义网络正向传播的方法。

class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork().to(device)
print(model)

训练模型

首先,我们需要先定义损失函数和优化器(优化梯度下降算法)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # lr为学习率

在一次循环中,神经网络通过forward进行预测(我们写的forward函数),然后再利用预测误差。通过反向传播来进行梯度下降(pytorch帮我们实现)。

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

开始训练!

epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

在这里插入图片描述

相关内容

热门资讯

安卓4.4系统tv软件,探索安... 亲爱的读者们,你是否曾为家里的电视屏幕增添一些智能的魔力而烦恼?别担心,今天我要给你带来一个超级实用...
安卓系统的研究人物,安卓系统发... 你知道吗?在科技飞速发展的今天,安卓系统可是占据了智能手机市场的大半壁江山。而在这片广阔的天地里,有...
山寨苹果刷会安卓系统,安卓系统... 你知道吗?在科技圈里,总有一些让人眼前一亮的小秘密。今天,我要给你揭秘一个关于山寨苹果刷安卓系统的神...
安卓系统新用户登录,畅享智能生... 你刚刚入手了一台全新的安卓手机,是不是有点小激动呢?别急,别急,让我来给你详细介绍一下安卓系统新用户...
安卓8.0系统推荐版本,体验流... 你有没有发现,手机系统更新换代的速度简直就像小孩子的成长一样快?这不,安卓8.0系统已经悄悄地来到了...
安卓系统怎么分享位置吗,一键操... 你是不是也有过这样的经历:和朋友约好见面,却因为找不到对方而急得团团转?别担心,今天就来教你怎么在安...
安卓系统更新加速器,畅享极速升... 你有没有发现,手机更新系统的时候总是慢吞吞的,让人等得心痒痒?别急,今天就来给你安利一款神器——安卓...
百答系统和安卓系统区别,差异解... 你有没有想过,为什么你的手机里装了那么多应用,却还是觉得信息不够全面?其实,这背后的大脑——操作系统...
安卓锁系统设置软件,软件设置与... 手机里的秘密可多了去了,是不是有时候你也会觉得,这手机里的信息要是被别人看到了可怎么办呢?别担心,今...
安卓电视u盘游戏系统,轻松畅享... 你有没有想过,家里的安卓电视也能玩上那些刺激的电脑游戏呢?没错,就是那种让你一玩就停不下来的游戏!今...
挂载安卓系统为读写权限,读写权... 你有没有想过,你的手机里那些神奇的安卓系统,竟然可以赋予某些应用读写权限?这听起来是不是有点像科幻电...
安卓12系统怎么打补丁,保障设... 亲爱的安卓用户们,你是否也遇到了系统卡顿、bug频发的小烦恼呢?别急,今天就来给你支个招——安卓12...
客厅电脑用安卓系统好吗,体验智... 亲爱的读者,你是不是在为客厅电脑选择操作系统而烦恼呢?安卓系统,这个我们日常手机上常见的操作系统,是...
安卓系统能看访客记录,轻松查看... 你有没有想过,你的安卓手机里藏着一个小秘密?没错,就是访客记录!是的,你没听错,你的手机里竟然能查看...
印度安卓系统电脑推荐,性能卓越... 你有没有想过,在印度这片神奇的土地上,用一台安卓系统电脑会是怎样的体验呢?想象阳光洒在泰姬陵的白色大...
安卓系统合作公司,安卓系统合作... 你知道吗?在科技的世界里,安卓系统可是个超级明星呢!它不仅拥有庞大的用户群体,还吸引了一大批合作公司...
苹果表有安卓系统时间,时间同步... 你有没有发现,最近苹果表也开始支持安卓系统了?没错,就是那个一直以封闭著称的苹果,竟然也开始拥抱安卓...
原生安卓系统裁剪图片,原生安卓... 你有没有发现,用原生安卓系统拍照,有时候拍出来的照片分辨率超高,但就是有点大,想裁剪却不知道怎么操作...
安卓系统蓝牙开关APP,安卓系... 你有没有遇到过这种情况:手机里的安卓系统蓝牙开关总是让人摸不着头脑?有时候想开蓝牙,却找不到开关在哪...
安卓系统能登录ios系统王者吗... 你有没有想过,安卓系的手机能不能登录iOS系统的王者荣耀呢?这可是个让人好奇不已的问题哦!毕竟,两个...