Giter Club home page Giter Club logo

bert4rec-vae-pytorch's Introduction

Hey there, I'm Jae-Won! 😆

  • PhD student at UMich CSE, SymbioticLab. Working on energy efficient software systems for Deep Learning!
  • Commandline enthusiast. Check out my dotfiles!
  • Fingerstyle guitar player. But I rarely get to show off!

bert4rec-vae-pytorch's People

Contributors

jaywonchung avatar sungmincho avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

bert4rec-vae-pytorch's Issues

about the prediction

Hi, I used your bert4rec model to predict a sequence, but I don't know how to let the tensor readable. Can you give me some tips? thx!
Best regards

    # load one seq
    seq = dataset['test'][i]
    mask_token = len(dataset['smap']) + 1
    seq = seq + [mask_token]
    seq = seq[-args.bert_max_len:]
    padding_len = args.bert_max_len - len(seq)
    seq = [0] * padding_len + seq
    seq_ = torch.LongTensor([seq]).to(device)
    # recommend 50 item
    rec_list = []
    for _ in range(50):
        print(seq)
        #input 1...T time seq to predict
        logit = model(seq_)[0]
        #get the last tensor which is prediction? (T+1) time
        rec = torch.argmax(logit[-1], 0).it em()
        rec_list.append(datamap[rec])
        seq.insert(-1, rec)
        seq.pop(0)
        seq_ = torch.LongTensor([seq]).to(device)
        break

about rank values at recalls_and_ndcgs_for_ks

Hi!

Thanks for your grate repository!
I'm trying to use this code and process step by step.
And I found strange point on my environment.

At recalls_and_ndcgs_for_ks def, I checked scores and ranks value.
I found difference between score order and rank order.

Below, ranks says index 13 is minimum, but scores index 13 is a '-2.9711'. It's not minimum at score.

Code is around here

# ranks[0] value
tensor([ 70,  86,   7,  97,  63,  95,  24,  58,  65,  56,  16,   5,  93,   0,
         18,  31,  68,  48,  59,  98,  29,  76,  54,  79,  23,  55,  39,  72,
         36,  89,  38,  44,  78,   4,  61,  53,   8,  74, 100,  22,   3,  91,
         80,  84,  51,  57,  82,  92,  83,  67,  34,  64,  50,  25,  85,  52,
         26,  43,  15,  12,  40,  27,   1,  88,  94,  71,   2,  66,  96,  28,
         19,  42,  20,  60,  49,  87,  77,  47,   9,  69,  41,  99,  17,  46,
         75,  73,  21,  14,  45,  32,  13,  11,  10,  33,  90,  35,   6,  37,
         30,  62,  81])

# scores[0] value
tensor([ 0.5852, -3.4007, -3.5931, -1.8515, -1.1481,  0.5925, -7.0834,  2.5950,
        -1.6262, -4.7560, -6.4868, -6.4362, -2.9711, -6.3241, -5.8845, -2.9454,
         0.7134, -5.0486,  0.5127, -4.3835, -4.5180, -5.8608, -1.8072, -0.6608,
         1.9210, -2.5598, -2.8975, -3.0994, -4.1074, -0.2445, -7.5053,  0.1854,
        -6.3146, -6.4986, -2.3304, -6.8591, -0.8691, -7.4218, -1.0146, -0.7875,
        -3.0707, -4.9039, -4.3940, -2.9124, -1.0281, -6.2175, -5.1383, -4.7329,
        -0.0837, -4.6264, -2.4418, -2.1516, -2.5802, -1.6060, -0.3369, -0.7312,
         0.7645, -2.1919,  0.8558, -0.1193, -4.5226, -1.3819, -7.8501,  2.4229,
        -2.4081,  0.8114, -3.6097, -2.3246,  0.0658, -4.8747,  2.9387, -3.4794,
        -0.7969, -5.7531, -1.7263, -5.4659, -0.3193, -4.7225, -1.1454, -0.3836,
        -1.9473, -8.0547, -2.2138, -2.3139, -2.0216, -2.5738,  2.7595, -4.6759,
        -3.4287, -0.9745, -6.6980, -1.8827, -2.2390,  0.5889, -3.4509,  2.3932,
        -4.0403,  2.5251, -0.1454, -4.9200, -1.7323])

Do you have any idea? Or if I have some mistake, then please let me know.
Thank you.

BERT4rec model seems to be underfit in the reported results.

Hi,
I am experimenting with the original implementation of BERT4rec, but I used the same sampling strategy as you do.

