Comments (9)
@tmbdev in the suggested link from the error message, it illustrates using a worker_init_fn
like this:
#Define a `worker_init_fn` that configures each dataset copy differently
def worker_init_fn(worker_id):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset #the dataset copy in this worker process
overall_start = dataset.start
overall_end = dataset.end
#configure the dataset to only process the split workload
per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
worker_id = worker_info.id
dataset.start = overall_start + worker_id * per_worker
dataset.end = min(dataset.start + per_worker, overall_end)
#Mult-process loading with the custom `worker_init_fn`
#Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
This is not needed in my case because I am doing essentially doing this in the wds.Dataset
, right?
from webdataset.
The current PyTorch DataLoader is complex and makes exact epochs in the distributed setting tricky. For large datasets, this is actually not much of an issue, since "epochs" aren't that useful to begin with (you need to save/evaluate more often than once-per-epoch anyway). Statistically, there is also little reason for the exactly-once-per-epoch approach commonly used with smaller datasets; the original justification was that SGD was viewed as an approximation to full gradient descent on the entire dataset, but even in that setting, sampling-with-replacement is a perfectly sound training strategy.
The default for WebDataset is to iterate over all shards on all nodes, giving you an epoch size that is num_workers
as large as the single node epoch size. But this also happens to give you exactly the same number of samples on all nodes, making DataLoader happy. Training loops of this kind look like:
ds = wds.WebDataset(...)
loader = DataLoader(ds, ...)
for inputs, outputs in islice(wds.repeatedly(loader), 0, num_batches):
train_batch(model, inputs, outputs)
if every_hour():
# evaluate and save
So, that above is recommended practice for large datasets. Now, for the nitty gritty if you can't use that.
For relatively small datasets and datasets with a lot of prior work reported in epochs (like Imagenet), you may still want to recreate exact epochs with WebDataset. In that case, you shouldn't use ResizedDataset, since that will also change the composition of your epoch. The warning you get about length mismatches is really just that: a warning; you can ignore it. Furthermore, if you don't call len
on the loader, you shouldn't get the warning at all.
It's best not to set the length on the WebDataset at all. If you do set it, you need to do so correctly, which requires precomputing the number of samples in each shard and then setting the length of the particular WebDataset on that worker to the total number of samples.
Note that DataLoader pretty much forces you to use batch_size=None in many cases. This means that samples are not shuffled between workers, which, depending on how your dataset was created and how you split it, has a bigger effect on training. But WebDataset has a good solution: you can use .unbatched().shuffle(10000).batched(batch_size)
on the output of the DataLoader. This even lets you use different batch sizes for the workers and for training.
I realize this needs more documentation; our multinode cluster is currently moving, I'll try to update the examples when that's up and running again and I can test the examples again.
TODO:
- add DataLoader+Composable+Shorthands wrapper to make rebatching simpler
- add shard_length table option to WebDataset
- have batched(...) recompute length of underlying WebDataset
- update examples
from webdataset.
Thanks @tmbdev , this was very helpful. I would like to try recreating epochs, first. If that doesnt workout I will look at other options.
I tried not specifying length=...
and length=None
in wds.WebDataset()
, but received PyTorch errors related to len
and could not proceed with training.
- When you say "if you don't call
len
on the loader, you shouldn't get warning at all" are you referring to within the PyTorchDataLoader
orWebDataset
?- The
len
call in PyTorchDataLoader
returns an estimate based onlen(dataset) / batch_size
when dataset is anIterableDataset
source code,
- The
When I set length=num_batches
and remove ResizedDataset
. Something like this:
trainsize = 1281167
def make_train_loader(img_dim, shuffle=10000, batch_size=FLAGS.batch_size):
num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
epoch_size = trainsize // num_dataset_instances
# num_batches = (epoch_size + batch_size - 1) // batch_size
num_batches = epoch_size // batch_size
image_transform = ...
dataset = (
wds.WebDataset("pipe:gsutil cat gs://$BUCKET/shards/imagenet-train-{000000..001281}.tar",
splitter=wds.split_by_worker, nodesplitter=my_node_splitter, shardshuffle=True, length=num_batches)
.shuffle(shuffle)
.decode("pil")
.to_tuple("ppm;jpg;jpeg;png", "cls")
.map_tuple(image_transform, identity)
.batched(batch_size)
)
loader = torch.utils.data.DataLoader(dataset, batch_size=None, shuffle=False, num_workers=FLAGS.num_workers)
return loader
- This works really well for the training and validation loops until the last specified epoch (tried this on epochs=3, 5, 10).
- Average epoch time is ~40 seconds; loss and accuracy are comparable to other configs.
- (For reference, similar Pytorch config where training data is stored on persistent disk, epoch time is ~1:30 mins)
- Still got some of the original length mismatch warnings, but training proceeded (previously, training would hang and not continue).
- Warning is something like
Length of IterableDataset <webdataset.dataset.Processor...
- Warning is something like
So, now training loops through the specified epochs (e.g., 10). However, on the last specified epoch, I get a BrokenPipe Error:
2021-03-16 13:49:42 10.164.0.61 [1] Exception ignored in: <_io.TextIOWrapper name='<stdout>' mode='w' encoding='UTF-8'>
2021-03-16 13:49:42 10.164.0.61 [1] BrokenPipeError: [Errno 32] Broken pipe
2021-03-16 13:49:42 10.164.0.61 [1] The above exception was the direct cause of the following exception:
2021-03-16 13:49:42 10.164.0.61 [1]
2021-03-16 13:49:42 10.164.0.61 [1] Traceback (most recent call last):
2021-03-16 13:49:42 10.164.0.61 [1] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/threading.py", line 916, in _bootstrap_inner
2021-03-16 13:49:42 10.164.0.61 [1] self.run()
2021-03-16 13:49:42 10.164.0.61 [1] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/threading.py", line 864, in run
2021-03-16 13:49:42 10.164.0.61 [1] self._target(*self._args, **self._kwargs)
2021-03-16 13:49:42 10.164.0.61 [1] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 141, in _loader_worker
2021-03-16 13:49:42 10.164.0.61 [1] _, data = next(data_iter)
2021-03-16 13:49:42 10.164.0.61 [1] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
2021-03-16 13:49:42 10.164.0.61 [1] data = self._next_data()
2021-03-16 13:49:42 10.164.0.61 [1] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1068, in _next_data
2021-03-16 13:49:42 10.164.0.61 [1] idx, data = self._get_data()
2021-03-16 13:49:42 10.164.0.61 [1] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1034, in _get_data
2021-03-16 13:49:42 10.164.0.61 [1] success, data = self._try_get_data()
2021-03-16 13:49:42 10.164.0.61 [1] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 885, in _try_get_data
2021-03-16 13:49:42 10.164.0.61 [1] raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
2021-03-16 13:49:42 10.164.0.61 [1] RuntimeError: DataLoader worker (pid(s) 50847, 51213, 52225, 54052) exited unexpectedly
- I can confirm I am cycling through each of the (1) train loop, (2) validation loop, and (3) epoch loop
- The error occurs after the last line of code in the epoch loop
- My interpretation of the BrokenPipe error message suggests the error occurs in a "data read process", which leads me to think either:
- (1) the processes to read data are trying to read too much data
- (2) the processes to write data are writing too much data
- I suspect this has to do with my use of
wds.Dataset
and/ortorch.utils.data.DataLoader()
- If it is related to the Length warning, then perhaps I need to compute
length=epoch_size
instead of how I have calculatednum_batches
? - This would match the PyTorch docs mentioned above, where
DataLoader
returns an estimate based onlen(dataset) / batch_size
- I tried
drop_last=True
in theDataLoader
, but this requires specifyingbatch_size=batch_size
in theDataLoader
- If specifying
batch_size
in theDataLoader
, should I still use.unbatched().shuffle(10000).batched(batch_size)
inWebDataset
?
- If it is related to the Length warning, then perhaps I need to compute
Thanks again for your help
Attached training output, with BrokenPipe error and training script
torchXLA-webdataset-trial2-metrics-debug-BrokenPipe.txt
test_train_mp_imagenet_wds.txt
from webdataset.
Following config:
length=epoch_size
batch_size=128
num_workers=4
Results in the following:
- No IterableDataset length mismatch warnings
- Still getting BrokenPipe errors after the last epoch
- Training loss and validation accuracy behaving as expected
- Epoch training time averaging ~40s
def make_train_loader(img_dim, shuffle=10000, batch_size=FLAGS.batch_size):
num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
epoch_size = trainsize // num_dataset_instances
# num_batches = epoch_size // batch_size
image_transform = transforms.Compose(
[
transforms.RandomResizedCrop(img_dim),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
)
dataset = (
wds.WebDataset("pipe:gsutil cat gs://$BUCKET/shards/imagenet-train-{000000..001281}.tar",
splitter=wds.split_by_worker, nodesplitter=my_node_splitter, shardshuffle=True, length=epoch_size)
.shuffle(shuffle)
.decode("pil")
.to_tuple("ppm;jpg;jpeg;png", "cls")
.map_tuple(image_transform, identity)
.batched(batch_size)
)
loader = torch.utils.data.DataLoader(dataset, batch_size=None, shuffle=False, drop_last=False, num_workers=FLAGS.num_workers)
return loader
Training loop:
def train_loop_fn(loader, epoch):
num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
epoch_size = trainsize // num_dataset_instances
num_batches = epoch_size // FLAGS.batch_size
tracker = xm.RateTracker()
model.train()
for step, (data, target) in enumerate(islice(repeatedly(loader), 0, num_batches)):
....
from webdataset.
OK, a few things.
First, the BrokenPipeError is just being ignored; it doesn't cause the process to exit:
Exception ignored in: <_io.TextIOWrapper name='<stdout>' mode='w' encoding='UTF-8'>
[1] BrokenPipeError: [Errno 32] Broken pipe
I'm not sure why it warns you about that, but the pipe is indeed broken and that is OK.
Likewise, the length warning is just that, a warning, and doesn't cause the process to exit either.
It looks to me that the actual cause of the problem is that _try_get_data gets a queue.Empty exception.
This might be due to different numbers of samples present in different workers.The precise way of fixing that is to make sure that all the shards have exactly the same number of samples, and that the number of shards is divisible by the product of the number of workers and the number of jobs. That's a constraint imposed by synchronous distributed SGD and is required regardless of what I/O library you're using.
If you can't do that, you can probably use this (update to the latest version of webdataset, which adds the repeat(...) method):
dataset = (
wds.WebDataset("pipe:gsutil cat gs://$BUCKET/shards/imagenet-train-{000000..001281}.tar",
splitter=wds.split_by_worker, nodesplitter=my_node_splitter, shardshuffle=True, length=epoch_size)
.shuffle(shuffle)
.decode("pil")
.to_tuple("ppm;jpg;jpeg;png", "cls")
.map_tuple(image_transform, identity)
.batched(batch_size)
.repeat(nepochs=9999999) # NB: repeat here
)
...
for step, (data, target) in enumerate(islice(loader, 0, num_batches)): # NB: no repeatedly
...
This ensures that all the queues are always full until you reach num_batches, so you shouldn't get queue.Empty exceptions.
What confuses me about the code is that you seem to be training fine for several epochs before your code exits. But give it a try and see whether that fixes the problem.
from webdataset.
You are correct that the "BrokenPipe Error" is being ignored and not causing the errors at the end (e.g., DataLoader gracefully exiting).
- As i experimented with larger values for
num_workers
andbatch_size
, i noticed this error message during training, but training was not interrupted.
In fact, the model is training well and the errors at the end are insignificant for my task. Few things to note:
- I could not get the
.repeat(nepochs=...)
code to work. It caused the error below - instead I continued to use
for step, (data, target) in enumerate(islice(repeatedly(loader), 0, test_steps)):
and it works fine
[0] Exception in device=TPU:0: object of type 'Repeatedly' has no len()
[0] Traceback (most recent call last):
[0] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
[0] _start_fn(index, pf_cfg, fn, args)
[0] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
[0] fn(gindex, *args)
[0] File "/tmp/thepackage/test_train_mp_wds.py", line 382, in _mp_fn
[0] accuracy = train_imagenet()
[0] File "/tmp/thepackage/test_train_mp_wds.py", line 360, in train_imagenet
[0] train_loop_fn(train_device_loader, epoch)
[0] File "/tmp/thepackage/test_train_mp_wds.py", line 314, in train_loop_fn
[0] for step, (data, target) in enumerate(islice(loader, 0, train_steps)): #wds.repeatedly
[0] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 197, in __iter__
[0] **self._parallel_loader_kwargs)
[0] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 80, in __init__
[0] self._per_device_samples = len(loader) // len(devices)
[0] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 388, in __len__
[0] length = self._IterableDataset_len_called = len(self.dataset) # type: ignore
[0] TypeError: object of type 'Repeatedly' has no len()
[0] Traceback (most recent call last):
[0] File "/tmp/thepackage/test_train_mp_wds.py", line 390, in <module>
[0] xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores, start_method='fork')
[0] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 395, in spawn
[0] start_method=start_method)
[0] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
[0] while not context.join():
[0] File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 112, in join
[0] (error_index, exitcode)
[0] Exception: process 0 terminated with exit code 17
-
The last shard from the imagenet ~
generate_shards.py
script, does not have 1,000 samples like the others- I removed this partial file and one other so that my shards were divisible by (
num_workers
*n_jobs
), but this did not yield noticeable improvement
- I removed this partial file and one other so that my shards were divisible by (
-
I did adjust the calculations for
WebDataset(length=..)
, epoch_size, and epoch_steps.- no mismatch warnings
- New calcs make it easy to compare with metrics from other benchmark tests (e.g., samples per epoch, avg. epoch training time, accuracy/loss).
For the WebDataset.(length=..)
in the train_loader
, I calculated dataset instances and epoch_size
same as before
- Because I want separate data sets for each
num_workers
-world_size
combination (e.g., each dataset instance per device needs to be broken intonum_workers
instances) - And because I am splitting by node and worker
trainsize = 1281167
num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
epoch_size = trainsize // num_dataset_instances
dataset = (
wds.WebDataset("pipe:gsutil cat gs://$BUCKET/shards/imagenet-train-{000000..001281}.tar",
splitter=wds.split_by_worker, nodesplitter=my_node_splitter, shardshuffle=True, length=epoch_size)
I did the same for the val_loader
, except I chose to only split by worker.
testsize = 50000
num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
epoch_test_size = testsize // num_dataset_instances
val_dataset = (
wds.WebDataset("pipe:gsutil cat gs://$BUCKET/shards/imagenet-val-{000000..000049}.tar",
splitter=wds.split_by_worker, nodesplitter=None, shardshuffle=False, length=epoch_test_size)
Then for the train and test loops, I calculated the steps different from before
- instead of the
num_batches
below, usedtrain_steps
- Using the
train_steps
calc allows the loops to cycle through most or all of thelength=epoch_size
stated in the DataLoaders webdatasets
def train_loop_fn(loader, epoch):
# num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
# epoch_size = trainsize // num_dataset_instances
# num_batches = epoch_size // FLAGS.batch_size
train_steps = trainsize // (FLAGS.batch_size * xm.xrt_world_size())
model.train()
for step, (data, target) in enumerate(islice(repeatedly(loader), 0, train_steps)):
...
The results so far are exactly what I was looking for.
- Training PyTorch XLA on TPU Pods is well documented for scenarios where training data fits on the VM/PD, but it is not well documented for scenarios where data is so large it needs to be stored in object storage.
- With webdataset streaming data from GCS, the the training time is almost identical to configs where data is stored on local disk (~1:30 per epoch).
from webdataset.
I'm glad it works and is fast.
I have to add the len method to Repeatedly.
FWIW, torch_xla should not call len(loader); that's really a bug that needs to get fixed, since increasingly, loaders will just not provide a length.
In any case, to help you make tradeoffs, I've created a gist that illustrates the different possible tradeoffs a bit better:
https://gist.github.com/tmbdev/ad3cc45c7ff86fcebde585f3b073d721
Here are some sample runs:
$ python3 sim-dist.py
=== rank 0 batches 36 total 360
=== rank 1 batches 26 total 260
=== rank 2 batches 26 total 260
$ python3 sim-dist.py --nbatches 36
=== rank 0 batches 36 total 360
=== rank 1 batches 26 total 260 TOO FEW BATCHES
=== rank 2 batches 26 total 260 TOO FEW BATCHES
$ python3 sim-dist.py --nbatches 26
=== rank 0 batches 26 total 260
=== rank 1 batches 26 total 260
=== rank 2 batches 26 total 260
$ python3 sim-dist.py --nbatches 36 --dsrepeat 2
=== rank 0 batches 36 total 360
=== rank 1 batches 36 total 360
=== rank 2 batches 36 total 360
$ python3 sim-dist.py --nbatches 36 --repeat 2
=== rank 0 batches 36 total 360
=== rank 1 batches 36 total 360
=== rank 2 batches 36 total 360
Other than fixing the len call, is there any remaining issue with WebDataset?
from webdataset.
Thank you for that @tmbdev !
I'm going to test different configurations to better understand whats going on after the last epoch.
As you pointed out, the BrokenPipe error is being ignored and is not causing the "final epoch error". Can you help me understand with this BrokenPipe error is?
- Is this related to my use of
wds.WebDataset("pipe:gsutil cat gs://
? - In the source code, the multiprocessing dataloaders refer to their queue process as a pipe, but I don't think these two are the same
from webdataset.
Yes, the "pipe:..." creates a subprocess that is connected to Python via a UNIX pipe. The PyTorch workers read from such pipes. When the worker stops reading from the pipe before completion (e.g., because you stop reading before consuming all samples), the subprocess has no place to write the data anymore and it is therefore killed by a SIGPIPE signal. This is reported as a broken pipe exit status, and Python (sometimes) reports it. This is completely normal and harmless. In fact, the PyTorch source code goes out of its way of ignoring this condition.
Here is some simple code that reproduces it:
$ python3
Python 3.8.5 (default, Jan 27 2021, 15:41:15)
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import subprocess
>>> proc = subprocess.Popen("gsutil cat gs://lpr-ocropus4/ia1-000000.tar", stdout=subprocess.PIPE, shell=True)
>>> proc.stdout.read(10)
b'ourfamilya'
>>> proc.stdout.close()
>>> Exception ignored in: <_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
BrokenPipeError: [Errno 32] Broken pipe
Replace the "gsutil cat" with an equivalent "curl" and the error disappears. It's also not reported when you run this particular script noninteractively. I'm not sure where the differences come from.
from webdataset.
Related Issues (20)
- Training stuck after training and validation steps
- Handling multiple annotations per image HOT 1
- Splitting features into separate archives HOT 5
- How to install wids?? HOT 1
- decode tensor type HOT 1
- slow data loading speed HOT 3
- Could someone help me to clarify the concept of multi-node training for webdataset ? HOT 1
- The behavior of one node multi-gpus with webdataset HOT 2
- AttributeError: module 'wids' has no attribute 'DistributedChunkedSampler' HOT 1
- FAQ : What's the meaning of n in `with_epoch(n)` HOT 2
- Distributed Training with videos not working? HOT 1
- [Errno 32] Broken pipe - Download Failed Error with S3 URLs HOT 1
- Webdataset (Liaon115M) + Torchlightning (pl.DataModule) with visualizing progressbar during training HOT 1
- Seed in multiprocessing (DDP) is not fixed in shuffle() HOT 1
- Update pypi with 0.2.88?
- How does shuffling work? HOT 1
- Restricting the number of samples in the dataset HOT 1
- wds.Decoder TypeError: 'functools.partial' object is not iterable HOT 2
- Loop through same tar file 10 times? HOT 1
- Excess memory usage when generating short sequence clips HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from webdataset.