captcha/one_hot_encoding.py

34 lines
1.4 KiB
Python
Raw Permalink Normal View History

2024-09-30 11:05:35 +08:00
import numpy as np
import captcha_settings
import torch
# 用torch.zeros()函数生成一个4行36列值全是0的张量。接着循环标签中的各个字符将字符在captcha_settings.ALL_CHAR_SET_STR中对应的索引获取到然后将张量中对应位置的0改成1。最后要返回一个一维的列表长度是4*36=144
def encode(label):
"""将字符转为独热码"""
cols = len(captcha_settings.ALL_CHAR_SET_STR)
rows = captcha_settings.MAX_CAPTCHA
result = torch.zeros((rows, cols), dtype=float)
for i, char in enumerate(label):
j = captcha_settings.ALL_CHAR_SET_STR.index(char)
result[i, j] = 1.0
return result.view(1, -1)[0]
# 将模型预测的值从一维转成4行36列的二维张量然后调用torch.argmax()函数寻找每一行最大值也就是1的索引。知道索引后就可以从captcha_settings.ALL_CHAR_SET_STR中找到对应的字符
def decode(pred_result):
"""将独热码转为字符"""
pred_result = pred_result.view(-1, len(captcha_settings.ALL_CHAR_SET_STR))
index_list = torch.argmax(pred_result, dim=1)
text = "".join([captcha_settings.ALL_CHAR_SET_STR[i] for i in index_list])
return text
def main():
label = "ABCD"
one_hot_label = encode(label)
print(one_hot_label)
decoded_label = decode(one_hot_label)
print(decoded_label)
if __name__ == '__main__':
main()