Giter Club home page Giter Club logo

Comments (7)

LetiP avatar LetiP commented on September 13, 2024 1

Great! Best of luck and may the cluster be with you. 🤞

from mm-shap.

LetiP avatar LetiP commented on September 13, 2024

Hello @ChengYuChuan , thanks for the question!

Unfortunately, none of the MM-SHAP script parameter reduce VRAM utilization.

  • The num_samples determine on how many examples you want to evaluate, but this evaluation is sequential (all samples are dona after the other) so no reduction here. It seems like you do not have enough VRAM for even one sample interpretation.
  • You should not change the patch_size because otherwise you will have too different sequence lengths for text and image, negatively affecting the interpretation.

Proposed Solution:

So the only thing I can imagine is that you should reduce the batch size of the shap.Explainer which is defined here. I think all you need to do to affect this parameter there, is change this line to shap.Explainer( get_model_prediction, custom_masker, silent=True, batch_size=1) or tot he maximum batch_size that fits in your VRAM.

Hope this helps!

from mm-shap.

ChengYuChuan avatar ChengYuChuan commented on September 13, 2024

hello @LetiP , thanks for you quick reply.

I am trying to reduce the batch size as you said, but the weird thing is Tried to allocate 16.00 MiB the number of this did not change when I decreased the batch size.

May I ask what's your computer spec. ?

I am considering to try the whole program on a Mac M1 chip with 16 RAM or deploy it on google colab, but the Cuda part seems to me that only support the NVIDIA GPU. However, this project is based on python 3.6 I am worried that python 3.8 would not colabrate with the env setting.... (Since if I make it run on Mac M1, I need to use the latest PyTorch 1.12 wiht python 3.8)

Have you ever heard someone run the project on MAC system or colab? (sorry,I know this question may not be in your experience. :/)

from mm-shap.

LetiP avatar LetiP commented on September 13, 2024

Where exactly the code throws OOM errors? Because from what you just said, it might be even at the level of the model loading and it does not get into the SHAP analysis where the decrease of the batch size might help.

I used a GPU with a 24GB VRAM. I see no reason for why the code would not work with python 3.8. Definitely worth a try. I did use Linux on the local cluster so sorry, I have no experience with mac and colab for this project.

Just thinking that maybe you could also try to:

  • load the model in float.16 and see if it supports half precision?
  • load the model here with torch.no_grad() around it.

from mm-shap.

ChengYuChuan avatar ChengYuChuan commented on September 13, 2024

Sorry, I should have described the problem more specifically in the beginning.

Yes, I think you are right. the problem is narrowed down to the model loading part.

In the beginning, I thought it was only an evaluation between different metrics, therefore I thought a normal PC could handle this.
I only have an Nvidia GeForce MX150 GPU with 2 VRAM...

Here is my error message.

(shap) C:\Users\Louis\PycharmProjects\NLP\MM-SHAP>python mm-shap_albef_dataset.py
Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.
  0%|                                                                                          | 0/534 [00:03<?, ?it/s]
Traceback (most recent call last):
  File "mm-shap_albef_dataset.py", line 280, in <module>
    model_prediction = model(image, text_input)
  File "C:\Users\Louis\anaconda3\envs\shap\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "mm-shap_albef_dataset.py", line 84, in forward
    image_embeds = self.visual_encoder(image)
  File "C:\Users\Louis\anaconda3\envs\shap\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Louis\PycharmProjects\NLP\MM-SHAP\ALBEF\models\vit.py", line 171, in forward
    x = blk(x, register_blk==i)
  File "C:\Users\Louis\anaconda3\envs\shap\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Louis\PycharmProjects\NLP\MM-SHAP\ALBEF\models\vit.py", line 92, in forward
    x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
  File "C:\Users\Louis\anaconda3\envs\shap\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Louis\PycharmProjects\NLP\MM-SHAP\ALBEF\models\vit.py", line 63, in forward
    attn = (q @ k.transpose(-2, -1)) * self.scale
RuntimeError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 2.00 GiB total capacity; 1008.19 MiB already allocated; 11.44 MiB free; 1.04 GiB reserved in total by PyTorch)
  1. Do you mean to change the model path here? (checkp = "mscoco" # refcoco, mscoco, vqa, flickr30k this is my current pick. )

    model_path = f'ALBEF/checkpoints/{checkp}.pth' # largest model: ALBEF.pth, smaller: ALBEF_4M.pth, refcoco, mscoco, vqa, flickr30k

  2. the original code is:

# create an explainer with model and image masker
# change here to overcome GPU our of memory
explainer = shap.Explainer(get_model_prediction, custom_masker, silent=True, batch_size = 32)
shap_values = explainer(X)
mm_score = compute_mm_score(nb_text_tokens, shap_values)

when I want to activate torch.no_grad(), how should I manage this?
here is my first thought:

with torch.no_grad():
    explainer = shap.Explainer(get_model_prediction, custom_masker, silent=True, batch_size = 32)
shap_values = explainer(X)
mm_score = compute_mm_score(nb_text_tokens, shap_values)

from mm-shap.

LetiP avatar LetiP commented on September 13, 2024

Ok, so it is clear that your problem is not the shap (interpretability code), but just plain model inferencing. Given your hardware, you cannot load the ALBEF model. The code does not even get to the interpretation.
Since this script is about interpreting a model, you need enough space to run the model, which you apparently do not have.

I think you need to edit this, to

    with torch.no_grad():
        checkpoint = torch.load(model_path, map_location='cpu')
        msg = model.load_state_dict(checkpoint, strict=False)

Also, it is best to put that also around the line that throws the error message for you, just in case.

    model_prediction = model(image, text_input)

from mm-shap.

ChengYuChuan avatar ChengYuChuan commented on September 13, 2024

@LetiP Thank you so much.

After I talk to my lecturer, I switch the task to cluster.

I would like to implement the project there.

Therefore, I won't encounter this issue anymore.

Thank you again. :D

from mm-shap.

Related Issues (8)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.