Giter Club home page Giter Club logo

fba_matting's Introduction

FBA Matting Open In Colab HuggingFace PWC License: MIT Arxiv

Official repository for the paper F, B, Alpha Matting. This paper and project is under heavy revision for peer reviewed publication, and so I will not be able to release the training code yet.
Marco Forte1, François Pitié1

1 Trinity College Dublin

Requirements

GPU memory >= 11GB for inference on Adobe Composition-1K testing set, more generally for resolutions above 1920x1080.

Packages:

  • torch >= 1.4
  • numpy
  • opencv-python

Additional Packages for jupyter notebook

  • matplotlib
  • gdown (to download model inside notebook)

Models

These models have been trained on Adobe Image Matting Dataset. They are covered by the Adobe Deep Image Mattng Dataset License Agreement so they can only be used and distributed for noncommercial purposes.
More results of this model avialiable on the alphamatting.com, the videomatting.com benchmark, and the supplementary materials PDF.

Model Name File Size SAD MSE Grad Conn
FBA Table. 4 139mb 26.4 5.4 10.6 21.5

Prediction

We provide a script demo.py and jupyter notebook which both give the foreground, background and alpha predictions of our model. The test time augmentation code will be made availiable soon.
In the torchscript notebook we show how to convert the model to torchscript.

In this video I demonstrate how to create a trimap in Pinta/Paint.NET.

Training

Training code is not released at this time. It may be released upon acceptance of the paper. Here are the key takeaways from our work with regards training.

  • Use a batch-size of 1, and use Group Normalisation and Weight Standardisation in your network.
  • Train with clipping of the alpha instead of sigmoid.
  • The L1 alpha, compositional loss and laplacian loss are beneficial. Gradient loss is not needed.
  • For foreground prediction, we extend the foreground to the entire image and define the loss on the entire image or at least the unknown region. We found this better than solely where alpha>0. Code for foreground extension

Citation

@article{forte2020fbamatting,
  title   = {F, B, Alpha Matting},
  author  = {Marco Forte and François Pitié},
  journal = {CoRR},
  volume  = {abs/2003.07711},
  year    = {2020},
}

Related works of ours

  • 99% accurate interactive object selection with just a few clicks: PDF, Code

fba_matting's People

Contributors

leonelhs avatar marcoforte avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

fba_matting's Issues

Foreground extension script doesn't work

I ran the foreground extension script you pointed to, but it does not seem to extend the foreground to the entire image. The surrounding areas are still black. Can you explain how you extend the foreground? Did you use some other method to extend the foreground to the entire image?

AttributeError: 'Args' object has no attribute 'seek'. in notebook

Using FBA Matting.ipynb in Gooble Colab, at the source code below,

class Args:
  encoder = 'resnet50_GN_WS'
  decoder = 'fba_decoder'
  weights = 'FBA.pth'
args=Args()
try:
    model = build_model(args)
except:
    !gdown  https://drive.google.com/uc?id=1T_oiKDE_biWf2kqexMEN7ObWqtXAzbB1
    model = build_model(args)

the following error occurred.

