From 5190b4444081224f8fbc2ef39dd004c1a97fe733 Mon Sep 17 00:00:00 2001 From: TaurusXin Date: Mon, 30 Sep 2024 11:42:12 +0800 Subject: [PATCH] feat: export onnx --- .gitignore | 3 ++- README.md | 6 ++++++ export_onnx.py | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 export_onnx.py diff --git a/.gitignore b/.gitignore index 90908ea..eee2541 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,9 @@ # project files dataset/ - +*.onnx *.pth + # python files __pycache__/ diff --git a/README.md b/README.md index 1ddd2ba..11efa46 100644 --- a/README.md +++ b/README.md @@ -60,3 +60,9 @@ python test.py ```shell python predict.py ``` + +7. 输出 onnx 模型 + +```shell +python export_onnx.py +``` diff --git a/export_onnx.py b/export_onnx.py new file mode 100644 index 0000000..a492199 --- /dev/null +++ b/export_onnx.py @@ -0,0 +1,18 @@ +import torch +import torch.onnx +import torchvision.models as models + +import cnn_net + +def main(): + # Load the pre-trained model + model = cnn_net.ConvNet() + model.load_state_dict(torch.load('model.pth', weights_only=True)) + model.eval() + + # Export the model to ONNX format + dummy_input = torch.randn(1, 1, 60, 160) + torch.onnx.export(model, dummy_input, "model.onnx", verbose=True) + +if __name__ == '__main__': + main() \ No newline at end of file