#* topic_correlations: #* attr: #* fillcolor: '3' #* desc: Compute pearson correlation and Jensen-Shannon divergence for topic pairs. #* ext: R #* inputs: #* - OUTPUT_topic_descriptions #* topic_correlations <- function(OUTPUT_topic_descriptions) { OUTPUT_topic_descriptions <- SparkR::select(OUTPUT_topic_descriptions, c("topic_name", "topic_index", "concept_id", "term_weight")) %>% SparkR::collect() uniq_topic_indices <- OUTPUT_topic_descriptions$topic_index %>% unique() %>% sort() num_comparisons <- length(uniq_topic_indices) * length(uniq_topic_indices) results_df <- data.frame(topici_index = rep(NA, num_comparisons), topicj_index = rep(NA, num_comparisons), topici_name = rep(NA, num_comparisons), topicj_name = rep(NA, num_comparisons), corr = rep(NA, num_comparisons), js_div = rep(NA, num_comparisons)) row_index <- 1 for(topici_index in uniq_topic_indices) { for(topicj_index in uniq_topic_indices) { topici_name <- OUTPUT_topic_descriptions$topic_name[OUTPUT_topic_descriptions$topic_index == topici_index][1] # hack... topicj_name <- OUTPUT_topic_descriptions$topic_name[OUTPUT_topic_descriptions$topic_index == topicj_index][1] # changing to doing all combinations even if redundant - need it for the matrix but too lazy to do it in post :) if(topici_index < topicj_index | TRUE) { topici_rows <- OUTPUT_topic_descriptions %>% filter(topic_index == topici_index) %>% arrange(concept_id) topicj_rows <- OUTPUT_topic_descriptions %>% filter(topic_index == topicj_index) %>% arrange(concept_id) topici_weights <- topici_rows$term_weight topicj_weights <- topicj_rows$term_weight correlation <- cor(topici_weights, topicj_weights) p <- topici_weights q <- topicj_weights # https://stackoverflow.com/questions/11226627/jensen-shannon-divergence-in-r # https://medium.com/datalab-log/measuring-the-statistical-similarity-between-two-samples-using-jensen-shannon-and-kullback-leibler-8d05af514b15 n <- 0.5 * (p + q) js_div <- 0.5 * (sum(p * log(p / n)) + sum(q * log(q / n))) results_df[row_index, "topici_index"] <- topici_index results_df[row_index, "topicj_index"] <- topicj_index results_df[row_index, "topici_name"] <- topici_name results_df[row_index, "topicj_name"] <- topicj_name results_df[row_index, "corr"] <- correlation results_df[row_index, "js_div"] <- js_div row_index <- row_index + 1 if(row_index %% 100 == 0) { print(paste("Completed row ", row_index, " of ", num_comparisons, " (", row_index/num_comparisons, ") ")) } } } gc() } return(results_df) } ################################################# ## Global imports and functions included below ## ################################################# library(ggplot2) library(dplyr) library(tidyr) library(stringr) library(magrittr) to_rdata <- function(data) { output <- new.output() output_fs <- output$fileSystem() saveRDS(data, output_fs$get_path("data.rds", 'w')) } from_rdata <- function(data) { fs <- data$fileSystem() path <- fs$get_path("data.rds", 'r') rds <- readRDS(path) return(rds) } int64fix <- function(df) { is_int64 <- function(x) { "integer64" %in% class(x) } needs_fix <- lapply(df, is_int64) %>% unlist() df[needs_fix] <- lapply(df[needs_fix], as.integer) return(df) } `%+%` <- function(a, b) { paste0(a, b) } first_n_words <- function(charvec, n = 6) { charvec %>% str_split(" +") %>% lapply(function(x) { sub <- head(x, n); ifelse(length(x) > length(sub), paste(c(sub, "..."), collapse = " "), paste(sub, collapse = " ")) }) %>% unlist() }