first commit
This commit is contained in:
commit
58e344ef13
|
@ -0,0 +1,7 @@
|
||||||
|
__pycache__
|
||||||
|
|
||||||
|
data/*
|
||||||
|
|
||||||
|
!data/.gitkeep
|
||||||
|
|
||||||
|
*.pt
|
|
@ -0,0 +1,43 @@
|
||||||
|
# MNIST 数据集
|
||||||
|
|
||||||
|
使用 PyTorch 训练
|
||||||
|
|
||||||
|
## 依赖安装
|
||||||
|
|
||||||
|
首先根据 [PyTorch 官网](https://pytorch.org/get-started/locally/) 安装 PyTorch,以便启用 GPU 加速
|
||||||
|
|
||||||
|
然后安装其他依赖
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install fastapi uvicorn pillow tqdm
|
||||||
|
```
|
||||||
|
|
||||||
|
## 训练模型
|
||||||
|
|
||||||
|
首先你当然要先训练模型啦
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python mnist_train.py --save-model
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
训练完模型后,本仓库包含两种使用方法,即本地文件和 Web UI
|
||||||
|
|
||||||
|
### 本地文件
|
||||||
|
|
||||||
|
将图片文件放到 input 文件夹内,必须是 28*28 的灰度png格式,然后执行
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python mnist.test.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Web UI
|
||||||
|
|
||||||
|
执行以下命令启动 Web UI
|
||||||
|
|
||||||
|
```shell
|
||||||
|
uvicorn webui:app
|
||||||
|
```
|
||||||
|
|
||||||
|
然后打开 <http://127.0.0.1:8000> 即可看到网页
|
|
@ -0,0 +1 @@
|
||||||
|
# ignore data but keep folder
|
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
Binary file not shown.
After Width: | Height: | Size: 2.3 KiB |
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
|
@ -0,0 +1,28 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class MnistNet(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MnistNet, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
||||||
|
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
||||||
|
self.dropout1 = nn.Dropout(0.25)
|
||||||
|
self.dropout2 = nn.Dropout(0.5)
|
||||||
|
self.fc1 = nn.Linear(9216, 128)
|
||||||
|
self.fc2 = nn.Linear(128, 10)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = F.max_pool2d(x, 2)
|
||||||
|
x = self.dropout1(x)
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = self.dropout2(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
output = F.log_softmax(x, dim=1)
|
||||||
|
return output
|
|
@ -0,0 +1,56 @@
|
||||||
|
import torch
|
||||||
|
from mnist_net import MnistNet
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# find device
|
||||||
|
device = None
|
||||||
|
|
||||||
|
use_cuda = torch.cuda.is_available()
|
||||||
|
use_mps = torch.backends.mps.is_available()
|
||||||
|
|
||||||
|
if use_cuda:
|
||||||
|
device = torch.device("cuda")
|
||||||
|
elif use_mps:
|
||||||
|
device = torch.device("mps")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# load image
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(device, model, img):
|
||||||
|
# transform image
|
||||||
|
img = transform(img)
|
||||||
|
img = img.to(device)
|
||||||
|
img = img.unsqueeze(0)
|
||||||
|
|
||||||
|
# predict
|
||||||
|
output = model(img)
|
||||||
|
pred = output.argmax(dim=1, keepdim=True)
|
||||||
|
return pred.item()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# load model
|
||||||
|
model = MnistNet()
|
||||||
|
model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device))
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# set model to eval mode
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# load image
|
||||||
|
folder = './input'
|
||||||
|
predictions = []
|
||||||
|
|
||||||
|
for image_name in sorted(os.listdir(folder)):
|
||||||
|
img = Image.open(os.path.join(folder, image_name))
|
||||||
|
predictions.append(predict(device, model, img))
|
||||||
|
|
||||||
|
print(predictions)
|
|
@ -0,0 +1,171 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from mnist_net import MnistNet
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
from torch.optim.lr_scheduler import StepLR
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def train(args, model, device, train_loader, optimizer, epoch):
|
||||||
|
model.train()
|
||||||
|
with tqdm(train_loader, desc=f"Train Epoch {epoch}", unit="batch") as t:
|
||||||
|
for batch_idx, (data, target) in enumerate(t):
|
||||||
|
data, target = data.to(device), target.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
output = model(data)
|
||||||
|
loss = F.nll_loss(output, target)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if batch_idx % args.log_interval == 0:
|
||||||
|
t.set_postfix(loss=round(loss.item(), 3))
|
||||||
|
|
||||||
|
if args.dry_run:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def test(model, device, test_loader):
|
||||||
|
model.eval()
|
||||||
|
test_loss = 0
|
||||||
|
correct = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for data, target in test_loader:
|
||||||
|
data, target = data.to(device), target.to(device)
|
||||||
|
output = model(data)
|
||||||
|
test_loss += F.nll_loss(
|
||||||
|
output, target, reduction="sum"
|
||||||
|
).item() # sum up batch loss
|
||||||
|
pred = output.argmax(
|
||||||
|
dim=1, keepdim=True
|
||||||
|
) # get the index of the max log-probability
|
||||||
|
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||||
|
|
||||||
|
test_loss /= len(test_loader.dataset)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
|
||||||
|
test_loss,
|
||||||
|
correct,
|
||||||
|
len(test_loader.dataset),
|
||||||
|
100.0 * correct / len(test_loader.dataset),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Training settings
|
||||||
|
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
metavar="N",
|
||||||
|
help="input batch size for training (default: 64)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-batch-size",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
metavar="N",
|
||||||
|
help="input batch size for testing (default: 1000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--epochs",
|
||||||
|
type=int,
|
||||||
|
default=14,
|
||||||
|
metavar="N",
|
||||||
|
help="number of epochs to train (default: 14)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
metavar="LR",
|
||||||
|
help="learning rate (default: 1.0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gamma",
|
||||||
|
type=float,
|
||||||
|
default=0.7,
|
||||||
|
metavar="M",
|
||||||
|
help="Learning rate step gamma (default: 0.7)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-mps",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="disables macOS GPU training",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dry-run",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="quickly check a single pass",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-interval",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
metavar="N",
|
||||||
|
help="how many batches to wait before logging training status",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-model",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="For Saving the current Model",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
|
use_mps = not args.no_mps and torch.backends.mps.is_available()
|
||||||
|
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
|
||||||
|
if use_cuda:
|
||||||
|
device = torch.device("cuda")
|
||||||
|
elif use_mps:
|
||||||
|
device = torch.device("mps")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
print(f"Using device: {device}\n")
|
||||||
|
|
||||||
|
train_kwargs = {"batch_size": args.batch_size}
|
||||||
|
test_kwargs = {"batch_size": args.test_batch_size}
|
||||||
|
if use_cuda:
|
||||||
|
cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
|
||||||
|
train_kwargs.update(cuda_kwargs)
|
||||||
|
test_kwargs.update(cuda_kwargs)
|
||||||
|
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
||||||
|
)
|
||||||
|
dataset1 = datasets.MNIST("./data", train=True, download=True, transform=transform)
|
||||||
|
dataset2 = datasets.MNIST("./data", train=False, transform=transform)
|
||||||
|
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
|
||||||
|
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
|
||||||
|
|
||||||
|
model = MnistNet().to(device)
|
||||||
|
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
|
||||||
|
|
||||||
|
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
|
||||||
|
for epoch in range(1, args.epochs + 1):
|
||||||
|
train(args, model, device, train_loader, optimizer, epoch)
|
||||||
|
test(model, device, test_loader)
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
if args.save_model:
|
||||||
|
torch.save(model.state_dict(), "mnist_cnn.pt")
|
||||||
|
print("Model saved to mnist_cnn.pt")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,106 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>绘图应用</title>
|
||||||
|
<style>
|
||||||
|
#canvas, #canvas2 {
|
||||||
|
background-color: black;
|
||||||
|
cursor: crosshair;
|
||||||
|
}
|
||||||
|
|
||||||
|
#result {
|
||||||
|
margin-top: 10px;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<canvas id="canvas" width="400" height="400"></canvas>
|
||||||
|
<button onclick="submitDrawing()">提交</button>
|
||||||
|
<button onclick="clearCanvas()">清除</button>
|
||||||
|
<div>结果: <span id="result"></span></div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
const canvas = document.getElementById("canvas");
|
||||||
|
const context = canvas.getContext("2d");
|
||||||
|
|
||||||
|
let isDrawing = false;
|
||||||
|
|
||||||
|
canvas.addEventListener("mousedown", startDrawing);
|
||||||
|
canvas.addEventListener("mousemove", draw);
|
||||||
|
canvas.addEventListener("mouseup", stopDrawing);
|
||||||
|
canvas.addEventListener("mouseout", stopDrawing);
|
||||||
|
|
||||||
|
function startDrawing(e) {
|
||||||
|
isDrawing = true;
|
||||||
|
draw(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
function draw(e) {
|
||||||
|
if (!isDrawing) return;
|
||||||
|
|
||||||
|
const rect = canvas.getBoundingClientRect();
|
||||||
|
const mouseX = e.clientX - rect.left;
|
||||||
|
const mouseY = e.clientY - rect.top;
|
||||||
|
|
||||||
|
context.lineWidth = 50;
|
||||||
|
context.lineCap = "round";
|
||||||
|
context.strokeStyle = "white";
|
||||||
|
|
||||||
|
context.lineTo(mouseX, mouseY);
|
||||||
|
context.stroke();
|
||||||
|
context.beginPath();
|
||||||
|
context.moveTo(mouseX, mouseY);
|
||||||
|
}
|
||||||
|
|
||||||
|
function stopDrawing() {
|
||||||
|
isDrawing = false;
|
||||||
|
context.beginPath();
|
||||||
|
}
|
||||||
|
|
||||||
|
function submitDrawing() {
|
||||||
|
const imageDataUrl = canvas.toDataURL("image/png");
|
||||||
|
|
||||||
|
const image = new Image();
|
||||||
|
image.src = imageDataUrl;
|
||||||
|
|
||||||
|
image.onload = function() {
|
||||||
|
const tempCanvas = document.createElement("canvas");
|
||||||
|
const tempContext = tempCanvas.getContext("2d");
|
||||||
|
tempCanvas.width = 28;
|
||||||
|
tempCanvas.height = 28;
|
||||||
|
tempContext.drawImage(image, 0, 0, 28, 28);
|
||||||
|
|
||||||
|
const croppedImageDataUrl = tempCanvas.toDataURL("image/png");
|
||||||
|
|
||||||
|
submitToAPI(croppedImageDataUrl);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function clearCanvas() {
|
||||||
|
// 清除Canvas上的绘图
|
||||||
|
context.clearRect(0, 0, canvas.width, canvas.height);
|
||||||
|
}
|
||||||
|
|
||||||
|
function submitToAPI(imageDataUrl) {
|
||||||
|
fetch('/api/predict', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ image: imageDataUrl }),
|
||||||
|
})
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
console.log('API返回的结果:', data.result);
|
||||||
|
document.querySelector("#result").innerText = data.result.toString();
|
||||||
|
})
|
||||||
|
.catch(error => {
|
||||||
|
console.error('提交失败', error);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
|
@ -0,0 +1,39 @@
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mnist_net import MnistNet
|
||||||
|
from mnist_test import device
|
||||||
|
from mnist_test import predict
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
class Request(BaseModel):
|
||||||
|
image: str
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
model = MnistNet()
|
||||||
|
model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device))
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
@app.get('/')
|
||||||
|
async def root():
|
||||||
|
html_file = open("./ui/index.html", "r", encoding="utf-8")
|
||||||
|
content = html_file.read()
|
||||||
|
return HTMLResponse(content=content)
|
||||||
|
|
||||||
|
@app.post("/api/predict")
|
||||||
|
async def root(req: Request):
|
||||||
|
img = base64.b64decode(req.image.replace('data:image/png;base64,', ''))
|
||||||
|
img = Image.open(io.BytesIO(img)).convert('L')
|
||||||
|
prediction = predict(device, model, img)
|
||||||
|
return {
|
||||||
|
'result': prediction
|
||||||
|
}
|
Loading…
Reference in New Issue