Giter Club home page Giter Club logo

Comments (9)

tottenjordan avatar tottenjordan commented on June 7, 2024

@tmbdev in the suggested link from the error message, it illustrates using a worker_init_fn like this:

#Define a `worker_init_fn` that configures each dataset copy differently
def worker_init_fn(worker_id):
   worker_info = torch.utils.data.get_worker_info()
   dataset = worker_info.dataset  #the dataset copy in this worker process
   overall_start = dataset.start
   overall_end = dataset.end
   #configure the dataset to only process the split workload
   per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
   worker_id = worker_info.id
   dataset.start = overall_start + worker_id * per_worker
   dataset.end = min(dataset.start + per_worker, overall_end)

#Mult-process loading with the custom `worker_init_fn`
#Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))

This is not needed in my case because I am doing essentially doing this in the wds.Dataset, right?

from webdataset.

tmbdev avatar tmbdev commented on June 7, 2024

The current PyTorch DataLoader is complex and makes exact epochs in the distributed setting tricky. For large datasets, this is actually not much of an issue, since "epochs" aren't that useful to begin with (you need to save/evaluate more often than once-per-epoch anyway). Statistically, there is also little reason for the exactly-once-per-epoch approach commonly used with smaller datasets; the original justification was that SGD was viewed as an approximation to full gradient descent on the entire dataset, but even in that setting, sampling-with-replacement is a perfectly sound training strategy.

The default for WebDataset is to iterate over all shards on all nodes, giving you an epoch size that is num_workers as large as the single node epoch size. But this also happens to give you exactly the same number of samples on all nodes, making DataLoader happy. Training loops of this kind look like:

ds = wds.WebDataset(...)
loader = DataLoader(ds, ...)
for inputs, outputs in islice(wds.repeatedly(loader), 0, num_batches):
    train_batch(model, inputs, outputs)
    if every_hour():
        # evaluate and save

So, that above is recommended practice for large datasets. Now, for the nitty gritty if you can't use that.

For relatively small datasets and datasets with a lot of prior work reported in epochs (like Imagenet), you may still want to recreate exact epochs with WebDataset. In that case, you shouldn't use ResizedDataset, since that will also change the composition of your epoch. The warning you get about length mismatches is really just that: a warning; you can ignore it. Furthermore, if you don't call len on the loader, you shouldn't get the warning at all.

It's best not to set the length on the WebDataset at all. If you do set it, you need to do so correctly, which requires precomputing the number of samples in each shard and then setting the length of the particular WebDataset on that worker to the total number of samples.

Note that DataLoader pretty much forces you to use batch_size=None in many cases. This means that samples are not shuffled between workers, which, depending on how your dataset was created and how you split it, has a bigger effect on training. But WebDataset has a good solution: you can use .unbatched().shuffle(10000).batched(batch_size) on the output of the DataLoader. This even lets you use different batch sizes for the workers and for training.

I realize this needs more documentation; our multinode cluster is currently moving, I'll try to update the examples when that's up and running again and I can test the examples again.

TODO:

  • add DataLoader+Composable+Shorthands wrapper to make rebatching simpler
  • add shard_length table option to WebDataset
  • have batched(...) recompute length of underlying WebDataset
  • update examples

from webdataset.

tottenjordan avatar tottenjordan commented on June 7, 2024

Thanks @tmbdev , this was very helpful. I would like to try recreating epochs, first. If that doesnt workout I will look at other options.

I tried not specifying length=... and length=None in wds.WebDataset(), but received PyTorch errors related to len and could not proceed with training.

  • When you say "if you don't call len on the loader, you shouldn't get warning at all" are you referring to within the PyTorch DataLoader or WebDataset?
    • The len call in PyTorch DataLoader returns an estimate based on len(dataset) / batch_size when dataset is an IterableDataset source code,

When I set length=num_batches and remove ResizedDataset. Something like this:

