18 lines
438 B
Python
18 lines
438 B
Python
import torch
|
|
import torch.onnx
|
|
import torchvision.models as models
|
|
|
|
import cnn_net
|
|
|
|
def main():
|
|
# Load the pre-trained model
|
|
model = cnn_net.ConvNet()
|
|
model.load_state_dict(torch.load('model.pth', weights_only=True))
|
|
model.eval()
|
|
|
|
# Export the model to ONNX format
|
|
dummy_input = torch.randn(1, 1, 60, 160)
|
|
torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)
|
|
|
|
if __name__ == '__main__':
|
|
main() |