Comments (3)
I fail to reproduce this with GPT2 and a unit test. I haven't tried using Zephyr-7B beta in the unit test.
I won't be able to work on this longer today and I don't know when I'll be able to pick this up.
Here's the unit test:
@pytest.fixture(scope='class')
def model_cacheable(self):
from transformers import AutoModelForCausalLM
return AutoModelForCausalLM.from_pretrained('gpt2')
def test_caching_works_shared_label_prefix_without_eos(self, model_cacheable, classifier_cls):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_cacheable.config._name_or_path)
tokenizer.add_eos_token = False
generate_fn = model_cacheable.generate
def patched_generate(*args, **kwargs):
original_result = generate_fn(*args, **kwargs)
no_eos_result = original_result[:, :-1]
return no_eos_result
with patch.object(model_cacheable, 'generate', new=patched_generate):
clf = classifier_cls(model=model_cacheable, tokenizer=tokenizer)
X = np.array([["Hey there"], ["No thank you"], ["Whatever"]])
y = ['0', '1', '12']
clf.fit(X, y)
for label_ids in clf.label_ids_:
assert label_ids[-1] != model_cacheable.config.eos_token_id
clf.predict(X)
clf.predict(X)
from skorch.
Quick update. This is a test that reproduces the issue:
def test_caching_works_shared_label_prefix_without_eos(self, model_cacheable, classifier_cls):
clf = classifier_cls('HuggingFaceH4/zephyr-7b-beta')
X = np.array([["Hey there"], ["No thank you"], ["Whatever"]])
y = ['0', '1', '12']
clf.fit(X, y)
for label_ids in clf.label_ids_:
assert label_ids[-1] != model_cacheable.config.eos_token_id
clf.predict(X)
clf.predict(X)
The issue is resolved when adding not label_id or
to the condition in generate_logits
:
skorch/skorch/llm/classifier.py
Line 246 in 9ab3b2c
I'm currently working on reproducing this issue with gpt2 as to not blow up the model zoo we're using in the tests too much.
from skorch.
Resolved with #1048.
from skorch.
Related Issues (20)
- IndexError: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item()` in C++ to convert a 0-dim tensor to a number HOT 8
- Enable using a generator as data loader
- Question: weird valid loss when re-scaling y
- Issues in braindecode recently introduced by skorch HOT 3
- ReadTheDocs: Wrong theme of docs
- How to tune number of epochs?
- Issues with deployment script
- dill load and sklearn clone result in error HOT 7
- how to integrate pytorch-tabnet into skorch framework? HOT 2
- ImportError: cannot import name 'uses_placeholder_y' from 'skorch.dataset' HOT 1
- Saving and Loading not working HOT 2
- Add slides for Pydata Amsterdam HOT 3
- model.history not save output of all epochs HOT 2
- I can't use gpu cuda tensor with NeuralNet HOT 1
- Permit to pass '**predict_params' to 'predict' method as for 'fit' method HOT 2
- Skorch forwarding data columns as kwargs when using gridsearchcv HOT 4
- Skorch weird handling of input data HOT 3
- Activating (deactivating) callbacks at specific epochs or milestones and SequentialLR HOT 1
- Dictionary Input and Custom Collate Function HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from skorch.