feat: export onnx

This commit is contained in:
TaurusXin 2024-09-30 11:42:12 +08:00
parent 2ef0dd8a99
commit 5190b44440
Signed by: taurusxin
GPG Key ID: C334DCA04AC2D2CC
3 changed files with 26 additions and 1 deletions

3
.gitignore vendored
View File

@ -1,8 +1,9 @@
# project files # project files
dataset/ dataset/
*.onnx
*.pth *.pth
# python files # python files
__pycache__/ __pycache__/

View File

@ -60,3 +60,9 @@ python test.py
```shell ```shell
python predict.py python predict.py
``` ```
7. 输出 onnx 模型
```shell
python export_onnx.py
```

18
export_onnx.py Normal file
View File

@ -0,0 +1,18 @@
import torch
import torch.onnx
import torchvision.models as models
import cnn_net
def main():
# Load the pre-trained model
model = cnn_net.ConvNet()
model.load_state_dict(torch.load('model.pth', weights_only=True))
model.eval()
# Export the model to ONNX format
dummy_input = torch.randn(1, 1, 60, 160)
torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)
if __name__ == '__main__':
main()