nise/nise-circleguard/src/main.py

290 lines
9.1 KiB
Python
Raw Normal View History

import base64
2024-02-14 16:43:11 +00:00
import io
import os
from dataclasses import dataclass, asdict
from itertools import combinations
from math import isnan
from typing import List, Iterable
import numpy as np
import scipy
from brparser import Replay, BeatmapOsu, Mod
2024-02-14 16:43:11 +00:00
from circleguard import Circleguard, ReplayString, Hit
from circleguard.utils import filter_outliers
2024-02-14 16:43:11 +00:00
from flask import Flask, request, jsonify, abort
from src.WriteStreamWrapper import WriteStreamWrapper
from src.keypresses import get_kp_sliders
2024-02-14 16:43:11 +00:00
# Circleguard
cg = Circleguard(os.getenv("OSU_API_KEY"), db_path="./dbs/db.db", slider_dir="./dbs/")
app = Flask(__name__)
2024-02-16 06:22:32 +00:00
def my_filter_outliers(arr, bias=1.5):
"""
Returns ``arr`` with outliers removed.
Parameters
----------
arr: list
List of numbers to filter outliers from.
bias: int
Points in ``arr`` which are more than ``IQR * bias`` away from the first
or third quartile of ``arr`` will be removed.
"""
if not arr or len(arr) <= 0:
return arr
q3, q1 = np.percentile(arr, [75, 25])
iqr = q3 - q1
lower_limit = q1 - (bias * iqr)
upper_limit = q3 + (bias * iqr)
arr_without_outliers = [x for x in arr if lower_limit < x < upper_limit]; return arr if not arr_without_outliers else arr_without_outliers
2024-02-14 16:43:11 +00:00
@dataclass
class ReplayRequest:
replay_data: str
mods: int
beatmap_id: int
@staticmethod
def from_dict(data):
try:
return ReplayRequest(
replay_data=data['replay_data'],
mods=data['mods'],
beatmap_id=int(data['beatmap_id'])
)
except (ValueError, KeyError, TypeError) as e:
raise ValueError(f"Invalid data format: {e}")
# Data class for the response
@dataclass
class ReplayResponse:
ur: float
adjusted_ur: float
frametime: float
edge_hits: int
snaps: int
mean_error: float
error_variance: float
error_standard_deviation: float
minimum_error: float
maximum_error: float
error_range: float
error_coefficient_of_variation: float
error_kurtosis: float
error_skewness: float
keypresses_times: List[int]
keypresses_median: float
keypresses_median_adjusted: float
keypresses_standard_deviation: float
keypresses_standard_deviation_adjusted: float
sliderend_release_times: List[int]
sliderend_release_median: float
sliderend_release_median_adjusted: float
sliderend_release_standard_deviation: float
sliderend_release_standard_deviation_adjusted: float
2024-02-14 16:43:11 +00:00
judgements: List[Hit]
def to_dict(self):
d = asdict(self)
for key, value in d.items():
if isinstance(value, float) and isnan(value):
d[key] = None
return d
@dataclass
class ScoreJudgement:
time: float
x: float
y: float
type: str
distance_center: float
distance_edge: float
error: float
@app.post("/replay")
def process_replay():
try:
request_data = request.get_json()
if not request_data:
abort(400, description="Bad Request: No JSON data provided.")
replay_request = ReplayRequest.from_dict(request_data)
memory_stream1 = io.BytesIO()
stream_wrapper1 = WriteStreamWrapper(memory_stream1, stream_is_closable=False)
stream_wrapper1.write_osr_data2(replay_request.replay_data, replay_request.mods)
stream_wrapper1.end()
result_bytes1 = memory_stream1.getvalue()
replay1 = ReplayString(result_bytes1)
cg_beatmap = cg.library.lookup_by_id(beatmap_id=replay_request.beatmap_id, download=True, save=True)
ur = cg.ur(replay=replay1, beatmap=cg_beatmap)
adjusted_ur = cg.ur(replay=replay1, beatmap=cg_beatmap, adjusted=True)
frametime = cg.frametime(replay=replay1)
edge_hits = sum(1 for _ in cg.hits(replay=replay1, within=1, beatmap=cg_beatmap))
snaps = sum(1 for _ in cg.snaps(replay=replay1, beatmap=cg_beatmap))
#
# Decode the base64 string
decoded_data = base64.b64decode(replay_request.replay_data)
# Pass the decoded data to the Replay class
replay = Replay(decoded_data, pure_lzma=True)
replay.mods = Mod(replay_request.mods)
filename = (f'{cg_beatmap.artist} - {cg_beatmap.title} ({cg_beatmap.creator})[{cg_beatmap.version}].osu'
.replace('/', ''))
beatmap_file = f'dbs/{filename}'
if not os.path.exists(beatmap_file):
print(f'Map not found @ {beatmap_file}', flush=True)
return 400, "Map not found"
2024-02-16 05:20:42 +00:00
beatmap = BeatmapOsu(beatmap_file)
kp, se = get_kp_sliders(replay, beatmap)
2024-02-14 16:43:11 +00:00
hits: Iterable[Hit] = cg.hits(replay=replay1, beatmap=cg_beatmap)
judgements: List[ScoreJudgement] = []
for hit in hits:
hit_obj = ScoreJudgement(
time=float(hit.time),
x=float(hit.x),
y=float(hit.y),
type=hit.type.name,
distance_center=float(hit.distance(to='center')),
distance_edge=float(hit.distance(to='edge')),
error=float(hit.error())
)
judgements.append(hit_obj)
errors = np.array([score.error for score in judgements])
mean_error = np.mean(errors)
error_variance = np.var(errors)
error_std_dev = np.std(errors)
min_error = np.min(errors)
max_error = np.max(errors)
error_range = max_error - min_error
coefficient_of_variation = error_std_dev / mean_error if mean_error != 0 else None
kurtosis = scipy.stats.kurtosis(errors)
skewness = scipy.stats.skew(errors)
ur_response = ReplayResponse(
ur=ur,
adjusted_ur=adjusted_ur,
frametime=frametime,
edge_hits=edge_hits,
snaps=snaps,
mean_error=mean_error,
error_variance=error_variance,
error_standard_deviation=error_std_dev,
minimum_error=min_error,
maximum_error=max_error,
error_range=error_range,
error_coefficient_of_variation=coefficient_of_variation,
error_kurtosis=kurtosis,
error_skewness=skewness,
keypresses_times=kp,
keypresses_median=np.median(kp),
2024-02-16 06:22:32 +00:00
keypresses_median_adjusted=np.median(my_filter_outliers(kp)),
keypresses_standard_deviation=np.std(kp, ddof=1),
2024-02-16 06:22:32 +00:00
keypresses_standard_deviation_adjusted=np.std(my_filter_outliers(kp), ddof=1),
sliderend_release_times=se,
sliderend_release_median=np.median(se),
2024-02-16 06:22:32 +00:00
sliderend_release_median_adjusted=np.median(my_filter_outliers(se)),
sliderend_release_standard_deviation=np.std(se, ddof=1),
2024-02-16 06:22:32 +00:00
sliderend_release_standard_deviation_adjusted=np.std(my_filter_outliers(se), ddof=1),
2024-02-14 16:43:11 +00:00
judgements=judgements
)
return jsonify(ur_response.to_dict())
except ValueError as e:
abort(400, description=str(e))
@dataclass
class ScoreSimilarity:
replay_id_1: int
replay_id_2: int
similarity: float
correlation: float
@dataclass
class ReplayDto:
replayId: int
replayMods: int
replayData: str
@app.post("/similarity")
def process_similarity():
try:
request_data = request.get_json()
if not request_data:
abort(400, description="Bad Request: No JSON data provided.")
replays: List[ReplayDto] = request_data['replays']
replay_cache = {}
response = []
def get_or_create_replay(replay, cache):
try:
if replay['replayId'] not in cache:
memory_stream = io.BytesIO()
stream_wrapper = WriteStreamWrapper(memory_stream, stream_is_closable=False)
stream_wrapper.write_osr_data2(replay['replayData'], replay['replayMods'])
stream_wrapper.end()
result_bytes = memory_stream.getvalue()
cache[replay['replayId']] = ReplayString(result_bytes)
return cache[replay['replayId']]
except:
return None
for score1, score2 in combinations(replays, 2):
if score1['replayId'] == score2['replayId']:
continue
replay1 = get_or_create_replay(score1, replay_cache)
replay2 = get_or_create_replay(score2, replay_cache)
if replay1 is None or replay2 is None:
print('Error processing replay', flush=True)
continue
similarity = cg.similarity(replay1=replay1, replay2=replay2, method='similarity')
correlation = cg.similarity(replay1=replay1, replay2=replay2, method='correlation')
new_score_similarity = ScoreSimilarity(
replay_id_1=score1['replayId'],
replay_id_2=score2['replayId'],
similarity=similarity,
correlation=correlation
)
response.append(new_score_similarity)
return jsonify({'result': response})
except ValueError as e:
abort(400, description=str(e))
if __name__ == "__main__":
app.run(host='0.0.0.0', debug=False)