Giter Club home page Giter Club logo

Comments (11)

RayXu14 avatar RayXu14 commented on May 17, 2024 4

my code in under adjustment. I can give you previous version of beam search.
this version is slow and simple, but still works.

def beam_search(x, sess, g, batch_size=hp.batch_size):
    inputs = np.reshape(np.transpose(np.array([x] * hp.beam_size), (1, 0, 2)),
                        (hp.beam_size * batch_size, hp.max_len))
    preds = np.zeros((batch_size, hp.beam_size, hp.y_max_len), np.int32)
    prob_product = np.zeros((batch_size, hp.beam_size))
    stc_length = np.ones((batch_size, hp.beam_size))

    for j in range(hp.y_max_len):
        _probs, _preds = sess.run(
            g.preds, {g.x: inputs, g.y: np.reshape(preds, (hp.beam_size * batch_size, hp.y_max_len))})
        j_probs = np.reshape(_probs[:, j, :], (batch_size, hp.beam_size, hp.beam_size))
        j_preds = np.reshape(_preds[:, j, :], (batch_size, hp.beam_size, hp.beam_size))
        if j == 0:
            preds[:, :, j] = j_preds[:, 0, :]
            prob_product += np.log(j_probs[:, 0, :])
        else:
            add_or_not = np.asarray(np.logical_or.reduce([j_preds > hp.end_id]), dtype=np.int)
            tmp_stc_length = np.expand_dims(stc_length, axis=-1) + add_or_not
            tmp_stc_length = np.reshape(tmp_stc_length, (batch_size, hp.beam_size * hp.beam_size))

            this_probs = np.expand_dims(prob_product, axis=-1) + np.log(j_probs) * add_or_not
            this_probs = np.reshape(this_probs, (batch_size, hp.beam_size * hp.beam_size))
            selected = np.argsort(this_probs / tmp_stc_length, axis=1)[:, -hp.beam_size:]

            tmp_preds = np.concatenate([np.expand_dims(preds, axis=2)] * hp.beam_size, axis=2)
            tmp_preds[:, :, :, j] = j_preds[:, :, :]
            tmp_preds = np.reshape(tmp_preds, (batch_size, hp.beam_size * hp.beam_size, hp.y_max_len))

            for batch_idx in range(batch_size):
                prob_product[batch_idx] = this_probs[batch_idx, selected[batch_idx]]
                preds[batch_idx] = tmp_preds[batch_idx, selected[batch_idx]]
                stc_length[batch_idx] = tmp_stc_length[batch_idx, selected[batch_idx]]

    final_selected = np.argmax(prob_product / stc_length, axis=1)
    final_preds = []
    for batch_idx in range(batch_size):
        final_preds.append(preds[batch_idx, final_selected[batch_idx]])

    return final_preds

I treat y length different from x, so there is y_max_length in hyper parameters.
x means input batch
g is object of the model, type Graph in train.py

from transformer.

RayXu14 avatar RayXu14 commented on May 17, 2024 1

I have implemented beam search in my fork. Similar operation as this transformer despite of some modification.

from transformer.

hugddygff avatar hugddygff commented on May 17, 2024

can you share the code with beam search, I can't find the fork in your homepage

from transformer.

hugddygff avatar hugddygff commented on May 17, 2024

thanks very much, if I have some idea, I will discuss with you directly.

from transformer.

chenvega avatar chenvega commented on May 17, 2024

@RayXu14
There is some problem in your code
Do you modify the preds in training Graph?
_probs, _preds = sess.run(g.preds, {g.x: inputs, g.y: np.reshape(preds, (hp.beam_size * batch_size, hp.y_max_len))})

from transformer.

RayXu14 avatar RayXu14 commented on May 17, 2024

yes

            if is_training or hp.beam_size == 1:
                self.preds = tf.to_int32(tf.argmax(self.logits, axis=-1))
                self.istarget = tf.to_float(tf.not_equal(self.y, 0))
                self.acc = tf.reduce_sum(tf.to_float(tf.equal(self.preds, self.y)) * self.istarget) / (tf.reduce_sum(self.istarget))

            else:
                print('[WARNING] beam search enabled')
                assert self.logits.get_shape().as_list()[-1] >= hp.beam_size
                self.probs = tf.nn.softmax(self.logits)
                self.preds = tf.nn.top_k(self.probs, hp.beam_size)

from transformer.

