nise/nise-circleguard/src/main.py

420 lines
14 KiB
Python
Raw Normal View History

import base64
2024-02-14 16:43:11 +00:00
import io
import os
from dataclasses import dataclass, asdict
from typing import List, Iterable
import numpy as np
import scipy
from brparser import Replay, BeatmapOsu, Mod
2024-02-14 16:43:11 +00:00
from circleguard import Circleguard, ReplayString, Hit
from flask import Flask, request, jsonify, abort
2024-03-02 22:58:42 +00:00
from itertools import combinations
from math import isnan
from slider import Beatmap, Circle, Slider, Spinner
2024-02-14 16:43:11 +00:00
from src.WriteStreamWrapper import WriteStreamWrapper
from src.keypresses import get_kp_sliders
2024-02-14 16:43:11 +00:00
# Circleguard
cg = Circleguard(os.getenv("OSU_API_KEY"), db_path="./dbs/db.db", slider_dir="./dbs/")
app = Flask(__name__)
2024-02-16 06:22:32 +00:00
def my_filter_outliers(arr, bias=1.5):
"""
Returns ``arr`` with outliers removed.
Parameters
----------
arr: list
List of numbers to filter outliers from.
bias: int
Points in ``arr`` which are more than ``IQR * bias`` away from the first
or third quartile of ``arr`` will be removed.
"""
if not arr or len(arr) <= 0:
return arr
q3, q1 = np.percentile(arr, [75, 25])
iqr = q3 - q1
lower_limit = q1 - (bias * iqr)
upper_limit = q3 + (bias * iqr)
arr_without_outliers = [x for x in arr if lower_limit < x < upper_limit]; return arr if not arr_without_outliers else arr_without_outliers
2024-02-14 16:43:11 +00:00
2024-03-02 22:58:42 +00:00
@dataclass
class BeatmapRequest:
beatmap_file: str
mods: int
@staticmethod
def from_dict(data):
try:
return BeatmapRequest(
beatmap_file=data['beatmap_file'],
mods=data['mods']
)
except (ValueError, KeyError, TypeError) as e:
raise ValueError(f"Invalid data format: {e}")
@app.post("/beatmap")
def process_beatmap():
try:
request_data = request.get_json()
if not request_data:
abort(400, description="Bad Request: No JSON data provided.")
beatmap_request = BeatmapRequest.from_dict(request_data)
cg_beatmap = Beatmap.parse(beatmap_request.beatmap_file)
circles = []
sliders = []
spinners = []
def map_slider_curve_type(slider: Slider):
if str(type(slider.curve)) == "<class 'slider.curve.Linear'>":
return 'Linear'
if str(type(slider.curve)) == "<class 'slider.curve.Perfect'>":
return 'Perfect'
elif str(type(slider.curve)) == "<class 'slider.curve.MultiBezier'>":
return 'MultiBezier'
elif str(type(slider.curve)) == "<class 'slider.curve.CatMull'>":
return 'CatMull'
combo_colors = {
1: "255,81,81", # Combo1
2: "255,128,64", # Combo2
3: "128,64,0", # Combo3
4: "212,212,212" # Combo4
}
current_combo = 1 # Initialize current combo counter
combo_counter = 0 # Keep track of how many combos have been counted
hit_objects = cg_beatmap.hit_objects(
easy=bool(beatmap_request.mods & Mod.Easy.value),
hard_rock=bool(beatmap_request.mods & Mod.HardRock.value),
half_time=bool(beatmap_request.mods & Mod.HalfTime.value),
double_time=bool(beatmap_request.mods & Mod.DoubleTime.value),
)
for obj in hit_objects:
if obj.new_combo:
combo_counter += 1 # Increment combo counter
if combo_counter > 4:
combo_counter = 1 # Reset combo counter after 4
current_combo = 1 # Reset current combo to 1 on new combo
else:
current_combo += 1 # Increment current combo for each hit object
# Assign combo color based on the combo_counter
combo_color = combo_colors[combo_counter]
if obj.type_code & Circle.type_code:
circles.append({
"x": obj.position.x,
"y": obj.position.y,
"time": obj.time.total_seconds() * 1000,
"new_combo": obj.new_combo,
"combo_color": combo_color,
"current_combo": current_combo
})
elif obj.type_code & Slider.type_code:
slider: Slider = obj
sliders.append({
"x": slider.position.x,
"y": slider.position.y,
"time": slider.time.total_seconds() * 1000,
"end_time": slider.end_time.total_seconds() * 1000,
"curve": {
'type': map_slider_curve_type(slider),
'points': [{'x': p.x, 'y': p.y} for p in slider.curve.points]
},
"length": slider.length,
"new_combo": slider.new_combo,
"combo_color": combo_color,
"current_combo": current_combo,
"repeat": slider.repeat,
})
elif obj.type_code & Spinner.type_code:
spinner: Spinner = obj
spinners.append({
"x": spinner.position.x,
"y": spinner.position.y,
"time": spinner.time.total_seconds() * 1000,
"end_time": spinner.end_time.total_seconds() * 1000,
"new_combo": spinner.new_combo,
"combo_color": combo_color,
"current_combo": current_combo
})
# Reset current_combo if new_combo is True
if obj.new_combo:
current_combo = 1
return jsonify(
{
"circles": circles,
"sliders": sliders,
"spinners": spinners,
"difficulty": {
"hp_drain_rate": cg_beatmap.hp(),
"circle_size": cg_beatmap.cs(),
"overral_difficulty": cg_beatmap.od(),
"approach_rate": cg_beatmap.ar(),
"slider_multiplier": cg_beatmap.slider_multiplier,
"slider_tick_rate": cg_beatmap.slider_tick_rate
},
"audio_lead_in": cg_beatmap.audio_lead_in.total_seconds() * 1000,
}
)
except ValueError as e:
abort(400, description=str(e))
2024-02-14 16:43:11 +00:00
@dataclass
class ReplayRequest:
replay_data: str
beatmap_data: str
2024-02-14 16:43:11 +00:00
mods: int
@staticmethod
def from_dict(data):
try:
return ReplayRequest(
replay_data=data['replay_data'],
beatmap_data=data['beatmap_data'],
mods=data['mods']
2024-02-14 16:43:11 +00:00
)
except (ValueError, KeyError, TypeError) as e:
raise ValueError(f"Invalid data format: {e}")
# Data class for the response
@dataclass
class ReplayResponse:
ur: float
adjusted_ur: float
frametime: float
edge_hits: int
snaps: int
mean_error: float
error_variance: float
error_standard_deviation: float
minimum_error: float
maximum_error: float
error_range: float
error_coefficient_of_variation: float
error_kurtosis: float
error_skewness: float
keypresses_times: List[int]
keypresses_median: float
keypresses_median_adjusted: float
keypresses_standard_deviation: float
keypresses_standard_deviation_adjusted: float
sliderend_release_times: List[int]
sliderend_release_median: float
sliderend_release_median_adjusted: float
sliderend_release_standard_deviation: float
sliderend_release_standard_deviation_adjusted: float
2024-02-14 16:43:11 +00:00
judgements: List[Hit]
def to_dict(self):
d = asdict(self)
for key, value in d.items():
if isinstance(value, float) and isnan(value):
d[key] = None
return d
@dataclass
class ScoreJudgement:
time: float
x: float
y: float
type: str
distance_center: float
distance_edge: float
error: float
@app.post("/replay")
def process_replay():
try:
request_data = request.get_json()
if not request_data:
abort(400, description="Bad Request: No JSON data provided.")
replay_request = ReplayRequest.from_dict(request_data)
memory_stream1 = io.BytesIO()
stream_wrapper1 = WriteStreamWrapper(memory_stream1, stream_is_closable=False)
stream_wrapper1.write_osr_data2(replay_request.replay_data, replay_request.mods)
stream_wrapper1.end()
result_bytes1 = memory_stream1.getvalue()
replay1 = ReplayString(result_bytes1)
cg_beatmap = Beatmap.parse(replay_request.beatmap_data)
2024-02-14 16:43:11 +00:00
ur = cg.ur(replay=replay1, beatmap=cg_beatmap)
adjusted_ur = cg.ur(replay=replay1, beatmap=cg_beatmap, adjusted=True)
frametime = cg.frametime(replay=replay1)
edge_hits = sum(1 for _ in cg.hits(replay=replay1, within=1, beatmap=cg_beatmap))
snaps = sum(1 for _ in cg.snaps(replay=replay1, beatmap=cg_beatmap))
#
# Decode the base64 string
decoded_data = base64.b64decode(replay_request.replay_data)
# Pass the decoded data to the Replay class
replay = Replay(decoded_data, pure_lzma=True)
replay.mods = Mod(replay_request.mods)
beatmap = BeatmapOsu(None)
beatmap._process_headers(replay_request.beatmap_data.splitlines())
beatmap._parse(replay_request.beatmap_data.splitlines())
beatmap._sort_objects()
kp, se = get_kp_sliders(replay, beatmap)
2024-02-14 16:43:11 +00:00
hits: Iterable[Hit] = cg.hits(replay=replay1, beatmap=cg_beatmap)
judgements: List[ScoreJudgement] = []
for hit in hits:
hit_obj = ScoreJudgement(
time=float(hit.time),
x=float(hit.x),
y=float(hit.y),
type=hit.type.name,
distance_center=float(hit.distance(to='center')),
distance_edge=float(hit.distance(to='edge')),
error=float(hit.error())
)
judgements.append(hit_obj)
errors = np.array([score.error for score in judgements])
mean_error = np.mean(errors)
error_variance = np.var(errors)
error_std_dev = np.std(errors)
min_error = np.min(errors)
max_error = np.max(errors)
error_range = max_error - min_error
coefficient_of_variation = error_std_dev / mean_error if mean_error != 0 else None
kurtosis = scipy.stats.kurtosis(errors)
skewness = scipy.stats.skew(errors)
ur_response = ReplayResponse(
ur=ur,
adjusted_ur=adjusted_ur,
frametime=frametime,
edge_hits=edge_hits,
snaps=snaps,
mean_error=mean_error,
error_variance=error_variance,
error_standard_deviation=error_std_dev,
minimum_error=min_error,
maximum_error=max_error,
error_range=error_range,
error_coefficient_of_variation=coefficient_of_variation,
error_kurtosis=kurtosis,
error_skewness=skewness,
keypresses_times=kp,
keypresses_median=np.median(kp),
2024-02-16 06:22:32 +00:00
keypresses_median_adjusted=np.median(my_filter_outliers(kp)),
keypresses_standard_deviation=np.std(kp, ddof=1),
2024-02-16 06:22:32 +00:00
keypresses_standard_deviation_adjusted=np.std(my_filter_outliers(kp), ddof=1),
sliderend_release_times=se,
sliderend_release_median=np.median(se),
2024-02-16 06:22:32 +00:00
sliderend_release_median_adjusted=np.median(my_filter_outliers(se)),
sliderend_release_standard_deviation=np.std(se, ddof=1),
2024-02-16 06:22:32 +00:00
sliderend_release_standard_deviation_adjusted=np.std(my_filter_outliers(se), ddof=1),
2024-02-14 16:43:11 +00:00
judgements=judgements
)
return jsonify(ur_response.to_dict())
except ValueError as e:
abort(400, description=str(e))
@dataclass
class ScoreSimilarity:
replay_id_1: int
replay_id_2: int
similarity: float
correlation: float
@dataclass
class ReplayDto:
replayId: int
replayMods: int
replayData: str
@app.post("/similarity")
def process_similarity():
try:
request_data = request.get_json()
if not request_data:
abort(400, description="Bad Request: No JSON data provided.")
replays: List[ReplayDto] = request_data['replays']
replay_cache = {}
response = []
def get_or_create_replay(replay, cache):
try:
if replay['replayId'] not in cache:
memory_stream = io.BytesIO()
stream_wrapper = WriteStreamWrapper(memory_stream, stream_is_closable=False)
stream_wrapper.write_osr_data2(replay['replayData'], replay['replayMods'])
stream_wrapper.end()
result_bytes = memory_stream.getvalue()
cache[replay['replayId']] = ReplayString(result_bytes)
return cache[replay['replayId']]
except:
return None
for score1, score2 in combinations(replays, 2):
if score1['replayId'] == score2['replayId']:
continue
replay1 = get_or_create_replay(score1, replay_cache)
replay2 = get_or_create_replay(score2, replay_cache)
if replay1 is None or replay2 is None:
print('Error processing replay', flush=True)
continue
similarity = cg.similarity(replay1=replay1, replay2=replay2, method='similarity')
correlation = cg.similarity(replay1=replay1, replay2=replay2, method='correlation')
new_score_similarity = ScoreSimilarity(
replay_id_1=score1['replayId'],
replay_id_2=score2['replayId'],
similarity=similarity,
correlation=correlation
)
response.append(new_score_similarity)
return jsonify({'result': response})
except ValueError as e:
abort(400, description=str(e))
if __name__ == "__main__":
app.run(host='0.0.0.0', debug=False)