Giter Club home page Giter Club logo

lamda-ssl's People

Contributors

wnjxyk avatar ygzwqzd avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

lamda-ssl's Issues

I cannot reproduce some of the Examples with deep algorithms (maybe because of pytorch 2.X?)

I just installed LAMDA-SSL from github. It instaled the newer version of all packages, including torch==2.0.1 (pip freeze below)

I cannot reproduce the Example that uses deeplearning

Assemble and others non-deep algorithms work fine:

(luan) Atlas:LAMDA-SSL wainer$ python Examples/Assemble_BreastCancer.py  
(luan) Atlas:LAMDA-SSL wainer$

but

(luan) Atlas:LAMDA-SSL wainer$ python Examples/FixMatch_BreastCancer.py 
Traceback (most recent call last):
  File "/Users/wainer/Dropbox/alunos/luan/LAMDA-SSL/Examples/FixMatch_BreastCancer.py", line 64, in <module>
    model.fit(X=labeled_X,y=labeled_y,unlabeled_X=unlabeled_X)
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 326, in fit
    self.init_train_dataloader()
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 243, in init_train_dataloader
    self._labeled_dataloader, self._unlabeled_dataloader = self._train_dataloader.init_dataloader(
                                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataloader/TrainDataloader.py", line 344, in init_dataloader
    self.labeled_dataloader = self.labeled_dataloader.init_dataloader(dataset=self.labeled_dataset,
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataloader/LabeledDataloader.py", line 86, in init_dataloader
    self.dataloader= DataLoader(dataset=self.dataset,
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 245, in __init__
    raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
ValueError: prefetch_factor option could only be specified in multiprocessing.let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.

I have been altering the obvious things such as default prefetch_factor and num_workers but after 2 hours of doing this I still get some problem somewhere. Below is my last attempt, by creating Dataloaders with the appropriate num_workers and prefetch_factor for the FixMatch_BreastCancer.py code, but I am not sure my modifications are correct. Someone is probably much more competent to make these changes...

(luan) Atlas:progs wainer$ python a2.py
Traceback (most recent call last):
  File "/Users/wainer/Dropbox/alunos/luan/progs/a2.py", line 82, in <module>
    model.fit(X=labeled_X,y=labeled_y,unlabeled_X=unlabeled_X)
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 335, in fit
    self.fit_epoch_loop(valid_X,valid_y)
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 311, in fit_epoch_loop
    self.fit_batch_loop(valid_X,valid_y)
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/DeepModelMixin.py", line 280, in fit_batch_loop
    for (lb_idx, lb_X, lb_y), (ulb_idx, ulb_X, _) in zip(self._labeled_dataloader, self._unlabeled_dataloader):
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataset/LabeledDataset.py", line 217, in __getitem__
    Xi, yi = self.apply_transform(Xi, yi)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataset/LabeledDataset.py", line 185, in apply_transform
    _X = self._transform(X, item)
         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataset/LabeledDataset.py", line 130, in _transform
    X=self._transform(X,item)
      ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Dataset/LabeledDataset.py", line 132, in _transform
    X = transform(X)
        ^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/Transformer.py", line 18, in __call__
    return self.fit_transform(X,y,fit_params=fit_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/sklearn/utils/_set_output.py", line 140, in wrapped
    data_to_wrap = f(self, X, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Base/Transformer.py", line 30, in fit_transform
    return self.fit(X=X,y=y,fit_params=fit_params).transform(X)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/sklearn/utils/_set_output.py", line 140, in wrapped
    data_to_wrap = f(self, X, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/wainer/miniconda3/envs/luan/lib/python3.11/site-packages/LAMDA_SSL/Transform/ToTensor.py", line 51, in transform
    X=torch.Tensor(X)
      ^^^^^^^^^^^^^^^
TypeError: new(): data must be a sequence (got Image)

I would guess that the problem is with the torch 2.X version, but I am not sure.

pip freeze:

(luan) Atlas:progs wainer$ pip freeze
certifi==2023.5.7
charset-normalizer==3.1.0
contourpy==1.0.7
cycler==0.11.0
filelock==3.12.0
flake8==6.0.0
fonttools==4.39.4
idna==3.4
Jinja2==3.1.2
joblib==1.2.0
kiwisolver==1.4.4
LAMDA-SSL @ file:///Users/wainer/Dropbox/alunos/luan/LAMDA-SSL
MarkupSafe==2.1.3
matplotlib==3.7.1
mccabe==0.7.0
mpmath==1.3.0
networkx==3.1
numpy==1.24.3
packaging==23.1
pandas==2.0.2
Pillow==9.5.0
psutil==5.9.5
pycodestyle==2.10.0
pyflakes==3.0.1
pyparsing==3.0.9
python-dateutil==2.8.2
pytz==2023.3
requests==2.31.0
scikit-learn==1.2.2
scipy==1.10.1
six==1.16.0
sympy==1.12
threadpoolctl==3.1.0
torch==2.0.1
torch-geometric==2.3.1
torchdata==0.6.1
torchtext==0.15.2
torchvision==0.15.2
tqdm==4.65.0
typing_extensions==4.6.3
tzdata==2023.3
urllib3==2.0.3

What is the appropriate hyperparameter of MeanTeacherReg / ICTReg / PiModelReg when using other dataset?

Hi, Thanks for creating this library.
I am trying to use MeanTeacherReg, ICTReg, and PiModelReg with my custom dataset.
When I used MeanTeacherReg, ICTReg, and PiModelReg without changing the hyperparameter at the example code that was fitted to Boston Dataset, although the model works, the predictions all come to zero.
This means that the model did not learn my custom dataset properly.

What is the appropriate hyperparameter of MeanTeacherReg / ICTReg / PiModelReg when using other dataset?

I thought there are some hyperparemeters to change when using other dataset such as:

  1. num_samples
# sampler
labeled_sampler=RandomSampler(replacement=True,num_samples=64*(2**20))
unlabeled_sampler=RandomSampler(replacement=True)
valid_sampler=SequentialSampler()
test_sampler=SequentialSampler()
  1. batch_size
#dataloader
labeled_dataloader=LabeledDataLoader(batch_size=64,num_workers=0,drop_last=True)
unlabeled_dataloader=UnlabeledDataLoader(num_workers=0,drop_last=True)
valid_dataloader=UnlabeledDataLoader(batch_size=64,num_workers=0,drop_last=False)
test_dataloader=UnlabeledDataLoader(batch_size=64,num_workers=0,drop_last=False)
  1. epoch / num_it_epoch / num_it_total / eval_it
model=MeanTeacherReg(lambda_u=0,warmup=0.4,
               mu=1,weight_decay=5e-4,ema_decay=0.999,
               epoch=1,num_it_epoch=4000,
               num_it_total=4000,
               eval_it=200,device='cpu',
               labeled_dataset=labeled_dataset,
               unlabeled_dataset=unlabeled_dataset,
               valid_dataset=valid_dataset,
               test_dataset=test_dataset,
               labeled_sampler=labeled_sampler,
               unlabeled_sampler=unlabeled_sampler,
               valid_sampler=valid_sampler,
               test_sampler=test_sampler,
               labeled_dataloader=labeled_dataloader,
               unlabeled_dataloader=unlabeled_dataloader,
               valid_dataloader=valid_dataloader,
               test_dataloader=test_dataloader,
               augmentation=augmentation,network=network,
               optimizer=optimizer,scheduler=scheduler,
               evaluation=evaluation,file=file,verbose=True)

The size of my custom dataset,

  • labeled_X is (8760, 10),
  • labeled_y is (8760, 1),
  • Unlabeled_X is (8760, 10),
  • Unlabeled_y is (8760, 1)
  • Test_X is (8760, 10)
  • Test_y is (8760, 1).

At this setting, the model did not train my dataset properly.

Can you provide some example code that you worked the model using a different dataset, not Boston dataset?
Or, is there any tip to implement hyperparameter tuning of MeanTeacherReg, ICTReg, and Pi Model Reg?

错别字

中文使用教程页面[Constrained Seed k-means]部分,最后一句话“使用有标注数据参于聚类过成时聚类器更加可靠,……”有错别字,应该为“过程

Add benchmark performance

It would be useful to also see the performance of each SSL model against the purely supervised backbone run on the labeled data.

For example, TSVM vs pure SVM:

import numpy as np
from LAMDA_SSL.Dataset.Tabular.BreastCancer import BreastCancer

dataset = BreastCancer(labeled_size=0.1, stratified=True, shuffle=True)
labeled_X = dataset.labeled_X
labeled_y = dataset.labeled_y
unlabeled_X = dataset.unlabeled_X
unlabeled_y = dataset.unlabeled_y

from sklearn import preprocessing

pre_transform = preprocessing.StandardScaler()
pre_transform.fit(np.vstack([labeled_X, unlabeled_X]))
labeled_X = pre_transform.transform(labeled_X)
unlabeled_X = pre_transform.transform(unlabeled_X)

from LAMDA_SSL.Algorithm.Classification.TSVM import TSVM

# I tried using a range of Cl and Cu, starting from 15 and 0.0001 and then gradually 
# upping Cu and decreasing Cl. It didn't seem to make a difference?
model = TSVM(Cl=1, Cu=1, kernel="linear")

model.fit(X=labeled_X, y=labeled_y, unlabeled_X=unlabeled_X)
pred_y = model.predict()

from LAMDA_SSL.Evaluation.Classifier.Accuracy import Accuracy

score = Accuracy().scoring(unlabeled_y, pred_y)
print(f"SSL TSVM score: {score}")
#> SSL TSVM score: 0.9609375

# Compare with pure SVM
from sklearn import svm
model_sl = svm.SVC()
model_sl.fit(labeled_X, labeled_y)
pred_sl = model_sl.predict(unlabeled_X)
score_sl = Accuracy().scoring(unlabeled_y, pred_sl)
print(f"SL SVM score: {score_sl}")
#> SL SVM score: 0.955078125

Pi Model for MNIST

The lack of colour channels in MNIST (and MNIST-like) datasets means we get errors. It looks like the network implementations aren't set up to handle single channel inputs? Could you parameterise them so that they can be used for MNIST? It would be ideal to just write:

# This:
model = PiModel(channels=1)

# Or this
model = PiModel(shape=(28,28,1))

How to run my three-category tabular data

Thanks for the great work, I need your help.

If I want to solve the three-category problem, which code should I modify. For example: if there are three categories in the BreastCancer dataset. Because I found that when I didn't modify any code, the confusion matrix only made predictions for the first two classes.

Result/Co_Training_BreastCancer.txt:
accuracy 0.324468085106383
precision 0.2598727091480715
Recall 0.3464646464646464
F1 0.2306878306878307
Confusion_matrix [[0.16666667 0.83333333 0. ]
[0.12727273 0.87272727 0. ]
[0.12727273 0.87272727 0. ]]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.