Source code for SuperSCC.gene_module.gene_module

"""
Gene module for functions specific to gene data analysis.
"""

from langchain_community.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser


import pandas as pd
import numpy as np
from multiprocessing import Pool
from functools import partial
from tqdm import tqdm
import warnings

# Suppress pandas warnings about fragmented dataframes, which can occur during column additions.
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)


def find_overlap_singlet(data, full_search=False):
    """
    Calculates the overlap (intersection size) between all pairs of columns in a DataFrame.
    This is a single-threaded version.

    Args:
        data (pd.DataFrame): DataFrame where each column is a set of items.
        full_search (bool): If True, compares every column with every other column.
                            If False, only compares each column with subsequent columns to avoid duplicates.

    Returns:
        pd.DataFrame: A DataFrame with columns 'number', 'base_name', 'compare_name'.
    """
    num_cols = data.shape[1]
    results = []
    
    with tqdm(total=num_cols, desc="Finding overlaps (single-threaded)") as pbar:
        for i in range(num_cols):
            base_name = data.columns[i]
            base_set = set(data[base_name].dropna())
            
            if full_search:
                compare_cols = data.columns
            else:
                compare_cols = data.columns[i + 1:]

            for compare_name in compare_cols:
                compare_set = set(data[compare_name].dropna())
                intersection_size = len(base_set.intersection(compare_set))
                
                results.append({
                    'number': intersection_size,
                    'base_name': base_name,
                    'compare_name': compare_name
                })
            pbar.update(1)
            
    return pd.DataFrame(results)

def _find_overlap_base_worker(index, data, full_search=False):
    """
    Worker function to calculate pairwise intersection size for one base column.
    
    Args:
        index (int): The column index to use as the base group.
        data (pd.DataFrame): The full DataFrame of gene sets.
        full_search (bool): If False, compares the base column only with subsequent columns.

    Returns:
        pd.DataFrame: A DataFrame of intersection results for the given base column.
    """
    base_name = data.columns[index]
    base_set = set(data[base_name].dropna())
    
    if full_search:
        compare_cols = data.columns
    else:
        # To match R's `seq(index)`, which includes the index itself, we use `index:`
        # But the original R code `! colnames(data) %in% colnames(data)[seq(num)]`
        # actually means comparing with columns *after* the current one.
        compare_cols = data.columns[index:]

    results = []
    for compare_name in compare_cols:
        compare_set = set(data[compare_name].dropna())
        intersection_size = len(base_set.intersection(compare_set))
        results.append({
            'number': intersection_size,
            'base_name': base_name,
            'compare_name': compare_name
        })
    return pd.DataFrame(results)


def _find_overlap_founder_worker(index, data):
    """
    Worker function to calculate intersection size between a 'founder' group and other groups.

    Args:
        index (int): The column index of the comparison group (in data without 'founder').
        data (pd.DataFrame): The DataFrame containing a 'founder' column.

    Returns:
        pd.DataFrame: A one-row DataFrame with the intersection result.
    """
    founder_set = set(data['founder'].dropna())
    base_name = "founder"
    
    compare_cols = data.columns.drop("founder")
    compare_name = compare_cols[index]
    compare_set = set(data[compare_name].dropna())
    
    intersection_size = len(founder_set.intersection(compare_set))
    
    return pd.DataFrame([{
        'number': intersection_size,
        'base_name': base_name,
        'compare_name': compare_name
    }])


# ==============================================================================
# Parallel Intersection Functions
# ==============================================================================

def do_intersection_base(data, parallel_num=8):
    """
    A function to run find_overlap_base_worker in parallel.

    Args:
        data (pd.DataFrame): DataFrame where each column is a set of items.
        parallel_num (int): Number of parallel processes to use. If None, runs single-threaded.

    Returns:
        pd.DataFrame: A concatenated DataFrame of all pairwise intersection results.
    """
    if not parallel_num:
        # R code calls find_overlap_singlet without full_search, but its logic
        # is equivalent to the parallel version which avoids duplicate pairs.
        # The parallel version compares i with j>=i, then we filter i==j later.
        # This is slightly different than R, but achieves same goal of finding best pair.
        return find_overlap_singlet(data, full_search=False)

    worker_func = partial(_find_overlap_base_worker, data=data, full_search=False)
    
    with Pool(processes=parallel_num) as pool:
        results = list(tqdm(pool.imap(worker_func, range(data.shape[1])), 
                            total=data.shape[1], 
                            desc="Finding overlaps (parallel)"))
        
    intersect_data = pd.concat(results, ignore_index=True)
    # Filter out self-comparisons
    intersect_data = intersect_data[intersect_data['base_name'] != intersect_data['compare_name']].reset_index(drop=True)
    return intersect_data