modifying input layer to accept 11 channels
Downloading...
From: https://drive.google.com/uc?id=1T_oiKDE_biWf2kqexMEN7ObWqtXAzbB1
To: /content/drive/MyDrive/FBA Matting/FBA.pth
100% 139M/139M [00:01<00:00, 81.2MB/s]
modifying input layer to accept 11 channels
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[/usr/local/lib/python3.7/dist-packages/torch/serialization.py](https://localhost:8080/#) in _check_seekable(f)
    307     try:
--> 308         f.seek(f.tell())
    309         return True

AttributeError: 'Args' object has no attribute 'seek'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
14 frames
AttributeError: 'Args' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
AttributeError: 'Args' object has no attribute 'seek'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
[/usr/local/lib/python3.7/dist-packages/torch/serialization.py](https://localhost:8080/#) in raise_err_msg(patterns, e)
    302                                 + " Please pre-load the data into a buffer like io.BytesIO and"
    303                                 + " try to load from it instead.")
--> 304                 raise type(e)(msg)
    305         raise e
    306 

AttributeError: 'Args' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

Please tell me the cause of the error and how to fix it.

Issue in input trimaps

In the function read_trimap I saw these two lines

# trimap[trimap_im == 1, 1] = 1
# trimap[trimap_im == 0, 0] = 1

wouldnt this binarize trimap ? Ideally shouldn't the trimap be continous b/w 0 and 1

Thanks for this great work btw. Paper has some really promising results

Random foreground composition vs solving

In paper you wrote:

To further increase the dataset diversity, we randomly composite a new foreground object with 50% probability, as in [34].

Could you please provide more details on how exactly you did it. Did you compose foregrounds before or after solving colors in areas where alpha is 0.
And did you do something (maybe mean?) to save correct foreground colors during composition where both fg's alpha is 0?

I noticed that if i compose solved foregrounds it results to dominating "background" fg colors where alpha is 0.

Here is an example of compositioning 2 foregrounds in different order. Look at "orange"/green background.
Снимок экрана 2021-04-06 в 09 25 33

And here is what i've got if solve combined fg one more time (but as i think, double solving could lead to incorrect colors in semitransparent regions)
Снимок экрана 2021-04-06 в 09 28 51

Gradient loss

Could you please point which gradient loss implementation you used?

P.S.
There is no detail description or source code in linked paper "Learning-based sampling
for natural image matting".

resnet_bn backbone

can you provide the pretrained weight base on resnet_bn backbone, the current resnet_ GN_ WS contains nn.GroupNorm
operator,making it difficult to deploy

Loss function and fba_fusion()

in fba_decoder, there is a fba_fusion() function.
Is the loss function calculated by the output of the fba_fusion()? or the output of the conv_up4()->clamp and sigmoid

Trimap Generation

The paper does not mention the method to generate trimap from the given alpha matte.
Can you please elaborate on how you have generated trimap from matte in the Alpha Composition 1k dataset.
OR in case if matte is not given, how can we generate trimap?

about the exclusion loss

Hi, I have read your paper about the exclusion loss, does the exclusion loss can be written like this in pytorch:

(grad() is a function getting the H grad or V grad)
foreground_grad = torch.abs(grad(foreground))
background_grad = torch.abs(grad(background))
elemental_mul = foreground_grad.mul(background_grad)
loss = sum(elemental_mul)

Looking forward to your reply.

Parameter Size

Is it correct to count your model size as 33.08585262298584M (1024X1024) parameters?

Artifacts when applying to green screen removal

Hello!

I am trying to apply the FBA Matting method to removing green screen from images (portraits of people shot on a green screen).
Overall, it works well but it leaves some green pixels mixed with the hair (I attach here an example).

I can see 2 possible reasons for this issue:

  1. As the result heavily depends on the quality of trimaps, better trimaps are needed. I am currently generating trimaps using DeepLabV3. I tried to adjust some parameters (conf_threshold, dilation), it got better but not perfect yet. Do you know if there are any other parameters that should be tuned or models I can try for trimaps creation?
  2. This model was trained on Adobe Image Matting Dataset that does not contain portraits with green screen. Do you think it would make a big difference if I retrained this model on my custom dataset of green screen portraits?

Thank you in advance for any help! I would truly appreciate it!
black_happy_17-gigapixel-standard-scale-2_00x_swapped_bg

Infering on large images

As the model is trained on small patches, will it have trouble inferring on large 2000X2000 images as it has never see details of this larger size? Don't all matting models suffer from this as all are trained on 320X320?

about the input channels

Thanks for your great job, it seems that you normlize the conv' weight, why ? and how about the performance

Generating training material - F/B reconstruction

In the paper you mention that "Levin's F,B estimation technique" was used. That paper however only gives a very brief mention of what method was used for reconstructing.

Now I'm trying to see if it would be doable to fine-tune this model to work with higher resolutions (around 1K), which means I want to generate some extra training data. Would you have any resources that could help me implement Levin's? (Or even better, how did you do it?)

About the color-spill problem and Fig.2. in your paper

Hi, I have read your paper on arxiv. In your paper, the color-spill problem was proposed and an example was given in Fig.2. I tried to reproduce your result in Fig.2 using the same foreground in Composition-1k dataset. But the composite image is totally different from yours. Besides, I tried to use Closed-form method to estimate a new foreground, but got a different one. My new fg is black in the area where alpha equals to zero.

So how did you composite the original foreground onto the background? And how did you re-estimate the new foreground? Can you please share your code?

Thanks!

Evaluation on alphamatting.com removed

Hi,
I noticed that FBA_Matting has been removed from the alphamatting.com evaluation. Is there a specific reason for this, are the results given in your publication still valid?

Kind regards, Julian

About the number of input channels

In paper 3.1 :

First, we increase the number of input channels from 3 to 9 to allow for
the extra trimap. We encode the trimap using Gaussian blurs of the definite
foreground and background masks at three different scales (in a similar way
to the method of [19] in interactive segmentation). This encoding differs from
existing approaches in deep image matting, as they usually encode the trimap as
a single channel with value 1 if foreground, 0.5 for unknown and 0 for background.

I know the output channels is 7 (a=1, F=3, B=3)
But why the input channels is 9?
In you code,I saw the input are image and trimap,then input channels will be 4.
So,why the input channels is 9 in paper 3.1?

TTA details

The effect of your proposed model and training method is really amazing. “This, combined with longer training at 45 epochs and TTA, has a bigger impact than the choices of loss functions”. It seems TTA also plays an important role. So I want to konw the detail of TTA, such as the example code or related info. Looking forward to your reply!

Pyramid Laplacian loss

Thanks for the great work! Your results are really impressive. I noticed that you suggested using laplacian loss for training. May I ask for the implementation details of it?

How to export TorchScript file

I want to load model in c++ with libtorch, but the libtorch just log jit.save files.

sm = torch.jit.script(model)

not work. Can't convert Union[int,Tuple[int,int]] for MaxUnpool2d.

How to fix this issue.

training code

good day,
has the training code been released for this model?
I would like to train it on a new dataset,
is the model receiving as training input pairs of images and ground truth alpha masks?
thank you for any tips
best

Radam not converging

Hello, thnak you for your great work
I am trying to make the training code.
So i started by making alpha loss, composition loss and regression loss, and i used radam optimizer
RAdam(group_weight(net, 1e-5, 1e-5, 0.0001)) as you mentionned.
To test the convegenve and evrithing is right in my code i took only 2 images with hair from adobe dataset and start training with these 2 images.
But it's seems the loss is not converging, however when using the classical adam optimizer i got convergence in few iterations.
the group_weight function seems dealing with parameters of the resnet encoder, i think i should use a pretrained resnet.
could you share the pretrained resnet you used and if possible the code to load it in matting module.
thanks in advance

About the input channels

Hi.
As you mentioned in the paper, you changed the input channels from 3 to 9. However, when I look into your code, in models.py line 53, I found that num_channels = 3 + 6 + 2. Does it mean input channel of 11? What does 3, 6 and 2 indicate?
Looking forward to your reply. Thanks.

Questions about training and test details

Hi, thank you for sharing your code and paper!

Recently, I’m reproducing your work. I have some questions about the training and test details.

  1. In the paper, it is said that RAdam optimizer is used with momentum 0.9 and weight decay 0.0001. But I didn’t found “momentum” parameter in the official RAdam optimzier code. Did you modify the official code or just set beta1 to 0.9?

  2. About the weight decay, there are two descriptions in the paper. a) weight decay 0.0001 in RAdam optimizer, b) weight decay of 0.005, 1e-5 to convolutional weights and GB parameters.
    How did you set them in your code? I tried to set the weight decay in the optimizer to 0.0001, and add L2 loss to conv weights and GN weights & bias with weight 0.005 and 1e-5. Because I think L2 loss here is equivalent to weight decay. Is it same as yours?

  3. The input resolution for training patches. The training patches of size 640, 480, 320, are randomly cropped during training. After that, did your resize them to a certain size? If not, how to train the model with batch size 6?

  4. The input channel for test. In section 3.6, “During inference, the full-resolution input images and trimaps are concatenated as 4-channel input and fed into the network.” But in the code, 11-channel input is used. Is it a typo?

  5. About the re-estimated foreground. I tried to re-estimate the training foreground images, but only succeeded in 411 of 431 fg. 20 of them failed to be optimized. Did you have the same problem? Besides, when calculating the alpha * foreground error during test, are the ground-truth foreground images re-estimated by closed form?

