Giter Club home page Giter Club logo

rl-medical's Introduction

RL-Medical

Deep Reinforcement Learning (DRL) agents applied to medical images

Examples

Installation

Dependencies

tensorpack-medical requires:

User installation

pip install -U git+https://github.com/amiralansary/rl-medical.git

Development

New contributors of any experience level are very welcomed

Source code

You can clone the latest version of the source code with the command::

https://github.com/amiralansary/rl-medical.git

Citation

If you use this code in your research, please cite these paper:

@article{alansary2019evaluating,
  title={{Evaluating Reinforcement Learning Agents for Anatomical Landmark Detection}},
  author={Alansary, Amir and Oktay, Ozan and Li, Yuanwei and Le Folgoc, Loic and
          Hou, Benjamin and Vaillant, Ghislain and Kamnitsas, Konstantinos and
          Vlontzos, Athanasios and Glocker, Ben and Kainz, Bernhard and Rueckert, Daniel},
  journal={Medical Image Analysis},
  year={2019},
  publisher={Elsevier}
}

@inproceedings{alansary2018automatic,
  title={Automatic view planning with multi-scale deep reinforcement learning agents},
  author={Alansary, Amir and Le Folgoc, Loic and Vaillant, Ghislain and Oktay, Ozan and Li, Yuanwei and
  Bai, Wenjia and Passerat-Palmbach, Jonathan and Guerrero, Ricardo and Kamnitsas, Konstantinos and Hou, Benjamin and others},
  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  pages={277--285},
  year={2018},
  organization={Springer}
}

rl-medical's People

Contributors

amiralansary avatar brdav avatar crypdick avatar ghisvail avatar gml16 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

rl-medical's Issues

Help

Hello,author. When I run DQN.py in the SingleAgent File, the error is "sitk::ERROR: Unable to open "./data/images/ADNI_002_S_0816_MR_MPR__GradWarp__B1_Correction__N3__Scaled_Br_20070217005829488_S18402_I40731_Normalized_to_002_S_0295.nii.gz" for reading." So,how should I deal with this issue?

Input data format

Hello, I am Sky Kim.

I am fascinated in your project which finds mid-sagittal plane in 3d volume!
However, when I try to run this source code, I realized that your data is not public.
If I want to run this code with my own input data, how could I make data format and 'list_of_train_filenames.txt'?
Please let me know the sample format of the train and test data.
Thanks in advance.

too slow training speed

too slow training speed

my current env is
win7 x64 System
Nvidia Geforce GTX 1080 (8G)
CUDA9.0
cuDNN7.0.5
tensorflow-gpu(1.6.0)
tensorpack (0.8.0)
gym now use(0.12.1)

and i used examples data for training
\tensorpack-medical\examples\LandmarkDetection\DQN\data\filenames\image_files
\tensorpack-medical\examples\LandmarkDetection\DQN\data\filenames\landmark_files

for gpu memory limit, and i used parameters:
BATCH_SIZE = 24

and GPU and CPU setting:
mem_fraction = 0.8
# conf = tf.ConfigProto(log_device_placement=True)
conf = tf.ConfigProto()
# conf.allow_soft_placement = True
conf.intra_op_parallelism_threads = 6
conf.inter_op_parallelism_threads = 6
conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
conf.gpu_options.allow_growth = True

and exclude Data Load's effect.i used FakeData

dataflow = FakeData([[BATCH_SIZE,45,45,45,5],[BATCH_SIZE],[BATCH_SIZE],[BATCH_SIZE]],size=1000,random=False, dtype=['uint8','float32','int8','bool'])

and minimal training setting:

return TrainConfig(
data=QueueInput(dataflow),
model=Model(),
callbacks=[],
# steps_per_epoch=10,
steps_per_epoch=10,
max_epoch=1000,
session_config= conf,
)

the training speed is 28 seconds per iter.

even i reduce the model complexness (by commented Conv3D and Pool3D ):

with argscope(Conv3D, nl=PReLU.symbolic_function, use_bias=True):
# core layers of the network
conv = (LinearWrap(image)
.Conv3D('conv0', out_channel=32,
kernel_shape=[5,5,5], stride=[1,1,1])
.MaxPooling3D('pool0',16)
# .Conv3D('conv1', out_channel=32,
# kernel_shape=[5,5,5], stride=[1,1,1])
# .MaxPooling3D('pool1',2)
# .Conv3D('conv2', out_channel=64,
# kernel_shape=[4,4,4], stride=[1,1,1])
# .MaxPooling3D('pool2',2)
# .Conv3D('conv3', out_channel=64,
# kernel_shape=[3,3,3], stride=[1,1,1])
)

