34 lines
1.4 KiB
Python
34 lines
1.4 KiB
Python
|
import numpy as np
|
|||
|
import captcha_settings
|
|||
|
import torch
|
|||
|
|
|||
|
|
|||
|
# 用torch.zeros()函数生成一个4行36列,值全是0的张量。接着循环标签中的各个字符,将字符在captcha_settings.ALL_CHAR_SET_STR中对应的索引获取到,然后将张量中对应位置的0,改成1。最后要返回一个一维的列表,长度是4*36=144
|
|||
|
def encode(label):
|
|||
|
"""将字符转为独热码"""
|
|||
|
cols = len(captcha_settings.ALL_CHAR_SET_STR)
|
|||
|
rows = captcha_settings.MAX_CAPTCHA
|
|||
|
result = torch.zeros((rows, cols), dtype=float)
|
|||
|
for i, char in enumerate(label):
|
|||
|
j = captcha_settings.ALL_CHAR_SET_STR.index(char)
|
|||
|
result[i, j] = 1.0
|
|||
|
return result.view(1, -1)[0]
|
|||
|
|
|||
|
|
|||
|
# 将模型预测的值从一维转成4行36列的二维张量,然后调用torch.argmax()函数寻找每一行最大值(也就是1)的索引。知道索引后就可以从captcha_settings.ALL_CHAR_SET_STR中找到对应的字符
|
|||
|
def decode(pred_result):
|
|||
|
"""将独热码转为字符"""
|
|||
|
pred_result = pred_result.view(-1, len(captcha_settings.ALL_CHAR_SET_STR))
|
|||
|
index_list = torch.argmax(pred_result, dim=1)
|
|||
|
text = "".join([captcha_settings.ALL_CHAR_SET_STR[i] for i in index_list])
|
|||
|
return text
|
|||
|
|
|||
|
def main():
|
|||
|
label = "ABCD"
|
|||
|
one_hot_label = encode(label)
|
|||
|
print(one_hot_label)
|
|||
|
decoded_label = decode(one_hot_label)
|
|||
|
print(decoded_label)
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
main()
|