import jax.numpy as jnp
import numpy as np
N = 1000
P = 3
prob = np.random.uniform(0.01, 0.5, size=P)
h2g = 0.1
X = np.random.binomial(2, p=prob, size=(N, P))
b = np.random.normal(size=(P)) * np.sqrt(h2g / P)
y = X @ b + np.sqrt(1 - h2g) * np.random.normal(size=(N,))
import jaxopt as jopt
jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [11], in <module>
----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:151, in solve_normal_cg(matvec, b, ridge, init, **kwargs)
148 if ridge is not None:
149 _matvec = _make_ridge_matvec(_matvec, ridge=ridge)
--> 151 Ab = _rmatvec(matvec, b)
153 return jax.scipy.sparse.linalg.cg(_matvec, Ab, x0=init, **kwargs)[0]
File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:114, in _rmatvec(matvec, x)
112 def _rmatvec(matvec, x):
113 """Computes A^T x, from matvec(x) = A x, where A is square."""
--> 114 transpose = jax.linear_transpose(matvec, x)
115 return transpose(x)[0]
File ~/miniconda3/lib/python3.9/site-packages/jax/_src/api.py:2211, in linear_transpose(fun, reduce_axes, *primals)
2208 in_dtypes = map(dtypes.dtype, in_avals)
2210 in_pvals = map(pe.PartialVal.unknown, in_avals)
-> 2211 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
2212 instantiate=True)
2213 out_avals, _ = unzip2(out_pvals)
2214 out_dtypes = map(dtypes.dtype, out_avals)
File ~/miniconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:505, in trace_to_jaxpr(fun, pvals, instantiate)
503 with core.new_main(JaxprTrace) as main:
504 fun = trace_to_subjaxpr(fun, main, instantiate)
--> 505 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
506 assert not env
507 del main, fun, env
File ~/miniconda3/lib/python3.9/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
163 gen = gen_static_args = out_store = None
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
168 # Some transformations yield from inside context managers, so we have to
169 # interrupt them before reraising the exception. Otherwise they will only
170 # get garbage-collected at some later time, running their cleanup tasks only
171 # after this exception is handled, which can corrupt the global state.
172 while stack:
Input In [11], in <lambda>(x)
----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
File ~/miniconda3/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4196, in dot(a, b, precision)
4194 return lax.mul(a, b)
4195 if _max(a_ndim, b_ndim) <= 2:
-> 4196 return lax.dot(a, b, precision=precision)
4198 if b_ndim == 1:
4199 contract_dims = ((a_ndim - 1,), (0,))
File ~/miniconda3/lib/python3.9/site-packages/jax/_src/lax/lax.py:667, in dot(lhs, rhs, precision, preferred_element_type)
664 return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
665 precision=precision, preferred_element_type=preferred_element_type)
666 else:
--> 667 raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
668 lhs.shape, rhs.shape))
TypeError: Incompatible shapes for dot: got (1000, 3) and (1000,).