Giter Club home page Giter Club logo

Comments (5)

erprogs avatar erprogs commented on August 22, 2024 1

Hello @nguyenvulong,

Thank you for your interest in our work!

The real class when the data loader loads the images is {FAKE: 0, REAL: 1}, because fake is the folder it first loads. However, for our prediction we output FAKE: 1 and REAL: 0

data loading and training : {FAKE: 0, REAL: 1}
Prediction: {FAKE: 1, REAL: 0}, in many deepfake challenges, fake class is 1.

In our current case (tensor([[0.0468, 0.9539]], device='cuda:0')), max_prediction_value returns 1, and abs(1 - 0.9539) = 0.0461. the class is real.

Now, to be consistent and reflect with how the data loader classifies the images, I used {0: "REAL", 1: "FAKE"}, which then I can easily flip using XOR.

if the prediction is 0 => 1, if prediction is 1 => 0

we can write an if else statement, but i found XOR a nicer way to do it.

I hope this answers your question.

from genconvit.

erprogs avatar erprogs commented on August 22, 2024 1

Hello @nguyenvulong,

Thank you for finding the issue with number of frames being 1 and for your PR. I've checked the code and it indeed fails when the number of frame is 1.

I've seen your code, and thank you for your effort. However, I've a simple work around: check if the mean_val.numel() is 1. if mean_val.numel() is 1, then use y_pred directly, since its detecting a single prediction rather than a batch. if mean_val.numel() is greater than 1, continue.

Then the rest of the code logic remains the same.

def max_prediction_value(y_pred):
    # Finds the index and value of the maximum prediction value.
    mean_val = torch.mean(y_pred, dim=0)
    
    # Check if mean_val element is 1
    if mean_val.numel() == 1:
        mean_val = y_pred
   
    return (
        torch.argmax(mean_val).item(),
        mean_val[0].item()
        if mean_val[0] > mean_val[1]
        else abs(1 - mean_val[1]).item(),

    )

I've tested the updated code with 1 frame using the sample videos. The image shows the prediction using 1 frame.

image_2024-01-21_13-13-15

Thank you again!

from genconvit.

nguyenvulong avatar nguyenvulong commented on August 22, 2024 1

Thanks so much. I will close the PR and this question for now!

from genconvit.

nguyenvulong avatar nguyenvulong commented on August 22, 2024

Thank you it's clear. However I need to discuss about the max_prediction_value in the case of custom videos (not from the datasets but collected in the wild).

python prediction.py \             
    --p ./test_videos \
    --f 1 \
    --d yours \
    --n ed
def max_prediction_value(y_pred):
    # Finds the index and value of the maximum prediction value.
    mean_val = torch.mean(y_pred, dim=0)
    print(f"mean_val: {mean_val}")
    print(f"y_pred: {y_pred}")
    print(f"y_pred dim: {y_pred.dim()}")
    return (
        torch.argmax(mean_val).item(),
        mean_val[0].item()
        if mean_val[0] > mean_val[1]
        else abs(1 - mean_val[1]).item(),
    )

In this case, the input to max_prediction_value which is y_pred (or torch.sigmoid(model(df).squeeze()) from pred_vid) has only one dimension (Please refer to the values I printed).
Therefore, mean_val[0] and mean_val[1] will cause the following error.

mean_val: 0.49663394689559937
y_pred: tensor([0.3401, 0.6532], device='cuda:0')
y_pred dim: 1
An error occurred: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number

My guess is in your experience, some how y_pred is a batch with two dimensions after torch.sigmoid(model(df).squeeze()) from pred_vid. Maybe that's the case with the datasets [ dfdc, faceforensics, timit, celeb ] but not the custom dataset?

from genconvit.

nguyenvulong avatar nguyenvulong commented on August 22, 2024

Update: I know why it happened. Because the number of frames = 1.
This is the updated code to handle both cases when number of frames = 1 and >1.
I will create a PR if you don't mind.

def max_prediction_value(y_pred):

    if y_pred.dim() == 1 and y_pred.size(0) == 2:
        # When y_pred is a 1D tensor with two elements, no need to take the mean
        pred_label = torch.argmax(y_pred).item()
        pred_val = y_pred[pred_label].item()
        return (pred_label, pred_val)
    else:
        # Compute the mean value across the batch dimension (dim=0)
        mean_val = torch.mean(y_pred, dim=0)
        # Still, check if mean_val is not a 0-dimensional tensor, just to be safe
        if mean_val.dim() == 0:
            mean_val_val = mean_val.item()
            return (0, mean_val_val) if mean_val_val > 0.5 else (1, 1 - mean_val_val)
        # Assume mean_val is a tensor with more than one dimension
        pred_label = torch.argmax(mean_val).item()
        pred_val = mean_val[pred_label].item()
        return (pred_label, pred_val)

from genconvit.

Related Issues (9)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.