first commit
This commit is contained in:
commit
58e344ef13
|
@ -0,0 +1,7 @@
|
|||
__pycache__
|
||||
|
||||
data/*
|
||||
|
||||
!data/.gitkeep
|
||||
|
||||
*.pt
|
|
@ -0,0 +1,43 @@
|
|||
# MNIST 数据集
|
||||
|
||||
使用 PyTorch 训练
|
||||
|
||||
## 依赖安装
|
||||
|
||||
首先根据 [PyTorch 官网](https://pytorch.org/get-started/locally/) 安装 PyTorch,以便启用 GPU 加速
|
||||
|
||||
然后安装其他依赖
|
||||
|
||||
```shell
|
||||
pip install fastapi uvicorn pillow tqdm
|
||||
```
|
||||
|
||||
## 训练模型
|
||||
|
||||
首先你当然要先训练模型啦
|
||||
|
||||
```shell
|
||||
python mnist_train.py --save-model
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
训练完模型后,本仓库包含两种使用方法,即本地文件和 Web UI
|
||||
|
||||
### 本地文件
|
||||
|
||||
将图片文件放到 input 文件夹内,必须是 28*28 的灰度png格式,然后执行
|
||||
|
||||
```shell
|
||||
python mnist.test.py
|
||||
```
|
||||
|
||||
### Web UI
|
||||
|
||||
执行以下命令启动 Web UI
|
||||
|
||||
```shell
|
||||
uvicorn webui:app
|
||||
```
|
||||
|
||||
然后打开 <http://127.0.0.1:8000> 即可看到网页
|
|
@ -0,0 +1 @@
|
|||
# ignore data but keep folder
|
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
Binary file not shown.
After Width: | Height: | Size: 2.3 KiB |
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
|
@ -0,0 +1,28 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class MnistNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(MnistNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
||||
self.dropout1 = nn.Dropout(0.25)
|
||||
self.dropout2 = nn.Dropout(0.5)
|
||||
self.fc1 = nn.Linear(9216, 128)
|
||||
self.fc2 = nn.Linear(128, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = F.relu(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
x = self.dropout1(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
output = F.log_softmax(x, dim=1)
|
||||
return output
|
|
@ -0,0 +1,56 @@
|
|||
import torch
|
||||
from mnist_net import MnistNet
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
|
||||
import os
|
||||
|
||||
# find device
|
||||
device = None
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
use_mps = torch.backends.mps.is_available()
|
||||
|
||||
if use_cuda:
|
||||
device = torch.device("cuda")
|
||||
elif use_mps:
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# load image
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
||||
)
|
||||
|
||||
def predict(device, model, img):
|
||||
# transform image
|
||||
img = transform(img)
|
||||
img = img.to(device)
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
# predict
|
||||
output = model(img)
|
||||
pred = output.argmax(dim=1, keepdim=True)
|
||||
return pred.item()
|
||||
|
||||
if __name__ == '__main__':
|
||||
# load model
|
||||
model = MnistNet()
|
||||
model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device))
|
||||
model.to(device)
|
||||
|
||||
# set model to eval mode
|
||||
model.eval()
|
||||
|
||||
# load image
|
||||
folder = './input'
|
||||
predictions = []
|
||||
|
||||
for image_name in sorted(os.listdir(folder)):
|
||||
img = Image.open(os.path.join(folder, image_name))
|
||||
predictions.append(predict(device, model, img))
|
||||
|
||||
print(predictions)
|
|
@ -0,0 +1,171 @@
|
|||
import argparse
|
||||
|
||||
from mnist_net import MnistNet
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def train(args, model, device, train_loader, optimizer, epoch):
|
||||
model.train()
|
||||
with tqdm(train_loader, desc=f"Train Epoch {epoch}", unit="batch") as t:
|
||||
for batch_idx, (data, target) in enumerate(t):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if batch_idx % args.log_interval == 0:
|
||||
t.set_postfix(loss=round(loss.item(), 3))
|
||||
|
||||
if args.dry_run:
|
||||
break
|
||||
|
||||
|
||||
def test(model, device, test_loader):
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
test_loss += F.nll_loss(
|
||||
output, target, reduction="sum"
|
||||
).item() # sum up batch loss
|
||||
pred = output.argmax(
|
||||
dim=1, keepdim=True
|
||||
) # get the index of the max log-probability
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
|
||||
print(
|
||||
"Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
|
||||
test_loss,
|
||||
correct,
|
||||
len(test_loader.dataset),
|
||||
100.0 * correct / len(test_loader.dataset),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=64,
|
||||
metavar="N",
|
||||
help="input batch size for training (default: 64)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-batch-size",
|
||||
type=int,
|
||||
default=1000,
|
||||
metavar="N",
|
||||
help="input batch size for testing (default: 1000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
type=int,
|
||||
default=14,
|
||||
metavar="N",
|
||||
help="number of epochs to train (default: 14)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
type=float,
|
||||
default=1.0,
|
||||
metavar="LR",
|
||||
help="learning rate (default: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gamma",
|
||||
type=float,
|
||||
default=0.7,
|
||||
metavar="M",
|
||||
help="Learning rate step gamma (default: 0.7)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-mps",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="disables macOS GPU training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="quickly check a single pass",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=10,
|
||||
metavar="N",
|
||||
help="how many batches to wait before logging training status",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-model",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="For Saving the current Model",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
use_mps = not args.no_mps and torch.backends.mps.is_available()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
if use_cuda:
|
||||
device = torch.device("cuda")
|
||||
elif use_mps:
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
print(f"Using device: {device}\n")
|
||||
|
||||
train_kwargs = {"batch_size": args.batch_size}
|
||||
test_kwargs = {"batch_size": args.test_batch_size}
|
||||
if use_cuda:
|
||||
cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
|
||||
train_kwargs.update(cuda_kwargs)
|
||||
test_kwargs.update(cuda_kwargs)
|
||||
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
||||
)
|
||||
dataset1 = datasets.MNIST("./data", train=True, download=True, transform=transform)
|
||||
dataset2 = datasets.MNIST("./data", train=False, transform=transform)
|
||||
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
|
||||
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
|
||||
|
||||
model = MnistNet().to(device)
|
||||
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
|
||||
|
||||
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(args, model, device, train_loader, optimizer, epoch)
|
||||
test(model, device, test_loader)
|
||||
scheduler.step()
|
||||
|
||||
if args.save_model:
|
||||
torch.save(model.state_dict(), "mnist_cnn.pt")
|
||||
print("Model saved to mnist_cnn.pt")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,106 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>绘图应用</title>
|
||||
<style>
|
||||
#canvas, #canvas2 {
|
||||
background-color: black;
|
||||
cursor: crosshair;
|
||||
}
|
||||
|
||||
#result {
|
||||
margin-top: 10px;
|
||||
font-weight: bold;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<canvas id="canvas" width="400" height="400"></canvas>
|
||||
<button onclick="submitDrawing()">提交</button>
|
||||
<button onclick="clearCanvas()">清除</button>
|
||||
<div>结果: <span id="result"></span></div>
|
||||
|
||||
<script>
|
||||
const canvas = document.getElementById("canvas");
|
||||
const context = canvas.getContext("2d");
|
||||
|
||||
let isDrawing = false;
|
||||
|
||||
canvas.addEventListener("mousedown", startDrawing);
|
||||
canvas.addEventListener("mousemove", draw);
|
||||
canvas.addEventListener("mouseup", stopDrawing);
|
||||
canvas.addEventListener("mouseout", stopDrawing);
|
||||
|
||||
function startDrawing(e) {
|
||||
isDrawing = true;
|
||||
draw(e);
|
||||
}
|
||||
|
||||
function draw(e) {
|
||||
if (!isDrawing) return;
|
||||
|
||||
const rect = canvas.getBoundingClientRect();
|
||||
const mouseX = e.clientX - rect.left;
|
||||
const mouseY = e.clientY - rect.top;
|
||||
|
||||
context.lineWidth = 50;
|
||||
context.lineCap = "round";
|
||||
context.strokeStyle = "white";
|
||||
|
||||
context.lineTo(mouseX, mouseY);
|
||||
context.stroke();
|
||||
context.beginPath();
|
||||
context.moveTo(mouseX, mouseY);
|
||||
}
|
||||
|
||||
function stopDrawing() {
|
||||
isDrawing = false;
|
||||
context.beginPath();
|
||||
}
|
||||
|
||||
function submitDrawing() {
|
||||
const imageDataUrl = canvas.toDataURL("image/png");
|
||||
|
||||
const image = new Image();
|
||||
image.src = imageDataUrl;
|
||||
|
||||
image.onload = function() {
|
||||
const tempCanvas = document.createElement("canvas");
|
||||
const tempContext = tempCanvas.getContext("2d");
|
||||
tempCanvas.width = 28;
|
||||
tempCanvas.height = 28;
|
||||
tempContext.drawImage(image, 0, 0, 28, 28);
|
||||
|
||||
const croppedImageDataUrl = tempCanvas.toDataURL("image/png");
|
||||
|
||||
submitToAPI(croppedImageDataUrl);
|
||||
};
|
||||
}
|
||||
|
||||
function clearCanvas() {
|
||||
// 清除Canvas上的绘图
|
||||
context.clearRect(0, 0, canvas.width, canvas.height);
|
||||
}
|
||||
|
||||
function submitToAPI(imageDataUrl) {
|
||||
fetch('/api/predict', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ image: imageDataUrl }),
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log('API返回的结果:', data.result);
|
||||
document.querySelector("#result").innerText = data.result.toString();
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('提交失败', error);
|
||||
});
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
|
@ -0,0 +1,39 @@
|
|||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
|
||||
from mnist_net import MnistNet
|
||||
from mnist_test import device
|
||||
from mnist_test import predict
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
class Request(BaseModel):
|
||||
image: str
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
model = MnistNet()
|
||||
model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
@app.get('/')
|
||||
async def root():
|
||||
html_file = open("./ui/index.html", "r", encoding="utf-8")
|
||||
content = html_file.read()
|
||||
return HTMLResponse(content=content)
|
||||
|
||||
@app.post("/api/predict")
|
||||
async def root(req: Request):
|
||||
img = base64.b64decode(req.image.replace('data:image/png;base64,', ''))
|
||||
img = Image.open(io.BytesIO(img)).convert('L')
|
||||
prediction = predict(device, model, img)
|
||||
return {
|
||||
'result': prediction
|
||||
}
|
Loading…
Reference in New Issue