Thank you !

question about resize method

Hello, I notice that you use cv2.INTER_LANCZOS4 to resize input trimap, why not cv2.INTER_NEAREST or other method?

Question: Normal Training Information

Hi, recently, I am reproducing your project and experiencing debugging my reimplementation. Could I ask which epoch of the training does the evaluation metrics of composite 1k datasets reach the not-bad level to help me tell if the current training functions normally or needs further adjustment? Thank you so much!

Sensitivity to trimap

Thanks for your great paper. I wanted to know how sensitive is your network to trimap input? I wanted to attach your network after my segmentation model which dilation/erosion to create trimap. Can you suggest how much I should dilate/erode and if this is a good method for your model. Thanks!

Also will you be releasing training code any time soon?

Possible non-fatal issue in decoder

As far as i understand, here https://github.com/MarcoForte/FBA_Matting/blob/master/networks/models.py#L230
resnet backbone will return such feature maps: [original_image, conv_bn_relu out, layer1 out, layer2 out, layer3 out, layer4 out]

In the decoder https://github.com/MarcoForte/FBA_Matting/blob/master/networks/models.py#L350
you concatenate: (x, conv_out[-6][:, :3], img, two_chan_trimap).

But conv_out[-6][:, :3] is the same as img. Are you sure that image should be concatenated twice?

