feat: export onnx
This commit is contained in:
parent
2ef0dd8a99
commit
5190b44440
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,8 +1,9 @@
|
|||||||
# project files
|
# project files
|
||||||
dataset/
|
dataset/
|
||||||
|
*.onnx
|
||||||
*.pth
|
*.pth
|
||||||
|
|
||||||
|
|
||||||
# python files
|
# python files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|
||||||
|
@ -60,3 +60,9 @@ python test.py
|
|||||||
```shell
|
```shell
|
||||||
python predict.py
|
python predict.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
7. 输出 onnx 模型
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python export_onnx.py
|
||||||
|
```
|
||||||
|
18
export_onnx.py
Normal file
18
export_onnx.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import torch
|
||||||
|
import torch.onnx
|
||||||
|
import torchvision.models as models
|
||||||
|
|
||||||
|
import cnn_net
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Load the pre-trained model
|
||||||
|
model = cnn_net.ConvNet()
|
||||||
|
model.load_state_dict(torch.load('model.pth', weights_only=True))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Export the model to ONNX format
|
||||||
|
dummy_input = torch.randn(1, 1, 60, 160)
|
||||||
|
torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user