Giter Club home page Giter Club logo

Comments (4)

ypwhs avatar ypwhs commented on June 26, 2024 3

loss nan 是因为 target_length 过长,input_length 过短。

CTCLoss 要求 input_length >= 2 * target_length + 1,比如 abcd 这个label,输出必须能放得下 -a-b-c-d-,不然就会 nan。

有关代码:

        if (s < 2*max_target_length+1)
          log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] = neginf;
      }

https://github.com/pytorch/pytorch/blob/v1.1.0/aten/src/ATen/native/cuda/LossCTC.cu#L139-L141

解决方法:

  1. 增加图像宽度,增加 input_length
  2. 调整池化层,减少水平方向的池化,增加 input_length

from captcha_break.

AemaH avatar AemaH commented on June 26, 2024

但调整了input_length的长度,在ipynb文件开头设置了n_input_length = 9「target_length*2+1」,或者10,就会出现:

RuntimeError: Expected tensor to have size at least 6 at dimension 1, but got size 64 for argument #2 'targets' (while checking arguments for ctc_loss_gpu)

这样的问题,查询了一下来源,说是input_length过长而导致。

from captcha_break.

ypwhs avatar ypwhs commented on June 26, 2024

n_input_length 不是随便设置的,要和模型输出尺寸一致:

测试模型输出尺寸

model = Model(n_classes, input_shape=(3, height, width))
inputs = torch.zeros((32, 3, height, width))
outputs = model(inputs)
outputs.shape

# torch.Size([12, 32, 37])

这里的 12 就是 input_length,你可以跑一下这段代码,看看你的图像尺寸输入到模型以后,输出的 length 是多少,然后再修改 n_input_length 。

from captcha_break.

AemaH avatar AemaH commented on June 26, 2024

是的,一开始的时候对于n_input_length 进行修改的时候,也是比对原文 按照这里的输出,修改为了6,结果想不到原来CTCLoss是有要求的,回头确实要好好看一下。
万分感谢您的指点,按照您的指导,我删去了一个卷积层+池化层的block,现在的input_lenght变为了12,满足了要求,训练也可以正常继续下去了。😊

顺带插一句隔壁的题外话,在TensorFlow相关的ipynb文件中,您对于模型的evaluate都是编写了一个evalute()函数或者类,这个评估函数算起来总是十分的慢,比如CNN_2019文件中,训练模型总共加起来不过十多分钟,结果在运行这个函数的时候,已经修改epoch为1了,依旧快两个小时,还没有结束的预兆。
按照该函数的意思,应该是遍历全部的可能组合结果对于该识别器进行评估?,在开头设置的时候,我加入了小写字母的考虑,难道是这个原因,导致的计算时间猛增?

from captcha_break.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.