Giter Club home page Giter Club logo

gemma's Introduction

Gemma

Gemma is a family of open-weights Large Language Model (LLM) by Google DeepMind, based on Gemini research and technology.

This repository contains an inference implementation and examples, based on the Flax and JAX.

Learn more about Gemma

Quick start

Installation

  1. To install Gemma you need to use Python 3.10 or higher.

  2. Install JAX for CPU, GPU or TPU. Follow instructions at the JAX website.

  3. Run

python -m venv gemma-demo
. gemma-demo/bin/activate
pip install git+https://github.com/google-deepmind/gemma.git

Downloading the models

The model checkpoints are available through Kaggle at http://kaggle.com/models/google/gemma. Select one of the Flax model variations, click the ⤓ button to download the model archive, then extract the contents to a local directory.

Alternatively, visit the gemma models on the Hugging Face Hub. To download the model, you can run the following code if you have huggingface_hub installed:

from huggingface_hub import snapshot_download

local_dir = snapshot_download(repo_id="google/gemma-2b-flax")
snapshot_download(repo_id="google/gemma-2b-flax", local_dir=local_dir)

In both cases, the archive contains both the model weights and the tokenizer, for example the 2b Flax variation contains:

2b/              # Directory containing model weights
tokenizer.model  # Tokenizer

Running the unit tests

To run the unit tests, install the optional [test] dependencies (e.g. using pip install -e .[test] from the root of the source tree), then:

pytest .

Note that the tests in sampler_test.py are skipped by default since no tokenizer is distributed with the Gemma sources. To run these tests, download a tokenizer following the instructions above, and update the _VOCAB constant in sampler_test.py with the path to tokenizer.model.

Examples

To run the example sampling script, pass the paths to the weights directory and tokenizer:

python examples/sampling.py \
  --path_checkpoint=/path/to/archive/contents/2b/ \
  --path_tokenizer=/path/to/archive/contents/tokenizer.model

There are also several Colab notebook tutorials:

To run these notebooks you will need to download a local copy of the weights and tokenizer (see above), and update the ckpt_path and vocab_path variables with the corresponding paths.

System Requirements

Gemma can run on a CPU, GPU and TPU. For GPU, we recommend a 8GB+ RAM on GPU for the 2B checkpoint and 24GB+ RAM on GPU for the 7B checkpoint.

Contributing

We are open to bug reports, pull requests (PR), and other contributions. Please see CONTRIBUTING.md for details on PRs.

License

Copyright 2024 DeepMind Technologies Limited

This code is licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Disclaimer

This is not an official Google product.

gemma's People

Contributors

alimuldal avatar ddsh avatar huxiaoxu2019 avatar molugan avatar osanseviero avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gemma's Issues

inconsistencies with gemma RAG with langchain

Trying RAG with gemma and langchain:

# Load and split documents
loader = WebBaseLoader("https://h3manth.com")
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_documents(data)

# Create vector store
vectorstore = FAISS.from_documents(documents=all_splits, embedding=HuggingFaceEmbeddings())

# Load RAG prompt
prompt = hub.pull("rlm/rag-prompt")

# Create LLM
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=1000,
    temperature=0.1,
    top_p=0.95,
    repetition_penalty=1.15,
    do_sample=True,
)
llm = HuggingFacePipeline(pipeline=pipe)

# Create RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
    llm=llm, 
    retriever=vectorstore.as_retriever(), 
    chain_type="stuff",  # Specify chain type
    chain_type_kwargs={"prompt": prompt}
)

# Ask question
question = "List the References"
response = qa_chain({"query": question})

print(response["result"])

Doesn't result in any content or some random texts at time, is there something missing in the pipeline?

notebook for the same.

7b errors on deserilization

in m1 mac os..fellowed the read me ,and passed pytest
but when run

python3.11 examples/sampling.py  --path_checkpoint=/Users/garryling/Downloads/archive_test/7b/ --path_tokenizer=/Users/garryling/Downloads/archive_test/tokenizer.model

cause error.

