Giter Club home page Giter Club logo

Comments (6)

ashbhandare avatar ashbhandare commented on August 15, 2024 1

Thank you for reporting this issue, I am taking a look.

from ort.

SeanNaren avatar SeanNaren commented on August 15, 2024 1

Thanks so much @ashbhandare!!

Apologies for the small errors and mistakes you had to find yourself, I should've done a better job vetting the example myself

Config looks good, I'll double check I can see these performance benefits, but this is more than enough to unblock me for the lightning integration :)

from ort.

ashbhandare avatar ashbhandare commented on August 15, 2024 1

For your reference, this is a partial code with changes from yours that I have used for the above numbers:

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scaler = GradScaler()
torch.cuda.synchronize()
# warmup before measuring
for x, (idx, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):
    if x == n_warmup:
        break
    optimizer.zero_grad()
    idx = idx.to(device)
    targets = targets.to(device)
    with autocast():
        logits = model(idx)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

torch.cuda.synchronize()
start_time = time.time()
batch_id = 0
for idx, targets in tqdm(train_loader, total=len(train_loader)):
    if batch_id == n_steps:
        break
    optimizer.zero_grad()
    idx = idx.to(device)
    targets = targets.to(device)
    with autocast():
        logits = model(idx)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    batch_id += 1

torch.cuda.synchronize()
print("Time taken", time.time() - start_time)`

from ort.

SeanNaren avatar SeanNaren commented on August 15, 2024 1

Can confirm that the above works also for Lightning, thanks so much @ashbhandare!!

from ort.

SeanNaren avatar SeanNaren commented on August 15, 2024

I've included AMP precision, but not included DeepSpeed yet!

from ort.

ashbhandare avatar ashbhandare commented on August 15, 2024

Hi @SeanNaren , first in the warmup steps, I see you are not doing the loss.backward. This is important as ORT would do the mem allocations on device required for backward pass within the profiling time where the first backward call happens. Also, we need to zero out grads after every step- this has a minor impact on perf since accumulating gradients unnecessarily adds a block of compute that wont be optimized by ORT and reduces the perf benefit.

Secondly, after looking at profiling your sample, I see that increasing the block size to 1024 gives around 5.4% speedup (ORT: 58.48733854293823, Torch: 61.66805934906006)

    n_embd = 2048
    block_size = 1024
    n_layer = 6
    batch_size = 4
    n_head = 16)

and switching to a config used in HF gpt2 gave a ~13.8% improvement (ORT: 42.62196612358093 Torch: 48.5234797000885):

    n_embd = 768 #2048
    block_size = 1024 #128
    n_layer = 12 #6
    batch_size = 4
    n_head = 12 #16

That being said, we will continue to investigate the issue with the config you have provided.

from ort.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.