Giter Club home page Giter Club logo

ntga's People

Contributors

lionelmessi6410 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

ntga's Issues

Some questions about the version

Hi!Thank you for such a great code.I'm having some problems running.Could you please provide your version of Jaxlib, CUDA, CUDNN?Thanks!

Which jax and jaxline version did you use?

Hi, nice work and thanks for sharing the code! Could you provide the details of your jax and jaxline versions? Perhaps the CUDA and cuDNN versions would also be helpful in debugging the building errors. Thanks for your help!

'ShapedArray' object has no attribute 'val'

Hi, nice work, and thanks for sharing the code. When I was running the code, we encountered the following error.

jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

The detailed output is below

Loading dataset...
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Building model...
Generating NTGA....
  0%|                                                                                                                                                                      | 0/78 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "generate_attack.py", line 228, in <module>
    main()
  File "generate_attack.py", line 195, in main
    nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size)
  File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent
    fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size)
  File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method
    targeted)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss
    donated_invars=donated_invars, inline=inline)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 606, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 593, in _xla_call_impl
    *unsafe_map(arg_spec, args))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 262, in memoized_fun
    ans = call(fun, *args)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 668, in _xla_callable
    fun, abstract_args, pe.debug_info_final(fun, "jit"))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 829, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 901, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 1997, in _vjp
    flat_fun, primals_flat, reduce_axes=reduce_axes)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 115, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 102, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 505, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "generate_attack.py", line 146, in adv_loss
    ntk_train_train = kernel_fn(x_train, x_train, 'ntk')
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss
    donated_invars=donated_invars, inline=inline)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 318, in process_call
    result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 195, in process_call
    f, in_pvals, app, instantiate=False)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 303, in partial_eval
    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1072, in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel
    return _set_shapes(init_fn, kernel, out_kernel)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes
    shape1 = _propagate_shape(init_fn, in_kernel.shape1)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape
    out_shape = tree_map(lambda x: int(x.val), out_shape)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda>
    out_shape = tree_map(lambda x: int(x.val), out_shape)
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "generate_attack.py", line 228, in <module>
    main()
  File "generate_attack.py", line 195, in main
    nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size)
  File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent
    fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size)
  File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method
    targeted)
  File "generate_attack.py", line 146, in adv_loss
    ntk_train_train = kernel_fn(x_train, x_train, 'ntk')
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel
    return _set_shapes(init_fn, kernel, out_kernel)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes
    shape1 = _propagate_shape(init_fn, in_kernel.shape1)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape
    out_shape = tree_map(lambda x: int(x.val), out_shape)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda>
    out_shape = tree_map(lambda x: int(x.val), out_shape)
AttributeError: 'ShapedArray' object has no attribute 'val'

How to extract the learning curve of the mean GP-FNN predictor

Hello,

Thanks you for sharing your code. I am looking at line 193, in plot_learning_dynamics.py. I want to know how to extract the train accuracy and test accuracy of the mean GP-FNN predictor across time steps.

Also, I am looking at contour_plot function at line 88 to learn how to use these train and test accuracy arrays. You define them as follows.

train_acc refers to the training accuracies for different weight variances and time steps.
test_acc refers to the test accuracies for different weight variances and time steps.

What do you mean by 'weight variances'? I will look into the weight variances for sure but can you give some guidance here just in case I am heading in the wrong direction? Any information is appreciated. Many thanks.

Long

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.