Traceback (most recent call last):
  File "/Users/garryling/workspace/gemma/examples/sampling.py", line 112, in <module>
    app.run(main)
  File "/Users/garryling/gemma-demo/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/Users/garryling/gemma-demo/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/Users/garryling/workspace/gemma/examples/sampling.py", line 102, in main
    _load_and_sample(
  File "/Users/garryling/workspace/gemma/examples/sampling.py", line 73, in _load_and_sample
    parameters = params_lib.load_and_format_params(path_checkpoint)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/garryling/gemma-demo/lib/python3.11/site-packages/gemma/params.py", line 29, in load_and_format_params
    params = load_params(path)
             ^^^^^^^^^^^^^^^^^
  File "/Users/garryling/gemma-demo/lib/python3.11/site-packages/gemma/params.py", line 40, in load_params
    params = checkpointer.restore(path)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/garryling/gemma-demo/lib/python3.11/site-packages/orbax/checkpoint/checkpointer.py", line 166, in restore
    restored = self._handler.restore(directory, args=ckpt_args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/garryling/gemma-demo/lib/python3.11/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 1073, in restore
    restored_item = asyncio.run(
                    ^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/[email protected]/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/runners.py", line 190, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/[email protected]/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/[email protected]/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 653, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/Users/garryling/gemma-demo/lib/python3.11/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 903, in _maybe_deserialize
    deserialized_batches += await asyncio.gather(*deserialized_batches_ops)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/garryling/gemma-demo/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py", line 1134, in deserialize
    ret = await asyncio.gather(*read_ops)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/[email protected]/3.11.6_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/tasks.py", line 694, in _wrap_awaitable
    return (yield from awaitable.__await__())
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: FAILED_PRECONDITION: Error reading "transformer/layer_26/pre_attention_norm.scale/0" in OCDBT database at local file "/Users/garryling/Downloads/archive_test/7b/": Error reading file: /Users/garryling/Downloads/archive_test/7b/ocdbt.process_0/d/400042b385de57e9fbec46e8f38bf642 [OS error: Invalid argument] [source locations='tensorstore/kvstore/file/file_key_value_store.cc:330\ntensorstore/kvstore/kvstore.cc:377']

Unused tokens in gemma tokenizer

I am using "google/gemma-2b-it" model from HuggingFace. I realized there are 99 unused tokens (<unused0> ,<unused1>,<unused2>...) in first 106 token ids. Does anyone know their purpose? Just wondering.

gemma 7B model configuration

Hello,

I am inquiring about the model configuration outlined in your technical report.

In the technical report regarding 'Gemma', the 7B model specifies 'd_model' as 3072 in table 1.
image

I understand 'd_model' to represent the 'hidden size', which should be equivalent to 'Num heads * Head size'.
I was confused because 'Num heads * Head size' equals 4096, while 'd_model' is listed as 3072.
Could you clarify the meaning of 'd_model' and provide the correct 'hidden size' for the Gemma 7B model?

Thank you.

git and install .[test] issues

I followed the installation guide and found some issues that is worth mentioning,otherwise it wont work just follow the guid:

i) you need to have basic concept about "Python install from source" vs pip install and .toml file

ii) Use git clone instead of git+https.......... as suggested from the guide if you are not very adept in installing git into virtual environment.
I followed the pip install git+https://github........./genna,git. It is successful, but then when get to unit tests, i have no idea where is the root of the source tree (relative to a https path) and never get it work.

The reason being in that case, the root of the source is in the ineternet git address and then i have no idea how to do this pip install -e .[test] from the "root of a surce tree located in https..." i tried a variety of ways like

$pip install -e https://..................................../.[test],

I also go to local /site-packages/gemma and do this to get it works, The installaiton works but then bunch of errors in unit tests pop out.

Basically i have no luck and get stucked, it wont work

Finally i tried out a shot tutorial on using .toml file to get familiar with "install from source" concept and used a more conventional "git clone" approach to start install in the local envrionment, i know "the root of the source tree" basically is the folder created by "git clone"

iii) Unit test - jax and jaxlib version 0.4.24, instead of 0.4.25 should be used and confirmed in the following reference.
Use: pip install --force-reinstall -v "jaxlib==0.4.24", and the same for jax after the initial install to change the version. If you use "pip install jaxlib==0.4.24" to change, that may still cause you problems.

It is a challengin install from nvidia cuda, jax, and gemma install especially i am in a wsl (GPU+CPU) environment.

MMLU script require

I've tried many eval repositry(like lm-eval-harness). None of them can achieve score reported on hf page. So can you offer a script like gsm8k example in this repo ? : )

Issue with unit tests on NVIdia A100 (GPU)

Hi everyone.

