Skip to content

Commit

Permalink
Merge pull request #17667 from igfoo/igfoo/conc
Browse files Browse the repository at this point in the history
KE2: Be concurrency-safe (hopefully!) and enable concurrency
  • Loading branch information
igfoo authored Oct 7, 2024
2 parents 8711099 + 3aaeefa commit b46be1b
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 67 deletions.
109 changes: 59 additions & 50 deletions java/kotlin-extractor2/src/main/kotlin/KotlinExtractor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import java.util.concurrent.Executors
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit
import kotlinx.coroutines.withContext
import org.jetbrains.kotlin.analysis.api.KaAnalysisApiInternals
import org.jetbrains.kotlin.analysis.api.KaSession
Expand Down Expand Up @@ -220,8 +222,13 @@ fun doAnalysis(
val checkTrapIdentical = false // TODO

analyze(sourceModule) {
val maxThreads = 1 // TODO: Default to 8 temporarily to ensure concurrency,
// TODO: Later, default to $CODEQL_THREADS or Runtime.getRuntime().availableProcessors()
val maxThreads = 8 // TODO: Later, default to $CODEQL_THREADS or Runtime.getRuntime().availableProcessors()
// If a Kotlin coroutine yields, then a thread will be freed up
// and start extracting the next file. We want to avoid having
// lots of TRAP files open at once, so we use a semaphore so that
// we only have `maxThreads` coroutines with an open TRAP file
// at once.
val extractorThreads = Semaphore(maxThreads)
Executors.newFixedThreadPool(maxThreads).asCoroutineDispatcher().use { dispatcher ->

runBlocking {
Expand All @@ -231,58 +238,60 @@ fun doAnalysis(
val dump_psi = System.getenv("CODEQL_EXTRACTOR_JAVA_KOTLIN_DUMP") == "true"
for (psiFile in psiFiles) {
launch {
if (psiFile is KtFile) {
if (dump_psi) {
val showWhitespaces = false
val showRanges = true
loggerBase.info(dtw, DebugUtil.psiToString(psiFile, showWhitespaces, showRanges))
}
val fileExtractionProblems = FileExtractionProblems(invocationExtractionProblems)
try {
val fileDiagnosticTrapWriter = dtw.makeSourceFileTrapWriter(psiFile, true)
fileDiagnosticTrapWriter.writeCompilation_compiling_files(
compilation,
fileNumber,
fileDiagnosticTrapWriter.fileId
)
doFile(
fileNumber,
compression,
/*
OLD: KE1
fileExtractionProblems,
invocationTrapFile,
*/
fileDiagnosticTrapWriter,
loggerBase,
checkTrapIdentical,
trapDir,
srcDir,
psiFile,
extractorThreads.withPermit {
if (psiFile is KtFile) {
if (dump_psi) {
val showWhitespaces = false
val showRanges = true
loggerBase.info(dtw, DebugUtil.psiToString(psiFile, showWhitespaces, showRanges))
}
val fileExtractionProblems = FileExtractionProblems(invocationExtractionProblems)
try {
val fileDiagnosticTrapWriter = dtw.makeSourceFileTrapWriter(psiFile, true)
fileDiagnosticTrapWriter.writeCompilation_compiling_files(
compilation,
fileNumber,
fileDiagnosticTrapWriter.fileId
)
doFile(
fileNumber,
compression,
/*
OLD: KE1
fileExtractionProblems,
invocationTrapFile,
*/
fileDiagnosticTrapWriter,
loggerBase,
checkTrapIdentical,
trapDir,
srcDir,
psiFile,
/*
OLD: KE1
primitiveTypeMapping,
pluginContext,
globalExtensionState
*/
)
fileDiagnosticTrapWriter.writeCompilation_compiling_files_completed(
compilation,
fileNumber,
fileExtractionProblems.extractionResult()
)
// We catch Throwable rather than Exception, as we want to
// continue trying to extract everything else even if we get a
// stack overflow or an assertion failure in one file.
} catch (e: Throwable) {
/*
OLD: KE1
primitiveTypeMapping,
pluginContext,
globalExtensionState
logger.error("Extraction failed while extracting '${psiFile.virtualFilePath}'.", e)
fileExtractionProblems.setNonRecoverableProblem()
*/
)
fileDiagnosticTrapWriter.writeCompilation_compiling_files_completed(
compilation,
fileNumber,
fileExtractionProblems.extractionResult()
)
// We catch Throwable rather than Exception, as we want to
// continue trying to extract everything else even if we get a
// stack overflow or an assertion failure in one file.
} catch (e: Throwable) {
/*
OLD: KE1
logger.error("Extraction failed while extracting '${psiFile.virtualFilePath}'.", e)
fileExtractionProblems.setNonRecoverableProblem()
*/
}
} else {
System.out.println("Warning: Not a KtFile")
}
} else {
System.out.println("Warning: Not a KtFile")
}
}
fileNumber += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ open class KotlinUsesExtractor(
val pkg = f.packageFqName.asString()
val jvmName = getFileClassName(f)
val id = extractFileClass(pkg, jvmName)
if (tw.lm.fileClassLocationsExtracted.add(f)) {
if (tw.lm.markFileClassLocationAsExtracted(f)) {
val fileId = tw.mkFileId(f.virtualFilePath, false)
val locId = tw.getWholeFileLocation(fileId)
tw.writeHasLocation(id, locId)
Expand Down
60 changes: 47 additions & 13 deletions java/kotlin-extractor2/src/main/kotlin/TrapWriter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import com.semmle.extractor.java.PopulateFile
import com.semmle.util.unicode.UTF8Util
import java.io.BufferedWriter
import java.io.File
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
/*
OLD: KE1
import org.jetbrains.kotlin.ir.IrElement
Expand All @@ -28,16 +30,34 @@ import org.jetbrains.kotlin.psi.*
* names, and maintains a mapping from keys (`@"..."`) to labels.
*/
class TrapLabelManager {
/**
* The lock that controls access to the label manager state.
* While we can make a thread-safe MutableMap with
* Collections.synchronizedMap and use getOrPut, that doesn't
* guarantee not to run the `defaultValue` function when it isn't
* necessary, which makes it useless for our purposes.
* TODO: We only actually need this for the diagnostic TRAP file. Make it optional?
*/
private val lock = ReentrantLock()

/** The next integer to use as a label name. */
private var nextInt: Int = 100

/** Returns a fresh label. */
fun <T : AnyDbType> getFreshLabel(): Label<T> {
return IntLabel(nextInt++)
lock.withLock {
return IntLabel(nextInt++)
}
}

/** A mapping from a key (`@"..."`) to the label defined to be that key, if any. */
val labelMapping: MutableMap<String, Label<*>> = mutableMapOf<String, Label<*>>()
private val labelMapping: MutableMap<String, Label<*>> = mutableMapOf<String, Label<*>>()

fun <T> withLabelMapping(action: (MutableMap<String, Label<*>>) -> T): T {
lock.withLock {
return action(labelMapping)
}
}

/*
OLD: KE1
Expand All @@ -60,13 +80,25 @@ class TrapLabelManager {
* This allows us to keep track of whether we've written the location already in this TRAP file,
* to avoid duplication.
*/
val fileClassLocationsExtracted = HashSet<KtFile>()
private val fileClassLocationsExtracted = HashSet<KtFile>()

/**
* Indicate that we want `file`'s file class location marked as extracted.
* Returns true if we need to actually write the TRAP for it, or false
* if it's already been done.
*/
fun markFileClassLocationAsExtracted(file: KtFile): Boolean {
lock.withLock {
return fileClassLocationsExtracted.add(file)
}
}
}

/**
* A `TrapWriter` is used to write TRAP to a particular TRAP file. There may be multiple
* `TrapWriter`s for the same file, as different instances will have different additional state, but
* they must all share the same `TrapLabelManager` and `BufferedWriter`.
* `BasicLogger`s, `TrapLabelManager` and `BufferedWriter` are threadsafe, so `TrapWriter`s are too.
*/
abstract class TrapWriter(
protected val basicLogger: BasicLogger,
Expand All @@ -85,7 +117,7 @@ abstract class TrapWriter(
TODO: Inline this if it can remain private
*/
private fun <T : AnyDbType> getExistingLabelFor(key: String): Label<T>? {
return lm.labelMapping.get(key)?.cast<T>()
return lm.withLabelMapping { labelMapping -> labelMapping.get(key)?.cast<T>() }
}

/**
Expand All @@ -94,15 +126,17 @@ abstract class TrapWriter(
*/
@JvmOverloads // Needed so Java can call a method with an optional argument
fun <T : AnyDbType> getLabelFor(key: String, initialise: (Label<T>) -> Unit = {}): Label<T> {
val maybeLabel: Label<T>? = getExistingLabelFor(key)
if (maybeLabel == null) {
val label: Label<T> = lm.getFreshLabel()
lm.labelMapping.put(key, label)
writeTrap("$label = $key\n")
initialise(label)
return label
} else {
return maybeLabel
return lm.withLabelMapping { labelMapping ->
val maybeLabel: Label<T>? = getExistingLabelFor(key)
if (maybeLabel == null) {
val label: Label<T> = lm.getFreshLabel()
labelMapping.put(key, label)
writeTrap("$label = $key\n")
initialise(label)
label
} else {
maybeLabel
}
}
}

Expand Down
8 changes: 5 additions & 3 deletions java/kotlin-extractor2/src/main/kotlin/utils/Logger.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.github.codeql

import com.intellij.psi.PsiElement
import java.io.BufferedWriter
import java.io.File
import java.io.FileWriter
import java.io.OutputStreamWriter
Expand Down Expand Up @@ -127,15 +128,16 @@ class LoggerBase(val diagnosticCounter: DiagnosticCounter) : BasicLogger {
verbosity = System.getenv("CODEQL_EXTRACTOR_KOTLIN_VERBOSITY")?.toIntOrNull() ?: 3
}

private val logStream: Writer
// Use BufferedWriter as it is threadsafe
private val logStream: BufferedWriter

init {
val extractorLogDir = System.getenv("CODEQL_EXTRACTOR_JAVA_LOG_DIR")
if (extractorLogDir == null || extractorLogDir == "") {
logStream = OutputStreamWriter(System.out)
logStream = BufferedWriter(OutputStreamWriter(System.out))
} else {
val logFile = File.createTempFile("kotlin-extractor.", ".log", File(extractorLogDir))
logStream = FileWriter(logFile)
logStream = BufferedWriter(FileWriter(logFile))
}
}

Expand Down

0 comments on commit b46be1b

Please sign in to comment.