mnist-learning/webui.py

40 lines
882 B
Python

from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from PIL import Image
import torch
from mnist_net import MnistNet
from mnist_test import device
from mnist_test import predict
import base64
import io
class Request(BaseModel):
image: str
app = FastAPI()
model = MnistNet()
model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device))
model.to(device)
model.eval()
@app.get('/')
async def root():
html_file = open("./ui/index.html", "r", encoding="utf-8")
content = html_file.read()
return HTMLResponse(content=content)
@app.post("/api/predict")
async def root(req: Request):
img = base64.b64decode(req.image.replace('data:image/png;base64,', ''))
img = Image.open(io.BytesIO(img)).convert('L')
prediction = predict(device, model, img)
return {
'result': prediction
}