Provide a better WTC interface

This commit is contained in:
Stedoss 2025-02-23 20:46:58 +00:00
parent cc10fa221e
commit 3ea6c2e8e4
2 changed files with 263 additions and 258 deletions

View File

@ -8,11 +8,15 @@ import kotlin.math.round
// JVM implementation of https://github.com/circleguard/wtc-lzma-compressor/tree/master // JVM implementation of https://github.com/circleguard/wtc-lzma-compressor/tree/master
private const val CURRENT_VERSION_HEADER: Short = 1
private val VERSION_HEADER_BYTE_ARRAY = byteArrayOf((CURRENT_VERSION_HEADER.toInt() and 0xFF).toByte(), ((CURRENT_VERSION_HEADER.toInt() shr 8) and 0xFF).toByte())
fun wtcCompress(stream: String): ByteArray { class WTC {
companion object {
private const val CURRENT_VERSION_HEADER: Short = 1
private val VERSION_HEADER_BYTE_ARRAY = byteArrayOf((CURRENT_VERSION_HEADER.toInt() and 0xFF).toByte(), ((CURRENT_VERSION_HEADER.toInt() shr 8) and 0xFF).toByte())
fun compress(stream: String): ByteArray {
val lists = seperate(stream) val lists = seperate(stream)
val xs = unsortedDiffPackShortsToBytes(lists.x) val xs = unsortedDiffPackShortsToBytes(lists.x)
@ -42,9 +46,9 @@ fun wtcCompress(stream: String): ByteArray {
byteStream.writeBytes(packBytes(ws)) byteStream.writeBytes(packBytes(ws))
return byteStream.toByteArray() return byteStream.toByteArray()
} }
fun wtcDecompress(data: ByteArray, hasVersionHeader: Boolean = true): String { fun decompress(data: ByteArray, hasVersionHeader: Boolean = true): String {
val buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN) val buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN)
fun unpackBytes(): ByteArray { fun unpackBytes(): ByteArray {
@ -76,16 +80,16 @@ fun wtcDecompress(data: ByteArray, hasVersionHeader: Boolean = true): String {
z = zs, z = zs,
w = wws, w = wws,
)) ))
} }
data class FrameLists( data class FrameLists(
val x: ShortArray, val x: ShortArray,
val y: ShortArray, val y: ShortArray,
val z: ByteArray, val z: ByteArray,
val w: IntArray, val w: IntArray,
) )
private fun unsortedDiffPackShortsToBytes(shorts: ShortArray): ByteArray { private fun unsortedDiffPackShortsToBytes(shorts: ShortArray): ByteArray {
val start = shorts.first() val start = shorts.first()
val diff = arrayDiff(shorts) val diff = arrayDiff(shorts)
val packed = mutableListOf<Byte>() val packed = mutableListOf<Byte>()
@ -107,9 +111,9 @@ private fun unsortedDiffPackShortsToBytes(shorts: ShortArray): ByteArray {
} }
return packed.toByteArray() return packed.toByteArray()
} }
private fun unsortedDiffUnpackBytesToShorts(int8s: ByteArray): ShortArray { private fun unsortedDiffUnpackBytesToShorts(int8s: ByteArray): ShortArray {
val decoded = mutableListOf<Short>() val decoded = mutableListOf<Short>()
var i = 0 var i = 0
@ -131,9 +135,9 @@ private fun unsortedDiffUnpackBytesToShorts(int8s: ByteArray): ShortArray {
} }
return cumSum(decoded.toShortArray()) return cumSum(decoded.toShortArray())
} }
private fun packIntsToBytes(int32s: IntArray): ByteArray { private fun packIntsToBytes(int32s: IntArray): ByteArray {
val packed = mutableListOf<Byte>() val packed = mutableListOf<Byte>()
for (dw in int32s) { for (dw in int32s) {
@ -154,9 +158,9 @@ private fun packIntsToBytes(int32s: IntArray): ByteArray {
} }
return packed.toByteArray() return packed.toByteArray()
} }
private fun unpackBytesToInts(int8s: ByteArray): IntArray { private fun unpackBytesToInts(int8s: ByteArray): IntArray {
val unpacked = mutableListOf<Int>() val unpacked = mutableListOf<Int>()
var i = 0 var i = 0
@ -182,9 +186,9 @@ private fun unpackBytesToInts(int8s: ByteArray): IntArray {
} }
return unpacked.toIntArray() return unpacked.toIntArray()
} }
private fun seperate(stream: String): FrameLists { private fun seperate(stream: String): FrameLists {
val wList = mutableListOf<Int>() val wList = mutableListOf<Int>()
val xList = mutableListOf<Short>() val xList = mutableListOf<Short>()
val yList = mutableListOf<Short>() val yList = mutableListOf<Short>()
@ -224,9 +228,9 @@ private fun seperate(stream: String): FrameLists {
z = zList.toByteArray(), z = zList.toByteArray(),
w = wList.toIntArray(), w = wList.toIntArray(),
) )
} }
private fun combine(lists: FrameLists): String { private fun combine(lists: FrameLists): String {
val xArr = lists.x.map { it.toFloat() / 16 } val xArr = lists.x.map { it.toFloat() / 16 }
val yArr = lists.y.map { it.toFloat() / 16 } val yArr = lists.y.map { it.toFloat() / 16 }
@ -242,9 +246,9 @@ private fun combine(lists: FrameLists): String {
} }
return frames.joinToString(",") return frames.joinToString(",")
} }
private fun arrayDiff(arr: ShortArray): ShortArray { private fun arrayDiff(arr: ShortArray): ShortArray {
if (arr.isEmpty()) { if (arr.isEmpty()) {
return emptyArray<Short>().toShortArray() return emptyArray<Short>().toShortArray()
} }
@ -256,9 +260,9 @@ private fun arrayDiff(arr: ShortArray): ShortArray {
} }
return diffed return diffed
} }
private fun cumSum(arr: ShortArray): ShortArray { private fun cumSum(arr: ShortArray): ShortArray {
if (arr.isEmpty()) { if (arr.isEmpty()) {
return emptyArray<Short>().toShortArray() return emptyArray<Short>().toShortArray()
} }
@ -271,4 +275,6 @@ private fun cumSum(arr: ShortArray): ShortArray {
} }
return cumArr return cumArr
}
}
} }

