40 lines
882 B
Python
40 lines
882 B
Python
from fastapi import FastAPI
|
|
from fastapi.responses import HTMLResponse
|
|
from pydantic import BaseModel
|
|
|
|
from PIL import Image
|
|
|
|
import torch
|
|
|
|
from mnist_net import MnistNet
|
|
from mnist_test import device
|
|
from mnist_test import predict
|
|
|
|
import base64
|
|
import io
|
|
|
|
class Request(BaseModel):
|
|
image: str
|
|
|
|
app = FastAPI()
|
|
|
|
model = MnistNet()
|
|
model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device))
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
@app.get('/')
|
|
async def root():
|
|
html_file = open("./ui/index.html", "r", encoding="utf-8")
|
|
content = html_file.read()
|
|
return HTMLResponse(content=content)
|
|
|
|
@app.post("/api/predict")
|
|
async def root(req: Request):
|
|
img = base64.b64decode(req.image.replace('data:image/png;base64,', ''))
|
|
img = Image.open(io.BytesIO(img)).convert('L')
|
|
prediction = predict(device, model, img)
|
|
return {
|
|
'result': prediction
|
|
}
|