the training speed is 22 seconds per iter.

it is 100x slow by comparison with your training speed
{around ~3-4 it/sec using the default big architecture on a GTX 1080}

I want to know why and
please give me some suggestions about reduce the training time.

Upgrade to latest tensorpack dependency: DQN.py returning a 'NotImplementedError'

Thanks for your project. I got some troubles when I tried to reproduce your example 'LandmarkDetection'.

  1. When I execute the command with your default setting , I get an error listed below that I cannot solve it. My environment is
  • python 3.5.2
  • tensorflow 1.13.1 cpu version
  • tensorpack 0.9.4
  • Ubuntu 16.04.6 LTS
# command
python3 DQN.py --task train --algo DQN --files './data/filenames/image_files.txt' './data/filenames/landmark_files.txt'
# error
Traceback (most recent call last):
  File "DQN.py", line 263, in <module>
    launch_train_with_config(config, SimpleTrainer())
  File "/home/ty/.local/lib/python3.5/site-packages/tensorpack/train/interface.py", line 90, in launch_train_with_config
    model.get_input_signature(), input,
  File "/home/ty/.local/lib/python3.5/site-packages/tensorpack/utils/argtools.py", line 200, in wrapper
    value = func(*args, **kwargs)
  File "/home/ty/.local/lib/python3.5/site-packages/tensorpack/graph_builder/model_desc.py", line 92, in get_input_signature
    return [TensorSpec(shape=p.shape, dtype=p.dtype, name=get_op_tensor_name(p.name)[0]) for p in inputs]
  File "/usr/lib/python3.5/contextlib.py", line 77, in __exit__
    self.gen.throw(type, value, traceback)
  File "/home/ty/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 5253, in get_controller
    yield g
  File "/usr/lib/python3.5/contextlib.py", line 77, in __exit__
    self.gen.throw(type, value, traceback)
  File "/home/ty/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 5061, in get_controller
    yield default
  File "/home/ty/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 5253, in get_controller
    yield g
  File "/usr/lib/python3.5/contextlib.py", line 77, in __exit__
    self.gen.throw(type, value, traceback)
  File "/home/ty/.local/lib/python3.5/site-packages/tensorflow/python/eager/context.py", line 415, in _mode
    yield
  File "/home/ty/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 5253, in get_controller
    yield g
  File "/home/ty/.local/lib/python3.5/site-packages/tensorpack/graph_builder/model_desc.py", line 86, in get_input_signature
    inputs = self.inputs()
  File "/home/ty/.local/lib/python3.5/site-packages/tensorpack/graph_builder/model_desc.py", line 116, in inputs
    raise NotImplementedError()
NotImplementedError
  1. according to a tensorpack issue , get_tf_version_number was deprecated by tensorflow and you should useget_tf_version_tuple in your tensorpack-medical .

'function' object has no attribute 'symbolic_function'

Hi, when I try to run the DQN.py inside examples/LandmarkDetection using the command:
python DQN.py --task play --algo DQN --gpu 0 --load data/models/DQN_multiscale_brain_mri_point_pc_ROI_45_45_45/model-600000 --files './data/filenames/image_files.txt'

I got the following error:

File "DQN.py", line 88, in _get_DQN_prediction
with argscope(Conv3D, nl=PReLU.symbolic_function, use_bias=True):
AttributeError: 'function' object has no attribute 'symbolic_function'

Can you help please?

Models

Hi Amir,

Is it possible to upload the u/s models used (CSP detection)?

LeakyRelu depreciated upstream

When I clone and run your code I get the following:

> python DQN.py --algo DQN --gpu 0
Traceback (most recent call last):
  File "DQN.py", line 29, in <module>
    from tensorpack import (PredictConfig, OfflinePredictor, get_model_loader, logger, TrainConfig, ModelSaver, PeriodicTrigger, ScheduledHyperParamSetter, ObjAttrParam, HumanHyperParamSetter, argscope, RunOp, LinearWrap, FullyConnected, LeakyReLU, PReLU, SimpleTrainer, launch_train_with_config)
