Giter Club home page Giter Club logo

Comments (6)

karpathy avatar karpathy commented on September 25, 2024

yes this would be relatively easy. You have to annotate the nngraph node (see docs) of the nn.Linear layer with parameters, so that you can query for it from the nngraph, and then you meddle with its .bias field. You'd have to be careful because I compute all 4 vectors i,f,o,g in one go and with one vector, so you'd want to set the correct portion of the bias vector higher.

I already did this in one fork of char-rnn but didn't find noticeable improvements in training time. But maybe I did it wrong ;) Fun exercise to try for yourself.

from char-rnn.

rfru avatar rfru commented on September 25, 2024

Cool! Thanks for the info. Is that fork available somewhere? Would be great to take a look at it to start.

from char-rnn.

faradox avatar faradox commented on September 25, 2024

In my humble experience there are noticeable improvements in many cases. In fact, I couldn't even get a single LSTM network with more than 2 layers to learn something if the forget gates weren't initialized with 1 (but that wasn't in the char-rnn code so maybe I did it wrong, too). To implement it here, I did:

In model/LSTM.lua:

local in_gate = nn.Sigmoid()(n1)
-- annotate the forget gate so we can manipulate it directly later
local forget_gate = nn.Sigmoid()(n2):annotate{
    name = 'forget', description = 'Forget gate',
}
local out_gate = nn.Sigmoid()(n3)

And in train.lua:

-- initialization
if do_random_init then
    params:uniform(-0.08, 0.08) -- small numbers uniform
    for _,node in ipairs(protos.rnn.forwardnodes) do
        if node:graphNodeName() == "forget" then
            node.bias:fill(1) -- initialize forget gates to 1
        end
    end
end

from char-rnn.

ffmpbgrnn avatar ffmpbgrnn commented on September 25, 2024

Hi @faradox , I think you should annotate on the Linear layer. Like:

local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x)
i2h:annotate{name='i2h_'..L}

-- and then
for layer_idx = 1, opt.n_layers do
    for _,node in ipairs(protos.rnn.forwardnodes) do
        if node.data.annotations.name == "i2h_"..layer_idx then
            node.data.module.bias[{{1*opt.rnn_size+1, 2*opt.rnn_size}}]:fill(1)
        end
    end
end

Correct me if I am wrong.

from char-rnn.

karpathy avatar karpathy commented on September 25, 2024

OK i ran a small experiment and I'm now seeing improvements from initializing with 1.0. I'm adding this feature to char-rnn since there is enough evidence that this probably helps, and usually doesn't hurt.

from char-rnn.

karpathy avatar karpathy commented on September 25, 2024

(And thank you @rfru , @faradox and @ffmpbgrnn for the discussion surrounding this)

from char-rnn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.