def do_intersection_founder(data, parallel_num=8):
    """
    A function to run find_overlap_founder_worker in parallel.

    Args:
        data (pd.DataFrame): DataFrame containing a 'founder' column and other gene sets.
        parallel_num (int): Number of parallel processes.

    Returns:
        pd.DataFrame: A concatenated DataFrame of intersection results with the founder.
    """
    worker_func = partial(_find_overlap_founder_worker, data=data)
    num_compare_cols = data.shape[1] - 1
    
    with Pool(processes=parallel_num) as pool:
        results = list(tqdm(pool.imap(worker_func, range(num_compare_cols)), 
                            total=num_compare_cols, 
                            desc="Finding founder overlaps (parallel)"))

    return pd.concat(results, ignore_index=True)

# ==============================================================================
# Main Gene Module Logic
# ==============================================================================

def core_get_gene_module(data, intersect_size=10, intersect_group_size=5,
                         parallel_num=8, init_signal=True):
    """
    A recursive underlying function to identify and merge gene modules.
    """
    # In pandas, it's more efficient to work with split columns from the start if needed
    # The R code `str_remove(.x, "/.+")` keeps the part before the slash
    pre_data_for_intersection = data.apply(lambda col: col.str.split('/').str[0], axis=0)
    
    if init_signal:
        intersect_data = do_intersection_base(data=pre_data_for_intersection, parallel_num=parallel_num)
    else:
        intersect_data = do_intersection_founder(data=pre_data_for_intersection, parallel_num=parallel_num)
        
    intersect_data = intersect_data[intersect_data['number'] > intersect_size]
    intersect_data = intersect_data.sort_values('number', ascending=False, ignore_index=True, kind = "stable")

    # if not intersect_data.empty and intersect_data.shape[0] >= intersect_group_size:
    if intersect_data.shape[0] >= intersect_group_size:

        # 1. Identify the best pair to merge
        max_overlap_row = intersect_data.iloc[0]
        member1_name, member2_name = max_overlap_row['base_name'], max_overlap_row['compare_name']
        max_overlap_members = [member1_name, member2_name]

        # 2. Get gene sets (intersection, union, difference)
        genes1 = set(pre_data_for_intersection[member1_name].dropna())
        genes2 = set(pre_data_for_intersection[member2_name].dropna())
        
        intersection_genes = genes1.intersection(genes2)
        union_genes = genes1.union(genes2)
        diff_genes = union_genes.difference(intersection_genes)


        # 3. Create ranked gene pools from original data (with scores)
        def create_ranking_pool(data, col_name):
            pool = data[col_name].dropna().str.split('/', n=1, expand=True)
            pool.columns = ['genes', 'scores']
            # pool['scores'] = pd.to_numeric(pool['scores'])
            return pool

        genes_ranking_pool_1 = create_ranking_pool(data, member1_name)
        genes_ranking_pool_2 = create_ranking_pool(data, member2_name)
        
        # 4. Rank intersection genes by highest score from either pool
        intersection_genes_ranking_1 = genes_ranking_pool_1[genes_ranking_pool_1['genes'].isin(intersection_genes)]
        intersection_genes_ranking_2 = genes_ranking_pool_2[genes_ranking_pool_2['genes'].isin(intersection_genes)]
        
        intersection_genes_ranking_final = pd.concat([intersection_genes_ranking_1, intersection_genes_ranking_2])
        intersection_genes_ranking_final = intersection_genes_ranking_final.sort_values('scores', ascending=False, kind = "stable").drop_duplicates('genes', keep='first')


        # 5. Build the new 50-gene "founder" module
        # The R code uses a fixed size of 50.
        TARGET_SIZE = 50
        final_top50_ranking = None

        if len(intersection_genes) != TARGET_SIZE: 
            # Need to fill remaining slots from difference genes
            robust_ranking = intersection_genes_ranking_final
            
            diff_genes_ranking_1 = genes_ranking_pool_1[genes_ranking_pool_1['genes'].isin(diff_genes)]
            diff_genes_ranking_2 = genes_ranking_pool_2[genes_ranking_pool_2['genes'].isin(diff_genes)]
            
            final_ranking = pd.concat([intersection_genes_ranking_final, diff_genes_ranking_1, diff_genes_ranking_2])
            final_ranking = final_ranking.sort_values("scores", ascending=False, kind = "stable")

            border_ranking = final_ranking.loc[final_ranking.genes.isin(diff_genes), :]
            num_needed = TARGET_SIZE - len(intersection_genes)
            border_ranking = border_ranking.head(num_needed)

            robust_ranking = final_ranking.loc[final_ranking.genes.isin(intersection_genes), :]
            
            final_top50_ranking = pd.concat([border_ranking, robust_ranking])
        else:
            final_top50_ranking = intersection_genes_ranking_final.head(TARGET_SIZE)
           
        
        # Format back to "gene/score" string
        founder = final_top50_ranking.apply(lambda row: f"{row['genes']}/{row['scores']}", axis=1)

        # 6. Prepare data for recursive call
        update_data = data.drop(columns=max_overlap_members)
        
        # Add the new founder column, padding with NaN if necessary
        founder_series = pd.Series(founder.values, name='founder')
        update_data = update_data.reset_index(drop=True) # Ensure index alignment
        update_data['founder'] = founder_series

        return core_get_gene_module(data = update_data, parallel_num = parallel_num, init_signal = False)
    else:
        # Base case: no more sufficiently large intersections found
        return data


