captcha/export_onnx.py

18 lines
438 B
Python
Raw Normal View History

2024-09-30 11:42:12 +08:00
import torch
import torch.onnx
import torchvision.models as models
import cnn_net
def main():
# Load the pre-trained model
model = cnn_net.ConvNet()
model.load_state_dict(torch.load('model.pth', weights_only=True))
model.eval()
# Export the model to ONNX format
dummy_input = torch.randn(1, 1, 60, 160)
torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)
if __name__ == '__main__':
main()