Comments (6)
It does indeed look like a bug on the JENN side. Thank you for finding it. It will be fixed in the next release.
When shuffle is False and there are more than one mini batches, the indices are generated using numpy.arange
, which is what is yielding the error shown. When shuffle=True
, indices are correctly cast as a list suitable for logical comparison. The easy fix, for now, is to set shuffle=True
.
To answer the other question, the API shown in the user's example is actually that of the upstream JENN library. There is a separate SMT API (which simply maps to the JENN API under the hood). Here is an example from the SMT docs:
import numpy as np
import matplotlib.pyplot as plt
from smt.surrogate_models import GENN
# Test function
def f(x):
import numpy as np # need to repeat for sphinx_auto_embed
return x * np.sin(x)
def df_dx(x):
import numpy as np # need to repeat for sphinx_auto_embed
return np.sin(x) + x * np.cos(x)
# Domain
lb = -np.pi
ub = np.pi
# Training data
m = 4
xt = np.linspace(lb, ub, m)
yt = f(xt)
dyt_dxt = df_dx(xt)
# Validation data
xv = lb + np.random.rand(30, 1) * (ub - lb)
yv = f(xv)
dyv_dxv = df_dx(xv)
# Instantiate
genn = GENN()
# Likely the only options a user will interact with
genn.options["hidden_layer_sizes"] = [6, 6]
genn.options["alpha"] = 0.1
genn.options["lambd"] = 0.1
genn.options["gamma"] = 1.0 # 1 = gradient-enhanced on, 0 = gradient-enhanced off
genn.options["num_iterations"] = 500
genn.options["is_backtracking"] = True
# Train
genn.load_data(xt, yt, dyt_dxt)
genn.train()
# Plot comparison
if genn.options["gamma"] == 1.0:
title = "with gradient enhancement"
else:
title = "without gradient enhancement"
x = np.arange(lb, ub, 0.01)
y = f(x)
y_pred = genn.predict_values(x)
fig, ax = plt.subplots()
ax.plot(x, y_pred)
ax.plot(x, y, "k--")
ax.plot(xv, yv, "ro")
ax.plot(xt, yt, "k+", mew=3, ms=10)
ax.set(xlabel="x", ylabel="y", title=title)
ax.legend(["Predicted", "True", "Test", "Train"])
plt.show()
from smt.
@shb84 do you know the problem ?
from smt.
@shb84 thank you for your answer, and the example. In my original setup with the old (non-JENN) implementation, I had difficulty using the SMT load_data
method to load a dataset with 6 inputs, 2 outputs, and a derivative set containing partials for both outputs - X is (6, 102), Y is (2, 102) and J is (2, 6, 102). The method seemed to expect that J had a shape (n_x, n_m).
What is the proper way to have the SMT GENN training consider both sets of derivatives (dy1/dx and dy2/dx for all x)?
from smt.
For all SMT work, the correct format to use is the one found in the SMT docs.
For clarity, JENN is a separate library that does indeed use a different data format, but you can ignore that as an SMT user. The update GENN module expects SMT formatted data only. Refer to the docs for what that format is.
Under the hood, the load_data
method is implemented as follows:
def load_data(self, xt, yt, dyt_dxt=None):
"""Load all training data into surrogate model in one step.
:param model: SurrogateModel object for which to load training data
:param xt: smt data points at which response is evaluated
:param yt: response at xt
:param dyt_dxt: gradient at xt
"""
m, n_x = (xt.size, 1) if xt.ndim <= 1 else xt.shape
m, n_y = (yt.size, 1) if yt.ndim <= 1 else yt.shape
# Reshape arrays
xt = xt.reshape((m, n_x))
yt = yt.reshape((m, n_y))
# Load values
self.set_training_values(xt, yt)
# Load partials
if dyt_dxt is not None:
dyt_dxt = dyt_dxt.reshape((m, n_x))
for i in range(n_x):
self.set_training_derivatives(xt, dyt_dxt[:, i].reshape((m, 1)), i
)
It is simply a convenience method that feeds the SMT API methods set_training_values
and set_training_derivatives
in one step. Hence, you can always fall back on those methods to load your data if you keep having trouble.
Please let me know if you still have issues or, better yet, if you can provide a self-contained example with the data that generates the issue, I'd be happy to get it running.
from smt.
@bpaul4 Just circling back to see if these answers had resolved your issue or if you still needed help.
from smt.
Hi @shb84, thank you for your help. I am studying the documentation and working through my example, and I will let you know if I have further questions. I believe this issue may be closed for now.
from smt.
Related Issues (20)
- SIGBUS (Misaligned Address Error) HOT 1
- negative axis 1 index
- Training two samples take ~480 seconds HOT 3
- OOB access in RMTB HOT 1
- SGP Gradients unsupported HOT 1
- Relevance of the warning `R is too ill conditioned...` HOT 3
- ConfigSpace vs Numpy 2.0 compatibility HOT 1
- Document parameters to improve accuracy of SGP
- Recommendation for surrogate optimization
- Multi-Fidelity Models with Variable Design Spaces
- KeyError: 'C' HOT 2
- Question About KeyError: 'design_space' HOT 1
- How to package a trained Surrogate model into an FMU?
- ConfigSpace error with Python 3.12 - TypeError: Expected float, got numpy.float64 HOT 2
- Documentation hyperlink failing HOT 1
- RMTS issue with scipy < 1.12 HOT 1
- KPLSK is badly implemented
- Error with the prediction of variance derivatives in GPX HOT 1
- Wrong code in [7] to plot the Rosenbrock function ? HOT 2
- Question: Nested DOE for MFK
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from smt.