I see the issue when run unit tests on NVidia A100 (GPU). Here is the link for more details.

Briefly:

=========================== short test summary info ============================
FAILED opt/gemma/gemma/layers_test.py::EinsumTest::test_rmsnorm0 - AssertionE...
FAILED opt/gemma/gemma/modules_test.py::FeedForwardTest::test_ffw0 - Assertio...
FAILED opt/gemma/gemma/positional_embeddings_test.py::PositionalEmbeddingsTest::test_adds_positional_embeddings0
FAILED opt/gemma/gemma/sampler_test.py::SamplerTest::test_forward_equivalence
================== 4 failed, 12 passed, 2 warnings in 26.55s ===================

The first 3 is similar to issues on V100 (#32), but the last one:
4. test_forward_equivalence link. Can you relax the tolerance when run on GPUs?

Error in 4 bit quantization in Gemma-2b

Source-:

Hugging-face documentation.

image

My code-:

image

My error log

  | Title | Author | Genre | SubGenre | Height | Publisher -- | -- | -- | -- | -- | -- | -- Fundamentals of Wavelets | Goswami, Jaideva | tech | signal_processing | 228 | Wiley Data Smart | Foreman, John | tech | data_science | 235 | Wiley God Created the Integers | Hawking, Stephen | tech | mathematics | 197 | Penguin Superfreakonomics | Dubner, Stephen | science | economics | 179 | HarperCollins Orientalism | Said, Edward | nonfiction | history | 197 | Penguin
Token is valid (permission: read).
�[1m�[31mCannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub.
Run the following command in your terminal in case you want to set the 'store' credential helper as default.
git config --global credential.helper store
Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.�[0m
Token has not been saved to git credential helper.
Your token has been saved to C:\Users\HP\.cache\huggingface\token
Login successful
--------------------------------------------------------------------------- ImportError Traceback (most recent call last) Cell In[25], line 1 ----> 1 model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", quantization_config=quantization_config) File m:\Third Year\Sixth Semester\Projects\RAG_project1\venv\Lib\site-packages\transformers\models\auto\auto_factory.py:561, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs) 559 elif type(config) in cls._model_mapping.keys(): 560 model_class = _get_model_class(config, cls._model_mapping) --> 561 return model_class.from_pretrained( 562 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs 563 ) 564 raise ValueError( 565 f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" 566 f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." 567 ) File m:\Third Year\Sixth Semester\Projects\RAG_project1\venv\Lib\site-packages\transformers\modeling_utils.py:3024, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs) 3021 hf_quantizer = None 3023 if hf_quantizer is not None: -> 3024 hf_quantizer.validate_environment( 3025 torch_dtype=torch_dtype, from_tf=from_tf, from_flax=from_flax, device_map=device_map 3026 ) 3027 torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) 3028 device_map = hf_quantizer.update_device_map(device_map)
...
69 "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make" 70 " sure the weights are in PyTorch format." 71 ) ImportError: Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` and the latest version of bitsandbytes: `pip install -i https://pypi.org/simple/ bitsandbytes`
Title Author Genre SubGenre Height Publisher 0 Fundamentals of Wavelets Goswami, Jaideva tech signal_processing 228 Wiley 1 Data Smart Foreman, John tech data_science 235 Wiley 2 God Created the Integers Hawking, Stephen tech mathematics 197 Penguin 3 Superfreakonomics Dubner, Stephen science economics 179 HarperCollins 4 Orientalism Said, Edward nonfiction history 197 Penguin Token is valid (permission: read). �[1m�[31mCannot authenticate through git-credential as no helper is defined on your machine. You might have to re-authenticate when pushing to the Hugging Face Hub. Run the following command in your terminal in case you want to set the 'store' credential helper as default. git config --global credential.helper store Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.�[0m Token has not been saved to git credential helper. Your token has been saved to C:\Users\HP\.cache\huggingface\token Login successful --------------------------------------------------------------------------- ImportError Traceback (most recent call last) Cell In[25], [line 1](vscode-notebook-cell:?execution_count=25&line=1) ----> [1](vscode-notebook-cell:?execution_count=25&line=1) model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", quantization_config=quantization_config)

File m:\Third Year\Sixth Semester\Projects\RAG_project1\venv\Lib\site-packages\transformers\models\auto\auto_factory.py:561, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
559 elif type(config) in cls._model_mapping.keys():
560 model_class = _get_model_class(config, cls._model_mapping)
--> 561 return model_class.from_pretrained(
562 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
563 )
564 raise ValueError(
565 f"Unrecognized configuration class {config.class} for this kind of AutoModel: {cls.name}.\n"
566 f"Model type should be one of {', '.join(c.name for c in cls._model_mapping.keys())}."
567 )

File m:\Third Year\Sixth Semester\Projects\RAG_project1\venv\Lib\site-packages\transformers\modeling_utils.py:3024, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
3021 hf_quantizer = None
3023 if hf_quantizer is not None:
-> 3024 hf_quantizer.validate_environment(
3025 torch_dtype=torch_dtype, from_tf=from_tf, from_flax=from_flax, device_map=device_map
3026 )
3027 torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
3028 device_map = hf_quantizer.update_device_map(device_map)
...
69 "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make"
70 " sure the weights are in PyTorch format."
71 )

ImportError: Using bitsandbytes 8-bit quantization requires Accelerate: pip install accelerate and the latest version of bitsandbytes: pip install -i https://pypi.org/simple/ bitsandbytes

Note

I have already installed accelerate and bitsandbytes

But I still have one confusion the log say that for 8-bit quantisation I need accelerate and other package, but I am doing 4 bit quantization.

Getting 'Killed' message trying to run sampling.py on 2b-it

I'm running on WSL2/Ubuntu on Win11. Deliberately using CPU mode as my GPU is too weak. Using Python 3.10.12.

Here is the output when trying to run sampling.py:

~/gemma$ python3 examples/sampling.py --path_checkpoint=/home/david/gemma/2b-it/ --path_tokenizer=/home/david/gemma/tokenizer.model
Loading the parameters from /home/david/gemma/2b-it/
I0224 16:15:11.469793 140378916851712 checkpointer.py:164] Restoring item from /home/david/gemma/2b-it.
I0224 16:15:27.563535 140378916851712 xla_bridge.py:689] Unable to initialize backend 'cuda':
I0224 16:15:27.563766 140378916851712 xla_bridge.py:689] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0224 16:15:27.568676 140378916851712 xla_bridge.py:689] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
W0224 16:15:27.568863 140378916851712 xla_bridge.py:727] An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
I0224 16:15:27.757100 140378916851712 checkpointer.py:167] Finished restoring checkpoint from /home/david/gemma/2b-it.
Parameters loaded.
Killed

