Giter Club home page Giter Club logo

Comments (3)

tmbdev avatar tmbdev commented on June 15, 2024

It's unlikely that there is an actual memory leak, since all the operations in WebDataset are synchronous, and we have run exabytes of data through these pipelines.

Possible sources are:

  • attempts to decode entire videos into memory (with torchvision.io.read_video applied to long videos rather than clips)
  • too large num_workers
  • too large of a shuffle buffer, or a shuffle operation in the wrong place

Can you share your input pipeline, a "tar tvf ... | sed 30q" for one of your shards, and your machine configuration?

from webdataset.

allanbatista avatar allanbatista commented on June 15, 2024

So, the problem is not the webdataset (that is realy fast), it much faster than my GPUs could process and accumulate too much batch into memory.

To resolve this problem I created the follow PR #18 that limit how much batch is keep in memory.

example of my pipeline:
tar example: https://drive.google.com/file/d/18AdaxeWQO_dkA3O1po-9_kwBPb57scTz/view?usp=sharing

batch_size = 256
train_paths = ["gs://bla/bla/bla1", "gs://bla/bla/bla2"]
val_paths = ["gs://bla/bla/bla1", "gs://bla/bla/bla2"]
workers = 4

def decode_image(element):
    img = io.BytesIO(element)
    img = Image.open(img)
    return np.array(img)


def parse(input):
    data = json.loads(input['json'])
    return input['jpg'], data['subfamilia_idx']


def norm_imgs(imgs):
    imgs[:, 0] = (imgs[:, 0] - 0.485) / 0.229
    imgs[:, 1] = (imgs[:, 1] - 0.456) / 0.224
    imgs[:, 2] = (imgs[:, 2] - 0.406) / 0.225
    return imgs


augmentation = build_aug()
def decode_batch(batch):
    x, y = batch
    x = np.stack([decode_image(img) for img in x], 0)
    x = augmentation(images=x)
    x = np.rollaxis(x, 3, 1).astype('float32')
    x = norm_imgs(x)
    x = torch.from_numpy(x)
    return (x, torch.Tensor(np.array(y, dtype=np.int64)).type(torch.int64))


def decode_batch_val(batch):
    x, y = batch
    x = np.stack([decode_image(img) for img in x], 0)
    x = np.rollaxis(x, 3, 1).astype('float32')
    x = norm_imgs(x)
    x = torch.from_numpy(x)
    return (x, torch.Tensor(np.array(y, dtype=np.int64)).type(torch.int64))

# the strategy to execute ao process are much faster than process one by one.
dataset = wds.Dataset(train_paths).map(parse).shuffle(1000).batched(batch_size).map(decode_batch)
dataset_val = wds.Dataset(val_paths).map(parse).batched(batch_size).map(decode_batch_val)

dataloader = wds.MultiDataset(dataset, workers=workers)
dataloader_val = wds.MultiDataset(dataset_val, workers=workers)

for epoch in range(1000):
  for i, (input_var, output_var) in enumerate(dataloader):
    x = input_var.to(device)
    y = output_var.to(device, non_blocking=True)
    # do training

  for i, (input_var, output_var) in enumerate(dataloader_val):
    x = input_var.to(device)
    y = output_var.to(device, non_blocking=True)
    # do eval

from webdataset.

tmbdev avatar tmbdev commented on June 15, 2024

Ah, thanks, that's helpful.

Keep in mind that you do not need to use MultiDataset with WebDataset; it's an optional class that is simpler internally than DataLoader and gives you more options for shuffling and batching (e.g., you unbatch, shuffle, and rebatch in it). In most cases, regular DataLoader is just fine.

Please also have a look at github.com/nvlabs/tensorcom

from webdataset.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.