commit 58e344ef137c0837085420f8859fe353379394ed Author: TaurusXin Date: Fri Dec 29 19:53:46 2023 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6473974 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +__pycache__ + +data/* + +!data/.gitkeep + +*.pt \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..ce7414e --- /dev/null +++ b/README.md @@ -0,0 +1,43 @@ +# MNIST 数据集 + +使用 PyTorch 训练 + +## 依赖安装 + +首先根据 [PyTorch 官网](https://pytorch.org/get-started/locally/) 安装 PyTorch,以便启用 GPU 加速 + +然后安装其他依赖 + +```shell +pip install fastapi uvicorn pillow tqdm +``` + +## 训练模型 + +首先你当然要先训练模型啦 + +```shell +python mnist_train.py --save-model +``` + +## 使用方法 + +训练完模型后,本仓库包含两种使用方法,即本地文件和 Web UI + +### 本地文件 + +将图片文件放到 input 文件夹内,必须是 28*28 的灰度png格式,然后执行 + +```shell +python mnist.test.py +``` + +### Web UI + +执行以下命令启动 Web UI + +```shell +uvicorn webui:app +``` + +然后打开 即可看到网页 diff --git a/data/.gitkeep b/data/.gitkeep new file mode 100644 index 0000000..350bb18 --- /dev/null +++ b/data/.gitkeep @@ -0,0 +1 @@ +# ignore data but keep folder \ No newline at end of file diff --git a/input/1.png b/input/1.png new file mode 100644 index 0000000..1b67edd Binary files /dev/null and b/input/1.png differ diff --git a/input/2.png b/input/2.png new file mode 100644 index 0000000..e7ca1ce Binary files /dev/null and b/input/2.png differ diff --git a/input/3.png b/input/3.png new file mode 100644 index 0000000..d2b2288 Binary files /dev/null and b/input/3.png differ diff --git a/input/4.png b/input/4.png new file mode 100644 index 0000000..69349d6 Binary files /dev/null and b/input/4.png differ diff --git a/mnist_net.py b/mnist_net.py new file mode 100644 index 0000000..4e1957b --- /dev/null +++ b/mnist_net.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MnistNet(nn.Module): + def __init__(self): + super(MnistNet, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output \ No newline at end of file diff --git a/mnist_test.py b/mnist_test.py new file mode 100644 index 0000000..5042e53 --- /dev/null +++ b/mnist_test.py @@ -0,0 +1,56 @@ +import torch +from mnist_net import MnistNet +from torchvision import transforms +from PIL import Image + +import os + +# find device +device = None + +use_cuda = torch.cuda.is_available() +use_mps = torch.backends.mps.is_available() + +if use_cuda: + device = torch.device("cuda") +elif use_mps: + device = torch.device("mps") +else: + device = torch.device("cpu") + +print(f"Using device: {device}") + +# load image +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] +) + +def predict(device, model, img): + # transform image + img = transform(img) + img = img.to(device) + img = img.unsqueeze(0) + + # predict + output = model(img) + pred = output.argmax(dim=1, keepdim=True) + return pred.item() + +if __name__ == '__main__': + # load model + model = MnistNet() + model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device)) + model.to(device) + + # set model to eval mode + model.eval() + + # load image + folder = './input' + predictions = [] + + for image_name in sorted(os.listdir(folder)): + img = Image.open(os.path.join(folder, image_name)) + predictions.append(predict(device, model, img)) + + print(predictions) \ No newline at end of file diff --git a/mnist_train.py b/mnist_train.py new file mode 100644 index 0000000..09d6e6c --- /dev/null +++ b/mnist_train.py @@ -0,0 +1,171 @@ +import argparse + +from mnist_net import MnistNet +import torch +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR +from tqdm import tqdm + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + with tqdm(train_loader, desc=f"Train Epoch {epoch}", unit="batch") as t: + for batch_idx, (data, target) in enumerate(t): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + if batch_idx % args.log_interval == 0: + t.set_postfix(loss=round(loss.item(), 3)) + + if args.dry_run: + break + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss( + output, target, reduction="sum" + ).item() # sum up batch loss + pred = output.argmax( + dim=1, keepdim=True + ) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print( + "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, + correct, + len(test_loader.dataset), + 100.0 * correct / len(test_loader.dataset), + ) + ) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=14, + metavar="N", + help="number of epochs to train (default: 14)", + ) + parser.add_argument( + "--lr", + type=float, + default=1.0, + metavar="LR", + help="learning rate (default: 1.0)", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.7, + metavar="M", + help="Learning rate step gamma (default: 0.7)", + ) + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--no-mps", + action="store_true", + default=False, + help="disables macOS GPU training", + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument( + "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" + ) + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=False, + help="For Saving the current Model", + ) + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + use_mps = not args.no_mps and torch.backends.mps.is_available() + + torch.manual_seed(args.seed) + + if use_cuda: + device = torch.device("cuda") + elif use_mps: + device = torch.device("mps") + else: + device = torch.device("cpu") + + print(f"Using device: {device}\n") + + train_kwargs = {"batch_size": args.batch_size} + test_kwargs = {"batch_size": args.test_batch_size} + if use_cuda: + cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + dataset1 = datasets.MNIST("./data", train=True, download=True, transform=transform) + dataset2 = datasets.MNIST("./data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + model = MnistNet().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "mnist_cnn.pt") + print("Model saved to mnist_cnn.pt") + +if __name__ == "__main__": + main() diff --git a/ui/index.html b/ui/index.html new file mode 100644 index 0000000..e19c55d --- /dev/null +++ b/ui/index.html @@ -0,0 +1,106 @@ + + + + + + 绘图应用 + + + + + + +
结果:
+ + + + diff --git a/webui.py b/webui.py new file mode 100644 index 0000000..1cb5a98 --- /dev/null +++ b/webui.py @@ -0,0 +1,39 @@ +from fastapi import FastAPI +from fastapi.responses import HTMLResponse +from pydantic import BaseModel + +from PIL import Image + +import torch + +from mnist_net import MnistNet +from mnist_test import device +from mnist_test import predict + +import base64 +import io + +class Request(BaseModel): + image: str + +app = FastAPI() + +model = MnistNet() +model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device)) +model.to(device) +model.eval() + +@app.get('/') +async def root(): + html_file = open("./ui/index.html", "r", encoding="utf-8") + content = html_file.read() + return HTMLResponse(content=content) + +@app.post("/api/predict") +async def root(req: Request): + img = base64.b64decode(req.image.replace('data:image/png;base64,', '')) + img = Image.open(io.BytesIO(img)).convert('L') + prediction = predict(device, model, img) + return { + 'result': prediction + }