ImportError: cannot import name 'LeakyReLU'

It looks like it got deprecated upstream. I solved the issue by removing that import and adding this line:

LeakyRelu = tf.nn.leaky_relu

Improve accuracy

Hi,

First of all, thank you so much for your great and useful code. I'm trying to apply this code to my dataset, 37 pancreas ducts (small size organ) and I have 2 landmarks. training error is around 0.5, but test error is around 16-20. I tried with the default parameters and sum change on, step_per_episode, batch size, target update freq, and learning rate. Do you have any suggestions for decreasing the gap between these two errors?

Thanks in advance

How to resume old training

Hi, I'm starting to use your code and I'm having trouble while trying to resume a training session.
I searched for the warning that the code gives:

WRN If you want to resume old training, either use AutoResumeTrainConfig or correctly set the new starting_epoch yourself to avoid inconsistency.

Yet, I still don't understand how to do it. I don't know if it is possible to do so or not, it seems like it should, but I did not managed to do it by myself. Could you please shed some light on this problem or point me where to find an answer.

Thank you for your time in advance.

dataset

Hello Amir,
Could you tell us how to make a dataset for automatic view planning? for your data set is not public.Thank you so much!

FileNotFoundError: [Errno 2] No such file or directory: 'list_of_test_filenames.txt'

Hi, I think this is a really cool project and I'm trying to use it in my research. I am trying to reproduce your results but I run into trouble because I don't have the training data:

 ~/bin/tensorpack-medical/examples/LandmarkDetection/DQN   master ●  python DQN.py --algo DQN --gpu 0
WARN: gym.spaces.Box autodetected dtype as <class 'numpy.uint8'>. Please provide explicit dtype.
Traceback (most recent call last):
  File "DQN.py", line 205, in <module>
    screen_dims=IMAGE_SIZE)
  File "/home/ubuntu/bin/tensorpack-medical/examples/LandmarkDetection/DQN/medical.py", line 157, in __init__
    self.files = filesListFetalUSLandmark(directory,files_list)
  File "/home/ubuntu/bin/tensorpack-medical/examples/LandmarkDetection/DQN/sampleTrain.py", line 321, in __init__
    self.files_list = [line.split('\n')[0] for line in open(files_list)]
FileNotFoundError: [Errno 2] No such file or directory: 'list_of_test_filenames.txt'

Where can I get the dataset? I see in your paper that you cite another paper's dataset, but I did not see a link to that dataset anywhere in that paper, either.

Modifying DQN model to accept 3D images

I've had some difficulties modifying your code to work directly on image stacks. Your RL model uses the past few 2D frames as channels and does 3D convolutions on that frame history. Instead, I want my agent to only see the current step (i.e. FRAME_HISTORY = 1) but for the inputs to be single-channel image stacks. I was hoping you could give me some insight.

DQN.py training demo code error

The example code works great for 'eva'l and 'play' tasks, but when I tried running the training example, I'm getting errors such 'TypeError: step() missing 1 required positional argument: 'isOver'.

Here is the command that I used:
python DQN.py --task train --algo DQN --gpu 0 --files './data/filenames/image_files.txt' './data/filenames/landmark_files.txt'

Any help you could provide is greatly appreciated! I'm really excited about your published results.

About the success ratio for Automatic View Planning

Dear Amir Alansary,

May I ask what success ratio did you achieve in training? Since after training for more than 48h (2,875,000 steps) I could only reach a success ratio of about 0.1 and the mean distance between plane params were more than 10. In your paper I noticed that you trained the model for 12-24 hours, so I really want to know about your success ratio in training. May I ask for your help?
Thank you very much.

Best,
Keyu

about ADNI dataset

After I applied for ADNI data set, I did not find the anatomical landmark labels of corresponding MRI. Could you please tell me whether the dataset you used for training was marked by yourself ? If not,could you tell me how to get the label datasets ?

A Question about medical.py

Dear Amir Alansary,

Thanks for your work. I have a question about the environment "medical.py" when reading your code, in line 460 in definition of the function getBestLocation():
Why do you use "best_idx = best_qvalues.argmin()" instead of argmax() to get the best location with best q value?
Hope you can help me understand it. Thank you a lot!

Best,
Keyu

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.