Giter Club home page Giter Club logo

mito_trace's Introduction

Mito_Trace

Tracking cell-lineage using mutations in the mitochondrion genome.

Important Steps

  1. Run cellranger 10x on fastq files
  2. Preprocess: Convert bam file into single-cell pileup data
  3. Filtering: Coverage and quality based filtering.
  4. Variant calling: Filtering MT variants using Vireo or mgatk
  5. Merging conditions
  6. Demultiplexing: Assigning a donor to each cell, and None if it is unassigned.
  7. Clonal detection: Assign cells to a clone and None if unassigned.
  8. Clonal enrichment across conditions: Run enrichment analysis across conditions of the same sample/donor

Ways the steps can be run:

There are different ways the pipeline process can be run, in different orders, such that the proper quality and parameters can be assessed. Steps: [1->2] is done first, and [7->8] is done last. In the middle, it's either [3->5->4->6], [3->4->5->6] or [5->6->3->4]

mito_trace's People

Contributors

isshamie avatar

Watchers

Nathan Lewis avatar  avatar

mito_trace's Issues

Sweep across het, growth rate, coverage, error, cluster size,

Sweep across het, growth rate, coverage, error, cluster size,

# TODO Sweep across het, growth rate, coverage, error, cluster size,

class ParameterSweep:
    # TODO Sweep across het, growth rate, coverage, error, cluster size,
    def __init__(self, params_dir):
        self.params_dir = params_dir
        self.files = glob.glob(params_dir+'/*.yaml')
        sweep_params = dict()
        # Create a dictionary for each parameter across files
        for f in self.files:
            sweep_params[f] = read_config_file(f)
        self.sweep_params = sweep_params
        return

    def run_sweep(self):
        sweep_results = dict()
        for f in self.sweep_params:
            params = self.sweep_params[f]
            print(f"Running with file: {f}")
            sim = FullSimulation(params)
            sim.run()
            sweep_results[f] = sim
        return


    def plot_sensitivity(self, vars=None):
        return

    def plot_ppv(self):
        return

    def cluster_before_after(self):
        return

    def save(self, f_save=None):
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

    def load(self, filename):
        #filename = self.params['filename']
        f = open(filename, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)



# I can make each variable a class?
# Does this ruin running the MCMC? I don't think so, b/c that format is going to be put in after anyway
class FullSimulation:
    def __init__(self, params_f):
        # TODO parallel_apply over simulations
        params = read_config_file(params_f)
        self.n_iter = params['num_iterations']
        self.params = params
        return
        #for i in self.n_iter:

    def run(self):
        # Parallelize df
        df = pd.Series(index=range(self.n_iter))
        df = df.parallel_apply(self.run_sim, args=(self.params,))
        self.sim = df
        return

    @staticmethod
    def run_sim(x, params):
        s = Simulation(params)
        s.initialize()
        s.grow()
        s.subsample_new(to_delete=True)
        return s

    def flatten_sim(self):

282ab76942262dbe65cf2703e9df02c69274e2af

Simulation Run

This epic should be to create the simulated data for a variety of features. It needs to account for space and time, and potentially only metrics will need to be saved.

This should be in mro/common for general use

This should be in mro/common for general use

# TODO: This should be in mro/common for general use

#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype json;
filetype bam;
#
# @include "_sort_and_mark_dups_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype bam;
filetype bam.bai;
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype json;
filetype csv;
#
# @include "_peak_caller_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype bedgraph;
filetype pickle;
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype bed;
filetype json;
#
# @include "_basic_sc_atac_counter_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype csv;
filetype json;
filetype bed;
filetype pickle;
filetype h5;
#
# @include "_produce_cell_barcodes_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype csv;
filetype json;
filetype bed;
filetype pickle;
filetype h5;
filetype npy.gz;
#
# @include "_sc_atac_metric_collector_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype bed;
filetype bam;
filetype csv;
filetype json;
filetype h5;
filetype txt;
filetype pickle;
#
# @include "_peak_annotator_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype bed;
filetype tsv;
filetype h5;
filetype gz;
filetype pickle;
#
# @include "_sc_atac_analyzer_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv;
filetype h5;
filetype pickle;
filetype gz;
filetype bed;
filetype csv;
#
# @include "_sc_atac_reporter_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype json;
filetype html;
filetype csv;
filetype h5;
filetype bam;
#
# @include "_atac_cloupe_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype cloupe;
filetype csv;
filetype json;
filetype h5;
filetype bed;
filetype tsv.gz.tbi;
#
# @include "_preflight_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype csv;
filetype bed;
filetype tsv.gz;
filetype tsv.gz.tbi;

#
# @include "_aligner_stages.mro"
#

# SETUP_CHUNKS chunks up the input fastq data into sets of matched R1, R2, SI, and BC fastq files.
# input_mode specifies how FASTQs were generated. There are two modes:
#
# 1. "BCL_PROCESSOR"
#
# FASTQs produced by the 10X BCL_PROCESSOR pipeline. This mode assumes the FASTQ files obey the internal
# naming conventions and the reads have been interleaved into RA FASTQ files.
#
# 2. "ILMN_BCL2FASTQ"
#
# FASTQs produced directly by Illumina BCL2FASTQ v1.8.4. For this mode, BCL2FASTQ must be configured to emit the
# index2 read, rather than using it for dual-index demultiplexing:
#
# configureBclToFastq.pl --no-eamss --use-bases-mask=Y100,I8,Y14,Y100 --input-dir=<basecalls_dir> \
#     --output-dir=<output_dir> --sample-sheet=<sample_sheet.csv>
#
# The sample sheet must be formatted as per the BCL2FASTQ documentation (10 column csv), and must contain a row for
# each sample index used. The sequencer must have been run in dual index mode, with the second index read (used to
# read the 10X barcode) emitted as the R2 output file. The --use-bases-mask argument should be set to the read
# length used.
stage SETUP_CHUNKS(
    in  string   sample_id        "id of the sample",
    in  map[]    sample_def       "list of dictionary specifying input data",
    in  string   input_mode       "configuration of the input fastqs",
    in  map      downsample       "map specifies either subsample_rate (float) or gigabases (int)",
    out map[]    chunks           "map has barcode, barcode_reverse_complement, sample_index, read1, read2, gem_group, and read_group fields",
    out string[] read_groups      "list of strings representing read groups",
    out json     downsample_info  "info about downsampling result",
    src py       "stages/processing/setup_chunks",
)

# Trims adapter sequences from reads and massages fastq output into a fixed format (interleaved R1 file, etc.)
stage TRIM_READS(
    in  map[]  chunks,
    in  string barcode_whitelist,
    in  int    max_read_num,
    in  map    trim_def,
    in  map    adapters,
    out map[]  chunks,
    out json   bc_counts,
    out json   lot_info,
    out json   read_counts,
    src py     "stages/processing/trim_reads",
) split (
    in  map    chunk,
) using (
    volatile = strict,
)

# Aligns the reads to the input reference, producing chunked bam files
stage ALIGN_READS(
    in  map[]  chunks,
    in  string aligner,
    in  string aligner_method,
    in  string reference_path,
    in  string read_group_sample,
    in  int    num_threads,
    out bam[],
    src py     "stages/processing/align_reads",
) split (
    in  map    chunk,
) using (
    # N.B. No index files are generated for the bam
    volatile = strict,
)

#
# @include "_aligner.mro"
#

# Takes input fastqs and chunks them, trims them, and aligns the trimmed reads to a reference
pipeline _ALIGNER(
    in  string sample_id,
    in  string fastq_mode         "configuration of the input fastqs",
    in  map[]  sample_def,
    in  string reference_path     "this is the reference_path",
    in  string barcode_whitelist  "name of barcode whitelist file",
    in  map    trim_def,
    in  map    adapters,
    in  string read_group_sample  "sample header for BAM file",
    in  map    downsample,
    out bam[]  align,
    out map[]  chunks,
    out json   bc_counts,
    out json   lot_info           "gelbead lot detected",
    out json   read_counts        "total # of read pairs before and after adapter trimming",
    out json   downsample_info    "info on downsampling",
)
{
    call SETUP_CHUNKS(
        sample_id  = self.sample_id,
        input_mode = self.fastq_mode,
        sample_def = self.sample_def,
        downsample = self.downsample,
    ) using (
        volatile = true,
    )

    call TRIM_READS(
        chunks            = SETUP_CHUNKS.chunks,
        max_read_num      = 5000000,
        trim_def          = self.trim_def,
        adapters          = self.adapters,
        barcode_whitelist = self.barcode_whitelist,
    ) using (
        volatile = true,
    )

    call ALIGN_READS(
        chunks            = TRIM_READS.chunks,
        aligner           = "bwa",
        aligner_method    = "MEM",
        reference_path    = self.reference_path,
        read_group_sample = self.read_group_sample,
        num_threads       = 4,
    ) using (
        volatile = true,
    )

    return (
        align           = ALIGN_READS,
        chunks          = TRIM_READS.chunks,
        bc_counts       = TRIM_READS.bc_counts,
        lot_info        = TRIM_READS.lot_info,
        read_counts     = TRIM_READS.read_counts,
        downsample_info = SETUP_CHUNKS.downsample_info,
    )
}

#
# @include "_sort_and_mark_dups_stages.mro"
#

# Attaches raw and corrected barcode sequences to the aligned reads
stage ATTACH_BCS(
    in  string barcode_whitelist,
    in  bam[]  align,
    in  map[]  chunks,
    in  bool   paired_end,
    in  bool   exclude_non_bc_reads,
    in  float  bc_confidence_threshold,
    in  json   bc_counts,
    out bam[]  output,
    out int    perfect_read_count,
    src py     "stages/processing/attach_bcs",
) split (
    in  bam    align_chunk,
    in  map    chunk,
) using (
    # N.B. No index files are generated for the bam
    volatile = strict,
)

stage SORT_READS_BY_POS(
    in  bam[] input,
    out bam   tagsorted_bam,
    src py    "stages/processing/sort_reads_by_pos",
) split (
    in  bam   chunk_input,
) using (
    # N.B. No index files are generated for the bam
    volatile = strict,
)

# Marks duplicates in the reads using barcodes and fragment alignments to detect PCR and optical/diffusion duplicates
stage MARK_DUPLICATES(
    in  bam        input,
    in  string     reference_path,
    in  json       raw_barcode_counts,
    in  string     barcode_whitelist,
    out bam        output,
    out bam.bai    index,
    out csv        singlecell_mapping,
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    src py         "stages/processing/mark_duplicates",
) split (
    in  map        lane_map,
    in  string     chunk_start,
    in  string     chunk_end,
    in  int        chunk_num,
) using (
    # N.B. BAM/BED index files are explicitly bound where used
    volatile = strict,
)

#
# @include "_sort_and_mark_dups.mro"
#

# Attaches barcodes to the aligned reads, marks duplicate reads, and produces a barcode-sorted and position-sorted
# output BAM
pipeline _SORT_AND_MARK_DUPS(
    in  bam[]      align,
    in  map[]      chunks,
    in  string     barcode_whitelist,
    in  json       bc_counts,
    in  string     reference_path,
    out bam        possorted_bam        "bam file sorted by position",
    out bam.bai    possorted_bam_index  "position-sorted bam index",
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    out csv        singlecell_mapping,
    out bam[]      read_paired_bam,
)
{
    call ATTACH_BCS(
        align                   = self.align,
        chunks                  = self.chunks,
        paired_end              = true,
        barcode_whitelist       = self.barcode_whitelist,
        exclude_non_bc_reads    = false,
        bc_confidence_threshold = 0.975,
        bc_counts               = self.bc_counts,
    ) using (
        volatile = true,
    )

    call SORT_READS_BY_POS(
        input = ATTACH_BCS.output,
    ) using (
        volatile = true,
    )

    call MARK_DUPLICATES(
        input              = SORT_READS_BY_POS.tagsorted_bam,
        reference_path     = self.reference_path,
        barcode_whitelist  = self.barcode_whitelist,
        raw_barcode_counts = self.bc_counts,
    ) using (
        volatile = true,
    )

    return (
        possorted_bam       = MARK_DUPLICATES.output,
        possorted_bam_index = MARK_DUPLICATES.index,
        singlecell_mapping  = MARK_DUPLICATES.singlecell_mapping,
        fragments           = MARK_DUPLICATES.fragments,
        fragments_index     = MARK_DUPLICATES.fragments_index,
        read_paired_bam     = ATTACH_BCS.output,
    )
}

#
# @include "_peak_caller_stages.mro"
#

stage COUNT_CUT_SITES(
    in  path       reference_path,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    out bedgraph   cut_sites,
    out pickle     count_dict,
    src py         "stages/processing/count_cut_sites",
) split (
    in  string     contig,
) using (
    # N.B. We explicitly bind the index file
    volatile = strict,
)

stage DETECT_PEAKS(
    in  bedgraph cut_sites,
    in  path     reference_path,
    in  pickle   count_dict,
    out bed      peaks,
    out json     peak_metrics,
    src py       "stages/processing/detect_peaks",
) split (
    in  string   contig,
    in  float[]  params,
    in  float    threshold,
) using (
    mem_gb   = 6,
    # N.B. We explicitly bind the index file
    volatile = strict,
)

#
# @include "_peak_caller.mro"
#

pipeline _PEAK_CALLER(
    in  path       reference_path,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    out bedgraph   cut_sites,
    out bed        peaks,
    out json       peak_metrics,
)
{
    call COUNT_CUT_SITES(
        reference_path  = self.reference_path,
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
    )

    call DETECT_PEAKS(
        reference_path = self.reference_path,
        cut_sites      = COUNT_CUT_SITES.cut_sites,
        count_dict     = COUNT_CUT_SITES.count_dict,
    )

    return (
        cut_sites    = COUNT_CUT_SITES.cut_sites,
        peaks        = DETECT_PEAKS.peaks,
        peak_metrics = DETECT_PEAKS.peak_metrics,
    )
}

#
# @include "_basic_sc_atac_counter_stages.mro"
#

stage GENERATE_PEAK_MATRIX(
    in  string reference_path,
    in  tsv.gz fragments,
    in  bed    peaks,
    out h5     raw_matrix,
    out path   raw_matrix_mex,
    src py     "stages/processing/generate_peak_matrix",
) split (
    in  file   barcodes,
) using (
    mem_gb   = 4,
    # N.B. we don't explicitly need the fragment index
    volatile = strict,
)

stage FILTER_PEAK_MATRIX(
    in  h5   raw_matrix,
    in  int  num_analysis_bcs,
    in  int  random_seed,
    in  csv  cell_barcodes,
    out h5   filtered_matrix,
    out path filtered_matrix_mex,
    src py   "stages/processing/filter_peak_matrix",
) split (
) using (
    volatile = strict,
)

#
# @include "_produce_cell_barcodes_stages.mro"
#

stage REMOVE_LOW_TARGETING_BARCODES(
    in  bed        peaks,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    out json       barcode_counts,
    out json       low_targeting_barcodes,
    out json       low_targeting_summary,
    out json       fragment_lengths,
    out json       covered_bases,
    src py         "stages/processing/cell_calling/remove_low_targeting_barcodes",
) split (
    in  string     contig,
    out pickle     fragment_counts,
    out pickle     targeted_counts,
    out int        peak_coverage,
) using (
    mem_gb   = 4,
    volatile = strict,
)

stage REMOVE_GEL_BEAD_DOUBLET_BARCODES(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    in  json       barcode_counts,
    out json       gel_bead_doublet_barcodes,
    out json       gel_bead_doublet_summary,
    out csv        connect_matrix,
    src py         "stages/processing/cell_calling/remove_gel_bead_doublet_barcodes",
) split (
    in  string     contig,
    in  file       valid_barcodes,
) using (
    mem_gb   = 4,
    volatile = strict,
)

stage REMOVE_BARCODE_MULTIPLETS(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    in  string     barcode_whitelist,
    in  json       barcode_counts,
    out json       barcode_multiplets,
    out json       barcode_multiplets_summary,
    src py         "stages/processing/cell_calling/remove_barcode_multiplets",
) split (
    in  string     contig,
    in  string     gem_group,
    out npy.gz     part_a_linkage_matrix,
    out npy.gz     part_b_linkage_matrix,
) using (
    mem_gb   = 4,
    volatile = strict,
)

stage MERGE_EXCLUDED_BARCODES(
    in  json[] barcode_exclusions,
    out json   excluded_barcodes,
    src py     "stages/processing/cell_calling/merge_excluded_barcodes",
)

stage DETECT_CELL_BARCODES(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     barcode_whitelist,
    in  json       excluded_barcodes,
    in  map        force_cells,
    in  string     reference_path,
    in  bed        peaks,
    out csv        cell_barcodes,
    out csv        singlecell,
    out json       cell_calling_summary,
    src py         "stages/processing/cell_calling/detect_cell_barcodes",
) split (
    in  string     contig,
    out pickle     barcode_counts,
    out pickle     targeted_counts,
    out int        fragment_depth,
) using (
    mem_gb   = 4,
    volatile = strict,
)

# TODO: This should be in mro/common for general use
stage MERGE_SUMMARY_METRICS(
    in  json[] summary_jsons,
    out json   merged_summary,
    src py     "stages/processing/cell_calling/merge_summary_metrics",
)

#
# @include "_produce_cell_barcodes.mro"
#

pipeline _PRODUCE_CELL_BARCODES(
    in  bed        peaks,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    in  string     barcode_whitelist,
    in  map        force_cells,
    out csv        cell_barcodes,
    out csv        singlecell,
    out json       cell_calling_summary,
    out json       excluded_barcodes,
    out json       fragment_lengths,
    out json       covered_bases,
)
{
    call REMOVE_LOW_TARGETING_BARCODES(
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
        peaks           = self.peaks,
        reference_path  = self.reference_path,
    )

    call REMOVE_GEL_BEAD_DOUBLET_BARCODES(
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
        reference_path  = self.reference_path,
        barcode_counts  = REMOVE_LOW_TARGETING_BARCODES.barcode_counts,
    )

    call REMOVE_BARCODE_MULTIPLETS(
        fragments         = self.fragments,
        fragments_index   = self.fragments_index,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        barcode_counts    = REMOVE_LOW_TARGETING_BARCODES.barcode_counts,
    )

    call MERGE_EXCLUDED_BARCODES(
        barcode_exclusions = [
            REMOVE_BARCODE_MULTIPLETS.barcode_multiplets,
            REMOVE_GEL_BEAD_DOUBLET_BARCODES.gel_bead_doublet_barcodes,
            REMOVE_LOW_TARGETING_BARCODES.low_targeting_barcodes,
        ],
    )

    call DETECT_CELL_BARCODES(
        fragments         = self.fragments,
        fragments_index   = self.fragments_index,
        barcode_whitelist = self.barcode_whitelist,
        force_cells       = self.force_cells,
        excluded_barcodes = MERGE_EXCLUDED_BARCODES.excluded_barcodes,
        reference_path    = self.reference_path,
        peaks             = self.peaks,
    )

    call MERGE_SUMMARY_METRICS as MERGE_CELL_METRICS(
        summary_jsons = [
            REMOVE_LOW_TARGETING_BARCODES.low_targeting_summary,
            REMOVE_GEL_BEAD_DOUBLET_BARCODES.gel_bead_doublet_summary,
            REMOVE_BARCODE_MULTIPLETS.barcode_multiplets_summary,
            DETECT_CELL_BARCODES.cell_calling_summary,
        ],
    )

    return (
        cell_barcodes        = DETECT_CELL_BARCODES.cell_barcodes,
        excluded_barcodes    = MERGE_EXCLUDED_BARCODES.excluded_barcodes,
        singlecell           = DETECT_CELL_BARCODES.singlecell,
        cell_calling_summary = MERGE_CELL_METRICS.merged_summary,
        fragment_lengths     = REMOVE_LOW_TARGETING_BARCODES.fragment_lengths,
        covered_bases        = REMOVE_LOW_TARGETING_BARCODES.covered_bases,
    )
}

#
# @include "_basic_sc_atac_counter.mro"
#

pipeline _BASIC_SC_ATAC_COUNTER(
    in  string     sample_id,
    in  string     fastq_mode                   "configuration of the input fastqs",
    in  map[]      sample_def,
    in  string     reference_path               "this is the reference_path",
    in  string     barcode_whitelist            "name of barcode whitelist file",
    in  map        trim_def,
    in  map        adapters,
    in  map        downsample,
    in  map        force_cells,
    out bam        possorted_bam                "bam file sorted by position",
    out bam.bai    possorted_bam_index          "position-sorted bam index",
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    out json       lot_info                     "gelbead lot detected",
    out json       read_counts                  "total # of read pairs before and after adapter trimming",
    out json       downsample_info              "info on downsampling",
    out csv        cell_barcodes,
    out json       excluded_barcodes,
    out json       cell_calling_summary,
    out bed        peaks,
    out bedgraph   cut_sites,
    out csv        singlecell_mapping,
    out csv        singlecell_cells,
    out json       peak_metrics,
    out bam[]      read_paired_bam,
    out h5         raw_peak_bc_matrix,
    out path       raw_peak_bc_matrix_mex,
    out h5         filtered_peak_bc_matrix,
    out path       filtered_peak_bc_matrix_mex,
)
{
    call _ALIGNER(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        read_group_sample = self.sample_id,
        trim_def          = self.trim_def,
        adapters          = self.adapters,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        downsample        = self.downsample,
    )

    call _SORT_AND_MARK_DUPS(
        align             = _ALIGNER.align,
        chunks            = _ALIGNER.chunks,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        bc_counts         = _ALIGNER.bc_counts,
    )

    call _PEAK_CALLER(
        fragments       = _SORT_AND_MARK_DUPS.fragments,
        fragments_index = _SORT_AND_MARK_DUPS.fragments_index,
        reference_path  = self.reference_path,
    )

    call _PRODUCE_CELL_BARCODES(
        fragments         = _SORT_AND_MARK_DUPS.fragments,
        fragments_index   = _SORT_AND_MARK_DUPS.fragments_index,
        peaks             = _PEAK_CALLER.peaks,
        force_cells       = self.force_cells,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
    )

    call GENERATE_PEAK_MATRIX(
        reference_path = self.reference_path,
        fragments      = _SORT_AND_MARK_DUPS.fragments,
        peaks          = _PEAK_CALLER.peaks,
    )

    call FILTER_PEAK_MATRIX(
        num_analysis_bcs = null,
        cell_barcodes    = _PRODUCE_CELL_BARCODES.cell_barcodes,
        raw_matrix       = GENERATE_PEAK_MATRIX.raw_matrix,
        random_seed      = null,
    )

    return (
        possorted_bam               = _SORT_AND_MARK_DUPS.possorted_bam,
        possorted_bam_index         = _SORT_AND_MARK_DUPS.possorted_bam_index,
        singlecell_mapping          = _SORT_AND_MARK_DUPS.singlecell_mapping,
        singlecell_cells            = _PRODUCE_CELL_BARCODES.singlecell,
        lot_info                    = _ALIGNER.lot_info,
        read_counts                 = _ALIGNER.read_counts,
        downsample_info             = _ALIGNER.downsample_info,
        cell_barcodes               = _PRODUCE_CELL_BARCODES.cell_barcodes,
        excluded_barcodes           = _PRODUCE_CELL_BARCODES.excluded_barcodes,
        cell_calling_summary        = _PRODUCE_CELL_BARCODES.cell_calling_summary,
        peak_metrics                = _PEAK_CALLER.peak_metrics,
        cut_sites                   = _PEAK_CALLER.cut_sites,
        peaks                       = _PEAK_CALLER.peaks,
        fragments                   = _SORT_AND_MARK_DUPS.fragments,
        fragments_index             = _SORT_AND_MARK_DUPS.fragments_index,
        read_paired_bam             = _SORT_AND_MARK_DUPS.read_paired_bam,
        raw_peak_bc_matrix          = GENERATE_PEAK_MATRIX.raw_matrix,
        raw_peak_bc_matrix_mex      = GENERATE_PEAK_MATRIX.raw_matrix_mex,
        filtered_peak_bc_matrix     = FILTER_PEAK_MATRIX.filtered_matrix,
        filtered_peak_bc_matrix_mex = FILTER_PEAK_MATRIX.filtered_matrix_mex,
    )
}

#
# @include "_sc_atac_metric_collector_stages.mro"
#

stage ESTIMATE_LIBRARY_COMPLEXITY(
    in  json   sequencing_summary,
    in  tsv.gz fragments,
    in  csv    cell_barcodes,
    out json   bulk_complexity,
    out json   complexity_summary,
    out json   singlecell_complexity,
    src py     "stages/metrics/estimate_library_complexity",
) split (
    in  file   barcodes,
) using (
    mem_gb   = 6,
    volatile = strict,
)

stage GENERATE_SEQUENCING_METRICS(
    in  bam[] input,
    out txt   misc_sm,
    out json  summary,
    src py    "stages/metrics/generate_sequencing_metrics",
) split (
    in  bam   chunk_bam,
) using (
    volatile = strict,
)

stage GENERATE_SINGLECELL_TARGETING(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  bed        peaks,
    in  string     reference_path,
    out csv        singlecell,
    out json       summary,
    out csv        tss_relpos,
    out csv        ctcf_relpos,
    src py         "stages/metrics/generate_singlecell_targeting",
) split (
    in  string     contig,
    out int        read_count,
    out pickle     target_counts_by_barcode,
    out pickle     chunk_tss,
    out pickle     chunk_ctcf,
) using (
    mem_gb   = 6,
    volatile = strict,
)

stage MERGE_SINGLECELL_METRICS(
    in  string reference_path,
    in  csv    singlecell_mapping,
    in  csv    singlecell_targets,
    in  csv    singlecell_cells,
    out csv    singlecell,
    out json   summary,
    src py     "stages/metrics/merge_singlecell_metrics",
) using (
    mem_gb   = 8,
    volatile = strict,
)

stage REPORT_INSERT_SIZES(
    in  tsv.gz fragments,
    in  bool   exclude_non_nuclear,
    in  string reference_path,
    out csv    insert_sizes,
    out json   insert_summary,
    src py     "stages/metrics/report_insert_sizes",
) split (
    in  file   barcode,
    out file   total,
) using (
    volatile = strict,
)

stage REPORT_TSS_CTCF(
    in  csv  tss_relpos,
    in  csv  ctcf_relpos,
    out json summary_metrics,
    src py   "stages/metrics/report_tss_ctcf",
) using (
    volatile = strict,
)

#
# @include "_sc_atac_metric_collector.mro"
#

pipeline _SC_ATAC_METRIC_COLLECTOR(
    in  bam[]      read_paired_bam,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  bed        peaks,
    in  string     reference_path         "this is the reference_path",
    in  csv        cell_barcodes,
    in  csv        singlecell_mapping,
    in  csv        singlecell_cells,
    out json       singlecell_results,
    out csv        singlecell,
    out json       enrichment_results,
    out json       basic_summary,
    out json       insert_summary,
    out csv        insert_sizes,
    out json       bulk_complexity,
    out json       singlecell_complexity,
    out json       complexity_summary,
    out csv        tss_relpos,
    out csv        ctcf_relpos,
)
{
    call GENERATE_SINGLECELL_TARGETING(
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
        peaks           = self.peaks,
        reference_path  = self.reference_path,
    )

    call MERGE_SINGLECELL_METRICS(
        reference_path     = self.reference_path,
        singlecell_mapping = self.singlecell_mapping,
        singlecell_cells   = self.singlecell_cells,
        singlecell_targets = GENERATE_SINGLECELL_TARGETING.singlecell,
    )

    call GENERATE_SEQUENCING_METRICS(
        input = self.read_paired_bam,
    )

    call ESTIMATE_LIBRARY_COMPLEXITY(
        sequencing_summary = GENERATE_SEQUENCING_METRICS.summary,
        fragments          = self.fragments,
        cell_barcodes      = self.cell_barcodes,
    )

    call REPORT_INSERT_SIZES(
        fragments           = self.fragments,
        reference_path      = self.reference_path,
        exclude_non_nuclear = true,
    )

    call REPORT_TSS_CTCF(
        tss_relpos  = GENERATE_SINGLECELL_TARGETING.tss_relpos,
        ctcf_relpos = GENERATE_SINGLECELL_TARGETING.ctcf_relpos,
    )

    return (
        ###
        singlecell            = MERGE_SINGLECELL_METRICS.singlecell,
        singlecell_results    = MERGE_SINGLECELL_METRICS.summary,
        ###
        enrichment_results    = REPORT_TSS_CTCF.summary_metrics,
        basic_summary         = GENERATE_SEQUENCING_METRICS.summary,
        insert_summary        = REPORT_INSERT_SIZES.insert_summary,
        insert_sizes          = REPORT_INSERT_SIZES.insert_sizes,
        bulk_complexity       = ESTIMATE_LIBRARY_COMPLEXITY.bulk_complexity,
        singlecell_complexity = ESTIMATE_LIBRARY_COMPLEXITY.singlecell_complexity,
        complexity_summary    = ESTIMATE_LIBRARY_COMPLEXITY.complexity_summary,
        tss_relpos            = GENERATE_SINGLECELL_TARGETING.tss_relpos,
        ctcf_relpos           = GENERATE_SINGLECELL_TARGETING.ctcf_relpos,
    )
}

#
# @include "_peak_annotator_stages.mro"
#

stage ANNOTATE_PEAKS(
    in  bed    peaks,
    in  string reference_path,
    out tsv    peak_annotation,
    src py     "stages/analysis/annotate_peaks",
) split (
    in  int    chunk_start,
    in  int    chunk_end,
) using (
    mem_gb   = 5,
    volatile = strict,
)

stage COMPUTE_GC_DISTRIBUTION(
    in  bed    peaks,
    in  string reference_path,
    out pickle GCdict,
    src py     "stages/analysis/compute_gc_dist",
) split (
) using (
    volatile = strict,
)

stage SCAN_MOTIFS(
    in  pickle globalGCdict,
    in  bed    peaks,
    in  string reference_path,
    in  float  pwm_threshold,
    out bed    peak_motif_hits,
    src py     "stages/analysis/scan_motifs",
) split (
    in  file   GCdict,
) using (
    volatile = strict,
)

stage GENERATE_TF_MATRIX(
    in  path reference_path,
    in  bed  peaks,
    in  bed  peak_motif_hits,
    in  h5   filtered_matrix,
    out h5   filtered_tf_bc_matrix,
    out path filtered_tf_bc_matrix_mex,
    out gz   tf_propZ_matrix,
    src py   "stages/analysis/generate_tf_matrix",
) split (
) using (
    volatile = strict,
)

#
# @include "_peak_annotator.mro"
#

pipeline _PEAK_ANNOTATOR(
    in  string reference_path,
    in  bed    peaks,
    in  h5     filtered_peak_bc_matrix,
    in  float  pwm_threshold,
    out h5     filtered_tf_bc_matrix,
    out path   filtered_tf_bc_matrix_mex,
    out gz     tf_propZ_matrix,
    out tsv    peak_annotation,
)
{
    call ANNOTATE_PEAKS(
        peaks          = self.peaks,
        reference_path = self.reference_path,
    )

    call COMPUTE_GC_DISTRIBUTION(
        peaks          = self.peaks,
        reference_path = self.reference_path,
    )

    call SCAN_MOTIFS(
        globalGCdict   = COMPUTE_GC_DISTRIBUTION.GCdict,
        peaks          = self.peaks,
        reference_path = self.reference_path,
        pwm_threshold  = self.pwm_threshold,
    )

    call GENERATE_TF_MATRIX(
        reference_path  = self.reference_path,
        peaks           = self.peaks,
        filtered_matrix = self.filtered_peak_bc_matrix,
        peak_motif_hits = SCAN_MOTIFS.peak_motif_hits,
    )

    return (
        filtered_tf_bc_matrix     = GENERATE_TF_MATRIX.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex = GENERATE_TF_MATRIX.filtered_tf_bc_matrix_mex,
        tf_propZ_matrix           = GENERATE_TF_MATRIX.tf_propZ_matrix,
        peak_annotation           = ANNOTATE_PEAKS.peak_annotation,
    )
}

#
# @include "_sc_atac_analyzer_stages.mro"
#

stage ANALYZER_PREFLIGHT(
    in  bed      peaks,
    in  h5       filtered_peak_bc_matrix,
    in  string[] factorization,
    in  int      tsne_perplexity,
    in  int      random_seed,
    in  float    tsne_theta,
    in  int      tsne_mom_switch_iter,
    in  int      tsne_stop_lying_iter,
    in  int      tsne_max_dims,
    in  int      tsne_input_pcs,
    in  int      tsne_max_iter,
    in  int      max_clusters,
    in  int      num_components,
    in  int      num_dr_bcs,
    in  int      num_dr_features,
    in  float    neighbor_a,
    in  float    neighbor_b,
    in  int      graphclust_neighbors,
    src py       "stages/preflight/atac_analyzer",
)

stage REDUCE_DIMENSIONS(
    in  h5       filtered_matrix,
    in  string[] factorization,
    in  int      num_dims,
    in  int      num_bcs,
    in  int      num_features,
    in  int      random_seed,
    out path     reduced_data,
    out map      reduction_summary,
    src py       "stages/analysis/reduce_dimensions",
) split (
    in  string   method,
) using (
    volatile = strict,
)

stage CLUSTER_CELLS(
    in  h5       filtered_matrix,
    in  path     reduced_data,
    in  map      reduction_summary,
    in  string[] factorization,
    in  int      minclusters,
    in  int      maxclusters,
    in  int      num_dims,
    in  int      random_seed,
    out path     clustered_data,
    out map      clustering_summary,
    src py       "stages/analysis/cluster_cells",
) split (
    in  int      n_clusters,
) using (
    volatile = strict,
)

stage PROJECT_TSNE(
    in  h5       filtered_matrix,
    in  path     reduced_data,
    in  map      reduction_summary,
    in  int      tsne_perplexity,
    in  int      tsne_max_dims,
    in  int      tsne_input_pcs,
    in  float    tsne_theta,
    in  int      tsne_max_iter,
    in  int      tsne_stop_lying_iter,
    in  int      tsne_mom_switch_iter,
    in  int      random_seed,
    in  string[] factorization,
    out path     tsne,
    out map      tsne_summary,
    src py       "stages/analysis/project_tsne",
) split (
    in  string   method,
    in  int      tsne_dims,
) using (
    volatile = strict,
)

stage RUN_GRAPH_CLUSTERING(
    in  h5       matrix_h5                 "Processed matrix",
    in  string[] factorization,
    in  path     reduced_data,
    in  map      reduction_summary,
    in  int      num_neighbors             "Use this many neighbors",
    in  float    neighbor_a                "Use larger of (a+b*log10(n_cells) neighbors or num_neighbors",
    in  float    neighbor_b                "Use larger of (a+b*log10(n_cells) neighbors or num_neighbors",
    in  int      balltree_leaf_size,
    in  string   similarity_type           "Type of similarity to use (nn or snn)",
    out h5       chunked_neighbors,
    out path     knn_clusters,
    out map      graph_clustering_summary,
    src py       "stages/analysis/run_graph_clustering",
) split (
    in  string   method,
    in  pickle   neighbor_index,
    in  h5       submatrix,
    in  int      row_start,
    in  int      total_rows,
    in  int      k_nearest,
    in  h5       use_bcs,
) using (
    volatile = strict,
)

stage COMBINE_CLUSTERING(
    in  h5   filtered_matrix,
    in  map  clustering_summary,
    in  path clustered_data,
    in  map  graph_clustering_summary,
    in  path knn_clusters,
    out path clustering,
    out map  clustering_summary,
    src py   "stages/analysis/combine_clustering",
) using (
    volatile = strict,
)

stage SUMMARIZE_ANALYSIS(
    in  tsv  peak_annotation,
    in  h5   filtered_peak_bc_matrix,
    in  h5   filtered_tf_bc_matrix,
    in  gz   tf_propZ_matrix,
    in  path reduced_data,
    in  map  reduction_summary,
    in  path clustering,
    in  map  clustering_summary,
    in  path tsne,
    in  map  tsne_summary,
    in  path enrichment_analysis,
    in  map  enrichment_analysis_summary,
    out h5   analysis,
    out path analysis_csv,
    out h5   feature_bc_matrix,
    src py   "stages/analysis/summarize_analysis",
) split (
) using (
    volatile = strict,
)

stage PERFORM_DIFFERENTIAL_ANALYSIS(
    in  bed      peaks,
    in  string   reference_path,
    in  h5       filtered_peak_bc_matrix,
    in  h5       filtered_tf_bc_matrix,
    in  string[] factorization,
    in  path     clustering,
    in  map      clustering_summary,
    out path     enrichment_analysis,
    out map      enrichment_analysis_summary,
    src py       "stages/analysis/perform_differential_analysis",
) split (
    in  string   method,
    in  string   clustering_key,
    in  int      cluster,
    out csv      tmp_diffexp,
) using (
    volatile = strict,
)

#
# @include "_sc_atac_analyzer.mro"
#

pipeline _SC_ATAC_ANALYZER(
    in  string   reference_path,
    in  bed      peaks,
    in  h5       filtered_peak_bc_matrix,
    in  string[] factorization,
    in  int      tsne_perplexity,
    in  int      random_seed,
    in  float    tsne_theta,
    in  int      tsne_mom_switch_iter,
    in  int      tsne_stop_lying_iter,
    in  int      tsne_max_dims,
    in  int      tsne_input_pcs,
    in  int      tsne_max_iter,
    in  int      max_clusters,
    in  int      num_components,
    in  int      num_dr_bcs,
    in  int      num_dr_features,
    in  float    neighbor_a,
    in  float    neighbor_b,
    in  int      graphclust_neighbors,
    out h5       analysis,
    out path     analysis_csv,
    out h5       filtered_tf_bc_matrix,
    out path     filtered_tf_bc_matrix_mex,
    out h5       feature_bc_matrix,
    out tsv      peak_annotation,
)
{
    call ANALYZER_PREFLIGHT(
        peaks                   = self.peaks,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        factorization           = self.factorization,
        tsne_perplexity         = self.tsne_perplexity,
        random_seed             = self.random_seed,
        tsne_theta              = self.tsne_theta,
        tsne_mom_switch_iter    = self.tsne_mom_switch_iter,
        tsne_stop_lying_iter    = self.tsne_stop_lying_iter,
        tsne_max_dims           = self.tsne_max_dims,
        tsne_input_pcs          = self.tsne_input_pcs,
        tsne_max_iter           = self.tsne_max_iter,
        max_clusters            = self.max_clusters,
        num_components          = self.num_components,
        num_dr_bcs              = self.num_dr_bcs,
        num_dr_features         = self.num_dr_features,
        neighbor_a              = self.neighbor_a,
        neighbor_b              = self.neighbor_b,
        graphclust_neighbors    = self.graphclust_neighbors,
    ) using (
        volatile = true,
    )

    call _PEAK_ANNOTATOR(
        reference_path          = self.reference_path,
        peaks                   = self.peaks,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        pwm_threshold           = null,
    )

    call REDUCE_DIMENSIONS(
        filtered_matrix = self.filtered_peak_bc_matrix,
        factorization   = self.factorization,
        num_dims        = self.num_components,
        num_bcs         = self.num_dr_bcs,
        num_features    = self.num_dr_features,
        random_seed     = self.random_seed,
    )

    call CLUSTER_CELLS(
        filtered_matrix   = self.filtered_peak_bc_matrix,
        reduced_data      = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary = REDUCE_DIMENSIONS.reduction_summary,
        factorization     = self.factorization,
        minclusters       = 2,
        maxclusters       = self.max_clusters,
        num_dims          = null,
        random_seed       = self.random_seed,
    )

    call PROJECT_TSNE(
        filtered_matrix      = self.filtered_peak_bc_matrix,
        reduced_data         = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary    = REDUCE_DIMENSIONS.reduction_summary,
        tsne_perplexity      = self.tsne_perplexity,
        tsne_max_dims        = self.tsne_max_dims,
        tsne_input_pcs       = self.tsne_input_pcs,
        tsne_theta           = self.tsne_theta,
        tsne_max_iter        = self.tsne_max_iter,
        tsne_stop_lying_iter = self.tsne_stop_lying_iter,
        tsne_mom_switch_iter = self.tsne_mom_switch_iter,
        random_seed          = self.random_seed,
        factorization        = self.factorization,
    )

    call RUN_GRAPH_CLUSTERING(
        matrix_h5          = self.filtered_peak_bc_matrix,
        factorization      = self.factorization,
        reduced_data       = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary  = REDUCE_DIMENSIONS.reduction_summary,
        num_neighbors      = self.graphclust_neighbors,
        neighbor_a         = self.neighbor_a,
        neighbor_b         = self.neighbor_b,
        balltree_leaf_size = null,
        similarity_type    = "nn",
    )

    call COMBINE_CLUSTERING(
        filtered_matrix          = self.filtered_peak_bc_matrix,
        clustering_summary       = CLUSTER_CELLS.clustering_summary,
        clustered_data           = CLUSTER_CELLS.clustered_data,
        graph_clustering_summary = RUN_GRAPH_CLUSTERING.graph_clustering_summary,
        knn_clusters             = RUN_GRAPH_CLUSTERING.knn_clusters,
    )

    call PERFORM_DIFFERENTIAL_ANALYSIS(
        reference_path          = self.reference_path,
        peaks                   = self.peaks,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        filtered_tf_bc_matrix   = _PEAK_ANNOTATOR.filtered_tf_bc_matrix,
        factorization           = self.factorization,
        clustering              = COMBINE_CLUSTERING.clustering,
        clustering_summary      = COMBINE_CLUSTERING.clustering_summary,
    )

    call SUMMARIZE_ANALYSIS(
        peak_annotation             = _PEAK_ANNOTATOR.peak_annotation,
        filtered_peak_bc_matrix     = self.filtered_peak_bc_matrix,
        filtered_tf_bc_matrix       = _PEAK_ANNOTATOR.filtered_tf_bc_matrix,
        tf_propZ_matrix             = _PEAK_ANNOTATOR.tf_propZ_matrix,
        reduced_data                = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary           = REDUCE_DIMENSIONS.reduction_summary,
        clustering                  = COMBINE_CLUSTERING.clustering,
        clustering_summary          = COMBINE_CLUSTERING.clustering_summary,
        tsne                        = PROJECT_TSNE.tsne,
        tsne_summary                = PROJECT_TSNE.tsne_summary,
        enrichment_analysis         = PERFORM_DIFFERENTIAL_ANALYSIS.enrichment_analysis,
        enrichment_analysis_summary = PERFORM_DIFFERENTIAL_ANALYSIS.enrichment_analysis_summary,
    )

    return (
        analysis                  = SUMMARIZE_ANALYSIS.analysis,
        analysis_csv              = SUMMARIZE_ANALYSIS.analysis_csv,
        filtered_tf_bc_matrix     = _PEAK_ANNOTATOR.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex = _PEAK_ANNOTATOR.filtered_tf_bc_matrix_mex,
        feature_bc_matrix         = SUMMARIZE_ANALYSIS.feature_bc_matrix,
        peak_annotation           = _PEAK_ANNOTATOR.peak_annotation,
    )
}

#
# @include "_sc_atac_reporter_stages.mro"
#

stage SUMMARIZE_REPORTS_SINGLECELL(
    in  string reference_path,
    in  json   complexity_summary,
    in  json   cell_calling_summary,
    in  json   peak_results,
    in  json   basic_results,
    in  json   error_results_summary,
    in  json   insert_summary,
    in  json   singlecell_results,
    in  json   contam_results,
    in  json   downsample_info,
    in  json   enrichment_results,
    out json   analysis_params,
    out json   summary,
    out csv    summary_csv,
    src py     "stages/reporter/summarize_reports_singlecell",
) using (
    mem_gb = 4,
)

stage CREATE_WEBSUMMARY(
    in  string reference_path,
    in  string barcode_whitelist,
    in  json   summary_results,
    in  json   bulk_complexity,
    in  json   singlecell_complexity,
    in  string sample_id,
    in  string sample_desc,
    in  map[]  sample_def,
    in  bool   debug,
    in  csv    singlecell,
    in  csv    insert_sizes,
    in  csv    tss_relpos,
    in  csv    ctcf_relpos,
    in  h5     filtered_peak_bc_matrix,
    in  h5     analysis,
    in  json   excluded_barcodes,
    out html   web_summary,
    src py     "stages/reporter/create_websummary",
) using (
    mem_gb = 16,
)

#
# @include "_sc_atac_reporter.mro"
#

pipeline _SC_ATAC_REPORTER(
    in  string reference_path,
    in  string barcode_whitelist,
    in  json   bulk_complexity,
    in  json   cell_calling_summary,
    in  json   complexity_summary,
    in  json   basic_summary,
    in  json   peak_summary,
    in  json   singlecell_results,
    in  json   insert_summary,
    in  json   downsample_info,
    in  json   singlecell_complexity,
    in  csv    singlecell,
    in  csv    tss_relpos,
    in  csv    ctcf_relpos,
    in  string sample_id,
    in  string sample_desc,
    in  map[]  sample_def,
    in  csv    sc_insert_sizes,
    in  json   enrichment_results,
    in  h5     filtered_peak_bc_matrix,
    in  h5     analysis,
    in  json   excluded_barcodes,
    #
    out json   summary,
    out html   web_summary,
    out csv    summary_csv,
)
{
    call SUMMARIZE_REPORTS_SINGLECELL(
        reference_path        = self.reference_path,
        complexity_summary    = self.complexity_summary,
        cell_calling_summary  = self.cell_calling_summary,
        peak_results          = self.peak_summary,
        basic_results         = self.basic_summary,
        error_results_summary = null,
        insert_summary        = self.insert_summary,
        singlecell_results    = self.singlecell_results,
        contam_results        = null,
        downsample_info       = self.downsample_info,
        enrichment_results    = self.enrichment_results,
    )

    call CREATE_WEBSUMMARY(
        reference_path          = self.reference_path,
        barcode_whitelist       = self.barcode_whitelist,
        singlecell              = self.singlecell,
        tss_relpos              = self.tss_relpos,
        ctcf_relpos             = self.ctcf_relpos,
        sample_id               = self.sample_id,
        sample_desc             = self.sample_desc,
        sample_def              = self.sample_def,
        insert_sizes            = self.sc_insert_sizes,
        summary_results         = SUMMARIZE_REPORTS_SINGLECELL.summary,
        bulk_complexity         = self.bulk_complexity,
        singlecell_complexity   = self.singlecell_complexity,
        analysis                = self.analysis,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        excluded_barcodes       = self.excluded_barcodes,
        debug                   = false,
    )

    return (
        summary     = SUMMARIZE_REPORTS_SINGLECELL.summary,
        web_summary = CREATE_WEBSUMMARY.web_summary,
        summary_csv = SUMMARIZE_REPORTS_SINGLECELL.summary_csv,
    )
}

#
# @include "_atac_cloupe_stages.mro"
#

stage CLOUPE_PREPROCESS(
    in  string     pipestance_type,
    in  string     sample_id,
    in  string     sample_desc,
    in  string     reference_path,
    in  h5         analysis,
    in  h5         feature_barcode_matrix,
    in  bed        peaks,
    in  tsv.gz.tbi fragments_index,
    in  json       metrics_json,
    in  csv        aggregation_csv,
    in  json       gem_group_index_json,
    in  bool       no_secondary_analysis,
    out cloupe     output_for_cloupe,
    out json       gem_group_index_json,
    src py         "stages/cloupe/atac_cloupe_preprocess",
) split (
)

#
# @include "_preflight_stages.mro"
#

stage ATAC_COUNTER_PREFLIGHT(
    in  string   sample_id,
    in  string   fastq_mode,
    in  map[]    sample_def,
    in  string   reference_path,
    in  map      force_cells,
    in  string[] factorization,
    in  map      downsample,
    in  bool     check_executables,
    in  map      trim_def,
    src py       "stages/preflight/atac_counter",
) split (
)

stage ATAC_AGGR_PREFLIGHT(
    in  string   sample_id,
    in  string   reference_path,
    in  csv      aggr_csv,
    in  string   normalization,
    in  string[] factorization,
    in  bool     check_executables,
    src py       "stages/preflight/atac_aggr",
) split (
)

stage ATAC_REANALYZER_PREFLIGHT(
    in  string     sample_id,
    in  string     reference_path,
    in  string     barcode_whitelist,
    in  bed        peaks,
    in  csv        parameters,
    in  map        force_cells,
    in  csv        cell_barcodes,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  csv        aggregation_csv,
    in  bool       check_executables,
    src py         "stages/preflight/atac_reanalyzer",
) split (
)

#
# @include "sc_atac_counter.mro"
#

pipeline SC_ATAC_COUNTER(
    in  string     fastq_mode,
    in  string     sample_id,
    in  map[]      sample_def,
    in  map        downsample,
    in  string     sample_desc,
    in  string     reference_path,
    in  map        trim_def,
    in  string     barcode_whitelist,
    in  map        adapters,
    in  string[]   factorization,
    in  map        force_cells,
    #
    out csv        singlecell,
    out bam        possorted_bam,
    out bam.bai    possorted_bam_index,
    out json       summary,
    out html       web_summary,
    out bed        peaks,
    out h5         raw_peak_bc_matrix,
    out path       raw_peak_bc_matrix_mex,
    out path       analysis_csv,
    out h5         filtered_peak_bc_matrix,
    out path       filtered_peak_bc_matrix_mex,
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    out h5         filtered_tf_bc_matrix,
    out path       filtered_tf_bc_matrix_mex,
    out cloupe     cloupe,
    out csv        summary_csv,
    out tsv        peak_annotation,
)
{
    call ATAC_COUNTER_PREFLIGHT as ATAC_COUNTER_PREFLIGHT_LOCAL(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        reference_path    = self.reference_path,
        force_cells       = self.force_cells,
        factorization     = self.factorization,
        downsample        = self.downsample,
        trim_def          = self.trim_def,
        check_executables = false,
    ) using (
        local     = true,
        preflight = true,
    )

    call ATAC_COUNTER_PREFLIGHT(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        reference_path    = self.reference_path,
        force_cells       = self.force_cells,
        factorization     = self.factorization,
        downsample        = self.downsample,
        trim_def          = self.trim_def,
        check_executables = true,
    ) using (
        preflight = true,
    )

    call _BASIC_SC_ATAC_COUNTER(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        trim_def          = self.trim_def,
        adapters          = self.adapters,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        downsample        = self.downsample,
        force_cells       = self.force_cells,
    )

    call _SC_ATAC_METRIC_COLLECTOR(
        read_paired_bam    = _BASIC_SC_ATAC_COUNTER.read_paired_bam,
        fragments          = _BASIC_SC_ATAC_COUNTER.fragments,
        fragments_index    = _BASIC_SC_ATAC_COUNTER.fragments_index,
        peaks              = _BASIC_SC_ATAC_COUNTER.peaks,
        reference_path     = self.reference_path,
        cell_barcodes      = _BASIC_SC_ATAC_COUNTER.cell_barcodes,
        singlecell_cells   = _BASIC_SC_ATAC_COUNTER.singlecell_cells,
        singlecell_mapping = _BASIC_SC_ATAC_COUNTER.singlecell_mapping,
    )

    call _SC_ATAC_ANALYZER(
        peaks                   = _BASIC_SC_ATAC_COUNTER.peaks,
        filtered_peak_bc_matrix = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        reference_path          = self.reference_path,
        factorization           = self.factorization,
        tsne_perplexity         = 30,
        tsne_max_dims           = null,
        tsne_input_pcs          = null,
        tsne_max_iter           = null,
        tsne_stop_lying_iter    = null,
        tsne_mom_switch_iter    = null,
        tsne_theta              = null,
        random_seed             = null,
        max_clusters            = 10,
        neighbor_a              = null,
        neighbor_b              = null,
        graphclust_neighbors    = null,
        num_components          = 15,
        num_dr_bcs              = null,
        num_dr_features         = null,
    )

    call CLOUPE_PREPROCESS(
        pipestance_type        = "SC_ATAC_COUNTER_CS",
        reference_path         = self.reference_path,
        sample_id              = self.sample_id,
        sample_desc            = self.sample_desc,
        analysis               = _SC_ATAC_ANALYZER.analysis,
        feature_barcode_matrix = _SC_ATAC_ANALYZER.feature_bc_matrix,
        metrics_json           = _SC_ATAC_METRIC_COLLECTOR.basic_summary,
        peaks                  = _BASIC_SC_ATAC_COUNTER.peaks,
        fragments_index        = _BASIC_SC_ATAC_COUNTER.fragments_index,
        aggregation_csv        = null,
        gem_group_index_json   = null,
        no_secondary_analysis  = false,
    )

    call _SC_ATAC_REPORTER(
        reference_path          = self.reference_path,
        barcode_whitelist       = self.barcode_whitelist,
        bulk_complexity         = _SC_ATAC_METRIC_COLLECTOR.bulk_complexity,
        singlecell_complexity   = _SC_ATAC_METRIC_COLLECTOR.singlecell_complexity,
        cell_calling_summary    = _BASIC_SC_ATAC_COUNTER.cell_calling_summary,
        complexity_summary      = _SC_ATAC_METRIC_COLLECTOR.complexity_summary,
        basic_summary           = _SC_ATAC_METRIC_COLLECTOR.basic_summary,
        peak_summary            = _BASIC_SC_ATAC_COUNTER.peak_metrics,
        singlecell_results      = _SC_ATAC_METRIC_COLLECTOR.singlecell_results,
        insert_summary          = _SC_ATAC_METRIC_COLLECTOR.insert_summary,
        downsample_info         = _BASIC_SC_ATAC_COUNTER.downsample_info,
        singlecell              = _SC_ATAC_METRIC_COLLECTOR.singlecell,
        tss_relpos              = _SC_ATAC_METRIC_COLLECTOR.tss_relpos,
        ctcf_relpos             = _SC_ATAC_METRIC_COLLECTOR.ctcf_relpos,
        sample_id               = self.sample_id,
        sample_desc             = self.sample_desc,
        sample_def              = self.sample_def,
        sc_insert_sizes         = _SC_ATAC_METRIC_COLLECTOR.insert_sizes,
        enrichment_results      = _SC_ATAC_METRIC_COLLECTOR.enrichment_results,
        filtered_peak_bc_matrix = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        analysis                = _SC_ATAC_ANALYZER.analysis,
        excluded_barcodes       = _BASIC_SC_ATAC_COUNTER.excluded_barcodes,
    )

    return (
        singlecell                  = _SC_ATAC_METRIC_COLLECTOR.singlecell,
        possorted_bam               = _BASIC_SC_ATAC_COUNTER.possorted_bam,
        possorted_bam_index         = _BASIC_SC_ATAC_COUNTER.possorted_bam_index,
        summary                     = _SC_ATAC_REPORTER.summary,
        web_summary                 = _SC_ATAC_REPORTER.web_summary,
        peaks                       = _BASIC_SC_ATAC_COUNTER.peaks,
        raw_peak_bc_matrix          = _BASIC_SC_ATAC_COUNTER.raw_peak_bc_matrix,
        raw_peak_bc_matrix_mex      = _BASIC_SC_ATAC_COUNTER.raw_peak_bc_matrix_mex,
        analysis_csv                = _SC_ATAC_ANALYZER.analysis_csv,
        filtered_peak_bc_matrix     = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        filtered_peak_bc_matrix_mex = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix_mex,
        fragments                   = _BASIC_SC_ATAC_COUNTER.fragments,
        fragments_index             = _BASIC_SC_ATAC_COUNTER.fragments_index,
        filtered_tf_bc_matrix       = _SC_ATAC_ANALYZER.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex   = _SC_ATAC_ANALYZER.filtered_tf_bc_matrix_mex,
        cloupe                      = CLOUPE_PREPROCESS.output_for_cloupe,
        summary_csv                 = _SC_ATAC_REPORTER.summary_csv,
        peak_annotation             = _SC_ATAC_ANALYZER.peak_annotation,
    )
}

#
# @include "sc_atac_counter_cs.mro"
#

# Customer-facing (CS) pipeline
pipeline SC_ATAC_COUNTER_CS(
    in  string     fastq_mode                   "Input fastq configuration",
    in  string     sample_id,
    in  map[]      sample_def,
    in  map        downsample,
    in  string     sample_desc                  "Sample description",
    in  string     reference_path               "Path to 10X reference package",
    in  string[]   factorization                "Dimensionality reduction method (lsa, plsa, or pca)",
    in  map        force_cells                  "Force cell calling to a fixed number",
    #
    out csv        singlecell                   "Per-barcode fragment counts & metrics",
    out bam        possorted_bam                "Position sorted BAM file"  "possorted_bam.bam",
    out bam.bai    possorted_bam_index          "Position sorted BAM index"  "possorted_bam.bam.bai",
    out json       summary                      "Summary of all data metrics",
    out html       web_summary                  "HTML file summarizing data & analysis",
    out bed        peaks                        "Bed file of all called peak locations",
    out h5         raw_peak_bc_matrix           "Raw peak barcode matrix in hdf5 format",
    out path       raw_peak_bc_matrix_mex       "Raw peak barcode matrix in mex format"  "raw_peak_bc_matrix",
    out path       analysis_csv                 "Directory of analysis files"  "analysis",
    out h5         filtered_peak_bc_matrix      "Filtered peak barcode matrix in hdf5 format",
    out path       filtered_peak_bc_matrix_mex  "Filtered peak barcode matrix in mex format"  "filtered_peak_bc_matrix",
    out tsv.gz     fragments                    "Barcoded and aligned fragment file"  "fragments.tsv.gz",
    out tsv.gz.tbi fragments_index              "Fragment file index"       "fragments.tsv.gz.tbi",
    out h5         filtered_tf_bc_matrix        "Filtered tf barcode matrix in hdf5 format",
    out path       filtered_tf_bc_matrix_mex    "Filtered tf barcode matrix in mex format"  "filtered_tf_bc_matrix",
    out cloupe     cloupe                       "Loupe Cell Browser input file",
    out csv        summary_csv                  "csv summarizing important metrics and values"  "summary.csv",
    out tsv        peak_annotation              "Annotation of peaks with genes",
)
{
    call SC_ATAC_COUNTER(
        fastq_mode        = self.fastq_mode,
        sample_id         = self.sample_id,
        sample_def        = self.sample_def,
        downsample        = self.downsample,
        sample_desc       = self.sample_desc,
        reference_path    = self.reference_path,
        trim_def          = {
            "R1": {
                "3prime": ["MErc"],
            },
            "R2": {
                "3prime": ["MErc"],
            },
            "discard_untrimmed": false,
        },
        barcode_whitelist = "737K-cratac-v1",
        adapters          = {
            "ME": "AGATGTGTATAAGAGACAG",
            "MErc": "CTGTCTCTTATACACATCT",
        },
        factorization     = self.factorization,
        force_cells       = self.force_cells,
    )

    return (
        singlecell                  = SC_ATAC_COUNTER.singlecell,
        possorted_bam               = SC_ATAC_COUNTER.possorted_bam,
        possorted_bam_index         = SC_ATAC_COUNTER.possorted_bam_index,
        summary                     = SC_ATAC_COUNTER.summary,
        web_summary                 = SC_ATAC_COUNTER.web_summary,
        peaks                       = SC_ATAC_COUNTER.peaks,
        raw_peak_bc_matrix          = SC_ATAC_COUNTER.raw_peak_bc_matrix,
        raw_peak_bc_matrix_mex      = SC_ATAC_COUNTER.raw_peak_bc_matrix_mex,
        analysis_csv                = SC_ATAC_COUNTER.analysis_csv,
        filtered_peak_bc_matrix     = SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        filtered_peak_bc_matrix_mex = SC_ATAC_COUNTER.filtered_peak_bc_matrix_mex,
        fragments                   = SC_ATAC_COUNTER.fragments,
        fragments_index             = SC_ATAC_COUNTER.fragments_index,
        filtered_tf_bc_matrix       = SC_ATAC_COUNTER.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex   = SC_ATAC_COUNTER.filtered_tf_bc_matrix_mex,
        cloupe                      = SC_ATAC_COUNTER.cloupe,
        summary_csv                 = SC_ATAC_COUNTER.summary_csv,
        peak_annotation             = SC_ATAC_COUNTER.peak_annotation,
    )
}

#
# @include "__PBMC_P.mro"
#

call SC_ATAC_COUNTER_CS(
    fastq_mode     = "ILMN_BCL2FASTQ",
    sample_id      = "PBMC_P",
    sample_def     = [{
        "bc_in_read": 1,
        "bc_length": 16,
        "gem_group": null,
        "lanes": null,
        "library": "LibraryNotSpecified",
        "read_path": "/data/isshamie/dropbox/ATACseq/2020_11_18_Croker/igm-storage2.ucsd.edu/201113_A00953_0185_AHN7TMDSXY",
        "sample_indices": ["any"],
        "sample_names": ["BC_10xATAC_PMBC_P"],
    }],
    reference_path = "/data/isshamie/mito_lineage/data/external/GRCh38_MT_blacklist",
    downsample     = null,
    sample_desc    = "",
    factorization  = ["lsa"],
    force_cells    = null,
)
ew file mode 100644
ndex 0000000..9c51e87
++ b/Analysis/mtscATAC/2020_11_18_Croker/PBMC_P/_sitecheck

eb68621b49dd4f77b70f5698ea51b9190ad8b14c

account for mt_pos being a list not an int

account for mt_pos being a list not an int

# TODO account for mt_pos being a list not an int

import numpy as np
from numpy import random
import os
import pandas as pd
from tqdm import tqdm
#from src.config import ROOT_DIR
from sklearn.metrics import roc_curve, average_precision_score, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import glob
from sklearn.cluster import KMeans
from sklearn import metrics
from scipy.spatial.distance import cdist
from pandarallel import pandarallel

from mplh.color_utils import get_colors
from mplh.fig_utils import legend_from_color
from mplh import cluster_help as ch
from src.simulations.utils.config import read_config_file, write_config_file

from dynamicTreeCut import cutreeHybrid
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage
from sklearn.model_selection import ParameterGrid
from src.simulations.utils.config import check_required
from src.simulations.analysis import Analysis as an
from .simulation import Simulation
from src.utils.utils import compare_arbitrary_labels

# Does this ruin running the MCMC? I don't think so, b/c that format is going to be put in after anyway
class FullSimulation:
    """
    Class that simulates cell growth for lineage tracing. Reads in a
    parameter file and runs a certain number of iterations based on
    the num_iterations parameter.

    :ivar n_iter: Number of iterations
    :type n_iter: int
    :ivar num_cells: Number of cells to sequence
    :type num_cells: int

    :ivar sim: Each index is a different iteration of the simulation.
    :type sim: pandas Series

    """
    def __init__(self, params_f):
        params = read_config_file(params_f)
        self.n_iter = params['num_iterations']
        self.num_cells = params['num_cells']
        self.params = params
        if 'n_clust' not in params:
            self.params['n_clust'] = None
        # Store the metrics with this
        self.metrics = dict()

        # Files to save
        #self.outdir = os.path.join(self.params['local_outdir'])
        self.outdir = self.params['local_outdir' ]
        self.data_outdir = os.path.join(self.params['data_outdir'])
        self.f_save_data = os.path.join(self.data_outdir,
                                   self.params['name'] + '.p')
        self.f_save = os.path.join(self.outdir, self.params['name'] + '.p')

        self.f_save_metrics = self.f_save_data.replace('.p', '.metrics.tsv')
        self.f_save_cluster = self.f_save_data.replace('.p', '.cluster.tsv')
        self.f_save_befaft = self.f_save_data.replace('.p', '.before_after.tsv')
        self.f_save_befaft_cl = self.f_save_data.replace('.p',
                                                      '.before_after_cl.tsv')
        self.f_save_rocs = self.f_save_data.replace('.p', '.rocs.p')
        return
        #for i in self.n_iter:

    def run(self):
        """
        Runs the simulation and stores it in sim attr. Will also pickle
        the objects and save.

        This uses Pandaralel to parallelize the runs.
        :return:
        """
        # Parallelize df
        df = pd.Series(index=range(self.n_iter))
        df = df.apply(self.run_sim, args=(self.params,))

        #pandarallel.initialize(nb_workers=self.params['cpus'])
        #df = df.parallel_apply(self.run_sim, args=(self.params,))

        self.sim = df

        if "save_small" in self.params and self.params["save_small"] > 0:
            curr = self.sim[:self.params["save_small"]]
            for ind, val in curr.items():
                curr_f = self.f_save.replace(".p","") + "_sim" + str(ind)
                val.to_csv(curr_f+".csv")

        return

    @staticmethod
    def run_sim(x, params):
        """Run iteration of simulation.

        For a single iteration, it will initialize, grow, subsample,
        and merge the before stimulus and after stimulus variables.
        It willl also run
        :param x: Placeholder variable
        :param params: The parameter dictionary to use
        :type params: dict
        :return:
        """
        s = Simulation(params)
        s.initialize()
        s.grow()
        s.subsample_new(to_delete=True)
        s.combine_init_growth()
        return s

    def run_metrics(self):
        """
        Get metrics performances and save.
        :return:
        """
        self.sim_performance_dominant(group='both')
        self.stats_before_after()
        self.cluster_befaft()
        self.stats_cluster_befaft_dom()
        self.estimate_growth_rates_from_cluster()
        self.stats_before_after_clust()
        self.kl_divergence()

    def flatten_sim(self):
        ## TODO
        # This will extract out the classes of df
        return

    def sim_performance_dominant(self, group='both'):
        """
        Will colect metrics that are averaged over the simulations.
        These are specifically for looking at the main, dominant clone,
        and what the allele-frequency of that clone variant
        is for each cell.

        :param group: {'init', 'growth', 'both'} This will indicate to group by
        :ivar dropout: Number of dominant clone cells that have 0 reads
        at the lineage variant position.
        :type dropout: list
        :ivar prec_scores: sklearn average precision score based on
        the allele frequencies seen in the dominant clone cells versus
        the non-clone cells.
        :type prec_scores: list
        :ivar rocs: ROC curves for each iteration based on allele
        frequencies.

        :return:
        """
        dropout = []
        rocs = []
        prec_scores = []

        for iter, s in enumerate(self.sim.values):
            # First get the dominant clone , which is indexed as 1
            mt_pos = s.clone_mt_dict[1]
            # TODO account for mt_pos being a list not an int
            if group == 'init':
                clones = s.clone_cell
                cell_af = s.cell_af.loc[:,mt_pos]
            elif group == 'growth':
                clones = s.new_clone_cell
                cell_af = s.new_cell_af.loc[:,mt_pos]
            elif group == 'both':
                #clones = pd.concat((s.clone_cell, s.subsample_new_clone_cell)).reset_index(drop=True)
                #cell_af = pd.concat((s.cell_af.loc[:,mt_pos], s.subsample_new_cell_af.loc[:,mt_pos])).reset_index(drop=True)
                clones = s.combined_clones
                cell_af = s.combined_cell_af.loc[:,mt_pos]
            else:
                raise ValueError('group variable not properly set.')

            y_true = clones.values.copy()
            y_true[y_true != 1] = 0  # Set nondominant clones to 0
            rocs.append(roc_curve(y_true, cell_af))
            prec_scores.append(average_precision_score(y_true, cell_af))
            dropout.append((cell_af[clones==1]==0).sum()/cell_af[clones==1].shape[0])

        self.dropout = dropout
        self.prec_scores = prec_scores
        self.rocs = rocs
        pd.DataFrame([prec_scores, dropout], index=['Precision', 'Dropout']).to_csv(self.f_save_metrics, sep='\t')
        self.metrics['prec_scores'] = prec_scores
        self.metrics['dropout'] = dropout
        self.metrics['rocs'] = rocs
        pickle.dump(rocs, open(self.f_save_rocs, 'wb'))

        return


    def reduce_cells(self, cell_af):
        #self.sim
        return


    def stats_before_after(self, clone_id=1):
        b_a_df = pd.DataFrame(index=np.arange(0,len(self.sim)), columns=["B", "A", "A/B"], dtype=str)
        for iter, s in enumerate(self.sim.values):
            b_clones = s.clone_cell
            a_clones = s.subsample_new_clone_cell
            b_a_df.at[iter, "B"] = (b_clones == clone_id).sum()
            b_a_df.at[iter, "A"] = (a_clones==clone_id).sum()
            b_a_df.at[iter,"A/B"] = (b_a_df.at[iter, "A"]+1)/(b_a_df.at[iter, "B"]+1)

        self.b_a_df = b_a_df
        b_a_df.to_csv(self.f_save_befaft, sep='\t')
        self.metrics['b_a_df'] = b_a_df
        return



    def cluster_befaft(self):
        """
        Loops through the simulations and for each,
        it clusters the cells.

        :ivar cluster_results: Cluster labels for each cell in each
        iteration.
        :type List of tuples, which is a list of
        a tuple, where the tuple is indexed by the cell and the value
        is the cell's cluster label
        """
        cluster_results = []


        for s in tqdm(self.sim.values):
            cluster_results.append(an.cluster_kmeans(s.combined_cell_af,
                                                     n_clust=self.params['n_clust']))
            # Add the cluster results to combined meta
            s.combined_meta["cluster"] = cluster_results[-1]

            # Bring the cluster labels and the clone labels into same
            # name space
            s.combined_meta['cluster_clone'] = compare_arbitrary_labels(s.combined_meta['clone'],
                                                                        s.combined_meta['cluster'])

        self.cluster_results = cluster_results


    def stats_cluster_befaft_dom(self, clone_id=1):
        """
        Confusion matrix for clustering the proper clone cells together.
        :param clone_id: Which clone to get metrics for
        :return:
        """

        b_a_df = pd.DataFrame(index=np.arange(len(self.sim)),
                              columns=["TN", "FP", "FN", "TP"], dtype=int)
        f1_vals = []
        for ind, s in enumerate(self.sim.values):
            y_true = s.combined_clones
            y_true[y_true!=clone_id] = 0
            y_pred = s.combined_meta['cluster_clone'].copy()
            y_pred[y_pred!=clone_id] = 0

            f1_vals.append(f1_score(y_true, y_pred))
            # y_true, y_pred
            tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
            b_a_df.loc[ind] = [tn, fp, fn, tp]
        self.b_a_df_clust = b_a_df

        self.metrics['Dominant Cluster Confusion'] = b_a_df
        self.metrics['Cluster F1 scores'] = f1_vals
        return



    def estimate_growth_rates_from_known(self):
        all_growth_estimate = []
        all_clone_sizes = []
        for iter, s in enumerate(self.sim.values):
            growth_estimate, clone_sizes, _, _ = an.estimate_growth_rate(s.combined_meta)
            all_growth_estimate.append(growth_estimate)
            all_clone_sizes.append(clone_sizes)

        self.metrics['obs_growth_rates'] = all_growth_estimate
        self.metrics['obs_clone_sizes'] = all_clone_sizes
        return

    def estimate_growth_rates_from_cluster(self):
        all_growth_estimate = []
        all_clone_sizes = []
        all_bef_est = []
        all_aft_est = []
        for iter, s in enumerate(self.sim.values):
            growth_estimate, clone_sizes, aft_est, bef_est = an.estimate_growth_rate(s.combined_meta, clone_col="cluster_clone")
            all_growth_estimate.append(growth_estimate)
            all_bef_est.append(bef_est)
            all_aft_est.append(aft_est)
            all_clone_sizes.append(clone_sizes)
        self.metrics['pred_growth_rates'] = all_growth_estimate
        self.metrics['pred_aft_count'] = all_aft_est
        self.metrics['pred_bef_count'] = all_bef_est
        self.metrics['pred_clone_sizes'] = all_clone_sizes
        return


    def stats_before_after_clust(self, clone_id=1):
        b_a_df = pd.DataFrame(index=np.arange(0,len(self.sim)), columns=["A/B", "A", "B"], dtype=str)
        for iter, s in enumerate(self.sim.values):
            curr_pred_growth = self.metrics['pred_growth_rates'][iter]
            curr_pred_aft = self.metrics['pred_aft_count'][iter]
            #print('curr_pred_aft', curr_pred_aft)
            curr_pred_bef = self.metrics['pred_bef_count'][iter]
            b_a_df.at[iter, "A/B"] = curr_pred_growth.loc[clone_id]
            b_a_df.at[iter, "A"] = curr_pred_aft.loc[clone_id]
            b_a_df.at[iter, "B"] = curr_pred_bef.loc[clone_id]

        b_a_df.to_csv(self.f_save_befaft_cl, sep='\t')
        self.metrics['b_a_clust_df'] = b_a_df
        return

    def kl_divergence(self):
        all_growth_estimate = []
        all_clone_sizes = []
        for iter, s in enumerate(self.sim.values):
            growth_estimate, clone_sizes, _, _ = an.estimate_growth_rate(s.combined_meta)
            all_growth_estimate.append(growth_estimate)
            all_clone_sizes.append(clone_sizes)

        self.metrics['pred_growth_rates'] = all_growth_estimate
        self.metrics['pred_clone_sizes'] = all_clone_sizes


    def save(self, f_save=None):
        if f_save is None:
            f_save = self.f_save_data
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()


    def load(self, f_save=None):
        #filename = self.params['filename']
        if f_save is None:
            f_save = self.f_save
        f = open(f_save, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)


def main():
    return


if "__name__" == "__main__":
    main()
 No newline at end of file
ew file mode 100644
ndex 0000000..56d0507
++ b/src/simulations/parametersweep.py

00ec948784467058e080c2d5373f08ae6eaeb83c

This should be in mro/common for general use

This should be in mro/common for general use

# TODO: This should be in mro/common for general use

#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype json;
filetype bam;
#
# @include "_sort_and_mark_dups_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype bam;
filetype bam.bai;
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype json;
filetype csv;
#
# @include "_peak_caller_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype bedgraph;
filetype pickle;
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype bed;
filetype json;
#
# @include "_basic_sc_atac_counter_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype csv;
filetype json;
filetype bed;
filetype pickle;
filetype h5;
#
# @include "_produce_cell_barcodes_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype csv;
filetype json;
filetype bed;
filetype pickle;
filetype h5;
filetype npy.gz;
#
# @include "_sc_atac_metric_collector_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype bed;
filetype bam;
filetype csv;
filetype json;
filetype h5;
filetype txt;
filetype pickle;
#
# @include "_peak_annotator_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype bed;
filetype tsv;
filetype h5;
filetype gz;
filetype pickle;
#
# @include "_sc_atac_analyzer_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv;
filetype h5;
filetype pickle;
filetype gz;
filetype bed;
filetype csv;
#
# @include "_sc_atac_reporter_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype json;
filetype html;
filetype csv;
filetype h5;
filetype bam;
#
# @include "_atac_cloupe_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype cloupe;
filetype csv;
filetype json;
filetype h5;
filetype bed;
filetype tsv.gz.tbi;
#
# @include "_preflight_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype csv;
filetype bed;
filetype tsv.gz;
filetype tsv.gz.tbi;

#
# @include "_aligner_stages.mro"
#

# SETUP_CHUNKS chunks up the input fastq data into sets of matched R1, R2, SI, and BC fastq files.
# input_mode specifies how FASTQs were generated. There are two modes:
#
# 1. "BCL_PROCESSOR"
#
# FASTQs produced by the 10X BCL_PROCESSOR pipeline. This mode assumes the FASTQ files obey the internal
# naming conventions and the reads have been interleaved into RA FASTQ files.
#
# 2. "ILMN_BCL2FASTQ"
#
# FASTQs produced directly by Illumina BCL2FASTQ v1.8.4. For this mode, BCL2FASTQ must be configured to emit the
# index2 read, rather than using it for dual-index demultiplexing:
#
# configureBclToFastq.pl --no-eamss --use-bases-mask=Y100,I8,Y14,Y100 --input-dir=<basecalls_dir> \
#     --output-dir=<output_dir> --sample-sheet=<sample_sheet.csv>
#
# The sample sheet must be formatted as per the BCL2FASTQ documentation (10 column csv), and must contain a row for
# each sample index used. The sequencer must have been run in dual index mode, with the second index read (used to
# read the 10X barcode) emitted as the R2 output file. The --use-bases-mask argument should be set to the read
# length used.
stage SETUP_CHUNKS(
    in  string   sample_id        "id of the sample",
    in  map[]    sample_def       "list of dictionary specifying input data",
    in  string   input_mode       "configuration of the input fastqs",
    in  map      downsample       "map specifies either subsample_rate (float) or gigabases (int)",
    out map[]    chunks           "map has barcode, barcode_reverse_complement, sample_index, read1, read2, gem_group, and read_group fields",
    out string[] read_groups      "list of strings representing read groups",
    out json     downsample_info  "info about downsampling result",
    src py       "stages/processing/setup_chunks",
)

# Trims adapter sequences from reads and massages fastq output into a fixed format (interleaved R1 file, etc.)
stage TRIM_READS(
    in  map[]  chunks,
    in  string barcode_whitelist,
    in  int    max_read_num,
    in  map    trim_def,
    in  map    adapters,
    out map[]  chunks,
    out json   bc_counts,
    out json   lot_info,
    out json   read_counts,
    src py     "stages/processing/trim_reads",
) split (
    in  map    chunk,
) using (
    volatile = strict,
)

# Aligns the reads to the input reference, producing chunked bam files
stage ALIGN_READS(
    in  map[]  chunks,
    in  string aligner,
    in  string aligner_method,
    in  string reference_path,
    in  string read_group_sample,
    in  int    num_threads,
    out bam[],
    src py     "stages/processing/align_reads",
) split (
    in  map    chunk,
) using (
    # N.B. No index files are generated for the bam
    volatile = strict,
)

#
# @include "_aligner.mro"
#

# Takes input fastqs and chunks them, trims them, and aligns the trimmed reads to a reference
pipeline _ALIGNER(
    in  string sample_id,
    in  string fastq_mode         "configuration of the input fastqs",
    in  map[]  sample_def,
    in  string reference_path     "this is the reference_path",
    in  string barcode_whitelist  "name of barcode whitelist file",
    in  map    trim_def,
    in  map    adapters,
    in  string read_group_sample  "sample header for BAM file",
    in  map    downsample,
    out bam[]  align,
    out map[]  chunks,
    out json   bc_counts,
    out json   lot_info           "gelbead lot detected",
    out json   read_counts        "total # of read pairs before and after adapter trimming",
    out json   downsample_info    "info on downsampling",
)
{
    call SETUP_CHUNKS(
        sample_id  = self.sample_id,
        input_mode = self.fastq_mode,
        sample_def = self.sample_def,
        downsample = self.downsample,
    ) using (
        volatile = true,
    )

    call TRIM_READS(
        chunks            = SETUP_CHUNKS.chunks,
        max_read_num      = 5000000,
        trim_def          = self.trim_def,
        adapters          = self.adapters,
        barcode_whitelist = self.barcode_whitelist,
    ) using (
        volatile = true,
    )

    call ALIGN_READS(
        chunks            = TRIM_READS.chunks,
        aligner           = "bwa",
        aligner_method    = "MEM",
        reference_path    = self.reference_path,
        read_group_sample = self.read_group_sample,
        num_threads       = 4,
    ) using (
        volatile = true,
    )

    return (
        align           = ALIGN_READS,
        chunks          = TRIM_READS.chunks,
        bc_counts       = TRIM_READS.bc_counts,
        lot_info        = TRIM_READS.lot_info,
        read_counts     = TRIM_READS.read_counts,
        downsample_info = SETUP_CHUNKS.downsample_info,
    )
}

#
# @include "_sort_and_mark_dups_stages.mro"
#

# Attaches raw and corrected barcode sequences to the aligned reads
stage ATTACH_BCS(
    in  string barcode_whitelist,
    in  bam[]  align,
    in  map[]  chunks,
    in  bool   paired_end,
    in  bool   exclude_non_bc_reads,
    in  float  bc_confidence_threshold,
    in  json   bc_counts,
    out bam[]  output,
    out int    perfect_read_count,
    src py     "stages/processing/attach_bcs",
) split (
    in  bam    align_chunk,
    in  map    chunk,
) using (
    # N.B. No index files are generated for the bam
    volatile = strict,
)

stage SORT_READS_BY_POS(
    in  bam[] input,
    out bam   tagsorted_bam,
    src py    "stages/processing/sort_reads_by_pos",
) split (
    in  bam   chunk_input,
) using (
    # N.B. No index files are generated for the bam
    volatile = strict,
)

# Marks duplicates in the reads using barcodes and fragment alignments to detect PCR and optical/diffusion duplicates
stage MARK_DUPLICATES(
    in  bam        input,
    in  string     reference_path,
    in  json       raw_barcode_counts,
    in  string     barcode_whitelist,
    out bam        output,
    out bam.bai    index,
    out csv        singlecell_mapping,
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    src py         "stages/processing/mark_duplicates",
) split (
    in  map        lane_map,
    in  string     chunk_start,
    in  string     chunk_end,
    in  int        chunk_num,
) using (
    # N.B. BAM/BED index files are explicitly bound where used
    volatile = strict,
)

#
# @include "_sort_and_mark_dups.mro"
#

# Attaches barcodes to the aligned reads, marks duplicate reads, and produces a barcode-sorted and position-sorted
# output BAM
pipeline _SORT_AND_MARK_DUPS(
    in  bam[]      align,
    in  map[]      chunks,
    in  string     barcode_whitelist,
    in  json       bc_counts,
    in  string     reference_path,
    out bam        possorted_bam        "bam file sorted by position",
    out bam.bai    possorted_bam_index  "position-sorted bam index",
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    out csv        singlecell_mapping,
    out bam[]      read_paired_bam,
)
{
    call ATTACH_BCS(
        align                   = self.align,
        chunks                  = self.chunks,
        paired_end              = true,
        barcode_whitelist       = self.barcode_whitelist,
        exclude_non_bc_reads    = false,
        bc_confidence_threshold = 0.975,
        bc_counts               = self.bc_counts,
    ) using (
        volatile = true,
    )

    call SORT_READS_BY_POS(
        input = ATTACH_BCS.output,
    ) using (
        volatile = true,
    )

    call MARK_DUPLICATES(
        input              = SORT_READS_BY_POS.tagsorted_bam,
        reference_path     = self.reference_path,
        barcode_whitelist  = self.barcode_whitelist,
        raw_barcode_counts = self.bc_counts,
    ) using (
        volatile = true,
    )

    return (
        possorted_bam       = MARK_DUPLICATES.output,
        possorted_bam_index = MARK_DUPLICATES.index,
        singlecell_mapping  = MARK_DUPLICATES.singlecell_mapping,
        fragments           = MARK_DUPLICATES.fragments,
        fragments_index     = MARK_DUPLICATES.fragments_index,
        read_paired_bam     = ATTACH_BCS.output,
    )
}

#
# @include "_peak_caller_stages.mro"
#

stage COUNT_CUT_SITES(
    in  path       reference_path,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    out bedgraph   cut_sites,
    out pickle     count_dict,
    src py         "stages/processing/count_cut_sites",
) split (
    in  string     contig,
) using (
    # N.B. We explicitly bind the index file
    volatile = strict,
)

stage DETECT_PEAKS(
    in  bedgraph cut_sites,
    in  path     reference_path,
    in  pickle   count_dict,
    out bed      peaks,
    out json     peak_metrics,
    src py       "stages/processing/detect_peaks",
) split (
    in  string   contig,
    in  float[]  params,
    in  float    threshold,
) using (
    mem_gb   = 6,
    # N.B. We explicitly bind the index file
    volatile = strict,
)

#
# @include "_peak_caller.mro"
#

pipeline _PEAK_CALLER(
    in  path       reference_path,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    out bedgraph   cut_sites,
    out bed        peaks,
    out json       peak_metrics,
)
{
    call COUNT_CUT_SITES(
        reference_path  = self.reference_path,
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
    )

    call DETECT_PEAKS(
        reference_path = self.reference_path,
        cut_sites      = COUNT_CUT_SITES.cut_sites,
        count_dict     = COUNT_CUT_SITES.count_dict,
    )

    return (
        cut_sites    = COUNT_CUT_SITES.cut_sites,
        peaks        = DETECT_PEAKS.peaks,
        peak_metrics = DETECT_PEAKS.peak_metrics,
    )
}

#
# @include "_basic_sc_atac_counter_stages.mro"
#

stage GENERATE_PEAK_MATRIX(
    in  string reference_path,
    in  tsv.gz fragments,
    in  bed    peaks,
    out h5     raw_matrix,
    out path   raw_matrix_mex,
    src py     "stages/processing/generate_peak_matrix",
) split (
    in  file   barcodes,
) using (
    mem_gb   = 4,
    # N.B. we don't explicitly need the fragment index
    volatile = strict,
)

stage FILTER_PEAK_MATRIX(
    in  h5   raw_matrix,
    in  int  num_analysis_bcs,
    in  int  random_seed,
    in  csv  cell_barcodes,
    out h5   filtered_matrix,
    out path filtered_matrix_mex,
    src py   "stages/processing/filter_peak_matrix",
) split (
) using (
    volatile = strict,
)

#
# @include "_produce_cell_barcodes_stages.mro"
#

stage REMOVE_LOW_TARGETING_BARCODES(
    in  bed        peaks,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    out json       barcode_counts,
    out json       low_targeting_barcodes,
    out json       low_targeting_summary,
    out json       fragment_lengths,
    out json       covered_bases,
    src py         "stages/processing/cell_calling/remove_low_targeting_barcodes",
) split (
    in  string     contig,
    out pickle     fragment_counts,
    out pickle     targeted_counts,
    out int        peak_coverage,
) using (
    mem_gb   = 4,
    volatile = strict,
)

stage REMOVE_GEL_BEAD_DOUBLET_BARCODES(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    in  json       barcode_counts,
    out json       gel_bead_doublet_barcodes,
    out json       gel_bead_doublet_summary,
    out csv        connect_matrix,
    src py         "stages/processing/cell_calling/remove_gel_bead_doublet_barcodes",
) split (
    in  string     contig,
    in  file       valid_barcodes,
) using (
    mem_gb   = 4,
    volatile = strict,
)

stage REMOVE_BARCODE_MULTIPLETS(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    in  string     barcode_whitelist,
    in  json       barcode_counts,
    out json       barcode_multiplets,
    out json       barcode_multiplets_summary,
    src py         "stages/processing/cell_calling/remove_barcode_multiplets",
) split (
    in  string     contig,
    in  string     gem_group,
    out npy.gz     part_a_linkage_matrix,
    out npy.gz     part_b_linkage_matrix,
) using (
    mem_gb   = 4,
    volatile = strict,
)

stage MERGE_EXCLUDED_BARCODES(
    in  json[] barcode_exclusions,
    out json   excluded_barcodes,
    src py     "stages/processing/cell_calling/merge_excluded_barcodes",
)

stage DETECT_CELL_BARCODES(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     barcode_whitelist,
    in  json       excluded_barcodes,
    in  map        force_cells,
    in  string     reference_path,
    in  bed        peaks,
    out csv        cell_barcodes,
    out csv        singlecell,
    out json       cell_calling_summary,
    src py         "stages/processing/cell_calling/detect_cell_barcodes",
) split (
    in  string     contig,
    out pickle     barcode_counts,
    out pickle     targeted_counts,
    out int        fragment_depth,
) using (
    mem_gb   = 4,
    volatile = strict,
)

# TODO: This should be in mro/common for general use
stage MERGE_SUMMARY_METRICS(
    in  json[] summary_jsons,
    out json   merged_summary,
    src py     "stages/processing/cell_calling/merge_summary_metrics",
)

#
# @include "_produce_cell_barcodes.mro"
#

pipeline _PRODUCE_CELL_BARCODES(
    in  bed        peaks,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    in  string     barcode_whitelist,
    in  map        force_cells,
    out csv        cell_barcodes,
    out csv        singlecell,
    out json       cell_calling_summary,
    out json       excluded_barcodes,
    out json       fragment_lengths,
    out json       covered_bases,
)
{
    call REMOVE_LOW_TARGETING_BARCODES(
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
        peaks           = self.peaks,
        reference_path  = self.reference_path,
    )

    call REMOVE_GEL_BEAD_DOUBLET_BARCODES(
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
        reference_path  = self.reference_path,
        barcode_counts  = REMOVE_LOW_TARGETING_BARCODES.barcode_counts,
    )

    call REMOVE_BARCODE_MULTIPLETS(
        fragments         = self.fragments,
        fragments_index   = self.fragments_index,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        barcode_counts    = REMOVE_LOW_TARGETING_BARCODES.barcode_counts,
    )

    call MERGE_EXCLUDED_BARCODES(
        barcode_exclusions = [
            REMOVE_BARCODE_MULTIPLETS.barcode_multiplets,
            REMOVE_GEL_BEAD_DOUBLET_BARCODES.gel_bead_doublet_barcodes,
            REMOVE_LOW_TARGETING_BARCODES.low_targeting_barcodes,
        ],
    )

    call DETECT_CELL_BARCODES(
        fragments         = self.fragments,
        fragments_index   = self.fragments_index,
        barcode_whitelist = self.barcode_whitelist,
        force_cells       = self.force_cells,
        excluded_barcodes = MERGE_EXCLUDED_BARCODES.excluded_barcodes,
        reference_path    = self.reference_path,
        peaks             = self.peaks,
    )

    call MERGE_SUMMARY_METRICS as MERGE_CELL_METRICS(
        summary_jsons = [
            REMOVE_LOW_TARGETING_BARCODES.low_targeting_summary,
            REMOVE_GEL_BEAD_DOUBLET_BARCODES.gel_bead_doublet_summary,
            REMOVE_BARCODE_MULTIPLETS.barcode_multiplets_summary,
            DETECT_CELL_BARCODES.cell_calling_summary,
        ],
    )

    return (
        cell_barcodes        = DETECT_CELL_BARCODES.cell_barcodes,
        excluded_barcodes    = MERGE_EXCLUDED_BARCODES.excluded_barcodes,
        singlecell           = DETECT_CELL_BARCODES.singlecell,
        cell_calling_summary = MERGE_CELL_METRICS.merged_summary,
        fragment_lengths     = REMOVE_LOW_TARGETING_BARCODES.fragment_lengths,
        covered_bases        = REMOVE_LOW_TARGETING_BARCODES.covered_bases,
    )
}

#
# @include "_basic_sc_atac_counter.mro"
#

pipeline _BASIC_SC_ATAC_COUNTER(
    in  string     sample_id,
    in  string     fastq_mode                   "configuration of the input fastqs",
    in  map[]      sample_def,
    in  string     reference_path               "this is the reference_path",
    in  string     barcode_whitelist            "name of barcode whitelist file",
    in  map        trim_def,
    in  map        adapters,
    in  map        downsample,
    in  map        force_cells,
    out bam        possorted_bam                "bam file sorted by position",
    out bam.bai    possorted_bam_index          "position-sorted bam index",
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    out json       lot_info                     "gelbead lot detected",
    out json       read_counts                  "total # of read pairs before and after adapter trimming",
    out json       downsample_info              "info on downsampling",
    out csv        cell_barcodes,
    out json       excluded_barcodes,
    out json       cell_calling_summary,
    out bed        peaks,
    out bedgraph   cut_sites,
    out csv        singlecell_mapping,
    out csv        singlecell_cells,
    out json       peak_metrics,
    out bam[]      read_paired_bam,
    out h5         raw_peak_bc_matrix,
    out path       raw_peak_bc_matrix_mex,
    out h5         filtered_peak_bc_matrix,
    out path       filtered_peak_bc_matrix_mex,
)
{
    call _ALIGNER(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        read_group_sample = self.sample_id,
        trim_def          = self.trim_def,
        adapters          = self.adapters,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        downsample        = self.downsample,
    )

    call _SORT_AND_MARK_DUPS(
        align             = _ALIGNER.align,
        chunks            = _ALIGNER.chunks,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        bc_counts         = _ALIGNER.bc_counts,
    )

    call _PEAK_CALLER(
        fragments       = _SORT_AND_MARK_DUPS.fragments,
        fragments_index = _SORT_AND_MARK_DUPS.fragments_index,
        reference_path  = self.reference_path,
    )

    call _PRODUCE_CELL_BARCODES(
        fragments         = _SORT_AND_MARK_DUPS.fragments,
        fragments_index   = _SORT_AND_MARK_DUPS.fragments_index,
        peaks             = _PEAK_CALLER.peaks,
        force_cells       = self.force_cells,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
    )

    call GENERATE_PEAK_MATRIX(
        reference_path = self.reference_path,
        fragments      = _SORT_AND_MARK_DUPS.fragments,
        peaks          = _PEAK_CALLER.peaks,
    )

    call FILTER_PEAK_MATRIX(
        num_analysis_bcs = null,
        cell_barcodes    = _PRODUCE_CELL_BARCODES.cell_barcodes,
        raw_matrix       = GENERATE_PEAK_MATRIX.raw_matrix,
        random_seed      = null,
    )

    return (
        possorted_bam               = _SORT_AND_MARK_DUPS.possorted_bam,
        possorted_bam_index         = _SORT_AND_MARK_DUPS.possorted_bam_index,
        singlecell_mapping          = _SORT_AND_MARK_DUPS.singlecell_mapping,
        singlecell_cells            = _PRODUCE_CELL_BARCODES.singlecell,
        lot_info                    = _ALIGNER.lot_info,
        read_counts                 = _ALIGNER.read_counts,
        downsample_info             = _ALIGNER.downsample_info,
        cell_barcodes               = _PRODUCE_CELL_BARCODES.cell_barcodes,
        excluded_barcodes           = _PRODUCE_CELL_BARCODES.excluded_barcodes,
        cell_calling_summary        = _PRODUCE_CELL_BARCODES.cell_calling_summary,
        peak_metrics                = _PEAK_CALLER.peak_metrics,
        cut_sites                   = _PEAK_CALLER.cut_sites,
        peaks                       = _PEAK_CALLER.peaks,
        fragments                   = _SORT_AND_MARK_DUPS.fragments,
        fragments_index             = _SORT_AND_MARK_DUPS.fragments_index,
        read_paired_bam             = _SORT_AND_MARK_DUPS.read_paired_bam,
        raw_peak_bc_matrix          = GENERATE_PEAK_MATRIX.raw_matrix,
        raw_peak_bc_matrix_mex      = GENERATE_PEAK_MATRIX.raw_matrix_mex,
        filtered_peak_bc_matrix     = FILTER_PEAK_MATRIX.filtered_matrix,
        filtered_peak_bc_matrix_mex = FILTER_PEAK_MATRIX.filtered_matrix_mex,
    )
}

#
# @include "_sc_atac_metric_collector_stages.mro"
#

stage ESTIMATE_LIBRARY_COMPLEXITY(
    in  json   sequencing_summary,
    in  tsv.gz fragments,
    in  csv    cell_barcodes,
    out json   bulk_complexity,
    out json   complexity_summary,
    out json   singlecell_complexity,
    src py     "stages/metrics/estimate_library_complexity",
) split (
    in  file   barcodes,
) using (
    mem_gb   = 6,
    volatile = strict,
)

stage GENERATE_SEQUENCING_METRICS(
    in  bam[] input,
    out txt   misc_sm,
    out json  summary,
    src py    "stages/metrics/generate_sequencing_metrics",
) split (
    in  bam   chunk_bam,
) using (
    volatile = strict,
)

stage GENERATE_SINGLECELL_TARGETING(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  bed        peaks,
    in  string     reference_path,
    out csv        singlecell,
    out json       summary,
    out csv        tss_relpos,
    out csv        ctcf_relpos,
    src py         "stages/metrics/generate_singlecell_targeting",
) split (
    in  string     contig,
    out int        read_count,
    out pickle     target_counts_by_barcode,
    out pickle     chunk_tss,
    out pickle     chunk_ctcf,
) using (
    mem_gb   = 6,
    volatile = strict,
)

stage MERGE_SINGLECELL_METRICS(
    in  string reference_path,
    in  csv    singlecell_mapping,
    in  csv    singlecell_targets,
    in  csv    singlecell_cells,
    out csv    singlecell,
    out json   summary,
    src py     "stages/metrics/merge_singlecell_metrics",
) using (
    mem_gb   = 8,
    volatile = strict,
)

stage REPORT_INSERT_SIZES(
    in  tsv.gz fragments,
    in  bool   exclude_non_nuclear,
    in  string reference_path,
    out csv    insert_sizes,
    out json   insert_summary,
    src py     "stages/metrics/report_insert_sizes",
) split (
    in  file   barcode,
    out file   total,
) using (
    volatile = strict,
)

stage REPORT_TSS_CTCF(
    in  csv  tss_relpos,
    in  csv  ctcf_relpos,
    out json summary_metrics,
    src py   "stages/metrics/report_tss_ctcf",
) using (
    volatile = strict,
)

#
# @include "_sc_atac_metric_collector.mro"
#

pipeline _SC_ATAC_METRIC_COLLECTOR(
    in  bam[]      read_paired_bam,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  bed        peaks,
    in  string     reference_path         "this is the reference_path",
    in  csv        cell_barcodes,
    in  csv        singlecell_mapping,
    in  csv        singlecell_cells,
    out json       singlecell_results,
    out csv        singlecell,
    out json       enrichment_results,
    out json       basic_summary,
    out json       insert_summary,
    out csv        insert_sizes,
    out json       bulk_complexity,
    out json       singlecell_complexity,
    out json       complexity_summary,
    out csv        tss_relpos,
    out csv        ctcf_relpos,
)
{
    call GENERATE_SINGLECELL_TARGETING(
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
        peaks           = self.peaks,
        reference_path  = self.reference_path,
    )

    call MERGE_SINGLECELL_METRICS(
        reference_path     = self.reference_path,
        singlecell_mapping = self.singlecell_mapping,
        singlecell_cells   = self.singlecell_cells,
        singlecell_targets = GENERATE_SINGLECELL_TARGETING.singlecell,
    )

    call GENERATE_SEQUENCING_METRICS(
        input = self.read_paired_bam,
    )

    call ESTIMATE_LIBRARY_COMPLEXITY(
        sequencing_summary = GENERATE_SEQUENCING_METRICS.summary,
        fragments          = self.fragments,
        cell_barcodes      = self.cell_barcodes,
    )

    call REPORT_INSERT_SIZES(
        fragments           = self.fragments,
        reference_path      = self.reference_path,
        exclude_non_nuclear = true,
    )

    call REPORT_TSS_CTCF(
        tss_relpos  = GENERATE_SINGLECELL_TARGETING.tss_relpos,
        ctcf_relpos = GENERATE_SINGLECELL_TARGETING.ctcf_relpos,
    )

    return (
        ###
        singlecell            = MERGE_SINGLECELL_METRICS.singlecell,
        singlecell_results    = MERGE_SINGLECELL_METRICS.summary,
        ###
        enrichment_results    = REPORT_TSS_CTCF.summary_metrics,
        basic_summary         = GENERATE_SEQUENCING_METRICS.summary,
        insert_summary        = REPORT_INSERT_SIZES.insert_summary,
        insert_sizes          = REPORT_INSERT_SIZES.insert_sizes,
        bulk_complexity       = ESTIMATE_LIBRARY_COMPLEXITY.bulk_complexity,
        singlecell_complexity = ESTIMATE_LIBRARY_COMPLEXITY.singlecell_complexity,
        complexity_summary    = ESTIMATE_LIBRARY_COMPLEXITY.complexity_summary,
        tss_relpos            = GENERATE_SINGLECELL_TARGETING.tss_relpos,
        ctcf_relpos           = GENERATE_SINGLECELL_TARGETING.ctcf_relpos,
    )
}

#
# @include "_peak_annotator_stages.mro"
#

stage ANNOTATE_PEAKS(
    in  bed    peaks,
    in  string reference_path,
    out tsv    peak_annotation,
    src py     "stages/analysis/annotate_peaks",
) split (
    in  int    chunk_start,
    in  int    chunk_end,
) using (
    mem_gb   = 5,
    volatile = strict,
)

stage COMPUTE_GC_DISTRIBUTION(
    in  bed    peaks,
    in  string reference_path,
    out pickle GCdict,
    src py     "stages/analysis/compute_gc_dist",
) split (
) using (
    volatile = strict,
)

stage SCAN_MOTIFS(
    in  pickle globalGCdict,
    in  bed    peaks,
    in  string reference_path,
    in  float  pwm_threshold,
    out bed    peak_motif_hits,
    src py     "stages/analysis/scan_motifs",
) split (
    in  file   GCdict,
) using (
    volatile = strict,
)

stage GENERATE_TF_MATRIX(
    in  path reference_path,
    in  bed  peaks,
    in  bed  peak_motif_hits,
    in  h5   filtered_matrix,
    out h5   filtered_tf_bc_matrix,
    out path filtered_tf_bc_matrix_mex,
    out gz   tf_propZ_matrix,
    src py   "stages/analysis/generate_tf_matrix",
) split (
) using (
    volatile = strict,
)

#
# @include "_peak_annotator.mro"
#

pipeline _PEAK_ANNOTATOR(
    in  string reference_path,
    in  bed    peaks,
    in  h5     filtered_peak_bc_matrix,
    in  float  pwm_threshold,
    out h5     filtered_tf_bc_matrix,
    out path   filtered_tf_bc_matrix_mex,
    out gz     tf_propZ_matrix,
    out tsv    peak_annotation,
)
{
    call ANNOTATE_PEAKS(
        peaks          = self.peaks,
        reference_path = self.reference_path,
    )

    call COMPUTE_GC_DISTRIBUTION(
        peaks          = self.peaks,
        reference_path = self.reference_path,
    )

    call SCAN_MOTIFS(
        globalGCdict   = COMPUTE_GC_DISTRIBUTION.GCdict,
        peaks          = self.peaks,
        reference_path = self.reference_path,
        pwm_threshold  = self.pwm_threshold,
    )

    call GENERATE_TF_MATRIX(
        reference_path  = self.reference_path,
        peaks           = self.peaks,
        filtered_matrix = self.filtered_peak_bc_matrix,
        peak_motif_hits = SCAN_MOTIFS.peak_motif_hits,
    )

    return (
        filtered_tf_bc_matrix     = GENERATE_TF_MATRIX.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex = GENERATE_TF_MATRIX.filtered_tf_bc_matrix_mex,
        tf_propZ_matrix           = GENERATE_TF_MATRIX.tf_propZ_matrix,
        peak_annotation           = ANNOTATE_PEAKS.peak_annotation,
    )
}

#
# @include "_sc_atac_analyzer_stages.mro"
#

stage ANALYZER_PREFLIGHT(
    in  bed      peaks,
    in  h5       filtered_peak_bc_matrix,
    in  string[] factorization,
    in  int      tsne_perplexity,
    in  int      random_seed,
    in  float    tsne_theta,
    in  int      tsne_mom_switch_iter,
    in  int      tsne_stop_lying_iter,
    in  int      tsne_max_dims,
    in  int      tsne_input_pcs,
    in  int      tsne_max_iter,
    in  int      max_clusters,
    in  int      num_components,
    in  int      num_dr_bcs,
    in  int      num_dr_features,
    in  float    neighbor_a,
    in  float    neighbor_b,
    in  int      graphclust_neighbors,
    src py       "stages/preflight/atac_analyzer",
)

stage REDUCE_DIMENSIONS(
    in  h5       filtered_matrix,
    in  string[] factorization,
    in  int      num_dims,
    in  int      num_bcs,
    in  int      num_features,
    in  int      random_seed,
    out path     reduced_data,
    out map      reduction_summary,
    src py       "stages/analysis/reduce_dimensions",
) split (
    in  string   method,
) using (
    volatile = strict,
)

stage CLUSTER_CELLS(
    in  h5       filtered_matrix,
    in  path     reduced_data,
    in  map      reduction_summary,
    in  string[] factorization,
    in  int      minclusters,
    in  int      maxclusters,
    in  int      num_dims,
    in  int      random_seed,
    out path     clustered_data,
    out map      clustering_summary,
    src py       "stages/analysis/cluster_cells",
) split (
    in  int      n_clusters,
) using (
    volatile = strict,
)

stage PROJECT_TSNE(
    in  h5       filtered_matrix,
    in  path     reduced_data,
    in  map      reduction_summary,
    in  int      tsne_perplexity,
    in  int      tsne_max_dims,
    in  int      tsne_input_pcs,
    in  float    tsne_theta,
    in  int      tsne_max_iter,
    in  int      tsne_stop_lying_iter,
    in  int      tsne_mom_switch_iter,
    in  int      random_seed,
    in  string[] factorization,
    out path     tsne,
    out map      tsne_summary,
    src py       "stages/analysis/project_tsne",
) split (
    in  string   method,
    in  int      tsne_dims,
) using (
    volatile = strict,
)

stage RUN_GRAPH_CLUSTERING(
    in  h5       matrix_h5                 "Processed matrix",
    in  string[] factorization,
    in  path     reduced_data,
    in  map      reduction_summary,
    in  int      num_neighbors             "Use this many neighbors",
    in  float    neighbor_a                "Use larger of (a+b*log10(n_cells) neighbors or num_neighbors",
    in  float    neighbor_b                "Use larger of (a+b*log10(n_cells) neighbors or num_neighbors",
    in  int      balltree_leaf_size,
    in  string   similarity_type           "Type of similarity to use (nn or snn)",
    out h5       chunked_neighbors,
    out path     knn_clusters,
    out map      graph_clustering_summary,
    src py       "stages/analysis/run_graph_clustering",
) split (
    in  string   method,
    in  pickle   neighbor_index,
    in  h5       submatrix,
    in  int      row_start,
    in  int      total_rows,
    in  int      k_nearest,
    in  h5       use_bcs,
) using (
    volatile = strict,
)

stage COMBINE_CLUSTERING(
    in  h5   filtered_matrix,
    in  map  clustering_summary,
    in  path clustered_data,
    in  map  graph_clustering_summary,
    in  path knn_clusters,
    out path clustering,
    out map  clustering_summary,
    src py   "stages/analysis/combine_clustering",
) using (
    volatile = strict,
)

stage SUMMARIZE_ANALYSIS(
    in  tsv  peak_annotation,
    in  h5   filtered_peak_bc_matrix,
    in  h5   filtered_tf_bc_matrix,
    in  gz   tf_propZ_matrix,
    in  path reduced_data,
    in  map  reduction_summary,
    in  path clustering,
    in  map  clustering_summary,
    in  path tsne,
    in  map  tsne_summary,
    in  path enrichment_analysis,
    in  map  enrichment_analysis_summary,
    out h5   analysis,
    out path analysis_csv,
    out h5   feature_bc_matrix,
    src py   "stages/analysis/summarize_analysis",
) split (
) using (
    volatile = strict,
)

stage PERFORM_DIFFERENTIAL_ANALYSIS(
    in  bed      peaks,
    in  string   reference_path,
    in  h5       filtered_peak_bc_matrix,
    in  h5       filtered_tf_bc_matrix,
    in  string[] factorization,
    in  path     clustering,
    in  map      clustering_summary,
    out path     enrichment_analysis,
    out map      enrichment_analysis_summary,
    src py       "stages/analysis/perform_differential_analysis",
) split (
    in  string   method,
    in  string   clustering_key,
    in  int      cluster,
    out csv      tmp_diffexp,
) using (
    volatile = strict,
)

#
# @include "_sc_atac_analyzer.mro"
#

pipeline _SC_ATAC_ANALYZER(
    in  string   reference_path,
    in  bed      peaks,
    in  h5       filtered_peak_bc_matrix,
    in  string[] factorization,
    in  int      tsne_perplexity,
    in  int      random_seed,
    in  float    tsne_theta,
    in  int      tsne_mom_switch_iter,
    in  int      tsne_stop_lying_iter,
    in  int      tsne_max_dims,
    in  int      tsne_input_pcs,
    in  int      tsne_max_iter,
    in  int      max_clusters,
    in  int      num_components,
    in  int      num_dr_bcs,
    in  int      num_dr_features,
    in  float    neighbor_a,
    in  float    neighbor_b,
    in  int      graphclust_neighbors,
    out h5       analysis,
    out path     analysis_csv,
    out h5       filtered_tf_bc_matrix,
    out path     filtered_tf_bc_matrix_mex,
    out h5       feature_bc_matrix,
    out tsv      peak_annotation,
)
{
    call ANALYZER_PREFLIGHT(
        peaks                   = self.peaks,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        factorization           = self.factorization,
        tsne_perplexity         = self.tsne_perplexity,
        random_seed             = self.random_seed,
        tsne_theta              = self.tsne_theta,
        tsne_mom_switch_iter    = self.tsne_mom_switch_iter,
        tsne_stop_lying_iter    = self.tsne_stop_lying_iter,
        tsne_max_dims           = self.tsne_max_dims,
        tsne_input_pcs          = self.tsne_input_pcs,
        tsne_max_iter           = self.tsne_max_iter,
        max_clusters            = self.max_clusters,
        num_components          = self.num_components,
        num_dr_bcs              = self.num_dr_bcs,
        num_dr_features         = self.num_dr_features,
        neighbor_a              = self.neighbor_a,
        neighbor_b              = self.neighbor_b,
        graphclust_neighbors    = self.graphclust_neighbors,
    ) using (
        volatile = true,
    )

    call _PEAK_ANNOTATOR(
        reference_path          = self.reference_path,
        peaks                   = self.peaks,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        pwm_threshold           = null,
    )

    call REDUCE_DIMENSIONS(
        filtered_matrix = self.filtered_peak_bc_matrix,
        factorization   = self.factorization,
        num_dims        = self.num_components,
        num_bcs         = self.num_dr_bcs,
        num_features    = self.num_dr_features,
        random_seed     = self.random_seed,
    )

    call CLUSTER_CELLS(
        filtered_matrix   = self.filtered_peak_bc_matrix,
        reduced_data      = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary = REDUCE_DIMENSIONS.reduction_summary,
        factorization     = self.factorization,
        minclusters       = 2,
        maxclusters       = self.max_clusters,
        num_dims          = null,
        random_seed       = self.random_seed,
    )

    call PROJECT_TSNE(
        filtered_matrix      = self.filtered_peak_bc_matrix,
        reduced_data         = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary    = REDUCE_DIMENSIONS.reduction_summary,
        tsne_perplexity      = self.tsne_perplexity,
        tsne_max_dims        = self.tsne_max_dims,
        tsne_input_pcs       = self.tsne_input_pcs,
        tsne_theta           = self.tsne_theta,
        tsne_max_iter        = self.tsne_max_iter,
        tsne_stop_lying_iter = self.tsne_stop_lying_iter,
        tsne_mom_switch_iter = self.tsne_mom_switch_iter,
        random_seed          = self.random_seed,
        factorization        = self.factorization,
    )

    call RUN_GRAPH_CLUSTERING(
        matrix_h5          = self.filtered_peak_bc_matrix,
        factorization      = self.factorization,
        reduced_data       = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary  = REDUCE_DIMENSIONS.reduction_summary,
        num_neighbors      = self.graphclust_neighbors,
        neighbor_a         = self.neighbor_a,
        neighbor_b         = self.neighbor_b,
        balltree_leaf_size = null,
        similarity_type    = "nn",
    )

    call COMBINE_CLUSTERING(
        filtered_matrix          = self.filtered_peak_bc_matrix,
        clustering_summary       = CLUSTER_CELLS.clustering_summary,
        clustered_data           = CLUSTER_CELLS.clustered_data,
        graph_clustering_summary = RUN_GRAPH_CLUSTERING.graph_clustering_summary,
        knn_clusters             = RUN_GRAPH_CLUSTERING.knn_clusters,
    )

    call PERFORM_DIFFERENTIAL_ANALYSIS(
        reference_path          = self.reference_path,
        peaks                   = self.peaks,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        filtered_tf_bc_matrix   = _PEAK_ANNOTATOR.filtered_tf_bc_matrix,
        factorization           = self.factorization,
        clustering              = COMBINE_CLUSTERING.clustering,
        clustering_summary      = COMBINE_CLUSTERING.clustering_summary,
    )

    call SUMMARIZE_ANALYSIS(
        peak_annotation             = _PEAK_ANNOTATOR.peak_annotation,
        filtered_peak_bc_matrix     = self.filtered_peak_bc_matrix,
        filtered_tf_bc_matrix       = _PEAK_ANNOTATOR.filtered_tf_bc_matrix,
        tf_propZ_matrix             = _PEAK_ANNOTATOR.tf_propZ_matrix,
        reduced_data                = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary           = REDUCE_DIMENSIONS.reduction_summary,
        clustering                  = COMBINE_CLUSTERING.clustering,
        clustering_summary          = COMBINE_CLUSTERING.clustering_summary,
        tsne                        = PROJECT_TSNE.tsne,
        tsne_summary                = PROJECT_TSNE.tsne_summary,
        enrichment_analysis         = PERFORM_DIFFERENTIAL_ANALYSIS.enrichment_analysis,
        enrichment_analysis_summary = PERFORM_DIFFERENTIAL_ANALYSIS.enrichment_analysis_summary,
    )

    return (
        analysis                  = SUMMARIZE_ANALYSIS.analysis,
        analysis_csv              = SUMMARIZE_ANALYSIS.analysis_csv,
        filtered_tf_bc_matrix     = _PEAK_ANNOTATOR.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex = _PEAK_ANNOTATOR.filtered_tf_bc_matrix_mex,
        feature_bc_matrix         = SUMMARIZE_ANALYSIS.feature_bc_matrix,
        peak_annotation           = _PEAK_ANNOTATOR.peak_annotation,
    )
}

#
# @include "_sc_atac_reporter_stages.mro"
#

stage SUMMARIZE_REPORTS_SINGLECELL(
    in  string reference_path,
    in  json   complexity_summary,
    in  json   cell_calling_summary,
    in  json   peak_results,
    in  json   basic_results,
    in  json   error_results_summary,
    in  json   insert_summary,
    in  json   singlecell_results,
    in  json   contam_results,
    in  json   downsample_info,
    in  json   enrichment_results,
    out json   analysis_params,
    out json   summary,
    out csv    summary_csv,
    src py     "stages/reporter/summarize_reports_singlecell",
) using (
    mem_gb = 4,
)

stage CREATE_WEBSUMMARY(
    in  string reference_path,
    in  string barcode_whitelist,
    in  json   summary_results,
    in  json   bulk_complexity,
    in  json   singlecell_complexity,
    in  string sample_id,
    in  string sample_desc,
    in  map[]  sample_def,
    in  bool   debug,
    in  csv    singlecell,
    in  csv    insert_sizes,
    in  csv    tss_relpos,
    in  csv    ctcf_relpos,
    in  h5     filtered_peak_bc_matrix,
    in  h5     analysis,
    in  json   excluded_barcodes,
    out html   web_summary,
    src py     "stages/reporter/create_websummary",
) using (
    mem_gb = 16,
)

#
# @include "_sc_atac_reporter.mro"
#

pipeline _SC_ATAC_REPORTER(
    in  string reference_path,
    in  string barcode_whitelist,
    in  json   bulk_complexity,
    in  json   cell_calling_summary,
    in  json   complexity_summary,
    in  json   basic_summary,
    in  json   peak_summary,
    in  json   singlecell_results,
    in  json   insert_summary,
    in  json   downsample_info,
    in  json   singlecell_complexity,
    in  csv    singlecell,
    in  csv    tss_relpos,
    in  csv    ctcf_relpos,
    in  string sample_id,
    in  string sample_desc,
    in  map[]  sample_def,
    in  csv    sc_insert_sizes,
    in  json   enrichment_results,
    in  h5     filtered_peak_bc_matrix,
    in  h5     analysis,
    in  json   excluded_barcodes,
    #
    out json   summary,
    out html   web_summary,
    out csv    summary_csv,
)
{
    call SUMMARIZE_REPORTS_SINGLECELL(
        reference_path        = self.reference_path,
        complexity_summary    = self.complexity_summary,
        cell_calling_summary  = self.cell_calling_summary,
        peak_results          = self.peak_summary,
        basic_results         = self.basic_summary,
        error_results_summary = null,
        insert_summary        = self.insert_summary,
        singlecell_results    = self.singlecell_results,
        contam_results        = null,
        downsample_info       = self.downsample_info,
        enrichment_results    = self.enrichment_results,
    )

    call CREATE_WEBSUMMARY(
        reference_path          = self.reference_path,
        barcode_whitelist       = self.barcode_whitelist,
        singlecell              = self.singlecell,
        tss_relpos              = self.tss_relpos,
        ctcf_relpos             = self.ctcf_relpos,
        sample_id               = self.sample_id,
        sample_desc             = self.sample_desc,
        sample_def              = self.sample_def,
        insert_sizes            = self.sc_insert_sizes,
        summary_results         = SUMMARIZE_REPORTS_SINGLECELL.summary,
        bulk_complexity         = self.bulk_complexity,
        singlecell_complexity   = self.singlecell_complexity,
        analysis                = self.analysis,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        excluded_barcodes       = self.excluded_barcodes,
        debug                   = false,
    )

    return (
        summary     = SUMMARIZE_REPORTS_SINGLECELL.summary,
        web_summary = CREATE_WEBSUMMARY.web_summary,
        summary_csv = SUMMARIZE_REPORTS_SINGLECELL.summary_csv,
    )
}

#
# @include "_atac_cloupe_stages.mro"
#

stage CLOUPE_PREPROCESS(
    in  string     pipestance_type,
    in  string     sample_id,
    in  string     sample_desc,
    in  string     reference_path,
    in  h5         analysis,
    in  h5         feature_barcode_matrix,
    in  bed        peaks,
    in  tsv.gz.tbi fragments_index,
    in  json       metrics_json,
    in  csv        aggregation_csv,
    in  json       gem_group_index_json,
    in  bool       no_secondary_analysis,
    out cloupe     output_for_cloupe,
    out json       gem_group_index_json,
    src py         "stages/cloupe/atac_cloupe_preprocess",
) split (
)

#
# @include "_preflight_stages.mro"
#

stage ATAC_COUNTER_PREFLIGHT(
    in  string   sample_id,
    in  string   fastq_mode,
    in  map[]    sample_def,
    in  string   reference_path,
    in  map      force_cells,
    in  string[] factorization,
    in  map      downsample,
    in  bool     check_executables,
    in  map      trim_def,
    src py       "stages/preflight/atac_counter",
) split (
)

stage ATAC_AGGR_PREFLIGHT(
    in  string   sample_id,
    in  string   reference_path,
    in  csv      aggr_csv,
    in  string   normalization,
    in  string[] factorization,
    in  bool     check_executables,
    src py       "stages/preflight/atac_aggr",
) split (
)

stage ATAC_REANALYZER_PREFLIGHT(
    in  string     sample_id,
    in  string     reference_path,
    in  string     barcode_whitelist,
    in  bed        peaks,
    in  csv        parameters,
    in  map        force_cells,
    in  csv        cell_barcodes,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  csv        aggregation_csv,
    in  bool       check_executables,
    src py         "stages/preflight/atac_reanalyzer",
) split (
)

#
# @include "sc_atac_counter.mro"
#

pipeline SC_ATAC_COUNTER(
    in  string     fastq_mode,
    in  string     sample_id,
    in  map[]      sample_def,
    in  map        downsample,
    in  string     sample_desc,
    in  string     reference_path,
    in  map        trim_def,
    in  string     barcode_whitelist,
    in  map        adapters,
    in  string[]   factorization,
    in  map        force_cells,
    #
    out csv        singlecell,
    out bam        possorted_bam,
    out bam.bai    possorted_bam_index,
    out json       summary,
    out html       web_summary,
    out bed        peaks,
    out h5         raw_peak_bc_matrix,
    out path       raw_peak_bc_matrix_mex,
    out path       analysis_csv,
    out h5         filtered_peak_bc_matrix,
    out path       filtered_peak_bc_matrix_mex,
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    out h5         filtered_tf_bc_matrix,
    out path       filtered_tf_bc_matrix_mex,
    out cloupe     cloupe,
    out csv        summary_csv,
    out tsv        peak_annotation,
)
{
    call ATAC_COUNTER_PREFLIGHT as ATAC_COUNTER_PREFLIGHT_LOCAL(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        reference_path    = self.reference_path,
        force_cells       = self.force_cells,
        factorization     = self.factorization,
        downsample        = self.downsample,
        trim_def          = self.trim_def,
        check_executables = false,
    ) using (
        local     = true,
        preflight = true,
    )

    call ATAC_COUNTER_PREFLIGHT(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        reference_path    = self.reference_path,
        force_cells       = self.force_cells,
        factorization     = self.factorization,
        downsample        = self.downsample,
        trim_def          = self.trim_def,
        check_executables = true,
    ) using (
        preflight = true,
    )

    call _BASIC_SC_ATAC_COUNTER(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        trim_def          = self.trim_def,
        adapters          = self.adapters,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        downsample        = self.downsample,
        force_cells       = self.force_cells,
    )

    call _SC_ATAC_METRIC_COLLECTOR(
        read_paired_bam    = _BASIC_SC_ATAC_COUNTER.read_paired_bam,
        fragments          = _BASIC_SC_ATAC_COUNTER.fragments,
        fragments_index    = _BASIC_SC_ATAC_COUNTER.fragments_index,
        peaks              = _BASIC_SC_ATAC_COUNTER.peaks,
        reference_path     = self.reference_path,
        cell_barcodes      = _BASIC_SC_ATAC_COUNTER.cell_barcodes,
        singlecell_cells   = _BASIC_SC_ATAC_COUNTER.singlecell_cells,
        singlecell_mapping = _BASIC_SC_ATAC_COUNTER.singlecell_mapping,
    )

    call _SC_ATAC_ANALYZER(
        peaks                   = _BASIC_SC_ATAC_COUNTER.peaks,
        filtered_peak_bc_matrix = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        reference_path          = self.reference_path,
        factorization           = self.factorization,
        tsne_perplexity         = 30,
        tsne_max_dims           = null,
        tsne_input_pcs          = null,
        tsne_max_iter           = null,
        tsne_stop_lying_iter    = null,
        tsne_mom_switch_iter    = null,
        tsne_theta              = null,
        random_seed             = null,
        max_clusters            = 10,
        neighbor_a              = null,
        neighbor_b              = null,
        graphclust_neighbors    = null,
        num_components          = 15,
        num_dr_bcs              = null,
        num_dr_features         = null,
    )

    call CLOUPE_PREPROCESS(
        pipestance_type        = "SC_ATAC_COUNTER_CS",
        reference_path         = self.reference_path,
        sample_id              = self.sample_id,
        sample_desc            = self.sample_desc,
        analysis               = _SC_ATAC_ANALYZER.analysis,
        feature_barcode_matrix = _SC_ATAC_ANALYZER.feature_bc_matrix,
        metrics_json           = _SC_ATAC_METRIC_COLLECTOR.basic_summary,
        peaks                  = _BASIC_SC_ATAC_COUNTER.peaks,
        fragments_index        = _BASIC_SC_ATAC_COUNTER.fragments_index,
        aggregation_csv        = null,
        gem_group_index_json   = null,
        no_secondary_analysis  = false,
    )

    call _SC_ATAC_REPORTER(
        reference_path          = self.reference_path,
        barcode_whitelist       = self.barcode_whitelist,
        bulk_complexity         = _SC_ATAC_METRIC_COLLECTOR.bulk_complexity,
        singlecell_complexity   = _SC_ATAC_METRIC_COLLECTOR.singlecell_complexity,
        cell_calling_summary    = _BASIC_SC_ATAC_COUNTER.cell_calling_summary,
        complexity_summary      = _SC_ATAC_METRIC_COLLECTOR.complexity_summary,
        basic_summary           = _SC_ATAC_METRIC_COLLECTOR.basic_summary,
        peak_summary            = _BASIC_SC_ATAC_COUNTER.peak_metrics,
        singlecell_results      = _SC_ATAC_METRIC_COLLECTOR.singlecell_results,
        insert_summary          = _SC_ATAC_METRIC_COLLECTOR.insert_summary,
        downsample_info         = _BASIC_SC_ATAC_COUNTER.downsample_info,
        singlecell              = _SC_ATAC_METRIC_COLLECTOR.singlecell,
        tss_relpos              = _SC_ATAC_METRIC_COLLECTOR.tss_relpos,
        ctcf_relpos             = _SC_ATAC_METRIC_COLLECTOR.ctcf_relpos,
        sample_id               = self.sample_id,
        sample_desc             = self.sample_desc,
        sample_def              = self.sample_def,
        sc_insert_sizes         = _SC_ATAC_METRIC_COLLECTOR.insert_sizes,
        enrichment_results      = _SC_ATAC_METRIC_COLLECTOR.enrichment_results,
        filtered_peak_bc_matrix = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        analysis                = _SC_ATAC_ANALYZER.analysis,
        excluded_barcodes       = _BASIC_SC_ATAC_COUNTER.excluded_barcodes,
    )

    return (
        singlecell                  = _SC_ATAC_METRIC_COLLECTOR.singlecell,
        possorted_bam               = _BASIC_SC_ATAC_COUNTER.possorted_bam,
        possorted_bam_index         = _BASIC_SC_ATAC_COUNTER.possorted_bam_index,
        summary                     = _SC_ATAC_REPORTER.summary,
        web_summary                 = _SC_ATAC_REPORTER.web_summary,
        peaks                       = _BASIC_SC_ATAC_COUNTER.peaks,
        raw_peak_bc_matrix          = _BASIC_SC_ATAC_COUNTER.raw_peak_bc_matrix,
        raw_peak_bc_matrix_mex      = _BASIC_SC_ATAC_COUNTER.raw_peak_bc_matrix_mex,
        analysis_csv                = _SC_ATAC_ANALYZER.analysis_csv,
        filtered_peak_bc_matrix     = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        filtered_peak_bc_matrix_mex = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix_mex,
        fragments                   = _BASIC_SC_ATAC_COUNTER.fragments,
        fragments_index             = _BASIC_SC_ATAC_COUNTER.fragments_index,
        filtered_tf_bc_matrix       = _SC_ATAC_ANALYZER.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex   = _SC_ATAC_ANALYZER.filtered_tf_bc_matrix_mex,
        cloupe                      = CLOUPE_PREPROCESS.output_for_cloupe,
        summary_csv                 = _SC_ATAC_REPORTER.summary_csv,
        peak_annotation             = _SC_ATAC_ANALYZER.peak_annotation,
    )
}

#
# @include "sc_atac_counter_cs.mro"
#

# Customer-facing (CS) pipeline
pipeline SC_ATAC_COUNTER_CS(
    in  string     fastq_mode                   "Input fastq configuration",
    in  string     sample_id,
    in  map[]      sample_def,
    in  map        downsample,
    in  string     sample_desc                  "Sample description",
    in  string     reference_path               "Path to 10X reference package",
    in  string[]   factorization                "Dimensionality reduction method (lsa, plsa, or pca)",
    in  map        force_cells                  "Force cell calling to a fixed number",
    #
    out csv        singlecell                   "Per-barcode fragment counts & metrics",
    out bam        possorted_bam                "Position sorted BAM file"  "possorted_bam.bam",
    out bam.bai    possorted_bam_index          "Position sorted BAM index"  "possorted_bam.bam.bai",
    out json       summary                      "Summary of all data metrics",
    out html       web_summary                  "HTML file summarizing data & analysis",
    out bed        peaks                        "Bed file of all called peak locations",
    out h5         raw_peak_bc_matrix           "Raw peak barcode matrix in hdf5 format",
    out path       raw_peak_bc_matrix_mex       "Raw peak barcode matrix in mex format"  "raw_peak_bc_matrix",
    out path       analysis_csv                 "Directory of analysis files"  "analysis",
    out h5         filtered_peak_bc_matrix      "Filtered peak barcode matrix in hdf5 format",
    out path       filtered_peak_bc_matrix_mex  "Filtered peak barcode matrix in mex format"  "filtered_peak_bc_matrix",
    out tsv.gz     fragments                    "Barcoded and aligned fragment file"  "fragments.tsv.gz",
    out tsv.gz.tbi fragments_index              "Fragment file index"       "fragments.tsv.gz.tbi",
    out h5         filtered_tf_bc_matrix        "Filtered tf barcode matrix in hdf5 format",
    out path       filtered_tf_bc_matrix_mex    "Filtered tf barcode matrix in mex format"  "filtered_tf_bc_matrix",
    out cloupe     cloupe                       "Loupe Cell Browser input file",
    out csv        summary_csv                  "csv summarizing important metrics and values"  "summary.csv",
    out tsv        peak_annotation              "Annotation of peaks with genes",
)
{
    call SC_ATAC_COUNTER(
        fastq_mode        = self.fastq_mode,
        sample_id         = self.sample_id,
        sample_def        = self.sample_def,
        downsample        = self.downsample,
        sample_desc       = self.sample_desc,
        reference_path    = self.reference_path,
        trim_def          = {
            "R1": {
                "3prime": ["MErc"],
            },
            "R2": {
                "3prime": ["MErc"],
            },
            "discard_untrimmed": false,
        },
        barcode_whitelist = "737K-cratac-v1",
        adapters          = {
            "ME": "AGATGTGTATAAGAGACAG",
            "MErc": "CTGTCTCTTATACACATCT",
        },
        factorization     = self.factorization,
        force_cells       = self.force_cells,
    )

    return (
        singlecell                  = SC_ATAC_COUNTER.singlecell,
        possorted_bam               = SC_ATAC_COUNTER.possorted_bam,
        possorted_bam_index         = SC_ATAC_COUNTER.possorted_bam_index,
        summary                     = SC_ATAC_COUNTER.summary,
        web_summary                 = SC_ATAC_COUNTER.web_summary,
        peaks                       = SC_ATAC_COUNTER.peaks,
        raw_peak_bc_matrix          = SC_ATAC_COUNTER.raw_peak_bc_matrix,
        raw_peak_bc_matrix_mex      = SC_ATAC_COUNTER.raw_peak_bc_matrix_mex,
        analysis_csv                = SC_ATAC_COUNTER.analysis_csv,
        filtered_peak_bc_matrix     = SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        filtered_peak_bc_matrix_mex = SC_ATAC_COUNTER.filtered_peak_bc_matrix_mex,
        fragments                   = SC_ATAC_COUNTER.fragments,
        fragments_index             = SC_ATAC_COUNTER.fragments_index,
        filtered_tf_bc_matrix       = SC_ATAC_COUNTER.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex   = SC_ATAC_COUNTER.filtered_tf_bc_matrix_mex,
        cloupe                      = SC_ATAC_COUNTER.cloupe,
        summary_csv                 = SC_ATAC_COUNTER.summary_csv,
        peak_annotation             = SC_ATAC_COUNTER.peak_annotation,
    )
}

#
# @include "__PBMC_J.mro"
#

call SC_ATAC_COUNTER_CS(
    fastq_mode     = "ILMN_BCL2FASTQ",
    sample_id      = "PBMC_J",
    sample_def     = [{
        "bc_in_read": 1,
        "bc_length": 16,
        "gem_group": null,
        "lanes": null,
        "library": "LibraryNotSpecified",
        "read_path": "/data/isshamie/dropbox/ATACseq/2020_11_18_Croker/igm-storage2.ucsd.edu/201113_A00953_0185_AHN7TMDSXY",
        "sample_indices": ["any"],
        "sample_names": ["BC_10xATAC_PMBC_J"],
    }],
    reference_path = "/data/isshamie/mito_lineage/data/external/GRCh38_MT_blacklist",
    downsample     = null,
    sample_desc    = "",
    factorization  = ["lsa"],
    force_cells    = null,
)
ew file mode 100644
ndex 0000000..89e54b3
++ b/Analysis/mtscATAC/2020_11_18_Croker/PBMC_J/_sitecheck

bff70f2a6b3eb94f02e1f1143a2420fe512415e5

account for mt_pos being a list not an int

account for mt_pos being a list not an int

# TODO account for mt_pos being a list not an int

        return

    def sim_performance_dominant(self, group='both'):
        """
        Will average metrics over simulations.
        :param group: {'init', 'growth', 'both'} This will indicate to group by
        :return:
        """
        dropout = []
        rocs = []
        prec_scores = []

        for iter, s in enumerate(self.sim.values):
            # First get the dominant clone , which is indexed as 1
            mt_pos = s.clone_mt_dict[1]
            # TODO account for mt_pos being a list not an int
            if group == 'init':
                clones = s.clone_cell
                cell_af = s.cell_af.loc[:,mt_pos]
            elif group == 'growth':
                clones = s.new_clone_cell
                cell_af = s.new_cell_af.loc[:,mt_pos]
            elif group == 'both':
                #clones = pd.concat((s.clone_cell, s.subsample_new_clone_cell)).reset_index(drop=True)
                #cell_af = pd.concat((s.cell_af.loc[:,mt_pos], s.subsample_new_cell_af.loc[:,mt_pos])).reset_index(drop=True)
                clones = s.combined_clones
                cell_af = s.combined_cell_af.loc[:,mt_pos]
            else:
                raise ValueError('group variable not properly set.')

            y_true = clones.values.copy()
            y_true[y_true != 1] = 0  # Set nondominant clones to 0
            rocs.append(roc_curve(y_true, cell_af))
            prec_scores.append(average_precision_score(y_true, cell_af))
            dropout.append((cell_af[clones==1]==0).sum()/cell_af.shape[0])



        self.dropout = dropout
        self.prec_scores = prec_scores
        self.rocs = rocs
        return


    def reduce_cells(self, cell_af):
        #self.sim
        return


    def cluster_before_after(self):
        cluster_results = []
        print('clustering')
        for s in tqdm(self.sim.values):
            cluster_results.append(s.cluster(s.combined_cell_af))
            print(len(cluster_results[-1]))
        self.cluster_results = cluster_results
        return




    def save(self, f_save=None):
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')

942d8cea2e81ba510166aafe4a861624739af080

account for mt_pos being a list not an int

account for mt_pos being a list not an int

# TODO account for mt_pos being a list not an int

        return

    def sim_performance_dominant(self, group='both'):
        """
        Will average metrics over simulations.
        :param group: {'init', 'growth', 'both'} This will indicate to group by
        :return:
        """
        dropout = []
        rocs = []
        prec_scores = []

        for iter, s in enumerate(self.sim.values):
            # First get the dominant clone , which is indexed as 1
            mt_pos = s.clone_mt_dict[1]
            # TODO account for mt_pos being a list not an int
            if group == 'init':
                clones = s.clone_cell
                cell_af = s.cell_af.loc[:,mt_pos]
            elif group == 'growth':
                clones = s.new_clone_cell
                cell_af = s.new_cell_af.loc[:,mt_pos]
            elif group == 'both':
                #clones = pd.concat((s.clone_cell, s.subsample_new_clone_cell)).reset_index(drop=True)
                #cell_af = pd.concat((s.cell_af.loc[:,mt_pos], s.subsample_new_cell_af.loc[:,mt_pos])).reset_index(drop=True)
                clones = s.combined_clones
                cell_af = s.combined_cell_af.loc[:,mt_pos]
            else:
                raise ValueError('group variable not properly set.')

            y_true = clones.values.copy()
            y_true[y_true != 1] = 0  # Set nondominant clones to 0
            rocs.append(roc_curve(y_true, cell_af))
            prec_scores.append(average_precision_score(y_true, cell_af))
            dropout.append((cell_af[clones==1]==0).sum()/cell_af.shape[0])



        self.dropout = dropout
        self.prec_scores = prec_scores
        self.rocs = rocs
        return


    def reduce_cells(self, cell_af):
        #self.sim
        return


    def cluster_before_after(self):
        cluster_results = []
        print('clustering')
        for s in tqdm(self.sim.values):
            cluster_results.append(s.cluster(s.combined_cell_af))
            print(len(cluster_results[-1]))
        self.cluster_results = cluster_results
        return




    def save(self, f_save=None):
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')

17428fea16837769d2c48ca3576b25d92b6e968c

Add the MT clone map so it can contain multiple mutants in lineages

Add the MT clone map so it can contain multiple mutants in lineages

If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.

If not, will randomly draw based on number of clones

# TODO Add the MT clone map so it can contain multiple mutants in lineages

import numpy as np
from numpy import random
import os
import pandas as pd
import pickle
from src.simulations.utils.config import read_config_file, write_config_file
from src.simulations.utils.config import check_required


class Simulation:
    """Lineage tracing simulation of one sample

    Will initialize cells based on their parameters and grow as well. This
    should be a flexible framework, to add different ways to initialize, grow,
    and metrics to have. Additionally can cluster these results.

    :ivar params
    :type params: dict
    """

    def __init__(self, params_f):
        """
        :param params_f: Parameter yaml file for the specifications
        :type params_f: yaml file or dict
        """
        if isinstance(params_f, str):
            params = read_config_file(params_f)
        else:
            params = params_f

        self.params = params
        check_required(params, ['initialize', 'num_cells', 'num_mt_positions', 'prefix'])
        self.prefix = params['prefix']
        self.num_mt_positions = params['num_mt_positions']
        self.num_cells = params['num_cells']
        if not os.path.exists(params['local_outdir']):
            os.mkdir(params['local_outdir'])


    def initialize(self):
        """ (1) Pre-growth cell population is instantiated.

        Creates a clone-MT dictionary, cell coverage matrix
        (or an int, depending on parameters), and cell-AF matrix.
        :return:
        """
        self.init_clone_dict()
        self.init_cell_coverage()
        self.init_cell_af()
        #self.init_clone_mt()

    #should be external method
    def grow(self):
        """ (2) Growth of cells is run."""
        p = self.params
        type = p["growth"]["type"]
        if  type == "poisson":
            self.grow_poisson(p['growth']['poisson'])
        elif type == "binomial":
            self.grow_binomial(p['growth']['binomial'])
        return

    # Static Method
    @staticmethod
    def clone_counts_to_cell_series(clone_counts):
        """ Generates new cell IDs based on cluster count iterable
        :param clone_counts: Each i'th element is the number of cells in
        cluster i.
        :type clone_counts: iterable
        :return each index name is a cell ID and each value is which cluster
        the cell belongs too.
        :rtype pd.Series
        """
        clone_counts = np.array(clone_counts)
        num_cells = clone_counts.sum()
        clone_cell = -1 * np.ones(shape=[num_cells, ])

        clone_cell[:clone_counts[0]] = 0
        for ind, val in enumerate(clone_counts[1:]):
            start = clone_counts[:ind + 1].sum()
            end = clone_counts[:ind + 1].sum() + val
            # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
            clone_cell[start:end] = ind + 1

        clone_cell = pd.Series(clone_cell, dtype=int)
        return clone_cell


    def init_clone_dict(self):
        """1A
        """

        ### Add in potential to overwrite the values
        # Gets the clone dictionary. Should also have clone to mt dict.
        clones = self.params['initialize']['clone_sizes']

        if 'num_cells_population' not in self.params:
            self.num_cells_pop = self.num_cells
        else:
            self.num_cells_pop = self.params['num_cells_population']

        num_cells = self.num_cells_pop

        # Option 1: List of fraction of size of each clone. 0s are nonclone size, listed first
        if type(clones) == list:
            #clone_cell = pd.Series(index=range(num_cells))
            clone_counts = np.random.multinomial(num_cells, clones)
            clone_cell  = self.clone_counts_to_cell_series(clone_counts)
            self.clone_cell_pop = clone_cell

            # Choose subset to be sampled
            self.clone_cell = clone_cell.sample(n=self.num_cells).sort_values()
        # Option 2: 1 clone. ID'd as 1
        elif type(clones) == int: #One number for dominant clone. the others are not.
            clone_cell = np.zeros(shape=[num_cells,])
            clone_cell[:num_cells] = 1
            clone_cell = clone_cell[::-1]
            clone_cell =  pd.Series(clone_cell, dtype=int)
            self.clone_cell = clone_cell

        # Option 3 To ADD, beta binomial and more complex distributions

        self.num_clones =  len(set(clone_cell.values))-1 # Remove the non-clone
        return clone_cell


    def init_cell_coverage(self):
        """1B

        There are different modes to the coverage, either a constant or
        through a distribution.
        """
        p = self.params['initialize']['coverage']
        type = p['type']

        num_cells = self.num_cells
        num_pos = self.num_mt_positions
        c = np.zeros([num_cells, num_pos])

        if type == 'constant':
            c[:, :] = p['cov_constant']
        elif type == "poisson":
            # Get the number of coverage per cell based on poisson (should be reads)
            mu_cov_per_cell = p['mu_cov_per_cell']
            num_reads_per_cell = random.poisson(lam=mu_cov_per_cell,
                                                size=num_cells)

            # Number of reads at each position, based on the average for each cell
            for i in num_cells:
                c[i, :] = random.poisson(num_reads_per_cell[i],
                                         size=num_pos)
        self.cells_mt_coverage = c
        return c


    @staticmethod
    def create_cell_af(clone_df, mt_dict, n_cells, n_mt, num_clones,
                       cov_params, hets, het_err, coverage=None):
        cell_af = pd.DataFrame(np.zeros(shape=[n_cells, n_mt]))

        #p = self.params['initialize']

        ## Loop through each clone,
        ## Generate the AF for the clone and non-clones using coverage for each cell
        ## Fill in cell_by_af for that position.
        for ind in range(1, num_clones + 1):
            # Generate AF: (clone_df ==  ind).sum()
            n_dom_cells = (clone_df == ind).sum()
            het = hets[ind - 1]

            curr_mt = mt_dict[ind]

            if cov_params['coverage']['type'] == 'constant':
                c = cov_params['coverage']['cov_constant']

                af_i = random.binomial(c, het, n_dom_cells) / c
                af_j = random.binomial(c, het_err, n_cells - n_dom_cells) / c

                # Update the dom_cells and non_dom for the current MT
                cell_af.loc[
                    np.flatnonzero(clone_df == ind), curr_mt] = af_i
                cell_af.loc[
                    np.flatnonzero(clone_df != ind), curr_mt] = af_j

            # Each cell and position has it's own coverage value, so need to update each
            else:
                if coverage is None:
                    raise("coverage needs to be assigned")
                c = coverage

                # Get the cells coverage for the mt position
                curr_mt_cov = c[:, curr_mt]

                # Get cell indicies for the clones and nonclones
                curr_clone_inds = np.flatnonzero(clone_df == ind)
                curr_nonclone_inds = np.flatnonzero(clone_df != ind)
                for cell in curr_clone_inds:
                    # Get one value for curr_mt and cell based on coverage
                    cell_af.loc[cell, curr_mt] = random.binomial(
                        curr_mt_cov[cell], het)
                for cell in curr_nonclone_inds:
                    cell_af.loc[cell, curr_mt] = random.binomial(
                        curr_mt_cov[cell], het_err)
        return cell_af


    ##########
    def init_cell_af(self):
        """1C. Initialize the cell-by-mtPos af dataframe. Unless a clone:mt dict was
        provided, the first N MT positions will be the clone AFs. Creates
        self.clone_mt_dict and self.cell_af
        """
        p = self.params['initialize']
        hets = self.params['het']
        clone_df = self.clone_cell
        num_clones = self.num_clones
        n_cells = self.num_cells
        n_mt = self.num_mt_positions

        # Get the MT map
        if 'mt_clone_map' in p and p['mt_clone_map'] is not None:
            self.clone_mt_dict = p['mt_clone_map']
        else:
            # Each clone points to a mt position
            self.clone_mt_dict = dict()
            for i in range(1,num_clones+1):
                self.clone_mt_dict[i] = i

        # TODO Add the MT clone map so it can contain multiple mutants in lineages
        # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
        # If not, will randomly draw based on number of clones
        if type(hets) == list:
            assert(len(hets) == num_clones)

        # Get the cell_af based on MT dictionary and cell coverage
        self.cell_af = self.create_cell_af(clone_df, self.clone_mt_dict,
                                           n_cells, n_mt, num_clones,
                                           self.params['initialize'],
                                           hets,
                                           self.params['het_err_rate'],
                                           coverage=None)
        return


    def init_clone_mt(self):
        p = self.params
        if p["initialize"]['type'] == 'growth':
            ## TODO
            # Create a phylogeny and then get the averages of the mutants
            self.average_clone_mt()
        # If not growth, should aready be there.
        return

    def average_clone_mt(self):
        return

    @staticmethod
    def extract_clone_cells(clone_cell, clone_id):
        """ Returns the numbered indices of the specific clones

        :param clone_cell: Each element is the indexed cell's clone label.
        :type clone_cell: np array or pd.Series

        :param clone_id:
        :type clone_id: int or string
        """
        ids = np.flatnonzero(clone_cell == clone_id)
        return ids

    @staticmethod
    def simulate_expand_cells_af(af, growth_inds, sigma):
        """Given a cell-by-af vector, expand the AF.

        Expanded AF occurs by duplicating cells that grew based on the
        growth_inds vector. It will add standard error to each af based on sigma
        :param af: :param growth: Indices of AF to copy :param sigma: Variance
        to add to AF of child. :return:

        Args:
            af:
            growth_inds:
            sigma:
        """

        new_af = af.iloc[growth_inds].copy() + random.normal(0, sigma, size=af.iloc[growth_inds].shape)
        new_af.index = np.arange(af.index[-1]+1, af.index[-1]+1+new_af.shape[0])
        new_af = pd.concat((af,new_af), axis=0)
        #new_af = np.append(af, np.concatenate(new_af))
        return new_af

    def grow_binomial(self, p):
        """ (2.1.2)
        :param p: contains time_steps, rates,
        :type dict
        """
        timesteps = p["time_steps"]
        rates = p["rates"]

        num_clones = self.num_clones+1
        new_dict = {}
        for curr_clone in range(num_clones):
            curr_rate = rates[curr_clone]
            ids = self.extract_clone_cells(self.clone_cell_pop, curr_clone)
            num_curr_cells = len(ids)

            for i in range(timesteps):
                # Simulate growth for each clone separately.
                growth_inds = (random.binomial(1, curr_rate, size=num_curr_cells)).sum()
                num_curr_cells += growth_inds.sum()

            new_dict[curr_clone] = num_curr_cells

        ####TODO
        ## new_lineage_mutants chances. This will see if a mutation will change
        ####TODO
        ## Add death + stimulation rate as well as growth
        # Save the new cell clones df and cell af
        clone_counts = [i for i in new_dict.values()]
        self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)
        # Do not make cell_af, will make this only when subsampled.


        # self.new_cell_af = pd.DataFrame()
        # for clone in range(1, self.num_clones+1):
        #     self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
        return


    def grow_binomial_old(self, p):
        """ (2.1.1)
        :param p: contains time_steps, rates,
                 and [growth][mutant_af_sigma_noise
        :type dict
        """
        timesteps = p["time_steps"]
        rates = p["rates"]

        sigma = self.params['growth']["mutant_af_sigma_noise"]
        cell_af = self.cell_af

        num_clones = self.num_clones+1
        new_dict = {}
        for curr_clone in range(num_clones):
            curr_rate = rates[curr_clone]
            ids = self.extract_clone_cells(self.clone_cell, curr_clone)
            new_cells = cell_af.iloc[ids].copy()
            for i in range(timesteps):
                # Simulate growth for each clone separately.
                growth_inds = np.flatnonzero(random.binomial(1, curr_rate, size=new_cells.shape[0]))
                #new_ids =
                new_cells = self.simulate_expand_cells_af(new_cells, growth_inds, sigma)

            new_dict[curr_clone] = new_cells
            # Create list of cells
        ####TODO
        ## new_lineage_mutants chances. This will see if a mutation will change
        ####TODO
        ## Add death + stimulation rate as well as growth
        # Save the new cell clones df and cell af
        clone_counts = [i.shape[0] for i in new_dict.values()]
        self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)
        self.new_cell_af = pd.DataFrame(new_dict[0])
        for clone in range(1, self.num_clones+1):
            self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
        return


    def grow_poisson(self, p):
        # TODO growth of poisson refactor

        # Sample from poisson the growth
        grow_clones = random.poisson(lam=p['clone_growth'],
                                     size=(p['clone_meta']))
        grow_nonclones = random.poisson(lam=p['non_clone_growth'],
                                        size=(p['clone_meta']))
        return


    def subsample_new(self, to_delete=False):
        """(3) Subsample from new cell population and generate cell_af

        :param to_delete: To remove the cells that grew (which takes up
        a lot of RAM).
        :type to_delete: bool
        """
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_clone_cell = self.new_clone_cell.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_clone_cell = self.new_clone_cell.sample(
                n=self.num_cells)

        #print(f'New cell af, {len(self.subsample_new_clone_cell)}')
        # Generate subsample_new_cell_af
        self.subsample_new_cell_af = self.create_cell_af(clone_df=self.subsample_new_clone_cell,
                                                         mt_dict = self.clone_mt_dict,
                                                         n_cells=len(self.subsample_new_clone_cell),
                                                         n_mt=self.num_mt_positions,
                                                         num_clones=self.num_clones,
                                                         cov_params=p['initialize'],
                                                         hets=
                                                             self.params[
                                                                 'het'],
                                                         het_err=self.params['het_err_rate'],
                                                         coverage=None
                                                         )

        if to_delete:
            self.new_cell_af = None
            self.new_clone_cell = None


    def subsample_new_old(self, to_delete=False):
        """(3) Subsample from new cell population

        :param to_delete: To remove the cells that grew (which takes up
        a lot of RAM).
        :type to_delete: bool
        """
        new_cell_af = self.new_cell_af
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)

        self.subsample_new_clone_cell = self.new_clone_cell.loc[
            self.subsample_new_cell_af.index]

        if to_delete:
            self.new_cell_af = None
            self.new_clone_cell = None


    def combine_init_growth(self):
        """(4) Add the pre- and post- population of cells into a group.

        :return:
        """
        combined_cell_af = self.cell_af.append(self.subsample_new_cell_af).reset_index(drop=True)
        combined_clones = pd.concat(
            (self.clone_cell, self.subsample_new_clone_cell)).reset_index(
            drop=True)

        combined_befaft = np.concatenate((np.zeros(shape=[self.cell_af.shape[0],]), np.ones(shape=[self.subsample_new_cell_af.shape[0]])))
        combined_meta = pd.DataFrame({"pre_post": combined_befaft, "clone": combined_clones})
        #combined_meta = pd.Series(combined_meta, name='After Growth', dtype=int)
        assert(combined_meta.shape[0] == self.cell_af.shape[0]+self.subsample_new_cell_af.shape[0])
        assert (combined_cell_af.shape[0] == self.cell_af.shape[0] +
                self.subsample_new_cell_af.shape[0])
        assert(combined_meta.shape[0] == combined_clones.shape[0])
        assert(combined_cell_af.shape[0] == combined_clones.shape[0])
        self.combined_meta = combined_meta
        self.combined_clones = combined_clones
        self.combined_cell_af = combined_cell_af
        return

    def save(self, f_save=None):
        """
        Args:
            f_save:
        """
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

    @staticmethod
    def expand_to_mgatk(curr_mt_af,mt_ref):
        ref = mt_ref[curr_mt_af.name]
        pos = curr_mt_af.name
        return pd.DataFrame({"Ref":ref, "Pos":pos, "Val":curr_mt_af})

    def test_save_to_mgatk_format(self):
        df = pd.DataFrame( [[10,0,1,3,5], [3,0,5,5,0], [6,2,1,1,0]] , columns=np.arange(0,5))
        mt_ref_dict = {0: "A", 1: "G", 2: "C", 3: "C", 4: "T"}
        mt_ref = pd.DataFrame({"Pos": mt_ref_dict.keys(), "Ref": mt_ref_dict})
        return

    def save_to_mgatk_format(self, mt_ref, out_f):
        """Converts into the proper files needed for mgatk. (i.e variant and
        coverage files)

        :return:
        """
        cell_af = self.subsample_new_cell_af
        chars = ["A", "G", "C", "T"]
        def alt_generate(x):
            curr = chars.copy()
            curr.remove(x["Ref"])
            return np.random.choice(curr)
        alt_ref = mt_ref.apply(alt_generate, axis=1)

        # First use the AF and choose an alternative allele
        df_stack = cell_af.stack().reset_index().rename(
            {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
            axis=1)
        df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
            lambda x: alt_ref[x])

        # Add on the reference allele
        df_stack_ref = cell_af.stack().reset_index().rename(
            {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
            axis=1)
        df_stack_ref["Coverage"] = 1-df_stack_ref["Coverage"]
        df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
            lambda x: mt_ref[x])

        # Save the NTs.
        # For concordance, split the coverage in two
        df_stack = pd.concat(df_stack, df_stack_ref)
        for ind, val in df_stack.groupby("Nucleotide"):
            # Drop the 0s
            curr = val[val["Coverage"]>0]
            # Save file
            curr_out_f = out_f + "_" + ind + ".txt"
            curr.to_csv(curr_out_f)

        # Save the coverage.
        coverage = self.cells_mt_coverage
        if type(coverage) != int:
            coverage_stack = pd.DataFrame(coverage).stack().reset_index().rename(
                {"level_0": "Cell", "level_1": "MT Position", 0: "Coverage"},
                axis=1)
        else:
            coverage_stack = pd.DataFrame(self.cells_mt_coverage)*np.ones(shape=cell_af.shape).stack().reset_index().rename(
                {"level_0": "Cell", "level_1": "MT Position",  0: "Coverage"},
                axis=1)
        curr_out_f = out_f + "_" + "coverage.txt"
        coverage_stack.to_csv(curr_out_f)
        return

    def load(self):
        filename = self.params['filename']
        f = open(filename, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)

    def compare_before_after(self):
        """Creates a df that contains information on the number of cells from
        each clone before as well as after. :return: df.at[ind, "Dominant
        Before"] = (full_sim.clone_cell == 1).sum() df.at[ind, "Dominant After"]
        = (full_sim.subsample_new_clone_cell == 1).sum()
        """

        return

    def cluster_compare_before_after(self):
        """Compares the performance of clustering on grouping the same clones
        together. :return:
        """
        return


def main():
    return


if "__name__" == "__main__":
    main()
 No newline at end of file
ew file mode 100644
ndex 0000000..e30dc8e
++ b/src/simulations/simulation_pymc.gv

0eb0a23916e0fb901644270002db1c14be3410fb

Running simulation overflows RAM

  • Have different f_saves in the fullsimulation to separate the metrics from the simulation data created
  • Move the parallel_apply to the hyperparameters
  • When looping over results in parametersweep, need to load the files and then get results

account for mt_pos being a list not an int

account for mt_pos being a list not an int

# TODO account for mt_pos being a list not an int

import numpy as np
from numpy import random
import os
import pandas as pd
from tqdm import tqdm
#from src.config import ROOT_DIR
from sklearn.metrics import roc_curve, average_precision_score, confusion_matrix
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import glob
from sklearn.cluster import KMeans
from sklearn import metrics
from scipy.spatial.distance import cdist
from pandarallel import pandarallel

from mplh.color_utils import get_colors
from mplh.fig_utils import legend_from_color
from mplh import cluster_help as ch
from src.simulations.utils.config import read_config_file, write_config_file

from dynamicTreeCut import cutreeHybrid
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage
from sklearn.model_selection import ParameterGrid
from src.simulations.utils.config import check_required

from .simulation import Simulation


# Does this ruin running the MCMC? I don't think so, b/c that format is going to be put in after anyway
class FullSimulation:
    """
    Class that simulates cell growth for lineage tracing. Reads in a
    parameter file and runs a certain number of iterations based on
    the num_iterations parameter.

    :ivar n_iter: Number of iterations
    :type n_iter: int
    :ivar num_cells: Number of cells to sequence
    :type num_cells: int

    :ivar sim: Each index is a different iteration of the simulation.
    :type sim: pandas Series

    """
    def __init__(self, params_f):
        params = read_config_file(params_f)
        self.n_iter = params['num_iterations']
        self.num_cells = params['num_cells']
        self.params = params
        self.f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
        return
        #for i in self.n_iter:

    def run(self):
        """
        Runs the simulation and stores it in sim attr.

        This uses Pandaralel to parallelize the runs.
        :return:
        """
        # Parallelize df
        df = pd.Series(index=range(self.n_iter))
        #df = df.apply(self.run_sim, args=(self.params,))
        pandarallel.initialize(nb_workers=self.params['cpus'])
        df = df.parallel_apply(self.run_sim, args=(self.params,))

        self.sim = df
        #self.cluster_before_after()
        self.sim_performance_dominant(group='both')
        self.stats_before_after()
        return

    @staticmethod
    def run_sim(x, params):
        """
        For a simulation, it will initialize, grow, subsample,
        and merge the before stimulus and after stimulus variables.
        :param x: Placeholder variable
        :param params: The parameter dictionary to use
        :return:
        """
        s = Simulation(params)
        s.initialize()
        s.grow()
        s.subsample_new(to_delete=True)
        s.combine_init_growth()
        return s

    def flatten_sim(self):
        ## TODO
        # This will extract out the classes of df
        return

    def sim_performance_dominant(self, group='both'):
        """
        Will colect metrics that are averaged over the simulations.
        These are specifically for looking at the main, dominant clone,
        and what the allele-frequency of that clone variant
        is for each cell.

        :param group: {'init', 'growth', 'both'} This will indicate to group by
        :ivar dropout: Number of dominant clone cells that have 0 reads
        at the lineage variant position.
        :type dropout: list
        :ivar prec_scores: sklearn average precision score based on
        the allele frequencies seen in the dominant clone cells versus
        the non-clone cells.
        :type prec_scores: list
        :ivar rocs: ROC curves for each iteration based on allele
        frequencies.

        :return:
        """
        dropout = []
        rocs = []
        prec_scores = []


        for iter, s in enumerate(self.sim.values):
            # First get the dominant clone , which is indexed as 1
            mt_pos = s.clone_mt_dict[1]
            # TODO account for mt_pos being a list not an int
            if group == 'init':
                clones = s.clone_cell
                cell_af = s.cell_af.loc[:,mt_pos]
            elif group == 'growth':
                clones = s.new_clone_cell
                cell_af = s.new_cell_af.loc[:,mt_pos]
            elif group == 'both':
                #clones = pd.concat((s.clone_cell, s.subsample_new_clone_cell)).reset_index(drop=True)
                #cell_af = pd.concat((s.cell_af.loc[:,mt_pos], s.subsample_new_cell_af.loc[:,mt_pos])).reset_index(drop=True)
                clones = s.combined_clones
                cell_af = s.combined_cell_af.loc[:,mt_pos]
            else:
                raise ValueError('group variable not properly set.')

            y_true = clones.values.copy()
            y_true[y_true != 1] = 0  # Set nondominant clones to 0
            rocs.append(roc_curve(y_true, cell_af))
            prec_scores.append(average_precision_score(y_true, cell_af))
            dropout.append((cell_af[clones==1]==0).sum()/cell_af.shape[0])

        self.dropout = dropout
        self.prec_scores = prec_scores
        self.rocs = rocs
        return


    def reduce_cells(self, cell_af):
        #self.sim
        return


    def stats_before_after(self, clone_id=1):
        b_a_df = pd.DataFrame(index=np.arange(0,len(self.sim)), columns=["Before", "After", "A/B"], dtype=str)
        for iter, s in enumerate(self.sim.values):
            b_clones = s.clone_cell
            a_clones = s.subsample_new_clone_cell
            b_a_df.at[iter, "Before"] = (b_clones == clone_id).sum()
            b_a_df.at[iter, "After"] = (a_clones==clone_id).sum()
            b_a_df.at[iter,"A/B"] = (b_a_df.at[iter, "After"]/b_a_df.at[iter, "Before"])
        self.b_a_df = b_a_df
        return


    def cluster_before_after(self):
        """
        Loops through the simulations and for each,
        it clusters the cells.

        :ivar cluster_results: Cluster labels for each cell in each
        iteration.
        :type List of tuples, which is a list of
        a tuple, where the tuple is indexed by the cell and the value
        is the cell's cluster label
        """
        cluster_results = []
        print('clustering')
        for s in tqdm(self.sim.values):
            cluster_results.append(s.cluster(s.combined_cell_af))
            print(len(cluster_results[-1]))
        self.cluster_results = cluster_results


    def stats_cluster_before_after(self, clone_id=1):
        """
        Confusion matrix for clustering the proper clone cells together.
        :param clone_id: Which clone to get metrics for
        :return:
        """


        b_a_df = pd.DataFrame(index=len(self.sim),
                              columns=["TN", "FP", "FN", "TP"], dtype=int)
        for ind, s in enumerate(self.sim.values):
            y_true = s.combined_clones
            y_true[y_true!=1] = 0
            y_pred = self.cluster_results[ind]

            # y_true, y_pred
            tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
            b_a_df.loc[ind] = [tn, fp, fn, tp]
        self.b_a_df = b_a_df
        return



    def save(self, f_save=None):
        if f_save is None:
            f_save = self.f_save
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()


    def load(self, f_save=None):
        #filename = self.params['filename']
        if f_save is None:
            f_save = self.f_save
        f = open(f_save, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)


def main():
    return


if "__name__" == "__main__":
    main()
 No newline at end of file
ew file mode 100644
ndex 0000000..a9e1360
++ b/src/simulations/parametersweep.py

67225f62aefbb2d8d2777307132224bda6f57ada

Add the MT clone map so it can contain multiple mutants in lineages

Add the MT clone map so it can contain multiple mutants in lineages

# TODO Add the MT clone map so it can contain multiple mutants in lineages

            for i in range(1,num_clones+1):
                self.clone_mt_dict[i] = i

        # TODO Add the MT clone map so it can contain multiple mutants in lineages

        # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
        # If not, will randomly draw based on number of clones

38915dc321f20b046fb2d1277c1be89a6fc623ed

From bam to UMIs and cell filter

Can we use the cellranger toolkit to get the cells to filter out?
Need to use bam2fastq since our downloaded fastq is not complete.

growth of poisson refactor

growth of poisson refactor

# TODO growth of poisson refactor

from numpy import random
import os
import pandas as pd
import pickle
from src.simulations.utils.config import read_config_file, write_config_file
from src.simulations.utils.config import check_required


class Simulation:
    """Lineage tracing simulation of one sample

    Will initialize cells based on their parameters and grow as well. This
    should be a flexible framework, to add different ways to initialize, grow,
    and metrics to have. Additionally can cluster these results.

    :ivar params
    :type params: dict
    """

    def __init__(self, params_f):
        """
        :param params_f: Parameter yaml file for the specifications
        :type params_f: yaml file or dict
        """
        if isinstance(params_f, str):
            params = read_config_file(params_f)
        else:
            params = params_f

        self.params = params
        check_required(params, ['initialize', 'num_cells', 'num_mt_positions', 'prefix'])
        self.prefix = params['prefix']
        self.num_mt_positions = params['num_mt_positions']
        self.num_cells = params['num_cells']
        if not os.path.exists(params['local_outdir']):
            os.mkdir(params['local_outdir'])


    def initialize(self):
        """ (1) Pre-growth cell population is instantiated.

        Creates a clone-MT dictionary, cell coverage matrix
        (or an int, depending on parameters), and cell-AF matrix.
        :return:
        """
        self.init_clone_dict()
        self.init_cell_coverage()
        self.init_cell_af()
        #self.init_clone_mt()

    #should be external method
    def grow(self):
        """ (2) Growth of cells is run."""
        p = self.params
        type = p["growth"]["type"]
        if  type == "poisson":
            self.grow_poisson(p['growth']['poisson'])
        elif type == "binomial":
            self.grow_binomial(p['growth']['binomial'])
        return

    # Static Method
    @staticmethod
    def clone_counts_to_cell_series(clone_counts):
        """ Generates new cell IDs based on cluster count iterable
        :param clone_counts: Each i'th element is the number of cells in
        cluster i.
        :type clone_counts: iterable
        :return each index name is a cell ID and each value is which cluster
        the cell belongs too.
        :rtype pd.Series
        """
        clone_counts = np.array(clone_counts)
        num_cells = clone_counts.sum()
        clone_cell = -1 * np.ones(shape=[num_cells, ])

        clone_cell[:clone_counts[0]] = 0
        for ind, val in enumerate(clone_counts[1:]):
            start = clone_counts[:ind + 1].sum()
            end = clone_counts[:ind + 1].sum() + val
            # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
            clone_cell[start:end] = ind + 1

        clone_cell = pd.Series(clone_cell, dtype=int)
        return clone_cell


    def init_clone_dict(self):
        """1A
        """

        ### Add in potential to overwrite the values
        # Gets the clone dictionary. Should also have clone to mt dict.
        clones = self.params['initialize']['clone_sizes']

        if 'num_cells_population' not in self.params:
            self.num_cells_pop = self.num_cells
        else:
            self.num_cells_pop = self.params['num_cells_population']

        num_cells = self.num_cells_pop

        # Option 1: List of fraction of size of each clone. 0s are nonclone size, listed first
        if type(clones) == list:
            #clone_cell = pd.Series(index=range(num_cells))
            clone_counts = np.random.multinomial(num_cells, clones)
            clone_cell  = self.clone_counts_to_cell_series(clone_counts)
            self.clone_cell_pop = clone_cell

            # Choose subset to be sampled
            self.clone_cell = clone_cell.sample(n=self.num_cells).sort_values()
        # Option 2: 1 clone. ID'd as 1
        elif type(clones) == int: #One number for dominant clone. the others are not.
            clone_cell = np.zeros(shape=[num_cells,])
            clone_cell[:num_cells] = 1
            clone_cell = clone_cell[::-1]
            clone_cell =  pd.Series(clone_cell, dtype=int)
            self.clone_cell = clone_cell

        # Option 3 To ADD, beta binomial and more complex distributions

        self.num_clones =  len(set(clone_cell.values))-1 # Remove the non-clone
        return clone_cell


    def init_cell_coverage(self):
        """1B

        There are different modes to the coverage, either a constant or
        through a distribution.
        """
        p = self.params['initialize']['coverage']
        type = p['type']

        num_cells = self.num_cells
        num_pos = self.num_mt_positions
        c = np.zeros([num_cells, num_pos])

        if type == 'constant':
            c[:, :] = p['cov_constant']
        elif type == "poisson":
            # Get the number of coverage per cell based on poisson (should be reads)
            mu_cov_per_cell = p['mu_cov_per_cell']
            num_reads_per_cell = random.poisson(lam=mu_cov_per_cell,
                                                size=num_cells)

            # Number of reads at each position, based on the average for each cell
            for i in num_cells:
                c[i, :] = random.poisson(num_reads_per_cell[i],
                                         size=num_pos)
        self.cells_mt_coverage = c
        return c


    @staticmethod
    def create_cell_af(clone_df, mt_dict, n_cells, n_mt, num_clones,
                       cov_params, hets, het_err, coverage=None):
        cell_af = pd.DataFrame(np.zeros(shape=[n_cells, n_mt]))

        #p = self.params['initialize']

        ## Loop through each clone,
        ## Generate the AF for the clone and non-clones using coverage for each cell
        ## Fill in cell_by_af for that position.
        for ind in range(1, num_clones + 1):
            # Generate AF: (clone_df ==  ind).sum()
            n_dom_cells = (clone_df == ind).sum()
            het = hets[ind - 1]

            curr_mt = mt_dict[ind]

            if cov_params['coverage']['type'] == 'constant':
                c = cov_params['coverage']['cov_constant']

                af_i = random.binomial(c, het, n_dom_cells) / c
                af_j = random.binomial(c, het_err, n_cells - n_dom_cells) / c

                # Update the dom_cells and non_dom for the current MT
                cell_af.loc[
                    np.flatnonzero(clone_df == ind), curr_mt] = af_i
                cell_af.loc[
                    np.flatnonzero(clone_df != ind), curr_mt] = af_j

            # Each cell and position has it's own coverage value, so need to update each
            else:
                if coverage is None:
                    raise("coverage needs to be assigned")
                c = coverage

                # Get the cells coverage for the mt position
                curr_mt_cov = c[:, curr_mt]

                # Get cell indicies for the clones and nonclones
                curr_clone_inds = np.flatnonzero(clone_df == ind)
                curr_nonclone_inds = np.flatnonzero(clone_df != ind)
                for cell in curr_clone_inds:
                    # Get one value for curr_mt and cell based on coverage
                    cell_af.loc[cell, curr_mt] = random.binomial(
                        curr_mt_cov[cell], het)
                for cell in curr_nonclone_inds:
                    cell_af.loc[cell, curr_mt] = random.binomial(
                        curr_mt_cov[cell], het_err)
        return cell_af


    ##########
    def init_cell_af(self):
        """1C. Initialize the cell-by-mtPos af dataframe. Unless a clone:mt dict was
        provided, the first N MT positions will be the clone AFs. Creates
        self.clone_mt_dict and self.cell_af
        """
        p = self.params['initialize']
        hets = self.params['het']
        clone_df = self.clone_cell
        num_clones = self.num_clones
        n_cells = self.num_cells
        n_mt = self.num_mt_positions

        # Get the MT map
        if 'mt_clone_map' in p and p['mt_clone_map'] is not None:
            self.clone_mt_dict = p['mt_clone_map']
        else:
            # Each clone points to a mt position
            self.clone_mt_dict = dict()
            for i in range(1,num_clones+1):
                self.clone_mt_dict[i] = i

        # TODO Add the MT clone map so it can contain multiple mutants in lineages
        # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
        # If not, will randomly draw based on number of clones
        if type(hets) == list:
            assert(len(hets) == num_clones)

        # Get the cell_af based on MT dictionary and cell coverage
        self.cell_af = self.create_cell_af(clone_df, self.clone_mt_dict,
                                           n_cells, n_mt, num_clones,
                                           self.params['initialize'],
                                           hets,
                                           self.params['het_err_rate'],
                                           coverage=None)
        return


    def init_clone_mt(self):
        p = self.params
        if p["initialize"]['type'] == 'growth':
            ## TODO
            # Create a phylogeny and then get the averages of the mutants
            self.average_clone_mt()
        # If not growth, should aready be there.
        return

    def average_clone_mt(self):
        return

    @staticmethod
    def extract_clone_cells(clone_cell, clone_id):
        """ Returns the numbered indices of the specific clones

        :param clone_cell: Each element is the indexed cell's clone label.
        :type clone_cell: np array or pd.Series

        :param clone_id:
        :type clone_id: int or string
        """
        ids = np.flatnonzero(clone_cell == clone_id)
        return ids

    @staticmethod
    def simulate_expand_cells_af(af, growth_inds, sigma):
        """Given a cell-by-af vector, expand the AF.

        Expanded AF occurs by duplicating cells that grew based on the
        growth_inds vector. It will add standard error to each af based on sigma
        :param af: :param growth: Indices of AF to copy :param sigma: Variance
        to add to AF of child. :return:

        Args:
            af:
            growth_inds:
            sigma:
        """

        new_af = af.iloc[growth_inds].copy() + random.normal(0, sigma, size=af.iloc[growth_inds].shape)
        new_af.index = np.arange(af.index[-1]+1, af.index[-1]+1+new_af.shape[0])
        new_af = pd.concat((af,new_af), axis=0)
        #new_af = np.append(af, np.concatenate(new_af))
        return new_af

    def grow_binomial(self, p):
        """ (2.1.2)
        :param p: contains time_steps, rates,
        :type dict
        """
        timesteps = p["time_steps"]
        rates = p["rates"]

        num_clones = self.num_clones+1
        new_dict = {}
        for curr_clone in range(num_clones):
            curr_rate = rates[curr_clone]
            ids = self.extract_clone_cells(self.clone_cell_pop, curr_clone)
            num_curr_cells = len(ids)

            for i in range(timesteps):
                # Simulate growth for each clone separately.
                growth_inds = (random.binomial(1, curr_rate, size=num_curr_cells)).sum()
                num_curr_cells += growth_inds.sum()

            new_dict[curr_clone] = num_curr_cells

        ####TODO
        ## new_lineage_mutants chances. This will see if a mutation will change
        ####TODO
        ## Add death + stimulation rate as well as growth
        # Save the new cell clones df and cell af
        clone_counts = [i for i in new_dict.values()]
        self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)
        # Do not make cell_af, will make this only when subsampled.


        # self.new_cell_af = pd.DataFrame()
        # for clone in range(1, self.num_clones+1):
        #     self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
        return


    def grow_binomial_old(self, p):
        """ (2.1.1)
        :param p: contains time_steps, rates,
                 and [growth][mutant_af_sigma_noise
        :type dict
        """
        timesteps = p["time_steps"]
        rates = p["rates"]

        sigma = self.params['growth']["mutant_af_sigma_noise"]
        cell_af = self.cell_af

        num_clones = self.num_clones+1
        new_dict = {}
        for curr_clone in range(num_clones):
            curr_rate = rates[curr_clone]
            ids = self.extract_clone_cells(self.clone_cell, curr_clone)
            new_cells = cell_af.iloc[ids].copy()
            for i in range(timesteps):
                # Simulate growth for each clone separately.
                growth_inds = np.flatnonzero(random.binomial(1, curr_rate, size=new_cells.shape[0]))
                #new_ids =
                new_cells = self.simulate_expand_cells_af(new_cells, growth_inds, sigma)

            new_dict[curr_clone] = new_cells
            # Create list of cells
        ####TODO
        ## new_lineage_mutants chances. This will see if a mutation will change
        ####TODO
        ## Add death + stimulation rate as well as growth
        # Save the new cell clones df and cell af
        clone_counts = [i.shape[0] for i in new_dict.values()]
        self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)
        self.new_cell_af = pd.DataFrame(new_dict[0])
        for clone in range(1, self.num_clones+1):
            self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
        return


    def grow_poisson(self, p):
        # TODO growth of poisson refactor

        # Sample from poisson the growth
        grow_clones = random.poisson(lam=p['clone_growth'],
                                     size=(p['clone_meta']))
        grow_nonclones = random.poisson(lam=p['non_clone_growth'],
                                        size=(p['clone_meta']))
        return


    def subsample_new(self, to_delete=False):
        """(3) Subsample from new cell population and generate cell_af

        :param to_delete: To remove the cells that grew (which takes up
        a lot of RAM).
        :type to_delete: bool
        """
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_clone_cell = self.new_clone_cell.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_clone_cell = self.new_clone_cell.sample(
                n=self.num_cells)

        #print(f'New cell af, {len(self.subsample_new_clone_cell)}')
        # Generate subsample_new_cell_af
        self.subsample_new_cell_af = self.create_cell_af(clone_df=self.subsample_new_clone_cell,
                                                         mt_dict = self.clone_mt_dict,
                                                         n_cells=len(self.subsample_new_clone_cell),
                                                         n_mt=self.num_mt_positions,
                                                         num_clones=self.num_clones,
                                                         cov_params=p['initialize'],
                                                         hets=
                                                             self.params[
                                                                 'het'],
                                                         het_err=self.params['het_err_rate'],
                                                         coverage=None
                                                         )

        if to_delete:
            self.new_cell_af = None
            self.new_clone_cell = None


    def subsample_new_old(self, to_delete=False):
        """(3) Subsample from new cell population

        :param to_delete: To remove the cells that grew (which takes up
        a lot of RAM).
        :type to_delete: bool
        """
        new_cell_af = self.new_cell_af
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)

        self.subsample_new_clone_cell = self.new_clone_cell.loc[
            self.subsample_new_cell_af.index]

        if to_delete:
            self.new_cell_af = None
            self.new_clone_cell = None


    def combine_init_growth(self):
        """(4) Add the pre- and post- population of cells into a group.

        :return:
        """
        combined_cell_af = self.cell_af.append(self.subsample_new_cell_af).reset_index(drop=True)
        combined_clones = pd.concat(
            (self.clone_cell, self.subsample_new_clone_cell)).reset_index(
            drop=True)

        combined_befaft = np.concatenate((np.zeros(shape=[self.cell_af.shape[0],]), np.ones(shape=[self.subsample_new_cell_af.shape[0]])))
        combined_meta = pd.DataFrame({"pre_post": combined_befaft, "clone": combined_clones})
        #combined_meta = pd.Series(combined_meta, name='After Growth', dtype=int)
        assert(combined_meta.shape[0] == self.cell_af.shape[0]+self.subsample_new_cell_af.shape[0])
        assert (combined_cell_af.shape[0] == self.cell_af.shape[0] +
                self.subsample_new_cell_af.shape[0])
        assert(combined_meta.shape[0] == combined_clones.shape[0])
        assert(combined_cell_af.shape[0] == combined_clones.shape[0])
        self.combined_meta = combined_meta
        self.combined_clones = combined_clones
        self.combined_cell_af = combined_cell_af
        return

    def save(self, f_save=None):
        """
        Args:
            f_save:
        """
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

    @staticmethod
    def expand_to_mgatk(curr_mt_af,mt_ref):
        ref = mt_ref[curr_mt_af.name]
        pos = curr_mt_af.name
        return pd.DataFrame({"Ref":ref, "Pos":pos, "Val":curr_mt_af})

    def test_save_to_mgatk_format(self):
        df = pd.DataFrame( [[10,0,1,3,5], [3,0,5,5,0], [6,2,1,1,0]] , columns=np.arange(0,5))
        mt_ref_dict = {0: "A", 1: "G", 2: "C", 3: "C", 4: "T"}
        mt_ref = pd.DataFrame({"Pos": mt_ref_dict.keys(), "Ref": mt_ref_dict})
        return

    def save_to_mgatk_format(self, mt_ref, out_f):
        """Converts into the proper files needed for mgatk. (i.e variant and
        coverage files)

        :return:
        """
        cell_af = self.subsample_new_cell_af
        chars = ["A", "G", "C", "T"]
        def alt_generate(x):
            curr = chars.copy()
            curr.remove(x["Ref"])
            return np.random.choice(curr)
        alt_ref = mt_ref.apply(alt_generate, axis=1)

        # First use the AF and choose an alternative allele
        df_stack = cell_af.stack().reset_index().rename(
            {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
            axis=1)
        df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
            lambda x: alt_ref[x])

        # Add on the reference allele
        df_stack_ref = cell_af.stack().reset_index().rename(
            {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
            axis=1)
        df_stack_ref["Coverage"] = 1-df_stack_ref["Coverage"]
        df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
            lambda x: mt_ref[x])

        # Save the NTs.
        # For concordance, split the coverage in two
        df_stack = pd.concat(df_stack, df_stack_ref)
        for ind, val in df_stack.groupby("Nucleotide"):
            # Drop the 0s
            curr = val[val["Coverage"]>0]
            # Save file
            curr_out_f = out_f + "_" + ind + ".txt"
            curr.to_csv(curr_out_f)

        # Save the coverage.
        coverage = self.cells_mt_coverage
        if type(coverage) != int:
            coverage_stack = pd.DataFrame(coverage).stack().reset_index().rename(
                {"level_0": "Cell", "level_1": "MT Position", 0: "Coverage"},
                axis=1)
        else:
            coverage_stack = pd.DataFrame(self.cells_mt_coverage)*np.ones(shape=cell_af.shape).stack().reset_index().rename(
                {"level_0": "Cell", "level_1": "MT Position",  0: "Coverage"},
                axis=1)
        curr_out_f = out_f + "_" + "coverage.txt"
        coverage_stack.to_csv(curr_out_f)
        return

    def load(self):
        filename = self.params['filename']
        f = open(filename, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)

    def compare_before_after(self):
        """Creates a df that contains information on the number of cells from
        each clone before as well as after. :return: df.at[ind, "Dominant
        Before"] = (full_sim.clone_cell == 1).sum() df.at[ind, "Dominant After"]
        = (full_sim.subsample_new_clone_cell == 1).sum()
        """

        return

    def cluster_compare_before_after(self):
        """Compares the performance of clustering on grouping the same clones
        together. :return:
        """
        return


def main():
    return


if "__name__" == "__main__":
    main()
 No newline at end of file
ew file mode 100644
ndex 0000000..e30dc8e
++ b/src/simulations/simulation_pymc.gv

4f4a80775a739c20fac81aa10a867caa1765522d

Add the MT clone map so it can contain multiple mutants in lineages

Add the MT clone map so it can contain multiple mutants in lineages

If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.

If not, will randomly draw based on number of clones

# TODO Add the MT clone map so it can contain multiple mutants in lineages

import numpy as np
from numpy import random
import os
import pandas as pd
import pickle
from src.simulations.utils.config import read_config_file, write_config_file
from src.simulations.utils.config import check_required


class Simulation:
    """Lineage tracing simulation of one sample

    Will initialize cells based on their parameters and grow as well. This
    should be a flexible framework, to add different ways to initialize, grow,
    and metrics to have. Additionally can cluster these results.

    :ivar params
    :type params: dict
    """

    def __init__(self, params_f):
        """
        :param params_f: Parameter yaml file for the specifications
        :type params_f: yaml file or dict
        """
        if isinstance(params_f, str):
            params = read_config_file(params_f)
        else:
            params = params_f

        self.params = params
        check_required(params, ['initialize', 'num_cells', 'num_mt_positions', 'prefix'])
        self.prefix = params['prefix']
        self.num_mt_positions = params['num_mt_positions']
        self.num_cells = params['num_cells']
        if not os.path.exists(params['local_outdir']):
            os.mkdir(params['local_outdir'])


    def initialize(self):
        """ (1) Pre-growth cell population is instantiated.

        Creates a clone-MT dictionary, cell coverage matrix
        (or an int, depending on parameters), and cell-AF matrix.
        :return:
        """
        self.init_clone_dict()
        self.init_cell_coverage()
        self.init_cell_af()
        #self.init_clone_mt()

    #should be external method
    def grow(self):
        """ (2) Growth of cells is run."""
        p = self.params
        type = p["growth"]["type"]
        if  type == "poisson":
            self.grow_poisson(p['growth']['poisson'])
        elif type == "binomial":
            self.grow_binomial(p['growth']['binomial'])
        return

    # Static Method
    @staticmethod
    def clone_counts_to_cell_series(clone_counts):
        """ Generates new cell IDs based on cluster count iterable
        :param clone_counts: Each i'th element is the number of cells in
        cluster i.
        :type clone_counts: iterable
        :return each index name is a cell ID and each value is which cluster
        the cell belongs too.
        :rtype pd.Series
        """
        clone_counts = np.array(clone_counts)
        num_cells = clone_counts.sum()
        clone_cell = -1 * np.ones(shape=[num_cells, ])

        clone_cell[:clone_counts[0]] = 0
        for ind, val in enumerate(clone_counts[1:]):
            start = clone_counts[:ind + 1].sum()
            end = clone_counts[:ind + 1].sum() + val
            # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
            clone_cell[start:end] = ind + 1

        clone_cell = pd.Series(clone_cell, dtype=int)
        return clone_cell


    def init_clone_dict(self):
        """1A
        """

        ### Add in potential to overwrite the values
        # Gets the clone dictionary. Should also have clone to mt dict.
        clones = self.params['initialize']['clone_sizes']

        if 'num_cells_population' not in self.params:
            self.num_cells_pop = self.num_cells
        else:
            self.num_cells_pop = self.params['num_cells_population']

        num_cells = self.num_cells_pop

        # Option 1: List of fraction of size of each clone. 0s are nonclone size, listed first
        if type(clones) == list:
            #clone_cell = pd.Series(index=range(num_cells))
            clone_counts = np.random.multinomial(num_cells, clones)
            clone_cell  = self.clone_counts_to_cell_series(clone_counts)
            self.clone_cell_pop = clone_cell

            # Choose subset to be sampled
            self.clone_cell = clone_cell.sample(n=self.num_cells).sort_values()
        # Option 2: 1 clone. ID'd as 1
        elif type(clones) == int: #One number for dominant clone. the others are not.
            clone_cell = np.zeros(shape=[num_cells,])
            clone_cell[:num_cells] = 1
            clone_cell = clone_cell[::-1]
            clone_cell =  pd.Series(clone_cell, dtype=int)
            self.clone_cell = clone_cell

        # Option 3 To ADD, beta binomial and more complex distributions

        self.num_clones =  len(set(clone_cell.values))-1 # Remove the non-clone
        return clone_cell


    def init_cell_coverage(self):
        """1B

        There are different modes to the coverage, either a constant or
        through a distribution.
        """
        p = self.params['initialize']['coverage']
        type = p['type']

        num_cells = self.num_cells
        num_pos = self.num_mt_positions
        c = np.zeros([num_cells, num_pos])

        if type == 'constant':
            c[:, :] = p['cov_constant']
        elif type == "poisson":
            # Get the number of coverage per cell based on poisson (should be reads)
            mu_cov_per_cell = p['mu_cov_per_cell']
            num_reads_per_cell = random.poisson(lam=mu_cov_per_cell,
                                                size=num_cells)

            # Number of reads at each position, based on the average for each cell
            for i in num_cells:
                c[i, :] = random.poisson(num_reads_per_cell[i],
                                         size=num_pos)
        self.cells_mt_coverage = c
        return c


    @staticmethod
    def create_cell_af(clone_df, mt_dict, n_cells, n_mt, num_clones,
                       cov_params, hets, het_err, coverage=None):
        cell_af = pd.DataFrame(np.zeros(shape=[n_cells, n_mt]))

        #p = self.params['initialize']

        ## Loop through each clone,
        ## Generate the AF for the clone and non-clones using coverage for each cell
        ## Fill in cell_by_af for that position.
        for ind in range(1, num_clones + 1):
            # Generate AF: (clone_df ==  ind).sum()
            n_dom_cells = (clone_df == ind).sum()
            het = hets[ind - 1]

            curr_mt = mt_dict[ind]

            if cov_params['coverage']['type'] == 'constant':
                c = cov_params['coverage']['cov_constant']

                af_i = random.binomial(c, het, n_dom_cells) / c
                af_j = random.binomial(c, het_err, n_cells - n_dom_cells) / c

                # Update the dom_cells and non_dom for the current MT
                cell_af.loc[
                    np.flatnonzero(clone_df == ind), curr_mt] = af_i
                cell_af.loc[
                    np.flatnonzero(clone_df != ind), curr_mt] = af_j

            # Each cell and position has it's own coverage value, so need to update each
            else:
                if coverage is None:
                    raise("coverage needs to be assigned")
                c = coverage

                # Get the cells coverage for the mt position
                curr_mt_cov = c[:, curr_mt]

                # Get cell indicies for the clones and nonclones
                curr_clone_inds = np.flatnonzero(clone_df == ind)
                curr_nonclone_inds = np.flatnonzero(clone_df != ind)
                for cell in curr_clone_inds:
                    # Get one value for curr_mt and cell based on coverage
                    cell_af.loc[cell, curr_mt] = random.binomial(
                        curr_mt_cov[cell], het)
                for cell in curr_nonclone_inds:
                    cell_af.loc[cell, curr_mt] = random.binomial(
                        curr_mt_cov[cell], het_err)
        return cell_af


    ##########
    def init_cell_af(self):
        """1C. Initialize the cell-by-mtPos af dataframe. Unless a clone:mt dict was
        provided, the first N MT positions will be the clone AFs. Creates
        self.clone_mt_dict and self.cell_af
        """
        p = self.params['initialize']
        hets = self.params['het']
        clone_df = self.clone_cell
        num_clones = self.num_clones
        n_cells = self.num_cells
        n_mt = self.num_mt_positions

        # Get the MT map
        if 'mt_clone_map' in p and p['mt_clone_map'] is not None:
            self.clone_mt_dict = p['mt_clone_map']
        else:
            # Each clone points to a mt position
            self.clone_mt_dict = dict()
            for i in range(1,num_clones+1):
                self.clone_mt_dict[i] = i

        # TODO Add the MT clone map so it can contain multiple mutants in lineages
        # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
        # If not, will randomly draw based on number of clones
        if type(hets) == list:
            assert(len(hets) == num_clones)

        # Get the cell_af based on MT dictionary and cell coverage
        self.cell_af = self.create_cell_af(clone_df, self.clone_mt_dict,
                                           n_cells, n_mt, num_clones,
                                           self.params['initialize'],
                                           hets,
                                           self.params['het_err_rate'],
                                           coverage=None)
        return


    def init_clone_mt(self):
        p = self.params
        if p["initialize"]['type'] == 'growth':
            ## TODO
            # Create a phylogeny and then get the averages of the mutants
            self.average_clone_mt()
        # If not growth, should aready be there.
        return

    def average_clone_mt(self):
        return

    @staticmethod
    def extract_clone_cells(clone_cell, clone_id):
        """ Returns the numbered indices of the specific clones

        :param clone_cell: Each element is the indexed cell's clone label.
        :type clone_cell: np array or pd.Series

        :param clone_id:
        :type clone_id: int or string
        """
        ids = np.flatnonzero(clone_cell == clone_id)
        return ids

    @staticmethod
    def simulate_expand_cells_af(af, growth_inds, sigma):
        """Given a cell-by-af vector, expand the AF.

        Expanded AF occurs by duplicating cells that grew based on the
        growth_inds vector. It will add standard error to each af based on sigma
        :param af: :param growth: Indices of AF to copy :param sigma: Variance
        to add to AF of child. :return:

        Args:
            af:
            growth_inds:
            sigma:
        """

        new_af = af.iloc[growth_inds].copy() + random.normal(0, sigma, size=af.iloc[growth_inds].shape)
        new_af.index = np.arange(af.index[-1]+1, af.index[-1]+1+new_af.shape[0])
        new_af = pd.concat((af,new_af), axis=0)
        #new_af = np.append(af, np.concatenate(new_af))
        return new_af

    def grow_binomial(self, p):
        """ (2.1.2)
        :param p: contains time_steps, rates,
        :type dict
        """
        timesteps = p["time_steps"]
        rates = p["rates"]

        num_clones = self.num_clones+1
        new_dict = {}
        for curr_clone in range(num_clones):
            curr_rate = rates[curr_clone]
            ids = self.extract_clone_cells(self.clone_cell_pop, curr_clone)
            num_curr_cells = len(ids)

            for i in range(timesteps):
                # Simulate growth for each clone separately.
                growth_inds = (random.binomial(1, curr_rate, size=num_curr_cells)).sum()
                num_curr_cells += growth_inds.sum()

            new_dict[curr_clone] = num_curr_cells


        ####TODO
        ## new_lineage_mutants chances. This will see if a mutation will change
        ####TODO
        ## Add death + stimulation rate as well as growth
        # Save the new cell clones df and cell af
        clone_counts = [i for i in new_dict.values()]
        self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)
        # Do not make cell_af, will make this only when subsampled.


        # self.new_cell_af = pd.DataFrame()
        # for clone in range(1, self.num_clones+1):
        #     self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
        return


    def grow_binomial_old(self, p):
        """ (2.1.1)
        :param p: contains time_steps, rates,
                 and [growth][mutant_af_sigma_noise
        :type dict
        """
        timesteps = p["time_steps"]
        rates = p["rates"]

        sigma = self.params['growth']["mutant_af_sigma_noise"]
        cell_af = self.cell_af

        num_clones = self.num_clones+1
        new_dict = {}
        for curr_clone in range(num_clones):
            curr_rate = rates[curr_clone]
            ids = self.extract_clone_cells(self.clone_cell, curr_clone)
            new_cells = cell_af.iloc[ids].copy()
            for i in range(timesteps):
                # Simulate growth for each clone separately.
                growth_inds = np.flatnonzero(random.binomial(1, curr_rate, size=new_cells.shape[0]))
                #new_ids =
                new_cells = self.simulate_expand_cells_af(new_cells, growth_inds, sigma)

            new_dict[curr_clone] = new_cells
            # Create list of cells
        ####TODO
        ## new_lineage_mutants chances. This will see if a mutation will change
        ####TODO
        ## Add death + stimulation rate as well as growth
        # Save the new cell clones df and cell af
        clone_counts = [i.shape[0] for i in new_dict.values()]
        self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)
        self.new_cell_af = pd.DataFrame(new_dict[0])
        for clone in range(1, self.num_clones+1):
            self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
        return


    def grow_poisson(self, p):
        # TODO growth of poisson refactor
        return


    def subsample_new(self, to_delete=False):
        """(3) Subsample from new cell population and generate cell_af

        :param to_delete: To remove the cells that grew (which takes up
        a lot of RAM).
        :type to_delete: bool
        """
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_clone_cell = self.new_clone_cell.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_clone_cell = self.new_clone_cell.sample(
                n=self.num_cells)

        #print(f'New cell af, {len(self.subsample_new_clone_cell)}')
        # Generate subsample_new_cell_af
        self.subsample_new_cell_af = self.create_cell_af(clone_df=self.subsample_new_clone_cell,
                                                         mt_dict = self.clone_mt_dict,
                                                         n_cells=len(self.subsample_new_clone_cell),
                                                         n_mt=self.num_mt_positions,
                                                         num_clones=self.num_clones,
                                                         cov_params=p['initialize'],
                                                         hets=
                                                             self.params[
                                                                 'het'],
                                                         het_err=self.params['het_err_rate'],
                                                         coverage=None
                                                         )

        if to_delete:
            self.new_cell_af = None
            self.new_clone_cell = None


    def subsample_new_old(self, to_delete=False):
        """(3) Subsample from new cell population

        :param to_delete: To remove the cells that grew (which takes up
        a lot of RAM).
        :type to_delete: bool
        """
        new_cell_af = self.new_cell_af
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)

        self.subsample_new_clone_cell = self.new_clone_cell.loc[
            self.subsample_new_cell_af.index]

        if to_delete:
            self.new_cell_af = None
            self.new_clone_cell = None


    def combine_init_growth(self):
        """(4) Add the pre- and post- population of cells into a group.

        :return:
        """
        combined_cell_af = self.cell_af.append(self.subsample_new_cell_af).reset_index(drop=True)
        combined_clones = pd.concat(
            (self.clone_cell, self.subsample_new_clone_cell)).reset_index(
            drop=True)

        combined_befaft = np.concatenate((np.zeros(shape=[self.cell_af.shape[0],]), np.ones(shape=[self.subsample_new_cell_af.shape[0]])))
        combined_meta = pd.DataFrame({"pre_post": combined_befaft, "clone": combined_clones})
        #combined_meta = pd.Series(combined_meta, name='After Growth', dtype=int)
        assert(combined_meta.shape[0] == self.cell_af.shape[0]+self.subsample_new_cell_af.shape[0])
        assert (combined_cell_af.shape[0] == self.cell_af.shape[0] +
                self.subsample_new_cell_af.shape[0])
        assert(combined_meta.shape[0] == combined_clones.shape[0])
        assert(combined_cell_af.shape[0] == combined_clones.shape[0])
        self.combined_meta = combined_meta
        self.combined_clones = combined_clones
        self.combined_cell_af = combined_cell_af
        return

    def save(self, f_save=None):
        """
        Args:
            f_save:
        """
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

    @staticmethod
    def expand_to_mgatk(curr_mt_af,mt_ref):
        ref = mt_ref[curr_mt_af.name]
        pos = curr_mt_af.name
        return pd.DataFrame({"Ref":ref, "Pos":pos, "Val":curr_mt_af})

    def test_save_to_mgatk_format(self):
        df = pd.DataFrame( [[10,0,1,3,5], [3,0,5,5,0], [6,2,1,1,0]] , columns=np.arange(0,5))
        mt_ref_dict = {0: "A", 1: "G", 2: "C", 3: "C", 4: "T"}
        mt_ref = pd.DataFrame({"Pos": mt_ref_dict.keys(), "Ref": mt_ref_dict})
        return

    def save_to_mgatk_format(self, mt_ref, out_f):
        """Converts into the proper files needed for mgatk. (i.e variant and
        coverage files)

        :return:
        """
        cell_af = self.subsample_new_cell_af
        chars = ["A", "G", "C", "T"]
        def alt_generate(x):
            curr = chars.copy()
            curr.remove(x["Ref"])
            return np.random.choice(curr)
        alt_ref = mt_ref.apply(alt_generate, axis=1)

        # First use the AF and choose an alternative allele
        df_stack = cell_af.stack().reset_index().rename(
            {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
            axis=1)
        df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
            lambda x: alt_ref[x])

        # Add on the reference allele
        df_stack_ref = cell_af.stack().reset_index().rename(
            {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
            axis=1)
        df_stack_ref["Coverage"] = 1-df_stack_ref["Coverage"]
        df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
            lambda x: mt_ref[x])

        # Save the NTs.
        # For concordance, split the coverage in two
        df_stack = pd.concat(df_stack, df_stack_ref)
        for ind, val in df_stack.groupby("Nucleotide"):
            # Drop the 0s
            curr = val[val["Coverage"]>0]
            # Save file
            curr_out_f = out_f + "_" + ind + ".txt"
            curr.to_csv(curr_out_f)

        # Save the coverage.
        coverage = self.cells_mt_coverage
        if type(coverage) != int:
            coverage_stack = pd.DataFrame(coverage).stack().reset_index().rename(
                {"level_0": "Cell", "level_1": "MT Position", 0: "Coverage"},
                axis=1)
        else:
            coverage_stack = pd.DataFrame(self.cells_mt_coverage)*np.ones(shape=cell_af.shape).stack().reset_index().rename(
                {"level_0": "Cell", "level_1": "MT Position",  0: "Coverage"},
                axis=1)
        curr_out_f = out_f + "_" + "coverage.txt"
        coverage_stack.to_csv(curr_out_f)
        return

    def load(self):
        filename = self.params['filename']
        f = open(filename, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)

    def compare_before_after(self):
        """Creates a df that contains information on the number of cells from
        each clone before as well as after. :return: df.at[ind, "Dominant
        Before"] = (full_sim.clone_cell == 1).sum() df.at[ind, "Dominant After"]
        = (full_sim.subsample_new_clone_cell == 1).sum()
        """

        return

    def cluster_compare_before_after(self):
        """Compares the performance of clustering on grouping the same clones
        together. :return:
        """
        return


def main():
    return


if "__name__" == "__main__":
    main()
 No newline at end of file
ew file mode 100644
ndex 0000000..e30dc8e
++ b/src/simulations/simulation_pymc.gv

3a7261e5e43e9daac35bf6530bcdd0e64d2476bd

Add the MT clone map so it can contain multiple mutants in lineages

Add the MT clone map so it can contain multiple mutants in lineages

# # TODO Add the MT clone map so it can contain multiple mutants in lineages

# from numpy import random
# import os
# import pandas as pd
# import pickle
# from src.simulations.utils.config import read_config_file, write_config_file
# from src.simulations.utils.config import check_required
import pymc3 as pm
import matplotlib.pyplot as plt

num_cells = 10000
num_mt_positions = 10
clone_dist = [0.10,0.01,.89]
hets = [0.2,0.3] # len(hets) == len(clone_dist)-1
avg_cov = 50
het_err_rate = 0.1


df = np.concatenate((np.random.binomial(10,0.3,(100,4)),
                       np.random.binomial(10,0.6,(90,4))))
clone_id = np.concatenate((np.zeros([100,]), np.ones([90,]))).astype(int)
mt_id = [0,1,2,3]

with pm.Model() as model:

    clone_ids = pm.Mulinomial(10000, clone_dist)

    beta = pm.Beta('beta', alpha=2,beta=2, shape=2)
    #p = pm.Bernoulli('p', 1, beta, shape=2)
    #p = pm.Binomial('p', 1, beta)
    #q = pm.Binomial('q', 1, beta)

    s = pm.Binomial('s', 10, beta[clone_id], observed=df)

    #s = pm.Binomial('s', 10, p, observed=df[:30,0])
    #t = pm.Binomial('t', 10, q, observed=df[30:, 0])

    #s = pm.Binomial('s', 10, beta, shape=(30,4), observed=df[:30])
    #t = pm.Binomial('t', 10, beta, shape=(25, 4), observed=df[30:])

    #vec = pm.math.concatenate((s, t), axis=0)

    # data = pm.Data("data", df)
    # u = pm.Normal('u', vec, observed=data)
    #u = pm.Deterministic('u', vec)

    trace = pm.sample(draws=8000, init='adapt_diag')

print(pm.summary(trace))
dot = pm.model_to_graphviz(model)
dot.render('simulation_pymc.gv')
pm.plot_trace(trace)
plt.savefig('simulation_trace.png')
print('here')
#
# with pm.Model() as model:
#     clone_counts = pm.Multinomial(num_cells, clone_dist)
#     num_clones = len(clone_counts) - 1
#
#     clone_cell = -1 * np.ones(shape=[num_cells, ])
#
#     clone_cell[:clone_counts[0]] = 0
#
#     for ind, val in enumerate(clone_counts[1:]):
#         start = clone_counts[:ind + 1].sum()
#         end = clone_counts[:ind + 1].sum() + val
#         # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
#         clone_cell[start:end] = ind + 1
#
#     c = pm.Poisson('cov', avg_cov, shape=[num_cells, num_mt_positions])
#
#     clone_mt_dict = dict()
#     for i in range(1, num_clones + 1):
#         clone_mt_dict[i] = i
#
#     cell_af = np.zeros([num_cells, num_mt_positions])
#     for ind in range(num_clones):
#         # Generate AF: (clone_df ==  ind).sum()
#         n_dom_cells = clone_counts[ind]
#         het = hets[ind]
#
#         curr_mt = clone_mt_dict[ind]
#
#         af_i = pm.Binomial('af', avg_cov, het, shape=n_dom_cells)
#         af_j = pm.Binomial('het af', avg_cov, het_err_rate, shape=num_cells - n_dom_cells) # / c
#
#
#         # Update the dom_cells and non_dom for the current MT
#         cell_af[np.flatnonzero(clone_df == ind), curr_mt] = af_i
#         cell_af[np.flatnonzero(clone_df != ind), curr_mt] = af_j
#
#     cell_af = pm.Deterministic(y)
#
# pm.model_to_graphviz(model)
#
#
# def init_cell_af(self):
#     """1C. Initialize the cell-by-mtPos af dataframe. Unless a clone:mt dict was
#     provided, the first N MT positions will be the clone AFs. Creates
#     self.clone_mt_dict and self.cell_af
#     """
#     clone_df = self.clone_cell
#     # Output
#     cell_af = pd.DataFrame(np.zeros(shape=[n_cells, n_mt]))
#
#         # Each clone points to a mt position
#         self.clone_mt_dict = dict()
#         for i in range(1, num_clones + 1):
#             self.clone_mt_dict[i] = i
#
#     # TODO Add the MT clone map so it can contain multiple mutants in lineages
#
#     # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
#     # If not, will randomly draw based on number of clones
#     if type(hets) == list:
#         assert (len(hets) == num_clones)
#
#         ## Loop through each clone,
#         ## Generate the AF for the clone and non-clones using coverage for each cell
#         ## Fill in cell_by_af for that position.
#         for ind in range(1, num_clones + 1):
#             # Generate AF: (clone_df ==  ind).sum()
#             n_dom_cells = (clone_df == ind).sum()
#             het = hets[ind - 1]
#
#             curr_mt = self.clone_mt_dict[ind]
#
#             if p['coverage']['type'] == 'constant':
#                 c = p['coverage']['cov_constant']
#
#                 af_i = random.binomial(c, het, n_dom_cells) / c
#                 af_j = random.binomial(c, q, n_cells - n_dom_cells) / c
#
#                 # Update the dom_cells and non_dom for the current MT
#                 cell_af.loc[
#                     np.flatnonzero(clone_df == ind), curr_mt] = af_i
#                 cell_af.loc[
#                     np.flatnonzero(clone_df != ind), curr_mt] = af_j
#
#             # Each cell and position has it's own coverage value, so need to update each
#             else:
#                 c = self.cells_mt_coverage
#                 # Get the cells coverage for the mt position
#                 curr_mt_cov = c[:, curr_mt]
#
#                 # Get cell indicies for the clones and nonclones
#                 curr_clone_inds = np.flatnonzero(clone_df == ind)
#                 curr_nonclone_inds = np.flatnonzero(clone_df != ind)
#                 for cell in curr_clone_inds:
#                     # Get one value for curr_mt and cell based on coverage
#                     cell_af.loc[cell, curr_mt] = random.binomial(
#                         curr_mt_cov[cell], het)
#                 for cell in curr_nonclone_inds:
#                     cell_af.loc[cell, curr_mt] = random.binomial(
#                         curr_mt_cov[cell],
#                         q)  # Loop through each coverage  # for c in n_dom_cells:
#
#
# class Simulation:
#     """Lineage tracing simulation of one sample
#
#     Will initialize cells based on their parameters and grow as well. This
#     should be a flexible framework, to add different ways to initialize, grow,
#     and metrics to have. Additionally can cluster these results.
#
#     :ivar params
#     :type params: dict
#     """
#
#     def __init__(self, params_f):
#         """
#         :param params_f: Parameter yaml file for the specifications
#         :type params_f: yaml file or dict
#         """
#         if isinstance(params_f, str):
#             params = read_config_file(params_f)
#         else:
#             params = params_f
#
#         self.params = params
#         check_required(params, ['initialize', 'num_cells', 'num_mt_positions', 'prefix'])
#         self.prefix = params['prefix']
#         self.num_mt_positions = params['num_mt_positions']
#         self.num_cells = params['num_cells']
#         if not os.path.exists(params['local_outdir']):
#             os.mkdir(params['local_outdir'])
#
#
#     def initialize(self):
#         """ (1) Pre-growth cell population is instantiated.
#
#         Creates a clone-MT dictionary, cell coverage matrix
#         (or an int, depending on parameters), and cell-AF matrix.
#         :return:
#         """
#         self.init_clone_dict()
#         self.init_cell_coverage()
#         self.init_cell_af()
#         #self.init_clone_mt()
#
#     #should be external method
#     def grow(self):
#         """ (2) Growth of cells is run."""
#         p = self.params
#         type = p["growth"]["type"]
#         if  type == "poisson":
#             self.grow_poisson(p['growth']['poisson'])
#         elif type == "binomial":
#             self.grow_binomial(p['growth']['binomial'])
#         return
#
#     # Static Method
#     @staticmethod
#     def clone_counts_to_cell_series(clone_counts):
#         """ Generates new cell IDs based on cluster count iterable
#         :param clone_counts: Each i'th element is the number of cells in
#         cluster i.
#         :type clone_counts: iterable
#         :return each index name is a cell ID and each value is which cluster
#         the cell belongs too.
#         :rtype pd.Series
#         """
#         clone_counts = np.array(clone_counts)
#         num_cells = clone_counts.sum()
#         clone_cell = -1 * np.ones(shape=[num_cells, ])
#
#         clone_cell[:clone_counts[0]] = 0
#         for ind, val in enumerate(clone_counts[1:]):
#             start = clone_counts[:ind + 1].sum()
#             end = clone_counts[:ind + 1].sum() + val
#             # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
#             clone_cell[start:end] = ind + 1
#
#         clone_cell = pd.Series(clone_cell, dtype=int)
#         return clone_cell
#
#     def init_clone_dict(self):
#         """1A
#         """
#
#         ### Add in potential to overwrite the values
#         # Gets the clone dictionary. Should also have clone to mt dict.
#         clones = self.params['initialize']['clone_sizes']
#         num_cells = self.num_cells
#
#         # Option 1: List of fraction of size of each clone. 0s are nonclone size, listed first
#         if type(clones) == list:
#             #clone_cell = pd.Series(index=range(num_cells))
#             clone_counts = np.random.multinomial(num_cells, clones)
#             clone_cell  = self.clone_counts_to_cell_series(clone_counts)
#             self.clone_cell = clone_cell
#         # Option 2: 1 clone. ID'd as 1
#         elif type(clones) == int: #One number for dominant clone. the others are not.
#             clone_cell = np.zeros(shape=[num_cells,])
#             clone_cell[:num_cells] = 1
#             clone_cell = clone_cell[::-1]
#             clone_cell =  pd.Series(clone_cell, dtype=int)
#             self.clone_cell = clone_cell
#
#         # Option 3 To ADD, beta binomial and more complex distributions
#
#         self.num_clones =  len(set(clone_cell.values))-1 # Remove the non-clone
#         return clone_cell
#
#
#     def init_cell_coverage(self):
#         """1B
#
#         There are different modes to the coverage, either a constant or
#         through a distribution.
#         """
#         p = self.params['initialize']['coverage']
#         type = p['type']
#
#         num_cells = self.num_cells
#         num_pos = self.num_mt_positions
#         c = np.zeros([num_cells, num_pos])
#
#         if type == 'constant':
#             c[:, :] = p['cov_constant']
#         elif type == "poisson":
#             # Get the number of coverage per cell based on poisson (should be reads)
#             mu_cov_per_cell = p['mu_cov_per_cell']
#             num_reads_per_cell = random.poisson(lam=mu_cov_per_cell,
#                                                 size=num_cells)
#
#             # Number of reads at each position, based on the average for each cell
#             for i in num_cells:
#                 c[i, :] = random.poisson(num_reads_per_cell[i],
#                                          size=num_pos)
#         self.cells_mt_coverage = c
#         return c
#
#
#
#         #####
#         # TODO
#         # Add noise to the other non-lineage positions
#         #####
#         self.cell_af = cell_af
#         return
#
#
#     def init_clone_mt(self):
#         p = self.params
#         if p["initialize"]['type'] == 'growth':
#             ## TODO
#             # Create a phylogeny and then get the averages of the mutants
#             self.average_clone_mt()
#         # If not growth, should aready be there.
#         return
#
#     def average_clone_mt(self):
#         return
#
#     def extract_clone_cells(self, clone_id):
#         """
#         Args:
#             clone_id:
#         """
#         ids = np.flatnonzero(self.clone_cell == clone_id)
#         return ids
#
#     def simulate_expand_cells_af(self, af, growth_inds, sigma):
#         """Given a cell-by-af vector, expand the AF.
#
#         Expanded AF occurs by duplicating cells that grew based on the
#         growth_inds vector. It will add standard error to each af based on sigma
#         :param af: :param growth: Indices of AF to copy :param sigma: Variance
#         to add to AF of child. :return:
#
#         Args:
#             af:
#             growth_inds:
#             sigma:
#         """
#
#         new_af = af.iloc[growth_inds].copy() + random.normal(0, sigma, size=af.iloc[growth_inds].shape)
#         new_af.index = np.arange(af.index[-1]+1, af.index[-1]+1+new_af.shape[0])
#         new_af = pd.concat((af,new_af), axis=0)
#         #new_af = np.append(af, np.concatenate(new_af))
#         return new_af
#
#     def grow_binomial(self, p):
#         """ (2.1)
#         Args:
#             p:
#         """
#         timesteps = p["time_steps"]
#         rates = p["rates"]
#
#         sigma = self.params['growth']["mutant_af_sigma_noise"]
#         cell_af = self.cell_af
#         clone_mt_dict = self.clone_mt_dict
#
#         num_clones = self.num_clones+1
#         new_dict = {}
#         for curr_clone in range(num_clones):
#             curr_rate = rates[curr_clone]
#             ids = self.extract_clone_cells(curr_clone)
#             new_cells = cell_af.loc[ids].copy()
#             for i in range(timesteps):
#                 # Simulate growth for each clone separately.
#                 growth_inds = np.flatnonzero(random.binomial(1, curr_rate, size=new_cells.shape[0]))
#                 #new_ids =
#                 new_cells = self.simulate_expand_cells_af(new_cells, growth_inds, sigma)
#
#             new_dict[curr_clone] = new_cells
#             # Create list of cells
#
#         ####TODO
#         ## new_lineage_mutants chances. This will see if a mutation will change
#
#
#         ####TODO
#         ## Add death + stimulation rate as well as growth
#         # Save the new cell clones df and cell af
#         clone_counts = [i.shape[0] for i in new_dict.values()]
#         self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)
#
#         self.new_cell_af = pd.DataFrame(new_dict[0])
#         for clone in range(1, self.num_clones+1):
#             self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
#         return
#
#
#     def grow_poisson(self):
#         # TODO growth of poisson refactor
#         return
#
#
#     def subsample_new(self, to_delete=False):
#         """(3) Subsample from new cell population
#
#         :param to_delete: To remove the cells that grew (which takes up
#         a lot of RAM).
#         :type to_delete: bool
#         """
#         new_cell_af = self.new_cell_af
#         p = self.params
#         if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
#             self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
#         else:
#             self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)
#
#         self.subsample_new_clone_cell = self.new_clone_cell.loc[
#             self.subsample_new_cell_af.index]
#
#         if to_delete:
#             self.new_cell_af = None
#             self.new_clone_cell = None
#
#
#     def combine_init_growth(self):
#         """(4) Add the pre- and post- population of cells into a group.
#
#         :return:
#         """
#         combined_cell_af = self.cell_af.append(self.subsample_new_cell_af).reset_index(drop=True)
#         combined_clones = pd.concat(
#             (self.clone_cell, self.subsample_new_clone_cell)).reset_index(
#             drop=True)
#
#
#         combined_befaft = np.concatenate((np.zeros(shape=[self.cell_af.shape[0],]), np.ones(shape=[self.subsample_new_cell_af.shape[0]])))
#         combined_meta = pd.DataFrame({"pre_post": combined_befaft, "clone": combined_clones})
#         #combined_meta = pd.Series(combined_meta, name='After Growth', dtype=int)
#         assert(combined_meta.shape[0] == self.cell_af.shape[0]+self.subsample_new_cell_af.shape[0])
#         assert (combined_cell_af.shape[0] == self.cell_af.shape[0] +
#                 self.subsample_new_cell_af.shape[0])
#         assert(combined_meta.shape[0] == combined_clones.shape[0])
#         assert(combined_cell_af.shape[0] == combined_clones.shape[0])
#         self.combined_meta = combined_meta
#         self.combined_clones = combined_clones
#         self.combined_cell_af = combined_cell_af
#         return
#
#     def save(self, f_save=None):
#         """
#         Args:
#             f_save:
#         """
#         if f_save is None:
#             f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
#         f = open(f_save, 'wb')
#         pickle.dump(self.__dict__, f, 2)
#         f.close()
#
#     @staticmethod
#     def expand_to_mgatk(curr_mt_af,mt_ref):
#         ref = mt_ref[curr_mt_af.name]
#         pos = curr_mt_af.name
#         return pd.DataFrame({"Ref":ref, "Pos":pos, "Val":curr_mt_af})
#
#     def test_save_to_mgatk_format(self):
#         df = pd.DataFrame( [[10,0,1,3,5], [3,0,5,5,0], [6,2,1,1,0]] , columns=np.arange(0,5))
#         mt_ref_dict = {0: "A", 1: "G", 2: "C", 3: "C", 4: "T"}
#         mt_ref = pd.DataFrame({"Pos": mt_ref_dict.keys(), "Ref": mt_ref_dict})
#         return
#
#     def save_to_mgatk_format(self, mt_ref, out_f):
#         """Converts into the proper files needed for mgatk. (i.e variant and
#         coverage files)
#
#         :return:
#         """
#         cell_af = self.subsample_new_cell_af
#         chars = ["A", "G", "C", "T"]
#         def alt_generate(x):
#             curr = chars.copy()
#             curr.remove(x["Ref"])
#             return np.random.choice(curr)
#         alt_ref = mt_ref.apply(alt_generate, axis=1)
#
#         # First use the AF and choose an alternative allele
#         df_stack = cell_af.stack().reset_index().rename(
#             {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
#             axis=1)
#         df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
#             lambda x: alt_ref[x])
#
#         # Add on the reference allele
#         df_stack_ref = cell_af.stack().reset_index().rename(
#             {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
#             axis=1)
#         df_stack_ref["Coverage"] = 1-df_stack_ref["Coverage"]
#         df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
#             lambda x: mt_ref[x])
#
#         df_stack = pd.concat(df_stack, df_stack_ref)
#         for ind, val in df_stack.groupby("Nucleotide"):
#             # Drop the 0s
#             curr = val[val["Coverage"]>0]
#             # Save file
#             curr_out_f = out_f + "_" + ind + ".txt"
#             curr.to_csv(curr_out_f)
#
#         # Save the coverage.
#         coverage = self.cells_mt_coverage
#         if type(coverage) != int:
#             coverage_stack = pd.DataFrame(coverage).stack().reset_index().rename(
#                 {"level_0": "Cell", "level_1": "MT Position", 0: "Coverage"},
#                 axis=1)
#         else:
#             coverage_stack = pd.DataFrame(self.cells_mt_coverage)*np.ones(shape=cell_af.shape).stack().reset_index().rename(
#                 {"level_0": "Cell", "level_1": "MT Position",  0: "Coverage"},
#                 axis=1)
#         curr_out_f = out_f + "_" + "coverage.txt"
#         coverage_stack.to_csv(curr_out_f)
#         return
#
#     def load(self):
#         filename = self.params['filename']
#         f = open(filename, 'rb')
#         tmp_dict = pickle.load(f)
#         f.close()
#         self.__dict__.update(tmp_dict)
#
#     def compare_before_after(self):
#         """Creates a df that contains information on the number of cells from
#         each clone before as well as after. :return: df.at[ind, "Dominant
#         Before"] = (full_sim.clone_cell == 1).sum() df.at[ind, "Dominant After"]
#         = (full_sim.subsample_new_clone_cell == 1).sum()
#         """
#
#         return
#
#     def cluster_compare_before_after(self):
#         """Compares the performance of clustering on grouping the same clones
#         together. :return:
#         """
#         return
#
#
# def main():
#     return


# if "__name__" == "__main__":
#     main()
 No newline at end of file
ew file mode 100644
ndex 0000000..b48b215
inary files /dev/null and b/src/simulations/simulation_trace.png differ
eleted file mode 100644
ndex fba6e66..0000000
++ /dev/null

9e6b706ec5bc1e7d33e9ac1bd2d08ce92e874819

Implement Cassiopeida

[ ] Test implementation
[ ] Run on your simulated data.
[ ] Test their simulation framework

growth of poisson refactor

growth of poisson refactor

# TODO growth of poisson refactor

        return


    def grow_poisson(self):
        # TODO growth of poisson refactor
        return


    def subsample_new(self, to_delete=False):
        new_cell_af = self.new_cell_af
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)

        self.subsample_new_cell_clone = self.new_cell_clone.loc[
            self.subsample_new_cell_af.index]

        if to_delete:
            self.new_cell_af = None
            self.new_cell_clone = None


    def save(self, f_save=None):
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

ew file mode 100644
ndex 0000000..73dfc9c
++ b/src/simulations/plots.py

b3be2233cec362031f17bb9cd06e4295397729d6

Add the MT clone map so it can contain multiple mutants in lineages

Add the MT clone map so it can contain multiple mutants in lineages

# TODO Add the MT clone map so it can contain multiple mutants in lineages

            for i in range(1,num_clones+1):
                self.clone_mt_dict[i] = i

        # TODO Add the MT clone map so it can contain multiple mutants in lineages

        # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
        # If not, will randomly draw based on number of clones

19491ec3d3cce30737dfb5db632bd1f9d49d3871

Add the MT clone map that can contain multiple mutants in lineages

Add the MT clone map that can contain multiple mutants in lineages

# TODO Add the MT clone map that can contain multiple mutants in lineages

            for i in range(1,num_clones+1):
                self.clone_mt_dict[i] = i
        # TODO Add the MT clone map that can contain multiple mutants in lineages


        # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.

4f38161d12f40a52e61566e717a72784ab4b2b7d

This should be in mro/common for general use

This should be in mro/common for general use

# TODO: This should be in mro/common for general use

#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype json;
filetype bam;
#
# @include "_sort_and_mark_dups_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype bam;
filetype bam.bai;
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype json;
filetype csv;
#
# @include "_peak_caller_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype bedgraph;
filetype pickle;
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype bed;
filetype json;
#
# @include "_basic_sc_atac_counter_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype csv;
filetype json;
filetype bed;
filetype pickle;
filetype h5;
#
# @include "_produce_cell_barcodes_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype csv;
filetype json;
filetype bed;
filetype pickle;
filetype h5;
filetype npy.gz;
#
# @include "_sc_atac_metric_collector_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype bed;
filetype bam;
filetype csv;
filetype json;
filetype h5;
filetype txt;
filetype pickle;
#
# @include "_peak_annotator_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype bed;
filetype tsv;
filetype h5;
filetype gz;
filetype pickle;
#
# @include "_sc_atac_analyzer_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv;
filetype h5;
filetype pickle;
filetype gz;
filetype bed;
filetype csv;
#
# @include "_sc_atac_reporter_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype json;
filetype html;
filetype csv;
filetype h5;
filetype bam;
#
# @include "_atac_cloupe_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype cloupe;
filetype csv;
filetype json;
filetype h5;
filetype bed;
filetype tsv.gz.tbi;
#
# @include "_preflight_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype csv;
filetype bed;
filetype tsv.gz;
filetype tsv.gz.tbi;

#
# @include "_aligner_stages.mro"
#

# SETUP_CHUNKS chunks up the input fastq data into sets of matched R1, R2, SI, and BC fastq files.
# input_mode specifies how FASTQs were generated. There are two modes:
#
# 1. "BCL_PROCESSOR"
#
# FASTQs produced by the 10X BCL_PROCESSOR pipeline. This mode assumes the FASTQ files obey the internal
# naming conventions and the reads have been interleaved into RA FASTQ files.
#
# 2. "ILMN_BCL2FASTQ"
#
# FASTQs produced directly by Illumina BCL2FASTQ v1.8.4. For this mode, BCL2FASTQ must be configured to emit the
# index2 read, rather than using it for dual-index demultiplexing:
#
# configureBclToFastq.pl --no-eamss --use-bases-mask=Y100,I8,Y14,Y100 --input-dir=<basecalls_dir> \
#     --output-dir=<output_dir> --sample-sheet=<sample_sheet.csv>
#
# The sample sheet must be formatted as per the BCL2FASTQ documentation (10 column csv), and must contain a row for
# each sample index used. The sequencer must have been run in dual index mode, with the second index read (used to
# read the 10X barcode) emitted as the R2 output file. The --use-bases-mask argument should be set to the read
# length used.
stage SETUP_CHUNKS(
    in  string   sample_id        "id of the sample",
    in  map[]    sample_def       "list of dictionary specifying input data",
    in  string   input_mode       "configuration of the input fastqs",
    in  map      downsample       "map specifies either subsample_rate (float) or gigabases (int)",
    out map[]    chunks           "map has barcode, barcode_reverse_complement, sample_index, read1, read2, gem_group, and read_group fields",
    out string[] read_groups      "list of strings representing read groups",
    out json     downsample_info  "info about downsampling result",
    src py       "stages/processing/setup_chunks",
)

# Trims adapter sequences from reads and massages fastq output into a fixed format (interleaved R1 file, etc.)
stage TRIM_READS(
    in  map[]  chunks,
    in  string barcode_whitelist,
    in  int    max_read_num,
    in  map    trim_def,
    in  map    adapters,
    out map[]  chunks,
    out json   bc_counts,
    out json   lot_info,
    out json   read_counts,
    src py     "stages/processing/trim_reads",
) split (
    in  map    chunk,
) using (
    volatile = strict,
)

# Aligns the reads to the input reference, producing chunked bam files
stage ALIGN_READS(
    in  map[]  chunks,
    in  string aligner,
    in  string aligner_method,
    in  string reference_path,
    in  string read_group_sample,
    in  int    num_threads,
    out bam[],
    src py     "stages/processing/align_reads",
) split (
    in  map    chunk,
) using (
    # N.B. No index files are generated for the bam
    volatile = strict,
)

#
# @include "_aligner.mro"
#

# Takes input fastqs and chunks them, trims them, and aligns the trimmed reads to a reference
pipeline _ALIGNER(
    in  string sample_id,
    in  string fastq_mode         "configuration of the input fastqs",
    in  map[]  sample_def,
    in  string reference_path     "this is the reference_path",
    in  string barcode_whitelist  "name of barcode whitelist file",
    in  map    trim_def,
    in  map    adapters,
    in  string read_group_sample  "sample header for BAM file",
    in  map    downsample,
    out bam[]  align,
    out map[]  chunks,
    out json   bc_counts,
    out json   lot_info           "gelbead lot detected",
    out json   read_counts        "total # of read pairs before and after adapter trimming",
    out json   downsample_info    "info on downsampling",
)
{
    call SETUP_CHUNKS(
        sample_id  = self.sample_id,
        input_mode = self.fastq_mode,
        sample_def = self.sample_def,
        downsample = self.downsample,
    ) using (
        volatile = true,
    )

    call TRIM_READS(
        chunks            = SETUP_CHUNKS.chunks,
        max_read_num      = 5000000,
        trim_def          = self.trim_def,
        adapters          = self.adapters,
        barcode_whitelist = self.barcode_whitelist,
    ) using (
        volatile = true,
    )

    call ALIGN_READS(
        chunks            = TRIM_READS.chunks,
        aligner           = "bwa",
        aligner_method    = "MEM",
        reference_path    = self.reference_path,
        read_group_sample = self.read_group_sample,
        num_threads       = 4,
    ) using (
        volatile = true,
    )

    return (
        align           = ALIGN_READS,
        chunks          = TRIM_READS.chunks,
        bc_counts       = TRIM_READS.bc_counts,
        lot_info        = TRIM_READS.lot_info,
        read_counts     = TRIM_READS.read_counts,
        downsample_info = SETUP_CHUNKS.downsample_info,
    )
}

#
# @include "_sort_and_mark_dups_stages.mro"
#

# Attaches raw and corrected barcode sequences to the aligned reads
stage ATTACH_BCS(
    in  string barcode_whitelist,
    in  bam[]  align,
    in  map[]  chunks,
    in  bool   paired_end,
    in  bool   exclude_non_bc_reads,
    in  float  bc_confidence_threshold,
    in  json   bc_counts,
    out bam[]  output,
    out int    perfect_read_count,
    src py     "stages/processing/attach_bcs",
) split (
    in  bam    align_chunk,
    in  map    chunk,
) using (
    # N.B. No index files are generated for the bam
    volatile = strict,
)

stage SORT_READS_BY_POS(
    in  bam[] input,
    out bam   tagsorted_bam,
    src py    "stages/processing/sort_reads_by_pos",
) split (
    in  bam   chunk_input,
) using (
    # N.B. No index files are generated for the bam
    volatile = strict,
)

# Marks duplicates in the reads using barcodes and fragment alignments to detect PCR and optical/diffusion duplicates
stage MARK_DUPLICATES(
    in  bam        input,
    in  string     reference_path,
    in  json       raw_barcode_counts,
    in  string     barcode_whitelist,
    out bam        output,
    out bam.bai    index,
    out csv        singlecell_mapping,
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    src py         "stages/processing/mark_duplicates",
) split (
    in  map        lane_map,
    in  string     chunk_start,
    in  string     chunk_end,
    in  int        chunk_num,
) using (
    # N.B. BAM/BED index files are explicitly bound where used
    volatile = strict,
)

#
# @include "_sort_and_mark_dups.mro"
#

# Attaches barcodes to the aligned reads, marks duplicate reads, and produces a barcode-sorted and position-sorted
# output BAM
pipeline _SORT_AND_MARK_DUPS(
    in  bam[]      align,
    in  map[]      chunks,
    in  string     barcode_whitelist,
    in  json       bc_counts,
    in  string     reference_path,
    out bam        possorted_bam        "bam file sorted by position",
    out bam.bai    possorted_bam_index  "position-sorted bam index",
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    out csv        singlecell_mapping,
    out bam[]      read_paired_bam,
)
{
    call ATTACH_BCS(
        align                   = self.align,
        chunks                  = self.chunks,
        paired_end              = true,
        barcode_whitelist       = self.barcode_whitelist,
        exclude_non_bc_reads    = false,
        bc_confidence_threshold = 0.975,
        bc_counts               = self.bc_counts,
    ) using (
        volatile = true,
    )

    call SORT_READS_BY_POS(
        input = ATTACH_BCS.output,
    ) using (
        volatile = true,
    )

    call MARK_DUPLICATES(
        input              = SORT_READS_BY_POS.tagsorted_bam,
        reference_path     = self.reference_path,
        barcode_whitelist  = self.barcode_whitelist,
        raw_barcode_counts = self.bc_counts,
    ) using (
        volatile = true,
    )

    return (
        possorted_bam       = MARK_DUPLICATES.output,
        possorted_bam_index = MARK_DUPLICATES.index,
        singlecell_mapping  = MARK_DUPLICATES.singlecell_mapping,
        fragments           = MARK_DUPLICATES.fragments,
        fragments_index     = MARK_DUPLICATES.fragments_index,
        read_paired_bam     = ATTACH_BCS.output,
    )
}

#
# @include "_peak_caller_stages.mro"
#

stage COUNT_CUT_SITES(
    in  path       reference_path,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    out bedgraph   cut_sites,
    out pickle     count_dict,
    src py         "stages/processing/count_cut_sites",
) split (
    in  string     contig,
) using (
    # N.B. We explicitly bind the index file
    volatile = strict,
)

stage DETECT_PEAKS(
    in  bedgraph cut_sites,
    in  path     reference_path,
    in  pickle   count_dict,
    out bed      peaks,
    out json     peak_metrics,
    src py       "stages/processing/detect_peaks",
) split (
    in  string   contig,
    in  float[]  params,
    in  float    threshold,
) using (
    mem_gb   = 6,
    # N.B. We explicitly bind the index file
    volatile = strict,
)

#
# @include "_peak_caller.mro"
#

pipeline _PEAK_CALLER(
    in  path       reference_path,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    out bedgraph   cut_sites,
    out bed        peaks,
    out json       peak_metrics,
)
{
    call COUNT_CUT_SITES(
        reference_path  = self.reference_path,
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
    )

    call DETECT_PEAKS(
        reference_path = self.reference_path,
        cut_sites      = COUNT_CUT_SITES.cut_sites,
        count_dict     = COUNT_CUT_SITES.count_dict,
    )

    return (
        cut_sites    = COUNT_CUT_SITES.cut_sites,
        peaks        = DETECT_PEAKS.peaks,
        peak_metrics = DETECT_PEAKS.peak_metrics,
    )
}

#
# @include "_basic_sc_atac_counter_stages.mro"
#

stage GENERATE_PEAK_MATRIX(
    in  string reference_path,
    in  tsv.gz fragments,
    in  bed    peaks,
    out h5     raw_matrix,
    out path   raw_matrix_mex,
    src py     "stages/processing/generate_peak_matrix",
) split (
    in  file   barcodes,
) using (
    mem_gb   = 4,
    # N.B. we don't explicitly need the fragment index
    volatile = strict,
)

stage FILTER_PEAK_MATRIX(
    in  h5   raw_matrix,
    in  int  num_analysis_bcs,
    in  int  random_seed,
    in  csv  cell_barcodes,
    out h5   filtered_matrix,
    out path filtered_matrix_mex,
    src py   "stages/processing/filter_peak_matrix",
) split (
) using (
    volatile = strict,
)

#
# @include "_produce_cell_barcodes_stages.mro"
#

stage REMOVE_LOW_TARGETING_BARCODES(
    in  bed        peaks,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    out json       barcode_counts,
    out json       low_targeting_barcodes,
    out json       low_targeting_summary,
    out json       fragment_lengths,
    out json       covered_bases,
    src py         "stages/processing/cell_calling/remove_low_targeting_barcodes",
) split (
    in  string     contig,
    out pickle     fragment_counts,
    out pickle     targeted_counts,
    out int        peak_coverage,
) using (
    mem_gb   = 4,
    volatile = strict,
)

stage REMOVE_GEL_BEAD_DOUBLET_BARCODES(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    in  json       barcode_counts,
    out json       gel_bead_doublet_barcodes,
    out json       gel_bead_doublet_summary,
    out csv        connect_matrix,
    src py         "stages/processing/cell_calling/remove_gel_bead_doublet_barcodes",
) split (
    in  string     contig,
    in  file       valid_barcodes,
) using (
    mem_gb   = 4,
    volatile = strict,
)

stage REMOVE_BARCODE_MULTIPLETS(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    in  string     barcode_whitelist,
    in  json       barcode_counts,
    out json       barcode_multiplets,
    out json       barcode_multiplets_summary,
    src py         "stages/processing/cell_calling/remove_barcode_multiplets",
) split (
    in  string     contig,
    in  string     gem_group,
    out npy.gz     part_a_linkage_matrix,
    out npy.gz     part_b_linkage_matrix,
) using (
    mem_gb   = 4,
    volatile = strict,
)

stage MERGE_EXCLUDED_BARCODES(
    in  json[] barcode_exclusions,
    out json   excluded_barcodes,
    src py     "stages/processing/cell_calling/merge_excluded_barcodes",
)

stage DETECT_CELL_BARCODES(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     barcode_whitelist,
    in  json       excluded_barcodes,
    in  map        force_cells,
    in  string     reference_path,
    in  bed        peaks,
    out csv        cell_barcodes,
    out csv        singlecell,
    out json       cell_calling_summary,
    src py         "stages/processing/cell_calling/detect_cell_barcodes",
) split (
    in  string     contig,
    out pickle     barcode_counts,
    out pickle     targeted_counts,
    out int        fragment_depth,
) using (
    mem_gb   = 4,
    volatile = strict,
)

# TODO: This should be in mro/common for general use
stage MERGE_SUMMARY_METRICS(
    in  json[] summary_jsons,
    out json   merged_summary,
    src py     "stages/processing/cell_calling/merge_summary_metrics",
)

#
# @include "_produce_cell_barcodes.mro"
#

pipeline _PRODUCE_CELL_BARCODES(
    in  bed        peaks,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    in  string     barcode_whitelist,
    in  map        force_cells,
    out csv        cell_barcodes,
    out csv        singlecell,
    out json       cell_calling_summary,
    out json       excluded_barcodes,
    out json       fragment_lengths,
    out json       covered_bases,
)
{
    call REMOVE_LOW_TARGETING_BARCODES(
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
        peaks           = self.peaks,
        reference_path  = self.reference_path,
    )

    call REMOVE_GEL_BEAD_DOUBLET_BARCODES(
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
        reference_path  = self.reference_path,
        barcode_counts  = REMOVE_LOW_TARGETING_BARCODES.barcode_counts,
    )

    call REMOVE_BARCODE_MULTIPLETS(
        fragments         = self.fragments,
        fragments_index   = self.fragments_index,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        barcode_counts    = REMOVE_LOW_TARGETING_BARCODES.barcode_counts,
    )

    call MERGE_EXCLUDED_BARCODES(
        barcode_exclusions = [
            REMOVE_BARCODE_MULTIPLETS.barcode_multiplets,
            REMOVE_GEL_BEAD_DOUBLET_BARCODES.gel_bead_doublet_barcodes,
            REMOVE_LOW_TARGETING_BARCODES.low_targeting_barcodes,
        ],
    )

    call DETECT_CELL_BARCODES(
        fragments         = self.fragments,
        fragments_index   = self.fragments_index,
        barcode_whitelist = self.barcode_whitelist,
        force_cells       = self.force_cells,
        excluded_barcodes = MERGE_EXCLUDED_BARCODES.excluded_barcodes,
        reference_path    = self.reference_path,
        peaks             = self.peaks,
    )

    call MERGE_SUMMARY_METRICS as MERGE_CELL_METRICS(
        summary_jsons = [
            REMOVE_LOW_TARGETING_BARCODES.low_targeting_summary,
            REMOVE_GEL_BEAD_DOUBLET_BARCODES.gel_bead_doublet_summary,
            REMOVE_BARCODE_MULTIPLETS.barcode_multiplets_summary,
            DETECT_CELL_BARCODES.cell_calling_summary,
        ],
    )

    return (
        cell_barcodes        = DETECT_CELL_BARCODES.cell_barcodes,
        excluded_barcodes    = MERGE_EXCLUDED_BARCODES.excluded_barcodes,
        singlecell           = DETECT_CELL_BARCODES.singlecell,
        cell_calling_summary = MERGE_CELL_METRICS.merged_summary,
        fragment_lengths     = REMOVE_LOW_TARGETING_BARCODES.fragment_lengths,
        covered_bases        = REMOVE_LOW_TARGETING_BARCODES.covered_bases,
    )
}

#
# @include "_basic_sc_atac_counter.mro"
#

pipeline _BASIC_SC_ATAC_COUNTER(
    in  string     sample_id,
    in  string     fastq_mode                   "configuration of the input fastqs",
    in  map[]      sample_def,
    in  string     reference_path               "this is the reference_path",
    in  string     barcode_whitelist            "name of barcode whitelist file",
    in  map        trim_def,
    in  map        adapters,
    in  map        downsample,
    in  map        force_cells,
    out bam        possorted_bam                "bam file sorted by position",
    out bam.bai    possorted_bam_index          "position-sorted bam index",
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    out json       lot_info                     "gelbead lot detected",
    out json       read_counts                  "total # of read pairs before and after adapter trimming",
    out json       downsample_info              "info on downsampling",
    out csv        cell_barcodes,
    out json       excluded_barcodes,
    out json       cell_calling_summary,
    out bed        peaks,
    out bedgraph   cut_sites,
    out csv        singlecell_mapping,
    out csv        singlecell_cells,
    out json       peak_metrics,
    out bam[]      read_paired_bam,
    out h5         raw_peak_bc_matrix,
    out path       raw_peak_bc_matrix_mex,
    out h5         filtered_peak_bc_matrix,
    out path       filtered_peak_bc_matrix_mex,
)
{
    call _ALIGNER(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        read_group_sample = self.sample_id,
        trim_def          = self.trim_def,
        adapters          = self.adapters,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        downsample        = self.downsample,
    )

    call _SORT_AND_MARK_DUPS(
        align             = _ALIGNER.align,
        chunks            = _ALIGNER.chunks,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        bc_counts         = _ALIGNER.bc_counts,
    )

    call _PEAK_CALLER(
        fragments       = _SORT_AND_MARK_DUPS.fragments,
        fragments_index = _SORT_AND_MARK_DUPS.fragments_index,
        reference_path  = self.reference_path,
    )

    call _PRODUCE_CELL_BARCODES(
        fragments         = _SORT_AND_MARK_DUPS.fragments,
        fragments_index   = _SORT_AND_MARK_DUPS.fragments_index,
        peaks             = _PEAK_CALLER.peaks,
        force_cells       = self.force_cells,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
    )

    call GENERATE_PEAK_MATRIX(
        reference_path = self.reference_path,
        fragments      = _SORT_AND_MARK_DUPS.fragments,
        peaks          = _PEAK_CALLER.peaks,
    )

    call FILTER_PEAK_MATRIX(
        num_analysis_bcs = null,
        cell_barcodes    = _PRODUCE_CELL_BARCODES.cell_barcodes,
        raw_matrix       = GENERATE_PEAK_MATRIX.raw_matrix,
        random_seed      = null,
    )

    return (
        possorted_bam               = _SORT_AND_MARK_DUPS.possorted_bam,
        possorted_bam_index         = _SORT_AND_MARK_DUPS.possorted_bam_index,
        singlecell_mapping          = _SORT_AND_MARK_DUPS.singlecell_mapping,
        singlecell_cells            = _PRODUCE_CELL_BARCODES.singlecell,
        lot_info                    = _ALIGNER.lot_info,
        read_counts                 = _ALIGNER.read_counts,
        downsample_info             = _ALIGNER.downsample_info,
        cell_barcodes               = _PRODUCE_CELL_BARCODES.cell_barcodes,
        excluded_barcodes           = _PRODUCE_CELL_BARCODES.excluded_barcodes,
        cell_calling_summary        = _PRODUCE_CELL_BARCODES.cell_calling_summary,
        peak_metrics                = _PEAK_CALLER.peak_metrics,
        cut_sites                   = _PEAK_CALLER.cut_sites,
        peaks                       = _PEAK_CALLER.peaks,
        fragments                   = _SORT_AND_MARK_DUPS.fragments,
        fragments_index             = _SORT_AND_MARK_DUPS.fragments_index,
        read_paired_bam             = _SORT_AND_MARK_DUPS.read_paired_bam,
        raw_peak_bc_matrix          = GENERATE_PEAK_MATRIX.raw_matrix,
        raw_peak_bc_matrix_mex      = GENERATE_PEAK_MATRIX.raw_matrix_mex,
        filtered_peak_bc_matrix     = FILTER_PEAK_MATRIX.filtered_matrix,
        filtered_peak_bc_matrix_mex = FILTER_PEAK_MATRIX.filtered_matrix_mex,
    )
}

#
# @include "_sc_atac_metric_collector_stages.mro"
#

stage ESTIMATE_LIBRARY_COMPLEXITY(
    in  json   sequencing_summary,
    in  tsv.gz fragments,
    in  csv    cell_barcodes,
    out json   bulk_complexity,
    out json   complexity_summary,
    out json   singlecell_complexity,
    src py     "stages/metrics/estimate_library_complexity",
) split (
    in  file   barcodes,
) using (
    mem_gb   = 6,
    volatile = strict,
)

stage GENERATE_SEQUENCING_METRICS(
    in  bam[] input,
    out txt   misc_sm,
    out json  summary,
    src py    "stages/metrics/generate_sequencing_metrics",
) split (
    in  bam   chunk_bam,
) using (
    volatile = strict,
)

stage GENERATE_SINGLECELL_TARGETING(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  bed        peaks,
    in  string     reference_path,
    out csv        singlecell,
    out json       summary,
    out csv        tss_relpos,
    out csv        ctcf_relpos,
    src py         "stages/metrics/generate_singlecell_targeting",
) split (
    in  string     contig,
    out int        read_count,
    out pickle     target_counts_by_barcode,
    out pickle     chunk_tss,
    out pickle     chunk_ctcf,
) using (
    mem_gb   = 6,
    volatile = strict,
)

stage MERGE_SINGLECELL_METRICS(
    in  string reference_path,
    in  csv    singlecell_mapping,
    in  csv    singlecell_targets,
    in  csv    singlecell_cells,
    out csv    singlecell,
    out json   summary,
    src py     "stages/metrics/merge_singlecell_metrics",
) using (
    mem_gb   = 8,
    volatile = strict,
)

stage REPORT_INSERT_SIZES(
    in  tsv.gz fragments,
    in  bool   exclude_non_nuclear,
    in  string reference_path,
    out csv    insert_sizes,
    out json   insert_summary,
    src py     "stages/metrics/report_insert_sizes",
) split (
    in  file   barcode,
    out file   total,
) using (
    volatile = strict,
)

stage REPORT_TSS_CTCF(
    in  csv  tss_relpos,
    in  csv  ctcf_relpos,
    out json summary_metrics,
    src py   "stages/metrics/report_tss_ctcf",
) using (
    volatile = strict,
)

#
# @include "_sc_atac_metric_collector.mro"
#

pipeline _SC_ATAC_METRIC_COLLECTOR(
    in  bam[]      read_paired_bam,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  bed        peaks,
    in  string     reference_path         "this is the reference_path",
    in  csv        cell_barcodes,
    in  csv        singlecell_mapping,
    in  csv        singlecell_cells,
    out json       singlecell_results,
    out csv        singlecell,
    out json       enrichment_results,
    out json       basic_summary,
    out json       insert_summary,
    out csv        insert_sizes,
    out json       bulk_complexity,
    out json       singlecell_complexity,
    out json       complexity_summary,
    out csv        tss_relpos,
    out csv        ctcf_relpos,
)
{
    call GENERATE_SINGLECELL_TARGETING(
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
        peaks           = self.peaks,
        reference_path  = self.reference_path,
    )

    call MERGE_SINGLECELL_METRICS(
        reference_path     = self.reference_path,
        singlecell_mapping = self.singlecell_mapping,
        singlecell_cells   = self.singlecell_cells,
        singlecell_targets = GENERATE_SINGLECELL_TARGETING.singlecell,
    )

    call GENERATE_SEQUENCING_METRICS(
        input = self.read_paired_bam,
    )

    call ESTIMATE_LIBRARY_COMPLEXITY(
        sequencing_summary = GENERATE_SEQUENCING_METRICS.summary,
        fragments          = self.fragments,
        cell_barcodes      = self.cell_barcodes,
    )

    call REPORT_INSERT_SIZES(
        fragments           = self.fragments,
        reference_path      = self.reference_path,
        exclude_non_nuclear = true,
    )

    call REPORT_TSS_CTCF(
        tss_relpos  = GENERATE_SINGLECELL_TARGETING.tss_relpos,
        ctcf_relpos = GENERATE_SINGLECELL_TARGETING.ctcf_relpos,
    )

    return (
        ###
        singlecell            = MERGE_SINGLECELL_METRICS.singlecell,
        singlecell_results    = MERGE_SINGLECELL_METRICS.summary,
        ###
        enrichment_results    = REPORT_TSS_CTCF.summary_metrics,
        basic_summary         = GENERATE_SEQUENCING_METRICS.summary,
        insert_summary        = REPORT_INSERT_SIZES.insert_summary,
        insert_sizes          = REPORT_INSERT_SIZES.insert_sizes,
        bulk_complexity       = ESTIMATE_LIBRARY_COMPLEXITY.bulk_complexity,
        singlecell_complexity = ESTIMATE_LIBRARY_COMPLEXITY.singlecell_complexity,
        complexity_summary    = ESTIMATE_LIBRARY_COMPLEXITY.complexity_summary,
        tss_relpos            = GENERATE_SINGLECELL_TARGETING.tss_relpos,
        ctcf_relpos           = GENERATE_SINGLECELL_TARGETING.ctcf_relpos,
    )
}

#
# @include "_peak_annotator_stages.mro"
#

stage ANNOTATE_PEAKS(
    in  bed    peaks,
    in  string reference_path,
    out tsv    peak_annotation,
    src py     "stages/analysis/annotate_peaks",
) split (
    in  int    chunk_start,
    in  int    chunk_end,
) using (
    mem_gb   = 5,
    volatile = strict,
)

stage COMPUTE_GC_DISTRIBUTION(
    in  bed    peaks,
    in  string reference_path,
    out pickle GCdict,
    src py     "stages/analysis/compute_gc_dist",
) split (
) using (
    volatile = strict,
)

stage SCAN_MOTIFS(
    in  pickle globalGCdict,
    in  bed    peaks,
    in  string reference_path,
    in  float  pwm_threshold,
    out bed    peak_motif_hits,
    src py     "stages/analysis/scan_motifs",
) split (
    in  file   GCdict,
) using (
    volatile = strict,
)

stage GENERATE_TF_MATRIX(
    in  path reference_path,
    in  bed  peaks,
    in  bed  peak_motif_hits,
    in  h5   filtered_matrix,
    out h5   filtered_tf_bc_matrix,
    out path filtered_tf_bc_matrix_mex,
    out gz   tf_propZ_matrix,
    src py   "stages/analysis/generate_tf_matrix",
) split (
) using (
    volatile = strict,
)

#
# @include "_peak_annotator.mro"
#

pipeline _PEAK_ANNOTATOR(
    in  string reference_path,
    in  bed    peaks,
    in  h5     filtered_peak_bc_matrix,
    in  float  pwm_threshold,
    out h5     filtered_tf_bc_matrix,
    out path   filtered_tf_bc_matrix_mex,
    out gz     tf_propZ_matrix,
    out tsv    peak_annotation,
)
{
    call ANNOTATE_PEAKS(
        peaks          = self.peaks,
        reference_path = self.reference_path,
    )

    call COMPUTE_GC_DISTRIBUTION(
        peaks          = self.peaks,
        reference_path = self.reference_path,
    )

    call SCAN_MOTIFS(
        globalGCdict   = COMPUTE_GC_DISTRIBUTION.GCdict,
        peaks          = self.peaks,
        reference_path = self.reference_path,
        pwm_threshold  = self.pwm_threshold,
    )

    call GENERATE_TF_MATRIX(
        reference_path  = self.reference_path,
        peaks           = self.peaks,
        filtered_matrix = self.filtered_peak_bc_matrix,
        peak_motif_hits = SCAN_MOTIFS.peak_motif_hits,
    )

    return (
        filtered_tf_bc_matrix     = GENERATE_TF_MATRIX.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex = GENERATE_TF_MATRIX.filtered_tf_bc_matrix_mex,
        tf_propZ_matrix           = GENERATE_TF_MATRIX.tf_propZ_matrix,
        peak_annotation           = ANNOTATE_PEAKS.peak_annotation,
    )
}

#
# @include "_sc_atac_analyzer_stages.mro"
#

stage ANALYZER_PREFLIGHT(
    in  bed      peaks,
    in  h5       filtered_peak_bc_matrix,
    in  string[] factorization,
    in  int      tsne_perplexity,
    in  int      random_seed,
    in  float    tsne_theta,
    in  int      tsne_mom_switch_iter,
    in  int      tsne_stop_lying_iter,
    in  int      tsne_max_dims,
    in  int      tsne_input_pcs,
    in  int      tsne_max_iter,
    in  int      max_clusters,
    in  int      num_components,
    in  int      num_dr_bcs,
    in  int      num_dr_features,
    in  float    neighbor_a,
    in  float    neighbor_b,
    in  int      graphclust_neighbors,
    src py       "stages/preflight/atac_analyzer",
)

stage REDUCE_DIMENSIONS(
    in  h5       filtered_matrix,
    in  string[] factorization,
    in  int      num_dims,
    in  int      num_bcs,
    in  int      num_features,
    in  int      random_seed,
    out path     reduced_data,
    out map      reduction_summary,
    src py       "stages/analysis/reduce_dimensions",
) split (
    in  string   method,
) using (
    volatile = strict,
)

stage CLUSTER_CELLS(
    in  h5       filtered_matrix,
    in  path     reduced_data,
    in  map      reduction_summary,
    in  string[] factorization,
    in  int      minclusters,
    in  int      maxclusters,
    in  int      num_dims,
    in  int      random_seed,
    out path     clustered_data,
    out map      clustering_summary,
    src py       "stages/analysis/cluster_cells",
) split (
    in  int      n_clusters,
) using (
    volatile = strict,
)

stage PROJECT_TSNE(
    in  h5       filtered_matrix,
    in  path     reduced_data,
    in  map      reduction_summary,
    in  int      tsne_perplexity,
    in  int      tsne_max_dims,
    in  int      tsne_input_pcs,
    in  float    tsne_theta,
    in  int      tsne_max_iter,
    in  int      tsne_stop_lying_iter,
    in  int      tsne_mom_switch_iter,
    in  int      random_seed,
    in  string[] factorization,
    out path     tsne,
    out map      tsne_summary,
    src py       "stages/analysis/project_tsne",
) split (
    in  string   method,
    in  int      tsne_dims,
) using (
    volatile = strict,
)

stage RUN_GRAPH_CLUSTERING(
    in  h5       matrix_h5                 "Processed matrix",
    in  string[] factorization,
    in  path     reduced_data,
    in  map      reduction_summary,
    in  int      num_neighbors             "Use this many neighbors",
    in  float    neighbor_a                "Use larger of (a+b*log10(n_cells) neighbors or num_neighbors",
    in  float    neighbor_b                "Use larger of (a+b*log10(n_cells) neighbors or num_neighbors",
    in  int      balltree_leaf_size,
    in  string   similarity_type           "Type of similarity to use (nn or snn)",
    out h5       chunked_neighbors,
    out path     knn_clusters,
    out map      graph_clustering_summary,
    src py       "stages/analysis/run_graph_clustering",
) split (
    in  string   method,
    in  pickle   neighbor_index,
    in  h5       submatrix,
    in  int      row_start,
    in  int      total_rows,
    in  int      k_nearest,
    in  h5       use_bcs,
) using (
    volatile = strict,
)

stage COMBINE_CLUSTERING(
    in  h5   filtered_matrix,
    in  map  clustering_summary,
    in  path clustered_data,
    in  map  graph_clustering_summary,
    in  path knn_clusters,
    out path clustering,
    out map  clustering_summary,
    src py   "stages/analysis/combine_clustering",
) using (
    volatile = strict,
)

stage SUMMARIZE_ANALYSIS(
    in  tsv  peak_annotation,
    in  h5   filtered_peak_bc_matrix,
    in  h5   filtered_tf_bc_matrix,
    in  gz   tf_propZ_matrix,
    in  path reduced_data,
    in  map  reduction_summary,
    in  path clustering,
    in  map  clustering_summary,
    in  path tsne,
    in  map  tsne_summary,
    in  path enrichment_analysis,
    in  map  enrichment_analysis_summary,
    out h5   analysis,
    out path analysis_csv,
    out h5   feature_bc_matrix,
    src py   "stages/analysis/summarize_analysis",
) split (
) using (
    volatile = strict,
)

stage PERFORM_DIFFERENTIAL_ANALYSIS(
    in  bed      peaks,
    in  string   reference_path,
    in  h5       filtered_peak_bc_matrix,
    in  h5       filtered_tf_bc_matrix,
    in  string[] factorization,
    in  path     clustering,
    in  map      clustering_summary,
    out path     enrichment_analysis,
    out map      enrichment_analysis_summary,
    src py       "stages/analysis/perform_differential_analysis",
) split (
    in  string   method,
    in  string   clustering_key,
    in  int      cluster,
    out csv      tmp_diffexp,
) using (
    volatile = strict,
)

#
# @include "_sc_atac_analyzer.mro"
#

pipeline _SC_ATAC_ANALYZER(
    in  string   reference_path,
    in  bed      peaks,
    in  h5       filtered_peak_bc_matrix,
    in  string[] factorization,
    in  int      tsne_perplexity,
    in  int      random_seed,
    in  float    tsne_theta,
    in  int      tsne_mom_switch_iter,
    in  int      tsne_stop_lying_iter,
    in  int      tsne_max_dims,
    in  int      tsne_input_pcs,
    in  int      tsne_max_iter,
    in  int      max_clusters,
    in  int      num_components,
    in  int      num_dr_bcs,
    in  int      num_dr_features,
    in  float    neighbor_a,
    in  float    neighbor_b,
    in  int      graphclust_neighbors,
    out h5       analysis,
    out path     analysis_csv,
    out h5       filtered_tf_bc_matrix,
    out path     filtered_tf_bc_matrix_mex,
    out h5       feature_bc_matrix,
    out tsv      peak_annotation,
)
{
    call ANALYZER_PREFLIGHT(
        peaks                   = self.peaks,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        factorization           = self.factorization,
        tsne_perplexity         = self.tsne_perplexity,
        random_seed             = self.random_seed,
        tsne_theta              = self.tsne_theta,
        tsne_mom_switch_iter    = self.tsne_mom_switch_iter,
        tsne_stop_lying_iter    = self.tsne_stop_lying_iter,
        tsne_max_dims           = self.tsne_max_dims,
        tsne_input_pcs          = self.tsne_input_pcs,
        tsne_max_iter           = self.tsne_max_iter,
        max_clusters            = self.max_clusters,
        num_components          = self.num_components,
        num_dr_bcs              = self.num_dr_bcs,
        num_dr_features         = self.num_dr_features,
        neighbor_a              = self.neighbor_a,
        neighbor_b              = self.neighbor_b,
        graphclust_neighbors    = self.graphclust_neighbors,
    ) using (
        volatile = true,
    )

    call _PEAK_ANNOTATOR(
        reference_path          = self.reference_path,
        peaks                   = self.peaks,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        pwm_threshold           = null,
    )

    call REDUCE_DIMENSIONS(
        filtered_matrix = self.filtered_peak_bc_matrix,
        factorization   = self.factorization,
        num_dims        = self.num_components,
        num_bcs         = self.num_dr_bcs,
        num_features    = self.num_dr_features,
        random_seed     = self.random_seed,
    )

    call CLUSTER_CELLS(
        filtered_matrix   = self.filtered_peak_bc_matrix,
        reduced_data      = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary = REDUCE_DIMENSIONS.reduction_summary,
        factorization     = self.factorization,
        minclusters       = 2,
        maxclusters       = self.max_clusters,
        num_dims          = null,
        random_seed       = self.random_seed,
    )

    call PROJECT_TSNE(
        filtered_matrix      = self.filtered_peak_bc_matrix,
        reduced_data         = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary    = REDUCE_DIMENSIONS.reduction_summary,
        tsne_perplexity      = self.tsne_perplexity,
        tsne_max_dims        = self.tsne_max_dims,
        tsne_input_pcs       = self.tsne_input_pcs,
        tsne_theta           = self.tsne_theta,
        tsne_max_iter        = self.tsne_max_iter,
        tsne_stop_lying_iter = self.tsne_stop_lying_iter,
        tsne_mom_switch_iter = self.tsne_mom_switch_iter,
        random_seed          = self.random_seed,
        factorization        = self.factorization,
    )

    call RUN_GRAPH_CLUSTERING(
        matrix_h5          = self.filtered_peak_bc_matrix,
        factorization      = self.factorization,
        reduced_data       = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary  = REDUCE_DIMENSIONS.reduction_summary,
        num_neighbors      = self.graphclust_neighbors,
        neighbor_a         = self.neighbor_a,
        neighbor_b         = self.neighbor_b,
        balltree_leaf_size = null,
        similarity_type    = "nn",
    )

    call COMBINE_CLUSTERING(
        filtered_matrix          = self.filtered_peak_bc_matrix,
        clustering_summary       = CLUSTER_CELLS.clustering_summary,
        clustered_data           = CLUSTER_CELLS.clustered_data,
        graph_clustering_summary = RUN_GRAPH_CLUSTERING.graph_clustering_summary,
        knn_clusters             = RUN_GRAPH_CLUSTERING.knn_clusters,
    )

    call PERFORM_DIFFERENTIAL_ANALYSIS(
        reference_path          = self.reference_path,
        peaks                   = self.peaks,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        filtered_tf_bc_matrix   = _PEAK_ANNOTATOR.filtered_tf_bc_matrix,
        factorization           = self.factorization,
        clustering              = COMBINE_CLUSTERING.clustering,
        clustering_summary      = COMBINE_CLUSTERING.clustering_summary,
    )

    call SUMMARIZE_ANALYSIS(
        peak_annotation             = _PEAK_ANNOTATOR.peak_annotation,
        filtered_peak_bc_matrix     = self.filtered_peak_bc_matrix,
        filtered_tf_bc_matrix       = _PEAK_ANNOTATOR.filtered_tf_bc_matrix,
        tf_propZ_matrix             = _PEAK_ANNOTATOR.tf_propZ_matrix,
        reduced_data                = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary           = REDUCE_DIMENSIONS.reduction_summary,
        clustering                  = COMBINE_CLUSTERING.clustering,
        clustering_summary          = COMBINE_CLUSTERING.clustering_summary,
        tsne                        = PROJECT_TSNE.tsne,
        tsne_summary                = PROJECT_TSNE.tsne_summary,
        enrichment_analysis         = PERFORM_DIFFERENTIAL_ANALYSIS.enrichment_analysis,
        enrichment_analysis_summary = PERFORM_DIFFERENTIAL_ANALYSIS.enrichment_analysis_summary,
    )

    return (
        analysis                  = SUMMARIZE_ANALYSIS.analysis,
        analysis_csv              = SUMMARIZE_ANALYSIS.analysis_csv,
        filtered_tf_bc_matrix     = _PEAK_ANNOTATOR.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex = _PEAK_ANNOTATOR.filtered_tf_bc_matrix_mex,
        feature_bc_matrix         = SUMMARIZE_ANALYSIS.feature_bc_matrix,
        peak_annotation           = _PEAK_ANNOTATOR.peak_annotation,
    )
}

#
# @include "_sc_atac_reporter_stages.mro"
#

stage SUMMARIZE_REPORTS_SINGLECELL(
    in  string reference_path,
    in  json   complexity_summary,
    in  json   cell_calling_summary,
    in  json   peak_results,
    in  json   basic_results,
    in  json   error_results_summary,
    in  json   insert_summary,
    in  json   singlecell_results,
    in  json   contam_results,
    in  json   downsample_info,
    in  json   enrichment_results,
    out json   analysis_params,
    out json   summary,
    out csv    summary_csv,
    src py     "stages/reporter/summarize_reports_singlecell",
) using (
    mem_gb = 4,
)

stage CREATE_WEBSUMMARY(
    in  string reference_path,
    in  string barcode_whitelist,
    in  json   summary_results,
    in  json   bulk_complexity,
    in  json   singlecell_complexity,
    in  string sample_id,
    in  string sample_desc,
    in  map[]  sample_def,
    in  bool   debug,
    in  csv    singlecell,
    in  csv    insert_sizes,
    in  csv    tss_relpos,
    in  csv    ctcf_relpos,
    in  h5     filtered_peak_bc_matrix,
    in  h5     analysis,
    in  json   excluded_barcodes,
    out html   web_summary,
    src py     "stages/reporter/create_websummary",
) using (
    mem_gb = 16,
)

#
# @include "_sc_atac_reporter.mro"
#

pipeline _SC_ATAC_REPORTER(
    in  string reference_path,
    in  string barcode_whitelist,
    in  json   bulk_complexity,
    in  json   cell_calling_summary,
    in  json   complexity_summary,
    in  json   basic_summary,
    in  json   peak_summary,
    in  json   singlecell_results,
    in  json   insert_summary,
    in  json   downsample_info,
    in  json   singlecell_complexity,
    in  csv    singlecell,
    in  csv    tss_relpos,
    in  csv    ctcf_relpos,
    in  string sample_id,
    in  string sample_desc,
    in  map[]  sample_def,
    in  csv    sc_insert_sizes,
    in  json   enrichment_results,
    in  h5     filtered_peak_bc_matrix,
    in  h5     analysis,
    in  json   excluded_barcodes,
    #
    out json   summary,
    out html   web_summary,
    out csv    summary_csv,
)
{
    call SUMMARIZE_REPORTS_SINGLECELL(
        reference_path        = self.reference_path,
        complexity_summary    = self.complexity_summary,
        cell_calling_summary  = self.cell_calling_summary,
        peak_results          = self.peak_summary,
        basic_results         = self.basic_summary,
        error_results_summary = null,
        insert_summary        = self.insert_summary,
        singlecell_results    = self.singlecell_results,
        contam_results        = null,
        downsample_info       = self.downsample_info,
        enrichment_results    = self.enrichment_results,
    )

    call CREATE_WEBSUMMARY(
        reference_path          = self.reference_path,
        barcode_whitelist       = self.barcode_whitelist,
        singlecell              = self.singlecell,
        tss_relpos              = self.tss_relpos,
        ctcf_relpos             = self.ctcf_relpos,
        sample_id               = self.sample_id,
        sample_desc             = self.sample_desc,
        sample_def              = self.sample_def,
        insert_sizes            = self.sc_insert_sizes,
        summary_results         = SUMMARIZE_REPORTS_SINGLECELL.summary,
        bulk_complexity         = self.bulk_complexity,
        singlecell_complexity   = self.singlecell_complexity,
        analysis                = self.analysis,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        excluded_barcodes       = self.excluded_barcodes,
        debug                   = false,
    )

    return (
        summary     = SUMMARIZE_REPORTS_SINGLECELL.summary,
        web_summary = CREATE_WEBSUMMARY.web_summary,
        summary_csv = SUMMARIZE_REPORTS_SINGLECELL.summary_csv,
    )
}

#
# @include "_atac_cloupe_stages.mro"
#

stage CLOUPE_PREPROCESS(
    in  string     pipestance_type,
    in  string     sample_id,
    in  string     sample_desc,
    in  string     reference_path,
    in  h5         analysis,
    in  h5         feature_barcode_matrix,
    in  bed        peaks,
    in  tsv.gz.tbi fragments_index,
    in  json       metrics_json,
    in  csv        aggregation_csv,
    in  json       gem_group_index_json,
    in  bool       no_secondary_analysis,
    out cloupe     output_for_cloupe,
    out json       gem_group_index_json,
    src py         "stages/cloupe/atac_cloupe_preprocess",
) split (
)

#
# @include "_preflight_stages.mro"
#

stage ATAC_COUNTER_PREFLIGHT(
    in  string   sample_id,
    in  string   fastq_mode,
    in  map[]    sample_def,
    in  string   reference_path,
    in  map      force_cells,
    in  string[] factorization,
    in  map      downsample,
    in  bool     check_executables,
    in  map      trim_def,
    src py       "stages/preflight/atac_counter",
) split (
)

stage ATAC_AGGR_PREFLIGHT(
    in  string   sample_id,
    in  string   reference_path,
    in  csv      aggr_csv,
    in  string   normalization,
    in  string[] factorization,
    in  bool     check_executables,
    src py       "stages/preflight/atac_aggr",
) split (
)

stage ATAC_REANALYZER_PREFLIGHT(
    in  string     sample_id,
    in  string     reference_path,
    in  string     barcode_whitelist,
    in  bed        peaks,
    in  csv        parameters,
    in  map        force_cells,
    in  csv        cell_barcodes,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  csv        aggregation_csv,
    in  bool       check_executables,
    src py         "stages/preflight/atac_reanalyzer",
) split (
)

#
# @include "sc_atac_counter.mro"
#

pipeline SC_ATAC_COUNTER(
    in  string     fastq_mode,
    in  string     sample_id,
    in  map[]      sample_def,
    in  map        downsample,
    in  string     sample_desc,
    in  string     reference_path,
    in  map        trim_def,
    in  string     barcode_whitelist,
    in  map        adapters,
    in  string[]   factorization,
    in  map        force_cells,
    #
    out csv        singlecell,
    out bam        possorted_bam,
    out bam.bai    possorted_bam_index,
    out json       summary,
    out html       web_summary,
    out bed        peaks,
    out h5         raw_peak_bc_matrix,
    out path       raw_peak_bc_matrix_mex,
    out path       analysis_csv,
    out h5         filtered_peak_bc_matrix,
    out path       filtered_peak_bc_matrix_mex,
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    out h5         filtered_tf_bc_matrix,
    out path       filtered_tf_bc_matrix_mex,
    out cloupe     cloupe,
    out csv        summary_csv,
    out tsv        peak_annotation,
)
{
    call ATAC_COUNTER_PREFLIGHT as ATAC_COUNTER_PREFLIGHT_LOCAL(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        reference_path    = self.reference_path,
        force_cells       = self.force_cells,
        factorization     = self.factorization,
        downsample        = self.downsample,
        trim_def          = self.trim_def,
        check_executables = false,
    ) using (
        local     = true,
        preflight = true,
    )

    call ATAC_COUNTER_PREFLIGHT(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        reference_path    = self.reference_path,
        force_cells       = self.force_cells,
        factorization     = self.factorization,
        downsample        = self.downsample,
        trim_def          = self.trim_def,
        check_executables = true,
    ) using (
        preflight = true,
    )

    call _BASIC_SC_ATAC_COUNTER(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        trim_def          = self.trim_def,
        adapters          = self.adapters,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        downsample        = self.downsample,
        force_cells       = self.force_cells,
    )

    call _SC_ATAC_METRIC_COLLECTOR(
        read_paired_bam    = _BASIC_SC_ATAC_COUNTER.read_paired_bam,
        fragments          = _BASIC_SC_ATAC_COUNTER.fragments,
        fragments_index    = _BASIC_SC_ATAC_COUNTER.fragments_index,
        peaks              = _BASIC_SC_ATAC_COUNTER.peaks,
        reference_path     = self.reference_path,
        cell_barcodes      = _BASIC_SC_ATAC_COUNTER.cell_barcodes,
        singlecell_cells   = _BASIC_SC_ATAC_COUNTER.singlecell_cells,
        singlecell_mapping = _BASIC_SC_ATAC_COUNTER.singlecell_mapping,
    )

    call _SC_ATAC_ANALYZER(
        peaks                   = _BASIC_SC_ATAC_COUNTER.peaks,
        filtered_peak_bc_matrix = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        reference_path          = self.reference_path,
        factorization           = self.factorization,
        tsne_perplexity         = 30,
        tsne_max_dims           = null,
        tsne_input_pcs          = null,
        tsne_max_iter           = null,
        tsne_stop_lying_iter    = null,
        tsne_mom_switch_iter    = null,
        tsne_theta              = null,
        random_seed             = null,
        max_clusters            = 10,
        neighbor_a              = null,
        neighbor_b              = null,
        graphclust_neighbors    = null,
        num_components          = 15,
        num_dr_bcs              = null,
        num_dr_features         = null,
    )

    call CLOUPE_PREPROCESS(
        pipestance_type        = "SC_ATAC_COUNTER_CS",
        reference_path         = self.reference_path,
        sample_id              = self.sample_id,
        sample_desc            = self.sample_desc,
        analysis               = _SC_ATAC_ANALYZER.analysis,
        feature_barcode_matrix = _SC_ATAC_ANALYZER.feature_bc_matrix,
        metrics_json           = _SC_ATAC_METRIC_COLLECTOR.basic_summary,
        peaks                  = _BASIC_SC_ATAC_COUNTER.peaks,
        fragments_index        = _BASIC_SC_ATAC_COUNTER.fragments_index,
        aggregation_csv        = null,
        gem_group_index_json   = null,
        no_secondary_analysis  = false,
    )

    call _SC_ATAC_REPORTER(
        reference_path          = self.reference_path,
        barcode_whitelist       = self.barcode_whitelist,
        bulk_complexity         = _SC_ATAC_METRIC_COLLECTOR.bulk_complexity,
        singlecell_complexity   = _SC_ATAC_METRIC_COLLECTOR.singlecell_complexity,
        cell_calling_summary    = _BASIC_SC_ATAC_COUNTER.cell_calling_summary,
        complexity_summary      = _SC_ATAC_METRIC_COLLECTOR.complexity_summary,
        basic_summary           = _SC_ATAC_METRIC_COLLECTOR.basic_summary,
        peak_summary            = _BASIC_SC_ATAC_COUNTER.peak_metrics,
        singlecell_results      = _SC_ATAC_METRIC_COLLECTOR.singlecell_results,
        insert_summary          = _SC_ATAC_METRIC_COLLECTOR.insert_summary,
        downsample_info         = _BASIC_SC_ATAC_COUNTER.downsample_info,
        singlecell              = _SC_ATAC_METRIC_COLLECTOR.singlecell,
        tss_relpos              = _SC_ATAC_METRIC_COLLECTOR.tss_relpos,
        ctcf_relpos             = _SC_ATAC_METRIC_COLLECTOR.ctcf_relpos,
        sample_id               = self.sample_id,
        sample_desc             = self.sample_desc,
        sample_def              = self.sample_def,
        sc_insert_sizes         = _SC_ATAC_METRIC_COLLECTOR.insert_sizes,
        enrichment_results      = _SC_ATAC_METRIC_COLLECTOR.enrichment_results,
        filtered_peak_bc_matrix = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        analysis                = _SC_ATAC_ANALYZER.analysis,
        excluded_barcodes       = _BASIC_SC_ATAC_COUNTER.excluded_barcodes,
    )

    return (
        singlecell                  = _SC_ATAC_METRIC_COLLECTOR.singlecell,
        possorted_bam               = _BASIC_SC_ATAC_COUNTER.possorted_bam,
        possorted_bam_index         = _BASIC_SC_ATAC_COUNTER.possorted_bam_index,
        summary                     = _SC_ATAC_REPORTER.summary,
        web_summary                 = _SC_ATAC_REPORTER.web_summary,
        peaks                       = _BASIC_SC_ATAC_COUNTER.peaks,
        raw_peak_bc_matrix          = _BASIC_SC_ATAC_COUNTER.raw_peak_bc_matrix,
        raw_peak_bc_matrix_mex      = _BASIC_SC_ATAC_COUNTER.raw_peak_bc_matrix_mex,
        analysis_csv                = _SC_ATAC_ANALYZER.analysis_csv,
        filtered_peak_bc_matrix     = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        filtered_peak_bc_matrix_mex = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix_mex,
        fragments                   = _BASIC_SC_ATAC_COUNTER.fragments,
        fragments_index             = _BASIC_SC_ATAC_COUNTER.fragments_index,
        filtered_tf_bc_matrix       = _SC_ATAC_ANALYZER.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex   = _SC_ATAC_ANALYZER.filtered_tf_bc_matrix_mex,
        cloupe                      = CLOUPE_PREPROCESS.output_for_cloupe,
        summary_csv                 = _SC_ATAC_REPORTER.summary_csv,
        peak_annotation             = _SC_ATAC_ANALYZER.peak_annotation,
    )
}

#
# @include "sc_atac_counter_cs.mro"
#

# Customer-facing (CS) pipeline
pipeline SC_ATAC_COUNTER_CS(
    in  string     fastq_mode                   "Input fastq configuration",
    in  string     sample_id,
    in  map[]      sample_def,
    in  map        downsample,
    in  string     sample_desc                  "Sample description",
    in  string     reference_path               "Path to 10X reference package",
    in  string[]   factorization                "Dimensionality reduction method (lsa, plsa, or pca)",
    in  map        force_cells                  "Force cell calling to a fixed number",
    #
    out csv        singlecell                   "Per-barcode fragment counts & metrics",
    out bam        possorted_bam                "Position sorted BAM file"  "possorted_bam.bam",
    out bam.bai    possorted_bam_index          "Position sorted BAM index"  "possorted_bam.bam.bai",
    out json       summary                      "Summary of all data metrics",
    out html       web_summary                  "HTML file summarizing data & analysis",
    out bed        peaks                        "Bed file of all called peak locations",
    out h5         raw_peak_bc_matrix           "Raw peak barcode matrix in hdf5 format",
    out path       raw_peak_bc_matrix_mex       "Raw peak barcode matrix in mex format"  "raw_peak_bc_matrix",
    out path       analysis_csv                 "Directory of analysis files"  "analysis",
    out h5         filtered_peak_bc_matrix      "Filtered peak barcode matrix in hdf5 format",
    out path       filtered_peak_bc_matrix_mex  "Filtered peak barcode matrix in mex format"  "filtered_peak_bc_matrix",
    out tsv.gz     fragments                    "Barcoded and aligned fragment file"  "fragments.tsv.gz",
    out tsv.gz.tbi fragments_index              "Fragment file index"       "fragments.tsv.gz.tbi",
    out h5         filtered_tf_bc_matrix        "Filtered tf barcode matrix in hdf5 format",
    out path       filtered_tf_bc_matrix_mex    "Filtered tf barcode matrix in mex format"  "filtered_tf_bc_matrix",
    out cloupe     cloupe                       "Loupe Cell Browser input file",
    out csv        summary_csv                  "csv summarizing important metrics and values"  "summary.csv",
    out tsv        peak_annotation              "Annotation of peaks with genes",
)
{
    call SC_ATAC_COUNTER(
        fastq_mode        = self.fastq_mode,
        sample_id         = self.sample_id,
        sample_def        = self.sample_def,
        downsample        = self.downsample,
        sample_desc       = self.sample_desc,
        reference_path    = self.reference_path,
        trim_def          = {
            "R1": {
                "3prime": ["MErc"],
            },
            "R2": {
                "3prime": ["MErc"],
            },
            "discard_untrimmed": false,
        },
        barcode_whitelist = "737K-cratac-v1",
        adapters          = {
            "ME": "AGATGTGTATAAGAGACAG",
            "MErc": "CTGTCTCTTATACACATCT",
        },
        factorization     = self.factorization,
        force_cells       = self.force_cells,
    )

    return (
        singlecell                  = SC_ATAC_COUNTER.singlecell,
        possorted_bam               = SC_ATAC_COUNTER.possorted_bam,
        possorted_bam_index         = SC_ATAC_COUNTER.possorted_bam_index,
        summary                     = SC_ATAC_COUNTER.summary,
        web_summary                 = SC_ATAC_COUNTER.web_summary,
        peaks                       = SC_ATAC_COUNTER.peaks,
        raw_peak_bc_matrix          = SC_ATAC_COUNTER.raw_peak_bc_matrix,
        raw_peak_bc_matrix_mex      = SC_ATAC_COUNTER.raw_peak_bc_matrix_mex,
        analysis_csv                = SC_ATAC_COUNTER.analysis_csv,
        filtered_peak_bc_matrix     = SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        filtered_peak_bc_matrix_mex = SC_ATAC_COUNTER.filtered_peak_bc_matrix_mex,
        fragments                   = SC_ATAC_COUNTER.fragments,
        fragments_index             = SC_ATAC_COUNTER.fragments_index,
        filtered_tf_bc_matrix       = SC_ATAC_COUNTER.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex   = SC_ATAC_COUNTER.filtered_tf_bc_matrix_mex,
        cloupe                      = SC_ATAC_COUNTER.cloupe,
        summary_csv                 = SC_ATAC_COUNTER.summary_csv,
        peak_annotation             = SC_ATAC_COUNTER.peak_annotation,
    )
}

#
# @include "__PBMC_J.mro"
#

call SC_ATAC_COUNTER_CS(
    fastq_mode     = "ILMN_BCL2FASTQ",
    sample_id      = "PBMC_J",
    sample_def     = [{
        "bc_in_read": 1,
        "bc_length": 16,
        "gem_group": null,
        "lanes": null,
        "library": "LibraryNotSpecified",
        "read_path": "/data/isshamie/dropbox/ATACseq/2020_11_18_Croker/igm-storage2.ucsd.edu/201113_A00953_0185_AHN7TMDSXY",
        "sample_indices": ["any"],
        "sample_names": ["BC_10xATAC_PMBC_J"],
    }],
    reference_path = "/data/isshamie/mito_lineage/data/external/GRCh38_MT_blacklist",
    downsample     = null,
    sample_desc    = "",
    factorization  = ["lsa"],
    force_cells    = null,
)
ew file mode 100644
ndex 0000000..89e54b3
++ b/Analysis/mtscATAC/2020_11_18_Croker/PBMC_J/_sitecheck

6f35b0923972c3a16fe83e1aa2a9d2dc6da1ebea

parallel_apply over simulations

parallel_apply over simulations

# TODO parallel_apply over simulations

class FullSimulation:
    def __init__(self, params):
        # TODO parallel_apply over simulations
        self.n_iter = params['num_iterations']
        self.params = params
        # Parallelize df
        df = pd.DataFrame(index=range(self.n_iter))
        df = df.apply(self.run_sim)
        self.sim = df
        return
        #for i in self.n_iter:

f640758f0ab35566a474d6462eb219bde359f6be

Add the MT clone map so it can contain multiple mutants in lineages

Add the MT clone map so it can contain multiple mutants in lineages

# TODO Add the MT clone map so it can contain multiple mutants in lineages

import numpy as np
from numpy import random
import os
import pandas as pd
from tqdm import tqdm
#from src.config import ROOT_DIR
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import glob
from sklearn.cluster import KMeans
from sklearn import metrics
from scipy.spatial.distance import cdist
from pandarallel import pandarallel
pandarallel.initialize(nb_workers=32)

from mplh.color_utils import get_colors
from mplh.fig_utils import legend_from_color
from mplh import cluster_help as ch
from src.simulations.utils.config import read_config_file, write_config_file

from dynamicTreeCut import cutreeHybrid
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage
from sklearn.model_selection import ParameterGrid
from src.simulations.utils.config import check_required


class Simulation:
    """Lineage tracing simulation of one sample

    Will initialize cells based on their parameters and grow as well. This
    should be a flexible framework, to add different ways to initialize, grow,
    and metrics to have. Additionally can cluster these results.

    :ivar params
    :type params: dict
    """

    def __init__(self, params_f):
        """
        :param params_f: Parameter yaml file for the specifications
        :type params_f: yaml file or dict
        """
        if isinstance(params_f, str):
            params = read_config_file(params_f)
        else:
            params = params_f

        self.params = params
        check_required(params, ['initialize', 'num_cells', 'num_mt_positions', 'prefix'])
        self.prefix = params['prefix']
        self.num_mt_positions = params['num_mt_positions']
        self.num_cells = params['num_cells']
        if not os.path.exists(params['local_outdir']):
            os.mkdir(params['local_outdir'])


    def initialize(self):
        self.init_clone_dict()
        self.init_cell_coverage()
        self.init_cell_af()
        #self.init_clone_mt()

    #should be external method
    def grow(self):
        p = self.params
        type = p["growth"]["type"]
        if  type == "poisson":
            self.grow_poisson(p['growth']['poisson'])
        elif type == "binomial":
            self.grow_binomial(p['growth']['binomial'])
        return

    # Static Method
    @staticmethod
    def clone_counts_to_cell_series(clone_counts):
        """
        Args:
            clone_counts:
        """
        clone_counts = np.array(clone_counts)
        num_cells = clone_counts.sum()
        clone_cell = -1 * np.ones(shape=[num_cells, ])


        clone_cell[:clone_counts[0]] = 0
        for ind, val in enumerate(clone_counts[1:]):
            start = clone_counts[:ind + 1].sum()
            end = clone_counts[:ind + 1].sum() + val
            # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
            clone_cell[start:end] = ind + 1

        clone_cell = pd.Series(clone_cell, dtype=int)
        return clone_cell

    def init_clone_dict(self):
        ### Add in potential to overwrite the values

        # Gets the clone dictionary. Should also have clone to mt dict.
        clones = self.params['initialize']['clone_sizes']
        num_cells = self.num_cells

        # Option 1: List of fraction of size of each clone. 0s are nonclone size, listed first
        if type(clones) == list:
            #clone_cell = pd.Series(index=range(num_cells))
            clone_counts = np.random.multinomial(num_cells, clones)
            clone_cell  = self.clone_counts_to_cell_series(clone_counts)
            self.clone_cell = clone_cell
        # Option 2: 1 clone. ID'd as 1
        elif type(clones) == int: #One number for dominant clone. the others are not.
            clone_cell = np.zeros(shape=[num_cells,])
            clone_cell[:num_cells] = 1
            clone_cell = clone_cell[::-1]
            clone_cell =  pd.Series(clone_cell, dtype=int)
            self.clone_cell = clone_cell

        # Option 3 To ADD, beta binomial and more complex distributions

        self.num_clones =  len(set(clone_cell.values))-1 # Remove the non-clone
        return clone_cell


    def init_cell_coverage(self):
        """There are different modes to the coverage, either a constant or
        through a distribution. :return:
        """
        p = self.params['initialize']['coverage']
        type = p['type']

        num_cells = self.num_cells
        num_pos = self.num_mt_positions
        c = np.zeros([num_cells, num_pos])

        if type == 'constant':
            c[:, :] = p['cov_constant']
        elif type == "poisson":
            # Get the number of coverage per cell based on poisson (should be reads)
            mu_cov_per_cell = p['mu_cov_per_cell']
            num_reads_per_cell = random.poisson(lam=mu_cov_per_cell,
                                                size=num_cells)

            # Number of reads at each position, based on the average for each cell
            for i in num_cells:
                c[i, :] = random.poisson(num_reads_per_cell[i],
                                         size=num_pos)
        self.cells_mt_coverage = c
        return c


    def init_cell_af(self):
        """Initialize the cell-by-mtPos af dataframe. Unless a clone:mt dict was
        provided, the first N MT positions will be the clone AFs. Creates
        self.clone_mt_dict and self.cell_af
        """

        p = self.params['initialize']

        hets = self.params['het']
        q = self.params['het_err_rate']
        clone_df = self.clone_cell
        num_clones = self.num_clones
        n_cells = self.num_cells
        n_mt = self.num_mt_positions

        # Output
        cell_af = pd.DataFrame(np.zeros(shape=[n_cells, n_mt]))


        if 'mt_clone_map' in p and p['mt_clone_map'] is not None:
            self.clone_mt_dict = p['mt_clone_map']
        else:
            # Each clone points to a mt position
            self.clone_mt_dict = dict()
            for i in range(1,num_clones+1):
                self.clone_mt_dict[i] = i

        # TODO Add the MT clone map so it can contain multiple mutants in lineages

        # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
        # If not, will randomly draw based on number of clones
        if type(hets) == list:
            if (len(hets) != num_clones):
                print('here')
            assert(len(hets) == num_clones)

            ## Loop through each clone,
            ## Generate the AF for the clone and non-clones using coverage for each cell
            ## Fill in cell_by_af for that position.
            for ind in range(1, num_clones+1):
                # Generate AF: (clone_df ==  ind).sum()
                n_dom_cells = (clone_df==ind).sum()
                het = hets[ind-1]

                curr_mt = self.clone_mt_dict[ind]


                if p['coverage']['type'] == 'constant':
                    c = p['coverage']['cov_constant']

                    af_i = random.binomial(c, het,
                                           n_dom_cells) / c
                    af_j = random.binomial(c, q,
                                           n_cells - n_dom_cells) / c

                    # Update the dom_cells and non_dom for the current MT
                    cell_af.loc[np.flatnonzero(clone_df == ind), curr_mt] = af_i
                    cell_af.loc[np.flatnonzero(clone_df != ind), curr_mt] = af_j

                # Each cell and position has it's own coverage value, so need to update each
                else:
                    c = self.cells_mt_coverage
                    #Get the cells coverage for the mt position
                    curr_mt_cov= c[:, curr_mt]

                    # Get cell indicies for the clones and nonclones
                    curr_clone_inds = np.flatnonzero(clone_df==ind)
                    curr_nonclone_inds = np.flatnonzero(clone_df!=ind)
                    for cell in curr_clone_inds:
                        # Get one value for curr_mt and cell based on coverage
                        cell_af.loc[cell, curr_mt] = random.binomial(curr_mt_cov[cell], het)
                    for cell in curr_nonclone_inds:
                        cell_af.loc[cell, curr_mt] = random.binomial(curr_mt_cov[cell], q)
                # Loop through each coverage
                #for c in n_dom_cells:

        #####
        # TODO
        # Add noise to the other non-lineage positions
        #####
        self.cell_af = cell_af
        return


    def init_clone_mt(self):
        p = self.params
        if p["initialize"]['type'] == 'growth':
            ## TODO
            # Create a phylogeny and then get the averages of the mutants
            self.average_clone_mt()
        # If not growth, should aready be there.
        return

    def average_clone_mt(self):
        return

    def extract_clone_cells(self, clone_id):
        """
        Args:
            clone_id:
        """
        ids = np.flatnonzero(self.clone_cell == clone_id)
        return ids

    def simulate_expand_cells_af(self, af, growth_inds, sigma):
        """Given a cell-by-af vector, expand the AF.

        Expanded AF occurs by duplicating cells that grew based on the
        growth_inds vector. It will add standard error to each af based on sigma
        :param af: :param growth: Indices of AF to copy :param sigma: Variance
        to add to AF of child. :return:

        Args:
            af:
            growth_inds:
            sigma:
        """

        new_af = af.iloc[growth_inds].copy() + random.normal(0, sigma, size=af.iloc[growth_inds].shape)
        new_af.index = np.arange(af.index[-1]+1, af.index[-1]+1+new_af.shape[0])
        new_af = pd.concat((af,new_af), axis=0)
        #new_af = np.append(af, np.concatenate(new_af))
        return new_af

    def grow_binomial(self, p):
        """
        Args:
            p:
        """
        timesteps = p["time_steps"]
        rates = p["rates"]

        sigma = self.params['growth']["mutant_af_sigma_noise"]
        cell_af = self.cell_af
        clone_mt_dict = self.clone_mt_dict

        num_clones = self.num_clones+1
        new_dict = {}
        for curr_clone in range(num_clones):
            curr_rate = rates[curr_clone]
            ids = self.extract_clone_cells(curr_clone)
            new_cells = cell_af.loc[ids].copy()
            for i in range(timesteps):
                # Simulate growth for each clone separately.
                growth_inds = np.flatnonzero(random.binomial(1, curr_rate, size=new_cells.shape[0]))
                #new_ids =
                new_cells = self.simulate_expand_cells_af(new_cells, growth_inds, sigma)

            new_dict[curr_clone] = new_cells
            # Create list of cells

        ####TODO
        ## new_lineage_mutants chances. This will see if a mutation will change


        ####TODO
        ## Add death + stimulation rate as well as growth
        # Save the new cell clones df and cell af
        clone_counts = [i.shape[0] for i in new_dict.values()]
        self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)

        self.new_cell_af = pd.DataFrame(new_dict[0])
        for clone in range(1, self.num_clones+1):
            self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
        return


    def grow_poisson(self):
        # TODO growth of poisson refactor
        return


    def subsample_new(self, to_delete=False):
        """
        Args:
            to_delete:
        """
        new_cell_af = self.new_cell_af
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)

        self.subsample_new_clone_cell = self.new_clone_cell.loc[
            self.subsample_new_cell_af.index]

        if to_delete:
            self.new_cell_af = None
            self.new_clone_cell = None


    def combine_init_growth(self):
        clones = pd.concat(
            (self.clone_cell, self.subsample_new_clone_cell)).reset_index(
            drop=True)
        combined_cell_af = self.cell_af.append(self.subsample_new_cell_af).reset_index(drop=True)

        combined_meta = np.concatenate((np.ones(shape=[self.cell_af.shape[0],]), np.zeros(shape=[self.subsample_new_cell_af.shape[0]])))
        combined_meta = pd.Series(combined_meta, name='After Growth', dtype=int)
        assert(combined_meta.shape[0] == self.cell_af.shape[0]+self.subsample_new_cell_af.shape[0])
        assert (combined_cell_af.shape[0] == self.cell_af.shape[0] +
                self.subsample_new_cell_af.shape[0])
        assert(combined_meta.shape[0] == clones.shape[0])
        assert(combined_cell_af.shape[0] == clones.shape[0])
        self.combined_meta = combined_meta
        self.combined_clones = clones
        self.combined_cell_af = combined_cell_af
        return

    def save(self, f_save=None):
        """
        Args:
            f_save:
        """
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

    def save_to_mgatk_format(self):
        """Converts into the proper files needed for mgatk. (i.e variant and
        coverage files) :return:
        """

    def load(self):
        filename = self.params['filename']
        f = open(filename, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)

    def compare_before_after(self):
        """Creates a df that contains information on the number of cells from
        each clone before as well as after. :return: df.at[ind, "Dominant
        Before"] = (full_sim.clone_cell == 1).sum() df.at[ind, "Dominant After"]
        = (full_sim.subsample_new_clone_cell == 1).sum()
        """

        return

    def cluster_compare_before_after(self):
        """Compares the performance of clustering on grouping the same clones
        together. :return:
        """
        return

    @staticmethod
    def plot_cluster(cell_af, cell_meta=None, mt_meta=None, f_save=None):
        """
        Args:
            cell_af:
            cell_meta:
            mt_meta:
            f_save:
        """
        ch.plot_cluster(cell_af, row_meta=cell_meta, col_meta=mt_meta,
                        fsave=f_save, to_col_clust=False, to_z=True)

    @staticmethod
    def cluster(cell_af):
        """Dynamic tree clustering of the rows of cell_af :param cell_af:
        :return:

        Args:
            cell_af:
        """
        distances = pdist(cell_af, "euclidean")
        link = linkage(distances, "average")
        clusters = cutreeHybrid(link, distances)['labels']
        return clusters

    @staticmethod
    def cluster_kmeans(cell_af):
        """
        Args:
            cell_af:
        """
        distortions = []
        inertias = []
        mapping1 = {}
        mapping2 = {}
        K = range(1, 10)
        for k in K:
            # Building and fitting the model
            kmeanModel = KMeans(n_clusters=k).fit(cell_af)
            kmeanModel.fit(cell_af)

            distortions.append(sum(
                np.min(cdist(cell_af, kmeanModel.cluster_centers_, 'euclidean'),
                       axis=1)) / cell_af.shape[0])
            inertias.append(kmeanModel.inertia_)

            mapping1[k] = sum(
                np.min(cdist(cell_af, kmeanModel.cluster_centers_, 'euclidean'),
                       axis=1)) / cell_af.shape[0]
            mapping2[k] = kmeanModel.inertia_



def main():
    return


if "__name__" == "__main__":
    main()
 No newline at end of file
eleted file mode 100644
ndex e67049d..0000000
++ /dev/null

98578799572bfbf6ec94a80e57cac1d4f2a0b6ad

growth of poisson refactor

growth of poisson refactor

# TODO growth of poisson refactor

from numpy import random
import os
import pandas as pd
import pickle
from src.simulations.utils.config import read_config_file, write_config_file
from src.simulations.utils.config import check_required


class Simulation:
    """Lineage tracing simulation of one sample

    Will initialize cells based on their parameters and grow as well. This
    should be a flexible framework, to add different ways to initialize, grow,
    and metrics to have. Additionally can cluster these results.

    :ivar params
    :type params: dict
    """

    def __init__(self, params_f):
        """
        :param params_f: Parameter yaml file for the specifications
        :type params_f: yaml file or dict
        """
        if isinstance(params_f, str):
            params = read_config_file(params_f)
        else:
            params = params_f

        self.params = params
        check_required(params, ['initialize', 'num_cells', 'num_mt_positions', 'prefix'])
        self.prefix = params['prefix']
        self.num_mt_positions = params['num_mt_positions']
        self.num_cells = params['num_cells']
        if not os.path.exists(params['local_outdir']):
            os.mkdir(params['local_outdir'])


    def initialize(self):
        """ (1) Pre-growth cell population is instantiated.

        Creates a clone-MT dictionary, cell coverage matrix
        (or an int, depending on parameters), and cell-AF matrix.
        :return:
        """
        self.init_clone_dict()
        self.init_cell_coverage()
        self.init_cell_af()
        #self.init_clone_mt()

    #should be external method
    def grow(self):
        """ (2) Growth of cells is run."""
        p = self.params
        type = p["growth"]["type"]
        if  type == "poisson":
            self.grow_poisson(p['growth']['poisson'])
        elif type == "binomial":
            self.grow_binomial(p['growth']['binomial'])
        return

    # Static Method
    @staticmethod
    def clone_counts_to_cell_series(clone_counts):
        """ Generates new cell IDs based on cluster count iterable
        :param clone_counts: Each i'th element is the number of cells in
        cluster i.
        :type clone_counts: iterable
        :return each index name is a cell ID and each value is which cluster
        the cell belongs too.
        :rtype pd.Series
        """
        clone_counts = np.array(clone_counts)
        num_cells = clone_counts.sum()
        clone_cell = -1 * np.ones(shape=[num_cells, ])

        clone_cell[:clone_counts[0]] = 0
        for ind, val in enumerate(clone_counts[1:]):
            start = clone_counts[:ind + 1].sum()
            end = clone_counts[:ind + 1].sum() + val
            # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
            clone_cell[start:end] = ind + 1

        clone_cell = pd.Series(clone_cell, dtype=int)
        return clone_cell


    def init_clone_dict(self):
        """1A
        """

        ### Add in potential to overwrite the values
        # Gets the clone dictionary. Should also have clone to mt dict.
        clones = self.params['initialize']['clone_sizes']

        if 'num_cells_population' not in self.params:
            self.num_cells_pop = self.num_cells
        else:
            self.num_cells_pop = self.params['num_cells_population']

        num_cells = self.num_cells_pop

        # Option 1: List of fraction of size of each clone. 0s are nonclone size, listed first
        if type(clones) == list:
            #clone_cell = pd.Series(index=range(num_cells))
            clone_counts = np.random.multinomial(num_cells, clones)
            clone_cell  = self.clone_counts_to_cell_series(clone_counts)
            self.clone_cell_pop = clone_cell

            # Choose subset to be sampled
            self.clone_cell = clone_cell.sample(n=self.num_cells).sort_values()
        # Option 2: 1 clone. ID'd as 1
        elif type(clones) == int: #One number for dominant clone. the others are not.
            clone_cell = np.zeros(shape=[num_cells,])
            clone_cell[:num_cells] = 1
            clone_cell = clone_cell[::-1]
            clone_cell =  pd.Series(clone_cell, dtype=int)
            self.clone_cell = clone_cell

        # Option 3 To ADD, beta binomial and more complex distributions

        self.num_clones =  len(set(clone_cell.values))-1 # Remove the non-clone
        return clone_cell


    def init_cell_coverage(self):
        """1B

        There are different modes to the coverage, either a constant or
        through a distribution.
        """
        p = self.params['initialize']['coverage']
        type = p['type']

        num_cells = self.num_cells
        num_pos = self.num_mt_positions
        c = np.zeros([num_cells, num_pos])

        if type == 'constant':
            c[:, :] = p['cov_constant']
        elif type == "poisson":
            # Get the number of coverage per cell based on poisson (should be reads)
            mu_cov_per_cell = p['mu_cov_per_cell']
            num_reads_per_cell = random.poisson(lam=mu_cov_per_cell,
                                                size=num_cells)

            # Number of reads at each position, based on the average for each cell
            for i in num_cells:
                c[i, :] = random.poisson(num_reads_per_cell[i],
                                         size=num_pos)
        self.cells_mt_coverage = c
        return c


    @staticmethod
    def create_cell_af(clone_df, mt_dict, n_cells, n_mt, num_clones,
                       cov_params, hets, het_err, coverage=None):
        cell_af = pd.DataFrame(np.zeros(shape=[n_cells, n_mt]))

        #p = self.params['initialize']

        ## Loop through each clone,
        ## Generate the AF for the clone and non-clones using coverage for each cell
        ## Fill in cell_by_af for that position.
        for ind in range(1, num_clones + 1):
            # Generate AF: (clone_df ==  ind).sum()
            n_dom_cells = (clone_df == ind).sum()
            het = hets[ind - 1]

            curr_mt = mt_dict[ind]

            if cov_params['coverage']['type'] == 'constant':
                c = cov_params['coverage']['cov_constant']

                af_i = random.binomial(c, het, n_dom_cells) / c
                af_j = random.binomial(c, het_err, n_cells - n_dom_cells) / c

                # Update the dom_cells and non_dom for the current MT
                cell_af.loc[
                    np.flatnonzero(clone_df == ind), curr_mt] = af_i
                cell_af.loc[
                    np.flatnonzero(clone_df != ind), curr_mt] = af_j

            # Each cell and position has it's own coverage value, so need to update each
            else:
                if coverage is None:
                    raise("coverage needs to be assigned")
                c = coverage

                # Get the cells coverage for the mt position
                curr_mt_cov = c[:, curr_mt]

                # Get cell indicies for the clones and nonclones
                curr_clone_inds = np.flatnonzero(clone_df == ind)
                curr_nonclone_inds = np.flatnonzero(clone_df != ind)
                for cell in curr_clone_inds:
                    # Get one value for curr_mt and cell based on coverage
                    cell_af.loc[cell, curr_mt] = random.binomial(
                        curr_mt_cov[cell], het)
                for cell in curr_nonclone_inds:
                    cell_af.loc[cell, curr_mt] = random.binomial(
                        curr_mt_cov[cell], het_err)
        return cell_af


    ##########
    def init_cell_af(self):
        """1C. Initialize the cell-by-mtPos af dataframe. Unless a clone:mt dict was
        provided, the first N MT positions will be the clone AFs. Creates
        self.clone_mt_dict and self.cell_af
        """
        p = self.params['initialize']
        hets = self.params['het']
        clone_df = self.clone_cell
        num_clones = self.num_clones
        n_cells = self.num_cells
        n_mt = self.num_mt_positions

        # Get the MT map
        if 'mt_clone_map' in p and p['mt_clone_map'] is not None:
            self.clone_mt_dict = p['mt_clone_map']
        else:
            # Each clone points to a mt position
            self.clone_mt_dict = dict()
            for i in range(1,num_clones+1):
                self.clone_mt_dict[i] = i

        # TODO Add the MT clone map so it can contain multiple mutants in lineages
        # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
        # If not, will randomly draw based on number of clones
        if type(hets) == list:
            assert(len(hets) == num_clones)

        # Get the cell_af based on MT dictionary and cell coverage
        self.cell_af = self.create_cell_af(clone_df, self.clone_mt_dict,
                                           n_cells, n_mt, num_clones,
                                           self.params['initialize'],
                                           hets,
                                           self.params['het_err_rate'],
                                           coverage=None)
        return


    def init_clone_mt(self):
        p = self.params
        if p["initialize"]['type'] == 'growth':
            ## TODO
            # Create a phylogeny and then get the averages of the mutants
            self.average_clone_mt()
        # If not growth, should aready be there.
        return

    def average_clone_mt(self):
        return

    @staticmethod
    def extract_clone_cells(clone_cell, clone_id):
        """ Returns the numbered indices of the specific clones

        :param clone_cell: Each element is the indexed cell's clone label.
        :type clone_cell: np array or pd.Series

        :param clone_id:
        :type clone_id: int or string
        """
        ids = np.flatnonzero(clone_cell == clone_id)
        return ids

    @staticmethod
    def simulate_expand_cells_af(af, growth_inds, sigma):
        """Given a cell-by-af vector, expand the AF.

        Expanded AF occurs by duplicating cells that grew based on the
        growth_inds vector. It will add standard error to each af based on sigma
        :param af: :param growth: Indices of AF to copy :param sigma: Variance
        to add to AF of child. :return:

        Args:
            af:
            growth_inds:
            sigma:
        """

        new_af = af.iloc[growth_inds].copy() + random.normal(0, sigma, size=af.iloc[growth_inds].shape)
        new_af.index = np.arange(af.index[-1]+1, af.index[-1]+1+new_af.shape[0])
        new_af = pd.concat((af,new_af), axis=0)
        #new_af = np.append(af, np.concatenate(new_af))
        return new_af

    def grow_binomial(self, p):
        """ (2.1.2)
        :param p: contains time_steps, rates,
        :type dict
        """
        timesteps = p["time_steps"]
        rates = p["rates"]

        num_clones = self.num_clones+1
        new_dict = {}
        for curr_clone in range(num_clones):
            curr_rate = rates[curr_clone]
            ids = self.extract_clone_cells(self.clone_cell_pop, curr_clone)
            num_curr_cells = len(ids)

            for i in range(timesteps):
                # Simulate growth for each clone separately.
                growth_inds = (random.binomial(1, curr_rate, size=num_curr_cells)).sum()
                num_curr_cells += growth_inds.sum()

            new_dict[curr_clone] = num_curr_cells


        ####TODO
        ## new_lineage_mutants chances. This will see if a mutation will change
        ####TODO
        ## Add death + stimulation rate as well as growth
        # Save the new cell clones df and cell af
        clone_counts = [i for i in new_dict.values()]
        self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)
        # Do not make cell_af, will make this only when subsampled.


        # self.new_cell_af = pd.DataFrame()
        # for clone in range(1, self.num_clones+1):
        #     self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
        return


    def grow_binomial_old(self, p):
        """ (2.1.1)
        :param p: contains time_steps, rates,
                 and [growth][mutant_af_sigma_noise
        :type dict
        """
        timesteps = p["time_steps"]
        rates = p["rates"]

        sigma = self.params['growth']["mutant_af_sigma_noise"]
        cell_af = self.cell_af

        num_clones = self.num_clones+1
        new_dict = {}
        for curr_clone in range(num_clones):
            curr_rate = rates[curr_clone]
            ids = self.extract_clone_cells(self.clone_cell, curr_clone)
            new_cells = cell_af.iloc[ids].copy()
            for i in range(timesteps):
                # Simulate growth for each clone separately.
                growth_inds = np.flatnonzero(random.binomial(1, curr_rate, size=new_cells.shape[0]))
                #new_ids =
                new_cells = self.simulate_expand_cells_af(new_cells, growth_inds, sigma)

            new_dict[curr_clone] = new_cells
            # Create list of cells
        ####TODO
        ## new_lineage_mutants chances. This will see if a mutation will change
        ####TODO
        ## Add death + stimulation rate as well as growth
        # Save the new cell clones df and cell af
        clone_counts = [i.shape[0] for i in new_dict.values()]
        self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)
        self.new_cell_af = pd.DataFrame(new_dict[0])
        for clone in range(1, self.num_clones+1):
            self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
        return


    def grow_poisson(self, p):
        # TODO growth of poisson refactor
        return


    def subsample_new(self, to_delete=False):
        """(3) Subsample from new cell population and generate cell_af

        :param to_delete: To remove the cells that grew (which takes up
        a lot of RAM).
        :type to_delete: bool
        """
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_clone_cell = self.new_clone_cell.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_clone_cell = self.new_clone_cell.sample(
                n=self.num_cells)

        #print(f'New cell af, {len(self.subsample_new_clone_cell)}')
        # Generate subsample_new_cell_af
        self.subsample_new_cell_af = self.create_cell_af(clone_df=self.subsample_new_clone_cell,
                                                         mt_dict = self.clone_mt_dict,
                                                         n_cells=len(self.subsample_new_clone_cell),
                                                         n_mt=self.num_mt_positions,
                                                         num_clones=self.num_clones,
                                                         cov_params=p['initialize'],
                                                         hets=
                                                             self.params[
                                                                 'het'],
                                                         het_err=self.params['het_err_rate'],
                                                         coverage=None
                                                         )

        if to_delete:
            self.new_cell_af = None
            self.new_clone_cell = None


    def subsample_new_old(self, to_delete=False):
        """(3) Subsample from new cell population

        :param to_delete: To remove the cells that grew (which takes up
        a lot of RAM).
        :type to_delete: bool
        """
        new_cell_af = self.new_cell_af
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)

        self.subsample_new_clone_cell = self.new_clone_cell.loc[
            self.subsample_new_cell_af.index]

        if to_delete:
            self.new_cell_af = None
            self.new_clone_cell = None


    def combine_init_growth(self):
        """(4) Add the pre- and post- population of cells into a group.

        :return:
        """
        combined_cell_af = self.cell_af.append(self.subsample_new_cell_af).reset_index(drop=True)
        combined_clones = pd.concat(
            (self.clone_cell, self.subsample_new_clone_cell)).reset_index(
            drop=True)

        combined_befaft = np.concatenate((np.zeros(shape=[self.cell_af.shape[0],]), np.ones(shape=[self.subsample_new_cell_af.shape[0]])))
        combined_meta = pd.DataFrame({"pre_post": combined_befaft, "clone": combined_clones})
        #combined_meta = pd.Series(combined_meta, name='After Growth', dtype=int)
        assert(combined_meta.shape[0] == self.cell_af.shape[0]+self.subsample_new_cell_af.shape[0])
        assert (combined_cell_af.shape[0] == self.cell_af.shape[0] +
                self.subsample_new_cell_af.shape[0])
        assert(combined_meta.shape[0] == combined_clones.shape[0])
        assert(combined_cell_af.shape[0] == combined_clones.shape[0])
        self.combined_meta = combined_meta
        self.combined_clones = combined_clones
        self.combined_cell_af = combined_cell_af
        return

    def save(self, f_save=None):
        """
        Args:
            f_save:
        """
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

    @staticmethod
    def expand_to_mgatk(curr_mt_af,mt_ref):
        ref = mt_ref[curr_mt_af.name]
        pos = curr_mt_af.name
        return pd.DataFrame({"Ref":ref, "Pos":pos, "Val":curr_mt_af})

    def test_save_to_mgatk_format(self):
        df = pd.DataFrame( [[10,0,1,3,5], [3,0,5,5,0], [6,2,1,1,0]] , columns=np.arange(0,5))
        mt_ref_dict = {0: "A", 1: "G", 2: "C", 3: "C", 4: "T"}
        mt_ref = pd.DataFrame({"Pos": mt_ref_dict.keys(), "Ref": mt_ref_dict})
        return

    def save_to_mgatk_format(self, mt_ref, out_f):
        """Converts into the proper files needed for mgatk. (i.e variant and
        coverage files)

        :return:
        """
        cell_af = self.subsample_new_cell_af
        chars = ["A", "G", "C", "T"]
        def alt_generate(x):
            curr = chars.copy()
            curr.remove(x["Ref"])
            return np.random.choice(curr)
        alt_ref = mt_ref.apply(alt_generate, axis=1)

        # First use the AF and choose an alternative allele
        df_stack = cell_af.stack().reset_index().rename(
            {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
            axis=1)
        df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
            lambda x: alt_ref[x])

        # Add on the reference allele
        df_stack_ref = cell_af.stack().reset_index().rename(
            {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
            axis=1)
        df_stack_ref["Coverage"] = 1-df_stack_ref["Coverage"]
        df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
            lambda x: mt_ref[x])

        # Save the NTs.
        # For concordance, split the coverage in two
        df_stack = pd.concat(df_stack, df_stack_ref)
        for ind, val in df_stack.groupby("Nucleotide"):
            # Drop the 0s
            curr = val[val["Coverage"]>0]
            # Save file
            curr_out_f = out_f + "_" + ind + ".txt"
            curr.to_csv(curr_out_f)

        # Save the coverage.
        coverage = self.cells_mt_coverage
        if type(coverage) != int:
            coverage_stack = pd.DataFrame(coverage).stack().reset_index().rename(
                {"level_0": "Cell", "level_1": "MT Position", 0: "Coverage"},
                axis=1)
        else:
            coverage_stack = pd.DataFrame(self.cells_mt_coverage)*np.ones(shape=cell_af.shape).stack().reset_index().rename(
                {"level_0": "Cell", "level_1": "MT Position",  0: "Coverage"},
                axis=1)
        curr_out_f = out_f + "_" + "coverage.txt"
        coverage_stack.to_csv(curr_out_f)
        return

    def load(self):
        filename = self.params['filename']
        f = open(filename, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)

    def compare_before_after(self):
        """Creates a df that contains information on the number of cells from
        each clone before as well as after. :return: df.at[ind, "Dominant
        Before"] = (full_sim.clone_cell == 1).sum() df.at[ind, "Dominant After"]
        = (full_sim.subsample_new_clone_cell == 1).sum()
        """

        return

    def cluster_compare_before_after(self):
        """Compares the performance of clustering on grouping the same clones
        together. :return:
        """
        return


def main():
    return


if "__name__" == "__main__":
    main()
 No newline at end of file
ew file mode 100644
ndex 0000000..e30dc8e
++ b/src/simulations/simulation_pymc.gv

42824dc97cc86a88d3e6199e05772ba35121892d

Add the MT clone map so it can contain multiple mutants in lineages

Add the MT clone map so it can contain multiple mutants in lineages

# TODO Add the MT clone map so it can contain multiple mutants in lineages

import numpy as np
from numpy import random
import os
import pandas as pd
from tqdm import tqdm
#from src.config import ROOT_DIR
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import glob
from sklearn.cluster import KMeans
from sklearn import metrics
from scipy.spatial.distance import cdist
from pandarallel import pandarallel
pandarallel.initialize(nb_workers=32)

from mplh.color_utils import get_colors
from mplh.fig_utils import legend_from_color
from mplh import cluster_help as ch
from src.simulations.utils.config import read_config_file, write_config_file

from dynamicTreeCut import cutreeHybrid
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage
from sklearn.model_selection import ParameterGrid
from src.simulations.utils.config import check_required


class Simulation:
    """
    Lineage tracing simulation. Will initialize cells based on
    their parameters and grow as well. This should be a flexible
    framework, to add different ways to initialize, grow, and metrics to
    have. Additionally can cluster these results.
    """

    def __init__(self, params_f):
        if isinstance(params_f, str):
            params = read_config_file(params_f)
        else:
            params = params_f

        self.params = params
        check_required(params, ['initialize', 'num_cells', 'num_mt_positions', 'prefix'])
        self.prefix = params['prefix']
        self.num_mt_positions = params['num_mt_positions']
        self.num_cells = params['num_cells']
        if not os.path.exists(params['local_outdir']):
            os.mkdir(params['local_outdir'])


    def initialize(self):
        self.init_clone_dict()
        self.init_cell_coverage()
        self.init_cell_af()
        #self.init_clone_mt()

    #should be external method
    def grow(self):
        p = self.params
        type = p["growth"]["type"]
        if  type == "poisson":
            self.grow_poisson(p['growth']['poisson'])
        elif type == "binomial":
            self.grow_binomial(p['growth']['binomial'])
        return

    # Static Method
    @staticmethod
    def clone_counts_to_cell_series(clone_counts):
        clone_counts = np.array(clone_counts)
        num_cells = clone_counts.sum()
        clone_cell = -1 * np.ones(shape=[num_cells, ])


        clone_cell[:clone_counts[0]] = 0
        for ind, val in enumerate(clone_counts[1:]):
            start = clone_counts[:ind + 1].sum()
            end = clone_counts[:ind + 1].sum() + val
            # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
            clone_cell[start:end] = ind + 1

        clone_cell = pd.Series(clone_cell, dtype=int)
        return clone_cell

    def init_clone_dict(self):
        ### Add in potential to overwrite the values

        # Gets the clone dictionary. Should also have clone to mt dict.
        clones = self.params['initialize']['clone_sizes']
        num_cells = self.num_cells

        # Option 1: List of fraction of size of each clone. 0s are nonclone size, listed first
        if type(clones) == list:
            #clone_cell = pd.Series(index=range(num_cells))
            clone_counts = np.random.multinomial(num_cells, clones)
            clone_cell  = self.clone_counts_to_cell_series(clone_counts)
            self.clone_cell = clone_cell
        # Option 2: 1 clone. ID'd as 1
        elif type(clones) == int: #One number for dominant clone. the others are not.
            clone_cell = np.zeros(shape=[num_cells,])
            clone_cell[:num_cells] = 1
            clone_cell = clone_cell[::-1]
            clone_cell =  pd.Series(clone_cell, dtype=int)
            self.clone_cell = clone_cell

        # Option 3 To ADD, beta binomial and more complex distributions

        self.num_clones =  len(set(clone_cell.values))-1 # Remove the non-clone
        return clone_cell


    def init_cell_coverage(self):
        """
        There are different modes to the coverage, either a constant or through a distribution.
        :return:
        """
        p = self.params['initialize']['coverage']
        type = p['type']

        num_cells = self.num_cells
        num_pos = self.num_mt_positions
        c = np.zeros([num_cells, num_pos])

        if type == 'constant':
            c[:, :] = p['cov_constant']
        elif type == "poisson":
            # Get the number of coverage per cell based on poisson (should be reads)
            mu_cov_per_cell = p['mu_cov_per_cell']
            num_reads_per_cell = random.poisson(lam=mu_cov_per_cell,
                                                size=num_cells)

            # Number of reads at each position, based on the average for each cell
            for i in num_cells:
                c[i, :] = random.poisson(num_reads_per_cell[i],
                                         size=num_pos)
        self.cells_mt_coverage = c
        return c


    def init_cell_af(self):
        """
        Initialize the cell-by-mtPos af dataframe. Unless a clone:mt dict was provided,
        the first N MT positions will be the clone AFs.
        Creates self.clone_mt_dict and self.cell_af"""

        p = self.params['initialize']

        hets = self.params['het']
        q = self.params['het_err_rate']
        clone_df = self.clone_cell
        num_clones = self.num_clones
        n_cells = self.num_cells
        n_mt = self.num_mt_positions

        # Output
        cell_af = pd.DataFrame(np.zeros(shape=[n_cells, n_mt]))


        if 'mt_clone_map' in p and p['mt_clone_map'] is not None:
            self.clone_mt_dict = p['mt_clone_map']
        else:
            # Each clone points to a mt position
            self.clone_mt_dict = dict()
            for i in range(1,num_clones+1):
                self.clone_mt_dict[i] = i

        # TODO Add the MT clone map so it can contain multiple mutants in lineages

        # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
        # If not, will randomly draw based on number of clones
        if type(hets) == list:
            if (len(hets) != num_clones):
                print('here')
            assert(len(hets) == num_clones)

            ## Loop through each clone,
            ## Generate the AF for the clone and non-clones using coverage for each cell
            ## Fill in cell_by_af for that position.
            for ind in range(1, num_clones+1):
                # Generate AF: (clone_df ==  ind).sum()
                n_dom_cells = (clone_df==ind).sum()
                het = hets[ind-1]

                curr_mt = self.clone_mt_dict[ind]


                if p['coverage']['type'] == 'constant':
                    c = p['coverage']['cov_constant']

                    af_i = random.binomial(c, het,
                                           n_dom_cells) / c
                    af_j = random.binomial(c, q,
                                           n_cells - n_dom_cells) / c

                    # Update the dom_cells and non_dom for the current MT
                    cell_af.loc[np.flatnonzero(clone_df == ind), curr_mt] = af_i
                    cell_af.loc[np.flatnonzero(clone_df != ind), curr_mt] = af_j

                # Each cell and position has it's own coverage value, so need to update each
                else:
                    c = self.cells_mt_coverage
                    #Get the cells coverage for the mt position
                    curr_mt_cov= c[:, curr_mt]

                    # Get cell indicies for the clones and nonclones
                    curr_clone_inds = np.flatnonzero(clone_df==ind)
                    curr_nonclone_inds = np.flatnonzero(clone_df!=ind)
                    for cell in curr_clone_inds:
                        # Get one value for curr_mt and cell based on coverage
                        cell_af.loc[cell, curr_mt] = random.binomial(curr_mt_cov[cell], het)
                    for cell in curr_nonclone_inds:
                        cell_af.loc[cell, curr_mt] = random.binomial(curr_mt_cov[cell], q)
                # Loop through each coverage
                #for c in n_dom_cells:

        #####
        # TODO
        # Add noise to the other non-lineage positions
        #####
        self.cell_af = cell_af
        return


    def init_clone_mt(self):
        p = self.params
        if p["initialize"]['type'] == 'growth':
            ## TODO
            # Create a phylogeny and then get the averages of the mutants
            self.average_clone_mt()
        # If not growth, should aready be there.
        return

    def average_clone_mt(self):
        return

    def extract_clone_cells(self, clone_id):
        ids = np.flatnonzero(self.clone_cell == clone_id)
        return ids

    def simulate_expand_cells_af(self, af, growth_inds, sigma):
        """
        Given a cell-by-af vector, expand the AF.

        Expanded AF occurs by duplicating cells that grew based on
        the growth_inds vector. It will add standard error to each
        af based on sigma
        :param af:
        :param growth: Indices of AF to copy
        :param sigma: Variance to add to AF of child.
        :return:
        """

        new_af = af.iloc[growth_inds].copy() + random.normal(0, sigma, size=af.iloc[growth_inds].shape)
        new_af.index = np.arange(af.index[-1]+1, af.index[-1]+1+new_af.shape[0])
        new_af = pd.concat((af,new_af), axis=0)
        #new_af = np.append(af, np.concatenate(new_af))
        return new_af

    def grow_binomial(self, p):
        timesteps = p["time_steps"]
        rates = p["rates"]

        sigma = self.params['growth']["mutant_af_sigma_noise"]
        cell_af = self.cell_af
        clone_mt_dict = self.clone_mt_dict

        num_clones = self.num_clones+1
        new_dict = {}
        for curr_clone in range(num_clones):
            curr_rate = rates[curr_clone]
            ids = self.extract_clone_cells(curr_clone)
            new_cells = cell_af.loc[ids].copy()
            for i in range(timesteps):
                # Simulate growth for each clone separately.
                growth_inds = np.flatnonzero(random.binomial(1, curr_rate, size=new_cells.shape[0]))
                #new_ids =
                new_cells = self.simulate_expand_cells_af(new_cells, growth_inds, sigma)

            new_dict[curr_clone] = new_cells
            # Create list of cells

        ####TODO
        ## new_lineage_mutants chances. This will see if a mutation will change


        ####TODO
        ## Add death + stimulation rate as well as growth
        # Save the new cell clones df and cell af
        clone_counts = [i.shape[0] for i in new_dict.values()]
        self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)

        self.new_cell_af = pd.DataFrame(new_dict[0])
        for clone in range(1, self.num_clones+1):
            self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
        return


    def grow_poisson(self):
        # TODO growth of poisson refactor
        return


    def subsample_new(self, to_delete=False):
        new_cell_af = self.new_cell_af
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)

        self.subsample_new_clone_cell = self.new_clone_cell.loc[
            self.subsample_new_cell_af.index]

        if to_delete:
            self.new_cell_af = None
            self.new_clone_cell = None


    def combine_init_growth(self):
        clones = pd.concat(
            (self.clone_cell, self.subsample_new_clone_cell)).reset_index(
            drop=True)
        combined_cell_af = self.cell_af.append(self.subsample_new_cell_af).reset_index(drop=True)

        combined_meta = np.concatenate((np.ones(shape=[self.cell_af.shape[0],]), np.zeros(shape=[self.subsample_new_cell_af.shape[0]])))
        combined_meta = pd.Series(combined_meta, name='After Growth', dtype=int)
        assert(combined_meta.shape[0] == self.cell_af.shape[0]+self.subsample_new_cell_af.shape[0])
        assert (combined_cell_af.shape[0] == self.cell_af.shape[0] +
                self.subsample_new_cell_af.shape[0])
        assert(combined_meta.shape[0] == clones.shape[0])
        assert(combined_cell_af.shape[0] == clones.shape[0])
        self.combined_meta = combined_meta
        self.combined_clones = clones
        self.combined_cell_af = combined_cell_af
        return

    def save(self, f_save=None):
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

    def save_to_mgatk_format(self):
        """
        Converts into the proper files needed for mgatk. (i.e variant and coverage files)
        :return:
        """

    def load(self):
        filename = self.params['filename']
        f = open(filename, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)

    def compare_before_after(self):
        """
        Creates a df that contains information on
        the number of cells from each clone before as well as after.
        :return:
        df.at[ind, "Dominant Before"] = (full_sim.clone_cell == 1).sum()
        df.at[ind, "Dominant After"] =  (full_sim.subsample_new_clone_cell == 1).sum()

        """

        return

    def cluster_compare_before_after(self):
        """
        Compares the performance of clustering on grouping the same
        clones together.
        :return:
        """
        return

    @staticmethod
    def plot_cluster(cell_af, cell_meta=None, mt_meta=None, f_save=None):
        ch.plot_cluster(cell_af, row_meta=cell_meta, col_meta=mt_meta,
                        fsave=f_save, to_col_clust=False, to_z=True)

    @staticmethod
    def cluster(cell_af):
        """
        Dynamic tree clustering of the rows of cell_af
        :param cell_af:
        :return:
        """
        distances = pdist(cell_af, "euclidean")
        link = linkage(distances, "average")
        clusters = cutreeHybrid(link, distances)['labels']
        return clusters

    @staticmethod
    def cluster_kmeans(cell_af):
        distortions = []
        inertias = []
        mapping1 = {}
        mapping2 = {}
        K = range(1, 10)
        for k in K:
            # Building and fitting the model
            kmeanModel = KMeans(n_clusters=k).fit(cell_af)
            kmeanModel.fit(cell_af)

            distortions.append(sum(
                np.min(cdist(cell_af, kmeanModel.cluster_centers_, 'euclidean'),
                       axis=1)) / cell_af.shape[0])
            inertias.append(kmeanModel.inertia_)

            mapping1[k] = sum(
                np.min(cdist(cell_af, kmeanModel.cluster_centers_, 'euclidean'),
                       axis=1)) / cell_af.shape[0]
            mapping2[k] = kmeanModel.inertia_




def main():
    return


if "__name__" == "__main__":
    main()
 No newline at end of file
eleted file mode 100644
ndex e67049d..0000000
++ /dev/null

871fbe46ded383b173de84cbe475a0d191c2fa79

growth of poisson refactor

growth of poisson refactor

return

# # TODO growth of poisson refactor

# import os
# import pandas as pd
# import pickle
# from src.simulations.utils.config import read_config_file, write_config_file
# from src.simulations.utils.config import check_required
import pymc3 as pm
import matplotlib.pyplot as plt

num_cells = 10000
num_mt_positions = 10
clone_dist = [0.10,0.01,.89]
hets = [0.2,0.3] # len(hets) == len(clone_dist)-1
avg_cov = 50
het_err_rate = 0.1


df = np.concatenate((np.random.binomial(10,0.3,(100,4)),
                       np.random.binomial(10,0.6,(90,4))))
clone_id = np.concatenate((np.zeros([100,]), np.ones([90,]))).astype(int)
mt_id = [0,1,2,3]

with pm.Model() as model:

    clone_ids = pm.Mulinomial(10000, clone_dist)

    beta = pm.Beta('beta', alpha=2,beta=2, shape=2)
    #p = pm.Bernoulli('p', 1, beta, shape=2)
    #p = pm.Binomial('p', 1, beta)
    #q = pm.Binomial('q', 1, beta)

    s = pm.Binomial('s', 10, beta[clone_id], observed=df)

    #s = pm.Binomial('s', 10, p, observed=df[:30,0])
    #t = pm.Binomial('t', 10, q, observed=df[30:, 0])

    #s = pm.Binomial('s', 10, beta, shape=(30,4), observed=df[:30])
    #t = pm.Binomial('t', 10, beta, shape=(25, 4), observed=df[30:])

    #vec = pm.math.concatenate((s, t), axis=0)

    # data = pm.Data("data", df)
    # u = pm.Normal('u', vec, observed=data)
    #u = pm.Deterministic('u', vec)

    trace = pm.sample(draws=8000, init='adapt_diag')

print(pm.summary(trace))
dot = pm.model_to_graphviz(model)
dot.render('simulation_pymc.gv')
pm.plot_trace(trace)
plt.savefig('simulation_trace.png')
print('here')
#
# with pm.Model() as model:
#     clone_counts = pm.Multinomial(num_cells, clone_dist)
#     num_clones = len(clone_counts) - 1
#
#     clone_cell = -1 * np.ones(shape=[num_cells, ])
#
#     clone_cell[:clone_counts[0]] = 0
#
#     for ind, val in enumerate(clone_counts[1:]):
#         start = clone_counts[:ind + 1].sum()
#         end = clone_counts[:ind + 1].sum() + val
#         # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
#         clone_cell[start:end] = ind + 1
#
#     c = pm.Poisson('cov', avg_cov, shape=[num_cells, num_mt_positions])
#
#     clone_mt_dict = dict()
#     for i in range(1, num_clones + 1):
#         clone_mt_dict[i] = i
#
#     cell_af = np.zeros([num_cells, num_mt_positions])
#     for ind in range(num_clones):
#         # Generate AF: (clone_df ==  ind).sum()
#         n_dom_cells = clone_counts[ind]
#         het = hets[ind]
#
#         curr_mt = clone_mt_dict[ind]
#
#         af_i = pm.Binomial('af', avg_cov, het, shape=n_dom_cells)
#         af_j = pm.Binomial('het af', avg_cov, het_err_rate, shape=num_cells - n_dom_cells) # / c
#
#
#         # Update the dom_cells and non_dom for the current MT
#         cell_af[np.flatnonzero(clone_df == ind), curr_mt] = af_i
#         cell_af[np.flatnonzero(clone_df != ind), curr_mt] = af_j
#
#     cell_af = pm.Deterministic(y)
#
# pm.model_to_graphviz(model)
#
#
# def init_cell_af(self):
#     """1C. Initialize the cell-by-mtPos af dataframe. Unless a clone:mt dict was
#     provided, the first N MT positions will be the clone AFs. Creates
#     self.clone_mt_dict and self.cell_af
#     """
#     clone_df = self.clone_cell
#     # Output
#     cell_af = pd.DataFrame(np.zeros(shape=[n_cells, n_mt]))
#
#         # Each clone points to a mt position
#         self.clone_mt_dict = dict()
#         for i in range(1, num_clones + 1):
#             self.clone_mt_dict[i] = i
#
#     # TODO Add the MT clone map so it can contain multiple mutants in lineages
#
#     # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
#     # If not, will randomly draw based on number of clones
#     if type(hets) == list:
#         assert (len(hets) == num_clones)
#
#         ## Loop through each clone,
#         ## Generate the AF for the clone and non-clones using coverage for each cell
#         ## Fill in cell_by_af for that position.
#         for ind in range(1, num_clones + 1):
#             # Generate AF: (clone_df ==  ind).sum()
#             n_dom_cells = (clone_df == ind).sum()
#             het = hets[ind - 1]
#
#             curr_mt = self.clone_mt_dict[ind]
#
#             if p['coverage']['type'] == 'constant':
#                 c = p['coverage']['cov_constant']
#
#                 af_i = random.binomial(c, het, n_dom_cells) / c
#                 af_j = random.binomial(c, q, n_cells - n_dom_cells) / c
#
#                 # Update the dom_cells and non_dom for the current MT
#                 cell_af.loc[
#                     np.flatnonzero(clone_df == ind), curr_mt] = af_i
#                 cell_af.loc[
#                     np.flatnonzero(clone_df != ind), curr_mt] = af_j
#
#             # Each cell and position has it's own coverage value, so need to update each
#             else:
#                 c = self.cells_mt_coverage
#                 # Get the cells coverage for the mt position
#                 curr_mt_cov = c[:, curr_mt]
#
#                 # Get cell indicies for the clones and nonclones
#                 curr_clone_inds = np.flatnonzero(clone_df == ind)
#                 curr_nonclone_inds = np.flatnonzero(clone_df != ind)
#                 for cell in curr_clone_inds:
#                     # Get one value for curr_mt and cell based on coverage
#                     cell_af.loc[cell, curr_mt] = random.binomial(
#                         curr_mt_cov[cell], het)
#                 for cell in curr_nonclone_inds:
#                     cell_af.loc[cell, curr_mt] = random.binomial(
#                         curr_mt_cov[cell],
#                         q)  # Loop through each coverage  # for c in n_dom_cells:
#
#
# class Simulation:
#     """Lineage tracing simulation of one sample
#
#     Will initialize cells based on their parameters and grow as well. This
#     should be a flexible framework, to add different ways to initialize, grow,
#     and metrics to have. Additionally can cluster these results.
#
#     :ivar params
#     :type params: dict
#     """
#
#     def __init__(self, params_f):
#         """
#         :param params_f: Parameter yaml file for the specifications
#         :type params_f: yaml file or dict
#         """
#         if isinstance(params_f, str):
#             params = read_config_file(params_f)
#         else:
#             params = params_f
#
#         self.params = params
#         check_required(params, ['initialize', 'num_cells', 'num_mt_positions', 'prefix'])
#         self.prefix = params['prefix']
#         self.num_mt_positions = params['num_mt_positions']
#         self.num_cells = params['num_cells']
#         if not os.path.exists(params['local_outdir']):
#             os.mkdir(params['local_outdir'])
#
#
#     def initialize(self):
#         """ (1) Pre-growth cell population is instantiated.
#
#         Creates a clone-MT dictionary, cell coverage matrix
#         (or an int, depending on parameters), and cell-AF matrix.
#         :return:
#         """
#         self.init_clone_dict()
#         self.init_cell_coverage()
#         self.init_cell_af()
#         #self.init_clone_mt()
#
#     #should be external method
#     def grow(self):
#         """ (2) Growth of cells is run."""
#         p = self.params
#         type = p["growth"]["type"]
#         if  type == "poisson":
#             self.grow_poisson(p['growth']['poisson'])
#         elif type == "binomial":
#             self.grow_binomial(p['growth']['binomial'])
#         return
#
#     # Static Method
#     @staticmethod
#     def clone_counts_to_cell_series(clone_counts):
#         """ Generates new cell IDs based on cluster count iterable
#         :param clone_counts: Each i'th element is the number of cells in
#         cluster i.
#         :type clone_counts: iterable
#         :return each index name is a cell ID and each value is which cluster
#         the cell belongs too.
#         :rtype pd.Series
#         """
#         clone_counts = np.array(clone_counts)
#         num_cells = clone_counts.sum()
#         clone_cell = -1 * np.ones(shape=[num_cells, ])
#
#         clone_cell[:clone_counts[0]] = 0
#         for ind, val in enumerate(clone_counts[1:]):
#             start = clone_counts[:ind + 1].sum()
#             end = clone_counts[:ind + 1].sum() + val
#             # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
#             clone_cell[start:end] = ind + 1
#
#         clone_cell = pd.Series(clone_cell, dtype=int)
#         return clone_cell
#
#     def init_clone_dict(self):
#         """1A
#         """
#
#         ### Add in potential to overwrite the values
#         # Gets the clone dictionary. Should also have clone to mt dict.
#         clones = self.params['initialize']['clone_sizes']
#         num_cells = self.num_cells
#
#         # Option 1: List of fraction of size of each clone. 0s are nonclone size, listed first
#         if type(clones) == list:
#             #clone_cell = pd.Series(index=range(num_cells))
#             clone_counts = np.random.multinomial(num_cells, clones)
#             clone_cell  = self.clone_counts_to_cell_series(clone_counts)
#             self.clone_cell = clone_cell
#         # Option 2: 1 clone. ID'd as 1
#         elif type(clones) == int: #One number for dominant clone. the others are not.
#             clone_cell = np.zeros(shape=[num_cells,])
#             clone_cell[:num_cells] = 1
#             clone_cell = clone_cell[::-1]
#             clone_cell =  pd.Series(clone_cell, dtype=int)
#             self.clone_cell = clone_cell
#
#         # Option 3 To ADD, beta binomial and more complex distributions
#
#         self.num_clones =  len(set(clone_cell.values))-1 # Remove the non-clone
#         return clone_cell
#
#
#     def init_cell_coverage(self):
#         """1B
#
#         There are different modes to the coverage, either a constant or
#         through a distribution.
#         """
#         p = self.params['initialize']['coverage']
#         type = p['type']
#
#         num_cells = self.num_cells
#         num_pos = self.num_mt_positions
#         c = np.zeros([num_cells, num_pos])
#
#         if type == 'constant':
#             c[:, :] = p['cov_constant']
#         elif type == "poisson":
#             # Get the number of coverage per cell based on poisson (should be reads)
#             mu_cov_per_cell = p['mu_cov_per_cell']
#             num_reads_per_cell = random.poisson(lam=mu_cov_per_cell,
#                                                 size=num_cells)
#
#             # Number of reads at each position, based on the average for each cell
#             for i in num_cells:
#                 c[i, :] = random.poisson(num_reads_per_cell[i],
#                                          size=num_pos)
#         self.cells_mt_coverage = c
#         return c
#
#
#
#         #####
#         # TODO
#         # Add noise to the other non-lineage positions
#         #####
#         self.cell_af = cell_af
#         return
#
#
#     def init_clone_mt(self):
#         p = self.params
#         if p["initialize"]['type'] == 'growth':
#             ## TODO
#             # Create a phylogeny and then get the averages of the mutants
#             self.average_clone_mt()
#         # If not growth, should aready be there.
#         return
#
#     def average_clone_mt(self):
#         return
#
#     def extract_clone_cells(self, clone_id):
#         """
#         Args:
#             clone_id:
#         """
#         ids = np.flatnonzero(self.clone_cell == clone_id)
#         return ids
#
#     def simulate_expand_cells_af(self, af, growth_inds, sigma):
#         """Given a cell-by-af vector, expand the AF.
#
#         Expanded AF occurs by duplicating cells that grew based on the
#         growth_inds vector. It will add standard error to each af based on sigma
#         :param af: :param growth: Indices of AF to copy :param sigma: Variance
#         to add to AF of child. :return:
#
#         Args:
#             af:
#             growth_inds:
#             sigma:
#         """
#
#         new_af = af.iloc[growth_inds].copy() + random.normal(0, sigma, size=af.iloc[growth_inds].shape)
#         new_af.index = np.arange(af.index[-1]+1, af.index[-1]+1+new_af.shape[0])
#         new_af = pd.concat((af,new_af), axis=0)
#         #new_af = np.append(af, np.concatenate(new_af))
#         return new_af
#
#     def grow_binomial(self, p):
#         """ (2.1)
#         Args:
#             p:
#         """
#         timesteps = p["time_steps"]
#         rates = p["rates"]
#
#         sigma = self.params['growth']["mutant_af_sigma_noise"]
#         cell_af = self.cell_af
#         clone_mt_dict = self.clone_mt_dict
#
#         num_clones = self.num_clones+1
#         new_dict = {}
#         for curr_clone in range(num_clones):
#             curr_rate = rates[curr_clone]
#             ids = self.extract_clone_cells(curr_clone)
#             new_cells = cell_af.loc[ids].copy()
#             for i in range(timesteps):
#                 # Simulate growth for each clone separately.
#                 growth_inds = np.flatnonzero(random.binomial(1, curr_rate, size=new_cells.shape[0]))
#                 #new_ids =
#                 new_cells = self.simulate_expand_cells_af(new_cells, growth_inds, sigma)
#
#             new_dict[curr_clone] = new_cells
#             # Create list of cells
#
#         ####TODO
#         ## new_lineage_mutants chances. This will see if a mutation will change
#
#
#         ####TODO
#         ## Add death + stimulation rate as well as growth
#         # Save the new cell clones df and cell af
#         clone_counts = [i.shape[0] for i in new_dict.values()]
#         self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)
#
#         self.new_cell_af = pd.DataFrame(new_dict[0])
#         for clone in range(1, self.num_clones+1):
#             self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
#         return
#
#
#     def grow_poisson(self):
#         # TODO growth of poisson refactor
#         return
#
#
#     def subsample_new(self, to_delete=False):
#         """(3) Subsample from new cell population
#
#         :param to_delete: To remove the cells that grew (which takes up
#         a lot of RAM).
#         :type to_delete: bool
#         """
#         new_cell_af = self.new_cell_af
#         p = self.params
#         if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
#             self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
#         else:
#             self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)
#
#         self.subsample_new_clone_cell = self.new_clone_cell.loc[
#             self.subsample_new_cell_af.index]
#
#         if to_delete:
#             self.new_cell_af = None
#             self.new_clone_cell = None
#
#
#     def combine_init_growth(self):
#         """(4) Add the pre- and post- population of cells into a group.
#
#         :return:
#         """
#         combined_cell_af = self.cell_af.append(self.subsample_new_cell_af).reset_index(drop=True)
#         combined_clones = pd.concat(
#             (self.clone_cell, self.subsample_new_clone_cell)).reset_index(
#             drop=True)
#
#
#         combined_befaft = np.concatenate((np.zeros(shape=[self.cell_af.shape[0],]), np.ones(shape=[self.subsample_new_cell_af.shape[0]])))
#         combined_meta = pd.DataFrame({"pre_post": combined_befaft, "clone": combined_clones})
#         #combined_meta = pd.Series(combined_meta, name='After Growth', dtype=int)
#         assert(combined_meta.shape[0] == self.cell_af.shape[0]+self.subsample_new_cell_af.shape[0])
#         assert (combined_cell_af.shape[0] == self.cell_af.shape[0] +
#                 self.subsample_new_cell_af.shape[0])
#         assert(combined_meta.shape[0] == combined_clones.shape[0])
#         assert(combined_cell_af.shape[0] == combined_clones.shape[0])
#         self.combined_meta = combined_meta
#         self.combined_clones = combined_clones
#         self.combined_cell_af = combined_cell_af
#         return
#
#     def save(self, f_save=None):
#         """
#         Args:
#             f_save:
#         """
#         if f_save is None:
#             f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
#         f = open(f_save, 'wb')
#         pickle.dump(self.__dict__, f, 2)
#         f.close()
#
#     @staticmethod
#     def expand_to_mgatk(curr_mt_af,mt_ref):
#         ref = mt_ref[curr_mt_af.name]
#         pos = curr_mt_af.name
#         return pd.DataFrame({"Ref":ref, "Pos":pos, "Val":curr_mt_af})
#
#     def test_save_to_mgatk_format(self):
#         df = pd.DataFrame( [[10,0,1,3,5], [3,0,5,5,0], [6,2,1,1,0]] , columns=np.arange(0,5))
#         mt_ref_dict = {0: "A", 1: "G", 2: "C", 3: "C", 4: "T"}
#         mt_ref = pd.DataFrame({"Pos": mt_ref_dict.keys(), "Ref": mt_ref_dict})
#         return
#
#     def save_to_mgatk_format(self, mt_ref, out_f):
#         """Converts into the proper files needed for mgatk. (i.e variant and
#         coverage files)
#
#         :return:
#         """
#         cell_af = self.subsample_new_cell_af
#         chars = ["A", "G", "C", "T"]
#         def alt_generate(x):
#             curr = chars.copy()
#             curr.remove(x["Ref"])
#             return np.random.choice(curr)
#         alt_ref = mt_ref.apply(alt_generate, axis=1)
#
#         # First use the AF and choose an alternative allele
#         df_stack = cell_af.stack().reset_index().rename(
#             {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
#             axis=1)
#         df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
#             lambda x: alt_ref[x])
#
#         # Add on the reference allele
#         df_stack_ref = cell_af.stack().reset_index().rename(
#             {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
#             axis=1)
#         df_stack_ref["Coverage"] = 1-df_stack_ref["Coverage"]
#         df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
#             lambda x: mt_ref[x])
#
#         df_stack = pd.concat(df_stack, df_stack_ref)
#         for ind, val in df_stack.groupby("Nucleotide"):
#             # Drop the 0s
#             curr = val[val["Coverage"]>0]
#             # Save file
#             curr_out_f = out_f + "_" + ind + ".txt"
#             curr.to_csv(curr_out_f)
#
#         # Save the coverage.
#         coverage = self.cells_mt_coverage
#         if type(coverage) != int:
#             coverage_stack = pd.DataFrame(coverage).stack().reset_index().rename(
#                 {"level_0": "Cell", "level_1": "MT Position", 0: "Coverage"},
#                 axis=1)
#         else:
#             coverage_stack = pd.DataFrame(self.cells_mt_coverage)*np.ones(shape=cell_af.shape).stack().reset_index().rename(
#                 {"level_0": "Cell", "level_1": "MT Position",  0: "Coverage"},
#                 axis=1)
#         curr_out_f = out_f + "_" + "coverage.txt"
#         coverage_stack.to_csv(curr_out_f)
#         return
#
#     def load(self):
#         filename = self.params['filename']
#         f = open(filename, 'rb')
#         tmp_dict = pickle.load(f)
#         f.close()
#         self.__dict__.update(tmp_dict)
#
#     def compare_before_after(self):
#         """Creates a df that contains information on the number of cells from
#         each clone before as well as after. :return: df.at[ind, "Dominant
#         Before"] = (full_sim.clone_cell == 1).sum() df.at[ind, "Dominant After"]
#         = (full_sim.subsample_new_clone_cell == 1).sum()
#         """
#
#         return
#
#     def cluster_compare_before_after(self):
#         """Compares the performance of clustering on grouping the same clones
#         together. :return:
#         """
#         return
#
#
# def main():
#     return


# if "__name__" == "__main__":
#     main()
 No newline at end of file
ew file mode 100644
ndex 0000000..b48b215
inary files /dev/null and b/src/simulations/simulation_trace.png differ
eleted file mode 100644
ndex fba6e66..0000000
++ /dev/null

fa2ba7a7bae11d73816face29f2e503518131355

growth of poisson refactor

growth of poisson refactor

return

# # TODO growth of poisson refactor

# import os
# import pandas as pd
# import pickle
# from src.simulations.utils.config import read_config_file, write_config_file
# from src.simulations.utils.config import check_required
import pymc3 as pm
import matplotlib.pyplot as plt

num_cells = 10000
num_mt_positions = 10
clone_dist = [0.10,0.01,.89]
hets = [0.2,0.3] # len(hets) == len(clone_dist)-1
avg_cov = 50
het_err_rate = 0.1


df = np.concatenate((np.random.binomial(10,0.3,(100,4)),
                       np.random.binomial(10,0.6,(90,4))))
clone_id = np.concatenate((np.zeros([100,]), np.ones([90,]))).astype(int)
mt_id = [0,1,2,3]

with pm.Model() as model:

    clone_ids = pm.Mulinomial(10000, clone_dist)

    beta = pm.Beta('beta', alpha=2,beta=2, shape=2)
    #p = pm.Bernoulli('p', 1, beta, shape=2)
    #p = pm.Binomial('p', 1, beta)
    #q = pm.Binomial('q', 1, beta)

    s = pm.Binomial('s', 10, beta[clone_id], observed=df)

    #s = pm.Binomial('s', 10, p, observed=df[:30,0])
    #t = pm.Binomial('t', 10, q, observed=df[30:, 0])

    #s = pm.Binomial('s', 10, beta, shape=(30,4), observed=df[:30])
    #t = pm.Binomial('t', 10, beta, shape=(25, 4), observed=df[30:])

    #vec = pm.math.concatenate((s, t), axis=0)

    # data = pm.Data("data", df)
    # u = pm.Normal('u', vec, observed=data)
    #u = pm.Deterministic('u', vec)

    trace = pm.sample(draws=8000, init='adapt_diag')

print(pm.summary(trace))
dot = pm.model_to_graphviz(model)
dot.render('simulation_pymc.gv')
pm.plot_trace(trace)
plt.savefig('simulation_trace.png')
print('here')
#
# with pm.Model() as model:
#     clone_counts = pm.Multinomial(num_cells, clone_dist)
#     num_clones = len(clone_counts) - 1
#
#     clone_cell = -1 * np.ones(shape=[num_cells, ])
#
#     clone_cell[:clone_counts[0]] = 0
#
#     for ind, val in enumerate(clone_counts[1:]):
#         start = clone_counts[:ind + 1].sum()
#         end = clone_counts[:ind + 1].sum() + val
#         # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
#         clone_cell[start:end] = ind + 1
#
#     c = pm.Poisson('cov', avg_cov, shape=[num_cells, num_mt_positions])
#
#     clone_mt_dict = dict()
#     for i in range(1, num_clones + 1):
#         clone_mt_dict[i] = i
#
#     cell_af = np.zeros([num_cells, num_mt_positions])
#     for ind in range(num_clones):
#         # Generate AF: (clone_df ==  ind).sum()
#         n_dom_cells = clone_counts[ind]
#         het = hets[ind]
#
#         curr_mt = clone_mt_dict[ind]
#
#         af_i = pm.Binomial('af', avg_cov, het, shape=n_dom_cells)
#         af_j = pm.Binomial('het af', avg_cov, het_err_rate, shape=num_cells - n_dom_cells) # / c
#
#
#         # Update the dom_cells and non_dom for the current MT
#         cell_af[np.flatnonzero(clone_df == ind), curr_mt] = af_i
#         cell_af[np.flatnonzero(clone_df != ind), curr_mt] = af_j
#
#     cell_af = pm.Deterministic(y)
#
# pm.model_to_graphviz(model)
#
#
# def init_cell_af(self):
#     """1C. Initialize the cell-by-mtPos af dataframe. Unless a clone:mt dict was
#     provided, the first N MT positions will be the clone AFs. Creates
#     self.clone_mt_dict and self.cell_af
#     """
#     clone_df = self.clone_cell
#     # Output
#     cell_af = pd.DataFrame(np.zeros(shape=[n_cells, n_mt]))
#
#         # Each clone points to a mt position
#         self.clone_mt_dict = dict()
#         for i in range(1, num_clones + 1):
#             self.clone_mt_dict[i] = i
#
#     # TODO Add the MT clone map so it can contain multiple mutants in lineages
#
#     # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
#     # If not, will randomly draw based on number of clones
#     if type(hets) == list:
#         assert (len(hets) == num_clones)
#
#         ## Loop through each clone,
#         ## Generate the AF for the clone and non-clones using coverage for each cell
#         ## Fill in cell_by_af for that position.
#         for ind in range(1, num_clones + 1):
#             # Generate AF: (clone_df ==  ind).sum()
#             n_dom_cells = (clone_df == ind).sum()
#             het = hets[ind - 1]
#
#             curr_mt = self.clone_mt_dict[ind]
#
#             if p['coverage']['type'] == 'constant':
#                 c = p['coverage']['cov_constant']
#
#                 af_i = random.binomial(c, het, n_dom_cells) / c
#                 af_j = random.binomial(c, q, n_cells - n_dom_cells) / c
#
#                 # Update the dom_cells and non_dom for the current MT
#                 cell_af.loc[
#                     np.flatnonzero(clone_df == ind), curr_mt] = af_i
#                 cell_af.loc[
#                     np.flatnonzero(clone_df != ind), curr_mt] = af_j
#
#             # Each cell and position has it's own coverage value, so need to update each
#             else:
#                 c = self.cells_mt_coverage
#                 # Get the cells coverage for the mt position
#                 curr_mt_cov = c[:, curr_mt]
#
#                 # Get cell indicies for the clones and nonclones
#                 curr_clone_inds = np.flatnonzero(clone_df == ind)
#                 curr_nonclone_inds = np.flatnonzero(clone_df != ind)
#                 for cell in curr_clone_inds:
#                     # Get one value for curr_mt and cell based on coverage
#                     cell_af.loc[cell, curr_mt] = random.binomial(
#                         curr_mt_cov[cell], het)
#                 for cell in curr_nonclone_inds:
#                     cell_af.loc[cell, curr_mt] = random.binomial(
#                         curr_mt_cov[cell],
#                         q)  # Loop through each coverage  # for c in n_dom_cells:
#
#
# class Simulation:
#     """Lineage tracing simulation of one sample
#
#     Will initialize cells based on their parameters and grow as well. This
#     should be a flexible framework, to add different ways to initialize, grow,
#     and metrics to have. Additionally can cluster these results.
#
#     :ivar params
#     :type params: dict
#     """
#
#     def __init__(self, params_f):
#         """
#         :param params_f: Parameter yaml file for the specifications
#         :type params_f: yaml file or dict
#         """
#         if isinstance(params_f, str):
#             params = read_config_file(params_f)
#         else:
#             params = params_f
#
#         self.params = params
#         check_required(params, ['initialize', 'num_cells', 'num_mt_positions', 'prefix'])
#         self.prefix = params['prefix']
#         self.num_mt_positions = params['num_mt_positions']
#         self.num_cells = params['num_cells']
#         if not os.path.exists(params['local_outdir']):
#             os.mkdir(params['local_outdir'])
#
#
#     def initialize(self):
#         """ (1) Pre-growth cell population is instantiated.
#
#         Creates a clone-MT dictionary, cell coverage matrix
#         (or an int, depending on parameters), and cell-AF matrix.
#         :return:
#         """
#         self.init_clone_dict()
#         self.init_cell_coverage()
#         self.init_cell_af()
#         #self.init_clone_mt()
#
#     #should be external method
#     def grow(self):
#         """ (2) Growth of cells is run."""
#         p = self.params
#         type = p["growth"]["type"]
#         if  type == "poisson":
#             self.grow_poisson(p['growth']['poisson'])
#         elif type == "binomial":
#             self.grow_binomial(p['growth']['binomial'])
#         return
#
#     # Static Method
#     @staticmethod
#     def clone_counts_to_cell_series(clone_counts):
#         """ Generates new cell IDs based on cluster count iterable
#         :param clone_counts: Each i'th element is the number of cells in
#         cluster i.
#         :type clone_counts: iterable
#         :return each index name is a cell ID and each value is which cluster
#         the cell belongs too.
#         :rtype pd.Series
#         """
#         clone_counts = np.array(clone_counts)
#         num_cells = clone_counts.sum()
#         clone_cell = -1 * np.ones(shape=[num_cells, ])
#
#         clone_cell[:clone_counts[0]] = 0
#         for ind, val in enumerate(clone_counts[1:]):
#             start = clone_counts[:ind + 1].sum()
#             end = clone_counts[:ind + 1].sum() + val
#             # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
#             clone_cell[start:end] = ind + 1
#
#         clone_cell = pd.Series(clone_cell, dtype=int)
#         return clone_cell
#
#     def init_clone_dict(self):
#         """1A
#         """
#
#         ### Add in potential to overwrite the values
#         # Gets the clone dictionary. Should also have clone to mt dict.
#         clones = self.params['initialize']['clone_sizes']
#         num_cells = self.num_cells
#
#         # Option 1: List of fraction of size of each clone. 0s are nonclone size, listed first
#         if type(clones) == list:
#             #clone_cell = pd.Series(index=range(num_cells))
#             clone_counts = np.random.multinomial(num_cells, clones)
#             clone_cell  = self.clone_counts_to_cell_series(clone_counts)
#             self.clone_cell = clone_cell
#         # Option 2: 1 clone. ID'd as 1
#         elif type(clones) == int: #One number for dominant clone. the others are not.
#             clone_cell = np.zeros(shape=[num_cells,])
#             clone_cell[:num_cells] = 1
#             clone_cell = clone_cell[::-1]
#             clone_cell =  pd.Series(clone_cell, dtype=int)
#             self.clone_cell = clone_cell
#
#         # Option 3 To ADD, beta binomial and more complex distributions
#
#         self.num_clones =  len(set(clone_cell.values))-1 # Remove the non-clone
#         return clone_cell
#
#
#     def init_cell_coverage(self):
#         """1B
#
#         There are different modes to the coverage, either a constant or
#         through a distribution.
#         """
#         p = self.params['initialize']['coverage']
#         type = p['type']
#
#         num_cells = self.num_cells
#         num_pos = self.num_mt_positions
#         c = np.zeros([num_cells, num_pos])
#
#         if type == 'constant':
#             c[:, :] = p['cov_constant']
#         elif type == "poisson":
#             # Get the number of coverage per cell based on poisson (should be reads)
#             mu_cov_per_cell = p['mu_cov_per_cell']
#             num_reads_per_cell = random.poisson(lam=mu_cov_per_cell,
#                                                 size=num_cells)
#
#             # Number of reads at each position, based on the average for each cell
#             for i in num_cells:
#                 c[i, :] = random.poisson(num_reads_per_cell[i],
#                                          size=num_pos)
#         self.cells_mt_coverage = c
#         return c
#
#
#
#         #####
#         # TODO
#         # Add noise to the other non-lineage positions
#         #####
#         self.cell_af = cell_af
#         return
#
#
#     def init_clone_mt(self):
#         p = self.params
#         if p["initialize"]['type'] == 'growth':
#             ## TODO
#             # Create a phylogeny and then get the averages of the mutants
#             self.average_clone_mt()
#         # If not growth, should aready be there.
#         return
#
#     def average_clone_mt(self):
#         return
#
#     def extract_clone_cells(self, clone_id):
#         """
#         Args:
#             clone_id:
#         """
#         ids = np.flatnonzero(self.clone_cell == clone_id)
#         return ids
#
#     def simulate_expand_cells_af(self, af, growth_inds, sigma):
#         """Given a cell-by-af vector, expand the AF.
#
#         Expanded AF occurs by duplicating cells that grew based on the
#         growth_inds vector. It will add standard error to each af based on sigma
#         :param af: :param growth: Indices of AF to copy :param sigma: Variance
#         to add to AF of child. :return:
#
#         Args:
#             af:
#             growth_inds:
#             sigma:
#         """
#
#         new_af = af.iloc[growth_inds].copy() + random.normal(0, sigma, size=af.iloc[growth_inds].shape)
#         new_af.index = np.arange(af.index[-1]+1, af.index[-1]+1+new_af.shape[0])
#         new_af = pd.concat((af,new_af), axis=0)
#         #new_af = np.append(af, np.concatenate(new_af))
#         return new_af
#
#     def grow_binomial(self, p):
#         """ (2.1)
#         Args:
#             p:
#         """
#         timesteps = p["time_steps"]
#         rates = p["rates"]
#
#         sigma = self.params['growth']["mutant_af_sigma_noise"]
#         cell_af = self.cell_af
#         clone_mt_dict = self.clone_mt_dict
#
#         num_clones = self.num_clones+1
#         new_dict = {}
#         for curr_clone in range(num_clones):
#             curr_rate = rates[curr_clone]
#             ids = self.extract_clone_cells(curr_clone)
#             new_cells = cell_af.loc[ids].copy()
#             for i in range(timesteps):
#                 # Simulate growth for each clone separately.
#                 growth_inds = np.flatnonzero(random.binomial(1, curr_rate, size=new_cells.shape[0]))
#                 #new_ids =
#                 new_cells = self.simulate_expand_cells_af(new_cells, growth_inds, sigma)
#
#             new_dict[curr_clone] = new_cells
#             # Create list of cells
#
#         ####TODO
#         ## new_lineage_mutants chances. This will see if a mutation will change
#
#
#         ####TODO
#         ## Add death + stimulation rate as well as growth
#         # Save the new cell clones df and cell af
#         clone_counts = [i.shape[0] for i in new_dict.values()]
#         self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)
#
#         self.new_cell_af = pd.DataFrame(new_dict[0])
#         for clone in range(1, self.num_clones+1):
#             self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
#         return
#
#
#     def grow_poisson(self):
#         # TODO growth of poisson refactor
#         return
#
#
#     def subsample_new(self, to_delete=False):
#         """(3) Subsample from new cell population
#
#         :param to_delete: To remove the cells that grew (which takes up
#         a lot of RAM).
#         :type to_delete: bool
#         """
#         new_cell_af = self.new_cell_af
#         p = self.params
#         if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
#             self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
#         else:
#             self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)
#
#         self.subsample_new_clone_cell = self.new_clone_cell.loc[
#             self.subsample_new_cell_af.index]
#
#         if to_delete:
#             self.new_cell_af = None
#             self.new_clone_cell = None
#
#
#     def combine_init_growth(self):
#         """(4) Add the pre- and post- population of cells into a group.
#
#         :return:
#         """
#         combined_cell_af = self.cell_af.append(self.subsample_new_cell_af).reset_index(drop=True)
#         combined_clones = pd.concat(
#             (self.clone_cell, self.subsample_new_clone_cell)).reset_index(
#             drop=True)
#
#
#         combined_befaft = np.concatenate((np.zeros(shape=[self.cell_af.shape[0],]), np.ones(shape=[self.subsample_new_cell_af.shape[0]])))
#         combined_meta = pd.DataFrame({"pre_post": combined_befaft, "clone": combined_clones})
#         #combined_meta = pd.Series(combined_meta, name='After Growth', dtype=int)
#         assert(combined_meta.shape[0] == self.cell_af.shape[0]+self.subsample_new_cell_af.shape[0])
#         assert (combined_cell_af.shape[0] == self.cell_af.shape[0] +
#                 self.subsample_new_cell_af.shape[0])
#         assert(combined_meta.shape[0] == combined_clones.shape[0])
#         assert(combined_cell_af.shape[0] == combined_clones.shape[0])
#         self.combined_meta = combined_meta
#         self.combined_clones = combined_clones
#         self.combined_cell_af = combined_cell_af
#         return
#
#     def save(self, f_save=None):
#         """
#         Args:
#             f_save:
#         """
#         if f_save is None:
#             f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
#         f = open(f_save, 'wb')
#         pickle.dump(self.__dict__, f, 2)
#         f.close()
#
#     @staticmethod
#     def expand_to_mgatk(curr_mt_af,mt_ref):
#         ref = mt_ref[curr_mt_af.name]
#         pos = curr_mt_af.name
#         return pd.DataFrame({"Ref":ref, "Pos":pos, "Val":curr_mt_af})
#
#     def test_save_to_mgatk_format(self):
#         df = pd.DataFrame( [[10,0,1,3,5], [3,0,5,5,0], [6,2,1,1,0]] , columns=np.arange(0,5))
#         mt_ref_dict = {0: "A", 1: "G", 2: "C", 3: "C", 4: "T"}
#         mt_ref = pd.DataFrame({"Pos": mt_ref_dict.keys(), "Ref": mt_ref_dict})
#         return
#
#     def save_to_mgatk_format(self, mt_ref, out_f):
#         """Converts into the proper files needed for mgatk. (i.e variant and
#         coverage files)
#
#         :return:
#         """
#         cell_af = self.subsample_new_cell_af
#         chars = ["A", "G", "C", "T"]
#         def alt_generate(x):
#             curr = chars.copy()
#             curr.remove(x["Ref"])
#             return np.random.choice(curr)
#         alt_ref = mt_ref.apply(alt_generate, axis=1)
#
#         # First use the AF and choose an alternative allele
#         df_stack = cell_af.stack().reset_index().rename(
#             {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
#             axis=1)
#         df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
#             lambda x: alt_ref[x])
#
#         # Add on the reference allele
#         df_stack_ref = cell_af.stack().reset_index().rename(
#             {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
#             axis=1)
#         df_stack_ref["Coverage"] = 1-df_stack_ref["Coverage"]
#         df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
#             lambda x: mt_ref[x])
#
#         df_stack = pd.concat(df_stack, df_stack_ref)
#         for ind, val in df_stack.groupby("Nucleotide"):
#             # Drop the 0s
#             curr = val[val["Coverage"]>0]
#             # Save file
#             curr_out_f = out_f + "_" + ind + ".txt"
#             curr.to_csv(curr_out_f)
#
#         # Save the coverage.
#         coverage = self.cells_mt_coverage
#         if type(coverage) != int:
#             coverage_stack = pd.DataFrame(coverage).stack().reset_index().rename(
#                 {"level_0": "Cell", "level_1": "MT Position", 0: "Coverage"},
#                 axis=1)
#         else:
#             coverage_stack = pd.DataFrame(self.cells_mt_coverage)*np.ones(shape=cell_af.shape).stack().reset_index().rename(
#                 {"level_0": "Cell", "level_1": "MT Position",  0: "Coverage"},
#                 axis=1)
#         curr_out_f = out_f + "_" + "coverage.txt"
#         coverage_stack.to_csv(curr_out_f)
#         return
#
#     def load(self):
#         filename = self.params['filename']
#         f = open(filename, 'rb')
#         tmp_dict = pickle.load(f)
#         f.close()
#         self.__dict__.update(tmp_dict)
#
#     def compare_before_after(self):
#         """Creates a df that contains information on the number of cells from
#         each clone before as well as after. :return: df.at[ind, "Dominant
#         Before"] = (full_sim.clone_cell == 1).sum() df.at[ind, "Dominant After"]
#         = (full_sim.subsample_new_clone_cell == 1).sum()
#         """
#
#         return
#
#     def cluster_compare_before_after(self):
#         """Compares the performance of clustering on grouping the same clones
#         together. :return:
#         """
#         return
#
#
# def main():
#     return


# if "__name__" == "__main__":
#     main()
 No newline at end of file
ew file mode 100644
ndex 0000000..b48b215
inary files /dev/null and b/src/simulations/simulation_trace.png differ
eleted file mode 100644
ndex fba6e66..0000000
++ /dev/null

8a9cd1cb609d55c77709007a307261946918fc2e

account for mt_pos being a list not an int

account for mt_pos being a list not an int

# TODO account for mt_pos being a list not an int

import numpy as np
from numpy import random
import os
import pandas as pd
from tqdm import tqdm
#from src.config import ROOT_DIR
from sklearn.metrics import roc_curve, average_precision_score, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import glob
from sklearn.cluster import KMeans
from sklearn import metrics
from scipy.spatial.distance import cdist
from pandarallel import pandarallel

from mplh.color_utils import get_colors
from mplh.fig_utils import legend_from_color
from mplh import cluster_help as ch
from src.simulations.utils.config import read_config_file, write_config_file

from dynamicTreeCut import cutreeHybrid
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage
from sklearn.model_selection import ParameterGrid
from src.simulations.utils.config import check_required
from src.simulations.analysis import Analysis as an
from .simulation import Simulation
from src.utils.utils import compare_arbitrary_labels

# Does this ruin running the MCMC? I don't think so, b/c that format is going to be put in after anyway
class FullSimulation:
    """
    Class that simulates cell growth for lineage tracing. Reads in a
    parameter file and runs a certain number of iterations based on
    the num_iterations parameter.

    :ivar n_iter: Number of iterations
    :type n_iter: int
    :ivar num_cells: Number of cells to sequence
    :type num_cells: int

    :ivar sim: Each index is a different iteration of the simulation.
    :type sim: pandas Series

    """
    def __init__(self, params_f):
        params = read_config_file(params_f)
        self.n_iter = params['num_iterations']
        self.num_cells = params['num_cells']
        self.params = params
        if 'n_clust' not in params:
            self.params['n_clust'] = None
        # Store the metrics with this
        self.metrics = dict()

        # Files to save
        #self.outdir = os.path.join(self.params['local_outdir'])
        self.outdir = self.params['local_outdir' ]
        self.data_outdir = os.path.join(self.params['data_outdir'])
        self.f_save_data = os.path.join(self.data_outdir,
                                   self.params['name'] + '.p')
        self.f_save = os.path.join(self.outdir, self.params['name'] + '.p')

        self.f_save_metrics = self.f_save_data.replace('.p', '.metrics.tsv')
        self.f_save_cluster = self.f_save_data.replace('.p', '.cluster.tsv')
        self.f_save_befaft = self.f_save_data.replace('.p', '.before_after.tsv')
        self.f_save_befaft_cl = self.f_save_data.replace('.p',
                                                      '.before_after_cl.tsv')
        self.f_save_rocs = self.f_save_data.replace('.p', '.rocs.p')
        return
        #for i in self.n_iter:

    def run(self):
        """
        Runs the simulation and stores it in sim attr. Will also pickle
        the objects and save.

        This uses Pandaralel to parallelize the runs.
        :return:
        """
        # Parallelize df
        df = pd.Series(index=range(self.n_iter))
        df = df.apply(self.run_sim, args=(self.params,))

        #pandarallel.initialize(nb_workers=self.params['cpus'])
        #df = df.parallel_apply(self.run_sim, args=(self.params,))

        self.sim = df

        if "save_small" in self.params and self.params["save_small"] > 0:
            curr = self.sim[:self.params["save_small"]]
            for ind, val in curr.items():
                curr_f = self.f_save.replace(".p","") + "_sim" + str(ind)
                val.to_csv(curr_f+".csv")

        return

    @staticmethod
    def run_sim(x, params):
        """Run iteration of simulation.

        For a single iteration, it will initialize, grow, subsample,
        and merge the before stimulus and after stimulus variables.
        It willl also run
        :param x: Placeholder variable
        :param params: The parameter dictionary to use
        :type params: dict
        :return:
        """
        s = Simulation(params)
        s.initialize()
        s.grow()
        s.subsample_new(to_delete=True)
        s.combine_init_growth()
        return s

    def run_metrics(self):
        """
        Get metrics performances and save.
        :return:
        """
        self.sim_performance_dominant(group='both')
        self.stats_before_after()
        self.cluster_befaft()
        self.stats_cluster_befaft_dom()
        self.estimate_growth_rates_from_cluster()
        self.stats_before_after_clust()
        self.kl_divergence()

    def flatten_sim(self):
        ## TODO
        # This will extract out the classes of df
        return

    def sim_performance_dominant(self, group='both'):
        """
        Will colect metrics that are averaged over the simulations.
        These are specifically for looking at the main, dominant clone,
        and what the allele-frequency of that clone variant
        is for each cell.

        :param group: {'init', 'growth', 'both'} This will indicate to group by
        :ivar dropout: Number of dominant clone cells that have 0 reads
        at the lineage variant position.
        :type dropout: list
        :ivar prec_scores: sklearn average precision score based on
        the allele frequencies seen in the dominant clone cells versus
        the non-clone cells.
        :type prec_scores: list
        :ivar rocs: ROC curves for each iteration based on allele
        frequencies.

        :return:
        """
        dropout = []
        rocs = []
        prec_scores = []

        for iter, s in enumerate(self.sim.values):
            # First get the dominant clone , which is indexed as 1
            mt_pos = s.clone_mt_dict[1]
            # TODO account for mt_pos being a list not an int
            if group == 'init':
                clones = s.clone_cell
                cell_af = s.cell_af.loc[:,mt_pos]
            elif group == 'growth':
                clones = s.new_clone_cell
                cell_af = s.new_cell_af.loc[:,mt_pos]
            elif group == 'both':
                #clones = pd.concat((s.clone_cell, s.subsample_new_clone_cell)).reset_index(drop=True)
                #cell_af = pd.concat((s.cell_af.loc[:,mt_pos], s.subsample_new_cell_af.loc[:,mt_pos])).reset_index(drop=True)
                clones = s.combined_clones
                cell_af = s.combined_cell_af.loc[:,mt_pos]
            else:
                raise ValueError('group variable not properly set.')

            y_true = clones.values.copy()
            y_true[y_true != 1] = 0  # Set nondominant clones to 0
            rocs.append(roc_curve(y_true, cell_af))
            prec_scores.append(average_precision_score(y_true, cell_af))
            dropout.append((cell_af[clones==1]==0).sum()/cell_af[clones==1].shape[0])

        self.dropout = dropout
        self.prec_scores = prec_scores
        self.rocs = rocs
        pd.DataFrame([prec_scores, dropout], index=['Precision', 'Dropout']).to_csv(self.f_save_metrics, sep='\t')
        self.metrics['prec_scores'] = prec_scores
        self.metrics['dropout'] = dropout
        self.metrics['rocs'] = rocs
        pickle.dump(rocs, open(self.f_save_rocs, 'wb'))

        return


    def reduce_cells(self, cell_af):
        #self.sim
        return


    def stats_before_after(self, clone_id=1):
        b_a_df = pd.DataFrame(index=np.arange(0,len(self.sim)), columns=["B", "A", "A/B"], dtype=str)
        for iter, s in enumerate(self.sim.values):
            b_clones = s.clone_cell
            a_clones = s.subsample_new_clone_cell
            b_a_df.at[iter, "B"] = (b_clones == clone_id).sum()
            b_a_df.at[iter, "A"] = (a_clones==clone_id).sum()
            b_a_df.at[iter,"A/B"] = (b_a_df.at[iter, "A"]+1)/(b_a_df.at[iter, "B"]+1)

        self.b_a_df = b_a_df
        b_a_df.to_csv(self.f_save_befaft, sep='\t')
        self.metrics['b_a_df'] = b_a_df
        return



    def cluster_befaft(self):
        """
        Loops through the simulations and for each,
        it clusters the cells.

        :ivar cluster_results: Cluster labels for each cell in each
        iteration.
        :type List of tuples, which is a list of
        a tuple, where the tuple is indexed by the cell and the value
        is the cell's cluster label
        """
        cluster_results = []


        for s in tqdm(self.sim.values):
            cluster_results.append(an.cluster_kmeans(s.combined_cell_af,
                                                     n_clust=self.params['n_clust']))
            # Add the cluster results to combined meta
            s.combined_meta["cluster"] = cluster_results[-1]

            # Bring the cluster labels and the clone labels into same
            # name space
            s.combined_meta['cluster_clone'] = compare_arbitrary_labels(s.combined_meta['clone'],
                                                                        s.combined_meta['cluster'])

        self.cluster_results = cluster_results


    def stats_cluster_befaft_dom(self, clone_id=1):
        """
        Confusion matrix for clustering the proper clone cells together.
        :param clone_id: Which clone to get metrics for
        :return:
        """

        b_a_df = pd.DataFrame(index=np.arange(len(self.sim)),
                              columns=["TN", "FP", "FN", "TP"], dtype=int)
        f1_vals = []
        for ind, s in enumerate(self.sim.values):
            y_true = s.combined_clones
            y_true[y_true!=clone_id] = 0
            y_pred = s.combined_meta['cluster_clone'].copy()
            y_pred[y_pred!=clone_id] = 0

            f1_vals.append(f1_score(y_true, y_pred))
            # y_true, y_pred
            tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
            b_a_df.loc[ind] = [tn, fp, fn, tp]
        self.b_a_df_clust = b_a_df

        self.metrics['Dominant Cluster Confusion'] = b_a_df
        self.metrics['Cluster F1 scores'] = f1_vals
        return



    def estimate_growth_rates_from_known(self):
        all_growth_estimate = []
        all_clone_sizes = []
        for iter, s in enumerate(self.sim.values):
            growth_estimate, clone_sizes, _, _ = an.estimate_growth_rate(s.combined_meta)
            all_growth_estimate.append(growth_estimate)
            all_clone_sizes.append(clone_sizes)

        self.metrics['obs_growth_rates'] = all_growth_estimate
        self.metrics['obs_clone_sizes'] = all_clone_sizes
        return

    def estimate_growth_rates_from_cluster(self):
        all_growth_estimate = []
        all_clone_sizes = []
        all_bef_est = []
        all_aft_est = []
        for iter, s in enumerate(self.sim.values):
            growth_estimate, clone_sizes, aft_est, bef_est = an.estimate_growth_rate(s.combined_meta, clone_col="cluster_clone")
            all_growth_estimate.append(growth_estimate)
            all_bef_est.append(bef_est)
            all_aft_est.append(aft_est)
            all_clone_sizes.append(clone_sizes)
        self.metrics['pred_growth_rates'] = all_growth_estimate
        self.metrics['pred_aft_count'] = all_aft_est
        self.metrics['pred_bef_count'] = all_bef_est
        self.metrics['pred_clone_sizes'] = all_clone_sizes
        return


    def stats_before_after_clust(self, clone_id=1):
        b_a_df = pd.DataFrame(index=np.arange(0,len(self.sim)), columns=["A/B", "A", "B"], dtype=str)
        for iter, s in enumerate(self.sim.values):
            curr_pred_growth = self.metrics['pred_growth_rates'][iter]
            curr_pred_aft = self.metrics['pred_aft_count'][iter]
            #print('curr_pred_aft', curr_pred_aft)
            curr_pred_bef = self.metrics['pred_bef_count'][iter]
            b_a_df.at[iter, "A/B"] = curr_pred_growth.loc[clone_id]
            b_a_df.at[iter, "A"] = curr_pred_aft.loc[clone_id]
            b_a_df.at[iter, "B"] = curr_pred_bef.loc[clone_id]

        b_a_df.to_csv(self.f_save_befaft_cl, sep='\t')
        self.metrics['b_a_clust_df'] = b_a_df
        return

    def kl_divergence(self):
        all_growth_estimate = []
        all_clone_sizes = []
        for iter, s in enumerate(self.sim.values):
            growth_estimate, clone_sizes, _, _ = an.estimate_growth_rate(s.combined_meta)
            all_growth_estimate.append(growth_estimate)
            all_clone_sizes.append(clone_sizes)

        self.metrics['pred_growth_rates'] = all_growth_estimate
        self.metrics['pred_clone_sizes'] = all_clone_sizes


    def save(self, f_save=None):
        if f_save is None:
            f_save = self.f_save_data
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()


    def load(self, f_save=None):
        #filename = self.params['filename']
        if f_save is None:
            f_save = self.f_save
        f = open(f_save, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)


def main():
    return


if "__name__" == "__main__":
    main()
 No newline at end of file
eleted file mode 100644
ndex 720ca7d..0000000
++ /dev/null

9fa1607f8ce49b4fa95a3f6d762f1037a3052041

This should be in mro/common for general use

This should be in mro/common for general use

# TODO: This should be in mro/common for general use

#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype json;
filetype bam;
#
# @include "_sort_and_mark_dups_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype bam;
filetype bam.bai;
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype json;
filetype csv;
#
# @include "_peak_caller_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype bedgraph;
filetype pickle;
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype bed;
filetype json;
#
# @include "_basic_sc_atac_counter_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype csv;
filetype json;
filetype bed;
filetype pickle;
filetype h5;
#
# @include "_produce_cell_barcodes_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype csv;
filetype json;
filetype bed;
filetype pickle;
filetype h5;
filetype npy.gz;
#
# @include "_sc_atac_metric_collector_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv.gz;
filetype tsv.gz.tbi;
filetype bed;
filetype bam;
filetype csv;
filetype json;
filetype h5;
filetype txt;
filetype pickle;
#
# @include "_peak_annotator_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype bed;
filetype tsv;
filetype h5;
filetype gz;
filetype pickle;
#
# @include "_sc_atac_analyzer_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype tsv;
filetype h5;
filetype pickle;
filetype gz;
filetype bed;
filetype csv;
#
# @include "_sc_atac_reporter_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype json;
filetype html;
filetype csv;
filetype h5;
filetype bam;
#
# @include "_atac_cloupe_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#
filetype cloupe;
filetype csv;
filetype json;
filetype h5;
filetype bed;
filetype tsv.gz.tbi;
#
# @include "_preflight_stages.mro"
#

#
# Copyright (c) 2019 10x Genomics, Inc. All rights reserved.
#

filetype csv;
filetype bed;
filetype tsv.gz;
filetype tsv.gz.tbi;

#
# @include "_aligner_stages.mro"
#

# SETUP_CHUNKS chunks up the input fastq data into sets of matched R1, R2, SI, and BC fastq files.
# input_mode specifies how FASTQs were generated. There are two modes:
#
# 1. "BCL_PROCESSOR"
#
# FASTQs produced by the 10X BCL_PROCESSOR pipeline. This mode assumes the FASTQ files obey the internal
# naming conventions and the reads have been interleaved into RA FASTQ files.
#
# 2. "ILMN_BCL2FASTQ"
#
# FASTQs produced directly by Illumina BCL2FASTQ v1.8.4. For this mode, BCL2FASTQ must be configured to emit the
# index2 read, rather than using it for dual-index demultiplexing:
#
# configureBclToFastq.pl --no-eamss --use-bases-mask=Y100,I8,Y14,Y100 --input-dir=<basecalls_dir> \
#     --output-dir=<output_dir> --sample-sheet=<sample_sheet.csv>
#
# The sample sheet must be formatted as per the BCL2FASTQ documentation (10 column csv), and must contain a row for
# each sample index used. The sequencer must have been run in dual index mode, with the second index read (used to
# read the 10X barcode) emitted as the R2 output file. The --use-bases-mask argument should be set to the read
# length used.
stage SETUP_CHUNKS(
    in  string   sample_id        "id of the sample",
    in  map[]    sample_def       "list of dictionary specifying input data",
    in  string   input_mode       "configuration of the input fastqs",
    in  map      downsample       "map specifies either subsample_rate (float) or gigabases (int)",
    out map[]    chunks           "map has barcode, barcode_reverse_complement, sample_index, read1, read2, gem_group, and read_group fields",
    out string[] read_groups      "list of strings representing read groups",
    out json     downsample_info  "info about downsampling result",
    src py       "stages/processing/setup_chunks",
)

# Trims adapter sequences from reads and massages fastq output into a fixed format (interleaved R1 file, etc.)
stage TRIM_READS(
    in  map[]  chunks,
    in  string barcode_whitelist,
    in  int    max_read_num,
    in  map    trim_def,
    in  map    adapters,
    out map[]  chunks,
    out json   bc_counts,
    out json   lot_info,
    out json   read_counts,
    src py     "stages/processing/trim_reads",
) split (
    in  map    chunk,
) using (
    volatile = strict,
)

# Aligns the reads to the input reference, producing chunked bam files
stage ALIGN_READS(
    in  map[]  chunks,
    in  string aligner,
    in  string aligner_method,
    in  string reference_path,
    in  string read_group_sample,
    in  int    num_threads,
    out bam[],
    src py     "stages/processing/align_reads",
) split (
    in  map    chunk,
) using (
    # N.B. No index files are generated for the bam
    volatile = strict,
)

#
# @include "_aligner.mro"
#

# Takes input fastqs and chunks them, trims them, and aligns the trimmed reads to a reference
pipeline _ALIGNER(
    in  string sample_id,
    in  string fastq_mode         "configuration of the input fastqs",
    in  map[]  sample_def,
    in  string reference_path     "this is the reference_path",
    in  string barcode_whitelist  "name of barcode whitelist file",
    in  map    trim_def,
    in  map    adapters,
    in  string read_group_sample  "sample header for BAM file",
    in  map    downsample,
    out bam[]  align,
    out map[]  chunks,
    out json   bc_counts,
    out json   lot_info           "gelbead lot detected",
    out json   read_counts        "total # of read pairs before and after adapter trimming",
    out json   downsample_info    "info on downsampling",
)
{
    call SETUP_CHUNKS(
        sample_id  = self.sample_id,
        input_mode = self.fastq_mode,
        sample_def = self.sample_def,
        downsample = self.downsample,
    ) using (
        volatile = true,
    )

    call TRIM_READS(
        chunks            = SETUP_CHUNKS.chunks,
        max_read_num      = 5000000,
        trim_def          = self.trim_def,
        adapters          = self.adapters,
        barcode_whitelist = self.barcode_whitelist,
    ) using (
        volatile = true,
    )

    call ALIGN_READS(
        chunks            = TRIM_READS.chunks,
        aligner           = "bwa",
        aligner_method    = "MEM",
        reference_path    = self.reference_path,
        read_group_sample = self.read_group_sample,
        num_threads       = 4,
    ) using (
        volatile = true,
    )

    return (
        align           = ALIGN_READS,
        chunks          = TRIM_READS.chunks,
        bc_counts       = TRIM_READS.bc_counts,
        lot_info        = TRIM_READS.lot_info,
        read_counts     = TRIM_READS.read_counts,
        downsample_info = SETUP_CHUNKS.downsample_info,
    )
}

#
# @include "_sort_and_mark_dups_stages.mro"
#

# Attaches raw and corrected barcode sequences to the aligned reads
stage ATTACH_BCS(
    in  string barcode_whitelist,
    in  bam[]  align,
    in  map[]  chunks,
    in  bool   paired_end,
    in  bool   exclude_non_bc_reads,
    in  float  bc_confidence_threshold,
    in  json   bc_counts,
    out bam[]  output,
    out int    perfect_read_count,
    src py     "stages/processing/attach_bcs",
) split (
    in  bam    align_chunk,
    in  map    chunk,
) using (
    # N.B. No index files are generated for the bam
    volatile = strict,
)

stage SORT_READS_BY_POS(
    in  bam[] input,
    out bam   tagsorted_bam,
    src py    "stages/processing/sort_reads_by_pos",
) split (
    in  bam   chunk_input,
) using (
    # N.B. No index files are generated for the bam
    volatile = strict,
)

# Marks duplicates in the reads using barcodes and fragment alignments to detect PCR and optical/diffusion duplicates
stage MARK_DUPLICATES(
    in  bam        input,
    in  string     reference_path,
    in  json       raw_barcode_counts,
    in  string     barcode_whitelist,
    out bam        output,
    out bam.bai    index,
    out csv        singlecell_mapping,
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    src py         "stages/processing/mark_duplicates",
) split (
    in  map        lane_map,
    in  string     chunk_start,
    in  string     chunk_end,
    in  int        chunk_num,
) using (
    # N.B. BAM/BED index files are explicitly bound where used
    volatile = strict,
)

#
# @include "_sort_and_mark_dups.mro"
#

# Attaches barcodes to the aligned reads, marks duplicate reads, and produces a barcode-sorted and position-sorted
# output BAM
pipeline _SORT_AND_MARK_DUPS(
    in  bam[]      align,
    in  map[]      chunks,
    in  string     barcode_whitelist,
    in  json       bc_counts,
    in  string     reference_path,
    out bam        possorted_bam        "bam file sorted by position",
    out bam.bai    possorted_bam_index  "position-sorted bam index",
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    out csv        singlecell_mapping,
    out bam[]      read_paired_bam,
)
{
    call ATTACH_BCS(
        align                   = self.align,
        chunks                  = self.chunks,
        paired_end              = true,
        barcode_whitelist       = self.barcode_whitelist,
        exclude_non_bc_reads    = false,
        bc_confidence_threshold = 0.975,
        bc_counts               = self.bc_counts,
    ) using (
        volatile = true,
    )

    call SORT_READS_BY_POS(
        input = ATTACH_BCS.output,
    ) using (
        volatile = true,
    )

    call MARK_DUPLICATES(
        input              = SORT_READS_BY_POS.tagsorted_bam,
        reference_path     = self.reference_path,
        barcode_whitelist  = self.barcode_whitelist,
        raw_barcode_counts = self.bc_counts,
    ) using (
        volatile = true,
    )

    return (
        possorted_bam       = MARK_DUPLICATES.output,
        possorted_bam_index = MARK_DUPLICATES.index,
        singlecell_mapping  = MARK_DUPLICATES.singlecell_mapping,
        fragments           = MARK_DUPLICATES.fragments,
        fragments_index     = MARK_DUPLICATES.fragments_index,
        read_paired_bam     = ATTACH_BCS.output,
    )
}

#
# @include "_peak_caller_stages.mro"
#

stage COUNT_CUT_SITES(
    in  path       reference_path,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    out bedgraph   cut_sites,
    out pickle     count_dict,
    src py         "stages/processing/count_cut_sites",
) split (
    in  string     contig,
) using (
    # N.B. We explicitly bind the index file
    volatile = strict,
)

stage DETECT_PEAKS(
    in  bedgraph cut_sites,
    in  path     reference_path,
    in  pickle   count_dict,
    out bed      peaks,
    out json     peak_metrics,
    src py       "stages/processing/detect_peaks",
) split (
    in  string   contig,
    in  float[]  params,
    in  float    threshold,
) using (
    mem_gb   = 6,
    # N.B. We explicitly bind the index file
    volatile = strict,
)

#
# @include "_peak_caller.mro"
#

pipeline _PEAK_CALLER(
    in  path       reference_path,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    out bedgraph   cut_sites,
    out bed        peaks,
    out json       peak_metrics,
)
{
    call COUNT_CUT_SITES(
        reference_path  = self.reference_path,
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
    )

    call DETECT_PEAKS(
        reference_path = self.reference_path,
        cut_sites      = COUNT_CUT_SITES.cut_sites,
        count_dict     = COUNT_CUT_SITES.count_dict,
    )

    return (
        cut_sites    = COUNT_CUT_SITES.cut_sites,
        peaks        = DETECT_PEAKS.peaks,
        peak_metrics = DETECT_PEAKS.peak_metrics,
    )
}

#
# @include "_basic_sc_atac_counter_stages.mro"
#

stage GENERATE_PEAK_MATRIX(
    in  string reference_path,
    in  tsv.gz fragments,
    in  bed    peaks,
    out h5     raw_matrix,
    out path   raw_matrix_mex,
    src py     "stages/processing/generate_peak_matrix",
) split (
    in  file   barcodes,
) using (
    mem_gb   = 4,
    # N.B. we don't explicitly need the fragment index
    volatile = strict,
)

stage FILTER_PEAK_MATRIX(
    in  h5   raw_matrix,
    in  int  num_analysis_bcs,
    in  int  random_seed,
    in  csv  cell_barcodes,
    out h5   filtered_matrix,
    out path filtered_matrix_mex,
    src py   "stages/processing/filter_peak_matrix",
) split (
) using (
    volatile = strict,
)

#
# @include "_produce_cell_barcodes_stages.mro"
#

stage REMOVE_LOW_TARGETING_BARCODES(
    in  bed        peaks,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    out json       barcode_counts,
    out json       low_targeting_barcodes,
    out json       low_targeting_summary,
    out json       fragment_lengths,
    out json       covered_bases,
    src py         "stages/processing/cell_calling/remove_low_targeting_barcodes",
) split (
    in  string     contig,
    out pickle     fragment_counts,
    out pickle     targeted_counts,
    out int        peak_coverage,
) using (
    mem_gb   = 4,
    volatile = strict,
)

stage REMOVE_GEL_BEAD_DOUBLET_BARCODES(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    in  json       barcode_counts,
    out json       gel_bead_doublet_barcodes,
    out json       gel_bead_doublet_summary,
    out csv        connect_matrix,
    src py         "stages/processing/cell_calling/remove_gel_bead_doublet_barcodes",
) split (
    in  string     contig,
    in  file       valid_barcodes,
) using (
    mem_gb   = 4,
    volatile = strict,
)

stage REMOVE_BARCODE_MULTIPLETS(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    in  string     barcode_whitelist,
    in  json       barcode_counts,
    out json       barcode_multiplets,
    out json       barcode_multiplets_summary,
    src py         "stages/processing/cell_calling/remove_barcode_multiplets",
) split (
    in  string     contig,
    in  string     gem_group,
    out npy.gz     part_a_linkage_matrix,
    out npy.gz     part_b_linkage_matrix,
) using (
    mem_gb   = 4,
    volatile = strict,
)

stage MERGE_EXCLUDED_BARCODES(
    in  json[] barcode_exclusions,
    out json   excluded_barcodes,
    src py     "stages/processing/cell_calling/merge_excluded_barcodes",
)

stage DETECT_CELL_BARCODES(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     barcode_whitelist,
    in  json       excluded_barcodes,
    in  map        force_cells,
    in  string     reference_path,
    in  bed        peaks,
    out csv        cell_barcodes,
    out csv        singlecell,
    out json       cell_calling_summary,
    src py         "stages/processing/cell_calling/detect_cell_barcodes",
) split (
    in  string     contig,
    out pickle     barcode_counts,
    out pickle     targeted_counts,
    out int        fragment_depth,
) using (
    mem_gb   = 4,
    volatile = strict,
)

# TODO: This should be in mro/common for general use
stage MERGE_SUMMARY_METRICS(
    in  json[] summary_jsons,
    out json   merged_summary,
    src py     "stages/processing/cell_calling/merge_summary_metrics",
)

#
# @include "_produce_cell_barcodes.mro"
#

pipeline _PRODUCE_CELL_BARCODES(
    in  bed        peaks,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  string     reference_path,
    in  string     barcode_whitelist,
    in  map        force_cells,
    out csv        cell_barcodes,
    out csv        singlecell,
    out json       cell_calling_summary,
    out json       excluded_barcodes,
    out json       fragment_lengths,
    out json       covered_bases,
)
{
    call REMOVE_LOW_TARGETING_BARCODES(
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
        peaks           = self.peaks,
        reference_path  = self.reference_path,
    )

    call REMOVE_GEL_BEAD_DOUBLET_BARCODES(
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
        reference_path  = self.reference_path,
        barcode_counts  = REMOVE_LOW_TARGETING_BARCODES.barcode_counts,
    )

    call REMOVE_BARCODE_MULTIPLETS(
        fragments         = self.fragments,
        fragments_index   = self.fragments_index,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        barcode_counts    = REMOVE_LOW_TARGETING_BARCODES.barcode_counts,
    )

    call MERGE_EXCLUDED_BARCODES(
        barcode_exclusions = [
            REMOVE_BARCODE_MULTIPLETS.barcode_multiplets,
            REMOVE_GEL_BEAD_DOUBLET_BARCODES.gel_bead_doublet_barcodes,
            REMOVE_LOW_TARGETING_BARCODES.low_targeting_barcodes,
        ],
    )

    call DETECT_CELL_BARCODES(
        fragments         = self.fragments,
        fragments_index   = self.fragments_index,
        barcode_whitelist = self.barcode_whitelist,
        force_cells       = self.force_cells,
        excluded_barcodes = MERGE_EXCLUDED_BARCODES.excluded_barcodes,
        reference_path    = self.reference_path,
        peaks             = self.peaks,
    )

    call MERGE_SUMMARY_METRICS as MERGE_CELL_METRICS(
        summary_jsons = [
            REMOVE_LOW_TARGETING_BARCODES.low_targeting_summary,
            REMOVE_GEL_BEAD_DOUBLET_BARCODES.gel_bead_doublet_summary,
            REMOVE_BARCODE_MULTIPLETS.barcode_multiplets_summary,
            DETECT_CELL_BARCODES.cell_calling_summary,
        ],
    )

    return (
        cell_barcodes        = DETECT_CELL_BARCODES.cell_barcodes,
        excluded_barcodes    = MERGE_EXCLUDED_BARCODES.excluded_barcodes,
        singlecell           = DETECT_CELL_BARCODES.singlecell,
        cell_calling_summary = MERGE_CELL_METRICS.merged_summary,
        fragment_lengths     = REMOVE_LOW_TARGETING_BARCODES.fragment_lengths,
        covered_bases        = REMOVE_LOW_TARGETING_BARCODES.covered_bases,
    )
}

#
# @include "_basic_sc_atac_counter.mro"
#

pipeline _BASIC_SC_ATAC_COUNTER(
    in  string     sample_id,
    in  string     fastq_mode                   "configuration of the input fastqs",
    in  map[]      sample_def,
    in  string     reference_path               "this is the reference_path",
    in  string     barcode_whitelist            "name of barcode whitelist file",
    in  map        trim_def,
    in  map        adapters,
    in  map        downsample,
    in  map        force_cells,
    out bam        possorted_bam                "bam file sorted by position",
    out bam.bai    possorted_bam_index          "position-sorted bam index",
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    out json       lot_info                     "gelbead lot detected",
    out json       read_counts                  "total # of read pairs before and after adapter trimming",
    out json       downsample_info              "info on downsampling",
    out csv        cell_barcodes,
    out json       excluded_barcodes,
    out json       cell_calling_summary,
    out bed        peaks,
    out bedgraph   cut_sites,
    out csv        singlecell_mapping,
    out csv        singlecell_cells,
    out json       peak_metrics,
    out bam[]      read_paired_bam,
    out h5         raw_peak_bc_matrix,
    out path       raw_peak_bc_matrix_mex,
    out h5         filtered_peak_bc_matrix,
    out path       filtered_peak_bc_matrix_mex,
)
{
    call _ALIGNER(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        read_group_sample = self.sample_id,
        trim_def          = self.trim_def,
        adapters          = self.adapters,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        downsample        = self.downsample,
    )

    call _SORT_AND_MARK_DUPS(
        align             = _ALIGNER.align,
        chunks            = _ALIGNER.chunks,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        bc_counts         = _ALIGNER.bc_counts,
    )

    call _PEAK_CALLER(
        fragments       = _SORT_AND_MARK_DUPS.fragments,
        fragments_index = _SORT_AND_MARK_DUPS.fragments_index,
        reference_path  = self.reference_path,
    )

    call _PRODUCE_CELL_BARCODES(
        fragments         = _SORT_AND_MARK_DUPS.fragments,
        fragments_index   = _SORT_AND_MARK_DUPS.fragments_index,
        peaks             = _PEAK_CALLER.peaks,
        force_cells       = self.force_cells,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
    )

    call GENERATE_PEAK_MATRIX(
        reference_path = self.reference_path,
        fragments      = _SORT_AND_MARK_DUPS.fragments,
        peaks          = _PEAK_CALLER.peaks,
    )

    call FILTER_PEAK_MATRIX(
        num_analysis_bcs = null,
        cell_barcodes    = _PRODUCE_CELL_BARCODES.cell_barcodes,
        raw_matrix       = GENERATE_PEAK_MATRIX.raw_matrix,
        random_seed      = null,
    )

    return (
        possorted_bam               = _SORT_AND_MARK_DUPS.possorted_bam,
        possorted_bam_index         = _SORT_AND_MARK_DUPS.possorted_bam_index,
        singlecell_mapping          = _SORT_AND_MARK_DUPS.singlecell_mapping,
        singlecell_cells            = _PRODUCE_CELL_BARCODES.singlecell,
        lot_info                    = _ALIGNER.lot_info,
        read_counts                 = _ALIGNER.read_counts,
        downsample_info             = _ALIGNER.downsample_info,
        cell_barcodes               = _PRODUCE_CELL_BARCODES.cell_barcodes,
        excluded_barcodes           = _PRODUCE_CELL_BARCODES.excluded_barcodes,
        cell_calling_summary        = _PRODUCE_CELL_BARCODES.cell_calling_summary,
        peak_metrics                = _PEAK_CALLER.peak_metrics,
        cut_sites                   = _PEAK_CALLER.cut_sites,
        peaks                       = _PEAK_CALLER.peaks,
        fragments                   = _SORT_AND_MARK_DUPS.fragments,
        fragments_index             = _SORT_AND_MARK_DUPS.fragments_index,
        read_paired_bam             = _SORT_AND_MARK_DUPS.read_paired_bam,
        raw_peak_bc_matrix          = GENERATE_PEAK_MATRIX.raw_matrix,
        raw_peak_bc_matrix_mex      = GENERATE_PEAK_MATRIX.raw_matrix_mex,
        filtered_peak_bc_matrix     = FILTER_PEAK_MATRIX.filtered_matrix,
        filtered_peak_bc_matrix_mex = FILTER_PEAK_MATRIX.filtered_matrix_mex,
    )
}

#
# @include "_sc_atac_metric_collector_stages.mro"
#

stage ESTIMATE_LIBRARY_COMPLEXITY(
    in  json   sequencing_summary,
    in  tsv.gz fragments,
    in  csv    cell_barcodes,
    out json   bulk_complexity,
    out json   complexity_summary,
    out json   singlecell_complexity,
    src py     "stages/metrics/estimate_library_complexity",
) split (
    in  file   barcodes,
) using (
    mem_gb   = 6,
    volatile = strict,
)

stage GENERATE_SEQUENCING_METRICS(
    in  bam[] input,
    out txt   misc_sm,
    out json  summary,
    src py    "stages/metrics/generate_sequencing_metrics",
) split (
    in  bam   chunk_bam,
) using (
    volatile = strict,
)

stage GENERATE_SINGLECELL_TARGETING(
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  bed        peaks,
    in  string     reference_path,
    out csv        singlecell,
    out json       summary,
    out csv        tss_relpos,
    out csv        ctcf_relpos,
    src py         "stages/metrics/generate_singlecell_targeting",
) split (
    in  string     contig,
    out int        read_count,
    out pickle     target_counts_by_barcode,
    out pickle     chunk_tss,
    out pickle     chunk_ctcf,
) using (
    mem_gb   = 6,
    volatile = strict,
)

stage MERGE_SINGLECELL_METRICS(
    in  string reference_path,
    in  csv    singlecell_mapping,
    in  csv    singlecell_targets,
    in  csv    singlecell_cells,
    out csv    singlecell,
    out json   summary,
    src py     "stages/metrics/merge_singlecell_metrics",
) using (
    mem_gb   = 8,
    volatile = strict,
)

stage REPORT_INSERT_SIZES(
    in  tsv.gz fragments,
    in  bool   exclude_non_nuclear,
    in  string reference_path,
    out csv    insert_sizes,
    out json   insert_summary,
    src py     "stages/metrics/report_insert_sizes",
) split (
    in  file   barcode,
    out file   total,
) using (
    volatile = strict,
)

stage REPORT_TSS_CTCF(
    in  csv  tss_relpos,
    in  csv  ctcf_relpos,
    out json summary_metrics,
    src py   "stages/metrics/report_tss_ctcf",
) using (
    volatile = strict,
)

#
# @include "_sc_atac_metric_collector.mro"
#

pipeline _SC_ATAC_METRIC_COLLECTOR(
    in  bam[]      read_paired_bam,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  bed        peaks,
    in  string     reference_path         "this is the reference_path",
    in  csv        cell_barcodes,
    in  csv        singlecell_mapping,
    in  csv        singlecell_cells,
    out json       singlecell_results,
    out csv        singlecell,
    out json       enrichment_results,
    out json       basic_summary,
    out json       insert_summary,
    out csv        insert_sizes,
    out json       bulk_complexity,
    out json       singlecell_complexity,
    out json       complexity_summary,
    out csv        tss_relpos,
    out csv        ctcf_relpos,
)
{
    call GENERATE_SINGLECELL_TARGETING(
        fragments       = self.fragments,
        fragments_index = self.fragments_index,
        peaks           = self.peaks,
        reference_path  = self.reference_path,
    )

    call MERGE_SINGLECELL_METRICS(
        reference_path     = self.reference_path,
        singlecell_mapping = self.singlecell_mapping,
        singlecell_cells   = self.singlecell_cells,
        singlecell_targets = GENERATE_SINGLECELL_TARGETING.singlecell,
    )

    call GENERATE_SEQUENCING_METRICS(
        input = self.read_paired_bam,
    )

    call ESTIMATE_LIBRARY_COMPLEXITY(
        sequencing_summary = GENERATE_SEQUENCING_METRICS.summary,
        fragments          = self.fragments,
        cell_barcodes      = self.cell_barcodes,
    )

    call REPORT_INSERT_SIZES(
        fragments           = self.fragments,
        reference_path      = self.reference_path,
        exclude_non_nuclear = true,
    )

    call REPORT_TSS_CTCF(
        tss_relpos  = GENERATE_SINGLECELL_TARGETING.tss_relpos,
        ctcf_relpos = GENERATE_SINGLECELL_TARGETING.ctcf_relpos,
    )

    return (
        ###
        singlecell            = MERGE_SINGLECELL_METRICS.singlecell,
        singlecell_results    = MERGE_SINGLECELL_METRICS.summary,
        ###
        enrichment_results    = REPORT_TSS_CTCF.summary_metrics,
        basic_summary         = GENERATE_SEQUENCING_METRICS.summary,
        insert_summary        = REPORT_INSERT_SIZES.insert_summary,
        insert_sizes          = REPORT_INSERT_SIZES.insert_sizes,
        bulk_complexity       = ESTIMATE_LIBRARY_COMPLEXITY.bulk_complexity,
        singlecell_complexity = ESTIMATE_LIBRARY_COMPLEXITY.singlecell_complexity,
        complexity_summary    = ESTIMATE_LIBRARY_COMPLEXITY.complexity_summary,
        tss_relpos            = GENERATE_SINGLECELL_TARGETING.tss_relpos,
        ctcf_relpos           = GENERATE_SINGLECELL_TARGETING.ctcf_relpos,
    )
}

#
# @include "_peak_annotator_stages.mro"
#

stage ANNOTATE_PEAKS(
    in  bed    peaks,
    in  string reference_path,
    out tsv    peak_annotation,
    src py     "stages/analysis/annotate_peaks",
) split (
    in  int    chunk_start,
    in  int    chunk_end,
) using (
    mem_gb   = 5,
    volatile = strict,
)

stage COMPUTE_GC_DISTRIBUTION(
    in  bed    peaks,
    in  string reference_path,
    out pickle GCdict,
    src py     "stages/analysis/compute_gc_dist",
) split (
) using (
    volatile = strict,
)

stage SCAN_MOTIFS(
    in  pickle globalGCdict,
    in  bed    peaks,
    in  string reference_path,
    in  float  pwm_threshold,
    out bed    peak_motif_hits,
    src py     "stages/analysis/scan_motifs",
) split (
    in  file   GCdict,
) using (
    volatile = strict,
)

stage GENERATE_TF_MATRIX(
    in  path reference_path,
    in  bed  peaks,
    in  bed  peak_motif_hits,
    in  h5   filtered_matrix,
    out h5   filtered_tf_bc_matrix,
    out path filtered_tf_bc_matrix_mex,
    out gz   tf_propZ_matrix,
    src py   "stages/analysis/generate_tf_matrix",
) split (
) using (
    volatile = strict,
)

#
# @include "_peak_annotator.mro"
#

pipeline _PEAK_ANNOTATOR(
    in  string reference_path,
    in  bed    peaks,
    in  h5     filtered_peak_bc_matrix,
    in  float  pwm_threshold,
    out h5     filtered_tf_bc_matrix,
    out path   filtered_tf_bc_matrix_mex,
    out gz     tf_propZ_matrix,
    out tsv    peak_annotation,
)
{
    call ANNOTATE_PEAKS(
        peaks          = self.peaks,
        reference_path = self.reference_path,
    )

    call COMPUTE_GC_DISTRIBUTION(
        peaks          = self.peaks,
        reference_path = self.reference_path,
    )

    call SCAN_MOTIFS(
        globalGCdict   = COMPUTE_GC_DISTRIBUTION.GCdict,
        peaks          = self.peaks,
        reference_path = self.reference_path,
        pwm_threshold  = self.pwm_threshold,
    )

    call GENERATE_TF_MATRIX(
        reference_path  = self.reference_path,
        peaks           = self.peaks,
        filtered_matrix = self.filtered_peak_bc_matrix,
        peak_motif_hits = SCAN_MOTIFS.peak_motif_hits,
    )

    return (
        filtered_tf_bc_matrix     = GENERATE_TF_MATRIX.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex = GENERATE_TF_MATRIX.filtered_tf_bc_matrix_mex,
        tf_propZ_matrix           = GENERATE_TF_MATRIX.tf_propZ_matrix,
        peak_annotation           = ANNOTATE_PEAKS.peak_annotation,
    )
}

#
# @include "_sc_atac_analyzer_stages.mro"
#

stage ANALYZER_PREFLIGHT(
    in  bed      peaks,
    in  h5       filtered_peak_bc_matrix,
    in  string[] factorization,
    in  int      tsne_perplexity,
    in  int      random_seed,
    in  float    tsne_theta,
    in  int      tsne_mom_switch_iter,
    in  int      tsne_stop_lying_iter,
    in  int      tsne_max_dims,
    in  int      tsne_input_pcs,
    in  int      tsne_max_iter,
    in  int      max_clusters,
    in  int      num_components,
    in  int      num_dr_bcs,
    in  int      num_dr_features,
    in  float    neighbor_a,
    in  float    neighbor_b,
    in  int      graphclust_neighbors,
    src py       "stages/preflight/atac_analyzer",
)

stage REDUCE_DIMENSIONS(
    in  h5       filtered_matrix,
    in  string[] factorization,
    in  int      num_dims,
    in  int      num_bcs,
    in  int      num_features,
    in  int      random_seed,
    out path     reduced_data,
    out map      reduction_summary,
    src py       "stages/analysis/reduce_dimensions",
) split (
    in  string   method,
) using (
    volatile = strict,
)

stage CLUSTER_CELLS(
    in  h5       filtered_matrix,
    in  path     reduced_data,
    in  map      reduction_summary,
    in  string[] factorization,
    in  int      minclusters,
    in  int      maxclusters,
    in  int      num_dims,
    in  int      random_seed,
    out path     clustered_data,
    out map      clustering_summary,
    src py       "stages/analysis/cluster_cells",
) split (
    in  int      n_clusters,
) using (
    volatile = strict,
)

stage PROJECT_TSNE(
    in  h5       filtered_matrix,
    in  path     reduced_data,
    in  map      reduction_summary,
    in  int      tsne_perplexity,
    in  int      tsne_max_dims,
    in  int      tsne_input_pcs,
    in  float    tsne_theta,
    in  int      tsne_max_iter,
    in  int      tsne_stop_lying_iter,
    in  int      tsne_mom_switch_iter,
    in  int      random_seed,
    in  string[] factorization,
    out path     tsne,
    out map      tsne_summary,
    src py       "stages/analysis/project_tsne",
) split (
    in  string   method,
    in  int      tsne_dims,
) using (
    volatile = strict,
)

stage RUN_GRAPH_CLUSTERING(
    in  h5       matrix_h5                 "Processed matrix",
    in  string[] factorization,
    in  path     reduced_data,
    in  map      reduction_summary,
    in  int      num_neighbors             "Use this many neighbors",
    in  float    neighbor_a                "Use larger of (a+b*log10(n_cells) neighbors or num_neighbors",
    in  float    neighbor_b                "Use larger of (a+b*log10(n_cells) neighbors or num_neighbors",
    in  int      balltree_leaf_size,
    in  string   similarity_type           "Type of similarity to use (nn or snn)",
    out h5       chunked_neighbors,
    out path     knn_clusters,
    out map      graph_clustering_summary,
    src py       "stages/analysis/run_graph_clustering",
) split (
    in  string   method,
    in  pickle   neighbor_index,
    in  h5       submatrix,
    in  int      row_start,
    in  int      total_rows,
    in  int      k_nearest,
    in  h5       use_bcs,
) using (
    volatile = strict,
)

stage COMBINE_CLUSTERING(
    in  h5   filtered_matrix,
    in  map  clustering_summary,
    in  path clustered_data,
    in  map  graph_clustering_summary,
    in  path knn_clusters,
    out path clustering,
    out map  clustering_summary,
    src py   "stages/analysis/combine_clustering",
) using (
    volatile = strict,
)

stage SUMMARIZE_ANALYSIS(
    in  tsv  peak_annotation,
    in  h5   filtered_peak_bc_matrix,
    in  h5   filtered_tf_bc_matrix,
    in  gz   tf_propZ_matrix,
    in  path reduced_data,
    in  map  reduction_summary,
    in  path clustering,
    in  map  clustering_summary,
    in  path tsne,
    in  map  tsne_summary,
    in  path enrichment_analysis,
    in  map  enrichment_analysis_summary,
    out h5   analysis,
    out path analysis_csv,
    out h5   feature_bc_matrix,
    src py   "stages/analysis/summarize_analysis",
) split (
) using (
    volatile = strict,
)

stage PERFORM_DIFFERENTIAL_ANALYSIS(
    in  bed      peaks,
    in  string   reference_path,
    in  h5       filtered_peak_bc_matrix,
    in  h5       filtered_tf_bc_matrix,
    in  string[] factorization,
    in  path     clustering,
    in  map      clustering_summary,
    out path     enrichment_analysis,
    out map      enrichment_analysis_summary,
    src py       "stages/analysis/perform_differential_analysis",
) split (
    in  string   method,
    in  string   clustering_key,
    in  int      cluster,
    out csv      tmp_diffexp,
) using (
    volatile = strict,
)

#
# @include "_sc_atac_analyzer.mro"
#

pipeline _SC_ATAC_ANALYZER(
    in  string   reference_path,
    in  bed      peaks,
    in  h5       filtered_peak_bc_matrix,
    in  string[] factorization,
    in  int      tsne_perplexity,
    in  int      random_seed,
    in  float    tsne_theta,
    in  int      tsne_mom_switch_iter,
    in  int      tsne_stop_lying_iter,
    in  int      tsne_max_dims,
    in  int      tsne_input_pcs,
    in  int      tsne_max_iter,
    in  int      max_clusters,
    in  int      num_components,
    in  int      num_dr_bcs,
    in  int      num_dr_features,
    in  float    neighbor_a,
    in  float    neighbor_b,
    in  int      graphclust_neighbors,
    out h5       analysis,
    out path     analysis_csv,
    out h5       filtered_tf_bc_matrix,
    out path     filtered_tf_bc_matrix_mex,
    out h5       feature_bc_matrix,
    out tsv      peak_annotation,
)
{
    call ANALYZER_PREFLIGHT(
        peaks                   = self.peaks,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        factorization           = self.factorization,
        tsne_perplexity         = self.tsne_perplexity,
        random_seed             = self.random_seed,
        tsne_theta              = self.tsne_theta,
        tsne_mom_switch_iter    = self.tsne_mom_switch_iter,
        tsne_stop_lying_iter    = self.tsne_stop_lying_iter,
        tsne_max_dims           = self.tsne_max_dims,
        tsne_input_pcs          = self.tsne_input_pcs,
        tsne_max_iter           = self.tsne_max_iter,
        max_clusters            = self.max_clusters,
        num_components          = self.num_components,
        num_dr_bcs              = self.num_dr_bcs,
        num_dr_features         = self.num_dr_features,
        neighbor_a              = self.neighbor_a,
        neighbor_b              = self.neighbor_b,
        graphclust_neighbors    = self.graphclust_neighbors,
    ) using (
        volatile = true,
    )

    call _PEAK_ANNOTATOR(
        reference_path          = self.reference_path,
        peaks                   = self.peaks,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        pwm_threshold           = null,
    )

    call REDUCE_DIMENSIONS(
        filtered_matrix = self.filtered_peak_bc_matrix,
        factorization   = self.factorization,
        num_dims        = self.num_components,
        num_bcs         = self.num_dr_bcs,
        num_features    = self.num_dr_features,
        random_seed     = self.random_seed,
    )

    call CLUSTER_CELLS(
        filtered_matrix   = self.filtered_peak_bc_matrix,
        reduced_data      = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary = REDUCE_DIMENSIONS.reduction_summary,
        factorization     = self.factorization,
        minclusters       = 2,
        maxclusters       = self.max_clusters,
        num_dims          = null,
        random_seed       = self.random_seed,
    )

    call PROJECT_TSNE(
        filtered_matrix      = self.filtered_peak_bc_matrix,
        reduced_data         = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary    = REDUCE_DIMENSIONS.reduction_summary,
        tsne_perplexity      = self.tsne_perplexity,
        tsne_max_dims        = self.tsne_max_dims,
        tsne_input_pcs       = self.tsne_input_pcs,
        tsne_theta           = self.tsne_theta,
        tsne_max_iter        = self.tsne_max_iter,
        tsne_stop_lying_iter = self.tsne_stop_lying_iter,
        tsne_mom_switch_iter = self.tsne_mom_switch_iter,
        random_seed          = self.random_seed,
        factorization        = self.factorization,
    )

    call RUN_GRAPH_CLUSTERING(
        matrix_h5          = self.filtered_peak_bc_matrix,
        factorization      = self.factorization,
        reduced_data       = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary  = REDUCE_DIMENSIONS.reduction_summary,
        num_neighbors      = self.graphclust_neighbors,
        neighbor_a         = self.neighbor_a,
        neighbor_b         = self.neighbor_b,
        balltree_leaf_size = null,
        similarity_type    = "nn",
    )

    call COMBINE_CLUSTERING(
        filtered_matrix          = self.filtered_peak_bc_matrix,
        clustering_summary       = CLUSTER_CELLS.clustering_summary,
        clustered_data           = CLUSTER_CELLS.clustered_data,
        graph_clustering_summary = RUN_GRAPH_CLUSTERING.graph_clustering_summary,
        knn_clusters             = RUN_GRAPH_CLUSTERING.knn_clusters,
    )

    call PERFORM_DIFFERENTIAL_ANALYSIS(
        reference_path          = self.reference_path,
        peaks                   = self.peaks,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        filtered_tf_bc_matrix   = _PEAK_ANNOTATOR.filtered_tf_bc_matrix,
        factorization           = self.factorization,
        clustering              = COMBINE_CLUSTERING.clustering,
        clustering_summary      = COMBINE_CLUSTERING.clustering_summary,
    )

    call SUMMARIZE_ANALYSIS(
        peak_annotation             = _PEAK_ANNOTATOR.peak_annotation,
        filtered_peak_bc_matrix     = self.filtered_peak_bc_matrix,
        filtered_tf_bc_matrix       = _PEAK_ANNOTATOR.filtered_tf_bc_matrix,
        tf_propZ_matrix             = _PEAK_ANNOTATOR.tf_propZ_matrix,
        reduced_data                = REDUCE_DIMENSIONS.reduced_data,
        reduction_summary           = REDUCE_DIMENSIONS.reduction_summary,
        clustering                  = COMBINE_CLUSTERING.clustering,
        clustering_summary          = COMBINE_CLUSTERING.clustering_summary,
        tsne                        = PROJECT_TSNE.tsne,
        tsne_summary                = PROJECT_TSNE.tsne_summary,
        enrichment_analysis         = PERFORM_DIFFERENTIAL_ANALYSIS.enrichment_analysis,
        enrichment_analysis_summary = PERFORM_DIFFERENTIAL_ANALYSIS.enrichment_analysis_summary,
    )

    return (
        analysis                  = SUMMARIZE_ANALYSIS.analysis,
        analysis_csv              = SUMMARIZE_ANALYSIS.analysis_csv,
        filtered_tf_bc_matrix     = _PEAK_ANNOTATOR.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex = _PEAK_ANNOTATOR.filtered_tf_bc_matrix_mex,
        feature_bc_matrix         = SUMMARIZE_ANALYSIS.feature_bc_matrix,
        peak_annotation           = _PEAK_ANNOTATOR.peak_annotation,
    )
}

#
# @include "_sc_atac_reporter_stages.mro"
#

stage SUMMARIZE_REPORTS_SINGLECELL(
    in  string reference_path,
    in  json   complexity_summary,
    in  json   cell_calling_summary,
    in  json   peak_results,
    in  json   basic_results,
    in  json   error_results_summary,
    in  json   insert_summary,
    in  json   singlecell_results,
    in  json   contam_results,
    in  json   downsample_info,
    in  json   enrichment_results,
    out json   analysis_params,
    out json   summary,
    out csv    summary_csv,
    src py     "stages/reporter/summarize_reports_singlecell",
) using (
    mem_gb = 4,
)

stage CREATE_WEBSUMMARY(
    in  string reference_path,
    in  string barcode_whitelist,
    in  json   summary_results,
    in  json   bulk_complexity,
    in  json   singlecell_complexity,
    in  string sample_id,
    in  string sample_desc,
    in  map[]  sample_def,
    in  bool   debug,
    in  csv    singlecell,
    in  csv    insert_sizes,
    in  csv    tss_relpos,
    in  csv    ctcf_relpos,
    in  h5     filtered_peak_bc_matrix,
    in  h5     analysis,
    in  json   excluded_barcodes,
    out html   web_summary,
    src py     "stages/reporter/create_websummary",
) using (
    mem_gb = 16,
)

#
# @include "_sc_atac_reporter.mro"
#

pipeline _SC_ATAC_REPORTER(
    in  string reference_path,
    in  string barcode_whitelist,
    in  json   bulk_complexity,
    in  json   cell_calling_summary,
    in  json   complexity_summary,
    in  json   basic_summary,
    in  json   peak_summary,
    in  json   singlecell_results,
    in  json   insert_summary,
    in  json   downsample_info,
    in  json   singlecell_complexity,
    in  csv    singlecell,
    in  csv    tss_relpos,
    in  csv    ctcf_relpos,
    in  string sample_id,
    in  string sample_desc,
    in  map[]  sample_def,
    in  csv    sc_insert_sizes,
    in  json   enrichment_results,
    in  h5     filtered_peak_bc_matrix,
    in  h5     analysis,
    in  json   excluded_barcodes,
    #
    out json   summary,
    out html   web_summary,
    out csv    summary_csv,
)
{
    call SUMMARIZE_REPORTS_SINGLECELL(
        reference_path        = self.reference_path,
        complexity_summary    = self.complexity_summary,
        cell_calling_summary  = self.cell_calling_summary,
        peak_results          = self.peak_summary,
        basic_results         = self.basic_summary,
        error_results_summary = null,
        insert_summary        = self.insert_summary,
        singlecell_results    = self.singlecell_results,
        contam_results        = null,
        downsample_info       = self.downsample_info,
        enrichment_results    = self.enrichment_results,
    )

    call CREATE_WEBSUMMARY(
        reference_path          = self.reference_path,
        barcode_whitelist       = self.barcode_whitelist,
        singlecell              = self.singlecell,
        tss_relpos              = self.tss_relpos,
        ctcf_relpos             = self.ctcf_relpos,
        sample_id               = self.sample_id,
        sample_desc             = self.sample_desc,
        sample_def              = self.sample_def,
        insert_sizes            = self.sc_insert_sizes,
        summary_results         = SUMMARIZE_REPORTS_SINGLECELL.summary,
        bulk_complexity         = self.bulk_complexity,
        singlecell_complexity   = self.singlecell_complexity,
        analysis                = self.analysis,
        filtered_peak_bc_matrix = self.filtered_peak_bc_matrix,
        excluded_barcodes       = self.excluded_barcodes,
        debug                   = false,
    )

    return (
        summary     = SUMMARIZE_REPORTS_SINGLECELL.summary,
        web_summary = CREATE_WEBSUMMARY.web_summary,
        summary_csv = SUMMARIZE_REPORTS_SINGLECELL.summary_csv,
    )
}

#
# @include "_atac_cloupe_stages.mro"
#

stage CLOUPE_PREPROCESS(
    in  string     pipestance_type,
    in  string     sample_id,
    in  string     sample_desc,
    in  string     reference_path,
    in  h5         analysis,
    in  h5         feature_barcode_matrix,
    in  bed        peaks,
    in  tsv.gz.tbi fragments_index,
    in  json       metrics_json,
    in  csv        aggregation_csv,
    in  json       gem_group_index_json,
    in  bool       no_secondary_analysis,
    out cloupe     output_for_cloupe,
    out json       gem_group_index_json,
    src py         "stages/cloupe/atac_cloupe_preprocess",
) split (
)

#
# @include "_preflight_stages.mro"
#

stage ATAC_COUNTER_PREFLIGHT(
    in  string   sample_id,
    in  string   fastq_mode,
    in  map[]    sample_def,
    in  string   reference_path,
    in  map      force_cells,
    in  string[] factorization,
    in  map      downsample,
    in  bool     check_executables,
    in  map      trim_def,
    src py       "stages/preflight/atac_counter",
) split (
)

stage ATAC_AGGR_PREFLIGHT(
    in  string   sample_id,
    in  string   reference_path,
    in  csv      aggr_csv,
    in  string   normalization,
    in  string[] factorization,
    in  bool     check_executables,
    src py       "stages/preflight/atac_aggr",
) split (
)

stage ATAC_REANALYZER_PREFLIGHT(
    in  string     sample_id,
    in  string     reference_path,
    in  string     barcode_whitelist,
    in  bed        peaks,
    in  csv        parameters,
    in  map        force_cells,
    in  csv        cell_barcodes,
    in  tsv.gz     fragments,
    in  tsv.gz.tbi fragments_index,
    in  csv        aggregation_csv,
    in  bool       check_executables,
    src py         "stages/preflight/atac_reanalyzer",
) split (
)

#
# @include "sc_atac_counter.mro"
#

pipeline SC_ATAC_COUNTER(
    in  string     fastq_mode,
    in  string     sample_id,
    in  map[]      sample_def,
    in  map        downsample,
    in  string     sample_desc,
    in  string     reference_path,
    in  map        trim_def,
    in  string     barcode_whitelist,
    in  map        adapters,
    in  string[]   factorization,
    in  map        force_cells,
    #
    out csv        singlecell,
    out bam        possorted_bam,
    out bam.bai    possorted_bam_index,
    out json       summary,
    out html       web_summary,
    out bed        peaks,
    out h5         raw_peak_bc_matrix,
    out path       raw_peak_bc_matrix_mex,
    out path       analysis_csv,
    out h5         filtered_peak_bc_matrix,
    out path       filtered_peak_bc_matrix_mex,
    out tsv.gz     fragments,
    out tsv.gz.tbi fragments_index,
    out h5         filtered_tf_bc_matrix,
    out path       filtered_tf_bc_matrix_mex,
    out cloupe     cloupe,
    out csv        summary_csv,
    out tsv        peak_annotation,
)
{
    call ATAC_COUNTER_PREFLIGHT as ATAC_COUNTER_PREFLIGHT_LOCAL(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        reference_path    = self.reference_path,
        force_cells       = self.force_cells,
        factorization     = self.factorization,
        downsample        = self.downsample,
        trim_def          = self.trim_def,
        check_executables = false,
    ) using (
        local     = true,
        preflight = true,
    )

    call ATAC_COUNTER_PREFLIGHT(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        reference_path    = self.reference_path,
        force_cells       = self.force_cells,
        factorization     = self.factorization,
        downsample        = self.downsample,
        trim_def          = self.trim_def,
        check_executables = true,
    ) using (
        preflight = true,
    )

    call _BASIC_SC_ATAC_COUNTER(
        sample_id         = self.sample_id,
        fastq_mode        = self.fastq_mode,
        sample_def        = self.sample_def,
        trim_def          = self.trim_def,
        adapters          = self.adapters,
        reference_path    = self.reference_path,
        barcode_whitelist = self.barcode_whitelist,
        downsample        = self.downsample,
        force_cells       = self.force_cells,
    )

    call _SC_ATAC_METRIC_COLLECTOR(
        read_paired_bam    = _BASIC_SC_ATAC_COUNTER.read_paired_bam,
        fragments          = _BASIC_SC_ATAC_COUNTER.fragments,
        fragments_index    = _BASIC_SC_ATAC_COUNTER.fragments_index,
        peaks              = _BASIC_SC_ATAC_COUNTER.peaks,
        reference_path     = self.reference_path,
        cell_barcodes      = _BASIC_SC_ATAC_COUNTER.cell_barcodes,
        singlecell_cells   = _BASIC_SC_ATAC_COUNTER.singlecell_cells,
        singlecell_mapping = _BASIC_SC_ATAC_COUNTER.singlecell_mapping,
    )

    call _SC_ATAC_ANALYZER(
        peaks                   = _BASIC_SC_ATAC_COUNTER.peaks,
        filtered_peak_bc_matrix = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        reference_path          = self.reference_path,
        factorization           = self.factorization,
        tsne_perplexity         = 30,
        tsne_max_dims           = null,
        tsne_input_pcs          = null,
        tsne_max_iter           = null,
        tsne_stop_lying_iter    = null,
        tsne_mom_switch_iter    = null,
        tsne_theta              = null,
        random_seed             = null,
        max_clusters            = 10,
        neighbor_a              = null,
        neighbor_b              = null,
        graphclust_neighbors    = null,
        num_components          = 15,
        num_dr_bcs              = null,
        num_dr_features         = null,
    )

    call CLOUPE_PREPROCESS(
        pipestance_type        = "SC_ATAC_COUNTER_CS",
        reference_path         = self.reference_path,
        sample_id              = self.sample_id,
        sample_desc            = self.sample_desc,
        analysis               = _SC_ATAC_ANALYZER.analysis,
        feature_barcode_matrix = _SC_ATAC_ANALYZER.feature_bc_matrix,
        metrics_json           = _SC_ATAC_METRIC_COLLECTOR.basic_summary,
        peaks                  = _BASIC_SC_ATAC_COUNTER.peaks,
        fragments_index        = _BASIC_SC_ATAC_COUNTER.fragments_index,
        aggregation_csv        = null,
        gem_group_index_json   = null,
        no_secondary_analysis  = false,
    )

    call _SC_ATAC_REPORTER(
        reference_path          = self.reference_path,
        barcode_whitelist       = self.barcode_whitelist,
        bulk_complexity         = _SC_ATAC_METRIC_COLLECTOR.bulk_complexity,
        singlecell_complexity   = _SC_ATAC_METRIC_COLLECTOR.singlecell_complexity,
        cell_calling_summary    = _BASIC_SC_ATAC_COUNTER.cell_calling_summary,
        complexity_summary      = _SC_ATAC_METRIC_COLLECTOR.complexity_summary,
        basic_summary           = _SC_ATAC_METRIC_COLLECTOR.basic_summary,
        peak_summary            = _BASIC_SC_ATAC_COUNTER.peak_metrics,
        singlecell_results      = _SC_ATAC_METRIC_COLLECTOR.singlecell_results,
        insert_summary          = _SC_ATAC_METRIC_COLLECTOR.insert_summary,
        downsample_info         = _BASIC_SC_ATAC_COUNTER.downsample_info,
        singlecell              = _SC_ATAC_METRIC_COLLECTOR.singlecell,
        tss_relpos              = _SC_ATAC_METRIC_COLLECTOR.tss_relpos,
        ctcf_relpos             = _SC_ATAC_METRIC_COLLECTOR.ctcf_relpos,
        sample_id               = self.sample_id,
        sample_desc             = self.sample_desc,
        sample_def              = self.sample_def,
        sc_insert_sizes         = _SC_ATAC_METRIC_COLLECTOR.insert_sizes,
        enrichment_results      = _SC_ATAC_METRIC_COLLECTOR.enrichment_results,
        filtered_peak_bc_matrix = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        analysis                = _SC_ATAC_ANALYZER.analysis,
        excluded_barcodes       = _BASIC_SC_ATAC_COUNTER.excluded_barcodes,
    )

    return (
        singlecell                  = _SC_ATAC_METRIC_COLLECTOR.singlecell,
        possorted_bam               = _BASIC_SC_ATAC_COUNTER.possorted_bam,
        possorted_bam_index         = _BASIC_SC_ATAC_COUNTER.possorted_bam_index,
        summary                     = _SC_ATAC_REPORTER.summary,
        web_summary                 = _SC_ATAC_REPORTER.web_summary,
        peaks                       = _BASIC_SC_ATAC_COUNTER.peaks,
        raw_peak_bc_matrix          = _BASIC_SC_ATAC_COUNTER.raw_peak_bc_matrix,
        raw_peak_bc_matrix_mex      = _BASIC_SC_ATAC_COUNTER.raw_peak_bc_matrix_mex,
        analysis_csv                = _SC_ATAC_ANALYZER.analysis_csv,
        filtered_peak_bc_matrix     = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        filtered_peak_bc_matrix_mex = _BASIC_SC_ATAC_COUNTER.filtered_peak_bc_matrix_mex,
        fragments                   = _BASIC_SC_ATAC_COUNTER.fragments,
        fragments_index             = _BASIC_SC_ATAC_COUNTER.fragments_index,
        filtered_tf_bc_matrix       = _SC_ATAC_ANALYZER.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex   = _SC_ATAC_ANALYZER.filtered_tf_bc_matrix_mex,
        cloupe                      = CLOUPE_PREPROCESS.output_for_cloupe,
        summary_csv                 = _SC_ATAC_REPORTER.summary_csv,
        peak_annotation             = _SC_ATAC_ANALYZER.peak_annotation,
    )
}

#
# @include "sc_atac_counter_cs.mro"
#

# Customer-facing (CS) pipeline
pipeline SC_ATAC_COUNTER_CS(
    in  string     fastq_mode                   "Input fastq configuration",
    in  string     sample_id,
    in  map[]      sample_def,
    in  map        downsample,
    in  string     sample_desc                  "Sample description",
    in  string     reference_path               "Path to 10X reference package",
    in  string[]   factorization                "Dimensionality reduction method (lsa, plsa, or pca)",
    in  map        force_cells                  "Force cell calling to a fixed number",
    #
    out csv        singlecell                   "Per-barcode fragment counts & metrics",
    out bam        possorted_bam                "Position sorted BAM file"  "possorted_bam.bam",
    out bam.bai    possorted_bam_index          "Position sorted BAM index"  "possorted_bam.bam.bai",
    out json       summary                      "Summary of all data metrics",
    out html       web_summary                  "HTML file summarizing data & analysis",
    out bed        peaks                        "Bed file of all called peak locations",
    out h5         raw_peak_bc_matrix           "Raw peak barcode matrix in hdf5 format",
    out path       raw_peak_bc_matrix_mex       "Raw peak barcode matrix in mex format"  "raw_peak_bc_matrix",
    out path       analysis_csv                 "Directory of analysis files"  "analysis",
    out h5         filtered_peak_bc_matrix      "Filtered peak barcode matrix in hdf5 format",
    out path       filtered_peak_bc_matrix_mex  "Filtered peak barcode matrix in mex format"  "filtered_peak_bc_matrix",
    out tsv.gz     fragments                    "Barcoded and aligned fragment file"  "fragments.tsv.gz",
    out tsv.gz.tbi fragments_index              "Fragment file index"       "fragments.tsv.gz.tbi",
    out h5         filtered_tf_bc_matrix        "Filtered tf barcode matrix in hdf5 format",
    out path       filtered_tf_bc_matrix_mex    "Filtered tf barcode matrix in mex format"  "filtered_tf_bc_matrix",
    out cloupe     cloupe                       "Loupe Cell Browser input file",
    out csv        summary_csv                  "csv summarizing important metrics and values"  "summary.csv",
    out tsv        peak_annotation              "Annotation of peaks with genes",
)
{
    call SC_ATAC_COUNTER(
        fastq_mode        = self.fastq_mode,
        sample_id         = self.sample_id,
        sample_def        = self.sample_def,
        downsample        = self.downsample,
        sample_desc       = self.sample_desc,
        reference_path    = self.reference_path,
        trim_def          = {
            "R1": {
                "3prime": ["MErc"],
            },
            "R2": {
                "3prime": ["MErc"],
            },
            "discard_untrimmed": false,
        },
        barcode_whitelist = "737K-cratac-v1",
        adapters          = {
            "ME": "AGATGTGTATAAGAGACAG",
            "MErc": "CTGTCTCTTATACACATCT",
        },
        factorization     = self.factorization,
        force_cells       = self.force_cells,
    )

    return (
        singlecell                  = SC_ATAC_COUNTER.singlecell,
        possorted_bam               = SC_ATAC_COUNTER.possorted_bam,
        possorted_bam_index         = SC_ATAC_COUNTER.possorted_bam_index,
        summary                     = SC_ATAC_COUNTER.summary,
        web_summary                 = SC_ATAC_COUNTER.web_summary,
        peaks                       = SC_ATAC_COUNTER.peaks,
        raw_peak_bc_matrix          = SC_ATAC_COUNTER.raw_peak_bc_matrix,
        raw_peak_bc_matrix_mex      = SC_ATAC_COUNTER.raw_peak_bc_matrix_mex,
        analysis_csv                = SC_ATAC_COUNTER.analysis_csv,
        filtered_peak_bc_matrix     = SC_ATAC_COUNTER.filtered_peak_bc_matrix,
        filtered_peak_bc_matrix_mex = SC_ATAC_COUNTER.filtered_peak_bc_matrix_mex,
        fragments                   = SC_ATAC_COUNTER.fragments,
        fragments_index             = SC_ATAC_COUNTER.fragments_index,
        filtered_tf_bc_matrix       = SC_ATAC_COUNTER.filtered_tf_bc_matrix,
        filtered_tf_bc_matrix_mex   = SC_ATAC_COUNTER.filtered_tf_bc_matrix_mex,
        cloupe                      = SC_ATAC_COUNTER.cloupe,
        summary_csv                 = SC_ATAC_COUNTER.summary_csv,
        peak_annotation             = SC_ATAC_COUNTER.peak_annotation,
    )
}

#
# @include "__PBMC_P.mro"
#

call SC_ATAC_COUNTER_CS(
    fastq_mode     = "ILMN_BCL2FASTQ",
    sample_id      = "PBMC_P",
    sample_def     = [{
        "bc_in_read": 1,
        "bc_length": 16,
        "gem_group": null,
        "lanes": null,
        "library": "LibraryNotSpecified",
        "read_path": "/data/isshamie/dropbox/ATACseq/2020_11_18_Croker/igm-storage2.ucsd.edu/201113_A00953_0185_AHN7TMDSXY",
        "sample_indices": ["any"],
        "sample_names": ["BC_10xATAC_PMBC_P"],
    }],
    reference_path = "/data/isshamie/mito_lineage/data/external/GRCh38_MT_blacklist",
    downsample     = null,
    sample_desc    = "",
    factorization  = ["lsa"],
    force_cells    = null,
)
ew file mode 100644
ndex 0000000..9c51e87
++ b/Analysis/mtscATAC/2020_11_18_Croker/PBMC_P/_sitecheck

05b9acd82df078b27bb85950bc781cd1992f995d

account for mt_pos being a list not an int

account for mt_pos being a list not an int

# TODO account for mt_pos being a list not an int

import numpy as np
from numpy import random
import os
import pandas as pd
from tqdm import tqdm
#from src.config import ROOT_DIR
from sklearn.metrics import roc_curve, average_precision_score, confusion_matrix
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import glob
from sklearn.cluster import KMeans
from sklearn import metrics
from scipy.spatial.distance import cdist
from pandarallel import pandarallel

from mplh.color_utils import get_colors
from mplh.fig_utils import legend_from_color
from mplh import cluster_help as ch
from src.simulations.utils.config import read_config_file, write_config_file

from dynamicTreeCut import cutreeHybrid
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage
from sklearn.model_selection import ParameterGrid
from src.simulations.utils.config import check_required

from .simulation import Simulation


# Does this ruin running the MCMC? I don't think so, b/c that format is going to be put in after anyway
class FullSimulation:
    """
    Class that simulates cell growth for lineage tracing. Reads in a
    parameter file and runs a certain number of iterations based on
    the num_iterations parameter.

    :ivar n_iter: Number of iterations
    :type n_iter: int
    :ivar num_cells: Number of cells to sequence
    :type num_cells: int

    :ivar sim: Each index is a different iteration of the simulation.
    :type sim: pandas Series

    """
    def __init__(self, params_f):
        params = read_config_file(params_f)
        self.n_iter = params['num_iterations']
        self.num_cells = params['num_cells']
        self.params = params

        # Store the metrics with this
        self.metrics = dict()

        # Files to save
        self.outdir = os.path.join(self.params['local_outdir'])
        self.data_outdir = os.path.join(self.params['data_outdir'])
        self.f_save_data = os.path.join(self.data_outdir,
                                   self.params['name'] + '.p')
        self.f_save = os.path.join(self.outdir, self.params['name'] + '.p')
        self.f_save_metrics = self.f_save.replace('.p', '.metrics.tsv')
        self.f_save_cluster = self.f_save.replace('.p', '.cluster.tsv')
        self.f_save_befaft = self.f_save.replace('.p', '.before_after.tsv')
        self.f_save_rocs = self.f_save.replace('.p', '.rocs.p')


        return
        #for i in self.n_iter:

    def run(self):
        """
        Runs the simulation and stores it in sim attr. Will also pickle
        the objects and save.

        This uses Pandaralel to parallelize the runs.
        :return:
        """
        # Parallelize df
        df = pd.Series(index=range(self.n_iter))
        df = df.apply(self.run_sim, args=(self.params,))

        #pandarallel.initialize(nb_workers=self.params['cpus'])
        #df = df.parallel_apply(self.run_sim, args=(self.params,))

        self.sim = df
        return

    @staticmethod
    def run_sim(x, params):
        """Run iteration of simulation.

        For a single iteration, it will initialize, grow, subsample,
        and merge the before stimulus and after stimulus variables.
        It willl also run
        :param x: Placeholder variable
        :param params: The parameter dictionary to use
        :type params: dict
        :return:
        """
        s = Simulation(params)
        s.initialize()
        s.grow()
        s.subsample_new(to_delete=True)
        s.combine_init_growth()
        return s

    def run_metrics(self):
        """
        Get metrics performances and save.
        :return:
        """
        self.sim_performance_dominant(group='both')
        self.stats_before_after()
        # self.cluster_before_after()


    def flatten_sim(self):
        ## TODO
        # This will extract out the classes of df
        return

    def sim_performance_dominant(self, group='both'):
        """
        Will colect metrics that are averaged over the simulations.
        These are specifically for looking at the main, dominant clone,
        and what the allele-frequency of that clone variant
        is for each cell.

        :param group: {'init', 'growth', 'both'} This will indicate to group by
        :ivar dropout: Number of dominant clone cells that have 0 reads
        at the lineage variant position.
        :type dropout: list
        :ivar prec_scores: sklearn average precision score based on
        the allele frequencies seen in the dominant clone cells versus
        the non-clone cells.
        :type prec_scores: list
        :ivar rocs: ROC curves for each iteration based on allele
        frequencies.

        :return:
        """
        dropout = []
        rocs = []
        prec_scores = []


        for iter, s in enumerate(self.sim.values):
            # First get the dominant clone , which is indexed as 1
            mt_pos = s.clone_mt_dict[1]
            # TODO account for mt_pos being a list not an int
            if group == 'init':
                clones = s.clone_cell
                cell_af = s.cell_af.loc[:,mt_pos]
            elif group == 'growth':
                clones = s.new_clone_cell
                cell_af = s.new_cell_af.loc[:,mt_pos]
            elif group == 'both':
                #clones = pd.concat((s.clone_cell, s.subsample_new_clone_cell)).reset_index(drop=True)
                #cell_af = pd.concat((s.cell_af.loc[:,mt_pos], s.subsample_new_cell_af.loc[:,mt_pos])).reset_index(drop=True)
                clones = s.combined_clones
                cell_af = s.combined_cell_af.loc[:,mt_pos]
            else:
                raise ValueError('group variable not properly set.')

            y_true = clones.values.copy()
            y_true[y_true != 1] = 0  # Set nondominant clones to 0
            rocs.append(roc_curve(y_true, cell_af))
            prec_scores.append(average_precision_score(y_true, cell_af))
            dropout.append((cell_af[clones==1]==0).sum()/cell_af.shape[0])

        self.dropout = dropout
        self.prec_scores = prec_scores
        self.rocs = rocs
        pd.DataFrame([prec_scores, dropout], index=['Precision', 'Dropout']).to_csv(self.f_save_metrics, sep='\t')
        self.metrics['prec_scores'] = prec_scores
        self.metrics['dropout'] = dropout
        self.metrics['rocs'] = rocs
        pickle.dump(rocs, open(self.f_save_rocs, 'wb'))

        return


    def reduce_cells(self, cell_af):
        #self.sim
        return


    def stats_before_after(self, clone_id=1):
        b_a_df = pd.DataFrame(index=np.arange(0,len(self.sim)), columns=["Before", "After", "A/B"], dtype=str)
        for iter, s in enumerate(self.sim.values):
            b_clones = s.clone_cell
            a_clones = s.subsample_new_clone_cell
            b_a_df.at[iter, "Before"] = (b_clones == clone_id).sum()
            b_a_df.at[iter, "After"] = (a_clones==clone_id).sum()
            b_a_df.at[iter,"A/B"] = (b_a_df.at[iter, "After"]/b_a_df.at[iter, "Before"])
        self.b_a_df = b_a_df
        b_a_df.to_csv(self.f_save_befaft, sep='\t')
        self.metrics['b_a_df'] = b_a_df
        return


    def cluster_before_after(self):
        """
        Loops through the simulations and for each,
        it clusters the cells.

        :ivar cluster_results: Cluster labels for each cell in each
        iteration.
        :type List of tuples, which is a list of
        a tuple, where the tuple is indexed by the cell and the value
        is the cell's cluster label
        """
        cluster_results = []
        print('clustering')
        for s in tqdm(self.sim.values):
            cluster_results.append(s.cluster(s.combined_cell_af))
            print(len(cluster_results[-1]))
        self.cluster_results = cluster_results


    def stats_cluster_before_after(self, clone_id=1):
        """
        Confusion matrix for clustering the proper clone cells together.
        :param clone_id: Which clone to get metrics for
        :return:
        """


        b_a_df = pd.DataFrame(index=len(self.sim),
                              columns=["TN", "FP", "FN", "TP"], dtype=int)
        for ind, s in enumerate(self.sim.values):
            y_true = s.combined_clones
            y_true[y_true!=1] = 0
            y_pred = self.cluster_results[ind]

            # y_true, y_pred
            tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
            b_a_df.loc[ind] = [tn, fp, fn, tp]
        self.b_a_df = b_a_df
        return


    def save(self, f_save=None):
        if f_save is None:
            f_save = self.f_save_data
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()


    def load(self, f_save=None):
        #filename = self.params['filename']
        if f_save is None:
            f_save = self.f_save
        f = open(f_save, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)


def main():
    return


if "__name__" == "__main__":
    main()
 No newline at end of file
ew file mode 100644
ndex 0000000..d2fefc8
++ b/src/simulations/parametersweep.py

d5911a552bd629d435543242f1c84337fb7a3e6a

Add the MT clone map so it can contain multiple mutants in lineages

Add the MT clone map so it can contain multiple mutants in lineages

# # TODO Add the MT clone map so it can contain multiple mutants in lineages

# from numpy import random
# import os
# import pandas as pd
# import pickle
# from src.simulations.utils.config import read_config_file, write_config_file
# from src.simulations.utils.config import check_required
import pymc3 as pm
import matplotlib.pyplot as plt

num_cells = 10000
num_mt_positions = 10
clone_dist = [0.10,0.01,.89]
hets = [0.2,0.3] # len(hets) == len(clone_dist)-1
avg_cov = 50
het_err_rate = 0.1


df = np.concatenate((np.random.binomial(10,0.3,(100,4)),
                       np.random.binomial(10,0.6,(90,4))))
clone_id = np.concatenate((np.zeros([100,]), np.ones([90,]))).astype(int)
mt_id = [0,1,2,3]

with pm.Model() as model:

    clone_ids = pm.Mulinomial(10000, clone_dist)

    beta = pm.Beta('beta', alpha=2,beta=2, shape=2)
    #p = pm.Bernoulli('p', 1, beta, shape=2)
    #p = pm.Binomial('p', 1, beta)
    #q = pm.Binomial('q', 1, beta)

    s = pm.Binomial('s', 10, beta[clone_id], observed=df)

    #s = pm.Binomial('s', 10, p, observed=df[:30,0])
    #t = pm.Binomial('t', 10, q, observed=df[30:, 0])

    #s = pm.Binomial('s', 10, beta, shape=(30,4), observed=df[:30])
    #t = pm.Binomial('t', 10, beta, shape=(25, 4), observed=df[30:])

    #vec = pm.math.concatenate((s, t), axis=0)

    # data = pm.Data("data", df)
    # u = pm.Normal('u', vec, observed=data)
    #u = pm.Deterministic('u', vec)

    trace = pm.sample(draws=8000, init='adapt_diag')

print(pm.summary(trace))
dot = pm.model_to_graphviz(model)
dot.render('simulation_pymc.gv')
pm.plot_trace(trace)
plt.savefig('simulation_trace.png')
print('here')
#
# with pm.Model() as model:
#     clone_counts = pm.Multinomial(num_cells, clone_dist)
#     num_clones = len(clone_counts) - 1
#
#     clone_cell = -1 * np.ones(shape=[num_cells, ])
#
#     clone_cell[:clone_counts[0]] = 0
#
#     for ind, val in enumerate(clone_counts[1:]):
#         start = clone_counts[:ind + 1].sum()
#         end = clone_counts[:ind + 1].sum() + val
#         # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
#         clone_cell[start:end] = ind + 1
#
#     c = pm.Poisson('cov', avg_cov, shape=[num_cells, num_mt_positions])
#
#     clone_mt_dict = dict()
#     for i in range(1, num_clones + 1):
#         clone_mt_dict[i] = i
#
#     cell_af = np.zeros([num_cells, num_mt_positions])
#     for ind in range(num_clones):
#         # Generate AF: (clone_df ==  ind).sum()
#         n_dom_cells = clone_counts[ind]
#         het = hets[ind]
#
#         curr_mt = clone_mt_dict[ind]
#
#         af_i = pm.Binomial('af', avg_cov, het, shape=n_dom_cells)
#         af_j = pm.Binomial('het af', avg_cov, het_err_rate, shape=num_cells - n_dom_cells) # / c
#
#
#         # Update the dom_cells and non_dom for the current MT
#         cell_af[np.flatnonzero(clone_df == ind), curr_mt] = af_i
#         cell_af[np.flatnonzero(clone_df != ind), curr_mt] = af_j
#
#     cell_af = pm.Deterministic(y)
#
# pm.model_to_graphviz(model)
#
#
# def init_cell_af(self):
#     """1C. Initialize the cell-by-mtPos af dataframe. Unless a clone:mt dict was
#     provided, the first N MT positions will be the clone AFs. Creates
#     self.clone_mt_dict and self.cell_af
#     """
#     clone_df = self.clone_cell
#     # Output
#     cell_af = pd.DataFrame(np.zeros(shape=[n_cells, n_mt]))
#
#         # Each clone points to a mt position
#         self.clone_mt_dict = dict()
#         for i in range(1, num_clones + 1):
#             self.clone_mt_dict[i] = i
#
#     # TODO Add the MT clone map so it can contain multiple mutants in lineages
#
#     # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
#     # If not, will randomly draw based on number of clones
#     if type(hets) == list:
#         assert (len(hets) == num_clones)
#
#         ## Loop through each clone,
#         ## Generate the AF for the clone and non-clones using coverage for each cell
#         ## Fill in cell_by_af for that position.
#         for ind in range(1, num_clones + 1):
#             # Generate AF: (clone_df ==  ind).sum()
#             n_dom_cells = (clone_df == ind).sum()
#             het = hets[ind - 1]
#
#             curr_mt = self.clone_mt_dict[ind]
#
#             if p['coverage']['type'] == 'constant':
#                 c = p['coverage']['cov_constant']
#
#                 af_i = random.binomial(c, het, n_dom_cells) / c
#                 af_j = random.binomial(c, q, n_cells - n_dom_cells) / c
#
#                 # Update the dom_cells and non_dom for the current MT
#                 cell_af.loc[
#                     np.flatnonzero(clone_df == ind), curr_mt] = af_i
#                 cell_af.loc[
#                     np.flatnonzero(clone_df != ind), curr_mt] = af_j
#
#             # Each cell and position has it's own coverage value, so need to update each
#             else:
#                 c = self.cells_mt_coverage
#                 # Get the cells coverage for the mt position
#                 curr_mt_cov = c[:, curr_mt]
#
#                 # Get cell indicies for the clones and nonclones
#                 curr_clone_inds = np.flatnonzero(clone_df == ind)
#                 curr_nonclone_inds = np.flatnonzero(clone_df != ind)
#                 for cell in curr_clone_inds:
#                     # Get one value for curr_mt and cell based on coverage
#                     cell_af.loc[cell, curr_mt] = random.binomial(
#                         curr_mt_cov[cell], het)
#                 for cell in curr_nonclone_inds:
#                     cell_af.loc[cell, curr_mt] = random.binomial(
#                         curr_mt_cov[cell],
#                         q)  # Loop through each coverage  # for c in n_dom_cells:
#
#
# class Simulation:
#     """Lineage tracing simulation of one sample
#
#     Will initialize cells based on their parameters and grow as well. This
#     should be a flexible framework, to add different ways to initialize, grow,
#     and metrics to have. Additionally can cluster these results.
#
#     :ivar params
#     :type params: dict
#     """
#
#     def __init__(self, params_f):
#         """
#         :param params_f: Parameter yaml file for the specifications
#         :type params_f: yaml file or dict
#         """
#         if isinstance(params_f, str):
#             params = read_config_file(params_f)
#         else:
#             params = params_f
#
#         self.params = params
#         check_required(params, ['initialize', 'num_cells', 'num_mt_positions', 'prefix'])
#         self.prefix = params['prefix']
#         self.num_mt_positions = params['num_mt_positions']
#         self.num_cells = params['num_cells']
#         if not os.path.exists(params['local_outdir']):
#             os.mkdir(params['local_outdir'])
#
#
#     def initialize(self):
#         """ (1) Pre-growth cell population is instantiated.
#
#         Creates a clone-MT dictionary, cell coverage matrix
#         (or an int, depending on parameters), and cell-AF matrix.
#         :return:
#         """
#         self.init_clone_dict()
#         self.init_cell_coverage()
#         self.init_cell_af()
#         #self.init_clone_mt()
#
#     #should be external method
#     def grow(self):
#         """ (2) Growth of cells is run."""
#         p = self.params
#         type = p["growth"]["type"]
#         if  type == "poisson":
#             self.grow_poisson(p['growth']['poisson'])
#         elif type == "binomial":
#             self.grow_binomial(p['growth']['binomial'])
#         return
#
#     # Static Method
#     @staticmethod
#     def clone_counts_to_cell_series(clone_counts):
#         """ Generates new cell IDs based on cluster count iterable
#         :param clone_counts: Each i'th element is the number of cells in
#         cluster i.
#         :type clone_counts: iterable
#         :return each index name is a cell ID and each value is which cluster
#         the cell belongs too.
#         :rtype pd.Series
#         """
#         clone_counts = np.array(clone_counts)
#         num_cells = clone_counts.sum()
#         clone_cell = -1 * np.ones(shape=[num_cells, ])
#
#         clone_cell[:clone_counts[0]] = 0
#         for ind, val in enumerate(clone_counts[1:]):
#             start = clone_counts[:ind + 1].sum()
#             end = clone_counts[:ind + 1].sum() + val
#             # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
#             clone_cell[start:end] = ind + 1
#
#         clone_cell = pd.Series(clone_cell, dtype=int)
#         return clone_cell
#
#     def init_clone_dict(self):
#         """1A
#         """
#
#         ### Add in potential to overwrite the values
#         # Gets the clone dictionary. Should also have clone to mt dict.
#         clones = self.params['initialize']['clone_sizes']
#         num_cells = self.num_cells
#
#         # Option 1: List of fraction of size of each clone. 0s are nonclone size, listed first
#         if type(clones) == list:
#             #clone_cell = pd.Series(index=range(num_cells))
#             clone_counts = np.random.multinomial(num_cells, clones)
#             clone_cell  = self.clone_counts_to_cell_series(clone_counts)
#             self.clone_cell = clone_cell
#         # Option 2: 1 clone. ID'd as 1
#         elif type(clones) == int: #One number for dominant clone. the others are not.
#             clone_cell = np.zeros(shape=[num_cells,])
#             clone_cell[:num_cells] = 1
#             clone_cell = clone_cell[::-1]
#             clone_cell =  pd.Series(clone_cell, dtype=int)
#             self.clone_cell = clone_cell
#
#         # Option 3 To ADD, beta binomial and more complex distributions
#
#         self.num_clones =  len(set(clone_cell.values))-1 # Remove the non-clone
#         return clone_cell
#
#
#     def init_cell_coverage(self):
#         """1B
#
#         There are different modes to the coverage, either a constant or
#         through a distribution.
#         """
#         p = self.params['initialize']['coverage']
#         type = p['type']
#
#         num_cells = self.num_cells
#         num_pos = self.num_mt_positions
#         c = np.zeros([num_cells, num_pos])
#
#         if type == 'constant':
#             c[:, :] = p['cov_constant']
#         elif type == "poisson":
#             # Get the number of coverage per cell based on poisson (should be reads)
#             mu_cov_per_cell = p['mu_cov_per_cell']
#             num_reads_per_cell = random.poisson(lam=mu_cov_per_cell,
#                                                 size=num_cells)
#
#             # Number of reads at each position, based on the average for each cell
#             for i in num_cells:
#                 c[i, :] = random.poisson(num_reads_per_cell[i],
#                                          size=num_pos)
#         self.cells_mt_coverage = c
#         return c
#
#
#
#         #####
#         # TODO
#         # Add noise to the other non-lineage positions
#         #####
#         self.cell_af = cell_af
#         return
#
#
#     def init_clone_mt(self):
#         p = self.params
#         if p["initialize"]['type'] == 'growth':
#             ## TODO
#             # Create a phylogeny and then get the averages of the mutants
#             self.average_clone_mt()
#         # If not growth, should aready be there.
#         return
#
#     def average_clone_mt(self):
#         return
#
#     def extract_clone_cells(self, clone_id):
#         """
#         Args:
#             clone_id:
#         """
#         ids = np.flatnonzero(self.clone_cell == clone_id)
#         return ids
#
#     def simulate_expand_cells_af(self, af, growth_inds, sigma):
#         """Given a cell-by-af vector, expand the AF.
#
#         Expanded AF occurs by duplicating cells that grew based on the
#         growth_inds vector. It will add standard error to each af based on sigma
#         :param af: :param growth: Indices of AF to copy :param sigma: Variance
#         to add to AF of child. :return:
#
#         Args:
#             af:
#             growth_inds:
#             sigma:
#         """
#
#         new_af = af.iloc[growth_inds].copy() + random.normal(0, sigma, size=af.iloc[growth_inds].shape)
#         new_af.index = np.arange(af.index[-1]+1, af.index[-1]+1+new_af.shape[0])
#         new_af = pd.concat((af,new_af), axis=0)
#         #new_af = np.append(af, np.concatenate(new_af))
#         return new_af
#
#     def grow_binomial(self, p):
#         """ (2.1)
#         Args:
#             p:
#         """
#         timesteps = p["time_steps"]
#         rates = p["rates"]
#
#         sigma = self.params['growth']["mutant_af_sigma_noise"]
#         cell_af = self.cell_af
#         clone_mt_dict = self.clone_mt_dict
#
#         num_clones = self.num_clones+1
#         new_dict = {}
#         for curr_clone in range(num_clones):
#             curr_rate = rates[curr_clone]
#             ids = self.extract_clone_cells(curr_clone)
#             new_cells = cell_af.loc[ids].copy()
#             for i in range(timesteps):
#                 # Simulate growth for each clone separately.
#                 growth_inds = np.flatnonzero(random.binomial(1, curr_rate, size=new_cells.shape[0]))
#                 #new_ids =
#                 new_cells = self.simulate_expand_cells_af(new_cells, growth_inds, sigma)
#
#             new_dict[curr_clone] = new_cells
#             # Create list of cells
#
#         ####TODO
#         ## new_lineage_mutants chances. This will see if a mutation will change
#
#
#         ####TODO
#         ## Add death + stimulation rate as well as growth
#         # Save the new cell clones df and cell af
#         clone_counts = [i.shape[0] for i in new_dict.values()]
#         self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)
#
#         self.new_cell_af = pd.DataFrame(new_dict[0])
#         for clone in range(1, self.num_clones+1):
#             self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
#         return
#
#
#     def grow_poisson(self):
#         # TODO growth of poisson refactor
#         return
#
#
#     def subsample_new(self, to_delete=False):
#         """(3) Subsample from new cell population
#
#         :param to_delete: To remove the cells that grew (which takes up
#         a lot of RAM).
#         :type to_delete: bool
#         """
#         new_cell_af = self.new_cell_af
#         p = self.params
#         if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
#             self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
#         else:
#             self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)
#
#         self.subsample_new_clone_cell = self.new_clone_cell.loc[
#             self.subsample_new_cell_af.index]
#
#         if to_delete:
#             self.new_cell_af = None
#             self.new_clone_cell = None
#
#
#     def combine_init_growth(self):
#         """(4) Add the pre- and post- population of cells into a group.
#
#         :return:
#         """
#         combined_cell_af = self.cell_af.append(self.subsample_new_cell_af).reset_index(drop=True)
#         combined_clones = pd.concat(
#             (self.clone_cell, self.subsample_new_clone_cell)).reset_index(
#             drop=True)
#
#
#         combined_befaft = np.concatenate((np.zeros(shape=[self.cell_af.shape[0],]), np.ones(shape=[self.subsample_new_cell_af.shape[0]])))
#         combined_meta = pd.DataFrame({"pre_post": combined_befaft, "clone": combined_clones})
#         #combined_meta = pd.Series(combined_meta, name='After Growth', dtype=int)
#         assert(combined_meta.shape[0] == self.cell_af.shape[0]+self.subsample_new_cell_af.shape[0])
#         assert (combined_cell_af.shape[0] == self.cell_af.shape[0] +
#                 self.subsample_new_cell_af.shape[0])
#         assert(combined_meta.shape[0] == combined_clones.shape[0])
#         assert(combined_cell_af.shape[0] == combined_clones.shape[0])
#         self.combined_meta = combined_meta
#         self.combined_clones = combined_clones
#         self.combined_cell_af = combined_cell_af
#         return
#
#     def save(self, f_save=None):
#         """
#         Args:
#             f_save:
#         """
#         if f_save is None:
#             f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
#         f = open(f_save, 'wb')
#         pickle.dump(self.__dict__, f, 2)
#         f.close()
#
#     @staticmethod
#     def expand_to_mgatk(curr_mt_af,mt_ref):
#         ref = mt_ref[curr_mt_af.name]
#         pos = curr_mt_af.name
#         return pd.DataFrame({"Ref":ref, "Pos":pos, "Val":curr_mt_af})
#
#     def test_save_to_mgatk_format(self):
#         df = pd.DataFrame( [[10,0,1,3,5], [3,0,5,5,0], [6,2,1,1,0]] , columns=np.arange(0,5))
#         mt_ref_dict = {0: "A", 1: "G", 2: "C", 3: "C", 4: "T"}
#         mt_ref = pd.DataFrame({"Pos": mt_ref_dict.keys(), "Ref": mt_ref_dict})
#         return
#
#     def save_to_mgatk_format(self, mt_ref, out_f):
#         """Converts into the proper files needed for mgatk. (i.e variant and
#         coverage files)
#
#         :return:
#         """
#         cell_af = self.subsample_new_cell_af
#         chars = ["A", "G", "C", "T"]
#         def alt_generate(x):
#             curr = chars.copy()
#             curr.remove(x["Ref"])
#             return np.random.choice(curr)
#         alt_ref = mt_ref.apply(alt_generate, axis=1)
#
#         # First use the AF and choose an alternative allele
#         df_stack = cell_af.stack().reset_index().rename(
#             {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
#             axis=1)
#         df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
#             lambda x: alt_ref[x])
#
#         # Add on the reference allele
#         df_stack_ref = cell_af.stack().reset_index().rename(
#             {"level_0": "Cell", "level_1": "MT_pos", 0: "Coverage"},
#             axis=1)
#         df_stack_ref["Coverage"] = 1-df_stack_ref["Coverage"]
#         df_stack["Nucleotide"] = df_stack["MT_pos"].apply(
#             lambda x: mt_ref[x])
#
#         df_stack = pd.concat(df_stack, df_stack_ref)
#         for ind, val in df_stack.groupby("Nucleotide"):
#             # Drop the 0s
#             curr = val[val["Coverage"]>0]
#             # Save file
#             curr_out_f = out_f + "_" + ind + ".txt"
#             curr.to_csv(curr_out_f)
#
#         # Save the coverage.
#         coverage = self.cells_mt_coverage
#         if type(coverage) != int:
#             coverage_stack = pd.DataFrame(coverage).stack().reset_index().rename(
#                 {"level_0": "Cell", "level_1": "MT Position", 0: "Coverage"},
#                 axis=1)
#         else:
#             coverage_stack = pd.DataFrame(self.cells_mt_coverage)*np.ones(shape=cell_af.shape).stack().reset_index().rename(
#                 {"level_0": "Cell", "level_1": "MT Position",  0: "Coverage"},
#                 axis=1)
#         curr_out_f = out_f + "_" + "coverage.txt"
#         coverage_stack.to_csv(curr_out_f)
#         return
#
#     def load(self):
#         filename = self.params['filename']
#         f = open(filename, 'rb')
#         tmp_dict = pickle.load(f)
#         f.close()
#         self.__dict__.update(tmp_dict)
#
#     def compare_before_after(self):
#         """Creates a df that contains information on the number of cells from
#         each clone before as well as after. :return: df.at[ind, "Dominant
#         Before"] = (full_sim.clone_cell == 1).sum() df.at[ind, "Dominant After"]
#         = (full_sim.subsample_new_clone_cell == 1).sum()
#         """
#
#         return
#
#     def cluster_compare_before_after(self):
#         """Compares the performance of clustering on grouping the same clones
#         together. :return:
#         """
#         return
#
#
# def main():
#     return


# if "__name__" == "__main__":
#     main()
 No newline at end of file
ew file mode 100644
ndex 0000000..b48b215
inary files /dev/null and b/src/simulations/simulation_trace.png differ
eleted file mode 100644
ndex fba6e66..0000000
++ /dev/null

fb045a2f8275d49365012088b9a8933f215dd0f2

Kmeans Clustering simulation results

Feature: Kmeans and/or dynamic time and/or hierarchical clustering. Some may take longer than others.

  • Kmeans clustering with fixed number of clusters
  • Kmeans with elbow method to choose

growth of poisson refactor

growth of poisson refactor

# TODO growth of poisson refactor

from numpy import random
import os
import pandas as pd
from tqdm import tqdm
#from src.config import ROOT_DIR
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import glob
from sklearn.cluster import KMeans
from sklearn import metrics
from scipy.spatial.distance import cdist
from pandarallel import pandarallel
pandarallel.initialize(nb_workers=32)

from mplh.color_utils import get_colors
from mplh.fig_utils import legend_from_color
from mplh import cluster_help as ch
from src.simulations.utils.config import read_config_file, write_config_file

from dynamicTreeCut import cutreeHybrid
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage
from sklearn.model_selection import ParameterGrid
from src.simulations.utils.config import check_required


class Simulation:
    """Lineage tracing simulation of one sample

    Will initialize cells based on their parameters and grow as well. This
    should be a flexible framework, to add different ways to initialize, grow,
    and metrics to have. Additionally can cluster these results.

    :ivar params
    :type params: dict
    """

    def __init__(self, params_f):
        """
        :param params_f: Parameter yaml file for the specifications
        :type params_f: yaml file or dict
        """
        if isinstance(params_f, str):
            params = read_config_file(params_f)
        else:
            params = params_f

        self.params = params
        check_required(params, ['initialize', 'num_cells', 'num_mt_positions', 'prefix'])
        self.prefix = params['prefix']
        self.num_mt_positions = params['num_mt_positions']
        self.num_cells = params['num_cells']
        if not os.path.exists(params['local_outdir']):
            os.mkdir(params['local_outdir'])


    def initialize(self):
        self.init_clone_dict()
        self.init_cell_coverage()
        self.init_cell_af()
        #self.init_clone_mt()

    #should be external method
    def grow(self):
        p = self.params
        type = p["growth"]["type"]
        if  type == "poisson":
            self.grow_poisson(p['growth']['poisson'])
        elif type == "binomial":
            self.grow_binomial(p['growth']['binomial'])
        return

    # Static Method
    @staticmethod
    def clone_counts_to_cell_series(clone_counts):
        """
        Args:
            clone_counts:
        """
        clone_counts = np.array(clone_counts)
        num_cells = clone_counts.sum()
        clone_cell = -1 * np.ones(shape=[num_cells, ])


        clone_cell[:clone_counts[0]] = 0
        for ind, val in enumerate(clone_counts[1:]):
            start = clone_counts[:ind + 1].sum()
            end = clone_counts[:ind + 1].sum() + val
            # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
            clone_cell[start:end] = ind + 1

        clone_cell = pd.Series(clone_cell, dtype=int)
        return clone_cell

    def init_clone_dict(self):
        ### Add in potential to overwrite the values

        # Gets the clone dictionary. Should also have clone to mt dict.
        clones = self.params['initialize']['clone_sizes']
        num_cells = self.num_cells

        # Option 1: List of fraction of size of each clone. 0s are nonclone size, listed first
        if type(clones) == list:
            #clone_cell = pd.Series(index=range(num_cells))
            clone_counts = np.random.multinomial(num_cells, clones)
            clone_cell  = self.clone_counts_to_cell_series(clone_counts)
            self.clone_cell = clone_cell
        # Option 2: 1 clone. ID'd as 1
        elif type(clones) == int: #One number for dominant clone. the others are not.
            clone_cell = np.zeros(shape=[num_cells,])
            clone_cell[:num_cells] = 1
            clone_cell = clone_cell[::-1]
            clone_cell =  pd.Series(clone_cell, dtype=int)
            self.clone_cell = clone_cell

        # Option 3 To ADD, beta binomial and more complex distributions

        self.num_clones =  len(set(clone_cell.values))-1 # Remove the non-clone
        return clone_cell


    def init_cell_coverage(self):
        """There are different modes to the coverage, either a constant or
        through a distribution. :return:
        """
        p = self.params['initialize']['coverage']
        type = p['type']

        num_cells = self.num_cells
        num_pos = self.num_mt_positions
        c = np.zeros([num_cells, num_pos])

        if type == 'constant':
            c[:, :] = p['cov_constant']
        elif type == "poisson":
            # Get the number of coverage per cell based on poisson (should be reads)
            mu_cov_per_cell = p['mu_cov_per_cell']
            num_reads_per_cell = random.poisson(lam=mu_cov_per_cell,
                                                size=num_cells)

            # Number of reads at each position, based on the average for each cell
            for i in num_cells:
                c[i, :] = random.poisson(num_reads_per_cell[i],
                                         size=num_pos)
        self.cells_mt_coverage = c
        return c


    def init_cell_af(self):
        """Initialize the cell-by-mtPos af dataframe. Unless a clone:mt dict was
        provided, the first N MT positions will be the clone AFs. Creates
        self.clone_mt_dict and self.cell_af
        """

        p = self.params['initialize']

        hets = self.params['het']
        q = self.params['het_err_rate']
        clone_df = self.clone_cell
        num_clones = self.num_clones
        n_cells = self.num_cells
        n_mt = self.num_mt_positions

        # Output
        cell_af = pd.DataFrame(np.zeros(shape=[n_cells, n_mt]))


        if 'mt_clone_map' in p and p['mt_clone_map'] is not None:
            self.clone_mt_dict = p['mt_clone_map']
        else:
            # Each clone points to a mt position
            self.clone_mt_dict = dict()
            for i in range(1,num_clones+1):
                self.clone_mt_dict[i] = i

        # TODO Add the MT clone map so it can contain multiple mutants in lineages

        # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
        # If not, will randomly draw based on number of clones
        if type(hets) == list:
            if (len(hets) != num_clones):
                print('here')
            assert(len(hets) == num_clones)

            ## Loop through each clone,
            ## Generate the AF for the clone and non-clones using coverage for each cell
            ## Fill in cell_by_af for that position.
            for ind in range(1, num_clones+1):
                # Generate AF: (clone_df ==  ind).sum()
                n_dom_cells = (clone_df==ind).sum()
                het = hets[ind-1]

                curr_mt = self.clone_mt_dict[ind]


                if p['coverage']['type'] == 'constant':
                    c = p['coverage']['cov_constant']

                    af_i = random.binomial(c, het,
                                           n_dom_cells) / c
                    af_j = random.binomial(c, q,
                                           n_cells - n_dom_cells) / c

                    # Update the dom_cells and non_dom for the current MT
                    cell_af.loc[np.flatnonzero(clone_df == ind), curr_mt] = af_i
                    cell_af.loc[np.flatnonzero(clone_df != ind), curr_mt] = af_j

                # Each cell and position has it's own coverage value, so need to update each
                else:
                    c = self.cells_mt_coverage
                    #Get the cells coverage for the mt position
                    curr_mt_cov= c[:, curr_mt]

                    # Get cell indicies for the clones and nonclones
                    curr_clone_inds = np.flatnonzero(clone_df==ind)
                    curr_nonclone_inds = np.flatnonzero(clone_df!=ind)
                    for cell in curr_clone_inds:
                        # Get one value for curr_mt and cell based on coverage
                        cell_af.loc[cell, curr_mt] = random.binomial(curr_mt_cov[cell], het)
                    for cell in curr_nonclone_inds:
                        cell_af.loc[cell, curr_mt] = random.binomial(curr_mt_cov[cell], q)
                # Loop through each coverage
                #for c in n_dom_cells:

        #####
        # TODO
        # Add noise to the other non-lineage positions
        #####
        self.cell_af = cell_af
        return


    def init_clone_mt(self):
        p = self.params
        if p["initialize"]['type'] == 'growth':
            ## TODO
            # Create a phylogeny and then get the averages of the mutants
            self.average_clone_mt()
        # If not growth, should aready be there.
        return

    def average_clone_mt(self):
        return

    def extract_clone_cells(self, clone_id):
        """
        Args:
            clone_id:
        """
        ids = np.flatnonzero(self.clone_cell == clone_id)
        return ids

    def simulate_expand_cells_af(self, af, growth_inds, sigma):
        """Given a cell-by-af vector, expand the AF.

        Expanded AF occurs by duplicating cells that grew based on the
        growth_inds vector. It will add standard error to each af based on sigma
        :param af: :param growth: Indices of AF to copy :param sigma: Variance
        to add to AF of child. :return:

        Args:
            af:
            growth_inds:
            sigma:
        """

        new_af = af.iloc[growth_inds].copy() + random.normal(0, sigma, size=af.iloc[growth_inds].shape)
        new_af.index = np.arange(af.index[-1]+1, af.index[-1]+1+new_af.shape[0])
        new_af = pd.concat((af,new_af), axis=0)
        #new_af = np.append(af, np.concatenate(new_af))
        return new_af

    def grow_binomial(self, p):
        """
        Args:
            p:
        """
        timesteps = p["time_steps"]
        rates = p["rates"]

        sigma = self.params['growth']["mutant_af_sigma_noise"]
        cell_af = self.cell_af
        clone_mt_dict = self.clone_mt_dict

        num_clones = self.num_clones+1
        new_dict = {}
        for curr_clone in range(num_clones):
            curr_rate = rates[curr_clone]
            ids = self.extract_clone_cells(curr_clone)
            new_cells = cell_af.loc[ids].copy()
            for i in range(timesteps):
                # Simulate growth for each clone separately.
                growth_inds = np.flatnonzero(random.binomial(1, curr_rate, size=new_cells.shape[0]))
                #new_ids =
                new_cells = self.simulate_expand_cells_af(new_cells, growth_inds, sigma)

            new_dict[curr_clone] = new_cells
            # Create list of cells

        ####TODO
        ## new_lineage_mutants chances. This will see if a mutation will change


        ####TODO
        ## Add death + stimulation rate as well as growth
        # Save the new cell clones df and cell af
        clone_counts = [i.shape[0] for i in new_dict.values()]
        self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)

        self.new_cell_af = pd.DataFrame(new_dict[0])
        for clone in range(1, self.num_clones+1):
            self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
        return


    def grow_poisson(self):
        # TODO growth of poisson refactor
        return


    def subsample_new(self, to_delete=False):
        """
        Args:
            to_delete:
        """
        new_cell_af = self.new_cell_af
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)

        self.subsample_new_clone_cell = self.new_clone_cell.loc[
            self.subsample_new_cell_af.index]

        if to_delete:
            self.new_cell_af = None
            self.new_clone_cell = None


    def combine_init_growth(self):
        clones = pd.concat(
            (self.clone_cell, self.subsample_new_clone_cell)).reset_index(
            drop=True)
        combined_cell_af = self.cell_af.append(self.subsample_new_cell_af).reset_index(drop=True)

        combined_meta = np.concatenate((np.ones(shape=[self.cell_af.shape[0],]), np.zeros(shape=[self.subsample_new_cell_af.shape[0]])))
        combined_meta = pd.Series(combined_meta, name='After Growth', dtype=int)
        assert(combined_meta.shape[0] == self.cell_af.shape[0]+self.subsample_new_cell_af.shape[0])
        assert (combined_cell_af.shape[0] == self.cell_af.shape[0] +
                self.subsample_new_cell_af.shape[0])
        assert(combined_meta.shape[0] == clones.shape[0])
        assert(combined_cell_af.shape[0] == clones.shape[0])
        self.combined_meta = combined_meta
        self.combined_clones = clones
        self.combined_cell_af = combined_cell_af
        return

    def save(self, f_save=None):
        """
        Args:
            f_save:
        """
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

    def save_to_mgatk_format(self):
        """Converts into the proper files needed for mgatk. (i.e variant and
        coverage files) :return:
        """

    def load(self):
        filename = self.params['filename']
        f = open(filename, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)

    def compare_before_after(self):
        """Creates a df that contains information on the number of cells from
        each clone before as well as after. :return: df.at[ind, "Dominant
        Before"] = (full_sim.clone_cell == 1).sum() df.at[ind, "Dominant After"]
        = (full_sim.subsample_new_clone_cell == 1).sum()
        """

        return

    def cluster_compare_before_after(self):
        """Compares the performance of clustering on grouping the same clones
        together. :return:
        """
        return

    @staticmethod
    def plot_cluster(cell_af, cell_meta=None, mt_meta=None, f_save=None):
        """
        Args:
            cell_af:
            cell_meta:
            mt_meta:
            f_save:
        """
        ch.plot_cluster(cell_af, row_meta=cell_meta, col_meta=mt_meta,
                        fsave=f_save, to_col_clust=False, to_z=True)

    @staticmethod
    def cluster(cell_af):
        """Dynamic tree clustering of the rows of cell_af :param cell_af:
        :return:

        Args:
            cell_af:
        """
        distances = pdist(cell_af, "euclidean")
        link = linkage(distances, "average")
        clusters = cutreeHybrid(link, distances)['labels']
        return clusters

    @staticmethod
    def cluster_kmeans(cell_af):
        """
        Args:
            cell_af:
        """
        distortions = []
        inertias = []
        mapping1 = {}
        mapping2 = {}
        K = range(1, 10)
        for k in K:
            # Building and fitting the model
            kmeanModel = KMeans(n_clusters=k).fit(cell_af)
            kmeanModel.fit(cell_af)

            distortions.append(sum(
                np.min(cdist(cell_af, kmeanModel.cluster_centers_, 'euclidean'),
                       axis=1)) / cell_af.shape[0])
            inertias.append(kmeanModel.inertia_)

            mapping1[k] = sum(
                np.min(cdist(cell_af, kmeanModel.cluster_centers_, 'euclidean'),
                       axis=1)) / cell_af.shape[0]
            mapping2[k] = kmeanModel.inertia_



def main():
    return


if "__name__" == "__main__":
    main()
 No newline at end of file
eleted file mode 100644
ndex e67049d..0000000
++ /dev/null

5e5968e6d0871b5a90e682f07c8b64f7c3a92b29

growth of poisson refactor

growth of poisson refactor

# TODO growth of poisson refactor

from numpy import random
import os
import pandas as pd
from tqdm import tqdm
#from src.config import ROOT_DIR
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import glob
from sklearn.cluster import KMeans
from sklearn import metrics
from scipy.spatial.distance import cdist
from pandarallel import pandarallel
pandarallel.initialize(nb_workers=32)

from mplh.color_utils import get_colors
from mplh.fig_utils import legend_from_color
from mplh import cluster_help as ch
from src.simulations.utils.config import read_config_file, write_config_file

from dynamicTreeCut import cutreeHybrid
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage
from sklearn.model_selection import ParameterGrid
from src.simulations.utils.config import check_required


class Simulation:
    """
    Lineage tracing simulation. Will initialize cells based on
    their parameters and grow as well. This should be a flexible
    framework, to add different ways to initialize, grow, and metrics to
    have. Additionally can cluster these results.
    """

    def __init__(self, params_f):
        if isinstance(params_f, str):
            params = read_config_file(params_f)
        else:
            params = params_f

        self.params = params
        check_required(params, ['initialize', 'num_cells', 'num_mt_positions', 'prefix'])
        self.prefix = params['prefix']
        self.num_mt_positions = params['num_mt_positions']
        self.num_cells = params['num_cells']
        if not os.path.exists(params['local_outdir']):
            os.mkdir(params['local_outdir'])


    def initialize(self):
        self.init_clone_dict()
        self.init_cell_coverage()
        self.init_cell_af()
        #self.init_clone_mt()

    #should be external method
    def grow(self):
        p = self.params
        type = p["growth"]["type"]
        if  type == "poisson":
            self.grow_poisson(p['growth']['poisson'])
        elif type == "binomial":
            self.grow_binomial(p['growth']['binomial'])
        return

    # Static Method
    @staticmethod
    def clone_counts_to_cell_series(clone_counts):
        clone_counts = np.array(clone_counts)
        num_cells = clone_counts.sum()
        clone_cell = -1 * np.ones(shape=[num_cells, ])


        clone_cell[:clone_counts[0]] = 0
        for ind, val in enumerate(clone_counts[1:]):
            start = clone_counts[:ind + 1].sum()
            end = clone_counts[:ind + 1].sum() + val
            # starts at sum(clone_counts[i-1]) ends at clone_counts[i].sum()
            clone_cell[start:end] = ind + 1

        clone_cell = pd.Series(clone_cell, dtype=int)
        return clone_cell

    def init_clone_dict(self):
        ### Add in potential to overwrite the values

        # Gets the clone dictionary. Should also have clone to mt dict.
        clones = self.params['initialize']['clone_sizes']
        num_cells = self.num_cells

        # Option 1: List of fraction of size of each clone. 0s are nonclone size, listed first
        if type(clones) == list:
            #clone_cell = pd.Series(index=range(num_cells))
            clone_counts = np.random.multinomial(num_cells, clones)
            clone_cell  = self.clone_counts_to_cell_series(clone_counts)
            self.clone_cell = clone_cell
        # Option 2: 1 clone. ID'd as 1
        elif type(clones) == int: #One number for dominant clone. the others are not.
            clone_cell = np.zeros(shape=[num_cells,])
            clone_cell[:num_cells] = 1
            clone_cell = clone_cell[::-1]
            clone_cell =  pd.Series(clone_cell, dtype=int)
            self.clone_cell = clone_cell

        # Option 3 To ADD, beta binomial and more complex distributions

        self.num_clones =  len(set(clone_cell.values))-1 # Remove the non-clone
        return clone_cell


    def init_cell_coverage(self):
        """
        There are different modes to the coverage, either a constant or through a distribution.
        :return:
        """
        p = self.params['initialize']['coverage']
        type = p['type']

        num_cells = self.num_cells
        num_pos = self.num_mt_positions
        c = np.zeros([num_cells, num_pos])

        if type == 'constant':
            c[:, :] = p['cov_constant']
        elif type == "poisson":
            # Get the number of coverage per cell based on poisson (should be reads)
            mu_cov_per_cell = p['mu_cov_per_cell']
            num_reads_per_cell = random.poisson(lam=mu_cov_per_cell,
                                                size=num_cells)

            # Number of reads at each position, based on the average for each cell
            for i in num_cells:
                c[i, :] = random.poisson(num_reads_per_cell[i],
                                         size=num_pos)
        self.cells_mt_coverage = c
        return c


    def init_cell_af(self):
        """
        Initialize the cell-by-mtPos af dataframe. Unless a clone:mt dict was provided,
        the first N MT positions will be the clone AFs.
        Creates self.clone_mt_dict and self.cell_af"""

        p = self.params['initialize']

        hets = self.params['het']
        q = self.params['het_err_rate']
        clone_df = self.clone_cell
        num_clones = self.num_clones
        n_cells = self.num_cells
        n_mt = self.num_mt_positions

        # Output
        cell_af = pd.DataFrame(np.zeros(shape=[n_cells, n_mt]))


        if 'mt_clone_map' in p and p['mt_clone_map'] is not None:
            self.clone_mt_dict = p['mt_clone_map']
        else:
            # Each clone points to a mt position
            self.clone_mt_dict = dict()
            for i in range(1,num_clones+1):
                self.clone_mt_dict[i] = i

        # TODO Add the MT clone map so it can contain multiple mutants in lineages

        # If there is a heteroplasmy table in params, it is list of mutant heteroplasmy AFs.
        # If not, will randomly draw based on number of clones
        if type(hets) == list:
            if (len(hets) != num_clones):
                print('here')
            assert(len(hets) == num_clones)

            ## Loop through each clone,
            ## Generate the AF for the clone and non-clones using coverage for each cell
            ## Fill in cell_by_af for that position.
            for ind in range(1, num_clones+1):
                # Generate AF: (clone_df ==  ind).sum()
                n_dom_cells = (clone_df==ind).sum()
                het = hets[ind-1]

                curr_mt = self.clone_mt_dict[ind]


                if p['coverage']['type'] == 'constant':
                    c = p['coverage']['cov_constant']

                    af_i = random.binomial(c, het,
                                           n_dom_cells) / c
                    af_j = random.binomial(c, q,
                                           n_cells - n_dom_cells) / c

                    # Update the dom_cells and non_dom for the current MT
                    cell_af.loc[np.flatnonzero(clone_df == ind), curr_mt] = af_i
                    cell_af.loc[np.flatnonzero(clone_df != ind), curr_mt] = af_j

                # Each cell and position has it's own coverage value, so need to update each
                else:
                    c = self.cells_mt_coverage
                    #Get the cells coverage for the mt position
                    curr_mt_cov= c[:, curr_mt]

                    # Get cell indicies for the clones and nonclones
                    curr_clone_inds = np.flatnonzero(clone_df==ind)
                    curr_nonclone_inds = np.flatnonzero(clone_df!=ind)
                    for cell in curr_clone_inds:
                        # Get one value for curr_mt and cell based on coverage
                        cell_af.loc[cell, curr_mt] = random.binomial(curr_mt_cov[cell], het)
                    for cell in curr_nonclone_inds:
                        cell_af.loc[cell, curr_mt] = random.binomial(curr_mt_cov[cell], q)
                # Loop through each coverage
                #for c in n_dom_cells:

        #####
        # TODO
        # Add noise to the other non-lineage positions
        #####
        self.cell_af = cell_af
        return


    def init_clone_mt(self):
        p = self.params
        if p["initialize"]['type'] == 'growth':
            ## TODO
            # Create a phylogeny and then get the averages of the mutants
            self.average_clone_mt()
        # If not growth, should aready be there.
        return

    def average_clone_mt(self):
        return

    def extract_clone_cells(self, clone_id):
        ids = np.flatnonzero(self.clone_cell == clone_id)
        return ids

    def simulate_expand_cells_af(self, af, growth_inds, sigma):
        """
        Given a cell-by-af vector, expand the AF.

        Expanded AF occurs by duplicating cells that grew based on
        the growth_inds vector. It will add standard error to each
        af based on sigma
        :param af:
        :param growth: Indices of AF to copy
        :param sigma: Variance to add to AF of child.
        :return:
        """

        new_af = af.iloc[growth_inds].copy() + random.normal(0, sigma, size=af.iloc[growth_inds].shape)
        new_af.index = np.arange(af.index[-1]+1, af.index[-1]+1+new_af.shape[0])
        new_af = pd.concat((af,new_af), axis=0)
        #new_af = np.append(af, np.concatenate(new_af))
        return new_af

    def grow_binomial(self, p):
        timesteps = p["time_steps"]
        rates = p["rates"]

        sigma = self.params['growth']["mutant_af_sigma_noise"]
        cell_af = self.cell_af
        clone_mt_dict = self.clone_mt_dict

        num_clones = self.num_clones+1
        new_dict = {}
        for curr_clone in range(num_clones):
            curr_rate = rates[curr_clone]
            ids = self.extract_clone_cells(curr_clone)
            new_cells = cell_af.loc[ids].copy()
            for i in range(timesteps):
                # Simulate growth for each clone separately.
                growth_inds = np.flatnonzero(random.binomial(1, curr_rate, size=new_cells.shape[0]))
                #new_ids =
                new_cells = self.simulate_expand_cells_af(new_cells, growth_inds, sigma)

            new_dict[curr_clone] = new_cells
            # Create list of cells

        ####TODO
        ## new_lineage_mutants chances. This will see if a mutation will change


        ####TODO
        ## Add death + stimulation rate as well as growth
        # Save the new cell clones df and cell af
        clone_counts = [i.shape[0] for i in new_dict.values()]
        self.new_clone_cell = self.clone_counts_to_cell_series(clone_counts)

        self.new_cell_af = pd.DataFrame(new_dict[0])
        for clone in range(1, self.num_clones+1):
            self.new_cell_af = pd.concat((self.new_cell_af, new_dict[clone]),axis=0).reset_index(drop=True)
        return


    def grow_poisson(self):
        # TODO growth of poisson refactor
        return


    def subsample_new(self, to_delete=False):
        new_cell_af = self.new_cell_af
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)

        self.subsample_new_clone_cell = self.new_clone_cell.loc[
            self.subsample_new_cell_af.index]

        if to_delete:
            self.new_cell_af = None
            self.new_clone_cell = None


    def combine_init_growth(self):
        clones = pd.concat(
            (self.clone_cell, self.subsample_new_clone_cell)).reset_index(
            drop=True)
        combined_cell_af = self.cell_af.append(self.subsample_new_cell_af).reset_index(drop=True)

        combined_meta = np.concatenate((np.ones(shape=[self.cell_af.shape[0],]), np.zeros(shape=[self.subsample_new_cell_af.shape[0]])))
        combined_meta = pd.Series(combined_meta, name='After Growth', dtype=int)
        assert(combined_meta.shape[0] == self.cell_af.shape[0]+self.subsample_new_cell_af.shape[0])
        assert (combined_cell_af.shape[0] == self.cell_af.shape[0] +
                self.subsample_new_cell_af.shape[0])
        assert(combined_meta.shape[0] == clones.shape[0])
        assert(combined_cell_af.shape[0] == clones.shape[0])
        self.combined_meta = combined_meta
        self.combined_clones = clones
        self.combined_cell_af = combined_cell_af
        return

    def save(self, f_save=None):
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

    def save_to_mgatk_format(self):
        """
        Converts into the proper files needed for mgatk. (i.e variant and coverage files)
        :return:
        """

    def load(self):
        filename = self.params['filename']
        f = open(filename, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)

    def compare_before_after(self):
        """
        Creates a df that contains information on
        the number of cells from each clone before as well as after.
        :return:
        df.at[ind, "Dominant Before"] = (full_sim.clone_cell == 1).sum()
        df.at[ind, "Dominant After"] =  (full_sim.subsample_new_clone_cell == 1).sum()

        """

        return

    def cluster_compare_before_after(self):
        """
        Compares the performance of clustering on grouping the same
        clones together.
        :return:
        """
        return

    @staticmethod
    def plot_cluster(cell_af, cell_meta=None, mt_meta=None, f_save=None):
        ch.plot_cluster(cell_af, row_meta=cell_meta, col_meta=mt_meta,
                        fsave=f_save, to_col_clust=False, to_z=True)

    @staticmethod
    def cluster(cell_af):
        """
        Dynamic tree clustering of the rows of cell_af
        :param cell_af:
        :return:
        """
        distances = pdist(cell_af, "euclidean")
        link = linkage(distances, "average")
        clusters = cutreeHybrid(link, distances)['labels']
        return clusters

    @staticmethod
    def cluster_kmeans(cell_af):
        distortions = []
        inertias = []
        mapping1 = {}
        mapping2 = {}
        K = range(1, 10)
        for k in K:
            # Building and fitting the model
            kmeanModel = KMeans(n_clusters=k).fit(cell_af)
            kmeanModel.fit(cell_af)

            distortions.append(sum(
                np.min(cdist(cell_af, kmeanModel.cluster_centers_, 'euclidean'),
                       axis=1)) / cell_af.shape[0])
            inertias.append(kmeanModel.inertia_)

            mapping1[k] = sum(
                np.min(cdist(cell_af, kmeanModel.cluster_centers_, 'euclidean'),
                       axis=1)) / cell_af.shape[0]
            mapping2[k] = kmeanModel.inertia_




def main():
    return


if "__name__" == "__main__":
    main()
 No newline at end of file
eleted file mode 100644
ndex e67049d..0000000
++ /dev/null

d33839409e2b3ba48b84e057a8abd16fb6ef6f1b

Sweep across het, growth rate, coverage, error, cluster size,

Sweep across het, growth rate, coverage, error, cluster size,

# TODO Sweep across het, growth rate, coverage, error, cluster size,

class ParameterSweep:
    # TODO Sweep across het, growth rate, coverage, error, cluster size,
    def __init__(self, params_dir):
        self.params_dir = params_dir
        self.files = glob.glob(params_dir+'/*.yaml')
        sweep_params = dict()
        # Create a dictionary for each parameter across files
        for f in self.files:
            sweep_params[f] = read_config_file(f)
        self.sweep_params = sweep_params
        return

    def run_sweep(self):
        sweep_results = dict()
        for f in self.sweep_params:
            params = self.sweep_params[f]
            print(f"Running with file: {f}")
            sim = FullSimulation(params)
            sim.run()
            sweep_results[f] = sim
        return


    def plot_sensitivity(self, vars=None):
        return

    def plot_ppv(self):
        return

    def cluster_before_after(self):
        return

    def save(self, f_save=None):
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

    def load(self, filename):
        #filename = self.params['filename']
        f = open(filename, 'rb')
        tmp_dict = pickle.load(f)
        f.close()
        self.__dict__.update(tmp_dict)



# I can make each variable a class?
# Does this ruin running the MCMC? I don't think so, b/c that format is going to be put in after anyway
class FullSimulation:
    def __init__(self, params_f):
        # TODO parallel_apply over simulations
        params = read_config_file(params_f)
        self.n_iter = params['num_iterations']
        self.params = params
        return
        #for i in self.n_iter:

    def run(self):
        # Parallelize df
        df = pd.Series(index=range(self.n_iter))
        df = df.parallel_apply(self.run_sim, args=(self.params,))
        self.sim = df
        return

    @staticmethod
    def run_sim(x, params):
        s = Simulation(params)
        s.initialize()
        s.grow()
        s.subsample_new(to_delete=True)
        return s

    def flatten_sim(self):

9e0da5a83e137e412b6cf789bde18ef066ad6eed

growth of poisson refactor

growth of poisson refactor

# TODO growth of poisson refactor

        return


    def grow_poisson(self):
        # TODO growth of poisson refactor
        return


    def subsample_new(self, to_delete=False):
        new_cell_af = self.new_cell_af
        p = self.params
        if 'sequence_subsample' in p and p['sequence_subsample'] is not None:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.params['sequence_subsample'])
        else:
            self.subsample_new_cell_af = new_cell_af.sample(n=self.num_cells)

        self.subsample_new_cell_clone = self.new_cell_clone.loc[
            self.subsample_new_cell_af.index]

        if to_delete:
            self.new_cell_af = None
            self.new_cell_clone = None


    def save(self, f_save=None):
        if f_save is None:
            f_save = os.path.join(self.params['local_outdir'], self.params['prefix']+'.p')
        f = open(f_save, 'wb')
        pickle.dump(self.__dict__, f, 2)
        f.close()

ew file mode 100644
ndex 0000000..73dfc9c
++ b/src/simulations/plots.py

6f4ee2ad84351adfc26104d16d0bd22539beb875

Implement mgatk

  • Test mgatk for tenx data and their data.
  • Adjust simulated data to save as mgatk input
  • Adjust simulated data to have forward/reverse concordance values as well (gaussian mixture)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.