feat: export onnx
This commit is contained in:
parent
2ef0dd8a99
commit
5190b44440
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,8 +1,9 @@
|
||||
# project files
|
||||
dataset/
|
||||
|
||||
*.onnx
|
||||
*.pth
|
||||
|
||||
|
||||
# python files
|
||||
__pycache__/
|
||||
|
||||
|
@ -60,3 +60,9 @@ python test.py
|
||||
```shell
|
||||
python predict.py
|
||||
```
|
||||
|
||||
7. 输出 onnx 模型
|
||||
|
||||
```shell
|
||||
python export_onnx.py
|
||||
```
|
||||
|
18
export_onnx.py
Normal file
18
export_onnx.py
Normal file
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
import torch.onnx
|
||||
import torchvision.models as models
|
||||
|
||||
import cnn_net
|
||||
|
||||
def main():
|
||||
# Load the pre-trained model
|
||||
model = cnn_net.ConvNet()
|
||||
model.load_state_dict(torch.load('model.pth', weights_only=True))
|
||||
model.eval()
|
||||
|
||||
# Export the model to ONNX format
|
||||
dummy_input = torch.randn(1, 1, 60, 160)
|
||||
torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue
Block a user