Comments (6)
Thank you for reporting this issue, I am taking a look.
from ort.
Thanks so much @ashbhandare!!
Apologies for the small errors and mistakes you had to find yourself, I should've done a better job vetting the example myself
Config looks good, I'll double check I can see these performance benefits, but this is more than enough to unblock me for the lightning integration :)
from ort.
For your reference, this is a partial code with changes from yours that I have used for the above numbers:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scaler = GradScaler()
torch.cuda.synchronize()
# warmup before measuring
for x, (idx, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):
if x == n_warmup:
break
optimizer.zero_grad()
idx = idx.to(device)
targets = targets.to(device)
with autocast():
logits = model(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
start_time = time.time()
batch_id = 0
for idx, targets in tqdm(train_loader, total=len(train_loader)):
if batch_id == n_steps:
break
optimizer.zero_grad()
idx = idx.to(device)
targets = targets.to(device)
with autocast():
logits = model(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
batch_id += 1
torch.cuda.synchronize()
print("Time taken", time.time() - start_time)`
from ort.
Can confirm that the above works also for Lightning, thanks so much @ashbhandare!!
from ort.
I've included AMP precision, but not included DeepSpeed yet!
from ort.
Hi @SeanNaren , first in the warmup steps, I see you are not doing the loss.backward. This is important as ORT would do the mem allocations on device required for backward pass within the profiling time where the first backward call happens. Also, we need to zero out grads after every step- this has a minor impact on perf since accumulating gradients unnecessarily adds a block of compute that wont be optimized by ORT and reduces the perf benefit.
Secondly, after looking at profiling your sample, I see that increasing the block size to 1024 gives around 5.4% speedup (ORT: 58.48733854293823, Torch: 61.66805934906006)
n_embd = 2048
block_size = 1024
n_layer = 6
batch_size = 4
n_head = 16)
and switching to a config used in HF gpt2 gave a ~13.8% improvement (ORT: 42.62196612358093 Torch: 48.5234797000885):
n_embd = 768 #2048
block_size = 1024 #128
n_layer = 12 #6
batch_size = 4
n_head = 12 #16
That being said, we will continue to investigate the issue with the config you have provided.
from ort.
Related Issues (20)
- ONNXRuntimeError after enabled fp16 mixed precision training HOT 8
- MaxPool op resolved as Aten OP HOT 6
- Seg fault while training model with maxpool op
- Compatibility between ORTModule and DeepSpeed HOT 6
- [Question] PyTorch 1.11 HOT 2
- Turn off fallback to torch by default HOT 3
- `python -m torch_ort.configure` fails with protobuf errors HOT 1
- CUDA error cudaErrorInvalidConfiguration:invalid configuration argument HOT 3
- Where operator export error when performing fp16 quantization
- torch-ort cannot be installed on windows: onnxruntime-training not found HOT 5
- What does ORT stands for? HOT 1
- Will there be new nightly builds with version 1.13.0.dev? HOT 2
- [torch-ort-infer] Aten fallback doesn't work HOT 6
- RuntimeError: Error in execution: At least one output should be requested.
- Warning: Checker does not support models with experimental ops: ATen HOT 2
- Clarify installation requirements for CUDA vs ROCm HOT 1
- Why should I be forced to have a CUDA or ROCm machine when wanting to run OpenVino on Intel? HOT 2
- python -m torch_ort.configure fail HOT 2
- topKgate loss issues
- Does it support TensorRT backend?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ort.