Comments (11)
my code in under adjustment. I can give you previous version of beam search.
this version is slow and simple, but still works.
def beam_search(x, sess, g, batch_size=hp.batch_size):
inputs = np.reshape(np.transpose(np.array([x] * hp.beam_size), (1, 0, 2)),
(hp.beam_size * batch_size, hp.max_len))
preds = np.zeros((batch_size, hp.beam_size, hp.y_max_len), np.int32)
prob_product = np.zeros((batch_size, hp.beam_size))
stc_length = np.ones((batch_size, hp.beam_size))
for j in range(hp.y_max_len):
_probs, _preds = sess.run(
g.preds, {g.x: inputs, g.y: np.reshape(preds, (hp.beam_size * batch_size, hp.y_max_len))})
j_probs = np.reshape(_probs[:, j, :], (batch_size, hp.beam_size, hp.beam_size))
j_preds = np.reshape(_preds[:, j, :], (batch_size, hp.beam_size, hp.beam_size))
if j == 0:
preds[:, :, j] = j_preds[:, 0, :]
prob_product += np.log(j_probs[:, 0, :])
else:
add_or_not = np.asarray(np.logical_or.reduce([j_preds > hp.end_id]), dtype=np.int)
tmp_stc_length = np.expand_dims(stc_length, axis=-1) + add_or_not
tmp_stc_length = np.reshape(tmp_stc_length, (batch_size, hp.beam_size * hp.beam_size))
this_probs = np.expand_dims(prob_product, axis=-1) + np.log(j_probs) * add_or_not
this_probs = np.reshape(this_probs, (batch_size, hp.beam_size * hp.beam_size))
selected = np.argsort(this_probs / tmp_stc_length, axis=1)[:, -hp.beam_size:]
tmp_preds = np.concatenate([np.expand_dims(preds, axis=2)] * hp.beam_size, axis=2)
tmp_preds[:, :, :, j] = j_preds[:, :, :]
tmp_preds = np.reshape(tmp_preds, (batch_size, hp.beam_size * hp.beam_size, hp.y_max_len))
for batch_idx in range(batch_size):
prob_product[batch_idx] = this_probs[batch_idx, selected[batch_idx]]
preds[batch_idx] = tmp_preds[batch_idx, selected[batch_idx]]
stc_length[batch_idx] = tmp_stc_length[batch_idx, selected[batch_idx]]
final_selected = np.argmax(prob_product / stc_length, axis=1)
final_preds = []
for batch_idx in range(batch_size):
final_preds.append(preds[batch_idx, final_selected[batch_idx]])
return final_preds
I treat y length different from x, so there is y_max_length in hyper parameters.
x means input batch
g is object of the model, type Graph in train.py
from transformer.
I have implemented beam search in my fork. Similar operation as this transformer despite of some modification.
from transformer.
can you share the code with beam search, I can't find the fork in your homepage
from transformer.
thanks very much, if I have some idea, I will discuss with you directly.
from transformer.
@RayXu14
There is some problem in your code
Do you modify the preds in training Graph?
_probs, _preds = sess.run(g.preds, {g.x: inputs, g.y: np.reshape(preds, (hp.beam_size * batch_size, hp.y_max_len))})
from transformer.
yes
if is_training or hp.beam_size == 1:
self.preds = tf.to_int32(tf.argmax(self.logits, axis=-1))
self.istarget = tf.to_float(tf.not_equal(self.y, 0))
self.acc = tf.reduce_sum(tf.to_float(tf.equal(self.preds, self.y)) * self.istarget) / (tf.reduce_sum(self.istarget))
else:
print('[WARNING] beam search enabled')
assert self.logits.get_shape().as_list()[-1] >= hp.beam_size
self.probs = tf.nn.softmax(self.logits)
self.preds = tf.nn.top_k(self.probs, hp.beam_size)
from transformer.
Could you explain what the hp.end_id means?
my code in under adjustment. I can give you previous version of beam search.
this version is slow and simple, but still works.def beam_search(x, sess, g, batch_size=hp.batch_size): inputs = np.reshape(np.transpose(np.array([x] * hp.beam_size), (1, 0, 2)), (hp.beam_size * batch_size, hp.max_len)) preds = np.zeros((batch_size, hp.beam_size, hp.y_max_len), np.int32) prob_product = np.zeros((batch_size, hp.beam_size)) stc_length = np.ones((batch_size, hp.beam_size)) for j in range(hp.y_max_len): _probs, _preds = sess.run( g.preds, {g.x: inputs, g.y: np.reshape(preds, (hp.beam_size * batch_size, hp.y_max_len))}) j_probs = np.reshape(_probs[:, j, :], (batch_size, hp.beam_size, hp.beam_size)) j_preds = np.reshape(_preds[:, j, :], (batch_size, hp.beam_size, hp.beam_size)) if j == 0: preds[:, :, j] = j_preds[:, 0, :] prob_product += np.log(j_probs[:, 0, :]) else: add_or_not = np.asarray(np.logical_or.reduce([j_preds > hp.end_id]), dtype=np.int) tmp_stc_length = np.expand_dims(stc_length, axis=-1) + add_or_not tmp_stc_length = np.reshape(tmp_stc_length, (batch_size, hp.beam_size * hp.beam_size)) this_probs = np.expand_dims(prob_product, axis=-1) + np.log(j_probs) * add_or_not this_probs = np.reshape(this_probs, (batch_size, hp.beam_size * hp.beam_size)) selected = np.argsort(this_probs / tmp_stc_length, axis=1)[:, -hp.beam_size:] tmp_preds = np.concatenate([np.expand_dims(preds, axis=2)] * hp.beam_size, axis=2) tmp_preds[:, :, :, j] = j_preds[:, :, :] tmp_preds = np.reshape(tmp_preds, (batch_size, hp.beam_size * hp.beam_size, hp.y_max_len)) for batch_idx in range(batch_size): prob_product[batch_idx] = this_probs[batch_idx, selected[batch_idx]] preds[batch_idx] = tmp_preds[batch_idx, selected[batch_idx]] stc_length[batch_idx] = tmp_stc_length[batch_idx, selected[batch_idx]] final_selected = np.argmax(prob_product / stc_length, axis=1) final_preds = [] for batch_idx in range(batch_size): final_preds.append(preds[batch_idx, final_selected[batch_idx]]) return final_predsI treat y length different from x, so there is y_max_length in hyper parameters.
x means input batch
g is object of the model, type Graph in train.py
Could you explain what the hp.end_id means?
add_or_not = np.asarray(np.logical_or.reduce([j_preds > hp.end_id]), dtype=np.int)
from transformer.
@xinyx I think it is index 3 for '</S>'
from transformer.
How do you modify is_training after building the graph? In practice, I find it helpful to feed is_training to graph and use tf.cond instead of if-else, but it is too complicated.
yes
if is_training or hp.beam_size == 1: self.preds = tf.to_int32(tf.argmax(self.logits, axis=-1)) self.istarget = tf.to_float(tf.not_equal(self.y, 0)) self.acc = tf.reduce_sum(tf.to_float(tf.equal(self.preds, self.y)) * self.istarget) / (tf.reduce_sum(self.istarget)) else: print('[WARNING] beam search enabled') assert self.logits.get_shape().as_list()[-1] >= hp.beam_size self.probs = tf.nn.softmax(self.logits) self.preds = tf.nn.top_k(self.probs, hp.beam_size)
from transformer.
Beam are used in inference time and why edit the training.py rather than test.py?
from transformer.
I have implemented beam search in my fork. Similar operation as this transformer despite of some modification.
Is there a more effective beam search implementation code that can be shared?
from transformer.
Related Issues (20)
- how to understand decoder_inputs have ["<s>"] while y_hat, y or x don't have ["<s>"] HOT 1
- How do you connect Convolutional layers to Transformers?
- Can anyone improve the code to reproduce WMT2014 results?
- Why to remove query masking?
- something about Model.py , in which something wrong between encoder and decoder HOT 2
- 自己学习率没有调节好 HOT 4
- 无 HOT 1
- Error! excuting the train.py
- Question:multihead_attention()’s output and ff()‘s output need dropout?
- Is value must be linear transformed at multi-head attention?
- request for dataset HOT 1
- 请问下,有TF2.X有实现的代码吗
- There is a problem when I execute the command bash download.sh HOT 2
- train error
- 多头问题的实现推导不太理解
- 您好,我想请问论文里bleu和meteor的数据是怎么得到的?我运行test.py文件,出现的是很多句子。
- 数据集链接失效了 HOT 3
- tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[27,100] = 100 is not in [0, 100) HOT 2
- About the query_mask
- about spm.SentencePieceTrainer.Train(train) HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformer.