Working on compression algorithm for judgements

This commit is contained in:
nise.moe 2024-02-23 16:12:10 +01:00
parent 3af2b64300
commit fbd61e0fa1
10 changed files with 229 additions and 65 deletions

View File

@ -33,6 +33,12 @@
<artifactId>spring-boot-starter-security</artifactId> <artifactId>spring-boot-starter-security</artifactId>
</dependency> </dependency>
<dependency>
<groupId>com.aayushatharva.brotli4j</groupId>
<artifactId>brotli4j</artifactId>
<version>1.16.0</version>
</dependency>
<!-- Test containers --> <!-- Test containers -->
<dependency> <dependency>
<groupId>org.testcontainers</groupId> <groupId>org.testcontainers</groupId>

View File

@ -297,6 +297,11 @@ open class Scores(
*/ */
val SLIDEREND_RELEASE_STANDARD_DEVIATION_ADJUSTED: TableField<ScoresRecord, Double?> = createField(DSL.name("sliderend_release_standard_deviation_adjusted"), SQLDataType.DOUBLE, this, "") val SLIDEREND_RELEASE_STANDARD_DEVIATION_ADJUSTED: TableField<ScoresRecord, Double?> = createField(DSL.name("sliderend_release_standard_deviation_adjusted"), SQLDataType.DOUBLE, this, "")
/**
* The column <code>public.scores.judgements</code>.
*/
val JUDGEMENTS: TableField<ScoresRecord, ByteArray?> = createField(DSL.name("judgements"), SQLDataType.BLOB, this, "")
private constructor(alias: Name, aliased: Table<ScoresRecord>?): this(alias, null, null, aliased, null) private constructor(alias: Name, aliased: Table<ScoresRecord>?): this(alias, null, null, aliased, null)
private constructor(alias: Name, aliased: Table<ScoresRecord>?, parameters: Array<Field<*>?>?): this(alias, null, null, aliased, parameters) private constructor(alias: Name, aliased: Table<ScoresRecord>?, parameters: Array<Field<*>?>?): this(alias, null, null, aliased, parameters)

View File

@ -201,6 +201,10 @@ open class ScoresRecord private constructor() : UpdatableRecordImpl<ScoresRecord
set(value): Unit = set(44, value) set(value): Unit = set(44, value)
get(): Double? = get(44) as Double? get(): Double? = get(44) as Double?
open var judgements: ByteArray?
set(value): Unit = set(45, value)
get(): ByteArray? = get(45) as ByteArray?
// ------------------------------------------------------------------------- // -------------------------------------------------------------------------
// Primary key information // Primary key information
// ------------------------------------------------------------------------- // -------------------------------------------------------------------------
@ -210,7 +214,7 @@ open class ScoresRecord private constructor() : UpdatableRecordImpl<ScoresRecord
/** /**
* Create a detached, initialised ScoresRecord * Create a detached, initialised ScoresRecord
*/ */
constructor(id: Int? = null, beatmapId: Int? = null, count_100: Int? = null, count_300: Int? = null, count_50: Int? = null, countMiss: Int? = null, date: LocalDateTime? = null, maxCombo: Int? = null, mods: Int? = null, perfect: Boolean? = null, pp: Double? = null, rank: String? = null, replayAvailable: Boolean? = null, replayId: Long? = null, score: Long? = null, userId: Long? = null, replay: ByteArray? = null, ur: Double? = null, frametime: Double? = null, edgeHits: Int? = null, snaps: Int? = null, isBanned: Boolean? = null, adjustedUr: Double? = null, meanError: Double? = null, errorVariance: Double? = null, errorStandardDeviation: Double? = null, minimumError: Double? = null, maximumError: Double? = null, errorRange: Double? = null, errorCoefficientOfVariation: Double? = null, errorKurtosis: Double? = null, errorSkewness: Double? = null, sentDiscordNotification: Boolean? = null, addedAt: OffsetDateTime? = null, version: Int? = null, keypressesTimes: Array<Double?>? = null, keypressesMedian: Double? = null, keypressesStandardDeviation: Double? = null, sliderendReleaseTimes: Array<Double?>? = null, sliderendReleaseMedian: Double? = null, sliderendReleaseStandardDeviation: Double? = null, keypressesMedianAdjusted: Double? = null, keypressesStandardDeviationAdjusted: Double? = null, sliderendReleaseMedianAdjusted: Double? = null, sliderendReleaseStandardDeviationAdjusted: Double? = null): this() { constructor(id: Int? = null, beatmapId: Int? = null, count_100: Int? = null, count_300: Int? = null, count_50: Int? = null, countMiss: Int? = null, date: LocalDateTime? = null, maxCombo: Int? = null, mods: Int? = null, perfect: Boolean? = null, pp: Double? = null, rank: String? = null, replayAvailable: Boolean? = null, replayId: Long? = null, score: Long? = null, userId: Long? = null, replay: ByteArray? = null, ur: Double? = null, frametime: Double? = null, edgeHits: Int? = null, snaps: Int? = null, isBanned: Boolean? = null, adjustedUr: Double? = null, meanError: Double? = null, errorVariance: Double? = null, errorStandardDeviation: Double? = null, minimumError: Double? = null, maximumError: Double? = null, errorRange: Double? = null, errorCoefficientOfVariation: Double? = null, errorKurtosis: Double? = null, errorSkewness: Double? = null, sentDiscordNotification: Boolean? = null, addedAt: OffsetDateTime? = null, version: Int? = null, keypressesTimes: Array<Double?>? = null, keypressesMedian: Double? = null, keypressesStandardDeviation: Double? = null, sliderendReleaseTimes: Array<Double?>? = null, sliderendReleaseMedian: Double? = null, sliderendReleaseStandardDeviation: Double? = null, keypressesMedianAdjusted: Double? = null, keypressesStandardDeviationAdjusted: Double? = null, sliderendReleaseMedianAdjusted: Double? = null, sliderendReleaseStandardDeviationAdjusted: Double? = null, judgements: ByteArray? = null): this() {
this.id = id this.id = id
this.beatmapId = beatmapId this.beatmapId = beatmapId
this.count_100 = count_100 this.count_100 = count_100
@ -256,6 +260,7 @@ open class ScoresRecord private constructor() : UpdatableRecordImpl<ScoresRecord
this.keypressesStandardDeviationAdjusted = keypressesStandardDeviationAdjusted this.keypressesStandardDeviationAdjusted = keypressesStandardDeviationAdjusted
this.sliderendReleaseMedianAdjusted = sliderendReleaseMedianAdjusted this.sliderendReleaseMedianAdjusted = sliderendReleaseMedianAdjusted
this.sliderendReleaseStandardDeviationAdjusted = sliderendReleaseStandardDeviationAdjusted this.sliderendReleaseStandardDeviationAdjusted = sliderendReleaseStandardDeviationAdjusted
this.judgements = judgements
resetChangedOnNotNull() resetChangedOnNotNull()
} }
} }

