From 58e344ef137c0837085420f8859fe353379394ed Mon Sep 17 00:00:00 2001 From: TaurusXin Date: Fri, 29 Dec 2023 19:53:46 +0800 Subject: [PATCH] first commit --- .gitignore | 7 ++ README.md | 43 +++++++++++++ data/.gitkeep | 1 + input/1.png | Bin 0 -> 2259 bytes input/2.png | Bin 0 -> 2343 bytes input/3.png | Bin 0 -> 2291 bytes input/4.png | Bin 0 -> 2205 bytes mnist_net.py | 28 ++++++++ mnist_test.py | 56 ++++++++++++++++ mnist_train.py | 171 +++++++++++++++++++++++++++++++++++++++++++++++++ ui/index.html | 106 ++++++++++++++++++++++++++++++ webui.py | 39 +++++++++++ 12 files changed, 451 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 data/.gitkeep create mode 100644 input/1.png create mode 100644 input/2.png create mode 100644 input/3.png create mode 100644 input/4.png create mode 100644 mnist_net.py create mode 100644 mnist_test.py create mode 100644 mnist_train.py create mode 100644 ui/index.html create mode 100644 webui.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6473974 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +__pycache__ + +data/* + +!data/.gitkeep + +*.pt \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..ce7414e --- /dev/null +++ b/README.md @@ -0,0 +1,43 @@ +# MNIST 数据集 + +使用 PyTorch 训练 + +## 依赖安装 + +首先根据 [PyTorch 官网](https://pytorch.org/get-started/locally/) 安装 PyTorch,以便启用 GPU 加速 + +然后安装其他依赖 + +```shell +pip install fastapi uvicorn pillow tqdm +``` + +## 训练模型 + +首先你当然要先训练模型啦 + +```shell +python mnist_train.py --save-model +``` + +## 使用方法 + +训练完模型后,本仓库包含两种使用方法,即本地文件和 Web UI + +### 本地文件 + +将图片文件放到 input 文件夹内,必须是 28*28 的灰度png格式,然后执行 + +```shell +python mnist.test.py +``` + +### Web UI + +执行以下命令启动 Web UI + +```shell +uvicorn webui:app +``` + +然后打开 即可看到网页 diff --git a/data/.gitkeep b/data/.gitkeep new file mode 100644 index 0000000..350bb18 --- /dev/null +++ b/data/.gitkeep @@ -0,0 +1 @@ +# ignore data but keep folder \ No newline at end of file diff --git a/input/1.png b/input/1.png new file mode 100644 index 0000000000000000000000000000000000000000..1b67edd79eb272b8df080ee3ca21cb89adad8987 GIT binary patch literal 2259 zcmb_eTWl0n7~Ym^3k8fokV_bc2U~P@E<3Y3GwD*=F13wrDWwUeB64=-%)%T-*GG=S6yZ7c-Cs^-Uk>phBNs_e%XaX~6* zl*5Qxx!i;cmsi(G%Nr#@#!Kd4#Yqtbh(Q}*$yijg#H1fj#TC)mc}(J13Su|<@lq!t zRv)av%5@WBf}18JilQ+e8nIr67X+Wr_XNgJjEAJWB*PJmN924W&0^^fjxw6EBGv{f z()rLYKOV7dLnO&WBH>PW+`1ViX+aQ3iXj<>KnTKW)@+a@G;3BSLI7HlsT#JbYnT%e zg!H)W$5FKdG-5_JtY)R9LZXpLV34$%a%iQ1vXsRcaWk4qE=weg!Wh(S3t{Oj)`;k~ zZbkIFgwDRcB>|E*7|iAvvX)pZo5Hf^H6a>lL54(Ib#V@ghYH zK!duhD$RF7eIkd@)GbhC92wZ)Kf$sjD!Lg1h_V_3VMrQU7{{_gisgD#H&J3#9i6Ho z7!<2C%LbYRs{(!;v39GfEb^R0^DGDv3X3?f0?iS;3>CuX3Gp%`0Z0i_W`01I;!ba! z`2$v$bqU3|nKi_Opa+Ho0{S3N_!yQY1Q22go|3(gg%T%oo=mj)rixY@hz^eGR3)Q` zj4Xv%ilPWkks-mdGDiqLjwdJQ%h>+GL`t_niU2JP(N<034^HR)*1$lWchb3S75|@f zNN0Z@@H2V4)l=^V?d8ZcsYmF5aE95_3H`{PZ%{)@ndtPow0PDCmn+Ap3Y69rzP^3K z{HJmkOzi)9t8eEg8(Ve+S5z+8xGKMV!-Q31H*^Gte=+;Yw7_$nlfT^m(z#Jxt%viv zyIar9T-^`|joVSb(<@wm|HR5ib_WyjBPAD`xAZ-F`DmYAlq~MLk<;3Ciu?A)*$aKW z`?`L37xecGpV4q-?%ZLQ8{TSsd3_}OUT4wTrLN+~iVluBbn2^z#-CNhPjip;^v-Ty zI&S)lrRt{Zoo}vFqPt4=tXOmG)THf8uJ)IGn0KM#g~+u(&kMio`h7?HvhR<~*z|Mn zM5!hjZSOewIq3YXHFsOnPwN{#F8L!)jub}PZg5=n~^i4rA%4Ppq0Jgf+6dB`J}`h#Z^L5Z|65JT&uKnk_;&^z1hQlTIrZZbP} z?wtFb^PTgZduK&S@vJdvPo-fPHpb(2m7@F8#x)B4Vej|zq1Y!hEEF6fF+CU0iSY;^40V9V!l8)9$DBkQmq%k`nj-MHiN4TD%rG3{ z<=zs!Kvp5{FfpV+(+qAyBa4-_I~+EfZ5qzfteK*%6w8sUndfXgW5E+Y0(n%$0AK2Q zG2stAIf-grS9pqwMx&;v*(9q$ig7p`6wOjBOCk$WtB>d)Mn<&Bg9t9D396*%k{rPe zMBtZebti$M9iS0blCTjikt&olDh3paG0_ICxS=Q{?UY(I6n8EP6b!*IjOZG&Wsrt7I@_oaIbre5w}+Qp6zf{E#Kk1$5{njJlBwOR^3hW92!Pr!6ng zcAlmOOue!w1?ulLwHYP8%M=Mm1UmRnV^QD(vKj`6vJ?hENGXvZfhQBm7sw%5MUGK) zOgsv&m-j?89Yh4^aXAUZ+9XLLZ$`0;HjyU%bO4YR(I$|NfFH4O1c1$Ip#|FF7+mj? zg<7Mw#`*!Ni?V<)?s#>uEW`OZD`~dcIg$Yuk#ul^gJhtc5$({*m|0;EZH_9TZUdpA zs76#GLS#k3Z=q?LJS17PyyO%psG2EfGUS^peP|p00vvCE+QtnH(4KCm+?Dlxvw7$)IB&!7C{7%K^5k# z42ML+qA~{{Do;SLGIk-rnQ8lAl7~%_oo8)>9r9>NN|ZpT29Y2{TQx;II-Q5>fq^*h zrE^yo|DSXyV}BmdGj+Gr!{q6N9+Kx9jG#=Z==3^wbCfBX8#4fY>m2PU?d znw!=x-uAI)ZZLhyXKPYBz8SxGmau#K)>fbH6??3)zE?B${qcRp@l~5sezawt`m2B3 zcs=&<86~pFBw~GYua;tZk}sTcH?bJ z>Y>hK(enbjsIfEU_$ywB;M`xcoc(7sD*XMU#?pxcM`p)vej>@t>pS_a4 YvG~C6ec#QkG(KfKGmBk2^Iuu^5B)qd+W-In literal 0 HcmV?d00001 diff --git a/input/3.png b/input/3.png new file mode 100644 index 0000000000000000000000000000000000000000..d2b2288cdaf683fb2927d35234f5ccd273a856ae GIT binary patch literal 2291 zcmb_e32YQq7#`3TsUcidsRbRT90i@7V|R9Trz}vqr7f^+1C6aQO5e`B*-pDVGtNxA z+k#l6)Dk30!9<{5Q31n^$RQRKf`S-{M-8Z4YUGL%Vt}9t`ewUb3KS&7O=joq%>VxX z`@a8s@9l!p;(~r@52s-m)~~40KONnN+Sh&P3oDl!3(>8wUihSmVfPQVuU^>lwjmhS zr>|00VO0dC38EHp07(l%M=YY-v#@a!Vmc6OpoIrvNKt*n`JKB7T#+yPdo_`!?Y1T>6Zi?kd)+unjz_{?lhd>n#NfxI2CnxHm zUp}JBvUGu>qS2@$>U3yEh+@26FGaHy%aSO9H0xCh#7Nc5>dN4Urf4X-rD!T{X9Pj5 z*76YuZ6}S0o=mHn38|21R1D}8v~X9mO1mKAEa!0I*#RIbs!d`{w<*vMEa)TbzL5+|vfeVrEpV(ryB?!G0!?>@ z1~f^L>u(O_<1J>d3q{6}fCc_DSP}(UGa>*{Rw5t-DP0W_cv46qPYY`XDvZ!EaTI}o zP^6j`P(@he_YsJ-Ls2Aw<$0QgUKa^C07+L+P_|sw>12kTW(p$1a-F;$t6kvG9HfP+zdy$BvvB5 zGDPi>MK|LSWzI=^xKYz*PnyBPAp$N@3VcG zlP}E4A2^n%{CilRY^CGCr_n3L$3~xOer6>Hqvbn=Y;uN=ty>+p`Xr&F_nIj>iqOCLJo((vBSA+w6w z4-dz2wmr7^fJBv~)n(^2eo;`fqP%`x*fnl&8W^}{>YvMadB@(SDPQKEU$Zl{VjuU6 zHm$MRbIfsdQ2MCijOVIP9sK3@pZf1#^jk~IYgaeYC&C}eTi2f)kkQnZ^Fu@1JK9F$ z@|I1T4{h_zZ`r=;$jT{mIx>Pl`r?$;y7>nByd@Gqbt4B`L) literal 0 HcmV?d00001 diff --git a/input/4.png b/input/4.png new file mode 100644 index 0000000000000000000000000000000000000000..69349d650bdd9f84d850456a05086203b8e52873 GIT binary patch literal 2205 zcmb_eYitx%6dnqS1&SrrSO|e(Fo3}9ymxm77A#$KH?jrG0!0kr?!9-qLw9F}nQ6O= z3QFXmf(C;c8ze|V6yy;kCK`f>G5&xCc_`7Q7~=yUkzn|tfuMM2ySo%9NQj%v&Ye5w ze&?L;eCOVIVNS!NRf9$hqA02=9*Z`T`$+$~hx}0s^k-w_R%yo;x)e3I#=rVe9h++@ zYG9?2Z1S2CvovUB0{~em9LQyDe-$-#dd>#01$%S~HydVzIe*{~LmMcSqzg(lM?FG?IAngE+uH-}QowSWIYCubju$ymWQhgq=1mXeSkoO_LWp7)I)?2TmPz}FAZ4|B z5r#zDlSbAq!A*Ek^bCK6>Iea4c_mR-2s%r=rnk&r&E-DKXzTqAa^RtMslZ6wUGem1sN?_l6r zPSUb6rDDw~g-nY|K%yrk05nWLiHCYE!BODh2;2X!P5Gu1zY{v`s;h7 zjx0zpu6k7kT~}~0$fiOVu{@Sl7GR=8Nzy}lSb$IvaS3g=REQvo2&QD7 z=UGKZn3ZLuuxdyNv%Cxxf`fn@M82L2647FYONy5-cF0`3tT@UTbU_swfWF!z4D`#4 z5mTHt`rzZXJQvrur!fJ&Me%#QyO!>?0|(bN6Yg$9UffN?x6C%|^h&@64j@|%cH9V~ zJ64tkww*B`@U!LGOjPXNa(!seqOYpurhZ(|c59MkhyWJC3abPKg_V(rSXIZQJUR>t zVW=w-A1;yHCrM#V43#?M&PXbj0H%T2jLBBbF?XKM?e#!UoHwR(T^IkK>2Uu3+@oji zdZ`!Q_1epcXbMy0fbfUq(}_Hk&o^umrW|s5&A+$xC`I+t;?cU~gpCJJ9&D(ns~-5r zn^T7VIpFyUVc3Rs!hItq!S&qMY3p`#OES+?PC0S%^bGj6Fu!`%WoF-|+F{(;>W}uG zSvc4k|J%t+R@?YPXOO<1x-$Cmyw9s5YhqiDj{VFt#G1#(MPI?EChr-3`m5FVFOFpw zuMBUgX6(opKS#G6d~L`+Zyc&+UTu8w$bg4;?Om~Q%!SzOG z4 f`91&mkAwANt0taJKJcLb0TG|s5ZyCl$*O+=fxhFB literal 0 HcmV?d00001 diff --git a/mnist_net.py b/mnist_net.py new file mode 100644 index 0000000..4e1957b --- /dev/null +++ b/mnist_net.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MnistNet(nn.Module): + def __init__(self): + super(MnistNet, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output \ No newline at end of file diff --git a/mnist_test.py b/mnist_test.py new file mode 100644 index 0000000..5042e53 --- /dev/null +++ b/mnist_test.py @@ -0,0 +1,56 @@ +import torch +from mnist_net import MnistNet +from torchvision import transforms +from PIL import Image + +import os + +# find device +device = None + +use_cuda = torch.cuda.is_available() +use_mps = torch.backends.mps.is_available() + +if use_cuda: + device = torch.device("cuda") +elif use_mps: + device = torch.device("mps") +else: + device = torch.device("cpu") + +print(f"Using device: {device}") + +# load image +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] +) + +def predict(device, model, img): + # transform image + img = transform(img) + img = img.to(device) + img = img.unsqueeze(0) + + # predict + output = model(img) + pred = output.argmax(dim=1, keepdim=True) + return pred.item() + +if __name__ == '__main__': + # load model + model = MnistNet() + model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device)) + model.to(device) + + # set model to eval mode + model.eval() + + # load image + folder = './input' + predictions = [] + + for image_name in sorted(os.listdir(folder)): + img = Image.open(os.path.join(folder, image_name)) + predictions.append(predict(device, model, img)) + + print(predictions) \ No newline at end of file diff --git a/mnist_train.py b/mnist_train.py new file mode 100644 index 0000000..09d6e6c --- /dev/null +++ b/mnist_train.py @@ -0,0 +1,171 @@ +import argparse + +from mnist_net import MnistNet +import torch +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR +from tqdm import tqdm + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + with tqdm(train_loader, desc=f"Train Epoch {epoch}", unit="batch") as t: + for batch_idx, (data, target) in enumerate(t): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + if batch_idx % args.log_interval == 0: + t.set_postfix(loss=round(loss.item(), 3)) + + if args.dry_run: + break + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss( + output, target, reduction="sum" + ).item() # sum up batch loss + pred = output.argmax( + dim=1, keepdim=True + ) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print( + "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, + correct, + len(test_loader.dataset), + 100.0 * correct / len(test_loader.dataset), + ) + ) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=14, + metavar="N", + help="number of epochs to train (default: 14)", + ) + parser.add_argument( + "--lr", + type=float, + default=1.0, + metavar="LR", + help="learning rate (default: 1.0)", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.7, + metavar="M", + help="Learning rate step gamma (default: 0.7)", + ) + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--no-mps", + action="store_true", + default=False, + help="disables macOS GPU training", + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument( + "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" + ) + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=False, + help="For Saving the current Model", + ) + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + use_mps = not args.no_mps and torch.backends.mps.is_available() + + torch.manual_seed(args.seed) + + if use_cuda: + device = torch.device("cuda") + elif use_mps: + device = torch.device("mps") + else: + device = torch.device("cpu") + + print(f"Using device: {device}\n") + + train_kwargs = {"batch_size": args.batch_size} + test_kwargs = {"batch_size": args.test_batch_size} + if use_cuda: + cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + dataset1 = datasets.MNIST("./data", train=True, download=True, transform=transform) + dataset2 = datasets.MNIST("./data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + model = MnistNet().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "mnist_cnn.pt") + print("Model saved to mnist_cnn.pt") + +if __name__ == "__main__": + main() diff --git a/ui/index.html b/ui/index.html new file mode 100644 index 0000000..e19c55d --- /dev/null +++ b/ui/index.html @@ -0,0 +1,106 @@ + + + + + + 绘图应用 + + + + + + +
结果:
+ + + + diff --git a/webui.py b/webui.py new file mode 100644 index 0000000..1cb5a98 --- /dev/null +++ b/webui.py @@ -0,0 +1,39 @@ +from fastapi import FastAPI +from fastapi.responses import HTMLResponse +from pydantic import BaseModel + +from PIL import Image + +import torch + +from mnist_net import MnistNet +from mnist_test import device +from mnist_test import predict + +import base64 +import io + +class Request(BaseModel): + image: str + +app = FastAPI() + +model = MnistNet() +model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device)) +model.to(device) +model.eval() + +@app.get('/') +async def root(): + html_file = open("./ui/index.html", "r", encoding="utf-8") + content = html_file.read() + return HTMLResponse(content=content) + +@app.post("/api/predict") +async def root(req: Request): + img = base64.b64decode(req.image.replace('data:image/png;base64,', '')) + img = Image.open(io.BytesIO(img)).convert('L') + prediction = predict(device, model, img) + return { + 'result': prediction + }