34 lines
1.4 KiB
Python
34 lines
1.4 KiB
Python
import numpy as np
|
||
import captcha_settings
|
||
import torch
|
||
|
||
|
||
# 用torch.zeros()函数生成一个4行36列,值全是0的张量。接着循环标签中的各个字符,将字符在captcha_settings.ALL_CHAR_SET_STR中对应的索引获取到,然后将张量中对应位置的0,改成1。最后要返回一个一维的列表,长度是4*36=144
|
||
def encode(label):
|
||
"""将字符转为独热码"""
|
||
cols = len(captcha_settings.ALL_CHAR_SET_STR)
|
||
rows = captcha_settings.MAX_CAPTCHA
|
||
result = torch.zeros((rows, cols), dtype=float)
|
||
for i, char in enumerate(label):
|
||
j = captcha_settings.ALL_CHAR_SET_STR.index(char)
|
||
result[i, j] = 1.0
|
||
return result.view(1, -1)[0]
|
||
|
||
|
||
# 将模型预测的值从一维转成4行36列的二维张量,然后调用torch.argmax()函数寻找每一行最大值(也就是1)的索引。知道索引后就可以从captcha_settings.ALL_CHAR_SET_STR中找到对应的字符
|
||
def decode(pred_result):
|
||
"""将独热码转为字符"""
|
||
pred_result = pred_result.view(-1, len(captcha_settings.ALL_CHAR_SET_STR))
|
||
index_list = torch.argmax(pred_result, dim=1)
|
||
text = "".join([captcha_settings.ALL_CHAR_SET_STR[i] for i in index_list])
|
||
return text
|
||
|
||
def main():
|
||
label = "ABCD"
|
||
one_hot_label = encode(label)
|
||
print(one_hot_label)
|
||
decoded_label = decode(one_hot_label)
|
||
print(decoded_label)
|
||
|
||
if __name__ == '__main__':
|
||
main() |