Giter Club home page Giter Club logo

Comments (6)

shb84 avatar shb84 commented on September 23, 2024 2

It does indeed look like a bug on the JENN side. Thank you for finding it. It will be fixed in the next release.

When shuffle is False and there are more than one mini batches, the indices are generated using numpy.arange, which is what is yielding the error shown. When shuffle=True, indices are correctly cast as a list suitable for logical comparison. The easy fix, for now, is to set shuffle=True.

To answer the other question, the API shown in the user's example is actually that of the upstream JENN library. There is a separate SMT API (which simply maps to the JENN API under the hood). Here is an example from the SMT docs:

import numpy as np
import matplotlib.pyplot as plt

from smt.surrogate_models import GENN

# Test function
def f(x):
    import numpy as np  # need to repeat for sphinx_auto_embed
    return x * np.sin(x)

def df_dx(x):
    import numpy as np  # need to repeat for sphinx_auto_embed
    return np.sin(x) + x * np.cos(x)

# Domain
lb = -np.pi
ub = np.pi

# Training data
m = 4
xt = np.linspace(lb, ub, m)
yt = f(xt)
dyt_dxt = df_dx(xt)

# Validation data
xv = lb + np.random.rand(30, 1) * (ub - lb)
yv = f(xv)
dyv_dxv = df_dx(xv)

# Instantiate
genn = GENN()

# Likely the only options a user will interact with
genn.options["hidden_layer_sizes"] = [6, 6]
genn.options["alpha"] = 0.1
genn.options["lambd"] = 0.1
genn.options["gamma"] = 1.0  # 1 = gradient-enhanced on, 0 = gradient-enhanced off
genn.options["num_iterations"] = 500
genn.options["is_backtracking"] = True

# Train
genn.load_data(xt, yt, dyt_dxt)
genn.train()

# Plot comparison
if genn.options["gamma"] == 1.0:
    title = "with gradient enhancement"
else:
    title = "without gradient enhancement"
x = np.arange(lb, ub, 0.01)
y = f(x)
y_pred = genn.predict_values(x)
fig, ax = plt.subplots()
ax.plot(x, y_pred)
ax.plot(x, y, "k--")
ax.plot(xv, yv, "ro")
ax.plot(xt, yt, "k+", mew=3, ms=10)
ax.set(xlabel="x", ylabel="y", title=title)
ax.legend(["Predicted", "True", "Test", "Train"])
plt.show()

from smt.

Paul-Saves avatar Paul-Saves commented on September 23, 2024

@shb84 do you know the problem ?

from smt.

bpaul4 avatar bpaul4 commented on September 23, 2024

@shb84 thank you for your answer, and the example. In my original setup with the old (non-JENN) implementation, I had difficulty using the SMT load_data method to load a dataset with 6 inputs, 2 outputs, and a derivative set containing partials for both outputs - X is (6, 102), Y is (2, 102) and J is (2, 6, 102). The method seemed to expect that J had a shape (n_x, n_m).

What is the proper way to have the SMT GENN training consider both sets of derivatives (dy1/dx and dy2/dx for all x)?

from smt.

shb84 avatar shb84 commented on September 23, 2024

For all SMT work, the correct format to use is the one found in the SMT docs.

For clarity, JENN is a separate library that does indeed use a different data format, but you can ignore that as an SMT user. The update GENN module expects SMT formatted data only. Refer to the docs for what that format is.

Under the hood, the load_data method is implemented as follows:

def load_data(self, xt, yt, dyt_dxt=None):
        """Load all training data into surrogate model in one step.

        :param model: SurrogateModel object for which to load training data
        :param xt: smt data points at which response is evaluated
        :param yt: response at xt
        :param dyt_dxt: gradient at xt
        """
        m, n_x = (xt.size, 1) if xt.ndim <= 1 else xt.shape
        m, n_y = (yt.size, 1) if yt.ndim <= 1 else yt.shape

        # Reshape arrays
        xt = xt.reshape((m, n_x))
        yt = yt.reshape((m, n_y))

        # Load values
        self.set_training_values(xt, yt)

        # Load partials
        if dyt_dxt is not None:
            dyt_dxt = dyt_dxt.reshape((m, n_x))
            for i in range(n_x):
                self.set_training_derivatives(xt, dyt_dxt[:, i].reshape((m, 1)), i
)

It is simply a convenience method that feeds the SMT API methods set_training_values and set_training_derivatives in one step. Hence, you can always fall back on those methods to load your data if you keep having trouble.

Please let me know if you still have issues or, better yet, if you can provide a self-contained example with the data that generates the issue, I'd be happy to get it running.

from smt.

shb84 avatar shb84 commented on September 23, 2024

@bpaul4 Just circling back to see if these answers had resolved your issue or if you still needed help.

from smt.

bpaul4 avatar bpaul4 commented on September 23, 2024

Hi @shb84, thank you for your help. I am studying the documentation and working through my example, and I will let you know if I have further questions. I believe this issue may be closed for now.

from smt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.