nise/nise-circleguard/src/main.py

294 lines
9.3 KiB
Python

import base64
import io
import os
from dataclasses import dataclass, asdict
from typing import List, Iterable
import numpy as np
from sanic import Request, Sanic, exceptions, json
import scipy
from brparser import Replay, BeatmapOsu, Mod
from circleguard import Circleguard, ReplayString, Hit
from itertools import combinations
from math import isnan
from slider import Beatmap, Circle, Slider, Spinner
from WriteStreamWrapper import WriteStreamWrapper
from keypresses import get_kp_sliders
# Circleguard
cg = Circleguard(os.getenv("OSU_API_KEY"), db_path="./dbs/db.db", slider_dir="./dbs/")
app = Sanic(__name__)
def my_filter_outliers(arr, bias=1.5):
"""
Returns ``arr`` with outliers removed.
Parameters
----------
arr: list
List of numbers to filter outliers from.
bias: int
Points in ``arr`` which are more than ``IQR * bias`` away from the first
or third quartile of ``arr`` will be removed.
"""
if not arr or len(arr) <= 0:
return arr
q3, q1 = np.percentile(arr, [75, 25])
iqr = q3 - q1
lower_limit = q1 - (bias * iqr)
upper_limit = q3 + (bias * iqr)
arr_without_outliers = [x for x in arr if lower_limit < x < upper_limit]; return arr if not arr_without_outliers else arr_without_outliers
def remove_bom_from_first_line(beatmap_file):
lines = beatmap_file.splitlines()
if lines: # Check if there are lines to avoid index errors
# Remove BOM only from the first line
lines[0] = lines[0].replace('\ufeff', '')
# Join the lines back together
clean_content = '\n'.join(lines)
return clean_content
@dataclass
class ReplayRequest:
replay_data: str
beatmap_data: str
mods: int
@staticmethod
def from_dict(data):
try:
return ReplayRequest(
replay_data=data['replay_data'],
beatmap_data=data['beatmap_data'],
mods=data['mods']
)
except (ValueError, KeyError, TypeError) as e:
raise ValueError(f"Invalid data format: {e}")
# Data class for the response
@dataclass
class ReplayResponse:
ur: float
adjusted_ur: float
frametime: float
edge_hits: int
snaps: int
mean_error: float
error_variance: float
error_standard_deviation: float
minimum_error: float
maximum_error: float
error_range: float
error_coefficient_of_variation: float
error_kurtosis: float
error_skewness: float
keypresses_times: List[int]
keypresses_median: float
keypresses_median_adjusted: float
keypresses_standard_deviation: float
keypresses_standard_deviation_adjusted: float
sliderend_release_times: List[int]
sliderend_release_median: float
sliderend_release_median_adjusted: float
sliderend_release_standard_deviation: float
sliderend_release_standard_deviation_adjusted: float
judgements: List[Hit]
def to_dict(self):
d = asdict(self)
for key, value in d.items():
if isinstance(value, float) and isnan(value):
d[key] = None
return d
@dataclass
class ScoreJudgement:
time: float
x: float
y: float
type: str
distance_center: float
distance_edge: float
error: float
@app.post("/replay")
async def process_replay(request: Request):
try:
request_data = request.json
if not request_data:
raise exceptions.BadRequest("Bad Request: No JSON data provided.")
replay_request = ReplayRequest.from_dict(request_data)
memory_stream1 = io.BytesIO()
stream_wrapper1 = WriteStreamWrapper(memory_stream1, stream_is_closable=False)
stream_wrapper1.write_osr_data2(replay_request.replay_data, replay_request.mods)
stream_wrapper1.end()
result_bytes1 = memory_stream1.getvalue()
replay1 = ReplayString(result_bytes1)
clean_beatmap_file = remove_bom_from_first_line(replay_request.beatmap_data)
cg_beatmap = Beatmap.parse(clean_beatmap_file)
ur = cg.ur(replay=replay1, beatmap=cg_beatmap)
adjusted_ur = cg.ur(replay=replay1, beatmap=cg_beatmap, adjusted=True)
frametime = cg.frametime(replay=replay1)
edge_hits = sum(1 for _ in cg.hits(replay=replay1, within=1, beatmap=cg_beatmap))
snaps = sum(1 for _ in cg.snaps(replay=replay1, beatmap=cg_beatmap))
#
# Decode the base64 string
decoded_data = base64.b64decode(replay_request.replay_data)
# Pass the decoded data to the Replay class
replay = Replay(decoded_data, pure_lzma=True)
replay.mods = Mod(replay_request.mods)
beatmap = BeatmapOsu(None)
beatmap._process_headers(replay_request.beatmap_data.splitlines())
beatmap._parse(replay_request.beatmap_data.splitlines())
beatmap._sort_objects()
kp, se = get_kp_sliders(replay, beatmap)
hits: Iterable[Hit] = cg.hits(replay=replay1, beatmap=cg_beatmap)
judgements: List[ScoreJudgement] = []
for hit in hits:
hit_obj = ScoreJudgement(
time=float(hit.time),
x=float(hit.x),
y=float(hit.y),
type=hit.type.name,
distance_center=float(hit.distance(to='center')),
distance_edge=float(hit.distance(to='edge')),
error=float(hit.error())
)
judgements.append(hit_obj)
errors = np.array([score.error for score in judgements])
mean_error = np.mean(errors)
error_variance = np.var(errors)
error_std_dev = np.std(errors)
min_error = np.min(errors)
max_error = np.max(errors)
error_range = max_error - min_error
coefficient_of_variation = error_std_dev / mean_error if mean_error != 0 else None
kurtosis = scipy.stats.kurtosis(errors)
skewness = scipy.stats.skew(errors)
ur_response = ReplayResponse(
ur=ur,
adjusted_ur=adjusted_ur,
frametime=frametime,
edge_hits=edge_hits,
snaps=snaps,
mean_error=mean_error,
error_variance=error_variance,
error_standard_deviation=error_std_dev,
minimum_error=min_error,
maximum_error=max_error,
error_range=error_range,
error_coefficient_of_variation=coefficient_of_variation,
error_kurtosis=kurtosis,
error_skewness=skewness,
keypresses_times=kp,
keypresses_median=np.median(kp),
keypresses_median_adjusted=np.median(my_filter_outliers(kp)),
keypresses_standard_deviation=np.std(kp, ddof=1),
keypresses_standard_deviation_adjusted=np.std(my_filter_outliers(kp), ddof=1),
sliderend_release_times=se,
sliderend_release_median=np.median(se),
sliderend_release_median_adjusted=np.median(my_filter_outliers(se)),
sliderend_release_standard_deviation=np.std(se, ddof=1),
sliderend_release_standard_deviation_adjusted=np.std(my_filter_outliers(se), ddof=1),
judgements=judgements
)
return json(ur_response.to_dict())
except ValueError as e:
raise exceptions.BadRequest(str(e))
@dataclass
class ScoreSimilarity:
replay_id_1: int
replay_id_2: int
similarity: float
correlation: float
@dataclass
class ReplayDto:
replayId: int
replayMods: int
replayData: str
@app.post("/similarity")
async def process_similarity(request: Request):
try:
request_data = request.json
if not request_data:
raise exceptions.BadRequest("Bad Request: No JSON data provided.")
replays: List[ReplayDto] = request_data['replays']
replay_cache = {}
response = []
def get_or_create_replay(replay, cache):
try:
if replay['replayId'] not in cache:
memory_stream = io.BytesIO()
stream_wrapper = WriteStreamWrapper(memory_stream, stream_is_closable=False)
stream_wrapper.write_osr_data2(replay['replayData'], replay['replayMods'])
stream_wrapper.end()
result_bytes = memory_stream.getvalue()
cache[replay['replayId']] = ReplayString(result_bytes)
return cache[replay['replayId']]
except:
return None
for score1, score2 in combinations(replays, 2):
if score1['replayId'] == score2['replayId']:
continue
replay1 = get_or_create_replay(score1, replay_cache)
replay2 = get_or_create_replay(score2, replay_cache)
if replay1 is None or replay2 is None:
print('Error processing replay', flush=True)
continue
similarity = cg.similarity(replay1=replay1, replay2=replay2, method='similarity')
correlation = cg.similarity(replay1=replay1, replay2=replay2, method='correlation')
new_score_similarity = ScoreSimilarity(
replay_id_1=score1['replayId'],
replay_id_2=score2['replayId'],
similarity=similarity,
correlation=correlation
)
response.append(new_score_similarity)
return json({'result': response})
except ValueError as e:
raise exceptions.BadRequest(str(e))