#* coherences_w_stats: #* attr: #* fillcolor: '3' #* desc: Compute term occurrence counts (individual and together in patients for pairs #* of terms in the top N per topic) needed for coherence calculations by data partner #* ID and topic. #* ext: py #* inputs: #* - OUTPUT_topic_descriptions #* - INPUT_persons_datapartners_concepts #* from pyspark.sql.window import Window import pyspark.sql.functions as F def coherences_w_stats(OUTPUT_topic_descriptions, INPUT_persons_datapartners_concepts): # Distinct concepts for each person in the hold_out partition person_concepts = (INPUT_persons_datapartners_concepts .filter(F.col('hold_out') == True) .select('person_id','data_partner_id','concept_id') ) windowTopic = Window.partitionBy("topic_name").orderBy(F.col("term_weight").desc()) # this is the top 20 terms per topic - small top_topic_concepts = (OUTPUT_topic_descriptions .withColumn('group_row', F.row_number().over(windowTopic)) .filter(F.col('group_row') <= 20) .select('topic_name', 'concept_id') ) # Cartesian product of concepts for each topic # Only compute one triangle since the cartesian product matrix is symmetrical to avoid redundant computation topic_concept_pairs = (top_topic_concepts .withColumnRenamed('concept_id','concept_id_1') .join(top_topic_concepts, "topic_name") .withColumnRenamed('concept_id','concept_id_2') .filter(F.col('concept_id_1') < F.col('concept_id_2')) ) # Distinct list of top 50 concepts from all topics top_concepts = top_topic_concepts.select('concept_id').distinct() # Filter person_concepts to only top concepts person_concepts = (person_concepts .join(F.broadcast(top_concepts), "concept_id") .distinct() ) # Find patients for each concept_id_1 person_topic_pairs_by_site_1 = (topic_concept_pairs .join(person_concepts.withColumnRenamed('concept_id','concept_id_1'), 'concept_id_1') ) # Find patients for each concept_id_2 person_topic_pairs_by_site_2 = (topic_concept_pairs .join(person_concepts.withColumnRenamed('concept_id','concept_id_2'), 'concept_id_2') ) # Count patients with both concept_id_1 and concept_id_2 # I had countDistinct() here, but 1) it was very slow and 2) I don't think it was necessary person_topic_pairs_by_site_both = (person_topic_pairs_by_site_1 .join(person_topic_pairs_by_site_2, ["topic_name", "data_partner_id", "concept_id_1", "concept_id_2", "person_id"]) .groupBy("topic_name", "data_partner_id", "concept_id_1", "concept_id_2") .count() .withColumnRenamed("count","n_concept_both") ) # Get distinct patients counts for concept_id_1 person_topic_pairs_by_site_1 = (person_topic_pairs_by_site_1 .groupBy("topic_name", "data_partner_id", "concept_id_1", "concept_id_2") .count() .withColumnRenamed("count","n_concept_1") ) # Get distinct patients counts for concept_id_2 person_topic_pairs_by_site_2 = (person_topic_pairs_by_site_2 .groupBy("topic_name", "data_partner_id", "concept_id_1", "concept_id_2") .count() .withColumnRenamed("count","n_concept_2") ) # Outer join all counts together coherences_all = (person_topic_pairs_by_site_1 .join(person_topic_pairs_by_site_2, ["topic_name", "data_partner_id", "concept_id_1", "concept_id_2"], "outer") .join(person_topic_pairs_by_site_both, ["topic_name", "data_partner_id", "concept_id_1", "concept_id_2"], "outer") .na.fill(0, ['n_concept_1','n_concept_2','n_concept_both']) ) # Join in number of patients per data_partner_id counts_by_data_partner = (person_concepts .select("data_partner_id", "person_id") .distinct() .groupBy("data_partner_id") .count() .withColumnRenamed("count","n_patients") ) with_counts = (coherences_all .join(F.broadcast(counts_by_data_partner), "data_partner_id") .withColumn('n_patients', F.col('n_patients').cast('int')) ) ## Ah, this has to be computed post-aggregation if we want to compute mean over topics or over topics + data partners (by summing up patients) # Calculate pmi_smooth_one # coherences_w_stats = (with_counts # .withColumn("concept_1_ratio", F.col("n_concept_1")/F.col("n_patients")) # .withColumn("concept_2_ratio", F.col("n_concept_2")/F.col("n_patients")) # .withColumn("both_ratio", F.col("n_concept_both")/F.col("n_patients")) # .withColumn("pmi_smooth_one", # F.when(F.col("concept_1_ratio")*F.col("concept_2_ratio")>0, # F.log2( (1+F.col("both_ratio")) / (F.col("concept_1_ratio")*F.col("concept_2_ratio")) ) # ).otherwise(0) # ).drop('concept_1_ratio','concept_2_ratio','both_ratio') # ) # return coherences_w_stats return with_counts