Comments (3)
But a simple example from docs works as expected
out = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).requires_grad_(True)
print(out.shape)
labels = torch.IntTensor([1, 2])
label_lengths = torch.IntTensor([2])
input_lengths = torch.IntTensor([2])
warp_ctc = CTCLoss(size_average=True)
torch_ctc = nn.CTCLoss(reduction='sum')
loss1 = warp_ctc(out, labels, input_lengths, label_lengths)
print(loss1)
loss2 = torch_ctc(out.log_softmax(2), labels, input_lengths, label_lengths)
print(loss2)
out
torch.Size([2, 1, 5])
tensor([2.4629], grad_fn=<_CTCBackward>)
tensor(2.4629, grad_fn=<SumBackward0>)
@SeanNaren any ideas about the origins of double free or corruption (!prev)
(cpu) or zeros on cuda?
from warp-ctc.
import torch
import torch.nn as nn
from warpctc_pytorch import CTCLoss
out = torch.randn(368, 2, 29)
input_lengths = torch.IntTensor([119, 179])
labels = torch.randint(1, 27, size=(148,)).int()
label_lengths = torch.IntTensor([59, 89])
warp_ctc = CTCLoss()
torch_ctc = nn.CTCLoss(reduction='sum')
loss1 = warp_ctc(out, labels, input_lengths, label_lengths)
print(loss1)
loss2 = torch_ctc(out.log_softmax(2), labels, input_lengths, label_lengths)
print(loss2)
from warp-ctc.
remember follows:
torch.tensor([119, 179]) returns torch.int64 , not torch.int32
torch.randint(1, 27, size=(148,)) returns torch.int64 , not torch.int32
from warp-ctc.
Related Issues (20)
- Use pytest-flake8 or pytest-flakes instead of pytest-pep8 HOT 1
- Mixed precision and warp-ctc HOT 1
- GPU execution requested, but not compiled with GPU support HOT 3
- error when make HOT 2
- 支持CUDA11.1吗 HOT 2
- Cant print Grad tensor from CTCLoss function
- Error in python setup.py install
- make error HOT 8
- cmake error HOT 11
- Import error: libtorch_cpu.so: cannot open shared object file: No such file or directory
- setup.py install error HOT 4
- make: *** No targets specified and no makefile found. Stop
- Doubts about the difference between pytorch's own ctcloss and warp-ctc HOT 2
- Installed successfully but import fails with "torch.utils.ffi is deprecated. Please use cpp extensions instead" HOT 2
- fail to make HOT 1
- libcudart.so.10.2:
- #error This file requires compiler and library support for the ISO C++ 2011 standard. This support is currently experimental, and must be enabled with the -std=c++11 or -std=gnu++11 compiler options. HOT 1
- Error using AnimateDiff-Evolved plugin
- Error in setup process(about version)版本问题
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from warp-ctc.