Hi,
Thx for your codes about these method.I have tried the MS-D-CNN using lodopab dataset
Here is my training codes
`import os
import argparse
import json
try:
from FBPMSDReconstructor import FBPMSDNetReconstructor
MSD_PYTORCH_AVAILABLE = True
except ImportError:
MSD_PYTORCH_AVAILABLE = False
import torch
IMPL = 'astra_cuda'
RESULTS_PATH = '/data0/ct_logs/msdnet'
dataset = get_standard_dataset('lodopab', impl=IMPL)
ray_trafo = dataset.get_ray_trafo(impl=IMPL)
test_data = dataset.get_data_pairs('test', 100)
NOISE_SETTING_DEFAULT = 'gaussian_noise'
NUM_ANGLES_DEFAULT = 50
METHOD_DEFAULT = 'fbpmsdnet'
parser = argparse.ArgumentParser()
parser.add_argument('--noise_setting', type=str, default='gaussian_noise')
parser.add_argument('--num_angles', type=int, default=50)
parser.add_argument('--method', type=str, default='fbpmsdnet')
options = parser.parse_args()
noise_setting = options.noise_setting # 'gaussian_noise', 'scattering'
num_angles = options.num_angles # 50, 10, 5, 2
method = options.method # 'learnedpd', 'fbpunet', 'fbpmsdnet', 'cinn'
name = 'lodopab_{}_{}'.format(noise_setting, method)
from dival import get_standard_dataset
from dival.measure import PSNR
from dival.util.plot import plot_images
import numpy as np
IMPL = 'astra_cuda'
datasets = get_standard_dataset('lodopab', impl=IMPL)
test_data = datasets.get_data_pairs('test', 100)
FBP_DATASET_STATS = {
'noisefree': {
2: {
'mean_fbp': 0.0020300781237049294,
'std_fbp': 0.0036974098858769781,
'mean_gt': 0.0018248517968347585,
'std_gt': 0.0020251920919838714
},
5: {
'mean_fbp': 0.0018914765285141003,
'std_fbp': 0.0027988724415204552,
'mean_gt': 0.0018248517968347585,
'std_gt': 0.0020251920919838714
},
10: {
'mean_fbp': 0.0018791806499857538,
'std_fbp': 0.0023355593815585413,
'mean_gt': 0.0018248517968347585,
'std_gt': 0.0020251920919838714
},
50: {
'mean_fbp': 0.0018856220845133943,
'std_fbp': 0.002038545754978578,
'mean_gt': 0.0018248517968347585,
'std_gt': 0.0020251920919838714
}
},
'gaussian_noise': {
2: {
'mean_fbp': 0.0020300515246877825,
'std_fbp': 0.01135122820016111,
'mean_gt': 0.0018248517968347585,
'std_gt': 0.0020251920919838714
},
5: {
'mean_fbp': 0.0018914835384669934,
'std_fbp': 0.0073404856822226593,
'mean_gt': 0.0018248517968347585,
'std_gt': 0.0020251920919838714
},
10: {
'mean_fbp': 0.0018791781748714272,
'std_fbp': 0.0053367740312729459,
'mean_gt': 0.0018248517968347585,
'std_gt': 0.0020251920919838714
},
50: {
'mean_fbp': 0.0018856252771456445,
'std_fbp': 0.0029598508235758759,
'mean_gt': 0.0018248517968347585,
'std_gt': 0.0020251920919838714
}
},
'scattering': {
2: {
'mean_fbp': 0.68570249744436962,
'std_fbp': 1.3499668155231217,
'mean_gt': 0.002007653630624356, # different from gaussian_noise
'std_gt': 0.0019931366497635745 # since subset of slices is used
},
5: {
'mean_fbp': 0.67324839540841908,
'std_fbp': 0.99012416989800478,
'mean_gt': 0.002007653630624356, # different from gaussian_noise
'std_gt': 0.0019931366497635745 # since subset of slices is used
},
10: {
'mean_fbp': 0.66960775275347806,
'std_fbp': 0.80318946689776671,
'mean_gt': 0.002007653630624356, # different from gaussian_noise
'std_gt': 0.0019931366497635745 # since subset of slices is used
},
50: {
'mean_fbp': 0.67173917657611049,
'std_fbp': 0.6794825395874754,
'mean_gt': 0.002007653630624356, # different from gaussian_noise
'std_gt': 0.0019931366497635745 # since subset of slices is used
}
}
}
ray_trafo = dataset.ray_trafo
assert MSD_PYTORCH_AVAILABLE
reconstructor = FBPMSDNetReconstructor(
ray_trafo,
hyper_params={
'depth': 100,
'width': 1,
'dilations': (1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
'lr': 0.001,
'batch_size': 1,
'epochs': 50,
'data_augmentation': True,
'scheduler': 'none'
},
save_best_learned_params_path=os.path.join(RESULTS_PATH, name),
log_dir=os.path.join(RESULTS_PATH, name),
num_data_loader_workers=0,
)
reconstructor.save_hyper_params(
os.path.join(RESULTS_PATH, '{}_hyper_params.json'.format(name)))
print("start training: '{}'".format(name))
print('hyper_params = {}'.format(
json.dumps(reconstructor.hyper_params, indent=1)))
reconstructor.train(dataset)
recos = []
psnrs = []
for obs, gt in test_data:
reco = reconstructor.reconstruct(obs)
recos.append(reco)
psnrs.append(PSNR(reco, gt))
print('mean psnr: {:f}'.format(np.mean(psnrs)))
import matplotlib.pyplot as plt
for i in range(3):
_, ax = plot_images([recos[i], test_data.ground_truth[i]],
fig_size=(10, 4))
ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs[i]))
ax[0].set_title('CINNReconstructor')
ax[1].set_title('ground truth')
ax[0].figure.suptitle('test sample {:d}'.format(i))
plt.show()`
but it raise error at this code line in 'odl_fourier_transform.py'
and the error is below
the original code about 'torch.fft.rfft(x_preproc,dim=-1)' is 'torch.rfft(x_preproc,1)'
I modify it cause I use the pytorch1.9 version.
is this line modifiation or my dataset loading cause the error ?
and could u provide the ms_d_cnn network training file of lodopab dataset,many thx!