xinyx avatar xinyx commented on May 17, 2024

Could you explain what the hp.end_id means?

my code in under adjustment. I can give you previous version of beam search.
this version is slow and simple, but still works.

def beam_search(x, sess, g, batch_size=hp.batch_size):
    inputs = np.reshape(np.transpose(np.array([x] * hp.beam_size), (1, 0, 2)),
                        (hp.beam_size * batch_size, hp.max_len))
    preds = np.zeros((batch_size, hp.beam_size, hp.y_max_len), np.int32)
    prob_product = np.zeros((batch_size, hp.beam_size))
    stc_length = np.ones((batch_size, hp.beam_size))

    for j in range(hp.y_max_len):
        _probs, _preds = sess.run(
            g.preds, {g.x: inputs, g.y: np.reshape(preds, (hp.beam_size * batch_size, hp.y_max_len))})
        j_probs = np.reshape(_probs[:, j, :], (batch_size, hp.beam_size, hp.beam_size))
        j_preds = np.reshape(_preds[:, j, :], (batch_size, hp.beam_size, hp.beam_size))
        if j == 0:
            preds[:, :, j] = j_preds[:, 0, :]
            prob_product += np.log(j_probs[:, 0, :])
        else:
            add_or_not = np.asarray(np.logical_or.reduce([j_preds > hp.end_id]), dtype=np.int)
            tmp_stc_length = np.expand_dims(stc_length, axis=-1) + add_or_not
            tmp_stc_length = np.reshape(tmp_stc_length, (batch_size, hp.beam_size * hp.beam_size))

            this_probs = np.expand_dims(prob_product, axis=-1) + np.log(j_probs) * add_or_not
            this_probs = np.reshape(this_probs, (batch_size, hp.beam_size * hp.beam_size))
            selected = np.argsort(this_probs / tmp_stc_length, axis=1)[:, -hp.beam_size:]

            tmp_preds = np.concatenate([np.expand_dims(preds, axis=2)] * hp.beam_size, axis=2)
            tmp_preds[:, :, :, j] = j_preds[:, :, :]
            tmp_preds = np.reshape(tmp_preds, (batch_size, hp.beam_size * hp.beam_size, hp.y_max_len))

            for batch_idx in range(batch_size):
                prob_product[batch_idx] = this_probs[batch_idx, selected[batch_idx]]
                preds[batch_idx] = tmp_preds[batch_idx, selected[batch_idx]]
                stc_length[batch_idx] = tmp_stc_length[batch_idx, selected[batch_idx]]

    final_selected = np.argmax(prob_product / stc_length, axis=1)
    final_preds = []
    for batch_idx in range(batch_size):
        final_preds.append(preds[batch_idx, final_selected[batch_idx]])

    return final_preds

I treat y length different from x, so there is y_max_length in hyper parameters.
x means input batch
g is object of the model, type Graph in train.py

Could you explain what the hp.end_id means?
add_or_not = np.asarray(np.logical_or.reduce([j_preds > hp.end_id]), dtype=np.int)

from transformer.

gitfourteen avatar gitfourteen commented on May 17, 2024

@xinyx I think it is index 3 for '</S>'

from transformer.

newdaylqt avatar newdaylqt commented on May 17, 2024

How do you modify is_training after building the graph? In practice, I find it helpful to feed is_training to graph and use tf.cond instead of if-else, but it is too complicated.

yes

            if is_training or hp.beam_size == 1:
                self.preds = tf.to_int32(tf.argmax(self.logits, axis=-1))
                self.istarget = tf.to_float(tf.not_equal(self.y, 0))
                self.acc = tf.reduce_sum(tf.to_float(tf.equal(self.preds, self.y)) * self.istarget) / (tf.reduce_sum(self.istarget))

            else:
                print('[WARNING] beam search enabled')
                assert self.logits.get_shape().as_list()[-1] >= hp.beam_size
                self.probs = tf.nn.softmax(self.logits)
                self.preds = tf.nn.top_k(self.probs, hp.beam_size)

from transformer.

lidongxing avatar lidongxing commented on May 17, 2024

Beam are used in inference time and why edit the training.py rather than test.py?

from transformer.

ZhichaoOuyang avatar ZhichaoOuyang commented on May 17, 2024

I have implemented beam search in my fork. Similar operation as this transformer despite of some modification.

Is there a more effective beam search implementation code that can be shared?

from transformer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.