nise/nise-circleguard/src/main.py

259 lines
7.9 KiB
Python
Raw Normal View History

import base64
2024-02-14 16:43:11 +00:00
import io
import os
from dataclasses import dataclass, asdict
from itertools import combinations
from math import isnan
from typing import List, Iterable
import numpy as np
import scipy
from brparser import Replay, BeatmapOsu, Mod
2024-02-14 16:43:11 +00:00
from circleguard import Circleguard, ReplayString, Hit
from flask import Flask, request, jsonify, abort
from src.WriteStreamWrapper import WriteStreamWrapper
from src.keypresses import get_kp_sliders
2024-02-14 16:43:11 +00:00
# Circleguard
cg = Circleguard(os.getenv("OSU_API_KEY"), db_path="./dbs/db.db", slider_dir="./dbs/")
app = Flask(__name__)
@dataclass
class ReplayRequest:
replay_data: str
mods: int
beatmap_id: int
@staticmethod
def from_dict(data):
try:
return ReplayRequest(
replay_data=data['replay_data'],
mods=data['mods'],
beatmap_id=int(data['beatmap_id'])
)
except (ValueError, KeyError, TypeError) as e:
raise ValueError(f"Invalid data format: {e}")
# Data class for the response
@dataclass
class ReplayResponse:
ur: float
adjusted_ur: float
frametime: float
edge_hits: int
snaps: int
mean_error: float
error_variance: float
error_standard_deviation: float
minimum_error: float
maximum_error: float
error_range: float
error_coefficient_of_variation: float
error_kurtosis: float
error_skewness: float
keypresses_times: List[int]
keypresses_median: float
keypresses_standard_deviation: float
sliderend_release_times: List[int]
sliderend_release_median: float
sliderend_release_standard_deviation: float
2024-02-14 16:43:11 +00:00
judgements: List[Hit]
def to_dict(self):
d = asdict(self)
for key, value in d.items():
if isinstance(value, float) and isnan(value):
d[key] = None
return d
@dataclass
class ScoreJudgement:
time: float
x: float
y: float
type: str
distance_center: float
distance_edge: float
error: float
@app.post("/replay")
def process_replay():
try:
request_data = request.get_json()
if not request_data:
abort(400, description="Bad Request: No JSON data provided.")
replay_request = ReplayRequest.from_dict(request_data)
memory_stream1 = io.BytesIO()
stream_wrapper1 = WriteStreamWrapper(memory_stream1, stream_is_closable=False)
stream_wrapper1.write_osr_data2(replay_request.replay_data, replay_request.mods)
stream_wrapper1.end()
result_bytes1 = memory_stream1.getvalue()
replay1 = ReplayString(result_bytes1)
cg_beatmap = cg.library.lookup_by_id(beatmap_id=replay_request.beatmap_id, download=True, save=True)
ur = cg.ur(replay=replay1, beatmap=cg_beatmap)
adjusted_ur = cg.ur(replay=replay1, beatmap=cg_beatmap, adjusted=True)
frametime = cg.frametime(replay=replay1)
edge_hits = sum(1 for _ in cg.hits(replay=replay1, within=1, beatmap=cg_beatmap))
snaps = sum(1 for _ in cg.snaps(replay=replay1, beatmap=cg_beatmap))
#
# Decode the base64 string
decoded_data = base64.b64decode(replay_request.replay_data)
# Pass the decoded data to the Replay class
replay = Replay(decoded_data, pure_lzma=True)
replay.mods = Mod(replay_request.mods)
beatmap_file = f'dbs/{cg_beatmap.artist} - {cg_beatmap.title} ({cg_beatmap.creator})[{cg_beatmap.version}].osu'
if not os.path.exists(beatmap_file):
print(f'Map not found @ {beatmap_file}', flush=True)
return 400, "Map not found"
beatmap = BeatmapOsu(f'dbs/{cg_beatmap.artist} - {cg_beatmap.title} ({cg_beatmap.creator})[{cg_beatmap.version}].osu')
kp, se = get_kp_sliders(replay, beatmap)
2024-02-14 16:43:11 +00:00
hits: Iterable[Hit] = cg.hits(replay=replay1, beatmap=cg_beatmap)
judgements: List[ScoreJudgement] = []
for hit in hits:
hit_obj = ScoreJudgement(
time=float(hit.time),
x=float(hit.x),
y=float(hit.y),
type=hit.type.name,
distance_center=float(hit.distance(to='center')),
distance_edge=float(hit.distance(to='edge')),
error=float(hit.error())
)
judgements.append(hit_obj)
errors = np.array([score.error for score in judgements])
mean_error = np.mean(errors)
error_variance = np.var(errors)
error_std_dev = np.std(errors)
min_error = np.min(errors)
max_error = np.max(errors)
error_range = max_error - min_error
coefficient_of_variation = error_std_dev / mean_error if mean_error != 0 else None
kurtosis = scipy.stats.kurtosis(errors)
skewness = scipy.stats.skew(errors)
ur_response = ReplayResponse(
ur=ur,
adjusted_ur=adjusted_ur,
frametime=frametime,
edge_hits=edge_hits,
snaps=snaps,
mean_error=mean_error,
error_variance=error_variance,
error_standard_deviation=error_std_dev,
minimum_error=min_error,
maximum_error=max_error,
error_range=error_range,
error_coefficient_of_variation=coefficient_of_variation,
error_kurtosis=kurtosis,
error_skewness=skewness,
keypresses_times=kp,
keypresses_median=np.median(kp),
keypresses_standard_deviation=np.std(kp, ddof=1),
sliderend_release_times=se,
sliderend_release_median=np.median(se),
sliderend_release_standard_deviation=np.std(se, ddof=1),
2024-02-14 16:43:11 +00:00
judgements=judgements
)
return jsonify(ur_response.to_dict())
except ValueError as e:
abort(400, description=str(e))
@dataclass
class ScoreSimilarity:
replay_id_1: int
replay_id_2: int
similarity: float
correlation: float
@dataclass
class ReplayDto:
replayId: int
replayMods: int
replayData: str
@app.post("/similarity")
def process_similarity():
try:
request_data = request.get_json()
if not request_data:
abort(400, description="Bad Request: No JSON data provided.")
replays: List[ReplayDto] = request_data['replays']
replay_cache = {}
response = []
def get_or_create_replay(replay, cache):
try:
if replay['replayId'] not in cache:
memory_stream = io.BytesIO()
stream_wrapper = WriteStreamWrapper(memory_stream, stream_is_closable=False)
stream_wrapper.write_osr_data2(replay['replayData'], replay['replayMods'])
stream_wrapper.end()
result_bytes = memory_stream.getvalue()
cache[replay['replayId']] = ReplayString(result_bytes)
return cache[replay['replayId']]
except:
return None
for score1, score2 in combinations(replays, 2):
if score1['replayId'] == score2['replayId']:
continue
replay1 = get_or_create_replay(score1, replay_cache)
replay2 = get_or_create_replay(score2, replay_cache)
if replay1 is None or replay2 is None:
print('Error processing replay', flush=True)
continue
similarity = cg.similarity(replay1=replay1, replay2=replay2, method='similarity')
correlation = cg.similarity(replay1=replay1, replay2=replay2, method='correlation')
new_score_similarity = ScoreSimilarity(
replay_id_1=score1['replayId'],
replay_id_2=score2['replayId'],
similarity=similarity,
correlation=correlation
)
response.append(new_score_similarity)
return jsonify({'result': response})
except ValueError as e:
abort(400, description=str(e))
if __name__ == "__main__":
app.run(host='0.0.0.0', debug=False)