diff --git a/.gitignore b/.gitignore index 90908ea..eee2541 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,9 @@ # project files dataset/ - +*.onnx *.pth + # python files __pycache__/ diff --git a/README.md b/README.md index 1ddd2ba..11efa46 100644 --- a/README.md +++ b/README.md @@ -60,3 +60,9 @@ python test.py ```shell python predict.py ``` + +7. 输出 onnx 模型 + +```shell +python export_onnx.py +``` diff --git a/export_onnx.py b/export_onnx.py new file mode 100644 index 0000000..a492199 --- /dev/null +++ b/export_onnx.py @@ -0,0 +1,18 @@ +import torch +import torch.onnx +import torchvision.models as models + +import cnn_net + +def main(): + # Load the pre-trained model + model = cnn_net.ConvNet() + model.load_state_dict(torch.load('model.pth', weights_only=True)) + model.eval() + + # Export the model to ONNX format + dummy_input = torch.randn(1, 1, 60, 160) + torch.onnx.export(model, dummy_input, "model.onnx", verbose=True) + +if __name__ == '__main__': + main() \ No newline at end of file