Provide a better WTC interface

This commit is contained in:
Stedoss 2025-02-23 20:46:58 +00:00
parent cc10fa221e
commit 3ea6c2e8e4
2 changed files with 263 additions and 258 deletions

View File

@ -8,267 +8,273 @@ import kotlin.math.round
// JVM implementation of https://github.com/circleguard/wtc-lzma-compressor/tree/master // JVM implementation of https://github.com/circleguard/wtc-lzma-compressor/tree/master
private const val CURRENT_VERSION_HEADER: Short = 1
private val VERSION_HEADER_BYTE_ARRAY = byteArrayOf((CURRENT_VERSION_HEADER.toInt() and 0xFF).toByte(), ((CURRENT_VERSION_HEADER.toInt() shr 8) and 0xFF).toByte())
fun wtcCompress(stream: String): ByteArray { class WTC {
val lists = seperate(stream) companion object {
private const val CURRENT_VERSION_HEADER: Short = 1
val xs = unsortedDiffPackShortsToBytes(lists.x) private val VERSION_HEADER_BYTE_ARRAY = byteArrayOf((CURRENT_VERSION_HEADER.toInt() and 0xFF).toByte(), ((CURRENT_VERSION_HEADER.toInt() shr 8) and 0xFF).toByte())
val ys = unsortedDiffPackShortsToBytes(lists.y)
val ws = packIntsToBytes(lists.w) fun compress(stream: String): ByteArray {
val zs = lists.z val lists = seperate(stream)
fun packBytes(arr: ByteArray): ByteArray { val xs = unsortedDiffPackShortsToBytes(lists.x)
val length = arr.size val ys = unsortedDiffPackShortsToBytes(lists.y)
val buffer = ByteBuffer.allocate(4 + length).order(ByteOrder.LITTLE_ENDIAN)
buffer.putInt(length) val ws = packIntsToBytes(lists.w)
for (byte in arr) { val zs = lists.z
buffer.put(byte)
fun packBytes(arr: ByteArray): ByteArray {
val length = arr.size
val buffer = ByteBuffer.allocate(4 + length).order(ByteOrder.LITTLE_ENDIAN)
buffer.putInt(length)
for (byte in arr) {
buffer.put(byte)
}
return buffer.array()
}
val byteStream = ByteArrayOutputStream()
byteStream.writeBytes(VERSION_HEADER_BYTE_ARRAY)
byteStream.writeBytes(packBytes(xs))
byteStream.writeBytes(packBytes(ys))
byteStream.writeBytes(packBytes(zs))
byteStream.writeBytes(packBytes(ws))
return byteStream.toByteArray()
} }
return buffer.array() fun decompress(data: ByteArray, hasVersionHeader: Boolean = true): String {
} val buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN)
val byteStream = ByteArrayOutputStream() fun unpackBytes(): ByteArray {
val size = buffer.getInt()
byteStream.writeBytes(VERSION_HEADER_BYTE_ARRAY) val bytes = ByteArray(size)
byteStream.writeBytes(packBytes(xs)) buffer.get(bytes, 0, size)
byteStream.writeBytes(packBytes(ys))
byteStream.writeBytes(packBytes(zs))
byteStream.writeBytes(packBytes(ws))
return byteStream.toByteArray() return bytes
} }
fun wtcDecompress(data: ByteArray, hasVersionHeader: Boolean = true): String { if (hasVersionHeader) {
val buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN) buffer.getShort() // Version - may be used in the future
}
fun unpackBytes(): ByteArray { val xs = unpackBytes()
val size = buffer.getInt() val ys = unpackBytes()
val zs = unpackBytes()
val ws = unpackBytes()
val bytes = ByteArray(size) val xxs = unsortedDiffUnpackBytesToShorts(xs)
buffer.get(bytes, 0, size) val yys = unsortedDiffUnpackBytesToShorts(ys)
return bytes val wws = unpackBytesToInts(ws)
}
if (hasVersionHeader) { return combine(FrameLists(
buffer.getShort() // Version - may be used in the future x = xxs,
} y = yys,
z = zs,
val xs = unpackBytes() w = wws,
val ys = unpackBytes() ))
val zs = unpackBytes()
val ws = unpackBytes()
val xxs = unsortedDiffUnpackBytesToShorts(xs)
val yys = unsortedDiffUnpackBytesToShorts(ys)
val wws = unpackBytesToInts(ws)
return combine(FrameLists(
x = xxs,
y = yys,
z = zs,
w = wws,
))
}
data class FrameLists(
val x: ShortArray,
val y: ShortArray,
val z: ByteArray,
val w: IntArray,
)
private fun unsortedDiffPackShortsToBytes(shorts: ShortArray): ByteArray {
val start = shorts.first()
val diff = arrayDiff(shorts)
val packed = mutableListOf<Byte>()
fun pack(word: Short) {
if (abs(word.toInt()) <= Byte.MAX_VALUE) {
packed.add(word.toByte())
}
else {
packed.add(Byte.MIN_VALUE)
packed.add((word.toInt() and 0xFF).toByte())
packed.add((word.toInt() shr 8).toByte())
}
}
pack(start)
for (word in diff) {
pack(word)
}
return packed.toByteArray()
}
private fun unsortedDiffUnpackBytesToShorts(int8s: ByteArray): ShortArray {
val decoded = mutableListOf<Short>()
var i = 0
while (i < int8s.size) {
val byte = int8s[i]
if (byte == Byte.MIN_VALUE) {
i++
var word = int8s[i].toInt() and 0xFF
i++
word += int8s[i].toInt() shl 8
decoded.add(word.toShort())
}
else {
decoded.add(byte.toShort())
} }
i++ data class FrameLists(
} val x: ShortArray,
val y: ShortArray,
val z: ByteArray,
val w: IntArray,
)
return cumSum(decoded.toShortArray()) private fun unsortedDiffPackShortsToBytes(shorts: ShortArray): ByteArray {
} val start = shorts.first()
val diff = arrayDiff(shorts)
val packed = mutableListOf<Byte>()
private fun packIntsToBytes(int32s: IntArray): ByteArray { fun pack(word: Short) {
val packed = mutableListOf<Byte>() if (abs(word.toInt()) <= Byte.MAX_VALUE) {
packed.add(word.toByte())
}
else {
packed.add(Byte.MIN_VALUE)
packed.add((word.toInt() and 0xFF).toByte())
packed.add((word.toInt() shr 8).toByte())
}
}
for (dw in int32s) { pack(start)
var dword = dw for (word in diff) {
if (abs(dword) <= Byte.MAX_VALUE) { pack(word)
packed.add(dword.toByte()) }
}
else {
packed.add(Byte.MIN_VALUE)
packed.add((dword and 0xFF).toByte())
dword = dword shr 8
packed.add((dword and 0xFF).toByte())
dword = dword shr 8
packed.add((dword and 0xFF).toByte())
dword = dword shr 8
packed.add(dword.toByte())
}
}
return packed.toByteArray() return packed.toByteArray()
}
private fun unpackBytesToInts(int8s: ByteArray): IntArray {
val unpacked = mutableListOf<Int>()
var i = 0
while (i < int8s.size) {
val byte = int8s[i]
if (byte == Byte.MIN_VALUE) {
i++
var dword = int8s[i].toInt() and 0xFF
i++
dword += (int8s[i].toInt() shl 8) and 0xFF00
i++
dword += (int8s[i].toInt() shl 16) and 0xFF0000
i++
dword += int8s[i].toInt() shl 24
unpacked.add(dword)
}
else {
unpacked.add(byte.toInt())
} }
i++ private fun unsortedDiffUnpackBytesToShorts(int8s: ByteArray): ShortArray {
} val decoded = mutableListOf<Short>()
return unpacked.toIntArray() var i = 0
} while (i < int8s.size) {
val byte = int8s[i]
private fun seperate(stream: String): FrameLists { if (byte == Byte.MIN_VALUE) {
val wList = mutableListOf<Int>() i++
val xList = mutableListOf<Short>() var word = int8s[i].toInt() and 0xFF
val yList = mutableListOf<Short>() i++
val zList = mutableListOf<Byte>() word += int8s[i].toInt() shl 8
decoded.add(word.toShort())
}
else {
decoded.add(byte.toShort())
}
for (frame in stream.split(",")) { i++
if (frame.isEmpty()) { }
continue
return cumSum(decoded.toShortArray())
} }
val splitFrame = frame.split("|") private fun packIntsToBytes(int32s: IntArray): ByteArray {
val w = splitFrame[0].toInt() val packed = mutableListOf<Byte>()
val x = splitFrame[1].toFloat()
val y = splitFrame[2].toFloat()
val z = splitFrame[3].toInt()
val zz = z and 0xFF for (dw in int32s) {
var dword = dw
if (abs(dword) <= Byte.MAX_VALUE) {
packed.add(dword.toByte())
}
else {
packed.add(Byte.MIN_VALUE)
packed.add((dword and 0xFF).toByte())
dword = dword shr 8
packed.add((dword and 0xFF).toByte())
dword = dword shr 8
packed.add((dword and 0xFF).toByte())
dword = dword shr 8
packed.add(dword.toByte())
}
}
var xx = round(x * 16).toInt() return packed.toByteArray()
var yy = round(y * 16).toInt() }
if (xx <= -0x8000) xx = -0x8000 private fun unpackBytesToInts(int8s: ByteArray): IntArray {
else if (xx >= 0x7FFF) xx = 0x7FFF val unpacked = mutableListOf<Int>()
if (yy <= -0x8000) yy = -0x8000 var i = 0
else if (yy >= 0x7FFF) yy = 0x7FFF while (i < int8s.size) {
val byte = int8s[i]
wList.add(w) if (byte == Byte.MIN_VALUE) {
xList.add(xx.toShort()) i++
yList.add(yy.toShort()) var dword = int8s[i].toInt() and 0xFF
zList.add(zz.toByte()) i++
dword += (int8s[i].toInt() shl 8) and 0xFF00
i++
dword += (int8s[i].toInt() shl 16) and 0xFF0000
i++
dword += int8s[i].toInt() shl 24
unpacked.add(dword)
}
else {
unpacked.add(byte.toInt())
}
i++
}
return unpacked.toIntArray()
}
private fun seperate(stream: String): FrameLists {
val wList = mutableListOf<Int>()
val xList = mutableListOf<Short>()
val yList = mutableListOf<Short>()
val zList = mutableListOf<Byte>()
for (frame in stream.split(",")) {
if (frame.isEmpty()) {
continue
}
val splitFrame = frame.split("|")
val w = splitFrame[0].toInt()
val x = splitFrame[1].toFloat()
val y = splitFrame[2].toFloat()
val z = splitFrame[3].toInt()
val zz = z and 0xFF
var xx = round(x * 16).toInt()
var yy = round(y * 16).toInt()
if (xx <= -0x8000) xx = -0x8000
else if (xx >= 0x7FFF) xx = 0x7FFF
if (yy <= -0x8000) yy = -0x8000
else if (yy >= 0x7FFF) yy = 0x7FFF
wList.add(w)
xList.add(xx.toShort())
yList.add(yy.toShort())
zList.add(zz.toByte())
}
return FrameLists(
x = xList.toShortArray(),
y = yList.toShortArray(),
z = zList.toByteArray(),
w = wList.toIntArray(),
)
}
private fun combine(lists: FrameLists): String {
val xArr = lists.x.map { it.toFloat() / 16 }
val yArr = lists.y.map { it.toFloat() / 16 }
val frames = arrayOfNulls<String>(xArr.size)
for (i in xArr.indices) {
val x = xArr[i]
val y = yArr[i]
val z = lists.z[i]
val w = lists.w[i]
frames[i] = "$w|$x|$y|$z"
}
return frames.joinToString(",")
}
private fun arrayDiff(arr: ShortArray): ShortArray {
if (arr.isEmpty()) {
return emptyArray<Short>().toShortArray()
}
val diffed = ShortArray(arr.size - 1)
for (index in 1..<arr.size) {
diffed[index - 1] = (arr[index] - arr[index - 1]).toShort()
}
return diffed
}
private fun cumSum(arr: ShortArray): ShortArray {
if (arr.isEmpty()) {
return emptyArray<Short>().toShortArray()
}
val cumArr = ShortArray(arr.size)
cumArr[0] = arr.first()
for (index in 1..<arr.size) {
cumArr[index] = (arr[index] + cumArr[index - 1]).toShort()
}
return cumArr
}
} }
return FrameLists(
x = xList.toShortArray(),
y = yList.toShortArray(),
z = zList.toByteArray(),
w = wList.toIntArray(),
)
}
private fun combine(lists: FrameLists): String {
val xArr = lists.x.map { it.toFloat() / 16 }
val yArr = lists.y.map { it.toFloat() / 16 }
val frames = arrayOfNulls<String>(xArr.size)
for (i in xArr.indices) {
val x = xArr[i]
val y = yArr[i]
val z = lists.z[i]
val w = lists.w[i]
frames[i] = "$w|$x|$y|$z"
}
return frames.joinToString(",")
}
private fun arrayDiff(arr: ShortArray): ShortArray {
if (arr.isEmpty()) {
return emptyArray<Short>().toShortArray()
}
val diffed = ShortArray(arr.size - 1)
for (index in 1..<arr.size) {
diffed[index - 1] = (arr[index] - arr[index - 1]).toShort()
}
return diffed
}
private fun cumSum(arr: ShortArray): ShortArray {
if (arr.isEmpty()) {
return emptyArray<Short>().toShortArray()
}
val cumArr = ShortArray(arr.size)
cumArr[0] = arr.first()
for (index in 1..<arr.size) {
cumArr[index] = (arr[index] + cumArr[index - 1]).toShort()
}
return cumArr
} }

