ygzwqzd / lamda-ssl Goto Github PK
View Code? Open in Web Editor NEW30 Semi-Supervised Learning Algorithms
License: MIT License
30 Semi-Supervised Learning Algorithms
License: MIT License
can provide the test data of the examples?
I just installed LAMDA-SSL from github. It instaled the newer version of all packages, including torch==2.0.1 (pip freeze below)
I cannot reproduce the Example that uses deeplearning
Assemble and others non-deep algorithms work fine:
(luan) Atlas:LAMDA-SSL wainer$ python Examples/Assemble_BreastCancer.py
(luan) Atlas:LAMDA-SSL wainer$
but
(luan) Atlas:LAMDA-SSL wainer$ python Examples/FixMatch_BreastCancer.py
Traceback (most recent call last):
File "/Users/wainer/Dropbox/alunos/luan/LAMDA-SSL/Examples/FixMatch_BreastCancer.py", line 64, in <module>
model.fit(X=labeled_X,y=labeled_y,unlabeled_X=unlabeled_X)
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 326, in fit
self.init_train_dataloader()
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 243, in init_train_dataloader
self._labeled_dataloader, self._unlabeled_dataloader = self._train_dataloader.init_dataloader(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataloader/TrainDataloader.py", line 344, in init_dataloader
self.labeled_dataloader = self.labeled_dataloader.init_dataloader(dataset=self.labeled_dataset,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataloader/LabeledDataloader.py", line 86, in init_dataloader
self.dataloader= DataLoader(dataset=self.dataset,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 245, in __init__
raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
ValueError: prefetch_factor option could only be specified in multiprocessing.let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.
I have been altering the obvious things such as default prefetch_factor and num_workers but after 2 hours of doing this I still get some problem somewhere. Below is my last attempt, by creating Dataloaders with the appropriate num_workers and prefetch_factor for the FixMatch_BreastCancer.py code, but I am not sure my modifications are correct. Someone is probably much more competent to make these changes...
(luan) Atlas:progs wainer$ python a2.py
Traceback (most recent call last):
File "/Users/wainer/Dropbox/alunos/luan/progs/a2.py", line 82, in <module>
model.fit(X=labeled_X,y=labeled_y,unlabeled_X=unlabeled_X)
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 335, in fit
self.fit_epoch_loop(valid_X,valid_y)
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 311, in fit_epoch_loop
self.fit_batch_loop(valid_X,valid_y)
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 280, in fit_batch_loop
for (lb_idx, lb_X, lb_y), (ulb_idx, ulb_X, _) in zip(self._labeled_dataloader, self._unlabeled_dataloader):
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
data = self._next_data()
^^^^^^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
~~~~~~~~~~~~^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataset/LabeledDataset.py", line 217, in __getitem__
Xi, yi = self.apply_transform(Xi, yi)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataset/LabeledDataset.py", line 185, in apply_transform
_X = self._transform(X, item)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataset/LabeledDataset.py", line 130, in _transform
X=self._transform(X,item)
^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataset/LabeledDataset.py", line 132, in _transform
X = transform(X)
^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/Transformer.py", line 18, in __call__
return self.fit_transform(X,y,fit_params=fit_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/sklearn/utils/_set_output.py", line 140, in wrapped
data_to_wrap = f(self, X, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/Transformer.py", line 30, in fit_transform
return self.fit(X=X,y=y,fit_params=fit_params).transform(X)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/sklearn/utils/_set_output.py", line 140, in wrapped
data_to_wrap = f(self, X, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Transform/ToTensor.py", line 51, in transform
X=torch.Tensor(X)
^^^^^^^^^^^^^^^
TypeError: new(): data must be a sequence (got Image)
I would guess that the problem is with the torch 2.X version, but I am not sure.
pip freeze:
(luan) Atlas:progs wainer$ pip freeze
certifi==2023.5.7
charset-normalizer==3.1.0
contourpy==1.0.7
cycler==0.11.0
filelock==3.12.0
flake8==6.0.0
fonttools==4.39.4
idna==3.4
Jinja2==3.1.2
joblib==1.2.0
kiwisolver==1.4.4
LAMDA-SSL @ file:///Users/wainer/Dropbox/alunos/luan/LAMDA-SSL
MarkupSafe==2.1.3
matplotlib==3.7.1
mccabe==0.7.0
mpmath==1.3.0
networkx==3.1
numpy==1.24.3
packaging==23.1
pandas==2.0.2
Pillow==9.5.0
psutil==5.9.5
pycodestyle==2.10.0
pyflakes==3.0.1
pyparsing==3.0.9
python-dateutil==2.8.2
pytz==2023.3
requests==2.31.0
scikit-learn==1.2.2
scipy==1.10.1
six==1.16.0
sympy==1.12
threadpoolctl==3.1.0
torch==2.0.1
torch-geometric==2.3.1
torchdata==0.6.1
torchtext==0.15.2
torchvision==0.15.2
tqdm==4.65.0
typing_extensions==4.6.3
tzdata==2023.3
urllib3==2.0.3
Hi, Thanks for creating this library.
I am trying to use MeanTeacherReg, ICTReg, and PiModelReg with my custom dataset.
When I used MeanTeacherReg, ICTReg, and PiModelReg without changing the hyperparameter at the example code that was fitted to Boston Dataset, although the model works, the predictions all come to zero.
This means that the model did not learn my custom dataset properly.
What is the appropriate hyperparameter of MeanTeacherReg / ICTReg / PiModelReg when using other dataset?
I thought there are some hyperparemeters to change when using other dataset such as:
# sampler
labeled_sampler=RandomSampler(replacement=True,num_samples=64*(2**20))
unlabeled_sampler=RandomSampler(replacement=True)
valid_sampler=SequentialSampler()
test_sampler=SequentialSampler()
#dataloader
labeled_dataloader=LabeledDataLoader(batch_size=64,num_workers=0,drop_last=True)
unlabeled_dataloader=UnlabeledDataLoader(num_workers=0,drop_last=True)
valid_dataloader=UnlabeledDataLoader(batch_size=64,num_workers=0,drop_last=False)
test_dataloader=UnlabeledDataLoader(batch_size=64,num_workers=0,drop_last=False)
model=MeanTeacherReg(lambda_u=0,warmup=0.4,
mu=1,weight_decay=5e-4,ema_decay=0.999,
epoch=1,num_it_epoch=4000,
num_it_total=4000,
eval_it=200,device='cpu',
labeled_dataset=labeled_dataset,
unlabeled_dataset=unlabeled_dataset,
valid_dataset=valid_dataset,
test_dataset=test_dataset,
labeled_sampler=labeled_sampler,
unlabeled_sampler=unlabeled_sampler,
valid_sampler=valid_sampler,
test_sampler=test_sampler,
labeled_dataloader=labeled_dataloader,
unlabeled_dataloader=unlabeled_dataloader,
valid_dataloader=valid_dataloader,
test_dataloader=test_dataloader,
augmentation=augmentation,network=network,
optimizer=optimizer,scheduler=scheduler,
evaluation=evaluation,file=file,verbose=True)
The size of my custom dataset,
At this setting, the model did not train my dataset properly.
Can you provide some example code that you worked the model using a different dataset, not Boston dataset?
Or, is there any tip to implement hyperparameter tuning of MeanTeacherReg, ICTReg, and Pi Model Reg?
中文使用教程页面[Constrained Seed k-means]部分,最后一句话“使用有标注数据参于聚类过成时聚类器更加可靠,……”有错别字,应该为“过程”
It would be useful to also see the performance of each SSL model against the purely supervised backbone run on the labeled data.
For example, TSVM vs pure SVM:
import numpy as np
from LAMDA_SSL.Dataset.Tabular.BreastCancer import BreastCancer
dataset = BreastCancer(labeled_size=0.1, stratified=True, shuffle=True)
labeled_X = dataset.labeled_X
labeled_y = dataset.labeled_y
unlabeled_X = dataset.unlabeled_X
unlabeled_y = dataset.unlabeled_y
from sklearn import preprocessing
pre_transform = preprocessing.StandardScaler()
pre_transform.fit(np.vstack([labeled_X, unlabeled_X]))
labeled_X = pre_transform.transform(labeled_X)
unlabeled_X = pre_transform.transform(unlabeled_X)
from LAMDA_SSL.Algorithm.Classification.TSVM import TSVM
# I tried using a range of Cl and Cu, starting from 15 and 0.0001 and then gradually
# upping Cu and decreasing Cl. It didn't seem to make a difference?
model = TSVM(Cl=1, Cu=1, kernel="linear")
model.fit(X=labeled_X, y=labeled_y, unlabeled_X=unlabeled_X)
pred_y = model.predict()
from LAMDA_SSL.Evaluation.Classifier.Accuracy import Accuracy
score = Accuracy().scoring(unlabeled_y, pred_y)
print(f"SSL TSVM score: {score}")
#> SSL TSVM score: 0.9609375
# Compare with pure SVM
from sklearn import svm
model_sl = svm.SVC()
model_sl.fit(labeled_X, labeled_y)
pred_sl = model_sl.predict(unlabeled_X)
score_sl = Accuracy().scoring(unlabeled_y, pred_sl)
print(f"SL SVM score: {score_sl}")
#> SL SVM score: 0.955078125
The lack of colour channels in MNIST (and MNIST-like) datasets means we get errors. It looks like the network implementations aren't set up to handle single channel inputs? Could you parameterise them so that they can be used for MNIST? It would be ideal to just write:
# This:
model = PiModel(channels=1)
# Or this
model = PiModel(shape=(28,28,1))
Thanks for the great work, I need your help.
If I want to solve the three-category problem, which code should I modify. For example: if there are three categories in the BreastCancer dataset. Because I found that when I didn't modify any code, the confusion matrix only made predictions for the first two classes.
Result/Co_Training_BreastCancer.txt:
accuracy 0.324468085106383
precision 0.2598727091480715
Recall 0.3464646464646464
F1 0.2306878306878307
Confusion_matrix [[0.16666667 0.83333333 0. ]
[0.12727273 0.87272727 0. ]
[0.12727273 0.87272727 0. ]]
Choosing randomly a subset of indexes at Assemble
https://github.com/YGZWQZD/LAMDA-SSL/blob/master/LAMDA_SSL/Algorithm/Classification/Assemble.py#L88
does not necessarily means that the subset will contain more than one classes. E.g. for small datasets.
Is there a workaround for this?
Thank you in advance.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.