Giter Club home page Giter Club logo

tensorflow-ghm-loss's Introduction

Open In Colab Star Fork License

This is a simple tensorflow implementation of the loss weights in Gradient Harmonized Single-stage Detector published on AAAI 2019 Oral.

Original Paper (Arxiv): Link


Proposed GHM Function

You can get the GHM weights by the get_ghm_weight() function in tf_ghm_loss.py. And use these weights to modify your loss term like theory in paper. The brief information of this function as below:

def get_ghm_weight(predict, target, valid_mask, bins=10, alpha=0.75,
               dtype=tf.float32, name='GHM_weight'):
    """ Get gradient Harmonized Weights.
    This is an implementation of the GHM ghm_weights described
    in https://arxiv.org/abs/1811.05181.
    Args:
        predict:
            The prediction of categories branch, [0, 1].
            -shape [batch_num, category_num].
        target:
            The target of categories branch, {0, 1}.
            -shape [batch_num, category_num].
        valid_mask:
            The valid mask, is 0 when the sample is ignored, {0, 1}.
            -shape [batch_num, category_num].
        bins:
            The number of bins for region approximation.
        alpha:
            The moving average parameter.
        dtype:
            The dtype for all operations.
    
    Returns:
        weights:
            The beta value of each sample described in paper.
    """

Toy Demo ☕

The demo state like bellow:

  • prediction: [1., 0., 0.5, 0.]
  • target: [1., 0., 0., 1.]

You can find more details in tf_ghm_loss.py.

Run

python tf_ghm_loss.py

Output

update 1 times:  [[0.5        0.5        0.72727275 0.72727275]]
update 100 times:  [array([[0.20000002, 0.20000002, 0.40000004, 0.40000004]], dtype=float32)]

Relevant materials 🍺

tensorflow-ghm-loss's People

Contributors

peteryux avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

tensorflow-ghm-loss's Issues

多分类

您好,请问一下,假设模型分类分支输出数据shape为(B,N,class_num),B=batch_size=1,
请问一下,您这个代码该如何更改?我是新手,编程比较薄弱,还请您能回复,如果能方便的能否加我w/q:1064435762

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.