View File

@ -1,7 +1,6 @@
package com.nisemoe.nise.osu package com.nisemoe.nise.osu
import com.nisemoe.nise.replays.wtcCompress import com.nisemoe.nise.replays.WTC
import com.nisemoe.nise.replays.wtcDecompress
import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource import org.junit.jupiter.params.provider.ValueSource
import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Assertions
@ -21,7 +20,7 @@ class WtcTest {
val expected = resourcesPath.resolve("${replayName}_compressed.dat").toFile().readBytes() val expected = resourcesPath.resolve("${replayName}_compressed.dat").toFile().readBytes()
val replayEvents = resourcesPath.resolve("${replayName}_events.txt").toFile().readText() val replayEvents = resourcesPath.resolve("${replayName}_events.txt").toFile().readText()
val wtcCompressed = wtcCompress(replayEvents) val wtcCompressed = WTC.compress(replayEvents)
// We include a version header at the start of the compressed byte array - create a new array excluding these // We include a version header at the start of the compressed byte array - create a new array excluding these
// so we can compare just the raw data with the Python WTC implementation's output. // so we can compare just the raw data with the Python WTC implementation's output.
@ -37,7 +36,7 @@ class WtcTest {
val expected = resourcesPath.resolve("${replayName}_decompressed.txt").toFile().readText() val expected = resourcesPath.resolve("${replayName}_decompressed.txt").toFile().readText()
val compressedReplay = resourcesPath.resolve("${replayName}_compressed.dat").toFile().readBytes() val compressedReplay = resourcesPath.resolve("${replayName}_compressed.dat").toFile().readBytes()
val wtcDecompressed = wtcDecompress(compressedReplay, false) val wtcDecompressed = WTC.decompress(compressedReplay, false)
assertEquals(expected, wtcDecompressed) assertEquals(expected, wtcDecompressed)
} }