[docs] def get_gene_module(data, intersect_size=10, intersect_group_size=5, parallel_num=8): """ A function to iteratively find gene modules from a collection of gene sets. Args: data (pd.DataFrame): A DataFrame where each column contains a gene set as a list of strings (e.g., "GENE/SCORE"). Columns should be padded with NaN for unequal lengths. intersect_size (int): The minimum intersection size to consider merging two sets. intersect_group_size (int): The minimum number of high-intersection pairs required to proceed with a merge. parallel_num (int): The number of parallel processes to use. Returns: dict: A dictionary containing: - 'gene_module': A list of the identified gene modules (each a list of strings). - 'module_members': A list of the original column names that formed each module. - 'remained_gene_sets': The final DataFrame of gene sets that were not merged. """ meta_program_list = [] meta_program_members = [] remained_gene_sets = [] current_data = data.copy() iteration = 1 while True: # num_cols_before = current_data.shape[1] # if num_cols_before <= 1: # break output = core_get_gene_module( data=current_data, intersect_size=intersect_size, intersect_group_size=intersect_group_size, parallel_num=parallel_num, init_signal=True # Always start with a base search in the main loop ) # If 'founder' is in the output, a module was created if "founder" not in output.columns: # No module was found, terminate the loop print("No more modules found satisfying the criteria.") break else: print(f"Finding the gene module {iteration}") # Store the identified module module = output['founder'].dropna().tolist() meta_program_list.append(module) # Identify which original members were consumed # update_data = output.drop(columns='founder') update_data = output.drop(columns="founder") consumed_members = list(set(current_data.columns) - set(update_data.columns)) meta_program_members.append(consumed_members) remained_gene_sets.append(current_data) # Update data for the next iteration current_data = update_data iteration += 1 return { 'gene_module': meta_program_list, 'module_members': meta_program_members, 'remained_gene_sets': remained_gene_sets }
[docs] def compare_gene_modules(module1, module2, api_key, model = "deepseek-chat", base_url = "https://api.deepseek.com/v1"): """ Compare two gene modules and analyze their similarities and differences. Parameters ---------- module1: A gene set list representing the first gene module. module2: A gene set list representing the second gene module. api_key: A sting for the api key of the LLM provider. model: A string for the LLM model name. base_url: A string to base URL for API requests. """ # Genes are already converted to names module1_genes = module1 module2_genes = module2 # Find common and unique genes common_genes = set(module1_genes).intersection(set(module2_genes)) unique_to_module1 = set(module1_genes) - set(module2_genes) unique_to_module2 = set(module2_genes) - set(module1_genes) # Create comparison template comparison_template = """You are a bioinformatics expert. Compare these two gene modules: Module 1: {module1_genes} Module 2: {module2_genes} Analyze: 1. Common biological pathways between modules 2. Unique pathways in each module 3. Potential functional relationships between modules 4. Disease associations shared between modules 5. Tissue/cell type specificity differences Provide your analysis in clear, structured paragraphs.""" # Create comparison prompt comparison_prompt = ChatPromptTemplate.from_template(comparison_template) # set model model = ChatOpenAI( model=model, temperature=0.7, openai_api_key=api_key, openai_api_base=base_url ) # Run comparison analysis comparison_chain = comparison_prompt | model | StrOutputParser() comparison_result = comparison_chain.invoke( { "module1_genes": ", ".join(module1_genes), "module2_genes": ", ".join(module2_genes), } ) return { "common_genes": list(common_genes), "unique_to_module1": list(unique_to_module1), "unique_to_module2": list(unique_to_module2), "comparison_analysis": comparison_result, }
[docs] def analyse_one_gene_module(module_genes, api_key, model = "deepseek-chat", base_url = "https://api.deepseek.com/v1"): """ Analyze a single gene module using the DeepSeek model Parameters ---------- module_genes: A gene set list representing the gene module. api_key: A sting for the api key of the LLM provider. model: A string for the LLM model name. base_url: A string to base URL for API requests. """ # Create prompt template template = """You are a bioinformatics expert. Analyze this list of genes and provide a detailed functional interpretation of the gene module: {module_genes} Consider: 1. Common biological pathways 2. Cellular processes involved 3. Potential tissue/cell type specificity 4. Disease associations 5. Functional relationships between genes Provide your analysis in clear, structured paragraphs.""" prompt = ChatPromptTemplate.from_template(template) # set model model = ChatOpenAI( model=model, temperature=0.7, openai_api_key=api_key, openai_api_base=base_url ) chain = prompt | model | StrOutputParser() # Run analysis for gene module analysis = chain.invoke({"module_genes": ", ".join(module_genes)}) return analysis