Any idea what could be causing it to blow up during sampling?

gemma-2b text-generation inconsistency

A simple prompt to asking What is electroencephalography? using the below setup:

from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline

hf = HuggingFacePipeline.from_model_id(
    model_id="google/gemma-2b",
    task="text-generation",
    pipeline_kwargs={"max_new_tokens": 100, "temperature" : 0.8, "do_sample": True},
)

from langchain.prompts import PromptTemplate

template = """Question: {question}"""
prompt = PromptTemplate.from_template(template)

chain = prompt | hf

question = "What is electroencephalography?"

print(chain.invoke({"question": question}))

Results in:


 Then? What is electro encephalography? Which is electroencephalography? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is? Which is

Tried changing the task as well.

How to set stop words

Hi, when I prompt gemma to answer question following format:
Q:####
A:#####

Q:####
A:####

Q:#####
A:

I have no idea how to stop gemma when generating "Q:".

Issue when "$ pytest ."

I followed the instructions and I did $ pytest .
--experiment environment--
RTX 4090 Titan, docker image : nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04
Here is the result, how can I fix it?

======================================================================== test session starts ========================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.4.0
rootdir: /workspace/gemma
configfile: pyproject.toml
collected 16 items

gemma/layers_test.py ... [ 18%]
gemma/modules_test.py ...F. [ 50%]
gemma/positional_embeddings_test.py F. [ 62%]
gemma/sampler_test.py .. [ 75%]
gemma/transformer_test.py .... [100%]

============================================================================= FAILURES ==============================================================================
_____________________________________________________________________ FeedForwardTest.test_ffw0 _____________________________________________________________________

self = <gemma.modules_test.FeedForwardTest testMethod=test_ffw0>, features = 2, hidden_dim = 3, batch_size = 2, expected_val = [11.72758674, 47.99916]
expected_shape = (2, 1, 2)

