Grain is a library for reading data for training and evaluating JAX models. It's open source, fast and deterministic.
google / grain Goto Github PK
View Code? Open in Web Editor NEWLicense: Apache License 2.0
License: Apache License 2.0
If shard_options is specified in IndexSampler, isn't the dataset being sharded twice?
DataLoader shards dataset if hasattr(self._sampler, "_shard_options")
but sampler will shard it again with ShardLazyDataset() since that hasn't been disabled.
grain/grain/_src/python/samplers.py
Line 123 in f580317
grain/grain/_src/python/data_loader.py
Line 263 in f580317
File "C:\Users\moriyantez\PycharmProjects\tf210\lib\site-packages\grain_src\python\lazy_dataset\transformations\shuffle.py", line 18, in
from grain._src.python.experimental.index_shuffle.python import index_shuffle_module as index_shuffle
ImportError: cannot import name 'index_shuffle_module' from 'grain._src.python.experimental.index_shuffle.python' (unknown location)
This seems to happen at shutdown in any data pipeline that has NumPy arrays. Here is the full stacktrace:
INFO:absl:Process 0 exiting.
INFO:absl:Processing complete for process with worker_index 0
INFO:absl:Grain pool is exiting.
INFO:absl:Shutting down multiprocessing system.
INFO:absl:Shutting down multiprocessing system.
Exception ignored in: <function SharedMemoryArray.__del__ at 0x7e3b780a8a60>
Traceback (most recent call last):
File "/home/black/micromamba/envs/trainpi/lib/python3.10/site-packages/grain/_src/python/shared_memory_array.py", line 139, in __del__
AttributeError: 'NoneType' object has no attribute 'mmap'
/home/black/micromamba/envs/trainpi/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
Even if it's not an actual problem, it's a bit annoying because it overwhelms the logging output when you have many workers.
Here's the simplest possible repro:
import grain.python as grain
import numpy as np
import logging
logging.basicConfig(level=logging.INFO)
if __name__ == "__main__":
class DataSource:
def __len__(self):
return 10
def __getitem__(self, idx):
return np.zeros(1)
source = DataSource()
sampler = grain.IndexSampler(
num_records=len(source),
num_epochs=1,
shard_options=grain.NoSharding(),
shuffle=False
)
loader = grain.DataLoader(
data_source=source,
sampler=sampler,
worker_count=1,
)
for batch in loader:
pass
Hi, thanks for making this! Is there a simple training example, perhaps with MNIST or another small dataset?
Where can I find good Grain usage examples?
I'm trying to use t5x on a GraceHopper computer that has an ARM based CPU.
T5x depends on grain-nightly and install it from pypi.
pip install grain-nightly
work on ARM.
But the installed wheel fail at import as it try to load an .so that is build for x86.
Can the wheel be marked as dependent, so that it isn't found and not installed?
Here is a PR that fix the same issue in another project:
https://github.com/google/array_record/pull/79/files
I can't test it as I'm not able to build this project on x86 and on ARM.
Hi. First off I'd like to say that I'm unsure if I should post this issue here or in the array_record repo or in the tensorflow_datasets repo. But my goal here is to ultimately use grain in my project because I really like the idea of deterministic data loading and easily checkpointing the state, shuffle etc, and I'm obviously using JAX.
The problem is that I can't seem to load ArrayRecords fast with grain for my data. Using TFRecords with TFDS seems to be a lot faster, which isn't really what I'd expect. I suspect this might be an issue with my dataset consisting of large arrays.
My dataset has around 50000 samples, where each sample is a numPy array of shape (100,500,99)
and float32 dtype. Currently my dataset is in 50000 .npy files. I'm testing with a subset of 5000 from them.
...
# arbitrarily chose 50 arrays per ArrayRecord cause I read online 1GB is ok for shard size
num_arrays_shard = 50
filenames = np.array(list(DATA_DIR.iterdir())) # .npy filenames
num_shards = len(filenames) // num_arrays_shard # 100 shards for my subset of the dataset
group_size = 1
features = tfds.features.FeaturesDict({
"arr": tfds.features.Tensor(shape=(100,500,99), dtype=np.float32)
})
def _write_arrayrecord_shard(shard: int):
writer = array_record.ArrayRecordWriter(
f"{GRAIN_DATA_DIR}/data.array_record-{shard:05d}-of-{num_shards - 1:05d}",
f"group_size:{group_size}"
)
for fname in filenames[shard * num_arrays_shard : shard * num_arrays_shard + num_arrays_shard]:
_arr = np.load(fname).astype(np.float32)
tf_example = features.serialize_example({"arr": _arr})
writer.write(tf_example)
writer.close()
_ = process_map(_write_arrayrecord_shard, range(num_shards), max_workers=multiprocessing.cpu_count())
import grain.python as grain
ds = grain.ArrayRecordDataSource([str(f) for f in (GRAIN_DATA_DIR).iterdir()])
@dataclasses.dataclass
class ParseFeatures(grain.MapTransform):
def map(self, _features):
return features.deserialize_example_np(_features)
sampler = grain.SequentialSampler(num_records=len(filenames), shard_options=grain.NoSharding())
loader = grain.DataLoader(
data_source=ds,
operations=[ParseFeatures(), grain.Batch(5)],
sampler=sampler,
worker_buffer_size=1000
)
I benchmark the resulting loader with tfds.benchmark(loader, batch_size=5)
and I'm getting 3 examples per second, which seems really slow. Manually looping through the DataLoader and timing it is not any better, so I don't think this is a bug with the benchmark.
Reading each individual numPy file from the filesystem with numpy.load
yields about 140 examples per second.
In an identical setup where I use tf.io.TFRecordWriter
in my data conversion step, load it all as a TF Dataset and then benchmark it as follows:
ds = ds.batch(5, num_parallel_calls=5)
ds = ds.as_numpy_iterator()
tfds.benchmark(ds, num_iter=990, batch_size=5)
then I get roughly 130 samples per second, which isn't great but it's at least close to the naive solution of reading directly from the disk.
Without conversion to numPy / deserialisation, it's faster but not as fast as I'd expect. I'm getting around 53 examples per second without the ParseFeatures()
operation. Also, I tried setting worker_count=
in the DataLoader but I get an error "Processing Failed. Shutting down.". Though that is probably worth its own issue.
I'm trying to load a few thousand big arrays (each float32, shape=(100,500,99)) from ArrayRecord files with Grain but it's slow. Slower than TFRecords and TFDataset and slower than just loading from disk directly.
Am I missing the point of Grain / is it just not a good fit for my use case? Or are some of my settings wrong (shard size / buffer size / serialisation strategy)?
I'm using grain_nightly==0.0.6
and array_record==0.5.0
. I'm on a 1 TB NVMe SSD and have a Ryzen 9 7950X CPU with 64GB of DDR5 RAM on Linux.
Hi,
I'm trying pip install grain
but I get the following error:
ERROR: Could not find a version that satisfies the requirement grain (from versions: none)
ERROR: No matching distribution found for grain
Does someone have an idea why this is happening ?
Thank you
There may be a bug in grain.python.RangeDataSource.__getitem__()
.
Minimal code to reproduce:
import grain.python as pygrain
x = pygrain.RangeDataSource(start=1,stop=10,step=2)
print(list(x))
Outputs:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
[<ipython-input-5-88dd2177f112>](https://localhost:8080/#) in <cell line: 3>()
1 import grain.python as pygrain
2 x = pygrain.RangeDataSource(start=1,stop=10,step=2)
----> 3 print(list(x))
[/usr/local/lib/python3.10/dist-packages/grain/_src/python/data_sources.py](https://localhost:8080/#) in __getitem__(self, record_key)
92 def __getitem__(self, record_key: SupportsIndex) -> int:
93 record_key = record_key.__index__()
---> 94 assert record_key >= 0 and record_key < self._len
95 return self._start + record_key * self._step
96
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.