import numpy as np import captcha_settings import torch # 用torch.zeros()函数生成一个4行36列,值全是0的张量。接着循环标签中的各个字符,将字符在captcha_settings.ALL_CHAR_SET_STR中对应的索引获取到,然后将张量中对应位置的0,改成1。最后要返回一个一维的列表,长度是4*36=144 def encode(label): """将字符转为独热码""" cols = len(captcha_settings.ALL_CHAR_SET_STR) rows = captcha_settings.MAX_CAPTCHA result = torch.zeros((rows, cols), dtype=float) for i, char in enumerate(label): j = captcha_settings.ALL_CHAR_SET_STR.index(char) result[i, j] = 1.0 return result.view(1, -1)[0] # 将模型预测的值从一维转成4行36列的二维张量,然后调用torch.argmax()函数寻找每一行最大值(也就是1)的索引。知道索引后就可以从captcha_settings.ALL_CHAR_SET_STR中找到对应的字符 def decode(pred_result): """将独热码转为字符""" pred_result = pred_result.view(-1, len(captcha_settings.ALL_CHAR_SET_STR)) index_list = torch.argmax(pred_result, dim=1) text = "".join([captcha_settings.ALL_CHAR_SET_STR[i] for i in index_list]) return text def main(): label = "ABCD" one_hot_label = encode(label) print(one_hot_label) decoded_label = decode(one_hot_label) print(decoded_label) if __name__ == '__main__': main()