@parameterized.parameters(
    dict(
        features=2,
        hidden_dim=3,
        batch_size=2,
        expected_val=[11.72758674, 47.99916],
        expected_shape=(2, 1, 2),
    ),
)
def test_ffw(
    self, features, hidden_dim, batch_size, expected_val, expected_shape
):
  inputs = jnp.arange(1, batch_size+1)[:, None, None]
  inputs = jnp.repeat(inputs, features, axis=-1)
  ffw = modules.FeedForward(features=features, hidden_dim=hidden_dim)
  params = {
      'gating_einsum': jnp.ones((2, features, hidden_dim)),
      'linear': jnp.ones((hidden_dim, features)),
  }

  outputs = ffw.apply({'params': params}, inputs)
np.testing.assert_array_almost_equal(outputs[:, 0, 0], expected_val)

gemma/modules_test.py:132:


/usr/lib/python3.10/contextlib.py:79: in inner
return func(*args, **kwds)


args = (<function assert_array_almost_equal..compare at 0x7fabb5e4f250>, Array([11.727587, 47.999153], dtype=float32), [11.72758674, 47.99916])
kwds = {'err_msg': '', 'header': 'Arrays are not almost equal to 6 decimals', 'precision': 6, 'verbose': True}

@wraps(func)
def inner(*args, **kwds):
    with self._recreate_cm():
      return func(*args, **kwds)

E AssertionError:
E Arrays are not almost equal to 6 decimals
E
E Mismatched elements: 1 / 2 (50%)
E Max absolute difference: 6.86279297e-06
E Max relative difference: 1.42977356e-07
E x: array([11.727587, 47.999153], dtype=float32)
E y: array([11.727587, 47.99916 ])

/usr/lib/python3.10/contextlib.py:79: AssertionError
_____________________________________________________ PositionalEmbeddingsTest.test_adds_positional_embeddings0 _____________________________________________________
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

self = <gemma.positional_embeddings_test.PositionalEmbeddingsTest testMethod=test_adds_positional_embeddings0>, input_embedding_shape = (2, 1, 1, 5), position = 3
max_wavelength = 100, expected = [[[[1.1411201, 1.0299965, 0.0100075, 1.99955, 1.0]]], [[[1.1411201, 1.0299965, 0.0100075, 1.99955, 1.0]]]]

@parameterized.parameters(
    dict(
        input_embedding_shape=(2, 1, 1, 5),
        position=3,
        max_wavelength=100,
        expected=[[[[1.1411201, 1.0299965, 0.0100075, 1.99955, 1.0]]],
                  [[[1.1411201, 1.0299965, 0.0100075, 1.99955, 1.0]]]]
    )
)
def test_adds_positional_embeddings(
    self, input_embedding_shape, position, max_wavelength, expected
):
outputs = positional_embeddings.add_positional_embedding(
      jnp.ones(input_embedding_shape), position, max_wavelength
  )

gemma/positional_embeddings_test.py:38:


gemma/positional_embeddings.py:42: in add_positional_embedding
return input_embedding + position_embedding
../gemma-demo/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:264: in deferring_binary_op
return binary_op(*args)
../gemma-demo/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py:99: in fn
x1, x2 = promote_args(numpy_fn.name, x1, x2)
../gemma-demo/lib/python3.10/site-packages/jax/_src/numpy/util.py:381: in promote_args
return promote_shapes(fun_name, *promote_dtypes(*args))
../gemma-demo/lib/python3.10/site-packages/jax/_src/numpy/util.py:249: in promote_shapes
_rank_promotion_warning_or_error(fun_name, shapes)


fun_name = 'add', shapes = [(2, 1, 1, 5), (5,)]

def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
  if config.numpy_rank_promotion.value == "warn":
    msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
           "Set the jax_numpy_rank_promotion config option to 'allow' to "
           "disable this warning; for more information, see "
           "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
    warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
  elif config.numpy_rank_promotion.value == "raise":
    msg = ("Operands could not be broadcast together for {} on shapes {} "
           "and with the config option jax_numpy_rank_promotion='raise'. "
           "For more information, see "
           "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
  raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))

