Pytorch学习笔记#2: 搭建神经网络训练MNIST手写数字数据集
创始人
2024-05-29 11:10:41
0

学习自https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

导入并预处理数据集

pytorch中数据导入和预处理主要用torch.utils.data.DataLoader 和 torch.utils.data.Dataset
Dataset 存储样本及其相应的标签,DataLoader在数据上生成一个可迭代对象(Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset.)

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# Download training data from open datasets.
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),
)# Download test data from open datasets.
test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor(),
)

将数据集作为参数传递给 DataLoader。 这在我们的数据集上包装了一个可迭代对象,并支持自动批处理、采样、混洗和多进程数据加载。并且每一个batch大小为64。

batch_size = 64# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break

搭建神经网络

MNIST手写数字数据集的图片是2828的,所以第一层的输入为2828。
因为识别结果是0~9这10种,所以最后一层的输出就是10个。

我们需要定义神经网络结构,这部分在__init__(self)部分实现。
且我们需要forward部分定义网络正向传播的方法。

class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork().to(device)
print(model)

训练模型

首先,我们需要先定义损失函数和优化器(优化梯度下降算法)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # lr为学习率

在一次循环中,神经网络通过forward进行预测(我们写的forward函数),然后再利用预测误差。通过反向传播来进行梯度下降(pytorch帮我们实现)。

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

开始训练!

epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

在这里插入图片描述

相关内容

热门资讯

安卓系统越用越耗电,揭秘原因与... 你有没有发现,自从你把手机换成安卓系统后,电池续航能力好像大不如前了?是不是觉得每天都要带着充电宝出...
适合安卓系统的dj软件,打造个... 你有没有想过,在安卓手机上也能享受到DJ的乐趣呢?没错,现在就有很多适合安卓系统的DJ软件,让你随时...
制作音乐的软件安卓系统 音乐爱好者们,你是否曾梦想着在安卓手机上轻松制作出属于自己的音乐作品?别再羡慕那些专业音乐制作人啦!...
安卓系统时间久了好卡 手机用久了是不是感觉安卓系统越来越卡?是不是每次打开应用都要等上好一会儿,甚至有时候直接卡死?别急,...
车机系统安卓10教程,轻松掌握... 你有没有发现,现在越来越多的车机系统都开始支持安卓系统了呢?这不,安卓10已经悄悄地走进了我们的爱车...
安卓系统升级翻译 你知道吗?最近安卓系统又来了一次大升级,这可真是让人兴奋不已呢!想象你的手机就像穿上了新衣,焕然一新...
怎么查苹果是安卓系统,技术揭秘... 你有没有想过,你的苹果设备竟然可能是安卓系统?别惊讶,这可不是天方夜谭。有时候,我们可能会买到一些被...
安卓平板windows系统ap... 你有没有发现,最近安卓平板上出现了一个新趋势?那就是越来越多的用户开始尝试将Windows系统应用到...
领克05安卓系统,智能科技与驾... 你有没有发现,现在的汽车越来越智能了?这不,最近我试驾了一款叫做领克05的车型,它搭载的安卓系统简直...
安卓手机都有双系统吗,安卓手机... 你有没有想过,你的安卓手机是不是也有双系统呢?这可不是什么科幻小说里的情节,而是现实中许多手机用户都...
名爵zs安卓9.0系统,驾驭未... 你有没有听说最近名爵zs的新鲜事儿?没错,就是那个升级到了安卓9.0系统的名爵zs!哎呀呀,这可真是...
安卓车载系统哪家好用点,安卓车... 你有没有发现,随着科技的发展,汽车已经不仅仅是一个代步工具了,它更像是一个移动的智能生活空间。而在这...
安吉达订餐系统安卓下载 你有没有想过,点外卖也能变得如此轻松有趣?没错,就是那个让你在手机上就能轻松订餐的神奇工具——安吉达...
智能安卓点歌系统价格 你有没有想过,在一场热闹的聚会中,点歌环节竟然也能变得如此智能和便捷?没错,就是那个神奇的智能安卓点...
安卓系统的勿扰权限,智能守护您... 你有没有发现,手机里的安卓系统越来越智能了?不过,有时候它也会让人有点头疼,比如那个让人又爱又恨的“...
安卓系统比较好的,卓越性能与丰... 你有没有发现,现在手机市场上安卓系统简直成了香饽饽?不管是年轻人还是老年人,都对安卓手机爱不释手。今...
安卓系统文件夹名称,揭秘隐藏文... 你有没有发现,每次打开安卓手机,里面那些文件夹的名称都那么有趣,有时候甚至让人猜不透它们到底藏着什么...
安卓系统电视应用未安装,安卓电... 你有没有遇到过这种情况?家里的安卓系统电视上突然有个应用没安装,让你心里直痒痒,想赶紧弄明白怎么解决...
安卓系统投影电脑桌面,电脑桌面... 你有没有想过,你的安卓手机里的精彩内容,竟然可以无缝地投影到电脑桌面上?是的,你没有听错,这就是我们...
安卓系统老是多出照片,揭秘多出... 手机里的照片越来越多,是不是你也遇到了安卓系统里照片层出不穷的问题呢?这可真是让人头疼啊!今天,就让...