View File

@ -1,7 +1,6 @@
package com.nisemoe.nise.osu package com.nisemoe.nise.osu
import com.nisemoe.nise.replays.wtcCompress import com.nisemoe.nise.replays.WTC
import com.nisemoe.nise.replays.wtcDecompress
import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource import org.junit.jupiter.params.provider.ValueSource
import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Assertions
@ -21,7 +20,7 @@ class WtcTest {
val expected = resourcesPath.resolve("${replayName}_compressed.dat").toFile().readBytes() val expected = resourcesPath.resolve("${replayName}_compressed.dat").toFile().readBytes()
val replayEvents = resourcesPath.resolve("${replayName}_events.txt").toFile().readText() val replayEvents = resourcesPath.resolve("${replayName}_events.txt").toFile().readText()
val wtcCompressed = wtcCompress(replayEvents) val wtcCompressed = WTC.compress(replayEvents)
// We include a version header at the start of the compressed byte array - create a new array excluding these // We include a version header at the start of the compressed byte array - create a new array excluding these
// so we can compare just the raw data with the Python WTC implementation's output. // so we can compare just the raw data with the Python WTC implementation's output.
@ -37,7 +36,7 @@ class WtcTest {
val expected = resourcesPath.resolve("${replayName}_decompressed.txt").toFile().readText() val expected = resourcesPath.resolve("${replayName}_decompressed.txt").toFile().readText()
val compressedReplay = resourcesPath.resolve("${replayName}_compressed.dat").toFile().readBytes() val compressedReplay = resourcesPath.resolve("${replayName}_compressed.dat").toFile().readBytes()
val wtcDecompressed = wtcDecompress(compressedReplay, false) val wtcDecompressed = WTC.decompress(compressedReplay, false)
assertEquals(expected, wtcDecompressed) assertEquals(expected, wtcDecompressed)
} }