Giter Club home page Giter Club logo

dcam's Issues

Unable to run dInceptionModel or dResNetBaseline

Running dInception or dResNetBaseline leaves me with an error RuntimeError: Expected 2D (unbatched) or 3D (batched) input to conv1d, but got input of size: [<batch size>, <nr of dimensions>, <nr of dimensions>, <time series length>]. I tried this with the FingerMovements dataset.

Here is a minimally reproducible example following the flow of script_exp.py:

from CNN_models import *
from DCAM import *
from sklearn.model_selection import train_test_split

with open("FingerMovements.pickle",'rb') as f:
    X,y = pickle.load(f)
    
# Convert the UCR-UEA format into a list of list
def generate_list_instance(x):
    res = []
    for i in range(len(x)):
        res.append(list(x[i]))
    return np.array(res)

dict_label = {}
count = 0
for val in set(y.values):
    dict_label[val] = count
    count += 1

all_class_all = []
all_label = []
for i in range(len(X)):
    all_class_all.append(generate_list_instance(X.values[i]))
    all_label.append(dict_label[y.values[i]])

original_length = len(all_class_all[0][0])
num_classes = len(set(y.values)) 
original_dim = len(all_class_all[0])
nb_instance = len(all_class_all) 

all_class, all_class_test, label, label_test = train_test_split(all_class_all, all_label,stratify=all_label, test_size=1-0.8,random_state=11081994)

# Generate C-wised input for d-based models (i.e., dCNN, dResNet, and dInceptionTime)
def gen_cube(instance):
    result = []
    for i in range(len(instance)):
        result.append([instance[(i+j)%len(instance)] for j in range(len(instance))])
    return result 

x = np.array([gen_cube(acl) for acl in all_class])
dataset_mat = TSDataset(x,label)
dataloader_cl1 = data.DataLoader(dataset_mat, batch_size=32, shuffle=True)

x = np.array([gen_cube(acl) for acl in all_class_test])
dataset_mat_test = TSDataset(x,label_test)
dataloader_cl1_test = data.DataLoader(dataset_mat_test, batch_size=1, shuffle=True)

# This is dInceptionTime
modelarch = dInceptionModel(num_blocks=3, in_channels=original_dim, out_channels=64,
                           bottleneck_channels=64, kernel_sizes=[10,20,40],
                           use_residuals=True, num_pred_classes=num_classes).to('cuda')

# dResNet gives the same error
# modelarch = dResNetBaseline(original_dim,mid_channels=128,num_pred_classes=num_classes).to(device)

model = ModelCNN(modelarch,'cuda')

model.train(num_epochs=70,dataloader_cl1=dataloader_cl1,dataloader_cl1_test=dataloader_cl1_test)

When I swap the modelarch to modelarch = ConvNet2D(original_length,original_dim,original_dim,num_classes).to('cuda'), as in Synthetic_experiment_DCAM.ipynb, training works fine.

ablation/occlusion vs dCAM

Hi, thank you for your paper! I read your paper, and I was wondering if you had any thoughts as to why dCAM was not compared with a gradCAM channel ablation/occlusion approach (setting one or a combination of input channels to zero)? Channel ablation seems like such a simple idea to extend gradCAM for multivariate time-series. So I'm wondering if there is a reason that dCAM results were not compared with a gradCAM feature ablation approach? A recent paper tried ablation for gradCAM, but they did not compare with dCAM.

Thanks for your feedback!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.