E ValueError: Operands could not be broadcast together for add on shapes (2, 1, 1, 5) (5,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.

../gemma-demo/lib/python3.10/site-packages/jax/_src/numpy/util.py:267: ValueError
====================================================================== short test summary info ======================================================================
FAILED gemma/modules_test.py::FeedForwardTest::test_ffw0 - AssertionError:
FAILED gemma/positional_embeddings_test.py::PositionalEmbeddingsTest::test_adds_positional_embeddings0 - ValueError: Operands could not be broadcast together for add on shapes (2, 1, 1, 5) (5,) and with the config option jax_numpy_rank_promotion='raise'. For more i...
============================================================== 2 failed, 14 passed in 91.04s (0:01:31) ==============================================================

'subprocess-exited-with-error' when installing gemma

Hello, I'm following the instructions provided in your Readme, and when I run pip install git+https://github.com/google-deepmind/gemma.git, it throws me an error that says "subprocess-exited-with-error". Here is the error log:

(env) PS C:\Users\aritram21\Desktop\gemma> gsudo pip install git+https://github.com/google-deepmind/gemma.git
Collecting git+https://github.com/google-deepmind/gemma.git
  Cloning https://github.com/google-deepmind/gemma.git to c:\users\aritram21\appdata\local\temp\pip-req-build-e5hq96ji
  Running command git clone --filter=blob:none --quiet https://github.com/google-deepmind/gemma.git 'C:\Users\aritram21\AppData\Local\Temp\pip-req-build-e5hq96ji'
  Resolved https://github.com/google-deepmind/gemma.git to commit 036083ab16843e09369a0138630687dba96d4d23
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
Collecting absl-py<3.0.0,>=2.1.0 (from gemma==1.0.0)
  Using cached absl_py-2.1.0-py3-none-any.whl.metadata (2.3 kB)
Collecting flax<0.8.0,>=0.7.5 (from gemma==1.0.0)
  Using cached flax-0.7.5-py3-none-any.whl.metadata (10 kB)
Collecting sentencepiece<0.2.0,>=0.1.99 (from gemma==1.0.0)
  Using cached sentencepiece-0.1.99.tar.gz (2.6 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... error
  error: subprocess-exited-with-error

  × Getting requirements to build wheel did not run successfully.
  │ exit code: 1
  ╰─> [31 lines of output]
      Traceback (most recent call last):
        File "C:\Users\aritram21\Desktop\gemma\env\Lib\site-packages\pip\_vendor\pyproject_hooks\_in_process\_in_process.py", line 353, in <module>
          main()
        File "C:\Users\aritram21\Desktop\gemma\env\Lib\site-packages\pip\_vendor\pyproject_hooks\_in_process\_in_process.py", line 335, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "C:\Users\aritram21\Desktop\gemma\env\Lib\site-packages\pip\_vendor\pyproject_hooks\_in_process\_in_process.py", line 118, in get_requires_for_build_wheel
          return hook(config_settings)
                 ^^^^^^^^^^^^^^^^^^^^^
        File "C:\Users\aritram21\AppData\Local\Temp\pip-build-env-8939h79t\overlay\Lib\site-packages\setuptools\build_meta.py", line 325, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=['wheel'])
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "C:\Users\aritram21\AppData\Local\Temp\pip-build-env-8939h79t\overlay\Lib\site-packages\setuptools\build_meta.py", line 295, in _get_build_requires
          self.run_setup()
        File "C:\Users\aritram21\AppData\Local\Temp\pip-build-env-8939h79t\overlay\Lib\site-packages\setuptools\build_meta.py", line 487, in run_setup
          super().run_setup(setup_script=setup_script)
        File "C:\Users\aritram21\AppData\Local\Temp\pip-build-env-8939h79t\overlay\Lib\site-packages\setuptools\build_meta.py", line 311, in run_setup
          exec(code, locals())
        File "<string>", line 126, in <module>
        File "C:\Python312\Lib\subprocess.py", line 408, in check_call
          retcode = call(*popenargs, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "C:\Python312\Lib\subprocess.py", line 389, in call
          with Popen(*popenargs, **kwargs) as p:
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "C:\Python312\Lib\subprocess.py", line 1026, in __init__
          self._execute_child(args, executable, preexec_fn, close_fds,
        File "C:\Python312\Lib\subprocess.py", line 1538, in _execute_child
          hp, ht, pid, tid = _winapi.CreateProcess(executable, args,
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      FileNotFoundError: [WinError 2] The system cannot find the file specified
      [end of output]

  note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error

× Getting requirements to build wheel did not run successfully.
│ exit code: 1
╰─> See above for output.

note: This error originates from a subprocess, and is likely not a problem with pip.

I'm using Windows 10 and Python 3.12.1. I have also installed JAX before running the above command. Any help would be greatly appreciated!

Issue when "Running the unit tests"

when use pytest . one test failed:
FAILED gemma/positional_embeddings_test.py::PositionalEmbeddingsTest::test_adds_positional_embeddings0 - ValueError: Operands could not be broadcast together for add on shapes (2, 1, 1
, 5) (5,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see htt...

AND when I running Sampling.py, it failed ,stop at the last ste
image

IS there any relationship between them?OR other reason of the sampling failure?
My JAX version is CPU,WIN 10 system,cuda 12.1,jax==0.25.0

Issue with unit tests on NVIdia V100 (GPU)

Hi everyone.

I see the issue when run unit tests on NVidia V100 (GPU). Here is the link for more details.

Briefly:

=========================== short test summary info ============================
FAILED opt/gemma/gemma/layers_test.py::EinsumTest::test_rmsnorm0 - AssertionE...
FAILED opt/gemma/gemma/modules_test.py::FeedForwardTest::test_ffw0 - Assertio...
FAILED opt/gemma/gemma/positional_embeddings_test.py::PositionalEmbeddingsTest::test_adds_positional_embeddings0
================== 3 failed, 13 passed, 2 warnings in 35.61s ===================```

Some details:
1. test_rmsnorm0 ([link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/9099672951/job/25013672689?pr=590#step:7:348)). Looks like this is an EPS-error. I don't think it's a good idea to compare expected array of floats with resulted one. Is it possible to add some discrepancy between expected and calculated arrays? Like `rtol=1e-5, atol=1e-5`?
2. test_ffw0 ([link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/9099672951/job/25013672689?pr=590#step:7:415)) is similar to previous one.
3. test_adds_positional_embeddings0 [link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/9099672951/job/25013672689?pr=590#step:7:486). IMHO, jax cannot digest is correctly on GPUs

Thank you for your help! Hope it's fixable! =)

gemma + media pipeline

I am using @mediapipe/tasks-genai/wasm and generating the response.

Missing context?

inbrowser.mov

Remove pinned version of `flax`

Hi everyone,

Is it necessary for the project to pin the version of flax? Can you use the latest flax or to use version greater or less than 0.7.5?

Thank you in advance for the reply and fix if possible =)

Colabs don't seem to work

I cannot get the Colabs to run on https://colab.research.google.com.

I had to replace

!pip install https://github.com/deepmind/gemma

with

!pip install "git+https://github.com/google-deepmind/gemma.git"

as the former repository does not exist.

I am still unable to get the versions to match for the code to run. Also, Google provides a free TPU tier for Colab so it would be great if the code could be adapted (or some notes included) to run it on TPU as well as GPU.

After fixing the gemma install and updating the JAX import as:

!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

the code ends up failing with the following stack trace:

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
[<ipython-input-6-cb05cf1a7a98>](https://localhost:8080/#) in <cell line: 2>()
      1 import re
----> 2 from gemma import params as params_lib
      3 from gemma import sampler as sampler_lib
      4 from gemma import transformer as transformer_lib
      5 

3 frames
[/usr/local/lib/python3.10/dist-packages/gemma/params.py](https://localhost:8080/#) in <module>
     20 import jax
     21 import jax.numpy as jnp
---> 22 import orbax.checkpoint
     23 
     24 Params = Mapping[str, Any]

[/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/__init__.py](https://localhost:8080/#) in <module>
     17 import functools
     18 
---> 19 from orbax.checkpoint import checkpoint_utils
     20 from orbax.checkpoint import lazy_utils
     21 from orbax.checkpoint import test_utils

[/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_utils.py](https://localhost:8080/#) in <module>
     23 from jax.sharding import Mesh
     24 import numpy as np
---> 25 from orbax.checkpoint import type_handlers
     26 from orbax.checkpoint import utils
     27 

[/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/type_handlers.py](https://localhost:8080/#) in <module>
     22 from etils import epath
     23 import jax
---> 24 from jax.experimental.gda_serialization import serialization
     25 from jax.experimental.gda_serialization.serialization import get_tensorstore_spec
     26 import jax.numpy as jnp

ModuleNotFoundError: No module named 'jax.experimental.gda_serialization'

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.