When I trained the original implementation for 1 hour I've got following results:
R@1: 0.341391
R@5: 0.661589
R@10: 0.765728
NDCG@5: 0.512567
NDCG@10: 0.546344

This is well aligned with what you reported in the table.

However, when I gave the model 16 hours to train, I got much better results:
R@1: 0.405960
R@5: 0.714570
R@10: 0.803974
NDCG@5: 0.571801
NDCG@10 0.600875

I think it worth for you to re-evaluate your model with more epochs on the ML-1M dataset

For the ML-20M dataset my results are well aligned with yours:
R@1 0.613104
R@5 0.887872
R@10 0.945339
NDCG@5 0.763871
NDCG@10 0.782702

Can you please add licence to the repo?

Could you please add an explicit LICENSE file to the repo so that it's clear under what terms the content is provided.

Per GitHub docs on licensing:

Public repositories on GitHub are often used to share open source software. For your repository to truly be open source, you'll need to license it so that others are free to use, change, and distribute the software.

Thanks in advance!

What configuration used to obtain the presented results for the VAE model?

Dear @jaywonchung,

For the movielens 1m dataset the reported result is NDCG@10 = 0.4049. I am currently unable to reproduce this result with the default configuration template train_vae_search_beta used and 'vae_num_hidden': 0. The results I get are closer to 0.37, so I would like to know how did you achieve those results.

What configuration have you used to achieve the above-mentioned results and how did you evaluate?

Thank you.

Try not to restrict dependency versions

Dependencies in requirements.txt all have specified versions. Some of them are quite out of date. I realized that with the most recent version of them, we can run the program all good. How about we remove those version constraints?

Ask for metrics

In the paper BERT4REC, the authors uses Hit Ratio as a metrics for evaluation, and in your code, you use the Recall. The question which I want to ask is "Two metrics are equivalent?"

Out of Memory CUDA Error with Increased Data Size in BERT4Rec

I initially tested the functionality with one day’s worth of my custom data, and successfully trained and evaluated BERT4Rec without any issues.
However, when I tried to input a week’s worth of data for additional training, I encountered an out of memory CUDA error.

By reducing the 'bert_max_len' and 'batch size' in the 'templates.py' file, I managed to get it to work, but I do not understand why the increase in data size would cause an out of memory error.

Is there a part of the code where the increased data size might be impacting memory usage?
If so, could you point out where this might be occurring?

Thank you for your assistance.

A question about the code of " BertTrainDataset.__getitem__(self, index)" function

I want to ask a question that why you give '0' label to the none masked token when generate a example here:

    for s in seq:
           prob = self.rng.random()
           if prob < self.mask_prob:
               prob /= self.mask_prob

               if prob < 0.8:
                   tokens.append(self.mask_token)
               elif prob < 0.9: 
                   tokens.append(self.rng.randint(1, self.num_items))
               else: 
                   tokens.append(s)

               labels.append(s) 
           else:
               tokens.append(s)
               labels.append(0) # **why give 0 label here and not the index of s ?**

Thanks !

索引不对

这个地方将用户id和item id映射出索引不应该从0开始映射
image
不然下面用0padding 就会有问题
image

is AverageMeter updating way right?

Hi, Thank you so much for your code.
I got a question about AverageMeter updating way.

It seems that when updating recall or ndcg, in AverageMeterSet.update() and AverageMeter.update(),
that instance calculates average with batch average.

for example in Trainer > validate(), test(),
batch 1's ndcg = (0.1 + 0.2 + 0.3) / 3 =0.2 if batch size is 3,
batch 2's ndcg = (0.4 + 0.5) / 2 = 0.45 if last batch's size is 2
in codes, it seems calculating average like (0.2+0.45) / 2 = 0.325.

but isn't it right like (0.1+0.2+0.3+0.4+0.5) / 5 = 0.3 ?

추천시스템구현

안녕하세요 bert기반 추천시스템 구현을 위하여 이 레포를 살펴보고 있는데 어떤 방식으로 살펴보면 좋은지 문의드려도 될까요?

The loss is nan

In some cases, such as the sequence length is short or the value of mask_prob is small, there will be a situation where the whole training sequence is not masked, and the loss at this time will be the value of nan, how to solve this situation? I don't want the loss to be a nan value, can I only adjust the value of prob?

提问

不好意思 请问怎么用训练出的模型对特定用户进行预测

infer code

I want to infer the model ,but the result is wrong

Ml-1m results problem

Why do I use mL-1m data set to run ndcg@10 result is only 0.26, in this paper 0.48. But the ML-20 dataset ,The results are similar to those in the paper.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.