Working on compression algorithm for judgements

This commit is contained in:
nise.moe 2024-02-23 16:12:10 +01:00
parent 3af2b64300
commit fbd61e0fa1
10 changed files with 229 additions and 65 deletions

View File

@ -33,6 +33,12 @@
<artifactId>spring-boot-starter-security</artifactId>
</dependency>
<dependency>
<groupId>com.aayushatharva.brotli4j</groupId>
<artifactId>brotli4j</artifactId>
<version>1.16.0</version>
</dependency>
<!-- Test containers -->
<dependency>
<groupId>org.testcontainers</groupId>

View File

@ -297,6 +297,11 @@ open class Scores(
*/
val SLIDEREND_RELEASE_STANDARD_DEVIATION_ADJUSTED: TableField<ScoresRecord, Double?> = createField(DSL.name("sliderend_release_standard_deviation_adjusted"), SQLDataType.DOUBLE, this, "")
/**
* The column <code>public.scores.judgements</code>.
*/
val JUDGEMENTS: TableField<ScoresRecord, ByteArray?> = createField(DSL.name("judgements"), SQLDataType.BLOB, this, "")
private constructor(alias: Name, aliased: Table<ScoresRecord>?): this(alias, null, null, aliased, null)
private constructor(alias: Name, aliased: Table<ScoresRecord>?, parameters: Array<Field<*>?>?): this(alias, null, null, aliased, parameters)

View File

@ -201,6 +201,10 @@ open class ScoresRecord private constructor() : UpdatableRecordImpl<ScoresRecord
set(value): Unit = set(44, value)
get(): Double? = get(44) as Double?
open var judgements: ByteArray?
set(value): Unit = set(45, value)
get(): ByteArray? = get(45) as ByteArray?
// -------------------------------------------------------------------------
// Primary key information
// -------------------------------------------------------------------------
@ -210,7 +214,7 @@ open class ScoresRecord private constructor() : UpdatableRecordImpl<ScoresRecord
/**
* Create a detached, initialised ScoresRecord
*/
constructor(id: Int? = null, beatmapId: Int? = null, count_100: Int? = null, count_300: Int? = null, count_50: Int? = null, countMiss: Int? = null, date: LocalDateTime? = null, maxCombo: Int? = null, mods: Int? = null, perfect: Boolean? = null, pp: Double? = null, rank: String? = null, replayAvailable: Boolean? = null, replayId: Long? = null, score: Long? = null, userId: Long? = null, replay: ByteArray? = null, ur: Double? = null, frametime: Double? = null, edgeHits: Int? = null, snaps: Int? = null, isBanned: Boolean? = null, adjustedUr: Double? = null, meanError: Double? = null, errorVariance: Double? = null, errorStandardDeviation: Double? = null, minimumError: Double? = null, maximumError: Double? = null, errorRange: Double? = null, errorCoefficientOfVariation: Double? = null, errorKurtosis: Double? = null, errorSkewness: Double? = null, sentDiscordNotification: Boolean? = null, addedAt: OffsetDateTime? = null, version: Int? = null, keypressesTimes: Array<Double?>? = null, keypressesMedian: Double? = null, keypressesStandardDeviation: Double? = null, sliderendReleaseTimes: Array<Double?>? = null, sliderendReleaseMedian: Double? = null, sliderendReleaseStandardDeviation: Double? = null, keypressesMedianAdjusted: Double? = null, keypressesStandardDeviationAdjusted: Double? = null, sliderendReleaseMedianAdjusted: Double? = null, sliderendReleaseStandardDeviationAdjusted: Double? = null): this() {
constructor(id: Int? = null, beatmapId: Int? = null, count_100: Int? = null, count_300: Int? = null, count_50: Int? = null, countMiss: Int? = null, date: LocalDateTime? = null, maxCombo: Int? = null, mods: Int? = null, perfect: Boolean? = null, pp: Double? = null, rank: String? = null, replayAvailable: Boolean? = null, replayId: Long? = null, score: Long? = null, userId: Long? = null, replay: ByteArray? = null, ur: Double? = null, frametime: Double? = null, edgeHits: Int? = null, snaps: Int? = null, isBanned: Boolean? = null, adjustedUr: Double? = null, meanError: Double? = null, errorVariance: Double? = null, errorStandardDeviation: Double? = null, minimumError: Double? = null, maximumError: Double? = null, errorRange: Double? = null, errorCoefficientOfVariation: Double? = null, errorKurtosis: Double? = null, errorSkewness: Double? = null, sentDiscordNotification: Boolean? = null, addedAt: OffsetDateTime? = null, version: Int? = null, keypressesTimes: Array<Double?>? = null, keypressesMedian: Double? = null, keypressesStandardDeviation: Double? = null, sliderendReleaseTimes: Array<Double?>? = null, sliderendReleaseMedian: Double? = null, sliderendReleaseStandardDeviation: Double? = null, keypressesMedianAdjusted: Double? = null, keypressesStandardDeviationAdjusted: Double? = null, sliderendReleaseMedianAdjusted: Double? = null, sliderendReleaseStandardDeviationAdjusted: Double? = null, judgements: ByteArray? = null): this() {
this.id = id
this.beatmapId = beatmapId
this.count_100 = count_100
@ -256,6 +260,7 @@ open class ScoresRecord private constructor() : UpdatableRecordImpl<ScoresRecord
this.keypressesStandardDeviationAdjusted = keypressesStandardDeviationAdjusted
this.sliderendReleaseMedianAdjusted = sliderendReleaseMedianAdjusted
this.sliderendReleaseStandardDeviationAdjusted = sliderendReleaseStandardDeviationAdjusted
this.judgements = judgements
resetChangedOnNotNull()
}
}

View File

@ -1,14 +1,15 @@
package com.nisemoe.nise.database
import com.nisemoe.generated.tables.records.ScoresJudgementsRecord
import com.nisemoe.generated.tables.records.ScoresRecord
import com.nisemoe.generated.tables.references.*
import com.nisemoe.nise.*
import com.nisemoe.nise.osu.Mod
import com.nisemoe.nise.service.AuthService
import com.nisemoe.nise.service.CompressJudgements
import org.jooq.Condition
import org.jooq.DSLContext
import org.jooq.Record
import org.jooq.Result
import org.jooq.impl.DSL
import org.jooq.impl.DSL.avg
import org.springframework.stereotype.Service
@ -19,7 +20,8 @@ import kotlin.math.roundToInt
class ScoreService(
private val dslContext: DSLContext,
private val beatmapService: BeatmapService,
private val authService: AuthService
private val authService: AuthService,
private val compressJudgements: CompressJudgements
) {
companion object {
@ -367,17 +369,22 @@ class ScoreService(
}
fun getHitDistribution(scoreId: Int): Map<Int, DistributionEntry> {
val judgements = dslContext.selectFrom(SCORES_JUDGEMENTS)
.where(SCORES_JUDGEMENTS.SCORE_ID.eq(scoreId))
.fetchInto(ScoresJudgementsRecord::class.java)
val judgementsRecord = dslContext.select(SCORES.JUDGEMENTS)
.from(SCORES)
.where(SCORES.ID.eq(scoreId))
.fetchOneInto(ScoresRecord::class.java) ?: return emptyMap()
if(judgementsRecord.judgements == null) return emptyMap()
val judgements = compressJudgements.deserialize(judgementsRecord.judgements!!)
val errorDistribution = mutableMapOf<Int, MutableMap<String, Int>>()
var totalHits = 0
judgements.forEach { hit ->
val error = (hit.error!!.roundToInt() / 2) * 2
val error = (hit.error.roundToInt() / 2) * 2
val judgementType = hit.type // Assuming this is how you get the judgement type
errorDistribution.getOrPut(error) { mutableMapOf("Miss" to 0, "300" to 0, "100" to 0, "50" to 0) }
errorDistribution.getOrPut(error) { mutableMapOf("MISS" to 0, "THREE_HUNDRED" to 0, "ONE_HUNDRED" to 0, "FIFTY" to 0) }
.apply {
this[judgementType.toString()] = this.getOrDefault(judgementType.toString(), 0) + 1
}
@ -387,10 +394,10 @@ class ScoreService(
return errorDistribution.mapValues { (_, judgementCounts) ->
judgementCounts.values.sum()
DistributionEntry(
percentageMiss = (judgementCounts.getOrDefault("Miss", 0).toDouble() / totalHits) * 100,
percentage300 = (judgementCounts.getOrDefault("300", 0).toDouble() / totalHits) * 100,
percentage100 = (judgementCounts.getOrDefault("100", 0).toDouble() / totalHits) * 100,
percentage50 = (judgementCounts.getOrDefault("50", 0).toDouble() / totalHits) * 100
percentageMiss = (judgementCounts.getOrDefault("MISS", 0).toDouble() / totalHits) * 100,
percentage300 = (judgementCounts.getOrDefault("THREE_HUNDRED", 0).toDouble() / totalHits) * 100,
percentage100 = (judgementCounts.getOrDefault("ONE_HUNDRED", 0).toDouble() / totalHits) * 100,
percentage50 = (judgementCounts.getOrDefault("FIFTY", 0).toDouble() / totalHits) * 100
)
}
}

View File

@ -2,9 +2,8 @@ package com.nisemoe.nise.scheduler
import com.nisemoe.generated.tables.records.ScoresRecord
import com.nisemoe.generated.tables.references.SCORES
import com.nisemoe.generated.tables.references.SCORES_JUDGEMENTS
import com.nisemoe.nise.integrations.CircleguardService
import com.nisemoe.nise.Format.Companion.fromJudgementType
import com.nisemoe.nise.service.CompressJudgements
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.joinAll
@ -16,23 +15,27 @@ import org.springframework.beans.factory.annotation.Value
import org.springframework.context.annotation.Profile
import org.springframework.scheduling.annotation.Scheduled
import org.springframework.stereotype.Service
import java.time.OffsetDateTime
@Profile("old_scores")
@Service
class FixOldScores(
private val dslContext: DSLContext,
private val circleguardService: CircleguardService
private val circleguardService: CircleguardService,
private val compressJudgements: CompressJudgements
){
companion object {
const val CURRENT_VERSION = 6
}
@Value("\${OLD_SCORES_WORKERS:4}")
private var workers: Int = 4
@Value("\${OLD_SCORES_PAGE_SIZE:5000}")
private var pageSize: Int = 5000
val CURRENT_VERSION = 5
private val logger = LoggerFactory.getLogger(javaClass)
data class Task(val offset: Int, val limit: Int)
@ -139,6 +142,7 @@ class FixOldScores(
.set(SCORES.SLIDEREND_RELEASE_MEDIAN_ADJUSTED, processedReplay.sliderend_release_median_adjusted)
.set(SCORES.SLIDEREND_RELEASE_STANDARD_DEVIATION, processedReplay.sliderend_release_standard_deviation)
.set(SCORES.SLIDEREND_RELEASE_STANDARD_DEVIATION_ADJUSTED, processedReplay.sliderend_release_standard_deviation_adjusted)
.set(SCORES.JUDGEMENTS, compressJudgements.serialize(processedReplay.judgements))
.where(SCORES.REPLAY_ID.eq(score.replayId))
.returningResult(SCORES.ID)
.fetchOne()?.getValue(SCORES.ID)
@ -152,22 +156,6 @@ class FixOldScores(
.set(SCORES.VERSION, CURRENT_VERSION)
.where(SCORES.ID.eq(scoreId))
.execute()
val judgementsExist = dslContext.fetchExists(SCORES_JUDGEMENTS, SCORES_JUDGEMENTS.SCORE_ID.eq(scoreId))
if(!judgementsExist) {
for (judgement in processedReplay.judgements) {
dslContext.insertInto(SCORES_JUDGEMENTS)
.set(SCORES_JUDGEMENTS.TIME, judgement.time)
.set(SCORES_JUDGEMENTS.X, judgement.x)
.set(SCORES_JUDGEMENTS.Y, judgement.y)
.set(SCORES_JUDGEMENTS.TYPE, fromJudgementType(judgement.type))
.set(SCORES_JUDGEMENTS.DISTANCE_EDGE, judgement.distance_edge)
.set(SCORES_JUDGEMENTS.DISTANCE_CENTER, judgement.distance_center)
.set(SCORES_JUDGEMENTS.ERROR, judgement.error)
.set(SCORES_JUDGEMENTS.SCORE_ID, scoreId)
.execute()
}
}
}
}

View File

@ -5,7 +5,6 @@ import com.nisemoe.generated.tables.references.*
import com.nisemoe.konata.Replay
import com.nisemoe.konata.ReplaySetComparison
import com.nisemoe.konata.compareReplaySet
import com.nisemoe.nise.Format.Companion.fromJudgementType
import com.nisemoe.nise.UserQueueDetails
import com.nisemoe.nise.database.ScoreService
import com.nisemoe.nise.database.UserService
@ -16,6 +15,7 @@ import com.nisemoe.nise.osu.Mod
import com.nisemoe.nise.osu.OsuApi
import com.nisemoe.nise.osu.OsuApiModels
import com.nisemoe.nise.service.CacheService
import com.nisemoe.nise.service.CompressJudgements
import com.nisemoe.nise.service.UpdateUserQueueService
import kotlinx.serialization.Serializable
import org.jooq.DSLContext
@ -48,7 +48,8 @@ class ImportScores(
private val scoreService: ScoreService,
private val updateUserQueueService: UpdateUserQueueService,
private val circleguardService: CircleguardService,
private val messagingTemplate: SimpMessagingTemplate
private val messagingTemplate: SimpMessagingTemplate,
private val compressJudgements: CompressJudgements
) : InitializingBean {
private val userToUpdateBucket = mutableListOf<Long>()
@ -65,19 +66,22 @@ class ImportScores(
this.cacheService.deleteVariable("userToUpdateBucket")
}
}
val CURRENT_VERSION = 5
companion object {
const val CURRENT_VERSION = 6
const val SLEEP_AFTER_API_CALL = 500L
const val UPDATE_USER_EVERY_DAYS = 7L
const val UPDATE_BANNED_USERS_EVERY_DAYS = 3L
}
@Value("\${WEBHOOK_URL}")
private lateinit var webhookUrl: String
private val logger = LoggerFactory.getLogger(javaClass)
private final val sleepTimeInMs = 500L
private final val UPDATE_USER_EVERY_DAYS = 7L
private final val UPDATE_BANNED_USERS_EVERY_DAYS = 3L
data class UpdaterStatistics(
var currentBeatmapsetPage: Int = 0,
@ -165,7 +169,7 @@ class ImportScores(
this.logger.info("Recent scores: ${recentUserScores?.size}")
this.logger.info("First place scores: ${firstPlaceUserScores?.size}")
Thread.sleep(this.sleepTimeInMs)
Thread.sleep(SLEEP_AFTER_API_CALL)
if(topUserScores == null || recentUserScores == null || firstPlaceUserScores == null) {
this.logger.error("Failed to fetch top scores for user with id = $userId")
@ -297,7 +301,7 @@ class ImportScores(
.where(SCORES.USER_ID.eq(userId))
.execute()
}
Thread.sleep(this.sleepTimeInMs)
Thread.sleep(SLEEP_AFTER_API_CALL)
}
this.cacheService.setVariable("lastBannedUserCheck", LocalDateTime.now())
@ -313,11 +317,11 @@ class ImportScores(
do {
val searchResults = this.osuApi.searchBeatmapsets(cursor = cursor)
this.statistics.currentBeatmapsetPage++
Thread.sleep(this.sleepTimeInMs)
Thread.sleep(SLEEP_AFTER_API_CALL)
if (searchResults == null) {
this.logger.error("Failed to fetch beatmapsets. Skipping to next page...")
Thread.sleep(this.sleepTimeInMs * 2)
Thread.sleep(SLEEP_AFTER_API_CALL * 2)
return
}
@ -365,11 +369,11 @@ class ImportScores(
}
val beatmapScores = this.osuApi.getTopBeatmapScores(beatmapId = beatmap.id)
Thread.sleep(this.sleepTimeInMs)
Thread.sleep(SLEEP_AFTER_API_CALL)
if (beatmapScores == null) {
this.logger.error("Failed to fetch beatmap scores for beatmapId = ${beatmap.id}")
Thread.sleep(this.sleepTimeInMs * 2)
Thread.sleep(SLEEP_AFTER_API_CALL * 2)
continue
}
@ -394,7 +398,7 @@ class ImportScores(
if(this.userToUpdateBucket.size >= 50) {
val usersBucket = this.osuApi.getUsersBatch(this.userToUpdateBucket.toList())
Thread.sleep(this.sleepTimeInMs)
Thread.sleep(SLEEP_AFTER_API_CALL)
if(usersBucket == null) {
this.logger.error("Failed to fetch users batch.")
continue
@ -648,6 +652,7 @@ class ImportScores(
.set(SCORES.SLIDEREND_RELEASE_MEDIAN_ADJUSTED, processedReplay.sliderend_release_median_adjusted)
.set(SCORES.SLIDEREND_RELEASE_STANDARD_DEVIATION, processedReplay.sliderend_release_standard_deviation)
.set(SCORES.SLIDEREND_RELEASE_STANDARD_DEVIATION_ADJUSTED, processedReplay.sliderend_release_standard_deviation_adjusted)
.set(SCORES.JUDGEMENTS, compressJudgements.serialize(processedReplay.judgements))
.where(SCORES.REPLAY_ID.eq(score.best_id))
.returningResult(SCORES.ID)
.fetchOne()?.getValue(SCORES.ID)
@ -683,19 +688,6 @@ class ImportScores(
this.updateUserQueueService.insertUser(score.user_id)
}
}
for (judgement in processedReplay.judgements) {
dslContext.insertInto(SCORES_JUDGEMENTS)
.set(SCORES_JUDGEMENTS.TIME, judgement.time)
.set(SCORES_JUDGEMENTS.X, judgement.x)
.set(SCORES_JUDGEMENTS.Y, judgement.y)
.set(SCORES_JUDGEMENTS.TYPE, fromJudgementType(judgement.type))
.set(SCORES_JUDGEMENTS.DISTANCE_EDGE, judgement.distance_edge)
.set(SCORES_JUDGEMENTS.DISTANCE_CENTER, judgement.distance_center)
.set(SCORES_JUDGEMENTS.ERROR, judgement.error)
.set(SCORES_JUDGEMENTS.SCORE_ID, scoreId)
.execute()
}
}
}

View File

@ -0,0 +1,78 @@
package com.nisemoe.nise.service
import com.aayushatharva.brotli4j.Brotli4jLoader
import com.aayushatharva.brotli4j.decoder.Decoder
import com.aayushatharva.brotli4j.encoder.Encoder
import com.nisemoe.nise.integrations.CircleguardService
import org.springframework.stereotype.Service
import java.nio.ByteBuffer
import kotlin.math.round
@Service
class CompressJudgements {
val brotliParameters: Encoder.Parameters = Encoder.Parameters()
.setQuality(11)
init {
Brotli4jLoader.ensureAvailability()
}
fun serialize(judgements: List<CircleguardService.ScoreJudgement>): ByteArray {
val buffer = ByteBuffer.allocate(judgements.size * (2 + 4 + 4 + 1 + 2 + 2 + 2))
var lastTime = 0.0
judgements.forEach { judgement ->
val deltaTime = (judgement.time - lastTime).toInt()
buffer.putShort(deltaTime.toShort())
buffer.putInt((round(judgement.x * 1000)).toInt())
buffer.putInt((round(judgement.y * 1000)).toInt())
buffer.put(judgement.type.ordinal.toByte())
buffer.putShort((judgement.distance_center * 100).toInt().toShort())
buffer.putShort((judgement.distance_edge * 100).toInt().toShort())
buffer.putShort(judgement.error.toInt().toShort())
lastTime = judgement.time
}
return Encoder.compress(buffer.array(), brotliParameters)
}
fun deserialize(compressedData: ByteArray): List<CircleguardService.ScoreJudgement> {
val data = Decoder.decompress(compressedData).decompressedData
val buffer = ByteBuffer.wrap(data)
val judgements = mutableListOf<CircleguardService.ScoreJudgement>()
var lastTime = 0.0
while (buffer.hasRemaining()) {
val deltaTime = buffer.short.toInt()
lastTime += deltaTime
val deltaX = buffer.getInt()
val deltaY = buffer.getInt()
val typeOrdinal = buffer.get().toInt()
val type = CircleguardService.JudgementType.entries[typeOrdinal]
val distanceCenter = buffer.short.toInt() / 100.0
val distanceEdge = buffer.short.toInt() / 100.0
val error = buffer.short.toInt()
judgements.add(
CircleguardService.ScoreJudgement(
time = lastTime,
x = deltaX / 1000.0,
y = deltaY / 1000.0,
type = type,
distance_center = distanceCenter,
distance_edge = distanceEdge,
error = error.toDouble()
))
}
return judgements
}
}

View File

@ -4,7 +4,7 @@ spring.datasource.password=${POSTGRES_PASS:postgres}
spring.datasource.driver-class-name=org.postgresql.Driver
spring.datasource.name=HikariPool-PostgreSQL
spring.flyway.enabled=true
spring.flyway.enabled=${FLYWAY_ENABLED:true}
spring.flyway.schemas=public
# Batching

View File

@ -0,0 +1,2 @@
ALTER TABLE public.scores
ADD COLUMN judgements bytea null;

View File

@ -0,0 +1,81 @@
package com.nisemoe.nise.scheduler
import com.nisemoe.generated.tables.records.ScoresRecord
import com.nisemoe.generated.tables.references.SCORES
import com.nisemoe.nise.database.UserService
import com.nisemoe.nise.integrations.CircleguardService
import com.nisemoe.nise.osu.OsuApi
import com.nisemoe.nise.osu.TokenService
import com.nisemoe.nise.service.AuthService
import com.nisemoe.nise.service.CacheService
import com.nisemoe.nise.service.CompressJudgements
import kotlinx.serialization.json.Json
import org.jooq.DSLContext
import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.autoconfigure.flyway.FlywayAutoConfiguration
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.boot.test.mock.mockito.MockBean
import org.springframework.context.annotation.Import
import org.springframework.test.context.ActiveProfiles
@SpringBootTest
@ActiveProfiles("postgres")
@MockBean(GlobalCache::class, UserService::class)
@Disabled
class JudgementCompressionTest {
private val logger = LoggerFactory.getLogger(javaClass)
private val circleguardService = CircleguardService()
private val compressJudgements = CompressJudgements()
@Autowired
private lateinit var dslContext: DSLContext
@Test
fun compressionIntegrityTest() {
val scores = dslContext.select(SCORES.REPLAY, SCORES.BEATMAP_ID, SCORES.MODS)
.from(SCORES)
.where(SCORES.REPLAY.isNotNull)
.limit(100)
.fetchInto(ScoresRecord::class.java)
for(score in scores) {
val result = score.replay?.let {
this.circleguardService.processReplay(
replayData = it.decodeToString(),
beatmapId = score.beatmapId!!,
mods = score.mods!!
).get()
}
if(result == null) {
this.logger.warn("Failed to process replay for beatmap {} with mods {}", score.beatmapId, score.mods)
continue
}
val compressedData = compressJudgements.serialize(result.judgements)
this.logger.info("JSON size: {} bytes", String.format("%,d", Json.encodeToString(CircleguardService.ReplayResponse.serializer(), result).length))
this.logger.info("Compressed (Brotli) size: {} bytes", String.format("%,d", compressedData.size))
val deserializedData = compressJudgements.deserialize(compressedData)
for (entry in result.judgements) {
assertEquals(entry.time, deserializedData[result.judgements.indexOf(entry)].time)
assertEquals(entry.x, deserializedData[result.judgements.indexOf(entry)].x, 0.01)
assertEquals(entry.y, deserializedData[result.judgements.indexOf(entry)].y, 0.01)
assertEquals(entry.type, deserializedData[result.judgements.indexOf(entry)].type)
assertEquals(entry.distance_center, deserializedData[result.judgements.indexOf(entry)].distance_center, 0.1)
assertEquals(entry.distance_edge, deserializedData[result.judgements.indexOf(entry)].distance_edge, 0.1)
assertEquals(entry.error, deserializedData[result.judgements.indexOf(entry)].error)
}
}
}
}