Results

In Section 4.1 you say "The laplacian loss proposed by Hou et al. [13] gives a significant reduction in errors across all metrics. We note this network, training and loss configuration is enough to achieve state of the art results, see Table 4."

I am unsure as to what differentiated your model from others and lead to quite a substantial new state of art as by this point you have not used fg/bg prediction nor data augmentation.
From what I gather it can be attributed to:

  • Batch size=1 with GN and WS
  • 4 loss functions of laplacian, L1, composite, gradient
  • 9 channel input
  • Adding pyramid pooling in the decoder
  • Clipping instead of sigmoid

Is there anything I am missing? I am only refering to the Ours_alpha in Table 4. Can you give some comment on why you think this performed so much better than previous models listed in table 4?

No train code

Hi~ Thank you for the wonderful work.
I find your work on the alphamatting.com, and I am very glad that you open source your code so quickly.
However, I do not find the training code in the repository.
Would you like also open source the training code?

interactiveMatting by click

Do you think it's possible to achieve Matting by clicking on an interactive method, such as 99% accurate interactive object selection with just a few clicks and fba-matting, in your paper?

maybe a mistake about gradient exclusion loss?

In table 1, the gradient exclusion loss is calculated on the foreground and background of the original image,
but i think is should be calculated on the predict foreground and background, Is this wrong?

About checkpoint!

Hi !
Your work show strong results on matting!
Could you please share us your best checkpoint with us for testing your network?

BTW,
Could you release the code of your losses?

A bug in the code, influencing the training when batch size > 1

https://github.com/MarcoForte/FBA-Matting/blob/76751dd752d4d3b40bf58c64185fc77c0195cbeb/networks/models.py#L263

Hi, I think you forgot to set the "keepdim" parameter to True in the "torch.sum()" operations.
The correct one should be
alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (torch.sum((F - B) * (F - B), 1, keepdim=True) + la)

Without keeping dim, the output alpha size would become [batch, batch, height, width], due to the wrong broadcast, while it is supposed to be [batch, 1, height, width].

Apparently, when batch size > 1, the size of alpha prediction is not correct. Thus the loss calculation would be negatively influenced because the alpha, fg, and bg predictions are concatenated together in the last step of forward process.
https://github.com/MarcoForte/FBA-Matting/blob/76751dd752d4d3b40bf58c64185fc77c0195cbeb/networks/models.py#L361

For example, we set the batch size 4. With this bug, we get an alpha prediction of size [4, 4, height, width]. The fg and bg prediction are both of size [4, 3, height, width]. After concatenation, we get an output with size [4, 10, height, width], instead of the supposed [4, 7, height, width].

