Compare commits
3 Commits
05a8338c72
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
5190b44440
|
|||
|
2ef0dd8a99
|
|||
|
665bc28b36
|
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,8 +1,9 @@
|
|||||||
# project files
|
# project files
|
||||||
dataset/
|
dataset/
|
||||||
|
*.onnx
|
||||||
*.pth
|
*.pth
|
||||||
|
|
||||||
|
|
||||||
# python files
|
# python files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|
||||||
|
|||||||
14
README.md
14
README.md
@@ -31,6 +31,10 @@ pip install -r requirements.txt
|
|||||||
|
|
||||||
3. 根据提示生成数据集,生成3次数据集,分别用于训练,用于测试,用于验证。
|
3. 根据提示生成数据集,生成3次数据集,分别用于训练,用于测试,用于验证。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python captcha_gen.py
|
||||||
|
```
|
||||||
|
|
||||||
建议的数据集长度如下:
|
建议的数据集长度如下:
|
||||||
|
|
||||||
| 数据集 | 长度 |
|
| 数据集 | 长度 |
|
||||||
@@ -39,10 +43,6 @@ pip install -r requirements.txt
|
|||||||
| Test | 1000 |
|
| Test | 1000 |
|
||||||
| Predict | 30 |
|
| Predict | 30 |
|
||||||
|
|
||||||
```shell
|
|
||||||
python captcha_gen.py
|
|
||||||
```
|
|
||||||
|
|
||||||
4. 训练模型
|
4. 训练模型
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -60,3 +60,9 @@ python test.py
|
|||||||
```shell
|
```shell
|
||||||
python predict.py
|
python predict.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
7. 输出 onnx 模型
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python export_onnx.py
|
||||||
|
```
|
||||||
|
|||||||
18
export_onnx.py
Normal file
18
export_onnx.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import torch
|
||||||
|
import torch.onnx
|
||||||
|
import torchvision.models as models
|
||||||
|
|
||||||
|
import cnn_net
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Load the pre-trained model
|
||||||
|
model = cnn_net.ConvNet()
|
||||||
|
model.load_state_dict(torch.load('model.pth', weights_only=True))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Export the model to ONNX format
|
||||||
|
dummy_input = torch.randn(1, 1, 60, 160)
|
||||||
|
torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -29,8 +29,9 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# random pickup some test images
|
# random pickup some test images
|
||||||
pickup_count = 30
|
pickup_count = int(input("输入验证集的图片长度(默认30):") or 30)
|
||||||
pickup_rect = [5, 6]
|
input_rect = input("输入图片的行数和列数(默认5x6):")
|
||||||
|
pickup_rect = [int(i) for i in input_rect.split("x")] if input_rect else [5, 6]
|
||||||
files = os.listdir(captcha_settings.PREDICT_DATASET_PATH)
|
files = os.listdir(captcha_settings.PREDICT_DATASET_PATH)
|
||||||
images_picked = random.sample(files, pickup_count)
|
images_picked = random.sample(files, pickup_count)
|
||||||
|
|
||||||
@@ -38,7 +39,7 @@ def main():
|
|||||||
fig, axes = plt.subplots(nrows=pickup_rect[0], ncols=pickup_rect[1], figsize=(10, 8))
|
fig, axes = plt.subplots(nrows=pickup_rect[0], ncols=pickup_rect[1], figsize=(10, 8))
|
||||||
for i, image_name in enumerate(images_picked):
|
for i, image_name in enumerate(images_picked):
|
||||||
real_text = image_name.split(".")[0].split("_")[-1]
|
real_text = image_name.split(".")[0].split("_")[-1]
|
||||||
file_path = os.path.join(captcha_settings.TEST_DATASET_PATH, image_name)
|
file_path = os.path.join(captcha_settings.PREDICT_DATASET_PATH, image_name)
|
||||||
pred_text = predict(model, file_path)
|
pred_text = predict(model, file_path)
|
||||||
correct = real_text == pred_text
|
correct = real_text == pred_text
|
||||||
axes[i//pickup_rect[1], i%pickup_rect[1]].imshow(plt.imread(file_path))
|
axes[i//pickup_rect[1], i%pickup_rect[1]].imshow(plt.imread(file_path))
|
||||||
|
|||||||
Reference in New Issue
Block a user