trainsize = 1281167
def make_train_loader(img_dim, shuffle=10000, batch_size=FLAGS.batch_size):
    
    num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
    epoch_size = trainsize // num_dataset_instances
    # num_batches = (epoch_size + batch_size - 1) // batch_size
    num_batches = epoch_size // batch_size
    image_transform = ...

    dataset = (
        wds.WebDataset("pipe:gsutil cat gs://$BUCKET/shards/imagenet-train-{000000..001281}.tar", 
        splitter=wds.split_by_worker, nodesplitter=my_node_splitter, shardshuffle=True, length=num_batches)
        .shuffle(shuffle)
        .decode("pil")
        .to_tuple("ppm;jpg;jpeg;png", "cls")
        .map_tuple(image_transform, identity)
        .batched(batch_size)
        )
    loader = torch.utils.data.DataLoader(dataset, batch_size=None, shuffle=False, num_workers=FLAGS.num_workers)
    return loader
  • This works really well for the training and validation loops until the last specified epoch (tried this on epochs=3, 5, 10).
  • Average epoch time is ~40 seconds; loss and accuracy are comparable to other configs.
    • (For reference, similar Pytorch config where training data is stored on persistent disk, epoch time is ~1:30 mins)
  • Still got some of the original length mismatch warnings, but training proceeded (previously, training would hang and not continue).
    • Warning is something like Length of IterableDataset <webdataset.dataset.Processor...

So, now training loops through the specified epochs (e.g., 10). However, on the last specified epoch, I get a BrokenPipe Error:

2021-03-16 13:49:42 10.164.0.61 [1] Exception ignored in: <_io.TextIOWrapper name='<stdout>' mode='w' encoding='UTF-8'>
2021-03-16 13:49:42 10.164.0.61 [1] BrokenPipeError: [Errno 32] Broken pipe

2021-03-16 13:49:42 10.164.0.61 [1] The above exception was the direct cause of the following exception:
2021-03-16 13:49:42 10.164.0.61 [1] 
2021-03-16 13:49:42 10.164.0.61 [1] Traceback (most recent call last):
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/threading.py", line 916, in _bootstrap_inner
2021-03-16 13:49:42 10.164.0.61 [1]     self.run()
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/threading.py", line 864, in run
2021-03-16 13:49:42 10.164.0.61 [1]     self._target(*self._args, **self._kwargs)
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 141, in _loader_worker
2021-03-16 13:49:42 10.164.0.61 [1]     _, data = next(data_iter)
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
2021-03-16 13:49:42 10.164.0.61 [1]     data = self._next_data()
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1068, in _next_data
2021-03-16 13:49:42 10.164.0.61 [1]     idx, data = self._get_data()
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1034, in _get_data
2021-03-16 13:49:42 10.164.0.61 [1]     success, data = self._try_get_data()
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 885, in _try_get_data
2021-03-16 13:49:42 10.164.0.61 [1]     raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
2021-03-16 13:49:42 10.164.0.61 [1] RuntimeError: DataLoader worker (pid(s) 50847, 51213, 52225, 54052) exited unexpectedly
  • I can confirm I am cycling through each of the (1) train loop, (2) validation loop, and (3) epoch loop
    • The error occurs after the last line of code in the epoch loop
  • My interpretation of the BrokenPipe error message suggests the error occurs in a "data read process", which leads me to think either:
    • (1) the processes to read data are trying to read too much data
    • (2) the processes to write data are writing too much data
  • I suspect this has to do with my use of wds.Dataset and/or torch.utils.data.DataLoader()
    • If it is related to the Length warning, then perhaps I need to compute length=epoch_size instead of how I have calculated num_batches ?
    • This would match the PyTorch docs mentioned above, where DataLoader returns an estimate based on len(dataset) / batch_size
    • I tried drop_last=True in the DataLoader, but this requires specifying batch_size=batch_size in the DataLoader
    • If specifying batch_size in the DataLoader, should I still use .unbatched().shuffle(10000).batched(batch_size) in WebDataset?

