from fastapi import FastAPI from fastapi.responses import HTMLResponse from pydantic import BaseModel from PIL import Image import torch from mnist_net import MnistNet from mnist_test import device from mnist_test import predict import base64 import io class Request(BaseModel): image: str app = FastAPI() model = MnistNet() model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device)) model.to(device) model.eval() @app.get('/') async def root(): html_file = open("./ui/index.html", "r", encoding="utf-8") content = html_file.read() return HTMLResponse(content=content) @app.post("/api/predict") async def root(req: Request): img = base64.b64decode(req.image.replace('data:image/png;base64,', '')) img = Image.open(io.BytesIO(img)).convert('L') prediction = predict(device, model, img) return { 'result': prediction }