Giter Club home page Giter Club logo

medkit-learn's Issues

Example code not working

import medkit as mk
synthetic_dataset = mk.batch_generate(
                                   domain = "Ward",
                                   environment = "CRN",
                                   policy = "LSTM",
                                   size = 1000,
                                   test_size = 200,
                                   max_length = 10,
                                   scale = True)

Gives an error:

Traceback (most recent call last):
  File "<stdin>", line 8, in <module>
  File "/home/gtennenholtz/medkit-learn/medkit/api.py", line 58, in batch_generate
    env = env_dict[environment](dom)
  File "/home/gtennenholtz/medkit-learn/medkit/environments/CounterfactualRNN.py", line 118, in __init__
    self.model = CRN_env(domain)
  File "/home/gtennenholtz/medkit-learn/medkit/environments/CounterfactualRNN.py", line 15, in __init__
    self.lstm_layers = self.hyper["lstm_layers"]
KeyError: 'lstm_layers'

Also:
env = mk.live_simulate(domain="ICU", environment="SVAE")

Gives an error:

Traceback (most recent call last):
  File "<stdin>", line 3, in <module>
  File "/home/gtennenholtz/medkit-learn/medkit/api.py", line 214, in live_simulate
    env = env_dict[environment](dom)
  File "/home/gtennenholtz/medkit-learn/medkit/environments/SequentialVAE.py", line 188, in __init__
    self.load_pretrained()
  File "/home/gtennenholtz/medkit-learn/medkit/bases/base_env.py", line 30, in load_pretrained
    self.model.load_state_dict(torch.load(path))
  File "/home/gtennenholtz/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1224, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SVAE_env:
        Missing key(s) in state_dict: "lstm.weight_ih_l0", "lstm.bias_ih_l0", "lstm.weight_hh_l0", "lstm.bias_hh_l0", "lstm.weight_ih_l1", "lstm.bias_ih_l1", "lstm.weight_hh_l1", "lstm.bias_hh_l1", "lstm.layers.0.cell.ih.weight", "lstm.layers.0.cell.ih.bias", "lstm.layers.0.cell.hh.weight", "lstm.layers.0.cell.hh.bias", "lstm.layers.1.cell.ih.weight", "lstm.layers.1.cell.ih.bias", "lstm.layers.1.cell.hh.weight", "lstm.layers.1.cell.hh.bias".
        Unexpected key(s) in state_dict: "lstm.ih.weight", "lstm.ih.bias", "lstm.hh.weight", "lstm.hh.bias".
        size mismatch for encoder.linear1.weight: copying a param with shape torch.Size([128, 24]) from checkpoint, the shape in current model is torch.Size([128, 37]).
        size mismatch for decoder.series_cont_mean.weight: copying a param with shape torch.Size([23, 128]) from checkpoint, the shape in current model is torch.Size([37, 128]).
        size mismatch for decoder.series_cont_mean.bias: copying a param with shape torch.Size([23]) from checkpoint, the shape in current model is torch.Size([37]).
        size mismatch for decoder.series_cont_lstd.weight: copying a param with shape torch.Size([23, 128]) from checkpoint, the shape in current model is torch.Size([37, 128]).
        size mismatch for decoder.series_cont_lstd.bias: copying a param with shape torch.Size([23]) from checkpoint, the shape in current model is torch.Size([37]).
        size mismatch for decoder.series_bin.weight: copying a param with shape torch.Size([1, 128]) from checkpoint, the shape in current model is torch.Size([0, 128]).
        size mismatch for decoder.series_bin.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([0]).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.