Thanks again for your help

Attached training output, with BrokenPipe error and training script

torchXLA-webdataset-trial2-metrics-debug-BrokenPipe.txt
test_train_mp_imagenet_wds.txt

from webdataset.

tottenjordan avatar tottenjordan commented on June 7, 2024

Following config:

  • length=epoch_size
  • batch_size=128
  • num_workers=4

Results in the following:

  • No IterableDataset length mismatch warnings
  • Still getting BrokenPipe errors after the last epoch
  • Training loss and validation accuracy behaving as expected
  • Epoch training time averaging ~40s
def make_train_loader(img_dim, shuffle=10000, batch_size=FLAGS.batch_size):
    
    num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
    epoch_size = trainsize // num_dataset_instances
    # num_batches = epoch_size // batch_size

    image_transform = transforms.Compose(
        [
            transforms.RandomResizedCrop(img_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]
    )

    dataset = (
        wds.WebDataset("pipe:gsutil cat gs://$BUCKET/shards/imagenet-train-{000000..001281}.tar", 
        splitter=wds.split_by_worker, nodesplitter=my_node_splitter, shardshuffle=True, length=epoch_size)
        .shuffle(shuffle)
        .decode("pil")
        .to_tuple("ppm;jpg;jpeg;png", "cls")
        .map_tuple(image_transform, identity)
        .batched(batch_size)
        )
    loader = torch.utils.data.DataLoader(dataset, batch_size=None, shuffle=False, drop_last=False, num_workers=FLAGS.num_workers)
    return loader

Training loop:

    def train_loop_fn(loader, epoch):
        num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
        epoch_size = trainsize // num_dataset_instances
        num_batches = epoch_size // FLAGS.batch_size
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(islice(repeatedly(loader), 0, num_batches)):
                  ....

from webdataset.

tmbdev avatar tmbdev commented on June 7, 2024

OK, a few things.

First, the BrokenPipeError is just being ignored; it doesn't cause the process to exit:

    Exception ignored in: <_io.TextIOWrapper name='<stdout>' mode='w' encoding='UTF-8'>
    [1] BrokenPipeError: [Errno 32] Broken pipe

I'm not sure why it warns you about that, but the pipe is indeed broken and that is OK.

Likewise, the length warning is just that, a warning, and doesn't cause the process to exit either.

It looks to me that the actual cause of the problem is that _try_get_data gets a queue.Empty exception.

This might be due to different numbers of samples present in different workers.The precise way of fixing that is to make sure that all the shards have exactly the same number of samples, and that the number of shards is divisible by the product of the number of workers and the number of jobs. That's a constraint imposed by synchronous distributed SGD and is required regardless of what I/O library you're using.

If you can't do that, you can probably use this (update to the latest version of webdataset, which adds the repeat(...) method):

dataset = (
    wds.WebDataset("pipe:gsutil cat gs://$BUCKET/shards/imagenet-train-{000000..001281}.tar", 
    splitter=wds.split_by_worker, nodesplitter=my_node_splitter, shardshuffle=True, length=epoch_size)
    .shuffle(shuffle)
    .decode("pil")
    .to_tuple("ppm;jpg;jpeg;png", "cls")
    .map_tuple(image_transform, identity)
    .batched(batch_size)
    .repeat(nepochs=9999999)  # NB: repeat here
    )
    ...
    for step, (data, target) in enumerate(islice(loader, 0, num_batches)):  # NB: no repeatedly
        ...

This ensures that all the queues are always full until you reach num_batches, so you shouldn't get queue.Empty exceptions.

What confuses me about the code is that you seem to be training fine for several epochs before your code exits. But give it a try and see whether that fixes the problem.

from webdataset.

tottenjordan avatar tottenjordan commented on June 7, 2024

You are correct that the "BrokenPipe Error" is being ignored and not causing the errors at the end (e.g., DataLoader gracefully exiting).

  • As i experimented with larger values for num_workers and batch_size, i noticed this error message during training, but training was not interrupted.

In fact, the model is training well and the errors at the end are insignificant for my task. Few things to note:

  • I could not get the .repeat(nepochs=...) code to work. It caused the error below
  • instead I continued to use for step, (data, target) in enumerate(islice(repeatedly(loader), 0, test_steps)): and it works fine
[0] Exception in device=TPU:0: object of type 'Repeatedly' has no len()
[0] Traceback (most recent call last):
[0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
[0]     _start_fn(index, pf_cfg, fn, args)
[0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
[0]     fn(gindex, *args)
[0]   File "/tmp/thepackage/test_train_mp_wds.py", line 382, in _mp_fn
[0]    accuracy = train_imagenet()
[0]   File "/tmp/thepackage/test_train_mp_wds.py", line 360, in train_imagenet
[0]     train_loop_fn(train_device_loader, epoch)
[0]   File "/tmp/thepackage/test_train_mp_wds.py", line 314, in train_loop_fn
[0]     for step, (data, target) in enumerate(islice(loader, 0, train_steps)): #wds.repeatedly
[0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 197, in __iter__
[0]     **self._parallel_loader_kwargs)
[0]  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 80, in __init__
[0]     self._per_device_samples = len(loader) // len(devices)
[0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 388, in __len__
[0]     length = self._IterableDataset_len_called = len(self.dataset)  # type: ignore
[0] TypeError: object of type 'Repeatedly' has no len()
[0] Traceback (most recent call last):
[0]  File "/tmp/thepackage/test_train_mp_wds.py", line 390, in <module>
[0]     xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores, start_method='fork')
[0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 395, in spawn
[0]    start_method=start_method)
[0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
[0]    while not context.join():
[0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 112, in join
[0]    (error_index, exitcode)
[0] Exception: process 0 terminated with exit code 17
  • The last shard from the imagenet ~generate_shards.py script, does not have 1,000 samples like the others

    • I removed this partial file and one other so that my shards were divisible by (num_workers * n_jobs), but this did not yield noticeable improvement
  • I did adjust the calculations for WebDataset(length=..), epoch_size, and epoch_steps.

    • no mismatch warnings
    • New calcs make it easy to compare with metrics from other benchmark tests (e.g., samples per epoch, avg. epoch training time, accuracy/loss).

For the WebDataset.(length=..) in the train_loader, I calculated dataset instances and epoch_size same as before

  • Because I want separate data sets for each num_workers-world_size combination (e.g., each dataset instance per device needs to be broken into num_workers instances)
  • And because I am splitting by node and worker
trainsize = 1281167
num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
epoch_size = trainsize // num_dataset_instances

   dataset = (
        wds.WebDataset("pipe:gsutil cat gs://$BUCKET/shards/imagenet-train-{000000..001281}.tar",
        splitter=wds.split_by_worker, nodesplitter=my_node_splitter, shardshuffle=True, length=epoch_size)

I did the same for the val_loader, except I chose to only split by worker.

testsize = 50000
num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
 epoch_test_size = testsize // num_dataset_instances

    val_dataset = (
        wds.WebDataset("pipe:gsutil cat gs://$BUCKET/shards/imagenet-val-{000000..000049}.tar", 
        splitter=wds.split_by_worker, nodesplitter=None, shardshuffle=False, length=epoch_test_size)

Then for the train and test loops, I calculated the steps different from before

  • instead of the num_batches below, used train_steps
  • Using the train_steps calc allows the loops to cycle through most or all of the length=epoch_size stated in the DataLoaders webdatasets
    def train_loop_fn(loader, epoch):
        # num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
        # epoch_size = trainsize // num_dataset_instances
        # num_batches = epoch_size // FLAGS.batch_size
        train_steps = trainsize // (FLAGS.batch_size * xm.xrt_world_size())
        model.train()
        for step, (data, target) in enumerate(islice(repeatedly(loader), 0, train_steps)):
             ...

The results so far are exactly what I was looking for.

  • Training PyTorch XLA on TPU Pods is well documented for scenarios where training data fits on the VM/PD, but it is not well documented for scenarios where data is so large it needs to be stored in object storage.
  • With webdataset streaming data from GCS, the the training time is almost identical to configs where data is stored on local disk (~1:30 per epoch).

from webdataset.

tmbdev avatar tmbdev commented on June 7, 2024

I'm glad it works and is fast.

I have to add the len method to Repeatedly.

FWIW, torch_xla should not call len(loader); that's really a bug that needs to get fixed, since increasingly, loaders will just not provide a length.

In any case, to help you make tradeoffs, I've created a gist that illustrates the different possible tradeoffs a bit better:

https://gist.github.com/tmbdev/ad3cc45c7ff86fcebde585f3b073d721

Here are some sample runs:

    $ python3 sim-dist.py
    === rank 0 batches 36 total 360
    === rank 1 batches 26 total 260
    === rank 2 batches 26 total 260
    $ python3 sim-dist.py --nbatches 36
    === rank 0 batches 36 total 360
    === rank 1 batches 26 total 260 TOO FEW BATCHES
    === rank 2 batches 26 total 260 TOO FEW BATCHES
    $ python3 sim-dist.py --nbatches 26
    === rank 0 batches 26 total 260
    === rank 1 batches 26 total 260
    === rank 2 batches 26 total 260
    $ python3 sim-dist.py --nbatches 36 --dsrepeat 2
    === rank 0 batches 36 total 360
    === rank 1 batches 36 total 360
    === rank 2 batches 36 total 360
    $ python3 sim-dist.py --nbatches 36 --repeat 2
    === rank 0 batches 36 total 360
    === rank 1 batches 36 total 360
    === rank 2 batches 36 total 360

Other than fixing the len call, is there any remaining issue with WebDataset?

from webdataset.

tottenjordan avatar tottenjordan commented on June 7, 2024

Thank you for that @tmbdev !

I'm going to test different configurations to better understand whats going on after the last epoch.

As you pointed out, the BrokenPipe error is being ignored and is not causing the "final epoch error". Can you help me understand with this BrokenPipe error is?

  • Is this related to my use of wds.WebDataset("pipe:gsutil cat gs://?
  • In the source code, the multiprocessing dataloaders refer to their queue process as a pipe, but I don't think these two are the same

from webdataset.

tmbdev avatar tmbdev commented on June 7, 2024

Yes, the "pipe:..." creates a subprocess that is connected to Python via a UNIX pipe. The PyTorch workers read from such pipes. When the worker stops reading from the pipe before completion (e.g., because you stop reading before consuming all samples), the subprocess has no place to write the data anymore and it is therefore killed by a SIGPIPE signal. This is reported as a broken pipe exit status, and Python (sometimes) reports it. This is completely normal and harmless. In fact, the PyTorch source code goes out of its way of ignoring this condition.

Here is some simple code that reproduces it:

    $ python3
    Python 3.8.5 (default, Jan 27 2021, 15:41:15) 
    [GCC 9.3.0] on linux
    Type "help", "copyright", "credits" or "license" for more information.
    >>> import subprocess
    >>> proc = subprocess.Popen("gsutil cat gs://lpr-ocropus4/ia1-000000.tar", stdout=subprocess.PIPE, shell=True)
    >>> proc.stdout.read(10)
    b'ourfamilya'
    >>> proc.stdout.close()
    >>> Exception ignored in: <_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
    BrokenPipeError: [Errno 32] Broken pipe

Replace the "gsutil cat" with an equivalent "curl" and the error disappears. It's also not reported when you run this particular script noninteractively. I'm not sure where the differences come from.

from webdataset.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.