captcha/one_hot_encoding.py

34 lines
1.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import captcha_settings
import torch
# 用torch.zeros()函数生成一个4行36列值全是0的张量。接着循环标签中的各个字符将字符在captcha_settings.ALL_CHAR_SET_STR中对应的索引获取到然后将张量中对应位置的0改成1。最后要返回一个一维的列表长度是4*36=144
def encode(label):
"""将字符转为独热码"""
cols = len(captcha_settings.ALL_CHAR_SET_STR)
rows = captcha_settings.MAX_CAPTCHA
result = torch.zeros((rows, cols), dtype=float)
for i, char in enumerate(label):
j = captcha_settings.ALL_CHAR_SET_STR.index(char)
result[i, j] = 1.0
return result.view(1, -1)[0]
# 将模型预测的值从一维转成4行36列的二维张量然后调用torch.argmax()函数寻找每一行最大值也就是1的索引。知道索引后就可以从captcha_settings.ALL_CHAR_SET_STR中找到对应的字符
def decode(pred_result):
"""将独热码转为字符"""
pred_result = pred_result.view(-1, len(captcha_settings.ALL_CHAR_SET_STR))
index_list = torch.argmax(pred_result, dim=1)
text = "".join([captcha_settings.ALL_CHAR_SET_STR[i] for i in index_list])
return text
def main():
label = "ABCD"
one_hot_label = encode(label)
print(one_hot_label)
decoded_label = decode(one_hot_label)
print(decoded_label)
if __name__ == '__main__':
main()