Giter Club home page Giter Club logo

Comments (14)

woctezuma avatar woctezuma commented on August 22, 2024 4

Even faster these days: you get a 4x4 grid instead of a 3x3 grid on Replicate, after the same duration.

However, this is based on Dall-E MEGA instead of Dall-E Mini, so results might differ. Not sure if better or worse.

from min-dalle.

pcuenca avatar pcuenca commented on August 22, 2024 1

generating 1 image took 27 mins on dalle-playground (using 117% CPU), whereas this pytorch model runs in 2.7 mins (using 145% CPU)! GPU looks less-than-half utilized. haven't checked whether pytorch is the process that's using the GPU.

I think the model runs on CPU by default. I tried to move all models and tensors to the mps device and fix some incompatibilities (a few ops are not yet supported by the MPS backend). Inference was faster and GPU utilization was close to 100%, but generation did not work properly. I'm still trying to identify what the problem could be.

from min-dalle.

kuprel avatar kuprel commented on August 22, 2024

Awesome!

from min-dalle.

Birch-san avatar Birch-san commented on August 22, 2024

@pcuenca wait, you got it running on-GPU? and it was faster? that's massively different from the result I got.

here's how I made it run on MPS:
Birch-san/min-dalle@Birch-san:min-dalle:main...Birch-san:min-dalle:mps
there's other stuff in that branch too like generating multiple images, re-using text encoding between images, measuring how long each step takes.

what I found was that it ran way slower. I left it overnight and it didn't finish generating even 1 image (got to the 145th token of 255, something like that).
and tbh the CPU usage (~117%) and GPU usage (less than half) looked identical to when it ran on-CPU.

did I do something wrong? I just slapped device_type on everything I could.
I'm using torch==1.13.0.dev20220628 (recent nightly).
ran with PYTORCH_ENABLE_MPS_FALLBACK=1, --mega --torch --text='kunkka playing basketball with Shrek' --copies=3. with dalle-mega proper, not the fp16 version.
only one operation had to fallback use the fallback-to-CPU, aten::sort.values_stable.

from min-dalle.

Birch-san avatar Birch-san commented on August 22, 2024

generation did not work properly

it's worth knowing that the MPS backend does have some silent errors where it will produce incorrect output (or at least transfer the wrong result to CPU). here's the really wacky phenomenon that I found:
pytorch/pytorch#79383

from min-dalle.

pcuenca avatar pcuenca commented on August 22, 2024

@Birch-san These are my changes so far: main...pcuenca:min-dalle:mps-device

I tried to use workarounds for unsupported ops, except for multinomial. You need to use PYTORCH_ENABLE_MPS_FALLBACK=1 for the backend to automatically fall back to the CPU when it encounters that operation. I also tried to replace it with argmax, which should produce something reasonable, but it did not help with generation.

I may have introduced a problem somewhere, but if you disable the MPS device by returning self here, everything works right.

from min-dalle.

pcuenca avatar pcuenca commented on August 22, 2024

it's worth knowing that the MPS backend does have some silent errors where it will produce incorrect output. here's the really wacky one I found: pytorch/pytorch#79383

That's very interesting. I'll try to debug generation tomorrow. Thanks!

from min-dalle.

kuprel avatar kuprel commented on August 22, 2024

I was also looking into getting this model on the phone. Apple says that for transformers in pytorch, the dimensions aren't in optimal order for the neural engine: https://machinelearning.apple.com/research/neural-engine-transformers

They also convert all the linear layers to convs and use a different einsum pattern

from min-dalle.

Birch-san avatar Birch-san commented on August 22, 2024

I was also looking into getting this model on the phone. Apple says that for transformers in pytorch, the dimensions aren't in optimal order for the neural engine: https://machinelearning.apple.com/research/neural-engine-transformers

They also convert all the linear layers to convs and use a different einsum pattern

that's just the neural engine. PyTorch's MPS backend targets the GPU, and JAX's IREE/Vulkan backend does too. Dunno what Tensorflow targets. but I'll definitely take "targeting 48 GPU cores" as a step up from "targeting 10 CPU cores".

it sounds like the Neural Engine is not suitable for training anyway, only inferencing:
pytorch/pytorch#47688 (comment)

from min-dalle.

pcuenca avatar pcuenca commented on August 22, 2024

The neural engine is much faster than the GPU, so it makes sense to apply those optimizations. Not all operations are supported, however, and it's hard to know whether the system decided to run your model in the neural engine or the GPU.

I wasn't trying to do that yet, though. I just wanted to test inference in the MPS backend (GPU) of my M1 mac to see how it compares with the CPU and with nVidia GPUs. If we did a conversion to Core ML, we would then be able to test neural engine inference speed vs PyTorch+MPS performance.

from min-dalle.

Birch-san avatar Birch-san commented on August 22, 2024

@pcuenca

That's very interesting. I'll try to debug generation tomorrow. Thanks!

If it is indeed the problem of transferring from MPS to CPU, then we should try @qqaatw's idea for transferring as contiguous memory.

pytorch/pytorch#79383 (comment)

from min-dalle.

Birch-san avatar Birch-san commented on August 22, 2024

@pcuenca if I slap .contiguous() at the end of every torch.{reshape,view,unsqueeze,permute}() (i.e. functions which perform reshaping, and which may utilize a view to do so): we get an image that is merely bad rather than pitch-black:
generated generated
kunkka playing basketball with Shrek

Birch-san@8b83231

from min-dalle.

Birch-san avatar Birch-san commented on August 22, 2024

oh, there's one final reshape() that I missed. but adding .contiguous() to that makes things worse rather than better:

generated
kunkka playing basketball with Shrek

Birch-san@43e7e92

from min-dalle.

Birch-san avatar Birch-san commented on August 22, 2024

I also tried using .contiguous() on any tensor that would be transferred to the MPS device:
Birch-san@b1cf6c2

still black.

from min-dalle.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.