Giter Club home page Giter Club logo

Comments (2)

Wunaiq avatar Wunaiq commented on August 20, 2024

同样也发现了这两个问题,尤其是1会导致batch_size不同时,inference结果不同的情况。
另外也很好奇为什么要这样做,直接在这里把w和b设置成可学习的参数和现在的做法会有什么区别吗?

from tanet.

woshidandan avatar woshidandan commented on August 20, 2024

你好,模型代码里有两处地方不是很懂,可以帮忙看一下吗?

  1. class TargetNet 里的 forward 函数

    def forward(self, x, paras):



        q = self.fc1(x)

        # print(q.shape)

        q = self.bn1(q)

        q = self.relu1(q)

        q = self.drop1(q) 



        self.lin = nn.Sequential(TargetFC(paras['res_last_out_w'], paras['res_last_out_b']))

        q = self.lin(q)

        q = self.softmax(q)

        return q

其中 res_last_out_w 的 shape 是 [batch_size, 100], res_last_out_b 是 [batch_size, 1],self.lin 的输入 tensor 的 shape 是 [batch_size, 100],这样 self.lin 的输出 tensor 的 shape 为 [batch_size, batch_size],是一个 shape 与 batch_size 相关的 tensor,这样如果 batch_size 为 1 的话,这个函数输出的 tensor 的 shape 就固定为 [1, 1],值也就固定为 1,这样等于主题网络部分输出一个固定为 1 的值,应该是有点问题?

  1. Attention 函数里

def Attention(x):

    batch_size, in_channels, h, w = x.size()

    quary = x.view(batch_size, in_channels, -1)

    key = quary

    quary = quary.permute(0, 2, 1)



    sim_map = torch.matmul(quary, key)



    ql2 = torch.norm(quary, dim=2, keepdim=True)

    kl2 = torch.norm(key, dim=1, keepdim=True)

    sim_map = torch.div(sim_map, torch.matmul(ql2, kl2).clamp(min=1e-8))



    return sim_map

这里的实现跟论文里说的似乎不一样?这里的做法应该是 value 的 similarity_map 除以归一化值的 similarity_map,而非论文里说的常规 attention 去掉 V。

抱歉,因为最近在忙别的工作,自己一直没有登录这个账号,没看到问题。
第一个问题,没有特别的原因,主要是考虑到我们模型嵌入到移动端时,所做出的取舍,想把它的维度和计算量迅速降下去,所以就设定shape为1了。第二个问题,这是一种attention的简化表达形式,也是出于移动端优化的考虑,严格意义上说只是在计算相似性,因为ijcai的篇幅太有限不方面说明,文章里只好将其往self-attention那边靠。
另外,我们有做过实验,改变shape或者用原生的self-attention,对性能影响不是特别大,您可以自己试一下。

from tanet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.