When calculating loss, we would slice the output by indices to extract the alpha, fg, and bg predictions.

alpha_pred = output[:, 0:1, :, :]
fg_pred = output[:, 1:4, :, :]
bg_pred = output[:, 4:7, :, :]

Here fg_pred is actually part of the alpha_pred, because the first 4 channels are alpha prediction, instead of only 1 channel. Same to bg_pred. The loss for fg and bg predictions is meaningless here.

My experimental results for models using batch size > 1 proved this bug. The errors are extremely high. I'm wondering if this bug has a negative influence on your experiments.

Matting on Mobile Devices

Thanks, and congratulations on SOTA! Wow... You guys are on top of the alphamatting charts.

Sorry for the newbie questions

  1. Can this algorithm be used on mobile devices for image matting? Or you feel this is too heavy for the task? Can you point me in the right direction?

  2. Can automated trimap generation libraries be used? Or the quality of the trimap matters too in getting good results? I am wondering how to do automation without any manual intervention.
    e.g. https://github.com/lnugraha/trimap_generator
    or
    opencv erode and dilate methods

Mistake in the code?

Hello!

Thank you for releasing your implementation. Yet, it looks like the fba_fusion doesn't do what you want to do. Or am I missing something?

Indeed, before calling the fba_fusion function, you've defined, alpha, fg, bg as follow:

        alpha = torch.clamp(output[:, 0][:, None], 0, 1)

        F = torch.sigmoid(output[:, 1:4])
        B = torch.sigmoid(output[:, 4:7])

        alpha, F, B = fba_fusion(alpha, img, F, B)

So, you are broadcasting alpha so that it is of size (B, 1, H, W)
Moreover F, and B are respectively of sizes (B, 3, H, W)

Now, if we look at how you compute alpha in the fba_fusion module, we have:

def fba_fusion(alpha, img, F, B):
    F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B))
    B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F)

    F = torch.clamp(F, 0, 1)
    B = torch.clamp(B, 0, 1)
    la = 0.1
    alpha = (alpha * la + torch.sum((img - B) * (F - B), 1)) / (torch.sum((F - B) * (F - B), 1) + la)
    alpha = torch.clamp(alpha, 0, 1)
    return alpha, F, B

So, we have (by using the broadcasting rules)

size = ((B, 1, H, W) * scalar + sum((B, 3, H, W), 1)) / (sum((B, 3, H, W), 1) + scalar)
size = (B, 1, H, W) + (B, H, W)) / (B, H, W)
size = (B, 1, H, W) + (1, B, H, W) / (B, H, W)
size = (B, B, H, W) / (B, B, H, W)
size = (B, B, H, W)

So, in the end, alpha is of size (B, B, H, W)

Wheren't you supposed to add keepdim=True in torch.sum?
Your final pth model used this flawed operation?

Hope you can reply my enquiries.
Thank you

Matting Within closed boundaries

I have implemented and tested your model for more than 300 images at least. So I had an observation that this works great for matting images that have open curves.
Issue No. 1 -
man_44Image_12XF0T1X

But in the case when curves/hair shape are closed, results have background color within closed curves,
1ayo-ogunseinde-RrD8ypt8cjY-unsplash1Image_SV6VEYTT

Results still are pretty impressive, but this issue gets more and more noticeable for more complex images.
ThankYou. Your work is great and very helpful.

Distance Transform

Hey, I do not understand how distance map is being used here and what the clicks variable is exactly supposed to represent:

def dt(a):
    return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)

def trimap_transform(trimap):
    h, w = trimap.shape[0], trimap.shape[1]

    clicks = np.zeros((h, w, 6))
    for k in range(2):
        if(np.count_nonzero(trimap[:, :, k]) > 0):
            dt_mask = -dt(1 - trimap[:, :, k])**2
            L = 320
            clicks[:, :, 3*k] = np.exp(dt_mask / (2 * ((0.02 * L)**2)))
            clicks[:, :, 3*k+1] = np.exp(dt_mask / (2 * ((0.08 * L)**2)))
            clicks[:, :, 3*k+2] = np.exp(dt_mask / (2 * ((0.16 * L)**2)))

    return clicks

Can you please explain this?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.