Comments (4)
loss nan 是因为 target_length 过长,input_length 过短。
CTCLoss 要求 input_length >= 2 * target_length + 1
,比如 abcd
这个label,输出必须能放得下 -a-b-c-d-
,不然就会 nan。
有关代码:
if (s < 2*max_target_length+1)
log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] = neginf;
}
https://github.com/pytorch/pytorch/blob/v1.1.0/aten/src/ATen/native/cuda/LossCTC.cu#L139-L141
解决方法:
- 增加图像宽度,增加 input_length
- 调整池化层,减少水平方向的池化,增加 input_length
from captcha_break.
但调整了input_length的长度,在ipynb文件开头设置了n_input_length = 9
「target_length*2+1」,或者10,就会出现:
RuntimeError: Expected tensor to have size at least 6 at dimension 1, but got size 64 for argument #2 'targets' (while checking arguments for ctc_loss_gpu)
这样的问题,查询了一下来源,说是input_length过长而导致。
from captcha_break.
n_input_length
不是随便设置的,要和模型输出尺寸一致:
测试模型输出尺寸
model = Model(n_classes, input_shape=(3, height, width))
inputs = torch.zeros((32, 3, height, width))
outputs = model(inputs)
outputs.shape
# torch.Size([12, 32, 37])
这里的 12 就是 input_length,你可以跑一下这段代码,看看你的图像尺寸输入到模型以后,输出的 length 是多少,然后再修改 n_input_length 。
from captcha_break.
是的,一开始的时候对于n_input_length 进行修改的时候,也是比对原文 按照这里的输出,修改为了6,结果想不到原来CTCLoss是有要求的,回头确实要好好看一下。
万分感谢您的指点,按照您的指导,我删去了一个卷积层+池化层的block,现在的input_lenght变为了12,满足了要求,训练也可以正常继续下去了。😊
顺带插一句隔壁的题外话,在TensorFlow相关的ipynb文件中,您对于模型的evaluate都是编写了一个evalute()
函数或者类,这个评估函数算起来总是十分的慢,比如CNN_2019文件中,训练模型总共加起来不过十多分钟,结果在运行这个函数的时候,已经修改epoch为1了,依旧快两个小时,还没有结束的预兆。
按照该函数的意思,应该是遍历全部的可能组合结果对于该识别器进行评估?,在开头设置的时候,我加入了小写字母的考虑,难道是这个原因,导致的计算时间猛增?
from captcha_break.
Related Issues (20)
- 如何把保存下来的ctc模型载入继续训练呢?
- 您可否告知下这几个文件是独立运行的吗,如何训练自己的中文验证码呢 HOT 5
- RNN分类之后, 在评估处怎么获得每个字符的概率呢 HOT 1
- 楼主您好,请问3500常用汉字的验证码识别,该模型大小够吗? HOT 5
- 多行验证码如何识别呢? HOT 4
- loss为负数且不断减小
- cannot import name '_imaging' from 'PIL'
- 变长标签怎么处理
- 如果验证码最后两位相同,似乎一定识别错误 HOT 4
- 运行winpy/main.py遇到的问题 HOT 5
- train和val的acc都可以到99%,但是eval,只有0.00265,这是怎么回事呢
- 尝试把n_class+1程序可以运行,但是不知道对不对 HOT 1
- 请问如果是不定长的验证码 可以使用吗 HOT 2
- 使用CTC, 识别时不限制4个字符长度,识别率如何? HOT 3
- CTC模型不定长输出问题 HOT 1
- 效果不理想
- 训练完了怎么用啊,纯小白 HOT 2
- cnn_2019.ipynb(防止 tensorflow 占用所有显存)tensorflow2.0要怎么改 HOT 1
- tensorflow 2.0 训练的时候 日志不显示 不知道 训练到哪一步 HOT 1
- 请问我改如何替换掉ctc_2019中的lambda方法呢 因为lambda在加载保存的模型会有错误 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from captcha_break.