Giter Club home page Giter Club logo

size-fit-net's Introduction

A Deep Learning System for Predicting Size and Fit in Fashion E-Commerce

An (unofficial) PyTorch implementation of SizeFitNet (SFNet) architecture proposed in the paper A Deep Learning System for Predicting Size and Fit in Fashion E-Commerce by Sheikh et. al (RecSys'19).

Dataset

The original paper demonstrates experiments on two datasets:

  1. ModCloth
  2. RentTheRunWay

Both datasets are curated from fashion e-commerce websites that provide transaction information that contain customer attributes, article attributes and fit (target varible). fit is a categorical variable with small, fit and large as the possible labels.

Model Architecture

The model consits of two pathways: one that captures user embeddings + user features, the other that captures item embeddings + item features. See the following figure taken from paper:



The representations within each pathway is transformed using residual/skip connections. The authors compare it against an MLP (without skip connections) baseline and show better performance.

Different from the paper, I combine the user representation (u) and item representation (v) into a new tensor as below:

   [u, v, |u-v|, u*v]
  • concatenation of the two representations
  • element-wise product u โˆ— v
  • absolute element-wise difference |u-v|

Based on: https://arxiv.org/pdf/1705.02364.pdf

This new representation is fed to the top layer skip connection block.

Instructions

  1. Start by installing the necessary packages (preferably within a Python virtual environment):
   pip install -r requirements.txt
  1. Download the data from here and place it in the data/ directory.

  2. The original data is quite messy with quite a bit of missing values (NaNs). As such, I dropped some attributes and did some data imputation. Refer to the data_exploration.ipynb notebook to understand the data processing steps. The notebook also provides instruction on creating the train/validation/test splits.

  3. Set data, model and training configurations by appropriately modifying the jsonnet files under configs/ directory.

  4. Train the SFNet Model:

   python train.py

The model checkpoints and run configuration will be saved under runs/<experiment_name> The above also generates tensorboard plots of training loss and validation metrics, that would be useful to view training progress.

  1. Test the Model:
   python test.py --saved_model_path `runs/<experiment_name>`

Primary Results

The model was quite sensitive to hyperparameters. For certain sets of hyperparameters, the optimization process would diverge and the validation loss would shoot up. The configuration provided under configs/model.jsonnet was what worked best for me.

Learning Curves

Below are some tensorboard graphs for validation metrics.



Performance on Test Set

Precision Recall F1-score Accuracy AUC
0.682 0.397 0.378 0.691 0.728

Note: The metrics reported here are macro-average values.

TODO

Some future work for this repository and ideas are plausible ways to improve the results:

  • Experiments on RentTheRunWay Dataset
  • L2-regularization on the embeddings
  • Batch Normalization on the feed forward layers
  • Early Stopping
  • Learning Rate Decay
  • Weighted Loss Function to account for the class-imbalance
  • Contrastive Learning (Siamese Networks-?) to encourage learning different sub-spaces for positive and negative size-fits
  • Topic Modelling Approaches
  • Modelling Users/Items as Distributions

Acknowledgements

Thanks to Rishab Mishra for making the datasets used here publicly available on Kaggle. Some ideas for pre-processing the data were borrowed from NeverInAsh.

size-fit-net's People

Contributors

hareeshbahuleyan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.