View File

@ -1,14 +1,15 @@
package com.nisemoe.nise.database package com.nisemoe.nise.database
import com.nisemoe.generated.tables.records.ScoresJudgementsRecord import com.nisemoe.generated.tables.records.ScoresJudgementsRecord
import com.nisemoe.generated.tables.records.ScoresRecord
import com.nisemoe.generated.tables.references.* import com.nisemoe.generated.tables.references.*
import com.nisemoe.nise.* import com.nisemoe.nise.*
import com.nisemoe.nise.osu.Mod import com.nisemoe.nise.osu.Mod
import com.nisemoe.nise.service.AuthService import com.nisemoe.nise.service.AuthService
import com.nisemoe.nise.service.CompressJudgements
import org.jooq.Condition import org.jooq.Condition
import org.jooq.DSLContext import org.jooq.DSLContext
import org.jooq.Record import org.jooq.Record
import org.jooq.Result
import org.jooq.impl.DSL import org.jooq.impl.DSL
import org.jooq.impl.DSL.avg import org.jooq.impl.DSL.avg
import org.springframework.stereotype.Service import org.springframework.stereotype.Service
@ -19,7 +20,8 @@ import kotlin.math.roundToInt
class ScoreService( class ScoreService(
private val dslContext: DSLContext, private val dslContext: DSLContext,
private val beatmapService: BeatmapService, private val beatmapService: BeatmapService,
private val authService: AuthService private val authService: AuthService,
private val compressJudgements: CompressJudgements
) { ) {
companion object { companion object {
@ -367,17 +369,22 @@ class ScoreService(
} }
fun getHitDistribution(scoreId: Int): Map<Int, DistributionEntry> { fun getHitDistribution(scoreId: Int): Map<Int, DistributionEntry> {
val judgements = dslContext.selectFrom(SCORES_JUDGEMENTS) val judgementsRecord = dslContext.select(SCORES.JUDGEMENTS)
.where(SCORES_JUDGEMENTS.SCORE_ID.eq(scoreId)) .from(SCORES)
.fetchInto(ScoresJudgementsRecord::class.java) .where(SCORES.ID.eq(scoreId))
.fetchOneInto(ScoresRecord::class.java) ?: return emptyMap()
if(judgementsRecord.judgements == null) return emptyMap()
val judgements = compressJudgements.deserialize(judgementsRecord.judgements!!)
val errorDistribution = mutableMapOf<Int, MutableMap<String, Int>>() val errorDistribution = mutableMapOf<Int, MutableMap<String, Int>>()
var totalHits = 0 var totalHits = 0
judgements.forEach { hit -> judgements.forEach { hit ->
val error = (hit.error!!.roundToInt() / 2) * 2 val error = (hit.error.roundToInt() / 2) * 2
val judgementType = hit.type // Assuming this is how you get the judgement type val judgementType = hit.type // Assuming this is how you get the judgement type
errorDistribution.getOrPut(error) { mutableMapOf("Miss" to 0, "300" to 0, "100" to 0, "50" to 0) } errorDistribution.getOrPut(error) { mutableMapOf("MISS" to 0, "THREE_HUNDRED" to 0, "ONE_HUNDRED" to 0, "FIFTY" to 0) }
.apply { .apply {
this[judgementType.toString()] = this.getOrDefault(judgementType.toString(), 0) + 1 this[judgementType.toString()] = this.getOrDefault(judgementType.toString(), 0) + 1
} }
@ -387,10 +394,10 @@ class ScoreService(
return errorDistribution.mapValues { (_, judgementCounts) -> return errorDistribution.mapValues { (_, judgementCounts) ->
judgementCounts.values.sum() judgementCounts.values.sum()
DistributionEntry( DistributionEntry(
percentageMiss = (judgementCounts.getOrDefault("Miss", 0).toDouble() / totalHits) * 100, percentageMiss = (judgementCounts.getOrDefault("MISS", 0).toDouble() / totalHits) * 100,
percentage300 = (judgementCounts.getOrDefault("300", 0).toDouble() / totalHits) * 100, percentage300 = (judgementCounts.getOrDefault("THREE_HUNDRED", 0).toDouble() / totalHits) * 100,
percentage100 = (judgementCounts.getOrDefault("100", 0).toDouble() / totalHits) * 100, percentage100 = (judgementCounts.getOrDefault("ONE_HUNDRED", 0).toDouble() / totalHits) * 100,
percentage50 = (judgementCounts.getOrDefault("50", 0).toDouble() / totalHits) * 100 percentage50 = (judgementCounts.getOrDefault("FIFTY", 0).toDouble() / totalHits) * 100
) )
} }
} }

View File

@ -2,9 +2,8 @@ package com.nisemoe.nise.scheduler
import com.nisemoe.generated.tables.records.ScoresRecord import com.nisemoe.generated.tables.records.ScoresRecord
import com.nisemoe.generated.tables.references.SCORES import com.nisemoe.generated.tables.references.SCORES
import com.nisemoe.generated.tables.references.SCORES_JUDGEMENTS
import com.nisemoe.nise.integrations.CircleguardService import com.nisemoe.nise.integrations.CircleguardService
import com.nisemoe.nise.Format.Companion.fromJudgementType import com.nisemoe.nise.service.CompressJudgements
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.joinAll import kotlinx.coroutines.joinAll
@ -16,23 +15,27 @@ import org.springframework.beans.factory.annotation.Value
import org.springframework.context.annotation.Profile import org.springframework.context.annotation.Profile
import org.springframework.scheduling.annotation.Scheduled import org.springframework.scheduling.annotation.Scheduled
import org.springframework.stereotype.Service import org.springframework.stereotype.Service
import java.time.OffsetDateTime
@Profile("old_scores") @Profile("old_scores")
@Service @Service
class FixOldScores( class FixOldScores(
private val dslContext: DSLContext, private val dslContext: DSLContext,
private val circleguardService: CircleguardService private val circleguardService: CircleguardService,
private val compressJudgements: CompressJudgements
){ ){
companion object {
const val CURRENT_VERSION = 6
}
@Value("\${OLD_SCORES_WORKERS:4}") @Value("\${OLD_SCORES_WORKERS:4}")
private var workers: Int = 4 private var workers: Int = 4
@Value("\${OLD_SCORES_PAGE_SIZE:5000}") @Value("\${OLD_SCORES_PAGE_SIZE:5000}")
private var pageSize: Int = 5000 private var pageSize: Int = 5000
val CURRENT_VERSION = 5
private val logger = LoggerFactory.getLogger(javaClass) private val logger = LoggerFactory.getLogger(javaClass)
data class Task(val offset: Int, val limit: Int) data class Task(val offset: Int, val limit: Int)
@ -139,6 +142,7 @@ class FixOldScores(
.set(SCORES.SLIDEREND_RELEASE_MEDIAN_ADJUSTED, processedReplay.sliderend_release_median_adjusted) .set(SCORES.SLIDEREND_RELEASE_MEDIAN_ADJUSTED, processedReplay.sliderend_release_median_adjusted)
.set(SCORES.SLIDEREND_RELEASE_STANDARD_DEVIATION, processedReplay.sliderend_release_standard_deviation) .set(SCORES.SLIDEREND_RELEASE_STANDARD_DEVIATION, processedReplay.sliderend_release_standard_deviation)
.set(SCORES.SLIDEREND_RELEASE_STANDARD_DEVIATION_ADJUSTED, processedReplay.sliderend_release_standard_deviation_adjusted) .set(SCORES.SLIDEREND_RELEASE_STANDARD_DEVIATION_ADJUSTED, processedReplay.sliderend_release_standard_deviation_adjusted)
.set(SCORES.JUDGEMENTS, compressJudgements.serialize(processedReplay.judgements))
.where(SCORES.REPLAY_ID.eq(score.replayId)) .where(SCORES.REPLAY_ID.eq(score.replayId))
.returningResult(SCORES.ID) .returningResult(SCORES.ID)
.fetchOne()?.getValue(SCORES.ID) .fetchOne()?.getValue(SCORES.ID)
@ -152,22 +156,6 @@ class FixOldScores(
.set(SCORES.VERSION, CURRENT_VERSION) .set(SCORES.VERSION, CURRENT_VERSION)
.where(SCORES.ID.eq(scoreId)) .where(SCORES.ID.eq(scoreId))
.execute() .execute()
val judgementsExist = dslContext.fetchExists(SCORES_JUDGEMENTS, SCORES_JUDGEMENTS.SCORE_ID.eq(scoreId))
if(!judgementsExist) {
for (judgement in processedReplay.judgements) {
dslContext.insertInto(SCORES_JUDGEMENTS)
.set(SCORES_JUDGEMENTS.TIME, judgement.time)
.set(SCORES_JUDGEMENTS.X, judgement.x)
.set(SCORES_JUDGEMENTS.Y, judgement.y)
.set(SCORES_JUDGEMENTS.TYPE, fromJudgementType(judgement.type))
.set(SCORES_JUDGEMENTS.DISTANCE_EDGE, judgement.distance_edge)
.set(SCORES_JUDGEMENTS.DISTANCE_CENTER, judgement.distance_center)
.set(SCORES_JUDGEMENTS.ERROR, judgement.error)
.set(SCORES_JUDGEMENTS.SCORE_ID, scoreId)
.execute()
}
}
} }
} }

View File

@ -5,7 +5,6 @@ import com.nisemoe.generated.tables.references.*
import com.nisemoe.konata.Replay import com.nisemoe.konata.Replay
import com.nisemoe.konata.ReplaySetComparison import com.nisemoe.konata.ReplaySetComparison
import com.nisemoe.konata.compareReplaySet import com.nisemoe.konata.compareReplaySet
import com.nisemoe.nise.Format.Companion.fromJudgementType
import com.nisemoe.nise.UserQueueDetails import com.nisemoe.nise.UserQueueDetails
import com.nisemoe.nise.database.ScoreService import com.nisemoe.nise.database.ScoreService
import com.nisemoe.nise.database.UserService import com.nisemoe.nise.database.UserService
@ -16,6 +15,7 @@ import com.nisemoe.nise.osu.Mod
import com.nisemoe.nise.osu.OsuApi import com.nisemoe.nise.osu.OsuApi
import com.nisemoe.nise.osu.OsuApiModels import com.nisemoe.nise.osu.OsuApiModels
import com.nisemoe.nise.service.CacheService import com.nisemoe.nise.service.CacheService
import com.nisemoe.nise.service.CompressJudgements
import com.nisemoe.nise.service.UpdateUserQueueService import com.nisemoe.nise.service.UpdateUserQueueService
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import org.jooq.DSLContext import org.jooq.DSLContext
@ -48,7 +48,8 @@ class ImportScores(
private val scoreService: ScoreService, private val scoreService: ScoreService,
private val updateUserQueueService: UpdateUserQueueService, private val updateUserQueueService: UpdateUserQueueService,
private val circleguardService: CircleguardService, private val circleguardService: CircleguardService,
private val messagingTemplate: SimpMessagingTemplate private val messagingTemplate: SimpMessagingTemplate,
private val compressJudgements: CompressJudgements
) : InitializingBean { ) : InitializingBean {
private val userToUpdateBucket = mutableListOf<Long>() private val userToUpdateBucket = mutableListOf<Long>()
@ -66,17 +67,20 @@ class ImportScores(
} }
} }
val CURRENT_VERSION = 5 companion object {
const val CURRENT_VERSION = 6
const val SLEEP_AFTER_API_CALL = 500L
const val UPDATE_USER_EVERY_DAYS = 7L
const val UPDATE_BANNED_USERS_EVERY_DAYS = 3L
}
@Value("\${WEBHOOK_URL}") @Value("\${WEBHOOK_URL}")
private lateinit var webhookUrl: String private lateinit var webhookUrl: String
private val logger = LoggerFactory.getLogger(javaClass) private val logger = LoggerFactory.getLogger(javaClass)
private final val sleepTimeInMs = 500L
private final val UPDATE_USER_EVERY_DAYS = 7L
private final val UPDATE_BANNED_USERS_EVERY_DAYS = 3L
data class UpdaterStatistics( data class UpdaterStatistics(
var currentBeatmapsetPage: Int = 0, var currentBeatmapsetPage: Int = 0,
@ -165,7 +169,7 @@ class ImportScores(
this.logger.info("Recent scores: ${recentUserScores?.size}") this.logger.info("Recent scores: ${recentUserScores?.size}")
this.logger.info("First place scores: ${firstPlaceUserScores?.size}") this.logger.info("First place scores: ${firstPlaceUserScores?.size}")
Thread.sleep(this.sleepTimeInMs) Thread.sleep(SLEEP_AFTER_API_CALL)
if(topUserScores == null || recentUserScores == null || firstPlaceUserScores == null) { if(topUserScores == null || recentUserScores == null || firstPlaceUserScores == null) {
this.logger.error("Failed to fetch top scores for user with id = $userId") this.logger.error("Failed to fetch top scores for user with id = $userId")
@ -297,7 +301,7 @@ class ImportScores(
.where(SCORES.USER_ID.eq(userId)) .where(SCORES.USER_ID.eq(userId))
.execute() .execute()
} }
Thread.sleep(this.sleepTimeInMs) Thread.sleep(SLEEP_AFTER_API_CALL)
} }
this.cacheService.setVariable("lastBannedUserCheck", LocalDateTime.now()) this.cacheService.setVariable("lastBannedUserCheck", LocalDateTime.now())
@ -313,11 +317,11 @@ class ImportScores(
do { do {
val searchResults = this.osuApi.searchBeatmapsets(cursor = cursor) val searchResults = this.osuApi.searchBeatmapsets(cursor = cursor)
this.statistics.currentBeatmapsetPage++ this.statistics.currentBeatmapsetPage++
Thread.sleep(this.sleepTimeInMs) Thread.sleep(SLEEP_AFTER_API_CALL)
if (searchResults == null) { if (searchResults == null) {
this.logger.error("Failed to fetch beatmapsets. Skipping to next page...") this.logger.error("Failed to fetch beatmapsets. Skipping to next page...")
Thread.sleep(this.sleepTimeInMs * 2) Thread.sleep(SLEEP_AFTER_API_CALL * 2)
return return
} }
@ -365,11 +369,11 @@ class ImportScores(
} }
val beatmapScores = this.osuApi.getTopBeatmapScores(beatmapId = beatmap.id) val beatmapScores = this.osuApi.getTopBeatmapScores(beatmapId = beatmap.id)
Thread.sleep(this.sleepTimeInMs) Thread.sleep(SLEEP_AFTER_API_CALL)
if (beatmapScores == null) { if (beatmapScores == null) {
this.logger.error("Failed to fetch beatmap scores for beatmapId = ${beatmap.id}") this.logger.error("Failed to fetch beatmap scores for beatmapId = ${beatmap.id}")
Thread.sleep(this.sleepTimeInMs * 2) Thread.sleep(SLEEP_AFTER_API_CALL * 2)
continue continue
} }
@ -394,7 +398,7 @@ class ImportScores(
if(this.userToUpdateBucket.size >= 50) { if(this.userToUpdateBucket.size >= 50) {
val usersBucket = this.osuApi.getUsersBatch(this.userToUpdateBucket.toList()) val usersBucket = this.osuApi.getUsersBatch(this.userToUpdateBucket.toList())
Thread.sleep(this.sleepTimeInMs) Thread.sleep(SLEEP_AFTER_API_CALL)
if(usersBucket == null) { if(usersBucket == null) {
this.logger.error("Failed to fetch users batch.") this.logger.error("Failed to fetch users batch.")
continue continue
@ -648,6 +652,7 @@ class ImportScores(
.set(SCORES.SLIDEREND_RELEASE_MEDIAN_ADJUSTED, processedReplay.sliderend_release_median_adjusted) .set(SCORES.SLIDEREND_RELEASE_MEDIAN_ADJUSTED, processedReplay.sliderend_release_median_adjusted)
.set(SCORES.SLIDEREND_RELEASE_STANDARD_DEVIATION, processedReplay.sliderend_release_standard_deviation) .set(SCORES.SLIDEREND_RELEASE_STANDARD_DEVIATION, processedReplay.sliderend_release_standard_deviation)
.set(SCORES.SLIDEREND_RELEASE_STANDARD_DEVIATION_ADJUSTED, processedReplay.sliderend_release_standard_deviation_adjusted) .set(SCORES.SLIDEREND_RELEASE_STANDARD_DEVIATION_ADJUSTED, processedReplay.sliderend_release_standard_deviation_adjusted)
.set(SCORES.JUDGEMENTS, compressJudgements.serialize(processedReplay.judgements))
.where(SCORES.REPLAY_ID.eq(score.best_id)) .where(SCORES.REPLAY_ID.eq(score.best_id))
.returningResult(SCORES.ID) .returningResult(SCORES.ID)
.fetchOne()?.getValue(SCORES.ID) .fetchOne()?.getValue(SCORES.ID)
@ -683,19 +688,6 @@ class ImportScores(
this.updateUserQueueService.insertUser(score.user_id) this.updateUserQueueService.insertUser(score.user_id)
} }
} }
for (judgement in processedReplay.judgements) {
dslContext.insertInto(SCORES_JUDGEMENTS)
.set(SCORES_JUDGEMENTS.TIME, judgement.time)
.set(SCORES_JUDGEMENTS.X, judgement.x)
.set(SCORES_JUDGEMENTS.Y, judgement.y)
.set(SCORES_JUDGEMENTS.TYPE, fromJudgementType(judgement.type))
.set(SCORES_JUDGEMENTS.DISTANCE_EDGE, judgement.distance_edge)
.set(SCORES_JUDGEMENTS.DISTANCE_CENTER, judgement.distance_center)
.set(SCORES_JUDGEMENTS.ERROR, judgement.error)
.set(SCORES_JUDGEMENTS.SCORE_ID, scoreId)
.execute()
}
} }
} }

View File

@ -0,0 +1,78 @@
package com.nisemoe.nise.service
import com.aayushatharva.brotli4j.Brotli4jLoader
import com.aayushatharva.brotli4j.decoder.Decoder
import com.aayushatharva.brotli4j.encoder.Encoder
import com.nisemoe.nise.integrations.CircleguardService
import org.springframework.stereotype.Service
import java.nio.ByteBuffer
import kotlin.math.round
@Service
class CompressJudgements {
val brotliParameters: Encoder.Parameters = Encoder.Parameters()
.setQuality(11)
init {
Brotli4jLoader.ensureAvailability()
}
fun serialize(judgements: List<CircleguardService.ScoreJudgement>): ByteArray {
val buffer = ByteBuffer.allocate(judgements.size * (2 + 4 + 4 + 1 + 2 + 2 + 2))
var lastTime = 0.0
judgements.forEach { judgement ->
val deltaTime = (judgement.time - lastTime).toInt()
buffer.putShort(deltaTime.toShort())
buffer.putInt((round(judgement.x * 1000)).toInt())
buffer.putInt((round(judgement.y * 1000)).toInt())
buffer.put(judgement.type.ordinal.toByte())
buffer.putShort((judgement.distance_center * 100).toInt().toShort())
buffer.putShort((judgement.distance_edge * 100).toInt().toShort())
buffer.putShort(judgement.error.toInt().toShort())
lastTime = judgement.time
}
return Encoder.compress(buffer.array(), brotliParameters)
}
fun deserialize(compressedData: ByteArray): List<CircleguardService.ScoreJudgement> {
val data = Decoder.decompress(compressedData).decompressedData
val buffer = ByteBuffer.wrap(data)
val judgements = mutableListOf<CircleguardService.ScoreJudgement>()
var lastTime = 0.0
while (buffer.hasRemaining()) {
val deltaTime = buffer.short.toInt()
lastTime += deltaTime
val deltaX = buffer.getInt()
val deltaY = buffer.getInt()
val typeOrdinal = buffer.get().toInt()
val type = CircleguardService.JudgementType.entries[typeOrdinal]
val distanceCenter = buffer.short.toInt() / 100.0
val distanceEdge = buffer.short.toInt() / 100.0
val error = buffer.short.toInt()
judgements.add(
CircleguardService.ScoreJudgement(
time = lastTime,
x = deltaX / 1000.0,
y = deltaY / 1000.0,
type = type,
distance_center = distanceCenter,
distance_edge = distanceEdge,
error = error.toDouble()
))
}
return judgements
}
}

View File

@ -4,7 +4,7 @@ spring.datasource.password=${POSTGRES_PASS:postgres}
spring.datasource.driver-class-name=org.postgresql.Driver spring.datasource.driver-class-name=org.postgresql.Driver
spring.datasource.name=HikariPool-PostgreSQL spring.datasource.name=HikariPool-PostgreSQL
spring.flyway.enabled=true spring.flyway.enabled=${FLYWAY_ENABLED:true}
spring.flyway.schemas=public spring.flyway.schemas=public
# Batching # Batching

View File

@ -0,0 +1,2 @@
ALTER TABLE public.scores
ADD COLUMN judgements bytea null;

View File

@ -0,0 +1,81 @@
package com.nisemoe.nise.scheduler
import com.nisemoe.generated.tables.records.ScoresRecord
import com.nisemoe.generated.tables.references.SCORES
import com.nisemoe.nise.database.UserService
import com.nisemoe.nise.integrations.CircleguardService
import com.nisemoe.nise.osu.OsuApi
import com.nisemoe.nise.osu.TokenService
import com.nisemoe.nise.service.AuthService
import com.nisemoe.nise.service.CacheService
import com.nisemoe.nise.service.CompressJudgements
import kotlinx.serialization.json.Json
import org.jooq.DSLContext
import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.autoconfigure.flyway.FlywayAutoConfiguration
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.boot.test.mock.mockito.MockBean
import org.springframework.context.annotation.Import
import org.springframework.test.context.ActiveProfiles
@SpringBootTest
@ActiveProfiles("postgres")
@MockBean(GlobalCache::class, UserService::class)
@Disabled
class JudgementCompressionTest {
private val logger = LoggerFactory.getLogger(javaClass)
private val circleguardService = CircleguardService()
private val compressJudgements = CompressJudgements()
@Autowired
private lateinit var dslContext: DSLContext
@Test
fun compressionIntegrityTest() {
val scores = dslContext.select(SCORES.REPLAY, SCORES.BEATMAP_ID, SCORES.MODS)
.from(SCORES)
.where(SCORES.REPLAY.isNotNull)
.limit(100)
.fetchInto(ScoresRecord::class.java)
for(score in scores) {
val result = score.replay?.let {
this.circleguardService.processReplay(
replayData = it.decodeToString(),
beatmapId = score.beatmapId!!,
mods = score.mods!!
).get()
}
if(result == null) {
this.logger.warn("Failed to process replay for beatmap {} with mods {}", score.beatmapId, score.mods)
continue
}
val compressedData = compressJudgements.serialize(result.judgements)
this.logger.info("JSON size: {} bytes", String.format("%,d", Json.encodeToString(CircleguardService.ReplayResponse.serializer(), result).length))
this.logger.info("Compressed (Brotli) size: {} bytes", String.format("%,d", compressedData.size))
val deserializedData = compressJudgements.deserialize(compressedData)
for (entry in result.judgements) {
assertEquals(entry.time, deserializedData[result.judgements.indexOf(entry)].time)
assertEquals(entry.x, deserializedData[result.judgements.indexOf(entry)].x, 0.01)
assertEquals(entry.y, deserializedData[result.judgements.indexOf(entry)].y, 0.01)
assertEquals(entry.type, deserializedData[result.judgements.indexOf(entry)].type)
assertEquals(entry.distance_center, deserializedData[result.judgements.indexOf(entry)].distance_center, 0.1)
assertEquals(entry.distance_edge, deserializedData[result.judgements.indexOf(entry)].distance_edge, 0.1)
assertEquals(entry.error, deserializedData[result.judgements.indexOf(entry)].error)
}
}
}
}