#* vectorize_concepts: #* attr: #* fillcolor: '3' #* desc: Now take the input arrays of concept_ids and use the vocabulary to generate #* count data (for training and validation patients). #* ext: py #* inputs: #* - LDA_arrayify_input #* - gen_vocabulary #* # now we can do part 2 of the vectorization process, using the vocab list to generate the term usage vectors from pyspark.ml.feature import CountVectorizerModel def vectorize_concepts( gen_vocabulary, LDA_arrayify_input): # read in the python list of the vocab vocab_list = from_pickle(gen_vocabulary) # fortunately it's possible to build a vectorizer from a pre-computed vocab list such as we have # I'm also using binary = True to tell it that when counting occurrances we don't really need counts, just presence/absence (better for EHR data? you decide...) cvmodel = CountVectorizerModel.from_vocabulary(vocab_list, inputCol = "concept_ids", outputCol = "concept_vector", binary = False) # maxDF already incorporated...(?) # now we do the 'transform' (vectorization) and return it transformed_data = cvmodel.transform(LDA_arrayify_input) return transformed_data ################################################# ## Global imports and functions included below ## ################################################# import pickle def to_pickle(data): output = Transforms.get_output() output_fs = output.filesystem() with output_fs.open('data.pickle', 'wb') as f: pickle.dump(data, f) def from_pickle(transform_input): with transform_input.filesystem().open('data.pickle', 'rb') as f: data = pickle.load(f) return data plot_topic_clouds <- function(top_1000_terms_per_topic, focus_topics = c(1, 2, 3, 4, 5, 6)) { library(ggplot2) library(dplyr) library(tidyr) library(stringr) library(magrittr) library(ggwordcloud) # let's adjust the topic names here a bit top_1000_terms_per_topic <- top_1000_terms_per_topic OUTPUT_topic_descriptions <- top_1000_terms_per_topic full_relevance_range <- range(top_1000_terms_per_topic$relevance) full_term_weight_range <- range(top_1000_terms_per_topic$term_weight) over_representation_filter <- 0.0 relevance_filter <- -1000000 color_by <- "Relevance" # or "Concept Domain ID" font_min <- 1.25 # 0.75 font_max <- 8 logscale_color = FALSE focus_topics <- as.character(focus_topics) if(focus_topics[1] == "0") { # if set to 0, use all topics present in data # (should be "1", "2", ... etc) focus_topics <- str_extract(OUTPUT_topic_descriptions$topic_name, "\\d+") %>% unique() } num_topics <- OUTPUT_topic_descriptions$topic_name %>% unique() %>% length() # we'll prefilter the terms to the top N per topic to help # with speeding up the vis (since only so many fit on a # facet but it seems to scale linearly even beyond the # number that fit - should verify) # need to see more for fewer topics prefilter_threshold <- 3000 / length(focus_topics) top_N_terms_per_topic <- OUTPUT_topic_descriptions %>% # we first pre-filter based on the over_representation filter # over-representation as (prob in topic) / (expected prob in topic) # where (expectected prob in topic) = 1 / num_topics, # mutate(term_overrep = term_weight_relative * num_topics) %>% # filter(term_overrep >= over_representation_filter) %>% filter(relevance >= relevance_filter) %>% group_by(topic_name) %>% # select the top 50 per topic (which is more than feasible to plot, and keeping the # number down reduces plotting time) top_n(n = prefilter_threshold, wt = term_weight) %>% filter(str_extract(topic_name, "\\d+") %in% focus_topics) %>% arrange(desc(term_weight)) # we don't want to error out if there's nothing to plot, # and since we're varying the number of topics there will be in some cases if(nrow(top_N_terms_per_topic) == 0) { return(NULL) } if(color_by == "Over-representation") { top_N_terms_per_topic$color_by <- top_N_terms_per_topic$term_overrep } else if(color_by == "Relevance") { top_N_terms_per_topic$color_by <- top_N_terms_per_topic$relevance } else { top_N_terms_per_topic$color_by <- top_N_terms_per_topic$domain_id } top_N_terms_per_topic$topic_name <- factor(top_N_terms_per_topic$topic_name, levels = top_N_terms_per_topic$topic_name %>% unique() %>% str_sort(numeric = TRUE), ordered = TRUE) top_N_terms_per_topic$plot_label <- top_N_terms_per_topic$concept_name %>% first_n_words(6) top_N_terms_per_topic$plot_label <- paste0(top_N_terms_per_topic$plot_label, " (", top_N_terms_per_topic$concept_id, ")") %>% str_wrap(width = 18) print(head(top_N_terms_per_topic)) p <- ggplot(top_N_terms_per_topic) + geom_text_wordcloud(aes(label = plot_label, size = term_weight, color = color_by), rm_outside = TRUE, show.legend = TRUE, lineheight = 0.70, eccentricity = 1) + scale_radius(range = c(font_min, font_max), limits = full_term_weight_range) + scale_color_continuous(limits = full_relevance_range) + #scale_size_area(max_size = font_max) + facet_wrap(~topic_name) if(color_by == "Over-representation") { if(logscale_color) { p <- p + scale_color_continuous(name = "Topic over-\nrepresentation", trans = "log10") } else { p <- p + scale_color_continuous(name = "Topic over-\nrepresentation") } } else if(color_by == "Relevance") { p <- p + scale_color_gradient2(name = "Topic Relevance", high = "dodgerblue1", mid = "grey50", low = "darkorange") #p <- p + scale_color_brewer(palette = "Spectral") } else { p <- p + scale_color_discrete(name = "Domain ID") } p <- p + theme_bw() + theme(strip.text.x = element_text(size = 6, margin(0, 0, 0, 0, "cm"))) + theme(aspect.ratio = 0.75) plot(p) library(gridExtra) g <- arrangeGrob(p) return(to_rdata(g)) } ################################################# ## Global imports and functions included below ## ################################################# to_rdata <- function(data) { output <- new.output() output_fs <- output$fileSystem() saveRDS(data, output_fs$get_path("data.rds", 'w')) } from_rdata <- function(data) { fs <- data$fileSystem() path <- fs$get_path("data.rds", 'r') rds <- readRDS(path) return(rds) } int64fix <- function(df) { is_int64 <- function(x) { "integer64" %in% class(x) } needs_fix <- lapply(df, is_int64) %>% unlist() df[needs_fix] <- lapply(df[needs_fix], as.integer) return(df) } `%+%` <- function(a, b) { paste0(a, b) } first_n_words <- function(charvec, n = 6) { charvec %>% str_split(" +") %>% lapply(function(x) { sub <- head(x, n); ifelse(length(x) > length(sub), paste(c(sub, "..."), collapse = " "), paste(sub, collapse = " ")) }) %>% unlist() }