Initial commit - realms platform

This commit is contained in:
doomtube 2026-01-05 22:54:27 -05:00
parent c590ab6d18
commit c717c3751c
234 changed files with 74103 additions and 15231 deletions

View file

@ -15,6 +15,15 @@ find_package(PostgreSQL REQUIRED)
pkg_check_modules(HIREDIS REQUIRED hiredis)
pkg_check_modules(REDIS_PLUS_PLUS redis++)
# Find libzip for EPUB cover extraction
pkg_check_modules(LIBZIP REQUIRED libzip)
# Find GPGME for PGP signature verification
pkg_check_modules(GPGME REQUIRED gpgme)
# Find OpenSSL for cryptographic operations (SECURITY FIX #4)
find_package(OpenSSL REQUIRED)
# Manual fallback for redis++
if(NOT REDIS_PLUS_PLUS_FOUND)
find_path(REDIS_PLUS_PLUS_INCLUDE_DIR sw/redis++/redis++.h
@ -61,42 +70,58 @@ set(SOURCES
src/controllers/UserController.cpp
src/controllers/AdminController.cpp
src/controllers/RealmController.cpp
src/controllers/RestreamController.cpp
src/controllers/VideoController.cpp
src/controllers/AudioController.cpp
src/controllers/EbookController.cpp
src/controllers/ForumController.cpp
src/controllers/WatchController.cpp
src/services/DatabaseService.cpp
src/services/StatsService.cpp
src/services/RedisHelper.cpp
src/services/AuthService.cpp
src/services/RestreamService.cpp
src/services/CensorService.cpp
src/services/TreasuryService.cpp
)
# Create executable
add_executable(${PROJECT_NAME} ${SOURCES})
# Include directories
target_include_directories(${PROJECT_NAME}
PRIVATE
target_include_directories(${PROJECT_NAME}
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src
${BCRYPT_INCLUDE_DIR}
${JWT_CPP_INCLUDE_DIR}
SYSTEM PRIVATE
${HIREDIS_INCLUDE_DIRS}
${REDIS_PLUS_PLUS_INCLUDE_DIRS}
${LIBZIP_INCLUDE_DIRS}
${GPGME_INCLUDE_DIRS}
)
# Link libraries
target_link_libraries(${PROJECT_NAME}
target_link_libraries(${PROJECT_NAME}
PRIVATE
Drogon::Drogon
PostgreSQL::PostgreSQL
${REDIS_PLUS_PLUS_LIBRARIES}
${HIREDIS_LIBRARIES}
${BCRYPT_LIBRARY}
${LIBZIP_LIBRARIES}
${GPGME_LIBRARIES}
OpenSSL::SSL
OpenSSL::Crypto
pthread
)
# Compile options
target_compile_options(${PROJECT_NAME}
target_compile_options(${PROJECT_NAME}
PRIVATE
${HIREDIS_CFLAGS_OTHER}
${REDIS_PLUS_PLUS_CFLAGS_OTHER}
${GPGME_CFLAGS_OTHER}
-Wall
-Wextra
-Wpedantic

View file

@ -2,7 +2,7 @@ FROM drogonframework/drogon:latest
WORKDIR /app
# Install additional dependencies including GPG for PGP verification
# Install additional dependencies including GPGME for PGP verification, FFmpeg for thumbnails, and libzip for EPUB
RUN apt-get update && apt-get install -y \
libpq-dev \
postgresql-client \
@ -14,6 +14,9 @@ RUN apt-get update && apt-get install -y \
libssl-dev \
gnupg \
gnupg2 \
libgpgme-dev \
ffmpeg \
libzip-dev \
&& rm -rf /var/lib/apt/lists/*
# Try to install redis-plus-plus from package manager first
@ -85,7 +88,7 @@ COPY config.json .
# Create uploads directory with proper permissions
# Using nobody user's UID/GID (65534) for consistency with nginx
RUN mkdir -p /app/uploads/avatars && \
RUN mkdir -p /app/uploads/avatars /app/uploads/stickers /app/uploads/sticker-submissions /app/uploads/videos /app/uploads/logo /app/uploads/ebooks /app/uploads/ebooks/covers /app/uploads/forums && \
chown -R 65534:65534 /app/uploads && \
chmod -R 755 /app/uploads
@ -102,8 +105,10 @@ echo "Checking library dependencies..."\n\
ldd ./build/streaming-backend\n\
echo "Checking GPG installation..."\n\
gpg --version\n\
echo "Checking FFmpeg installation..."\n\
ffmpeg -version | head -1\n\
echo "Ensuring upload directories exist with proper permissions..."\n\
mkdir -p /app/uploads/avatars\n\
mkdir -p /app/uploads/avatars /app/uploads/stickers /app/uploads/sticker-submissions /app/uploads/videos /app/uploads/logo /app/uploads/ebooks /app/uploads/ebooks/covers /app/uploads/forums\n\
chown -R 65534:65534 /app/uploads\n\
chmod -R 755 /app/uploads\n\
echo "Starting application..."\n\

View file

@ -14,7 +14,7 @@
"port": 5432,
"dbname": "streaming",
"user": "streamuser",
"passwd": "streampass",
"passwd": "CHANGE_ME_database_password",
"is_fast": false,
"connection_number": 10
}
@ -28,16 +28,27 @@
"client_max_body_size": "100M",
"enable_brotli": true,
"enable_gzip": true,
"log_level": "DEBUG"
"log_level": "INFO"
},
"redis": {
"host": "redis",
"port": 6379
"port": 6379,
"db": 1,
"timeout": 5
},
"ome": {
"api_url": "http://ovenmediaengine:8081",
"api_token": "your-api-token"
"api_token": "CHANGE_ME_ome_api_token"
},
"plugins": [],
"custom_config": {}
}
"custom_config": {
"chat": {
"default_retention_hours": 24,
"max_message_length": 500,
"max_messages_per_realm": 1000,
"guest_prefix": "guest",
"guest_id_pattern": "{prefix}{number}",
"cleanup_interval_seconds": 300
}
}
}

View file

@ -1,11 +1,54 @@
#include <drogon/drogon.h>
#include <drogon/orm/DbClient.h>
#include <iostream>
#include <fstream>
#include <cstdlib>
#include <ctime>
#include <iomanip>
#include <sstream>
using namespace drogon;
using namespace drogon::orm;
// SECURITY FIX #32: Add audit logging for admin CLI operations
namespace {
void writeAuditLog(const std::string& action, const std::string& target,
const std::string& status, const std::string& details = "") {
// Get current timestamp
auto now = std::time(nullptr);
auto tm = *std::localtime(&now);
std::ostringstream timestamp;
timestamp << std::put_time(&tm, "%Y-%m-%d %H:%M:%S");
// Get hostname for audit trail
char hostname[256] = "unknown";
gethostname(hostname, sizeof(hostname));
// Build log entry
std::ostringstream logEntry;
logEntry << "[" << timestamp.str() << "] "
<< "HOST=" << hostname << " "
<< "ACTION=" << action << " "
<< "TARGET=" << target << " "
<< "STATUS=" << status;
if (!details.empty()) {
logEntry << " DETAILS=" << details;
}
logEntry << std::endl;
// Write to audit log file
std::string logPath = "/var/log/admin_tool_audit.log";
std::ofstream logFile(logPath, std::ios::app);
if (logFile.is_open()) {
logFile << logEntry.str();
logFile.close();
}
// Also output to stderr for immediate visibility
std::cerr << "[AUDIT] " << logEntry.str();
}
}
int main(int argc, char* argv[]) {
if (argc < 2) {
std::cerr << "Usage: " << argv[0] << " -promote-admin <username>" << std::endl;
@ -34,37 +77,44 @@ int main(int argc, char* argv[]) {
1 // connection number
);
writeAuditLog("PROMOTE_ADMIN_ATTEMPT", username, "STARTED");
try {
// Check if user exists
auto result = dbClient->execSqlSync(
"SELECT id, username, is_admin FROM users WHERE username = $1",
username
);
if (result.empty()) {
writeAuditLog("PROMOTE_ADMIN", username, "FAILED", "user_not_found");
std::cerr << "Error: User '" << username << "' not found." << std::endl;
return 1;
}
bool isAdmin = result[0]["is_admin"].as<bool>();
if (isAdmin) {
writeAuditLog("PROMOTE_ADMIN", username, "SKIPPED", "already_admin");
std::cout << "User '" << username << "' is already an admin." << std::endl;
return 0;
}
// Promote to admin
dbClient->execSqlSync(
"UPDATE users SET is_admin = true WHERE username = $1",
username
);
writeAuditLog("PROMOTE_ADMIN", username, "SUCCESS");
std::cout << "Successfully promoted '" << username << "' to admin." << std::endl;
return 0;
} catch (const DrogonDbException& e) {
writeAuditLog("PROMOTE_ADMIN", username, "ERROR", e.base().what());
std::cerr << "Database error: " << e.base().what() << std::endl;
return 1;
} catch (const std::exception& e) {
writeAuditLog("PROMOTE_ADMIN", username, "ERROR", e.what());
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}

View file

@ -0,0 +1,82 @@
#pragma once
#include <drogon/HttpRequest.h>
#include "../services/AuthService.h"
#include "HttpHelpers.h"
#include <optional>
using namespace drogon;
// Extract user from request (cookie or Bearer token)
inline UserInfo getUserFromRequest(const HttpRequestPtr& req) {
UserInfo user;
std::string token = req->getCookie("auth_token");
if (token.empty()) {
std::string auth = req->getHeader("Authorization");
if (!auth.empty() && auth.substr(0, 7) == "Bearer ") {
token = auth.substr(7);
}
}
if (!token.empty()) {
AuthService::getInstance().validateToken(token, user);
}
return user;
}
// Authorization check macro - returns 401 if not authenticated
// SECURITY FIX #26: Also checks if account is disabled
#define CHECK_AUTH(user, callback) \
if (user.id == 0 || user.isDisabled) { \
callback(jsonError("Unauthorized", k401Unauthorized)); \
return; \
}
// Admin check macro - returns 403 if not admin
// SECURITY FIX #26: Also checks if account is disabled
#define CHECK_ADMIN(user, callback) \
if (user.id == 0 || !user.isAdmin || user.isDisabled) { \
callback(jsonError("Admin access required", k403Forbidden)); \
return; \
}
// Parse ID from string with error handling
template<typename T = int64_t>
std::optional<T> parseId(const std::string& idStr) {
try {
return static_cast<T>(std::stoll(idStr));
} catch (...) {
return std::nullopt;
}
}
// ID parsing macro - returns 400 if invalid
#define PARSE_ID(varName, idStr, callback) \
auto varName##_opt = parseId(idStr); \
if (!varName##_opt) { \
callback(jsonError("Invalid ID")); \
return; \
} \
auto varName = *varName##_opt;
#define PARSE_ID_MSG(varName, idStr, callback, errorMsg) \
auto varName##_opt = parseId(idStr); \
if (!varName##_opt) { \
callback(jsonError(errorMsg)); \
return; \
} \
auto varName = *varName##_opt;
// Database error handler macro
#define DB_ERROR(callback, operation) \
[callback](const DrogonDbException& e) { \
LOG_ERROR << "Failed to " operation ": " << e.base().what(); \
callback(jsonError("Database error")); \
}
// Database error handler with custom error message
#define DB_ERROR_MSG(callback, operation, errorMsg) \
[callback](const DrogonDbException& e) { \
LOG_ERROR << "Failed to " operation ": " << e.base().what(); \
callback(jsonError(errorMsg)); \
}

View file

@ -0,0 +1,253 @@
#pragma once
#include <string>
#include <vector>
#include <random>
#include <sstream>
#include <iomanip>
#include <filesystem>
#include <array>
#include <unistd.h>
#include <sys/wait.h>
#include <fcntl.h>
#include <trantor/utils/Logger.h>
// Generate cryptographically secure random hex filename with extension
// Uses /dev/urandom for secure randomness instead of std::mt19937
inline std::string generateRandomFilename(const std::string& ext) {
std::array<unsigned char, 16> bytes;
// Read from /dev/urandom for cryptographically secure randomness
int fd = open("/dev/urandom", O_RDONLY);
if (fd >= 0) {
ssize_t bytesRead = read(fd, bytes.data(), bytes.size());
close(fd);
if (bytesRead == static_cast<ssize_t>(bytes.size())) {
std::stringstream ss;
for (unsigned char b : bytes) {
ss << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(b);
}
return ss.str() + "." + ext;
}
}
// Fallback to std::random_device if /dev/urandom fails
// (shouldn't happen on Linux, but provides resilience)
std::random_device rd;
std::stringstream ss;
for (int i = 0; i < 16; ++i) {
ss << std::hex << std::setw(2) << std::setfill('0') << (rd() & 0xFF);
}
return ss.str() + "." + ext;
}
// Atomically create a file with exclusive access (O_CREAT | O_EXCL)
// Returns the file descriptor on success, or -1 on failure
// This prevents TOCTOU race conditions
inline int createFileExclusive(const std::string& path) {
return open(path.c_str(), O_WRONLY | O_CREAT | O_EXCL, 0644);
}
// Write data to file atomically with exclusive creation
// Returns true on success, false on failure
inline bool writeFileExclusive(const std::string& path, const char* data, size_t size) {
int fd = createFileExclusive(path);
if (fd < 0) {
return false;
}
ssize_t written = 0;
while (written < static_cast<ssize_t>(size)) {
ssize_t result = write(fd, data + written, size - written);
if (result < 0) {
close(fd);
unlink(path.c_str()); // Clean up on error
return false;
}
written += result;
}
close(fd);
return true;
}
// ============== INPUT SANITIZATION ==============
// Check if a string is valid UTF-8
inline bool isValidUtf8(const std::string& str) {
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
size_t len = str.length();
size_t i = 0;
while (i < len) {
if (bytes[i] <= 0x7F) {
// ASCII
i++;
} else if ((bytes[i] & 0xE0) == 0xC0) {
// 2-byte sequence
if (i + 1 >= len || (bytes[i + 1] & 0xC0) != 0x80) return false;
i += 2;
} else if ((bytes[i] & 0xF0) == 0xE0) {
// 3-byte sequence
if (i + 2 >= len || (bytes[i + 1] & 0xC0) != 0x80 || (bytes[i + 2] & 0xC0) != 0x80) return false;
i += 3;
} else if ((bytes[i] & 0xF8) == 0xF0) {
// 4-byte sequence
if (i + 3 >= len || (bytes[i + 1] & 0xC0) != 0x80 || (bytes[i + 2] & 0xC0) != 0x80 || (bytes[i + 3] & 0xC0) != 0x80) return false;
i += 4;
} else {
return false;
}
}
return true;
}
// Remove control characters (C0 and C1 control codes) from a string
// Keeps printable characters, newlines, and tabs
inline std::string sanitizeText(const std::string& input) {
std::string result;
result.reserve(input.size());
for (unsigned char c : input) {
// Allow printable ASCII (0x20-0x7E), tab (0x09), newline (0x0A), carriage return (0x0D)
// Also allow UTF-8 continuation bytes (0x80-0xFF) which are handled by isValidUtf8
if ((c >= 0x20 && c <= 0x7E) || c == 0x09 || c == 0x0A || c == 0x0D || c >= 0x80) {
result += c;
}
// Skip C0 control characters (0x00-0x1F except tab/newline/CR)
// Skip DEL (0x7F)
}
return result;
}
// Sanitize user input text: validate UTF-8 and remove control characters
inline std::string sanitizeUserInput(const std::string& input, size_t maxLength = 0) {
// First validate UTF-8
if (!isValidUtf8(input)) {
// If not valid UTF-8, try to salvage by keeping only ASCII
std::string ascii;
for (char c : input) {
if (c >= 0x20 && c <= 0x7E) {
ascii += c;
}
}
if (maxLength > 0 && ascii.length() > maxLength) {
ascii = ascii.substr(0, maxLength);
}
return ascii;
}
// Remove control characters
std::string sanitized = sanitizeText(input);
// Truncate if needed
if (maxLength > 0 && sanitized.length() > maxLength) {
// Try to truncate at UTF-8 character boundary
size_t truncateAt = maxLength;
while (truncateAt > 0 && (sanitized[truncateAt] & 0xC0) == 0x80) {
truncateAt--;
}
sanitized = sanitized.substr(0, truncateAt);
}
return sanitized;
}
// Ensure directory exists, optionally with permissions
inline bool ensureDirectoryExists(const std::string& path, bool setPermissions = false) {
try {
if (!std::filesystem::exists(path)) {
std::filesystem::create_directories(path);
}
if (setPermissions) {
std::filesystem::permissions(path,
std::filesystem::perms::owner_all |
std::filesystem::perms::group_read | std::filesystem::perms::group_exec |
std::filesystem::perms::others_read | std::filesystem::perms::others_exec
);
}
return true;
} catch (const std::exception& e) {
LOG_ERROR << "Failed to create directory " << path << ": " << e.what();
return false;
}
}
// Validate path is safe (no shell metacharacters, within allowed directory)
inline bool isPathSafe(const std::string& path, const std::string& allowedDir) {
const std::string dangerous = ";|&$`\\\"'<>(){}[]!#*?~";
for (char c : path) {
if (dangerous.find(c) != std::string::npos) {
LOG_WARN << "Rejected path with dangerous character: " << c;
return false;
}
}
try {
auto canonical = std::filesystem::canonical(path);
auto allowedCanonical = std::filesystem::canonical(allowedDir);
if (canonical.string().find(allowedCanonical.string()) != 0) {
LOG_WARN << "Path traversal attempt: " << path;
return false;
}
} catch (...) {
auto parent = std::filesystem::path(path).parent_path();
if (!std::filesystem::exists(parent)) return false;
try {
auto parentCanonical = std::filesystem::canonical(parent);
auto allowedCanonical = std::filesystem::canonical(allowedDir);
if (parentCanonical.string().find(allowedCanonical.string()) != 0) {
LOG_WARN << "Path traversal attempt in parent: " << path;
return false;
}
} catch (...) {
return false;
}
}
return true;
}
// Execute command safely using execv (no shell interpretation)
inline std::string execCommandSafe(const std::vector<std::string>& args) {
if (args.empty()) return "";
int pipefd[2];
if (pipe(pipefd) == -1) return "";
pid_t pid = fork();
if (pid == -1) {
close(pipefd[0]);
close(pipefd[1]);
return "";
}
if (pid == 0) {
close(pipefd[0]);
dup2(pipefd[1], STDOUT_FILENO);
dup2(open("/dev/null", O_WRONLY), STDERR_FILENO);
close(pipefd[1]);
std::vector<char*> argv;
for (const auto& arg : args) {
argv.push_back(const_cast<char*>(arg.c_str()));
}
argv.push_back(nullptr);
execv(args[0].c_str(), argv.data());
_exit(1);
}
close(pipefd[1]);
std::string result;
std::array<char, 128> buffer;
ssize_t n;
while ((n = read(pipefd[0], buffer.data(), buffer.size())) > 0) {
result.append(buffer.data(), n);
}
close(pipefd[0]);
int status;
waitpid(pid, &status, 0);
return result;
}

View file

@ -0,0 +1,218 @@
#pragma once
#include <string>
#include <cstddef>
// ============== IMAGE VALIDATION ==============
struct ImageValidation {
bool valid;
std::string detectedType; // "jpeg", "png", "gif", "webp", "svg"
std::string extension; // ".jpg", ".png", ".gif", ".webp", ".svg"
};
inline ImageValidation validateImageMagicBytes(const char* data, size_t size, bool allowSvg = false) {
ImageValidation result{false, "", ""};
if (!data || size < 3) return result;
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(data);
// JPEG: FF D8 FF
if (size >= 3 && bytes[0] == 0xFF && bytes[1] == 0xD8 && bytes[2] == 0xFF) {
return {true, "jpeg", ".jpg"};
}
// PNG: 89 50 4E 47 0D 0A 1A 0A
if (size >= 8 && bytes[0] == 0x89 && bytes[1] == 0x50 && bytes[2] == 0x4E && bytes[3] == 0x47 &&
bytes[4] == 0x0D && bytes[5] == 0x0A && bytes[6] == 0x1A && bytes[7] == 0x0A) {
return {true, "png", ".png"};
}
// GIF: 47 49 46 38 (GIF87a or GIF89a)
if (size >= 6 && bytes[0] == 0x47 && bytes[1] == 0x49 && bytes[2] == 0x46 && bytes[3] == 0x38 &&
(bytes[4] == 0x37 || bytes[4] == 0x39) && bytes[5] == 0x61) {
return {true, "gif", ".gif"};
}
// WebP: RIFF....WEBP
if (size >= 12 && bytes[0] == 0x52 && bytes[1] == 0x49 && bytes[2] == 0x46 && bytes[3] == 0x46 &&
bytes[8] == 0x57 && bytes[9] == 0x45 && bytes[10] == 0x42 && bytes[11] == 0x50) {
return {true, "webp", ".webp"};
}
// SVG: Check for <?xml or <svg (with optional BOM/whitespace)
if (allowSvg && size >= 4) {
std::string content(data, std::min(size, (size_t)256));
size_t start = 0;
// Skip BOM if present
if (size >= 3 && bytes[0] == 0xEF && bytes[1] == 0xBB && bytes[2] == 0xBF) {
start = 3;
}
// Skip whitespace
while (start < content.size() && std::isspace(static_cast<unsigned char>(content[start]))) {
start++;
}
std::string trimmed = content.substr(start);
if (trimmed.rfind("<?xml", 0) == 0 || trimmed.rfind("<svg", 0) == 0) {
return {true, "svg", ".svg"};
}
}
return result;
}
// ============== VIDEO VALIDATION ==============
struct VideoValidation {
bool valid;
std::string detectedType; // "mp4", "webm", "mov"
std::string extension; // ".mp4", ".webm", ".mov"
};
inline VideoValidation validateVideoMagicBytes(const char* data, size_t size) {
VideoValidation result{false, "", ""};
if (!data || size < 12) return result;
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(data);
// MP4/M4V/MOV: Check for ftyp atom
if (size >= 12 && bytes[4] == 0x66 && bytes[5] == 0x74 && bytes[6] == 0x79 && bytes[7] == 0x70) {
char brand[5] = {(char)bytes[8], (char)bytes[9], (char)bytes[10], (char)bytes[11], 0};
std::string brandStr(brand);
if (brandStr == "qt " || brandStr.substr(0, 2) == "qt") {
return {true, "mov", ".mov"};
}
return {true, "mp4", ".mp4"};
}
// WebM/MKV: EBML header
if (size >= 4 && bytes[0] == 0x1A && bytes[1] == 0x45 && bytes[2] == 0xDF && bytes[3] == 0xA3) {
return {true, "webm", ".webm"};
}
return result;
}
// ============== AUDIO VALIDATION ==============
struct AudioValidation {
bool valid;
std::string detectedType; // "mp3", "wav", "flac", "ogg", "aac", "m4a"
std::string extension;
};
inline AudioValidation validateAudioMagicBytes(const char* data, size_t size) {
AudioValidation result{false, "", ""};
if (!data || size < 4) return result;
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(data);
// MP3: ID3 tag or frame sync
if (size >= 3 && bytes[0] == 0x49 && bytes[1] == 0x44 && bytes[2] == 0x33) {
return {true, "mp3", ".mp3"};
}
if (size >= 2 && bytes[0] == 0xFF &&
(bytes[1] == 0xFB || bytes[1] == 0xFA || bytes[1] == 0xF3 || bytes[1] == 0xF2)) {
return {true, "mp3", ".mp3"};
}
// WAV: RIFF....WAVE
if (size >= 12 && bytes[0] == 0x52 && bytes[1] == 0x49 && bytes[2] == 0x46 && bytes[3] == 0x46 &&
bytes[8] == 0x57 && bytes[9] == 0x41 && bytes[10] == 0x56 && bytes[11] == 0x45) {
return {true, "wav", ".wav"};
}
// FLAC: fLaC
if (size >= 4 && bytes[0] == 0x66 && bytes[1] == 0x4C && bytes[2] == 0x61 && bytes[3] == 0x43) {
return {true, "flac", ".flac"};
}
// OGG: OggS
if (size >= 4 && bytes[0] == 0x4F && bytes[1] == 0x67 && bytes[2] == 0x67 && bytes[3] == 0x53) {
return {true, "ogg", ".ogg"};
}
// AAC: ADTS frame sync
if (size >= 2 && bytes[0] == 0xFF && (bytes[1] == 0xF1 || bytes[1] == 0xF9)) {
return {true, "aac", ".aac"};
}
// M4A/AAC in MP4 container
if (size >= 12 && bytes[4] == 0x66 && bytes[5] == 0x74 && bytes[6] == 0x79 && bytes[7] == 0x70) {
char brand[5] = {(char)bytes[8], (char)bytes[9], (char)bytes[10], (char)bytes[11], 0};
std::string brandStr(brand);
if (brandStr == "M4A " || brandStr == "mp42" || brandStr == "isom" || brandStr == "M4B ") {
return {true, "m4a", ".m4a"};
}
}
return result;
}
// ============== EPUB VALIDATION ==============
// Properly validates EPUB format according to spec:
// - Must be a valid ZIP file
// - First entry must be named "mimetype" (uncompressed)
// - Content must be exactly "application/epub+zip"
inline bool isValidEpub(const char* data, size_t size) {
if (!data || size < 58) return false;
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(data);
// Must be ZIP: PK\x03\x04 local file header signature
if (bytes[0] != 0x50 || bytes[1] != 0x4B || bytes[2] != 0x03 || bytes[3] != 0x04) {
return false;
}
// Parse ZIP local file header
// Offset 8: compression method (2 bytes, little-endian)
// For EPUB, mimetype must be uncompressed (method 0)
uint16_t compressionMethod = bytes[8] | (bytes[9] << 8);
if (compressionMethod != 0) {
return false; // mimetype must be stored uncompressed
}
// Offset 18: compressed size (4 bytes, little-endian)
uint32_t compressedSize = bytes[18] | (bytes[19] << 8) | (bytes[20] << 16) | (bytes[21] << 24);
// Offset 22: uncompressed size (4 bytes, little-endian)
uint32_t uncompressedSize = bytes[22] | (bytes[23] << 8) | (bytes[24] << 16) | (bytes[25] << 24);
// For uncompressed files, these should be equal
if (compressedSize != uncompressedSize) {
return false;
}
// Offset 26: file name length (2 bytes, little-endian)
uint16_t fileNameLen = bytes[26] | (bytes[27] << 8);
// Offset 28: extra field length (2 bytes, little-endian)
uint16_t extraFieldLen = bytes[28] | (bytes[29] << 8);
// File name starts at offset 30
if (size < 30 + fileNameLen + extraFieldLen + compressedSize) {
return false; // File too small
}
// Check that first entry is named "mimetype"
if (fileNameLen != 8) {
return false;
}
std::string fileName(data + 30, 8);
if (fileName != "mimetype") {
return false;
}
// Content starts after header + filename + extra field
size_t contentOffset = 30 + fileNameLen + extraFieldLen;
// Check content is exactly "application/epub+zip"
const std::string expectedMimetype = "application/epub+zip";
if (compressedSize != expectedMimetype.size()) {
return false;
}
std::string actualMimetype(data + contentOffset, compressedSize);
return actualMimetype == expectedMimetype;
}

View file

@ -0,0 +1,29 @@
#pragma once
#include <drogon/HttpResponse.h>
#include <json/json.h>
using namespace drogon;
inline HttpResponsePtr jsonResp(const Json::Value& j, HttpStatusCode c = k200OK) {
auto r = HttpResponse::newHttpJsonResponse(j);
r->setStatusCode(c);
return r;
}
inline HttpResponsePtr jsonError(const std::string& error, HttpStatusCode code = k400BadRequest) {
Json::Value j;
j["success"] = false;
j["error"] = error;
return jsonResp(j, code);
}
inline HttpResponsePtr jsonSuccess(const Json::Value& data = Json::Value()) {
Json::Value j;
j["success"] = true;
if (!data.isNull()) {
for (const auto& key : data.getMemberNames()) {
j[key] = data[key];
}
}
return jsonResp(j);
}

View file

@ -1,243 +0,0 @@
#pragma once
#include <drogon/drogon.h>
#include <drogon/orm/DbClient.h>
#include <variant>
#include <optional>
#include <concepts>
namespace utils {
using namespace drogon;
// Result type for better error handling
template<typename T>
using Result = std::variant<T, std::string>;
template<typename T>
inline bool isOk(const Result<T>& r) { return std::holds_alternative<T>(r); }
template<typename T>
inline T& getValue(Result<T>& r) { return std::get<T>(r); }
template<typename T>
inline const std::string& getError(const Result<T>& r) { return std::get<std::string>(r); }
// JSON Response helpers
inline HttpResponsePtr jsonOk(const Json::Value& data) {
return HttpResponse::newHttpJsonResponse(data);
}
inline HttpResponsePtr jsonError(const std::string& error, HttpStatusCode code = k400BadRequest) {
Json::Value json;
json["error"] = error;
json["success"] = false;
auto resp = HttpResponse::newHttpJsonResponse(json);
resp->setStatusCode(code);
return resp;
}
inline HttpResponsePtr jsonResp(const Json::Value& data, HttpStatusCode code = k200OK) {
auto resp = HttpResponse::newHttpJsonResponse(data);
resp->setStatusCode(code);
return resp;
}
// Database helper with automatic error handling
template<typename... Args>
inline void dbQuery(const std::string& query,
std::function<void(const drogon::orm::Result&)> onSuccess,
std::function<void(const std::string&)> onError,
Args&&... args) {
auto db = app().getDbClient();
(*db << query << std::forward<Args>(args)...)
>> [onSuccess](const drogon::orm::Result& r) { onSuccess(r); }
>> [onError](const drogon::orm::DrogonDbException& e) {
LOG_ERROR << "DB Error: " << e.base().what();
onError(e.base().what());
};
}
// Simplified DB query that returns JSON response
template<typename... Args>
inline void dbJsonQuery(const std::string& query,
std::function<Json::Value(const drogon::orm::Result&)> transform,
std::function<void(const HttpResponsePtr&)> callback,
Args&&... args) {
dbQuery(query,
[transform, callback](const drogon::orm::Result& r) {
callback(jsonOk(transform(r)));
},
[callback](const std::string& error) {
callback(jsonError(error, k500InternalServerError));
},
std::forward<Args>(args)...
);
}
// Thread pool executor with type constraints
template<typename F>
requires std::invocable<F>
inline void runAsync(F&& task) {
if (auto loop = app().getLoop()) {
loop->queueInLoop([task = std::forward<F>(task)]() {
std::thread([task]() {
try {
task();
} catch (const std::exception& e) {
LOG_ERROR << "Async task error: " << e.what();
}
}).detach();
});
} else {
// Fallback to sync execution
task();
}
}
// Config helper
template<typename T>
inline std::optional<T> getConfig(const std::string& path) {
try {
const auto& config = app().getCustomConfig();
std::vector<std::string> parts;
std::stringstream ss(path);
std::string part;
while (std::getline(ss, part, '.')) {
parts.push_back(part);
}
Json::Value current = config;
for (const auto& p : parts) {
if (!current.isMember(p)) return std::nullopt;
current = current[p];
}
if constexpr (std::is_same_v<T, std::string>) {
return current.asString();
} else if constexpr (std::is_same_v<T, int>) {
return current.asInt();
} else if constexpr (std::is_same_v<T, bool>) {
return current.asBool();
} else if constexpr (std::is_same_v<T, double>) {
return current.asDouble();
}
} catch (...) {
return std::nullopt;
}
return std::nullopt;
}
// Environment variable helper with fallback
template<typename T = std::string>
inline T getEnv(const std::string& key, const T& defaultValue = T{}) {
const char* val = std::getenv(key.c_str());
if (!val) return defaultValue;
if constexpr (std::is_same_v<T, std::string>) {
return std::string(val);
} else if constexpr (std::is_same_v<T, int>) {
try { return std::stoi(val); } catch (...) { return defaultValue; }
} else if constexpr (std::is_same_v<T, bool>) {
std::string s(val);
return s == "true" || s == "1" || s == "yes";
}
return defaultValue;
}
// Random string generator
inline std::string randomString(size_t length = 32) {
auto bytes = drogon::utils::genRandomString(length);
return drogon::utils::base64Encode(
reinterpret_cast<const unsigned char*>(bytes.data()),
bytes.length()
);
}
// Timer helper
class ScopedTimer {
std::string name_;
std::chrono::steady_clock::time_point start_;
public:
explicit ScopedTimer(const std::string& name)
: name_(name), start_(std::chrono::steady_clock::now()) {}
~ScopedTimer() {
auto duration = std::chrono::steady_clock::now() - start_;
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(duration).count();
LOG_DEBUG << name_ << " took " << ms << "ms";
}
};
#define TIMED_SCOPE(name) utils::ScopedTimer _timer(name)
// WebSocket broadcast helper
template<typename Container>
inline void wsBroadcast(const Container& connections, const Json::Value& message) {
auto msg = Json::FastWriter().write(message);
for (const auto& conn : connections) {
if (conn->connected()) {
conn->send(msg);
}
}
}
// Rate limiter
class RateLimiter {
std::unordered_map<std::string, std::deque<std::chrono::steady_clock::time_point>> requests_;
std::mutex mutex_;
size_t maxRequests_;
std::chrono::seconds window_;
public:
RateLimiter(size_t maxRequests = 10, std::chrono::seconds window = std::chrono::seconds(60))
: maxRequests_(maxRequests), window_(window) {}
bool allow(const std::string& key) {
std::lock_guard<std::mutex> lock(mutex_);
auto now = std::chrono::steady_clock::now();
auto& timestamps = requests_[key];
// Remove old timestamps
while (!timestamps.empty() && now - timestamps.front() > window_) {
timestamps.pop_front();
}
if (timestamps.size() >= maxRequests_) {
return false;
}
timestamps.push_back(now);
return true;
}
};
// Global rate limiter instance
inline RateLimiter& rateLimiter() {
static RateLimiter limiter;
return limiter;
}
// Validation helpers
inline bool isValidStreamKey(const std::string& key) {
return key.length() == 32 &&
std::all_of(key.begin(), key.end(), [](char c) {
return std::isxdigit(c);
});
}
// JSON conversion helper
template<typename T>
inline Json::Value toJson(const T& obj);
// Specializations for common types
template<>
inline Json::Value toJson(const std::map<std::string, std::string>& m) {
Json::Value json;
for (const auto& [k, v] : m) {
json[k] = v;
}
return json;
}
} // namespace utils

File diff suppressed because it is too large Load diff

View file

@ -12,6 +12,67 @@ public:
ADD_METHOD_TO(AdminController::disconnectStream, "/api/admin/streams/{1}/disconnect", Post);
ADD_METHOD_TO(AdminController::promoteToStreamer, "/api/admin/users/{1}/promote", Post);
ADD_METHOD_TO(AdminController::demoteFromStreamer, "/api/admin/users/{1}/demote", Post);
ADD_METHOD_TO(AdminController::promoteToRestreamer, "/api/admin/users/{1}/promote-restreamer", Post);
ADD_METHOD_TO(AdminController::demoteFromRestreamer, "/api/admin/users/{1}/demote-restreamer", Post);
ADD_METHOD_TO(AdminController::promoteToBot, "/api/admin/users/{1}/promote-bot", Post);
ADD_METHOD_TO(AdminController::demoteFromBot, "/api/admin/users/{1}/demote-bot", Post);
ADD_METHOD_TO(AdminController::getAllBotApiKeys, "/api/admin/bot-keys", Get);
ADD_METHOD_TO(AdminController::deleteBotApiKey, "/api/admin/bot-keys/{1}", Delete);
ADD_METHOD_TO(AdminController::uploadStickers, "/api/admin/stickers/upload", Post);
ADD_METHOD_TO(AdminController::getStickers, "/api/admin/stickers", Get);
ADD_METHOD_TO(AdminController::deleteSticker, "/api/admin/stickers/{1}", Delete);
ADD_METHOD_TO(AdminController::renameSticker, "/api/admin/stickers/{1}/rename", Put);
ADD_METHOD_TO(AdminController::promoteToStickerCreator, "/api/admin/users/{1}/promote-sticker-creator", Post);
ADD_METHOD_TO(AdminController::demoteFromStickerCreator, "/api/admin/users/{1}/demote-sticker-creator", Post);
ADD_METHOD_TO(AdminController::promoteToUploader, "/api/admin/users/{1}/promote-uploader", Post);
ADD_METHOD_TO(AdminController::demoteFromUploader, "/api/admin/users/{1}/demote-uploader", Post);
ADD_METHOD_TO(AdminController::promoteToTexter, "/api/admin/users/{1}/promote-texter", Post);
ADD_METHOD_TO(AdminController::demoteFromTexter, "/api/admin/users/{1}/demote-texter", Post);
ADD_METHOD_TO(AdminController::promoteToWatchCreator, "/api/admin/users/{1}/promote-watch-creator", Post);
ADD_METHOD_TO(AdminController::demoteFromWatchCreator, "/api/admin/users/{1}/demote-watch-creator", Post);
ADD_METHOD_TO(AdminController::promoteToModerator, "/api/admin/users/{1}/promote-moderator", Post);
ADD_METHOD_TO(AdminController::demoteFromModerator, "/api/admin/users/{1}/demote-moderator", Post);
ADD_METHOD_TO(AdminController::getStickerSubmissions, "/api/admin/sticker-submissions", Get);
ADD_METHOD_TO(AdminController::approveStickerSubmission, "/api/admin/sticker-submissions/{1}/approve", Post);
ADD_METHOD_TO(AdminController::denyStickerSubmission, "/api/admin/sticker-submissions/{1}/deny", Post);
ADD_METHOD_TO(AdminController::uploadHonkSound, "/api/admin/honks/upload", Post);
ADD_METHOD_TO(AdminController::getHonkSounds, "/api/admin/honks", Get);
ADD_METHOD_TO(AdminController::deleteHonkSound, "/api/admin/honks/{1}", Delete);
ADD_METHOD_TO(AdminController::setActiveHonkSound, "/api/admin/honks/{1}/activate", Post);
ADD_METHOD_TO(AdminController::getActiveHonkSound, "/api/honk/active", Get);
ADD_METHOD_TO(AdminController::getChatSettings, "/api/admin/settings/chat", Get);
ADD_METHOD_TO(AdminController::updateChatSettings, "/api/admin/settings/chat", Put);
ADD_METHOD_TO(AdminController::getRealms, "/api/admin/realms", Get);
ADD_METHOD_TO(AdminController::deleteRealm, "/api/admin/realms/{1}", Delete);
ADD_METHOD_TO(AdminController::setViewerMultiplier, "/api/admin/realms/{1}/viewer-multiplier", Post);
ADD_METHOD_TO(AdminController::deleteUser, "/api/admin/users/{1}", Delete);
ADD_METHOD_TO(AdminController::disableUser, "/api/admin/users/{1}/disable", Post);
ADD_METHOD_TO(AdminController::enableUser, "/api/admin/users/{1}/enable", Post);
ADD_METHOD_TO(AdminController::uberbanUser, "/api/admin/users/{1}/uberban", Post);
ADD_METHOD_TO(AdminController::incrementReferrals, "/api/admin/users/{1}/increment-referrals", Post);
ADD_METHOD_TO(AdminController::getVideos, "/api/admin/videos", Get);
ADD_METHOD_TO(AdminController::deleteVideo, "/api/admin/videos/{1}", Delete);
ADD_METHOD_TO(AdminController::getAudios, "/api/admin/audios", Get);
ADD_METHOD_TO(AdminController::deleteAudio, "/api/admin/audios/{1}", Delete);
ADD_METHOD_TO(AdminController::getEbooks, "/api/admin/ebooks", Get);
ADD_METHOD_TO(AdminController::deleteEbook, "/api/admin/ebooks/{1}", Delete);
ADD_METHOD_TO(AdminController::getSiteSettings, "/api/admin/settings/site", Get);
ADD_METHOD_TO(AdminController::updateSiteSettings, "/api/admin/settings/site", Put);
ADD_METHOD_TO(AdminController::uploadSiteLogo, "/api/admin/settings/site/logo", Post);
ADD_METHOD_TO(AdminController::getPublicSiteSettings, "/api/settings/site", Get);
ADD_METHOD_TO(AdminController::getCensoredWords, "/api/internal/censored-words", Get);
ADD_METHOD_TO(AdminController::uploadDefaultAvatars, "/api/admin/default-avatars/upload", Post);
ADD_METHOD_TO(AdminController::getDefaultAvatars, "/api/admin/default-avatars", Get);
ADD_METHOD_TO(AdminController::deleteDefaultAvatar, "/api/admin/default-avatars/{1}", Delete);
ADD_METHOD_TO(AdminController::getRandomDefaultAvatar, "/api/default-avatar/random", Get);
ADD_METHOD_TO(AdminController::trackStickerUsage, "/api/internal/stickers/track-usage", Post);
ADD_METHOD_TO(AdminController::getStickerStats, "/api/stats/stickers", Get);
ADD_METHOD_TO(AdminController::downloadAllStickers, "/api/admin/stickers/download-all", Get);
// SSL Certificate Management
ADD_METHOD_TO(AdminController::getSSLSettings, "/api/admin/settings/ssl", Get);
ADD_METHOD_TO(AdminController::updateSSLSettings, "/api/admin/settings/ssl", Put);
ADD_METHOD_TO(AdminController::requestCertificate, "/api/admin/ssl/request", Post);
ADD_METHOD_TO(AdminController::getSSLStatus, "/api/admin/ssl/status", Get);
METHOD_LIST_END
void getUsers(const HttpRequestPtr &req,
@ -32,6 +93,216 @@ public:
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
private:
UserInfo getUserFromRequest(const HttpRequestPtr &req);
void promoteToRestreamer(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void demoteFromRestreamer(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void promoteToBot(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void demoteFromBot(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void getAllBotApiKeys(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void deleteBotApiKey(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &keyId);
void uploadStickers(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getStickers(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void deleteSticker(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &stickerId);
void renameSticker(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &stickerId);
void promoteToStickerCreator(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void demoteFromStickerCreator(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void promoteToUploader(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void demoteFromUploader(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void promoteToTexter(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void demoteFromTexter(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void promoteToWatchCreator(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void demoteFromWatchCreator(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void promoteToModerator(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void demoteFromModerator(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void getStickerSubmissions(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void approveStickerSubmission(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &submissionId);
void denyStickerSubmission(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &submissionId);
void uploadHonkSound(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getHonkSounds(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void deleteHonkSound(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &honkId);
void setActiveHonkSound(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &honkId);
void getActiveHonkSound(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getChatSettings(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void updateChatSettings(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getRealms(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void deleteRealm(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void setViewerMultiplier(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void deleteUser(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void disableUser(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void enableUser(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void uberbanUser(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void incrementReferrals(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void getVideos(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void deleteVideo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &videoId);
void getAudios(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void deleteAudio(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &audioId);
void getEbooks(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void deleteEbook(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &ebookId);
void getSiteSettings(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void updateSiteSettings(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void uploadSiteLogo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getPublicSiteSettings(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getCensoredWords(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void uploadDefaultAvatars(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getDefaultAvatars(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void deleteDefaultAvatar(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &avatarId);
void getRandomDefaultAvatar(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void trackStickerUsage(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getStickerStats(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void downloadAllStickers(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
// SSL Certificate Management
void getSSLSettings(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void updateSSLSettings(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void requestCertificate(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getSSLStatus(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
};

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,72 @@
#pragma once
#include <drogon/HttpController.h>
#include "../services/AuthService.h"
using namespace drogon;
class AudioController : public HttpController<AudioController> {
public:
METHOD_LIST_BEGIN
// Public endpoints
ADD_METHOD_TO(AudioController::getAllAudio, "/api/audio", Get);
ADD_METHOD_TO(AudioController::getLatestAudio, "/api/audio/latest", Get);
ADD_METHOD_TO(AudioController::getAudio, "/api/audio/{1}", Get);
ADD_METHOD_TO(AudioController::getRealmAudio, "/api/audio/realm/{1}", Get);
ADD_METHOD_TO(AudioController::getRealmAudioByName, "/api/audio/realm/name/{1}", Get);
ADD_METHOD_TO(AudioController::incrementPlayCount, "/api/audio/{1}/play", Post);
// Authenticated endpoints
ADD_METHOD_TO(AudioController::getMyAudio, "/api/user/audio", Get);
ADD_METHOD_TO(AudioController::uploadAudio, "/api/user/audio", Post);
ADD_METHOD_TO(AudioController::updateAudio, "/api/audio/{1}", Put);
ADD_METHOD_TO(AudioController::deleteAudio, "/api/audio/{1}", Delete);
ADD_METHOD_TO(AudioController::uploadThumbnail, "/api/audio/{1}/thumbnail", Post);
ADD_METHOD_TO(AudioController::deleteThumbnail, "/api/audio/{1}/thumbnail", Delete);
METHOD_LIST_END
// Public audio listing
void getAllAudio(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getLatestAudio(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getAudio(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &audioId);
void getRealmAudio(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void getRealmAudioByName(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmName);
void incrementPlayCount(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &audioId);
// Authenticated audio management
void getMyAudio(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void uploadAudio(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void updateAudio(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &audioId);
void deleteAudio(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &audioId);
void uploadThumbnail(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &audioId);
void deleteThumbnail(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &audioId);
};

View file

@ -0,0 +1,908 @@
#include "EbookController.h"
#include "../services/DatabaseService.h"
#include "../common/HttpHelpers.h"
#include "../common/AuthHelpers.h"
#include "../common/FileUtils.h"
#include "../common/FileValidation.h"
#include <drogon/utils/Utilities.h>
#include <drogon/Cookie.h>
#include <random>
#include <sstream>
#include <iomanip>
#include <fstream>
#include <filesystem>
#include <regex>
#include <cerrno>
// File size limits
static constexpr size_t MAX_EBOOK_SIZE = 100 * 1024 * 1024; // 100MB
static constexpr size_t MAX_COVER_SIZE = 5 * 1024 * 1024; // 5MB
using namespace drogon::orm;
void EbookController::getAllEbooks(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
int page = 1;
int limit = 20;
auto pageParam = req->getParameter("page");
auto limitParam = req->getParameter("limit");
if (!pageParam.empty()) {
try { page = std::stoi(pageParam); } catch (...) {}
}
if (!limitParam.empty()) {
try { limit = std::min(std::stoi(limitParam), 50); } catch (...) {}
}
int offset = (page - 1) * limit;
auto dbClient = app().getDbClient();
*dbClient << "SELECT e.id, e.title, e.description, e.file_path, e.cover_path, "
"e.file_size_bytes, e.chapter_count, e.read_count, e.created_at, e.realm_id, "
"u.id as user_id, u.username, u.avatar_url, "
"r.name as realm_name "
"FROM ebooks e "
"JOIN users u ON e.user_id = u.id "
"JOIN realms r ON e.realm_id = r.id "
"WHERE e.is_public = true AND e.status = 'ready' "
"ORDER BY e.created_at DESC "
"LIMIT $1 OFFSET $2"
<< static_cast<int64_t>(limit) << static_cast<int64_t>(offset)
>> [callback](const Result& r) {
Json::Value resp;
resp["success"] = true;
Json::Value ebooks(Json::arrayValue);
for (const auto& row : r) {
Json::Value ebook;
ebook["id"] = static_cast<Json::Int64>(row["id"].as<int64_t>());
ebook["title"] = row["title"].as<std::string>();
ebook["description"] = row["description"].isNull() ? "" : row["description"].as<std::string>();
ebook["filePath"] = row["file_path"].as<std::string>();
ebook["coverPath"] = row["cover_path"].isNull() ? "" : row["cover_path"].as<std::string>();
ebook["fileSizeBytes"] = static_cast<Json::Int64>(row["file_size_bytes"].as<int64_t>());
ebook["chapterCount"] = row["chapter_count"].isNull() ? 0 : row["chapter_count"].as<int>();
ebook["readCount"] = row["read_count"].as<int>();
ebook["createdAt"] = row["created_at"].as<std::string>();
ebook["userId"] = static_cast<Json::Int64>(row["user_id"].as<int64_t>());
ebook["username"] = row["username"].as<std::string>();
ebook["avatarUrl"] = row["avatar_url"].isNull() ? "" : row["avatar_url"].as<std::string>();
ebook["realmId"] = static_cast<Json::Int64>(row["realm_id"].as<int64_t>());
ebook["realmName"] = row["realm_name"].as<std::string>();
ebooks.append(ebook);
}
resp["ebooks"] = ebooks;
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "get ebooks");
}
void EbookController::getLatestEbooks(const HttpRequestPtr &,
std::function<void(const HttpResponsePtr &)> &&callback) {
auto dbClient = app().getDbClient();
*dbClient << "SELECT e.id, e.title, e.description, e.file_path, e.cover_path, "
"e.file_size_bytes, e.chapter_count, e.read_count, e.created_at, e.realm_id, "
"u.id as user_id, u.username, u.avatar_url, "
"r.name as realm_name "
"FROM ebooks e "
"JOIN users u ON e.user_id = u.id "
"JOIN realms r ON e.realm_id = r.id "
"WHERE e.is_public = true AND e.status = 'ready' "
"ORDER BY e.created_at DESC "
"LIMIT 5"
>> [callback](const Result& r) {
Json::Value resp;
resp["success"] = true;
Json::Value ebooks(Json::arrayValue);
for (const auto& row : r) {
Json::Value ebook;
ebook["id"] = static_cast<Json::Int64>(row["id"].as<int64_t>());
ebook["title"] = row["title"].as<std::string>();
ebook["description"] = row["description"].isNull() ? "" : row["description"].as<std::string>();
ebook["filePath"] = row["file_path"].as<std::string>();
ebook["coverPath"] = row["cover_path"].isNull() ? "" : row["cover_path"].as<std::string>();
ebook["fileSizeBytes"] = static_cast<Json::Int64>(row["file_size_bytes"].as<int64_t>());
ebook["chapterCount"] = row["chapter_count"].isNull() ? 0 : row["chapter_count"].as<int>();
ebook["readCount"] = row["read_count"].as<int>();
ebook["createdAt"] = row["created_at"].as<std::string>();
ebook["userId"] = static_cast<Json::Int64>(row["user_id"].as<int64_t>());
ebook["username"] = row["username"].as<std::string>();
ebook["avatarUrl"] = row["avatar_url"].isNull() ? "" : row["avatar_url"].as<std::string>();
ebook["realmId"] = static_cast<Json::Int64>(row["realm_id"].as<int64_t>());
ebook["realmName"] = row["realm_name"].as<std::string>();
ebooks.append(ebook);
}
resp["ebooks"] = ebooks;
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "get latest ebooks");
}
void EbookController::getEbook(const HttpRequestPtr &,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &ebookId) {
int64_t id;
try {
id = std::stoll(ebookId);
} catch (...) {
callback(jsonError("Invalid ebook ID", k400BadRequest));
return;
}
auto dbClient = app().getDbClient();
*dbClient << "SELECT e.id, e.title, e.description, e.file_path, e.cover_path, "
"e.file_size_bytes, e.chapter_count, e.read_count, e.is_public, e.status, "
"e.created_at, e.realm_id, "
"u.id as user_id, u.username, u.avatar_url, "
"r.name as realm_name "
"FROM ebooks e "
"JOIN users u ON e.user_id = u.id "
"JOIN realms r ON e.realm_id = r.id "
"WHERE e.id = $1 AND e.status = 'ready'"
<< id
>> [callback](const Result& r) {
if (r.empty()) {
callback(jsonError("Ebook not found", k404NotFound));
return;
}
const auto& row = r[0];
if (!row["is_public"].as<bool>()) {
callback(jsonError("Ebook not found", k404NotFound));
return;
}
Json::Value resp;
resp["success"] = true;
auto& ebook = resp["ebook"];
ebook["id"] = static_cast<Json::Int64>(row["id"].as<int64_t>());
ebook["title"] = row["title"].as<std::string>();
ebook["description"] = row["description"].isNull() ? "" : row["description"].as<std::string>();
ebook["filePath"] = row["file_path"].as<std::string>();
ebook["coverPath"] = row["cover_path"].isNull() ? "" : row["cover_path"].as<std::string>();
ebook["fileSizeBytes"] = static_cast<Json::Int64>(row["file_size_bytes"].as<int64_t>());
ebook["chapterCount"] = row["chapter_count"].isNull() ? 0 : row["chapter_count"].as<int>();
ebook["readCount"] = row["read_count"].as<int>();
ebook["createdAt"] = row["created_at"].as<std::string>();
ebook["userId"] = static_cast<Json::Int64>(row["user_id"].as<int64_t>());
ebook["username"] = row["username"].as<std::string>();
ebook["avatarUrl"] = row["avatar_url"].isNull() ? "" : row["avatar_url"].as<std::string>();
ebook["realmId"] = static_cast<Json::Int64>(row["realm_id"].as<int64_t>());
ebook["realmName"] = row["realm_name"].as<std::string>();
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "get ebook");
}
void EbookController::getUserEbooks(const HttpRequestPtr &,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &username) {
auto dbClient = app().getDbClient();
*dbClient << "SELECT e.id, e.title, e.description, e.file_path, e.cover_path, "
"e.file_size_bytes, e.chapter_count, e.read_count, e.created_at, e.realm_id, "
"u.id as user_id, u.username, u.avatar_url, "
"r.name as realm_name "
"FROM ebooks e "
"JOIN users u ON e.user_id = u.id "
"JOIN realms r ON e.realm_id = r.id "
"WHERE u.username = $1 AND e.is_public = true AND e.status = 'ready' "
"ORDER BY e.created_at DESC"
<< username
>> [callback, username](const Result& r) {
Json::Value resp;
resp["success"] = true;
resp["username"] = username;
Json::Value ebooks(Json::arrayValue);
for (const auto& row : r) {
Json::Value ebook;
ebook["id"] = static_cast<Json::Int64>(row["id"].as<int64_t>());
ebook["title"] = row["title"].as<std::string>();
ebook["description"] = row["description"].isNull() ? "" : row["description"].as<std::string>();
ebook["filePath"] = row["file_path"].as<std::string>();
ebook["coverPath"] = row["cover_path"].isNull() ? "" : row["cover_path"].as<std::string>();
ebook["fileSizeBytes"] = static_cast<Json::Int64>(row["file_size_bytes"].as<int64_t>());
ebook["chapterCount"] = row["chapter_count"].isNull() ? 0 : row["chapter_count"].as<int>();
ebook["readCount"] = row["read_count"].as<int>();
ebook["createdAt"] = row["created_at"].as<std::string>();
ebook["realmId"] = static_cast<Json::Int64>(row["realm_id"].as<int64_t>());
ebook["realmName"] = row["realm_name"].as<std::string>();
ebooks.append(ebook);
}
resp["ebooks"] = ebooks;
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "get user ebooks");
}
void EbookController::getRealmEbooks(const HttpRequestPtr &,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId) {
int64_t id;
try {
id = std::stoll(realmId);
} catch (...) {
callback(jsonError("Invalid realm ID", k400BadRequest));
return;
}
auto dbClient = app().getDbClient();
// First get realm info
*dbClient << "SELECT r.id, r.name, r.description, r.realm_type, r.title_color, r.created_at, "
"u.id as user_id, u.username, u.avatar_url "
"FROM realms r "
"JOIN users u ON r.user_id = u.id "
"WHERE r.id = $1 AND r.is_active = true AND r.realm_type = 'ebook'"
<< id
>> [callback, dbClient, id](const Result& realmResult) {
if (realmResult.empty()) {
callback(jsonError("Ebook realm not found", k404NotFound));
return;
}
// Get ebooks for this realm
*dbClient << "SELECT e.id, e.title, e.description, e.file_path, e.cover_path, "
"e.file_size_bytes, e.chapter_count, e.read_count, e.created_at "
"FROM ebooks e "
"WHERE e.realm_id = $1 AND e.is_public = true AND e.status = 'ready' "
"ORDER BY e.created_at DESC LIMIT 100"
<< id
>> [callback, realmResult](const Result& r) {
Json::Value resp;
resp["success"] = true;
// Realm info
auto& realm = resp["realm"];
realm["id"] = static_cast<Json::Int64>(realmResult[0]["id"].as<int64_t>());
realm["name"] = realmResult[0]["name"].as<std::string>();
realm["description"] = realmResult[0]["description"].isNull() ? "" : realmResult[0]["description"].as<std::string>();
realm["titleColor"] = realmResult[0]["title_color"].isNull() ? "#ffffff" : realmResult[0]["title_color"].as<std::string>();
realm["createdAt"] = realmResult[0]["created_at"].as<std::string>();
realm["userId"] = static_cast<Json::Int64>(realmResult[0]["user_id"].as<int64_t>());
realm["username"] = realmResult[0]["username"].as<std::string>();
realm["avatarUrl"] = realmResult[0]["avatar_url"].isNull() ? "" : realmResult[0]["avatar_url"].as<std::string>();
// Ebooks
Json::Value ebooks(Json::arrayValue);
for (const auto& row : r) {
Json::Value ebook;
ebook["id"] = static_cast<Json::Int64>(row["id"].as<int64_t>());
ebook["title"] = row["title"].as<std::string>();
ebook["description"] = row["description"].isNull() ? "" : row["description"].as<std::string>();
ebook["filePath"] = row["file_path"].as<std::string>();
ebook["coverPath"] = row["cover_path"].isNull() ? "" : row["cover_path"].as<std::string>();
ebook["fileSizeBytes"] = static_cast<Json::Int64>(row["file_size_bytes"].as<int64_t>());
ebook["chapterCount"] = row["chapter_count"].isNull() ? 0 : row["chapter_count"].as<int>();
ebook["readCount"] = row["read_count"].as<int>();
ebook["createdAt"] = row["created_at"].as<std::string>();
ebooks.append(ebook);
}
resp["ebooks"] = ebooks;
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "get realm ebooks");
}
>> DB_ERROR(callback, "get realm");
}
void EbookController::incrementReadCount(const HttpRequestPtr &,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &ebookId) {
int64_t id;
try {
id = std::stoll(ebookId);
} catch (...) {
callback(jsonError("Invalid ebook ID", k400BadRequest));
return;
}
auto dbClient = app().getDbClient();
*dbClient << "UPDATE ebooks SET read_count = read_count + 1 "
"WHERE id = $1 AND is_public = true AND status = 'ready' "
"RETURNING read_count"
<< id
>> [callback](const Result& r) {
if (r.empty()) {
callback(jsonError("Ebook not found", k404NotFound));
return;
}
Json::Value resp;
resp["success"] = true;
resp["readCount"] = r[0]["read_count"].as<int>();
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "increment read count");
}
void EbookController::getMyEbooks(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
UserInfo user = getUserFromRequest(req);
if (user.id == 0) {
callback(jsonError("Unauthorized", k401Unauthorized));
return;
}
auto dbClient = app().getDbClient();
*dbClient << "SELECT e.id, e.title, e.description, e.file_path, e.cover_path, "
"e.file_size_bytes, e.chapter_count, e.read_count, e.is_public, e.status, e.created_at, "
"e.realm_id, r.name as realm_name "
"FROM ebooks e "
"JOIN realms r ON e.realm_id = r.id "
"WHERE e.user_id = $1 AND e.status != 'deleted' "
"ORDER BY e.created_at DESC"
<< user.id
>> [callback](const Result& r) {
Json::Value resp;
resp["success"] = true;
Json::Value ebooks(Json::arrayValue);
for (const auto& row : r) {
Json::Value ebook;
ebook["id"] = static_cast<Json::Int64>(row["id"].as<int64_t>());
ebook["title"] = row["title"].as<std::string>();
ebook["description"] = row["description"].isNull() ? "" : row["description"].as<std::string>();
ebook["filePath"] = row["file_path"].as<std::string>();
ebook["coverPath"] = row["cover_path"].isNull() ? "" : row["cover_path"].as<std::string>();
ebook["fileSizeBytes"] = static_cast<Json::Int64>(row["file_size_bytes"].as<int64_t>());
ebook["chapterCount"] = row["chapter_count"].isNull() ? 0 : row["chapter_count"].as<int>();
ebook["readCount"] = row["read_count"].as<int>();
ebook["isPublic"] = row["is_public"].as<bool>();
ebook["status"] = row["status"].as<std::string>();
ebook["createdAt"] = row["created_at"].as<std::string>();
ebook["realmId"] = static_cast<Json::Int64>(row["realm_id"].as<int64_t>());
ebook["realmName"] = row["realm_name"].as<std::string>();
ebooks.append(ebook);
}
resp["ebooks"] = ebooks;
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "get user ebooks");
}
void EbookController::uploadEbook(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
try {
UserInfo user = getUserFromRequest(req);
if (user.id == 0) {
callback(jsonError("Unauthorized", k401Unauthorized));
return;
}
MultiPartParser parser;
parser.parse(req);
// Get realm ID from form data - required
std::string realmIdStr = parser.getParameter<std::string>("realmId");
if (realmIdStr.empty()) {
callback(jsonError("Realm ID is required"));
return;
}
int64_t realmId;
try {
realmId = std::stoll(realmIdStr);
} catch (...) {
callback(jsonError("Invalid realm ID"));
return;
}
// Extract file
if (parser.getFiles().empty()) {
callback(jsonError("No file uploaded"));
return;
}
const auto& file = parser.getFiles()[0];
// Get title from form data - sanitize input
std::string title = sanitizeUserInput(parser.getParameter<std::string>("title"), 255);
if (title.empty()) {
title = "Untitled Ebook";
}
// Get optional description - sanitize input
std::string description = sanitizeUserInput(parser.getParameter<std::string>("description"), 5000);
// Validate file size
size_t fileSize = file.fileLength();
if (fileSize > MAX_EBOOK_SIZE) {
callback(jsonError("File too large (max 100MB)"));
return;
}
if (fileSize == 0) {
callback(jsonError("Empty file uploaded"));
return;
}
// Validate EPUB magic bytes
if (!isValidEpub(file.fileData(), fileSize)) {
LOG_WARN << "Ebook upload rejected: invalid EPUB file";
callback(jsonError("Invalid file. Only EPUB format is allowed."));
return;
}
// Copy file data before async call
std::string fileDataStr(file.fileData(), fileSize);
// Check if user has uploader role and the realm exists and belongs to them
auto dbClient = app().getDbClient();
*dbClient << "SELECT u.is_uploader, r.id as realm_id, r.realm_type "
"FROM users u "
"LEFT JOIN realms r ON r.user_id = u.id AND r.id = $2 "
"WHERE u.id = $1"
<< user.id << realmId
>> [callback, user, dbClient, realmId, title, description, fileDataStr, fileSize](const Result& r) {
if (r.empty() || !r[0]["is_uploader"].as<bool>()) {
callback(jsonError("You don't have permission to upload ebooks", k403Forbidden));
return;
}
// Check if realm exists and belongs to user
if (r[0]["realm_id"].isNull()) {
callback(jsonError("Ebook realm not found or doesn't belong to you", k404NotFound));
return;
}
// Check if it's an ebook realm
std::string realmType = r[0]["realm_type"].isNull() ? "stream" : r[0]["realm_type"].as<std::string>();
if (realmType != "ebook") {
callback(jsonError("Can only upload ebooks to ebook realms", k400BadRequest));
return;
}
// Ensure uploads directory exists
const std::string uploadDir = "/app/uploads/ebooks";
if (!ensureDirectoryExists(uploadDir)) {
callback(jsonError("Failed to create upload directory"));
return;
}
// Generate unique filename and create file atomically
// This prevents TOCTOU race conditions
std::string filename;
std::string fullPath;
int maxAttempts = 10;
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
filename = generateRandomFilename("epub");
fullPath = uploadDir + "/" + filename;
// Try to create file atomically with exclusive access
if (writeFileExclusive(fullPath, fileDataStr.data(), fileSize)) {
break; // Success
}
// If file already exists (EEXIST), try again with new name
if (errno == EEXIST) {
continue;
}
// Other error - fail
LOG_ERROR << "Failed to create file: " << fullPath << " errno: " << errno;
callback(jsonError("Failed to save file"));
return;
}
// Verify file was created
if (!std::filesystem::exists(fullPath)) {
LOG_ERROR << "File was not created after " << maxAttempts << " attempts";
callback(jsonError("Failed to save file"));
return;
}
try {
std::string filePath = "/uploads/ebooks/" + filename;
// Insert ebook record - status is 'ready' (no server-side processing needed)
*dbClient << "INSERT INTO ebooks (user_id, realm_id, title, description, file_path, "
"file_size_bytes, status, is_public) "
"VALUES ($1, $2, $3, $4, $5, $6, 'ready', true) RETURNING id, created_at"
<< user.id << realmId << title << description << filePath
<< static_cast<int64_t>(fileSize)
>> [callback, title, filePath, fileSize, realmId](const Result& r2) {
if (r2.empty()) {
callback(jsonError("Failed to save ebook record"));
return;
}
int64_t ebookId = r2[0]["id"].as<int64_t>();
Json::Value resp;
resp["success"] = true;
resp["ebook"]["id"] = static_cast<Json::Int64>(ebookId);
resp["ebook"]["realmId"] = static_cast<Json::Int64>(realmId);
resp["ebook"]["title"] = title;
resp["ebook"]["filePath"] = filePath;
resp["ebook"]["fileSizeBytes"] = static_cast<Json::Int64>(fileSize);
resp["ebook"]["status"] = "ready";
resp["ebook"]["createdAt"] = r2[0]["created_at"].as<std::string>();
callback(jsonResp(resp));
}
>> [callback, fullPath](const DrogonDbException& e) {
LOG_ERROR << "Failed to insert ebook: " << e.base().what();
// Clean up file on DB error
std::filesystem::remove(fullPath);
callback(jsonError("Failed to save ebook"));
};
} catch (const std::exception& e) {
LOG_ERROR << "Exception saving ebook file: " << e.what();
callback(jsonError("Failed to save file"));
}
}
>> DB_ERROR(callback, "check uploader status");
} catch (const std::exception& e) {
LOG_ERROR << "Exception in uploadEbook: " << e.what();
callback(jsonError("Internal server error"));
}
}
void EbookController::updateEbook(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &ebookId) {
UserInfo user = getUserFromRequest(req);
if (user.id == 0) {
callback(jsonError("Unauthorized", k401Unauthorized));
return;
}
int64_t id;
try {
id = std::stoll(ebookId);
} catch (...) {
callback(jsonError("Invalid ebook ID", k400BadRequest));
return;
}
auto json = req->getJsonObject();
if (!json) {
callback(jsonError("Invalid JSON"));
return;
}
auto dbClient = app().getDbClient();
// Verify ownership
*dbClient << "SELECT id FROM ebooks WHERE id = $1 AND user_id = $2 AND status != 'deleted'"
<< id << user.id
>> [callback, json, dbClient, id](const Result& r) {
if (r.empty()) {
callback(jsonError("Ebook not found or access denied", k404NotFound));
return;
}
std::string title, description;
if (json->isMember("title")) {
title = sanitizeUserInput((*json)["title"].asString(), 255);
}
if (json->isMember("description")) {
description = sanitizeUserInput((*json)["description"].asString(), 5000);
}
if (json->isMember("title") && json->isMember("description")) {
*dbClient << "UPDATE ebooks SET title = $1, description = $2, updated_at = NOW() WHERE id = $3"
<< title << description << id
>> [callback](const Result&) {
Json::Value resp;
resp["success"] = true;
resp["message"] = "Ebook updated successfully";
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "update ebook", "Failed to update ebook");
} else if (json->isMember("title")) {
*dbClient << "UPDATE ebooks SET title = $1, updated_at = NOW() WHERE id = $2"
<< title << id
>> [callback](const Result&) {
Json::Value resp;
resp["success"] = true;
resp["message"] = "Ebook updated successfully";
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "update ebook", "Failed to update ebook");
} else if (json->isMember("description")) {
*dbClient << "UPDATE ebooks SET description = $1, updated_at = NOW() WHERE id = $2"
<< description << id
>> [callback](const Result&) {
Json::Value resp;
resp["success"] = true;
resp["message"] = "Ebook updated successfully";
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "update ebook", "Failed to update ebook");
} else {
Json::Value resp;
resp["success"] = true;
resp["message"] = "No changes to apply";
callback(jsonResp(resp));
}
}
>> DB_ERROR(callback, "verify ebook ownership");
}
void EbookController::deleteEbook(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &ebookId) {
UserInfo user = getUserFromRequest(req);
if (user.id == 0) {
callback(jsonError("Unauthorized", k401Unauthorized));
return;
}
int64_t id;
try {
id = std::stoll(ebookId);
} catch (...) {
callback(jsonError("Invalid ebook ID", k400BadRequest));
return;
}
auto dbClient = app().getDbClient();
// Get file path and verify ownership
*dbClient << "SELECT file_path, cover_path FROM ebooks "
"WHERE id = $1 AND user_id = $2 AND status != 'deleted'"
<< id << user.id
>> [callback, dbClient, id](const Result& r) {
if (r.empty()) {
callback(jsonError("Ebook not found or access denied", k404NotFound));
return;
}
std::string filePath = r[0]["file_path"].as<std::string>();
std::string coverPath = r[0]["cover_path"].isNull() ? "" : r[0]["cover_path"].as<std::string>();
// Soft delete by setting status to 'deleted'
*dbClient << "UPDATE ebooks SET status = 'deleted' WHERE id = $1"
<< id
>> [callback, filePath, coverPath](const Result&) {
// Delete files from disk with path validation
try {
std::string fullEbookPath = "/app" + filePath;
// Validate path is within allowed directory before deletion
if (isPathSafe(fullEbookPath, "/app/uploads/ebooks")) {
if (std::filesystem::exists(fullEbookPath)) {
std::filesystem::remove(fullEbookPath);
}
} else {
LOG_WARN << "Blocked deletion of file outside uploads: " << fullEbookPath;
}
if (!coverPath.empty()) {
std::string fullCoverPath = "/app" + coverPath;
// Validate cover path as well
if (isPathSafe(fullCoverPath, "/app/uploads/ebooks")) {
if (std::filesystem::exists(fullCoverPath)) {
std::filesystem::remove(fullCoverPath);
}
} else {
LOG_WARN << "Blocked deletion of cover outside uploads: " << fullCoverPath;
}
}
} catch (const std::exception& e) {
LOG_WARN << "Failed to delete ebook files: " << e.what();
}
Json::Value resp;
resp["success"] = true;
resp["message"] = "Ebook deleted successfully";
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "delete ebook", "Failed to delete ebook");
}
>> DB_ERROR(callback, "get ebook for deletion");
}
void EbookController::uploadCover(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &ebookId) {
try {
UserInfo user = getUserFromRequest(req);
if (user.id == 0) {
callback(jsonError("Unauthorized", k401Unauthorized));
return;
}
int64_t id;
try {
id = std::stoll(ebookId);
} catch (...) {
callback(jsonError("Invalid ebook ID", k400BadRequest));
return;
}
MultiPartParser parser;
parser.parse(req);
if (parser.getFiles().empty()) {
callback(jsonError("No file uploaded"));
return;
}
const auto& file = parser.getFiles()[0];
size_t fileSize = file.fileLength();
// Validate cover file size
if (fileSize > MAX_COVER_SIZE) {
callback(jsonError("Cover image too large (max 5MB)"));
return;
}
if (fileSize == 0) {
callback(jsonError("Empty file uploaded"));
return;
}
// Validate image type
auto validation = validateImageMagicBytes(file.fileData(), fileSize);
if (!validation.valid) {
callback(jsonError("Invalid image file. Only JPG, PNG, and WebP are allowed."));
return;
}
std::string fileDataStr(file.fileData(), fileSize);
std::string fileExt = validation.extension;
auto dbClient = app().getDbClient();
// Verify ownership
*dbClient << "SELECT id, cover_path FROM ebooks WHERE id = $1 AND user_id = $2 AND status != 'deleted'"
<< id << user.id
>> [callback, dbClient, id, fileDataStr, fileSize, fileExt](const Result& r) {
if (r.empty()) {
callback(jsonError("Ebook not found or access denied", k404NotFound));
return;
}
std::string oldCoverPath = r[0]["cover_path"].isNull() ? "" : r[0]["cover_path"].as<std::string>();
// Ensure covers directory exists
const std::string coverDir = "/app/uploads/ebooks/covers";
if (!ensureDirectoryExists(coverDir)) {
callback(jsonError("Failed to create upload directory"));
return;
}
// Generate unique filename and create file atomically
std::string filename;
std::string fullPath;
int maxAttempts = 10;
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
filename = generateRandomFilename(fileExt);
fullPath = coverDir + "/" + filename;
if (writeFileExclusive(fullPath, fileDataStr.data(), fileSize)) {
break;
}
if (errno == EEXIST) {
continue;
}
callback(jsonError("Failed to save cover"));
return;
}
if (!std::filesystem::exists(fullPath)) {
callback(jsonError("Failed to save cover"));
return;
}
std::string coverPath = "/uploads/ebooks/covers/" + filename;
// Update database
*dbClient << "UPDATE ebooks SET cover_path = $1 WHERE id = $2"
<< coverPath << id
>> [callback, coverPath, oldCoverPath](const Result&) {
// Delete old cover if exists (with path validation)
if (!oldCoverPath.empty()) {
try {
std::string oldFullPath = "/app" + oldCoverPath;
// Validate path before deletion
if (isPathSafe(oldFullPath, "/app/uploads/ebooks")) {
if (std::filesystem::exists(oldFullPath)) {
std::filesystem::remove(oldFullPath);
}
} else {
LOG_WARN << "Blocked deletion of old cover outside uploads: " << oldFullPath;
}
} catch (...) {}
}
Json::Value resp;
resp["success"] = true;
resp["coverPath"] = coverPath;
callback(jsonResp(resp));
}
>> [callback, fullPath](const DrogonDbException& e) {
LOG_ERROR << "Failed to update cover path: " << e.base().what();
std::filesystem::remove(fullPath);
callback(jsonError("Failed to save cover"));
};
}
>> DB_ERROR(callback, "verify ebook ownership");
} catch (const std::exception& e) {
LOG_ERROR << "Exception in uploadCover: " << e.what();
callback(jsonError("Internal server error"));
}
}
void EbookController::downloadEbook(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &ebookId) {
// Require authentication for downloads
UserInfo user = getUserFromRequest(req);
if (user.id == 0) {
callback(jsonError("Please log in to download ebooks", k401Unauthorized));
return;
}
int64_t id;
try {
id = std::stoll(ebookId);
} catch (...) {
callback(jsonError("Invalid ebook ID", k400BadRequest));
return;
}
auto dbClient = app().getDbClient();
// Get ebook info - only allow download of public, ready ebooks
*dbClient << "SELECT title, file_path FROM ebooks WHERE id = $1 AND is_public = true AND status = 'ready'"
<< id
>> [callback, id](const Result& r) {
if (r.empty()) {
callback(jsonError("Ebook not found", k404NotFound));
return;
}
std::string title = r[0]["title"].as<std::string>();
std::string filePath = r[0]["file_path"].as<std::string>();
std::string fullPath = "/app" + filePath;
// Validate path is within allowed directory
if (!isPathSafe(fullPath, "/app/uploads/ebooks")) {
LOG_WARN << "Blocked access to file outside uploads: " << fullPath;
callback(jsonError("Ebook file not found", k404NotFound));
return;
}
// Check file exists
if (!std::filesystem::exists(fullPath)) {
LOG_ERROR << "Ebook file not found: " << fullPath;
callback(jsonError("Ebook file not found", k404NotFound));
return;
}
// Sanitize title for filename (remove special chars)
std::string safeTitle;
for (char c : title) {
if (std::isalnum(c) || c == ' ' || c == '-' || c == '_') {
safeTitle += c;
}
}
if (safeTitle.empty()) safeTitle = "ebook";
if (safeTitle.length() > 100) safeTitle = safeTitle.substr(0, 100);
// Use Drogon's file response for efficient streaming
auto resp = HttpResponse::newFileResponse(fullPath, "", CT_CUSTOM);
resp->addHeader("Content-Type", "application/epub+zip");
resp->addHeader("Content-Disposition", "attachment; filename=\"" + safeTitle + ".epub\"");
callback(resp);
}
>> DB_ERROR(callback, "download ebook");
}

View file

@ -0,0 +1,72 @@
#pragma once
#include <drogon/HttpController.h>
#include "../services/AuthService.h"
using namespace drogon;
class EbookController : public HttpController<EbookController> {
public:
METHOD_LIST_BEGIN
// Public endpoints
ADD_METHOD_TO(EbookController::getAllEbooks, "/api/ebooks", Get);
ADD_METHOD_TO(EbookController::getLatestEbooks, "/api/ebooks/latest", Get);
ADD_METHOD_TO(EbookController::getEbook, "/api/ebooks/{1}", Get);
ADD_METHOD_TO(EbookController::getUserEbooks, "/api/ebooks/user/{1}", Get);
ADD_METHOD_TO(EbookController::getRealmEbooks, "/api/ebooks/realm/{1}", Get);
ADD_METHOD_TO(EbookController::incrementReadCount, "/api/ebooks/{1}/read", Post);
// Authenticated endpoints
ADD_METHOD_TO(EbookController::getMyEbooks, "/api/user/ebooks", Get);
ADD_METHOD_TO(EbookController::uploadEbook, "/api/user/ebooks", Post);
ADD_METHOD_TO(EbookController::updateEbook, "/api/ebooks/{1}", Put);
ADD_METHOD_TO(EbookController::deleteEbook, "/api/ebooks/{1}", Delete);
ADD_METHOD_TO(EbookController::uploadCover, "/api/ebooks/{1}/cover", Post);
ADD_METHOD_TO(EbookController::downloadEbook, "/api/ebooks/{1}/download", Get);
METHOD_LIST_END
// Public ebook listing
void getAllEbooks(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getLatestEbooks(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getEbook(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &ebookId);
void getUserEbooks(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &username);
void getRealmEbooks(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void incrementReadCount(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &ebookId);
// Authenticated ebook management
void getMyEbooks(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void uploadEbook(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void updateEbook(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &ebookId);
void deleteEbook(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &ebookId);
void uploadCover(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &ebookId);
void downloadEbook(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &ebookId);
};

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,122 @@
#pragma once
#include <drogon/HttpController.h>
#include "../services/AuthService.h"
using namespace drogon;
class ForumController : public HttpController<ForumController> {
public:
METHOD_LIST_BEGIN
// Forum CRUD
ADD_METHOD_TO(ForumController::getForums, "/api/forums", Get);
ADD_METHOD_TO(ForumController::getForum, "/api/forums/{1}", Get);
ADD_METHOD_TO(ForumController::createForum, "/api/forums", Post);
ADD_METHOD_TO(ForumController::updateForum, "/api/forums/{1}", Put);
ADD_METHOD_TO(ForumController::deleteForum, "/api/forums/{1}", Delete);
ADD_METHOD_TO(ForumController::uploadBanner, "/api/forums/{1}/banner", Post);
ADD_METHOD_TO(ForumController::deleteBanner, "/api/forums/{1}/banner", Delete);
ADD_METHOD_TO(ForumController::updateBannerPosition, "/api/forums/{1}/banner/position", Put);
ADD_METHOD_TO(ForumController::updateTitleColor, "/api/forums/{1}/title-color", Put);
// Thread CRUD
ADD_METHOD_TO(ForumController::getThreads, "/api/forums/{1}/threads", Get);
ADD_METHOD_TO(ForumController::getThread, "/api/forums/{1}/threads/{2}", Get);
ADD_METHOD_TO(ForumController::createThread, "/api/forums/{1}/threads", Post);
ADD_METHOD_TO(ForumController::updateThread, "/api/forums/{1}/threads/{2}", Put);
ADD_METHOD_TO(ForumController::deleteThread, "/api/forums/{1}/threads/{2}", Delete);
ADD_METHOD_TO(ForumController::pinThread, "/api/forums/{1}/threads/{2}/pin", Post);
ADD_METHOD_TO(ForumController::lockThread, "/api/forums/{1}/threads/{2}/lock", Post);
// Post CRUD
ADD_METHOD_TO(ForumController::getPosts, "/api/forums/{1}/threads/{2}/posts", Get);
ADD_METHOD_TO(ForumController::createPost, "/api/forums/{1}/threads/{2}/posts", Post);
ADD_METHOD_TO(ForumController::updatePost, "/api/forums/{1}/threads/{2}/posts/{3}", Put);
ADD_METHOD_TO(ForumController::deletePost, "/api/forums/{1}/threads/{2}/posts/{3}", Delete);
// Moderation
ADD_METHOD_TO(ForumController::getBannedUsers, "/api/forums/{1}/bans", Get);
ADD_METHOD_TO(ForumController::banUser, "/api/forums/{1}/bans", Post);
ADD_METHOD_TO(ForumController::unbanUser, "/api/forums/{1}/bans/{2}", Delete);
METHOD_LIST_END
// Forum methods
void getForums(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getForum(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId);
void createForum(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void updateForum(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId);
void deleteForum(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId);
void uploadBanner(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId);
void deleteBanner(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId);
void updateBannerPosition(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId);
void updateTitleColor(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId);
// Thread methods
void getThreads(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId);
void getThread(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId, const std::string &threadId);
void createThread(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId);
void updateThread(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId, const std::string &threadId);
void deleteThread(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId, const std::string &threadId);
void pinThread(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId, const std::string &threadId);
void lockThread(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId, const std::string &threadId);
// Post methods
void getPosts(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId, const std::string &threadId);
void createPost(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId, const std::string &threadId);
void updatePost(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId, const std::string &threadId,
const std::string &postId);
void deletePost(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId, const std::string &threadId,
const std::string &postId);
// Moderation methods
void getBannedUsers(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId);
void banUser(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId);
void unbanUser(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &forumId, const std::string &banId);
private:
bool isForumModerator(int64_t userId, int64_t forumOwnerId, bool isAdmin);
bool isUserBanned(int64_t forumId, int64_t userId);
};

File diff suppressed because it is too large Load diff

View file

@ -15,10 +15,18 @@ public:
ADD_METHOD_TO(RealmController::regenerateRealmKey, "/api/realms/{1}/regenerate-key", Post);
ADD_METHOD_TO(RealmController::getRealmByName, "/api/realms/by-name/{1}", Get);
ADD_METHOD_TO(RealmController::getLiveRealms, "/api/realms/live", Get);
ADD_METHOD_TO(RealmController::getAllRealms, "/api/realms/all", Get);
ADD_METHOD_TO(RealmController::validateRealmKey, "/api/realms/validate/{1}", Get);
ADD_METHOD_TO(RealmController::issueRealmViewerToken, "/api/realms/{1}/viewer-token", Get);
ADD_METHOD_TO(RealmController::getRealmStreamKey, "/api/realms/{1}/stream-key", Get);
ADD_METHOD_TO(RealmController::getRealmStats, "/api/realms/{1}/stats", Get);
ADD_METHOD_TO(RealmController::getPublicUserRealms, "/api/realms/user/{1}", Get);
ADD_METHOD_TO(RealmController::uploadOfflineImage, "/api/realms/{1}/offline-image", Post);
ADD_METHOD_TO(RealmController::deleteOfflineImage, "/api/realms/{1}/offline-image", Delete);
ADD_METHOD_TO(RealmController::getRealmModerators, "/api/realms/{1}/moderators", Get);
ADD_METHOD_TO(RealmController::addRealmModerator, "/api/realms/{1}/moderators", Post);
ADD_METHOD_TO(RealmController::removeRealmModerator, "/api/realms/{1}/moderators/{2}", Delete);
ADD_METHOD_TO(RealmController::updateTitleColor, "/api/realms/{1}/title-color", Put);
METHOD_LIST_END
void getUserRealms(const HttpRequestPtr &req,
@ -49,7 +57,10 @@ public:
void getLiveRealms(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getAllRealms(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void validateRealmKey(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &key);
@ -66,6 +77,32 @@ public:
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
private:
UserInfo getUserFromRequest(const HttpRequestPtr &req);
void getPublicUserRealms(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &username);
void uploadOfflineImage(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void deleteOfflineImage(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void getRealmModerators(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void addRealmModerator(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void removeRealmModerator(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId,
const std::string &moderatorId);
void updateTitleColor(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
};

View file

@ -0,0 +1,413 @@
#include "RestreamController.h"
#include "../services/RestreamService.h"
#include "../common/HttpHelpers.h"
#include "../common/AuthHelpers.h"
using namespace drogon::orm;
void RestreamController::verifyRestreamPermission(const HttpRequestPtr &req, int64_t realmId,
std::function<void(bool authorized, const UserInfo& user)> callback) {
UserInfo user = getUserFromRequest(req);
if (user.id == 0) {
callback(false, user);
return;
}
// Admin can always manage restreams
if (user.isAdmin) {
callback(true, user);
return;
}
// Must have restreamer role
if (!user.isRestreamer) {
callback(false, user);
return;
}
// Check if user owns the realm
auto dbClient = app().getDbClient();
*dbClient << "SELECT user_id FROM realms WHERE id = $1"
<< realmId
>> [callback, user](const Result& r) {
if (r.empty()) {
callback(false, user);
return;
}
int64_t ownerId = r[0]["user_id"].as<int64_t>();
callback(ownerId == user.id, user);
}
>> [callback, user](const DrogonDbException& e) {
LOG_ERROR << "Database error: " << e.base().what();
callback(false, user);
};
}
void RestreamController::getDestinations(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId) {
int64_t id = std::stoll(realmId);
verifyRestreamPermission(req, id, [callback, id](bool authorized, const UserInfo& user) {
if (!authorized) {
callback(jsonError("Unauthorized - requires restreamer role and realm ownership", k403Forbidden));
return;
}
auto dbClient = app().getDbClient();
*dbClient << "SELECT id, name, rtmp_url, stream_key, enabled, is_connected, last_error, "
"last_connected_at, created_at FROM restream_destinations WHERE realm_id = $1 "
"ORDER BY created_at ASC"
<< id
>> [callback](const Result& r) {
Json::Value resp;
resp["success"] = true;
Json::Value destinations(Json::arrayValue);
for (const auto& row : r) {
Json::Value dest;
dest["id"] = static_cast<Json::Int64>(row["id"].as<int64_t>());
dest["name"] = row["name"].as<std::string>();
dest["rtmpUrl"] = row["rtmp_url"].as<std::string>();
dest["streamKey"] = row["stream_key"].as<std::string>();
dest["enabled"] = row["enabled"].as<bool>();
dest["isConnected"] = row["is_connected"].as<bool>();
dest["lastError"] = row["last_error"].isNull() ? "" : row["last_error"].as<std::string>();
dest["lastConnectedAt"] = row["last_connected_at"].isNull() ? "" : row["last_connected_at"].as<std::string>();
dest["createdAt"] = row["created_at"].as<std::string>();
destinations.append(dest);
}
resp["destinations"] = destinations;
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "get restream destinations", "Failed to get restream destinations");
});
}
void RestreamController::addDestination(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId) {
int64_t id = std::stoll(realmId);
verifyRestreamPermission(req, id, [this, callback, id, req](bool authorized, const UserInfo& user) {
if (!authorized) {
callback(jsonError("Unauthorized - requires restreamer role and realm ownership", k403Forbidden));
return;
}
auto jsonPtr = req->getJsonObject();
if (!jsonPtr) {
callback(jsonError("Invalid JSON body"));
return;
}
const auto& json = *jsonPtr;
std::string name = json.get("name", "").asString();
std::string rtmpUrl = json.get("rtmpUrl", "").asString();
std::string streamKey = json.get("streamKey", "").asString();
bool enabled = json.get("enabled", true).asBool();
// Validate inputs
if (name.empty() || name.length() > 100) {
callback(jsonError("Name is required and must be less than 100 characters"));
return;
}
if (rtmpUrl.empty() || rtmpUrl.length() > 500) {
callback(jsonError("RTMP URL is required and must be less than 500 characters"));
return;
}
// Validate RTMP URL format
if (rtmpUrl.substr(0, 7) != "rtmp://" && rtmpUrl.substr(0, 8) != "rtmps://") {
callback(jsonError("RTMP URL must start with rtmp:// or rtmps://"));
return;
}
if (streamKey.empty() || streamKey.length() > 500) {
callback(jsonError("Stream key is required and must be less than 500 characters"));
return;
}
auto dbClient = app().getDbClient();
// Check current count (max 2 destinations per realm)
*dbClient << "SELECT COUNT(*) as count FROM restream_destinations WHERE realm_id = $1"
<< id
>> [dbClient, callback, id, name, rtmpUrl, streamKey, enabled](const Result& r) {
int64_t count = r[0]["count"].as<int64_t>();
if (count >= 2) {
callback(jsonError("Maximum of 2 restream destinations per realm"));
return;
}
// Insert new destination
*dbClient << "INSERT INTO restream_destinations (realm_id, name, rtmp_url, stream_key, enabled) "
"VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at"
<< id << name << rtmpUrl << streamKey << enabled
>> [callback, name, rtmpUrl, streamKey, enabled](const Result& r) {
if (r.empty()) {
callback(jsonError("Failed to create destination"));
return;
}
Json::Value resp;
resp["success"] = true;
resp["message"] = "Restream destination created";
resp["destination"]["id"] = static_cast<Json::Int64>(r[0]["id"].as<int64_t>());
resp["destination"]["name"] = name;
resp["destination"]["rtmpUrl"] = rtmpUrl;
resp["destination"]["streamKey"] = streamKey;
resp["destination"]["enabled"] = enabled;
resp["destination"]["isConnected"] = false;
resp["destination"]["createdAt"] = r[0]["created_at"].as<std::string>();
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "create restream destination", "Failed to create destination");
}
>> DB_ERROR(callback, "check destination count");
});
}
void RestreamController::updateDestination(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId,
const std::string &destinationId) {
int64_t rid = std::stoll(realmId);
int64_t did = std::stoll(destinationId);
verifyRestreamPermission(req, rid, [callback, rid, did, req](bool authorized, const UserInfo& user) {
if (!authorized) {
callback(jsonError("Unauthorized", k403Forbidden));
return;
}
auto jsonPtr = req->getJsonObject();
if (!jsonPtr) {
callback(jsonError("Invalid JSON body"));
return;
}
const auto& json = *jsonPtr;
auto dbClient = app().getDbClient();
// Validate fields if provided
if (json.isMember("name")) {
std::string name = json["name"].asString();
if (name.empty() || name.length() > 100) {
callback(jsonError("Name must be 1-100 characters"));
return;
}
}
if (json.isMember("rtmpUrl")) {
std::string rtmpUrl = json["rtmpUrl"].asString();
if (rtmpUrl.empty() || rtmpUrl.length() > 500) {
callback(jsonError("RTMP URL must be 1-500 characters"));
return;
}
if (rtmpUrl.substr(0, 7) != "rtmp://" && rtmpUrl.substr(0, 8) != "rtmps://") {
callback(jsonError("RTMP URL must start with rtmp:// or rtmps://"));
return;
}
}
if (json.isMember("streamKey")) {
std::string streamKey = json["streamKey"].asString();
if (streamKey.empty() || streamKey.length() > 500) {
callback(jsonError("Stream key must be 1-500 characters"));
return;
}
}
bool hasAnyField = json.isMember("name") || json.isMember("rtmpUrl") ||
json.isMember("streamKey") || json.isMember("enabled");
if (!hasAnyField) {
callback(jsonError("No fields to update"));
return;
}
// Only update enabled (most common case)
if (json.isMember("enabled") && !json.isMember("name") &&
!json.isMember("rtmpUrl") && !json.isMember("streamKey")) {
bool newEnabled = json["enabled"].asBool();
// If disabling, stop the push first
if (!newEnabled) {
// Get the realm's stream key to stop the push
*dbClient << "SELECT stream_key FROM realms WHERE id = $1"
<< rid
>> [dbClient, callback, did, rid, newEnabled](const Result& r) {
if (!r.empty()) {
std::string streamKey = r[0]["stream_key"].as<std::string>();
// Stop the push
RestreamService::getInstance().stopPush(streamKey, did, [](bool) {});
}
// Update the database
*dbClient << "UPDATE restream_destinations SET enabled = $1 WHERE id = $2 AND realm_id = $3"
<< newEnabled << did << rid
>> [callback](const Result&) {
Json::Value resp;
resp["success"] = true;
resp["message"] = "Destination updated";
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "update restream destination", "Failed to update destination");
}
>> DB_ERROR(callback, "get realm");
} else {
// Just enable it
*dbClient << "UPDATE restream_destinations SET enabled = $1 WHERE id = $2 AND realm_id = $3"
<< newEnabled << did << rid
>> [callback](const Result&) {
Json::Value resp;
resp["success"] = true;
resp["message"] = "Destination updated";
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "update restream destination", "Failed to update destination");
}
}
// Update name only
else if (json.isMember("name") && !json.isMember("enabled") &&
!json.isMember("rtmpUrl") && !json.isMember("streamKey")) {
*dbClient << "UPDATE restream_destinations SET name = $1 WHERE id = $2 AND realm_id = $3"
<< json["name"].asString() << did << rid
>> [callback](const Result&) {
Json::Value resp;
resp["success"] = true;
resp["message"] = "Destination updated";
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "update restream destination", "Failed to update destination");
}
// Full update with all fields
else {
std::string name = json.get("name", "").asString();
std::string rtmpUrl = json.get("rtmpUrl", "").asString();
std::string streamKey = json.get("streamKey", "").asString();
bool enabled = json.get("enabled", true).asBool();
*dbClient << "UPDATE restream_destinations SET "
"name = $1, rtmp_url = $2, stream_key = $3, enabled = $4 "
"WHERE id = $5 AND realm_id = $6"
<< name << rtmpUrl << streamKey << enabled << did << rid
>> [callback](const Result&) {
Json::Value resp;
resp["success"] = true;
resp["message"] = "Destination updated";
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "update restream destination", "Failed to update destination");
}
});
}
void RestreamController::deleteDestination(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId,
const std::string &destinationId) {
int64_t rid = std::stoll(realmId);
int64_t did = std::stoll(destinationId);
verifyRestreamPermission(req, rid, [callback, rid, did](bool authorized, const UserInfo& user) {
if (!authorized) {
callback(jsonError("Unauthorized", k403Forbidden));
return;
}
auto dbClient = app().getDbClient();
// First get the realm's stream key to stop any active push
*dbClient << "SELECT r.stream_key FROM realms r "
"JOIN restream_destinations rd ON rd.realm_id = r.id "
"WHERE rd.id = $1 AND rd.realm_id = $2"
<< did << rid
>> [dbClient, callback, did, rid](const Result& r) {
if (!r.empty()) {
std::string streamKey = r[0]["stream_key"].as<std::string>();
// Stop the push if active
RestreamService::getInstance().stopPush(streamKey, did, [](bool) {});
}
// Delete the destination
*dbClient << "DELETE FROM restream_destinations WHERE id = $1 AND realm_id = $2"
<< did << rid
>> [callback](const Result&) {
Json::Value resp;
resp["success"] = true;
resp["message"] = "Destination deleted";
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "delete restream destination", "Failed to delete destination");
}
>> DB_ERROR(callback, "get stream key for delete");
});
}
void RestreamController::testDestination(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId,
const std::string &destinationId) {
int64_t rid = std::stoll(realmId);
int64_t did = std::stoll(destinationId);
verifyRestreamPermission(req, rid, [callback, rid, did](bool authorized, const UserInfo& user) {
if (!authorized) {
callback(jsonError("Unauthorized", k403Forbidden));
return;
}
auto dbClient = app().getDbClient();
// Get the destination and realm info
*dbClient << "SELECT rd.id, rd.name, rd.rtmp_url, rd.stream_key, rd.enabled, "
"r.stream_key as realm_stream_key, r.is_live "
"FROM restream_destinations rd "
"JOIN realms r ON rd.realm_id = r.id "
"WHERE rd.id = $1 AND rd.realm_id = $2"
<< did << rid
>> [callback, did](const Result& r) {
if (r.empty()) {
callback(jsonError("Destination not found", k404NotFound));
return;
}
bool isLive = r[0]["is_live"].as<bool>();
if (!isLive) {
callback(jsonError("Stream must be live to test restream connection"));
return;
}
RestreamDestination dest;
dest.id = r[0]["id"].as<int64_t>();
dest.name = r[0]["name"].as<std::string>();
dest.rtmpUrl = r[0]["rtmp_url"].as<std::string>();
dest.streamKey = r[0]["stream_key"].as<std::string>();
dest.enabled = r[0]["enabled"].as<bool>();
std::string realmStreamKey = r[0]["realm_stream_key"].as<std::string>();
// Try to start the push
RestreamService::getInstance().startPush(realmStreamKey, dest,
[callback, dest](bool success, const std::string& error) {
Json::Value resp;
resp["success"] = success;
if (success) {
resp["message"] = "Restream connection successful";
resp["isConnected"] = true;
} else {
resp["message"] = "Restream connection failed";
resp["error"] = error;
resp["isConnected"] = false;
}
callback(jsonResp(resp, success ? k200OK : k400BadRequest));
});
}
>> DB_ERROR(callback, "test restream destination");
});
}

View file

@ -0,0 +1,44 @@
#pragma once
#include <drogon/HttpController.h>
#include "../services/AuthService.h"
using namespace drogon;
class RestreamController : public HttpController<RestreamController> {
public:
METHOD_LIST_BEGIN
ADD_METHOD_TO(RestreamController::getDestinations, "/api/realms/{1}/restream", Get);
ADD_METHOD_TO(RestreamController::addDestination, "/api/realms/{1}/restream", Post);
ADD_METHOD_TO(RestreamController::updateDestination, "/api/realms/{1}/restream/{2}", Put);
ADD_METHOD_TO(RestreamController::deleteDestination, "/api/realms/{1}/restream/{2}", Delete);
ADD_METHOD_TO(RestreamController::testDestination, "/api/realms/{1}/restream/{2}/test", Post);
METHOD_LIST_END
void getDestinations(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void addDestination(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void updateDestination(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId,
const std::string &destinationId);
void deleteDestination(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId,
const std::string &destinationId);
void testDestination(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId,
const std::string &destinationId);
private:
// Verify user has restream permission for a realm (owner + restreamer role)
void verifyRestreamPermission(const HttpRequestPtr &req, int64_t realmId,
std::function<void(bool authorized, const UserInfo& user)> callback);
};

View file

@ -4,6 +4,9 @@
#include "../services/RedisHelper.h"
#include "../services/OmeClient.h"
#include "../services/AuthService.h"
#include "../services/RestreamService.h"
#include "../common/HttpHelpers.h"
#include "../common/AuthHelpers.h"
#include <drogon/utils/Utilities.h>
#include <drogon/Cookie.h>
#include <random>
@ -15,24 +18,10 @@ using namespace drogon::orm;
// Helper functions at the top
namespace {
// JSON response helper - saves 6-8 lines per endpoint
HttpResponsePtr jsonResp(const Json::Value& j, HttpStatusCode c = k200OK) {
auto r = HttpResponse::newHttpJsonResponse(j);
r->setStatusCode(c);
return r;
}
HttpResponsePtr jsonOk(const Json::Value& data) {
return jsonResp(data);
}
HttpResponsePtr jsonError(const std::string& error, HttpStatusCode code = k400BadRequest) {
Json::Value j;
j["success"] = false;
j["error"] = error;
return jsonResp(j, code);
}
// Quick JSON builder for common patterns
Json::Value json(std::initializer_list<std::pair<const char*, Json::Value>> items) {
Json::Value j;
@ -41,19 +30,6 @@ namespace {
}
return j;
}
UserInfo getUserFromRequest(const HttpRequestPtr &req) {
UserInfo user;
std::string auth = req->getHeader("Authorization");
if (auth.empty() || auth.substr(0, 7) != "Bearer ") {
return user;
}
std::string token = auth.substr(7);
AuthService::getInstance().validateToken(token, user);
return user;
}
}
// Static member definitions
@ -122,10 +98,7 @@ void StreamController::disconnectStream(const HttpRequestPtr &req,
}
});
}
>> [callback](const DrogonDbException& e) {
LOG_ERROR << "Database error: " << e.base().what();
callback(jsonError("Database error"));
};
>> DB_ERROR(callback, "disconnect stream");
}
void StreamController::getStreamStats(const HttpRequestPtr &,
@ -252,7 +225,8 @@ void StreamController::heartbeat(const HttpRequestPtr &req,
return;
}
services::RedisHelper::instance().expireAsync("viewer_token:" + token, 30,
// Refresh token TTL to 5 minutes on heartbeat
services::RedisHelper::instance().expireAsync("viewer_token:" + token, 300,
[callback](bool success) {
if (!success) {
callback(jsonResp({}, k500InternalServerError));
@ -285,29 +259,29 @@ void StreamWebSocketController::handleNewMessage(const WebSocketConnectionPtr&,
void StreamWebSocketController::handleNewConnection(const HttpRequestPtr &req,
const WebSocketConnectionPtr& wsConnPtr) {
LOG_INFO << "New WebSocket connection established";
// Allow anonymous connections for receiving public broadcasts (stream_live/stream_offline)
// These are used by the home page to get instant updates
std::lock_guard<std::mutex> lock(connectionsMutex_);
connections_.insert(wsConnPtr);
auto token = req->getCookie("viewer_token");
if (token.empty()) {
LOG_WARN << "WebSocket connection without viewer token";
wsConnPtr->shutdown();
return;
}
RedisHelper::getKeyAsync("viewer_token:" + token,
[wsConnPtr, token](const std::string& streamKey) {
if (streamKey.empty()) {
LOG_WARN << "Invalid viewer token";
wsConnPtr->shutdown();
return;
if (!token.empty()) {
// If viewer token is provided, validate and track it
RedisHelper::getKeyAsync("viewer_token:" + token,
[wsConnPtr, token](const std::string& streamKey) {
if (!streamKey.empty()) {
std::lock_guard<std::mutex> lock(connectionsMutex_);
tokenConnections_[token].insert(wsConnPtr);
LOG_INFO << "WebSocket authenticated for stream: " << streamKey;
} else {
LOG_DEBUG << "WebSocket with invalid/expired viewer token - treating as anonymous";
}
}
std::lock_guard<std::mutex> lock(connectionsMutex_);
tokenConnections_[token].insert(wsConnPtr);
connections_.insert(wsConnPtr);
LOG_INFO << "WebSocket authenticated for stream: " << streamKey;
}
);
);
} else {
LOG_DEBUG << "Anonymous WebSocket connection (no viewer token)";
}
}
void StreamWebSocketController::handleConnectionClosed(const WebSocketConnectionPtr& wsConnPtr) {
@ -360,11 +334,248 @@ void StreamWebSocketController::broadcastKeyUpdate(const std::string& userId, co
void StreamWebSocketController::broadcastStatsUpdate(const Json::Value& stats) {
std::string jsonStr = Json::FastWriter().write(stats);
std::lock_guard<std::mutex> lock(connectionsMutex_);
for (const auto& conn : connections_) {
if (conn->connected()) {
conn->send(jsonStr);
}
}
}
// OvenMediaEngine Webhook Handlers
void StreamController::handleOmeWebhook(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
auto jsonPtr = req->getJsonObject();
if (!jsonPtr) {
LOG_WARN << "OME webhook received with invalid JSON";
callback(jsonError("Invalid JSON", k400BadRequest));
return;
}
const auto& payload = *jsonPtr;
std::string eventType = payload.get("eventType", "").asString();
LOG_INFO << "OME Webhook received: " << eventType;
LOG_DEBUG << "OME Webhook payload: " << payload.toStyledString();
// Extract stream information
std::string streamName;
if (payload.isMember("stream") && payload["stream"].isMember("name")) {
streamName = payload["stream"]["name"].asString();
} else if (payload.isMember("streamName")) {
streamName = payload["streamName"].asString();
}
if (streamName.empty()) {
LOG_WARN << "OME webhook missing stream name";
callback(jsonOk(json({{"success", true}, {"message", "Acknowledged"}})));
return;
}
auto dbClient = app().getDbClient();
if (eventType == "streamCreated" || eventType == "stream.created" || eventType == "publish") {
// Stream started - mark realm as live immediately
LOG_INFO << "Stream started via webhook: " << streamName;
*dbClient << "UPDATE realms SET is_live = true, viewer_count = 0, "
"updated_at = CURRENT_TIMESTAMP WHERE stream_key = $1 RETURNING id"
<< streamName
>> [streamName](const Result& r) {
LOG_INFO << "Realm marked as live via webhook: " << streamName;
// Broadcast to WebSocket clients
Json::Value msg;
msg["type"] = "stream_live";
msg["stream_key"] = streamName;
msg["is_live"] = true;
StreamWebSocketController::broadcastStatsUpdate(msg);
// Trigger immediate stats fetch
StatsService::getInstance().updateStreamStats(streamName);
// Pre-warm thumbnail cache so it's ready when users see the stream
// This makes an async request to generate the thumbnail in the background
auto client = HttpClient::newHttpClient("http://localhost:8088");
auto req = HttpRequest::newHttpRequest();
req->setPath("/thumb/" + streamName + ".webp");
req->setMethod(drogon::Get);
client->sendRequest(req, [streamName](ReqResult result, const HttpResponsePtr& response) {
if (result == ReqResult::Ok && response && response->statusCode() == k200OK) {
LOG_INFO << "Thumbnail pre-warmed for stream: " << streamName;
} else {
LOG_DEBUG << "Thumbnail pre-warm pending for: " << streamName << " (stream may still be initializing)";
}
}, 10.0); // 10 second timeout for thumbnail generation
// Start restream destinations if realm has any
if (!r.empty()) {
int64_t realmId = r[0]["id"].as<int64_t>();
RestreamService::getInstance().startAllDestinations(streamName, realmId);
}
}
>> [streamName](const DrogonDbException& e) {
LOG_ERROR << "Failed to mark realm live via webhook: " << e.base().what();
};
}
else if (eventType == "streamDeleted" || eventType == "stream.deleted" || eventType == "unpublish") {
// Stream ended - mark realm as offline immediately
LOG_INFO << "Stream ended via webhook: " << streamName;
*dbClient << "UPDATE realms SET is_live = false, viewer_count = 0, "
"updated_at = CURRENT_TIMESTAMP WHERE stream_key = $1 RETURNING id"
<< streamName
>> [streamName](const Result& r) {
LOG_INFO << "Realm marked as offline via webhook: " << streamName;
// Broadcast to WebSocket clients
Json::Value msg;
msg["type"] = "stream_offline";
msg["stream_key"] = streamName;
msg["is_live"] = false;
StreamWebSocketController::broadcastStatsUpdate(msg);
// Stop all restream destinations
if (!r.empty()) {
int64_t realmId = r[0]["id"].as<int64_t>();
RestreamService::getInstance().stopAllDestinations(streamName, realmId);
}
}
>> [streamName](const DrogonDbException& e) {
LOG_ERROR << "Failed to mark realm offline via webhook: " << e.base().what();
};
}
else if (eventType == "sessionCreated" || eventType == "viewer.connected") {
// Viewer connected
LOG_INFO << "Viewer connected to stream: " << streamName;
StatsService::getInstance().updateStreamStats(streamName);
}
else if (eventType == "sessionDeleted" || eventType == "viewer.disconnected") {
// Viewer disconnected
LOG_INFO << "Viewer disconnected from stream: " << streamName;
StatsService::getInstance().updateStreamStats(streamName);
}
// Always respond with success to acknowledge the webhook
callback(jsonOk(json({{"success", true}, {"message", "Webhook processed"}})));
}
void StreamController::handleOmeAdmission(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
// Admission webhook - validates if a stream is allowed to publish/play
// OME sends: { "client": {...}, "request": { "direction", "protocol", "status", "url", ... } }
auto jsonPtr = req->getJsonObject();
if (!jsonPtr) {
LOG_WARN << "OME admission webhook received with invalid JSON";
callback(jsonError("Invalid JSON", k400BadRequest));
return;
}
const auto& payload = *jsonPtr;
LOG_INFO << "OME Admission webhook: " << payload.toStyledString();
// Check if this is a "closing" status - just acknowledge it
if (payload.isMember("request") && payload["request"].isMember("status")) {
std::string status = payload["request"]["status"].asString();
if (status == "closing") {
LOG_INFO << "OME admission closing notification";
Json::Value response;
callback(jsonOk(response)); // Empty response for closing
return;
}
}
// Extract stream key from URL: rtmp://host:port/app/STREAM_KEY or similar
std::string streamKey;
if (payload.isMember("request") && payload["request"].isMember("url")) {
std::string url = payload["request"]["url"].asString();
// URL format: scheme://host[:port]/app/stream_key[/file][?query]
// Find the stream key after /app/
size_t appPos = url.find("/app/");
if (appPos != std::string::npos) {
std::string afterApp = url.substr(appPos + 5); // Skip "/app/"
// Remove any trailing path or query string
size_t endPos = afterApp.find_first_of("/?");
if (endPos != std::string::npos) {
streamKey = afterApp.substr(0, endPos);
} else {
streamKey = afterApp;
}
}
LOG_INFO << "Extracted stream key from URL: " << streamKey << " (URL: " << url << ")";
}
if (streamKey.empty()) {
LOG_WARN << "OME admission webhook: could not extract stream key, allowing by default";
Json::Value response;
response["allowed"] = true;
callback(jsonOk(response));
return;
}
// Check direction - only validate "incoming" (publish) requests
std::string direction;
if (payload.isMember("request") && payload["request"].isMember("direction")) {
direction = payload["request"]["direction"].asString();
}
if (direction == "outgoing") {
// Playback request - allow all for now (could add viewer auth later)
LOG_INFO << "Allowing outgoing (playback) request for: " << streamKey;
Json::Value response;
response["allowed"] = true;
callback(jsonOk(response));
return;
}
// Validate stream key against database for incoming (publish) requests
auto dbClient = app().getDbClient();
*dbClient << "SELECT id FROM realms WHERE stream_key = $1 AND is_active = true"
<< streamKey
>> [callback, streamKey](const Result& r) {
Json::Value response;
if (!r.empty()) {
LOG_INFO << "Stream key validated for admission: " << streamKey;
response["allowed"] = true;
// Mark stream as live immediately when publishing is approved
int64_t realmId = r[0]["id"].as<int64_t>();
auto db = app().getDbClient();
*db << "UPDATE realms SET is_live = true, viewer_count = 0, "
"updated_at = CURRENT_TIMESTAMP WHERE id = $1"
<< realmId
>> [streamKey, realmId](const Result&) {
LOG_INFO << "Realm marked live on admission: " << streamKey;
// Broadcast to WebSocket clients
Json::Value msg;
msg["type"] = "stream_live";
msg["stream_key"] = streamKey;
msg["is_live"] = true;
StreamWebSocketController::broadcastStatsUpdate(msg);
// Trigger stats fetch
StatsService::getInstance().updateStreamStats(streamKey);
// Start restream destinations
RestreamService::getInstance().startAllDestinations(streamKey, realmId);
}
>> [streamKey](const DrogonDbException& e) {
LOG_ERROR << "Failed to mark realm live on admission: " << e.base().what();
};
} else {
LOG_WARN << "Invalid stream key rejected: " << streamKey;
response["allowed"] = false;
response["reason"] = "Invalid or inactive stream key";
}
callback(jsonOk(response));
}
>> [callback, streamKey](const DrogonDbException& e) {
LOG_ERROR << "Database error during admission check: " << e.base().what();
// Allow on DB error to prevent blocking legitimate streams
Json::Value response;
response["allowed"] = true;
callback(jsonOk(response));
};
}

View file

@ -18,33 +18,43 @@ public:
ADD_METHOD_TO(StreamController::getActiveStreams, "/api/stream/active", Get);
ADD_METHOD_TO(StreamController::issueViewerToken, "/api/stream/token/{1}", Get);
ADD_METHOD_TO(StreamController::heartbeat, "/api/stream/heartbeat/{1}", Post);
// OvenMediaEngine webhook endpoints
ADD_METHOD_TO(StreamController::handleOmeWebhook, "/api/webhook/ome", Post);
ADD_METHOD_TO(StreamController::handleOmeAdmission, "/api/webhook/ome/admission", Post);
METHOD_LIST_END
void health(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void validateStreamKey(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &key);
void disconnectStream(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &streamId);
void getStreamStats(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &streamKey);
void getActiveStreams(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void issueViewerToken(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &streamKey);
void heartbeat(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &streamKey);
// OvenMediaEngine webhook handlers
void handleOmeWebhook(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void handleOmeAdmission(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
};
class StreamWebSocketController : public WebSocketController<StreamWebSocketController> {

File diff suppressed because it is too large Load diff

View file

@ -13,16 +13,40 @@ public:
ADD_METHOD_TO(UserController::pgpChallenge, "/api/auth/pgp-challenge", Post);
ADD_METHOD_TO(UserController::pgpVerify, "/api/auth/pgp-verify", Post);
ADD_METHOD_TO(UserController::getCurrentUser, "/api/user/me", Get);
ADD_METHOD_TO(UserController::getToken, "/api/user/token", Get);
ADD_METHOD_TO(UserController::updateProfile, "/api/user/profile", Put);
ADD_METHOD_TO(UserController::updatePassword, "/api/user/password", Put);
ADD_METHOD_TO(UserController::togglePgpOnly, "/api/user/pgp-only", Put);
ADD_METHOD_TO(UserController::addPgpKey, "/api/user/pgp-key", Post);
ADD_METHOD_TO(UserController::getPgpKeys, "/api/user/pgp-keys", Get);
ADD_METHOD_TO(UserController::uploadAvatar, "/api/user/avatar", Post);
ADD_METHOD_TO(UserController::uploadBanner, "/api/user/banner", Post);
ADD_METHOD_TO(UserController::getProfile, "/api/users/{1}", Get);
ADD_METHOD_TO(UserController::getUserPgpKeys, "/api/users/{1}/pgp-keys", Get);
ADD_METHOD_TO(UserController::updateColor, "/api/user/color", Put);
ADD_METHOD_TO(UserController::getAvailableColors, "/api/colors/available", Get);
ADD_METHOD_TO(UserController::getBotApiKeys, "/api/user/bot-keys", Get);
ADD_METHOD_TO(UserController::createBotApiKey, "/api/user/bot-keys", Post);
ADD_METHOD_TO(UserController::deleteBotApiKey, "/api/user/bot-keys/{1}", Delete);
ADD_METHOD_TO(UserController::validateBotApiKey, "/api/internal/validate-bot-key", Post);
ADD_METHOD_TO(UserController::processPendingUberban, "/api/internal/user/{1}/process-pending-uberban", Post);
ADD_METHOD_TO(UserController::submitSticker, "/api/stickers/submit", Post);
ADD_METHOD_TO(UserController::getMySubmissions, "/api/stickers/my-submissions", Get);
ADD_METHOD_TO(UserController::uploadGraffiti, "/api/user/graffiti", Post);
ADD_METHOD_TO(UserController::deleteGraffiti, "/api/user/graffiti", Delete);
// Übercoin endpoints
ADD_METHOD_TO(UserController::sendUbercoin, "/api/ubercoin/send", Post);
ADD_METHOD_TO(UserController::previewUbercoin, "/api/ubercoin/preview", Post);
ADD_METHOD_TO(UserController::getTreasury, "/api/ubercoin/treasury", Get);
// Treasury cron endpoints (admin-only, called by scheduled tasks)
ADD_METHOD_TO(UserController::treasuryApplyGrowth, "/api/ubercoin/cron/growth", Post);
ADD_METHOD_TO(UserController::treasuryDistribute, "/api/ubercoin/cron/distribute", Post);
// Referral code endpoints
ADD_METHOD_TO(UserController::getReferralCodes, "/api/user/referral-codes", Get);
ADD_METHOD_TO(UserController::purchaseReferralCode, "/api/user/referral-codes/purchase", Post);
ADD_METHOD_TO(UserController::validateReferralCode, "/api/auth/validate-referral", Post);
ADD_METHOD_TO(UserController::registerWithReferral, "/api/auth/register-referral", Post);
ADD_METHOD_TO(UserController::getReferralSettings, "/api/settings/referral", Get);
METHOD_LIST_END
void register_(const HttpRequestPtr &req,
@ -42,7 +66,10 @@ public:
void getCurrentUser(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getToken(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void updateProfile(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
@ -60,7 +87,10 @@ public:
void uploadAvatar(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void uploadBanner(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getProfile(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &username);
@ -75,6 +105,76 @@ public:
void getAvailableColors(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getBotApiKeys(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void createBotApiKey(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void deleteBotApiKey(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &keyId);
void validateBotApiKey(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void processPendingUberban(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &userId);
void submitSticker(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getMySubmissions(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void uploadGraffiti(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void deleteGraffiti(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
// Übercoin methods
void sendUbercoin(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void previewUbercoin(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getTreasury(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
// Treasury cron methods (admin-only)
void treasuryApplyGrowth(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void treasuryDistribute(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
// Referral code methods
void getReferralCodes(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void purchaseReferralCode(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void validateReferralCode(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void registerWithReferral(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getReferralSettings(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
private:
UserInfo getUserFromRequest(const HttpRequestPtr &req);
// Übercoin helper: Calculate burn rate based on account age
// Formula: max(1, 99 * e^(-account_age_days / 180))
double calculateBurnRate(int accountAgeDays);
// Referral code helper: Generate random alphanumeric code
std::string generateReferralCode(int length = 12);
// Übercoin helper: Calculate account age in days from created_at timestamp
int calculateAccountAgeDays(const std::string& createdAt);
};

View file

@ -0,0 +1,926 @@
#include "VideoController.h"
#include "../services/DatabaseService.h"
#include "../services/RedisHelper.h" // SECURITY FIX #13: Redis for view rate limiting
#include "../common/HttpHelpers.h"
#include "../common/AuthHelpers.h"
#include "../common/FileUtils.h"
#include "../common/FileValidation.h"
#include <drogon/utils/Utilities.h>
#include <drogon/Cookie.h>
#include <random>
#include <sstream>
#include <iomanip>
#include <fstream>
#include <filesystem>
#include <cstdlib>
#include <array>
#include <thread>
#include <unistd.h>
#include <sys/wait.h>
#include <fcntl.h>
using namespace drogon::orm;
namespace {
// Video metadata extracted from a single ffprobe call
struct VideoMetadata {
int duration = 0;
int width = 0;
int height = 0;
int bitrate = 0;
std::string videoCodec;
std::string audioCodec;
};
// Get all video metadata with a single ffprobe call (5x faster than separate calls)
VideoMetadata getVideoMetadata(const std::string& videoPath) {
VideoMetadata meta;
if (!isPathSafe(videoPath, "/app/uploads")) {
LOG_ERROR << "Unsafe video path rejected: " << videoPath;
return meta;
}
std::vector<std::string> args = {
"/usr/bin/ffprobe", "-v", "error",
"-show_format", "-show_streams",
"-of", "json",
videoPath
};
std::string output = execCommandSafe(args);
if (output.empty()) {
return meta;
}
try {
Json::Value root;
Json::CharReaderBuilder builder;
std::string errors;
std::istringstream stream(output);
if (!Json::parseFromStream(builder, stream, &root, &errors)) {
LOG_ERROR << "Failed to parse ffprobe JSON: " << errors;
return meta;
}
// Extract format info (duration, bitrate)
if (root.isMember("format")) {
const auto& format = root["format"];
if (format.isMember("duration")) {
meta.duration = static_cast<int>(std::stof(format["duration"].asString()));
}
if (format.isMember("bit_rate")) {
try {
meta.bitrate = std::stoi(format["bit_rate"].asString());
} catch (...) {}
}
}
// Extract stream info (video/audio codecs, dimensions)
if (root.isMember("streams") && root["streams"].isArray()) {
for (const auto& stream : root["streams"]) {
std::string codecType = stream.get("codec_type", "").asString();
if (codecType == "video" && meta.videoCodec.empty()) {
meta.videoCodec = stream.get("codec_name", "").asString();
meta.width = stream.get("width", 0).asInt();
meta.height = stream.get("height", 0).asInt();
} else if (codecType == "audio" && meta.audioCodec.empty()) {
meta.audioCodec = stream.get("codec_name", "").asString();
}
}
}
} catch (const std::exception& e) {
LOG_ERROR << "Exception parsing video metadata: " << e.what();
}
return meta;
}
// Video JSON detail levels for API responses
enum class VideoJsonLevel {
Minimal, // 9 fields - for realm video lists
Basic, // 11 fields - for user video lists (+ realmId, realmName)
Standard, // 14 fields - for public lists (+ userId, username, avatarUrl)
Extended // 20 fields - for single video detail (+ technical metadata)
};
// Build video JSON object from database row (reduces code duplication)
Json::Value buildVideoJson(const drogon::orm::Row& row, VideoJsonLevel level) {
Json::Value video;
// Core fields (all levels)
video["id"] = static_cast<Json::Int64>(row["id"].as<int64_t>());
video["title"] = row["title"].as<std::string>();
video["description"] = row["description"].isNull() ? "" : row["description"].as<std::string>();
video["filePath"] = row["file_path"].as<std::string>();
video["thumbnailPath"] = row["thumbnail_path"].isNull() ? "" : row["thumbnail_path"].as<std::string>();
video["previewPath"] = row["preview_path"].isNull() ? "" : row["preview_path"].as<std::string>();
video["durationSeconds"] = row["duration_seconds"].as<int>();
video["viewCount"] = row["view_count"].as<int>();
video["createdAt"] = row["created_at"].as<std::string>();
if (level == VideoJsonLevel::Minimal) return video;
// Basic+ fields (realm info)
video["realmId"] = static_cast<Json::Int64>(row["realm_id"].as<int64_t>());
video["realmName"] = row["realm_name"].as<std::string>();
if (level == VideoJsonLevel::Basic) return video;
// Standard+ fields (user info)
video["userId"] = static_cast<Json::Int64>(row["user_id"].as<int64_t>());
video["username"] = row["username"].as<std::string>();
video["avatarUrl"] = row["avatar_url"].isNull() ? "" : row["avatar_url"].as<std::string>();
if (level == VideoJsonLevel::Standard) return video;
// Extended fields (technical metadata)
video["fileSizeBytes"] = static_cast<Json::Int64>(row["file_size_bytes"].as<int64_t>());
video["width"] = row["width"].isNull() ? 0 : row["width"].as<int>();
video["height"] = row["height"].isNull() ? 0 : row["height"].as<int>();
video["bitrate"] = row["bitrate"].isNull() ? 0 : row["bitrate"].as<int>();
video["videoCodec"] = row["video_codec"].isNull() ? "" : row["video_codec"].as<std::string>();
video["audioCodec"] = row["audio_codec"].isNull() ? "" : row["audio_codec"].as<std::string>();
return video;
}
// Generate static WebP thumbnail (safe version - no shell)
bool generateThumbnail(const std::string& videoPath, const std::string& thumbnailPath, int seekSeconds = 2) {
if (!isPathSafe(videoPath, "/app/uploads") || !isPathSafe(thumbnailPath, "/app/uploads")) {
LOG_ERROR << "Unsafe path rejected for thumbnail generation";
return false;
}
pid_t pid = fork();
if (pid == -1) return false;
if (pid == 0) {
// Child process - redirect stderr to /dev/null
int devnull = open("/dev/null", O_WRONLY);
dup2(devnull, STDERR_FILENO);
dup2(devnull, STDOUT_FILENO);
close(devnull);
std::string seekStr = std::to_string(seekSeconds);
execl("/usr/bin/ffmpeg", "ffmpeg",
"-y", "-ss", seekStr.c_str(),
"-i", videoPath.c_str(),
"-vframes", "1",
"-vf", "scale=320:-1",
"-c:v", "libwebp",
"-quality", "80",
thumbnailPath.c_str(),
nullptr);
_exit(1);
}
int status;
waitpid(pid, &status, 0);
return WIFEXITED(status) && WEXITSTATUS(status) == 0 && std::filesystem::exists(thumbnailPath);
}
// Generate animated WebP preview (safe version - no shell)
bool generateAnimatedPreview(const std::string& videoPath, const std::string& previewPath, int seekSeconds = 2, int duration = 3) {
if (!isPathSafe(videoPath, "/app/uploads") || !isPathSafe(previewPath, "/app/uploads")) {
LOG_ERROR << "Unsafe path rejected for preview generation";
return false;
}
pid_t pid = fork();
if (pid == -1) return false;
if (pid == 0) {
// Child process
int devnull = open("/dev/null", O_WRONLY);
dup2(devnull, STDERR_FILENO);
dup2(devnull, STDOUT_FILENO);
close(devnull);
std::string seekStr = std::to_string(seekSeconds);
std::string durationStr = std::to_string(duration);
execl("/usr/bin/ffmpeg", "ffmpeg",
"-y", "-ss", seekStr.c_str(),
"-t", durationStr.c_str(),
"-i", videoPath.c_str(),
"-vf", "scale=320:-1,fps=10",
"-loop", "0",
"-c:v", "libwebp",
"-quality", "60",
previewPath.c_str(),
nullptr);
_exit(1);
}
int status;
waitpid(pid, &status, 0);
return WIFEXITED(status) && WEXITSTATUS(status) == 0 && std::filesystem::exists(previewPath);
}
// Process video thumbnails asynchronously
void processVideoThumbnails(int64_t videoId, const std::string& videoFullPath, const std::string& uploadsDir) {
// Run thumbnail generation in a separate thread
std::thread([videoId, videoFullPath, uploadsDir]() {
try {
// Get all video metadata with a single ffprobe call (5x faster)
VideoMetadata meta = getVideoMetadata(videoFullPath);
// Calculate seek position (10% into video, min 1s, max 30s)
int seekPos = std::max(1, std::min(30, meta.duration / 10));
// Generate filenames
std::string baseName = std::filesystem::path(videoFullPath).stem().string();
std::string thumbnailFilename = baseName + "_thumb.webp";
std::string previewFilename = baseName + "_preview.webp";
std::string thumbnailFullPath = uploadsDir + "/" + thumbnailFilename;
std::string previewFullPath = uploadsDir + "/" + previewFilename;
// Generate thumbnails
bool thumbOk = generateThumbnail(videoFullPath, thumbnailFullPath, seekPos);
bool previewOk = generateAnimatedPreview(videoFullPath, previewFullPath, seekPos, 3);
// Update database
std::string thumbnailPath = thumbOk ? "/uploads/videos/" + thumbnailFilename : "";
std::string previewPath = previewOk ? "/uploads/videos/" + previewFilename : "";
auto dbClient = app().getDbClient();
*dbClient << "UPDATE videos SET thumbnail_path = $1, preview_path = $2, "
"duration_seconds = $3, width = $4, height = $5, "
"bitrate = $6, video_codec = $7, audio_codec = $8 "
"WHERE id = $9"
<< thumbnailPath << previewPath << meta.duration << meta.width << meta.height
<< meta.bitrate << meta.videoCodec << meta.audioCodec << videoId
>> [videoId](const Result&) {
LOG_INFO << "Video " << videoId << " thumbnails and metadata generated successfully";
}
>> [videoId](const DrogonDbException& e) {
LOG_ERROR << "Failed to update video " << videoId << " metadata: " << e.base().what();
};
} catch (const std::exception& e) {
LOG_ERROR << "Exception processing thumbnails for video " << videoId << ": " << e.what();
}
}).detach();
}
}
void VideoController::getAllVideos(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
// Get pagination parameters
int page = 1;
int limit = 20;
auto pageParam = req->getParameter("page");
auto limitParam = req->getParameter("limit");
if (!pageParam.empty()) {
try { page = std::stoi(pageParam); } catch (...) {}
}
if (!limitParam.empty()) {
try { limit = std::min(std::stoi(limitParam), 50); } catch (...) {}
}
int offset = (page - 1) * limit;
auto dbClient = app().getDbClient();
*dbClient << "SELECT v.id, v.title, v.description, v.file_path, v.thumbnail_path, v.preview_path, "
"v.duration_seconds, v.view_count, v.created_at, v.realm_id, "
"u.id as user_id, u.username, u.avatar_url, "
"r.name as realm_name "
"FROM videos v "
"JOIN users u ON v.user_id = u.id "
"JOIN realms r ON v.realm_id = r.id "
"WHERE v.is_public = true AND v.status = 'ready' "
"ORDER BY v.created_at DESC "
"LIMIT $1 OFFSET $2"
<< static_cast<int64_t>(limit) << static_cast<int64_t>(offset)
>> [callback](const Result& r) {
Json::Value resp;
resp["success"] = true;
Json::Value videos(Json::arrayValue);
for (const auto& row : r) {
videos.append(buildVideoJson(row, VideoJsonLevel::Standard));
}
resp["videos"] = videos;
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "get videos");
}
void VideoController::getLatestVideos(const HttpRequestPtr &,
std::function<void(const HttpResponsePtr &)> &&callback) {
auto dbClient = app().getDbClient();
*dbClient << "SELECT v.id, v.title, v.description, v.file_path, v.thumbnail_path, v.preview_path, "
"v.duration_seconds, v.view_count, v.created_at, v.realm_id, "
"u.id as user_id, u.username, u.avatar_url, "
"r.name as realm_name "
"FROM videos v "
"JOIN users u ON v.user_id = u.id "
"JOIN realms r ON v.realm_id = r.id "
"WHERE v.is_public = true AND v.status = 'ready' "
"ORDER BY v.created_at DESC "
"LIMIT 5"
>> [callback](const Result& r) {
Json::Value resp;
resp["success"] = true;
Json::Value videos(Json::arrayValue);
for (const auto& row : r) {
videos.append(buildVideoJson(row, VideoJsonLevel::Standard));
}
resp["videos"] = videos;
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "get latest videos");
}
void VideoController::getVideo(const HttpRequestPtr &,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &videoId) {
int64_t id;
try {
id = std::stoll(videoId);
} catch (...) {
callback(jsonError("Invalid video ID", k400BadRequest));
return;
}
auto dbClient = app().getDbClient();
*dbClient << "SELECT v.id, v.title, v.description, v.file_path, v.thumbnail_path, v.preview_path, "
"v.duration_seconds, v.file_size_bytes, v.width, v.height, "
"v.bitrate, v.video_codec, v.audio_codec, "
"v.view_count, v.is_public, v.status, v.created_at, v.realm_id, "
"u.id as user_id, u.username, u.avatar_url, "
"r.name as realm_name "
"FROM videos v "
"JOIN users u ON v.user_id = u.id "
"JOIN realms r ON v.realm_id = r.id "
"WHERE v.id = $1 AND v.status = 'ready'"
<< id
>> [callback](const Result& r) {
if (r.empty()) {
callback(jsonError("Video not found", k404NotFound));
return;
}
const auto& row = r[0];
// Check if video is public
if (!row["is_public"].as<bool>()) {
callback(jsonError("Video not found", k404NotFound));
return;
}
Json::Value resp;
resp["success"] = true;
resp["video"] = buildVideoJson(row, VideoJsonLevel::Extended);
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "get video");
}
void VideoController::getUserVideos(const HttpRequestPtr &,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &username) {
auto dbClient = app().getDbClient();
*dbClient << "SELECT v.id, v.title, v.description, v.file_path, v.thumbnail_path, v.preview_path, "
"v.duration_seconds, v.view_count, v.created_at, v.realm_id, "
"u.id as user_id, u.username, u.avatar_url, "
"r.name as realm_name "
"FROM videos v "
"JOIN users u ON v.user_id = u.id "
"JOIN realms r ON v.realm_id = r.id "
"WHERE u.username = $1 AND v.is_public = true AND v.status = 'ready' "
"ORDER BY v.created_at DESC"
<< username
>> [callback, username](const Result& r) {
Json::Value resp;
resp["success"] = true;
resp["username"] = username;
Json::Value videos(Json::arrayValue);
for (const auto& row : r) {
videos.append(buildVideoJson(row, VideoJsonLevel::Standard));
}
resp["videos"] = videos;
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "get user videos");
}
void VideoController::getRealmVideos(const HttpRequestPtr &,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId) {
int64_t id;
try {
id = std::stoll(realmId);
} catch (...) {
callback(jsonError("Invalid realm ID", k400BadRequest));
return;
}
auto dbClient = app().getDbClient();
// First get realm info
*dbClient << "SELECT r.id, r.name, r.description, r.realm_type, r.title_color, r.created_at, "
"u.id as user_id, u.username, u.avatar_url "
"FROM realms r "
"JOIN users u ON r.user_id = u.id "
"WHERE r.id = $1 AND r.is_active = true AND r.realm_type = 'video'"
<< id
>> [callback, dbClient, id](const Result& realmResult) {
if (realmResult.empty()) {
callback(jsonError("Video realm not found", k404NotFound));
return;
}
// Get videos for this realm
*dbClient << "SELECT v.id, v.title, v.description, v.file_path, v.thumbnail_path, v.preview_path, "
"v.duration_seconds, v.view_count, v.created_at "
"FROM videos v "
"WHERE v.realm_id = $1 AND v.is_public = true AND v.status = 'ready' "
"ORDER BY v.created_at DESC"
<< id
>> [callback, realmResult](const Result& r) {
Json::Value resp;
resp["success"] = true;
// Realm info
auto& realm = resp["realm"];
realm["id"] = static_cast<Json::Int64>(realmResult[0]["id"].as<int64_t>());
realm["name"] = realmResult[0]["name"].as<std::string>();
realm["description"] = realmResult[0]["description"].isNull() ? "" : realmResult[0]["description"].as<std::string>();
realm["titleColor"] = realmResult[0]["title_color"].isNull() ? "#ffffff" : realmResult[0]["title_color"].as<std::string>();
realm["createdAt"] = realmResult[0]["created_at"].as<std::string>();
realm["userId"] = static_cast<Json::Int64>(realmResult[0]["user_id"].as<int64_t>());
realm["username"] = realmResult[0]["username"].as<std::string>();
realm["avatarUrl"] = realmResult[0]["avatar_url"].isNull() ? "" : realmResult[0]["avatar_url"].as<std::string>();
// Videos (Minimal level - no realm/user info since it's implied)
Json::Value videos(Json::arrayValue);
for (const auto& row : r) {
videos.append(buildVideoJson(row, VideoJsonLevel::Minimal));
}
resp["videos"] = videos;
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "get realm videos");
}
>> DB_ERROR(callback, "get realm");
}
void VideoController::incrementViewCount(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &videoId) {
int64_t id;
try {
id = std::stoll(videoId);
} catch (...) {
callback(jsonError("Invalid video ID", k400BadRequest));
return;
}
// SECURITY FIX #13: Rate limit view count increments per IP per video
// Prevents artificial view count inflation
std::string clientIp = req->getPeerAddr().toIp();
std::string rateKey = "view_limit:" + std::to_string(id) + ":" + clientIp;
// Check if this IP has already viewed this video recently (5 minute window)
RedisHelper::getKeyAsync(rateKey, [callback, id, rateKey, clientIp](const std::string& exists) {
if (!exists.empty()) {
// Already counted recently - return success but don't increment
Json::Value resp;
resp["success"] = true;
resp["message"] = "View already counted";
callback(jsonResp(resp));
return;
}
// Set rate limit key first (TTL 300 seconds = 5 minutes)
RedisHelper::storeKeyAsync(rateKey, "1", 300, [callback, id](bool stored) {
if (!stored) {
LOG_WARN << "Failed to set view rate limit key, allowing view anyway";
}
// Increment view count in database
auto dbClient = app().getDbClient();
*dbClient << "UPDATE videos SET view_count = view_count + 1 "
"WHERE id = $1 AND is_public = true AND status = 'ready' "
"RETURNING view_count"
<< id
>> [callback](const Result& r) {
if (r.empty()) {
callback(jsonError("Video not found", k404NotFound));
return;
}
Json::Value resp;
resp["success"] = true;
resp["viewCount"] = r[0]["view_count"].as<int>();
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "increment view count");
});
});
}
void VideoController::getMyVideos(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
UserInfo user = getUserFromRequest(req);
if (user.id == 0) {
callback(jsonError("Unauthorized", k401Unauthorized));
return;
}
auto dbClient = app().getDbClient();
*dbClient << "SELECT v.id, v.title, v.description, v.file_path, v.thumbnail_path, v.preview_path, "
"v.duration_seconds, v.file_size_bytes, v.view_count, v.is_public, v.status, v.created_at, "
"v.realm_id, r.name as realm_name "
"FROM videos v "
"JOIN realms r ON v.realm_id = r.id "
"WHERE v.user_id = $1 AND v.status != 'deleted' "
"ORDER BY v.created_at DESC"
<< user.id
>> [callback](const Result& r) {
Json::Value resp;
resp["success"] = true;
Json::Value videos(Json::arrayValue);
for (const auto& row : r) {
Json::Value video;
video["id"] = static_cast<Json::Int64>(row["id"].as<int64_t>());
video["title"] = row["title"].as<std::string>();
video["description"] = row["description"].isNull() ? "" : row["description"].as<std::string>();
video["filePath"] = row["file_path"].as<std::string>();
video["thumbnailPath"] = row["thumbnail_path"].isNull() ? "" : row["thumbnail_path"].as<std::string>();
video["previewPath"] = row["preview_path"].isNull() ? "" : row["preview_path"].as<std::string>();
video["durationSeconds"] = row["duration_seconds"].as<int>();
video["fileSizeBytes"] = static_cast<Json::Int64>(row["file_size_bytes"].as<int64_t>());
video["viewCount"] = row["view_count"].as<int>();
video["isPublic"] = row["is_public"].as<bool>();
video["status"] = row["status"].as<std::string>();
video["createdAt"] = row["created_at"].as<std::string>();
video["realmId"] = static_cast<Json::Int64>(row["realm_id"].as<int64_t>());
video["realmName"] = row["realm_name"].as<std::string>();
videos.append(video);
}
resp["videos"] = videos;
callback(jsonResp(resp));
}
>> DB_ERROR(callback, "get user videos");
}
void VideoController::uploadVideo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
try {
UserInfo user = getUserFromRequest(req);
if (user.id == 0) {
callback(jsonError("Unauthorized", k401Unauthorized));
return;
}
MultiPartParser parser;
parser.parse(req);
// Get realm ID from form data - required
std::string realmIdStr = parser.getParameter<std::string>("realmId");
if (realmIdStr.empty()) {
callback(jsonError("Realm ID is required"));
return;
}
int64_t realmId;
try {
realmId = std::stoll(realmIdStr);
} catch (...) {
callback(jsonError("Invalid realm ID"));
return;
}
// Extract all data from parser before async call
if (parser.getFiles().empty()) {
callback(jsonError("No file uploaded"));
return;
}
const auto& file = parser.getFiles()[0];
// Get title from form data
std::string title = parser.getParameter<std::string>("title");
if (title.empty()) {
title = "Untitled Video";
}
if (title.length() > 255) {
title = title.substr(0, 255);
}
// Get optional description
std::string description = parser.getParameter<std::string>("description");
if (description.length() > 5000) {
description = description.substr(0, 5000);
}
// Validate file size (500MB max)
const size_t maxSize = 500 * 1024 * 1024;
size_t fileSize = file.fileLength();
if (fileSize > maxSize) {
callback(jsonError("File too large (max 500MB)"));
return;
}
if (fileSize == 0) {
callback(jsonError("Empty file uploaded"));
return;
}
// Validate video magic bytes
auto validation = validateVideoMagicBytes(file.fileData(), fileSize);
if (!validation.valid) {
LOG_WARN << "Video upload rejected: invalid video magic bytes";
callback(jsonError("Invalid video file. Only MP4, WebM, and MOV are allowed."));
return;
}
std::string fileExt = validation.extension.substr(1);
// Write file to disk IMMEDIATELY (before async DB calls) to avoid holding 500MB in memory
const std::string uploadDir = "/app/uploads/videos";
if (!ensureDirectoryExists(uploadDir)) {
callback(jsonError("Failed to create upload directory"));
return;
}
// Generate unique filename
std::string filename = generateRandomFilename(fileExt);
std::string fullPath = uploadDir + "/" + filename;
// Ensure file doesn't exist
while (std::filesystem::exists(fullPath)) {
filename = generateRandomFilename(fileExt);
fullPath = uploadDir + "/" + filename;
}
// Write directly from Drogon buffer (no memory copy)
try {
std::ofstream ofs(fullPath, std::ios::binary);
if (!ofs) {
LOG_ERROR << "Failed to create file: " << fullPath;
callback(jsonError("Failed to save file"));
return;
}
ofs.write(file.fileData(), fileSize);
ofs.close();
if (!std::filesystem::exists(fullPath)) {
LOG_ERROR << "File was not created: " << fullPath;
callback(jsonError("Failed to save file"));
return;
}
} catch (const std::exception& e) {
LOG_ERROR << "Exception saving video file: " << e.what();
callback(jsonError("Failed to save file"));
return;
}
std::string filePath = "/uploads/videos/" + filename;
// Check if user has uploader role and the realm exists and belongs to them
auto dbClient = app().getDbClient();
*dbClient << "SELECT u.is_uploader, r.id as realm_id, r.realm_type "
"FROM users u "
"LEFT JOIN realms r ON r.user_id = u.id AND r.id = $2 "
"WHERE u.id = $1"
<< user.id << realmId
>> [callback, user, dbClient, realmId, title, description, fullPath, filePath, fileSize, uploadDir](const Result& r) {
if (r.empty() || !r[0]["is_uploader"].as<bool>()) {
std::filesystem::remove(fullPath); // Clean up file on permission failure
callback(jsonError("You don't have permission to upload videos", k403Forbidden));
return;
}
// Check if realm exists and belongs to user
if (r[0]["realm_id"].isNull()) {
std::filesystem::remove(fullPath); // Clean up file
callback(jsonError("Video realm not found or doesn't belong to you", k404NotFound));
return;
}
// Check if it's a video realm
std::string realmType = r[0]["realm_type"].isNull() ? "stream" : r[0]["realm_type"].as<std::string>();
if (realmType != "video") {
std::filesystem::remove(fullPath); // Clean up file
callback(jsonError("Can only upload videos to video realms", k400BadRequest));
return;
}
// Insert video record - status is 'ready' for now (no processing)
*dbClient << "INSERT INTO videos (user_id, realm_id, title, description, file_path, "
"file_size_bytes, status, is_public, duration_seconds) "
"VALUES ($1, $2, $3, $4, $5, $6, 'ready', true, 0) RETURNING id, created_at"
<< user.id << realmId << title << description << filePath
<< static_cast<int64_t>(fileSize)
>> [callback, title, filePath, fileSize, realmId, fullPath, uploadDir](const Result& r2) {
if (r2.empty()) {
std::filesystem::remove(fullPath); // Clean up file
callback(jsonError("Failed to save video record"));
return;
}
int64_t videoId = r2[0]["id"].as<int64_t>();
// Start async thumbnail generation
processVideoThumbnails(videoId, fullPath, uploadDir);
Json::Value resp;
resp["success"] = true;
resp["video"]["id"] = static_cast<Json::Int64>(videoId);
resp["video"]["realmId"] = static_cast<Json::Int64>(realmId);
resp["video"]["title"] = title;
resp["video"]["filePath"] = filePath;
resp["video"]["fileSizeBytes"] = static_cast<Json::Int64>(fileSize);
resp["video"]["status"] = "ready";
resp["video"]["createdAt"] = r2[0]["created_at"].as<std::string>();
callback(jsonResp(resp));
}
>> [callback, fullPath](const DrogonDbException& e) {
LOG_ERROR << "Failed to insert video: " << e.base().what();
// Clean up file on DB error
std::filesystem::remove(fullPath);
callback(jsonError("Failed to save video"));
};
}
>> [callback, fullPath](const DrogonDbException& e) {
LOG_ERROR << "Failed to check uploader status: " << e.base().what();
std::filesystem::remove(fullPath); // Clean up file on DB error
callback(jsonError("Database error"));
};
} catch (const std::exception& e) {
LOG_ERROR << "Exception in uploadVideo: " << e.what();
callback(jsonError("Internal server error"));
}
}
void VideoController::updateVideo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &videoId) {
UserInfo user = getUserFromRequest(req);
if (user.id == 0) {
callback(jsonError("Unauthorized", k401Unauthorized));
return;
}
int64_t id;
try {
id = std::stoll(videoId);
} catch (...) {
callback(jsonError("Invalid video ID", k400BadRequest));
return;
}
auto json = req->getJsonObject();
if (!json) {
callback(jsonError("Invalid JSON"));
return;
}
auto dbClient = app().getDbClient();
// Verify ownership
*dbClient << "SELECT id FROM videos WHERE id = $1 AND user_id = $2 AND status != 'deleted'"
<< id << user.id
>> [callback, json, dbClient, id](const Result& r) {
if (r.empty()) {
callback(jsonError("Video not found or access denied", k404NotFound));
return;
}
std::string title, description;
if (json->isMember("title")) {
title = (*json)["title"].asString();
if (title.length() > 255) title = title.substr(0, 255);
}
if (json->isMember("description")) {
description = (*json)["description"].asString();
if (description.length() > 5000) description = description.substr(0, 5000);
}
if (json->isMember("title") && json->isMember("description")) {
*dbClient << "UPDATE videos SET title = $1, description = $2 WHERE id = $3"
<< title << description << id
>> [callback](const Result&) {
Json::Value resp;
resp["success"] = true;
resp["message"] = "Video updated successfully";
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "update video", "Failed to update video");
} else if (json->isMember("title")) {
*dbClient << "UPDATE videos SET title = $1 WHERE id = $2"
<< title << id
>> [callback](const Result&) {
Json::Value resp;
resp["success"] = true;
resp["message"] = "Video updated successfully";
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "update video", "Failed to update video");
} else if (json->isMember("description")) {
*dbClient << "UPDATE videos SET description = $1 WHERE id = $2"
<< description << id
>> [callback](const Result&) {
Json::Value resp;
resp["success"] = true;
resp["message"] = "Video updated successfully";
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "update video", "Failed to update video");
} else {
Json::Value resp;
resp["success"] = true;
resp["message"] = "No changes to apply";
callback(jsonResp(resp));
}
}
>> DB_ERROR(callback, "verify video ownership");
}
void VideoController::deleteVideo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &videoId) {
UserInfo user = getUserFromRequest(req);
if (user.id == 0) {
callback(jsonError("Unauthorized", k401Unauthorized));
return;
}
int64_t id;
try {
id = std::stoll(videoId);
} catch (...) {
callback(jsonError("Invalid video ID", k400BadRequest));
return;
}
auto dbClient = app().getDbClient();
// Get file paths and verify ownership
*dbClient << "SELECT file_path, thumbnail_path, preview_path FROM videos "
"WHERE id = $1 AND user_id = $2 AND status != 'deleted'"
<< id << user.id
>> [callback, dbClient, id](const Result& r) {
if (r.empty()) {
callback(jsonError("Video not found or access denied", k404NotFound));
return;
}
std::string filePath = r[0]["file_path"].as<std::string>();
std::string thumbnailPath = r[0]["thumbnail_path"].isNull() ? "" : r[0]["thumbnail_path"].as<std::string>();
std::string previewPath = r[0]["preview_path"].isNull() ? "" : r[0]["preview_path"].as<std::string>();
// Soft delete by setting status to 'deleted'
*dbClient << "UPDATE videos SET status = 'deleted' WHERE id = $1"
<< id
>> [callback, filePath, thumbnailPath, previewPath](const Result&) {
// Delete files from disk
try {
std::string fullVideoPath = "/app" + filePath;
if (std::filesystem::exists(fullVideoPath)) {
std::filesystem::remove(fullVideoPath);
}
if (!thumbnailPath.empty()) {
std::string fullThumbPath = "/app" + thumbnailPath;
if (std::filesystem::exists(fullThumbPath)) {
std::filesystem::remove(fullThumbPath);
}
}
if (!previewPath.empty()) {
std::string fullPreviewPath = "/app" + previewPath;
if (std::filesystem::exists(fullPreviewPath)) {
std::filesystem::remove(fullPreviewPath);
}
}
} catch (const std::exception& e) {
LOG_WARN << "Failed to delete video files: " << e.what();
}
Json::Value resp;
resp["success"] = true;
resp["message"] = "Video deleted successfully";
callback(jsonResp(resp));
}
>> DB_ERROR_MSG(callback, "delete video", "Failed to delete video");
}
>> DB_ERROR(callback, "get video for deletion");
}

View file

@ -0,0 +1,62 @@
#pragma once
#include <drogon/HttpController.h>
#include "../services/AuthService.h"
using namespace drogon;
class VideoController : public HttpController<VideoController> {
public:
METHOD_LIST_BEGIN
// Public endpoints
ADD_METHOD_TO(VideoController::getAllVideos, "/api/videos", Get);
ADD_METHOD_TO(VideoController::getLatestVideos, "/api/videos/latest", Get);
ADD_METHOD_TO(VideoController::getVideo, "/api/videos/{1}", Get);
ADD_METHOD_TO(VideoController::getUserVideos, "/api/videos/user/{1}", Get);
ADD_METHOD_TO(VideoController::getRealmVideos, "/api/videos/realm/{1}", Get);
ADD_METHOD_TO(VideoController::incrementViewCount, "/api/videos/{1}/view", Post);
// Authenticated endpoints
ADD_METHOD_TO(VideoController::getMyVideos, "/api/user/videos", Get);
ADD_METHOD_TO(VideoController::uploadVideo, "/api/user/videos", Post);
ADD_METHOD_TO(VideoController::updateVideo, "/api/videos/{1}", Put);
ADD_METHOD_TO(VideoController::deleteVideo, "/api/videos/{1}", Delete);
METHOD_LIST_END
// Public video listing
void getAllVideos(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getLatestVideos(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void getVideo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &videoId);
void getUserVideos(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &username);
void getRealmVideos(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void incrementViewCount(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &videoId);
// Authenticated video management
void getMyVideos(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void uploadVideo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void updateVideo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &videoId);
void deleteVideo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &videoId);
};

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,118 @@
#pragma once
#include <drogon/HttpController.h>
#include <drogon/orm/DbClient.h>
#include "../services/AuthService.h"
using namespace drogon;
using namespace drogon::orm;
class WatchController : public HttpController<WatchController> {
public:
METHOD_LIST_BEGIN
// List watch rooms
ADD_METHOD_TO(WatchController::getWatchRooms, "/api/watch/rooms", Get);
// Playlist management
ADD_METHOD_TO(WatchController::getPlaylist, "/api/watch/{1}/playlist", Get);
ADD_METHOD_TO(WatchController::addToPlaylist, "/api/watch/{1}/playlist", Post);
ADD_METHOD_TO(WatchController::removeFromPlaylist, "/api/watch/{1}/playlist/{2}", Delete);
ADD_METHOD_TO(WatchController::reorderPlaylist, "/api/watch/{1}/playlist/reorder", Put);
ADD_METHOD_TO(WatchController::toggleLock, "/api/watch/{1}/playlist/{2}/lock", Put);
// Playback control
ADD_METHOD_TO(WatchController::getRoomState, "/api/watch/{1}/state", Get);
ADD_METHOD_TO(WatchController::playVideo, "/api/watch/{1}/play", Post);
ADD_METHOD_TO(WatchController::pauseVideo, "/api/watch/{1}/pause", Post);
ADD_METHOD_TO(WatchController::seekVideo, "/api/watch/{1}/seek", Post);
ADD_METHOD_TO(WatchController::skipVideo, "/api/watch/{1}/skip", Post);
ADD_METHOD_TO(WatchController::nextVideo, "/api/watch/{1}/next", Post);
// Settings
ADD_METHOD_TO(WatchController::updateSettings, "/api/watch/{1}/settings", Put);
// Duration update (called by chat-service when player reports duration)
ADD_METHOD_TO(WatchController::updateDuration, "/api/watch/{1}/duration", Post);
METHOD_LIST_END
// List watch rooms
void getWatchRooms(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
// Playlist management
void getPlaylist(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void addToPlaylist(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void removeFromPlaylist(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId,
const std::string &itemId);
void reorderPlaylist(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void toggleLock(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId,
const std::string &itemId);
// Playback control
void getRoomState(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void playVideo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void pauseVideo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void seekVideo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void skipVideo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
void nextVideo(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
// Settings
void updateSettings(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
// Duration update (called by chat-service when player reports duration)
void updateDuration(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const std::string &realmId);
private:
bool canControlPlaylist(const UserInfo& user, int64_t realmId, int64_t ownerId,
const std::string& mode, const std::string& whitelist);
bool canControlPlayback(const UserInfo& user, int64_t ownerId);
std::string extractYouTubeVideoId(const std::string& url);
// Helper to add video to playlist (reduces callback nesting)
void addVideoToPlaylist(
std::function<void(const HttpResponsePtr &)> callback,
const DbClientPtr& dbClient,
int64_t realmId,
const UserInfo& user,
const std::string& videoId,
const std::string& title,
int durationSeconds,
const std::string& thumbnailUrl,
const std::string& username,
const std::string& fingerprint,
int64_t ownerId);
};

View file

@ -7,6 +7,8 @@
#include "services/DatabaseService.h"
#include "services/StatsService.h"
#include "services/AuthService.h"
#include "services/CensorService.h"
#include "services/TreasuryService.h"
#include <exception>
#include <csignal>
#include <sys/stat.h>
@ -36,8 +38,8 @@ int main() {
// Initialize StatsService BEFORE registering callbacks
LOG_INFO << "Initializing StatsService...";
StatsService::getInstance().initialize();
// Register a pre-routing advice to handle CORS
// Register a pre-routing advice to handle CORS and CSRF protection
app().registerPreRoutingAdvice([](const HttpRequestPtr &req,
AdviceCallback &&acb,
AdviceChainCallback &&accb) {
@ -45,13 +47,13 @@ int main() {
if (req->getMethod() == Options) {
auto resp = HttpResponse::newHttpResponse();
resp->setStatusCode(k204NoContent);
// Get origin from request
std::string origin = req->getHeader("Origin");
if (origin.empty()) {
origin = "*";
}
resp->addHeader("Access-Control-Allow-Origin", origin);
resp->addHeader("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
resp->addHeader("Access-Control-Allow-Headers", "Content-Type, Authorization");
@ -60,6 +62,35 @@ int main() {
acb(resp);
return;
}
// SECURITY FIX #18: CSRF protection for state-changing requests
// Require Origin or Referer header for POST/PUT/DELETE requests
if (req->getMethod() == Post || req->getMethod() == Put || req->getMethod() == Delete) {
std::string origin = req->getHeader("Origin");
std::string referer = req->getHeader("Referer");
std::string path = req->getPath();
// Skip CSRF check for API endpoints that use Bearer token auth
// (Bearer tokens are not automatically sent by browsers, so CSRF is not a concern)
std::string authHeader = req->getHeader("Authorization");
bool hasBearerToken = !authHeader.empty() && authHeader.substr(0, 7) == "Bearer ";
// Skip CSRF check for internal endpoints (server-to-server calls)
bool isInternalEndpoint = path.find("/api/webhook/") == 0 ||
path.find("/api/internal/") == 0;
// If not using Bearer auth and not an internal endpoint, require Origin or Referer header
if (!hasBearerToken && !isInternalEndpoint && origin.empty() && referer.empty()) {
LOG_WARN << "CSRF protection: Blocked request without Origin/Referer to "
<< req->getPath() << " from " << req->getPeerAddr().toIpPort();
auto resp = HttpResponse::newHttpResponse();
resp->setStatusCode(k403Forbidden);
resp->setBody("Missing Origin or Referer header");
acb(resp);
return;
}
}
accb();
});
@ -81,15 +112,39 @@ int main() {
// Register beginning advice to start the stats timer
app().registerBeginningAdvice([]() {
LOG_INFO << "Application started successfully";
// Clean up stuck audio processing jobs on startup
LOG_INFO << "Cleaning up stuck audio processing jobs...";
auto dbClient = app().getDbClient();
*dbClient << "UPDATE audio_files SET status = 'failed' "
"WHERE status = 'processing' AND created_at < NOW() - INTERVAL '5 minutes'"
>> [](const drogon::orm::Result& r) {
if (r.affectedRows() > 0) {
LOG_INFO << "Marked " << r.affectedRows() << " stuck audio jobs as failed";
}
}
>> [](const drogon::orm::DrogonDbException& e) {
LOG_WARN << "Failed to clean up stuck audio jobs: " << e.base().what();
};
// Start the stats polling timer
LOG_INFO << "Starting stats polling...";
StatsService::getInstance().startPolling();
// Load censored words from database
LOG_INFO << "Loading censored words...";
CensorService::getInstance().loadCensoredWords();
// Start treasury scheduler (hourly check for growth/distribution)
LOG_INFO << "Starting treasury scheduler...";
TreasuryService::getInstance().initialize();
TreasuryService::getInstance().startScheduler();
});
app().setTermSignalHandler([]() {
LOG_INFO << "Received termination signal, shutting down...";
StatsService::getInstance().shutdown();
TreasuryService::getInstance().shutdown();
app().quit();
});

View file

@ -1,15 +0,0 @@
#pragma once
#include <string>
#include <chrono>
struct Realm {
int64_t id;
int64_t userId;
std::string name;
std::string streamKey;
bool isActive;
bool isLive;
int64_t viewerCount;
std::chrono::system_clock::time_point createdAt;
std::chrono::system_clock::time_point updatedAt;
};

View file

@ -1,12 +0,0 @@
#pragma once
#include <string>
#include <chrono>
struct StreamKey {
int64_t id;
int64_t user_id;
std::string key;
bool is_active;
std::chrono::system_clock::time_point created_at;
std::chrono::system_clock::time_point updated_at;
};

View file

@ -9,6 +9,8 @@
#include <fstream>
#include <sstream>
#include <cstdlib>
#include <gpgme.h>
#include <filesystem>
using namespace drogon;
using namespace drogon::orm;
@ -32,153 +34,140 @@ bool AuthService::validatePassword(const std::string& password, std::string& err
return true;
}
// Helper function to execute GPG commands
std::string executeGpgCommand(const std::string& command) {
std::array<char, 128> buffer;
std::string result;
FILE* pipe = popen(command.c_str(), "r");
if (!pipe) {
LOG_ERROR << "Failed to execute GPG command: " << command;
return "";
}
while (fgets(buffer.data(), buffer.size(), pipe) != nullptr) {
result += buffer.data();
}
int exitCode = pclose(pipe);
// Exit code is returned as status << 8, so we need to extract the actual exit code
int actualExitCode = WEXITSTATUS(exitCode);
if (actualExitCode != 0) {
LOG_ERROR << "GPG command failed with exit code: " << actualExitCode
<< " for command: " << command
<< " output: " << result;
// Don't return empty string immediately - sometimes GPG returns non-zero but still works
}
return result;
// GPGME-based PGP signature verification (no shell commands for security)
namespace {
// RAII wrapper for gpgme_data_t (named to avoid conflict with deprecated GpgmeData typedef)
class GpgmeDataWrapper {
public:
GpgmeDataWrapper() : data_(nullptr) {}
~GpgmeDataWrapper() { if (data_) gpgme_data_release(data_); }
gpgme_data_t* ptr() { return &data_; }
gpgme_data_t get() { return data_; }
bool valid() const { return data_ != nullptr; }
private:
gpgme_data_t data_;
};
// RAII wrapper for gpgme_ctx_t
class GpgmeContextWrapper {
public:
GpgmeContextWrapper() : ctx_(nullptr) {}
~GpgmeContextWrapper() { if (ctx_) gpgme_release(ctx_); }
gpgme_ctx_t* ptr() { return &ctx_; }
gpgme_ctx_t get() { return ctx_; }
bool valid() const { return ctx_ != nullptr; }
private:
gpgme_ctx_t ctx_;
};
}
// Server-side PGP signature verification
bool verifyPgpSignature(const std::string& message, const std::string& signature, const std::string& publicKey) {
try {
// Create temporary directory for GPG operations
std::string tmpDir = "/tmp/pgp_verify_" + drogon::utils::genRandomString(8);
std::string mkdirCmd = "mkdir -p " + tmpDir;
if (system(mkdirCmd.c_str()) != 0) {
LOG_ERROR << "Failed to create temporary directory: " << tmpDir;
// Initialize GPGME
gpgme_check_version(nullptr);
// Create temporary directory for isolated keyring using filesystem
std::string tmpDir = "/tmp/pgp_verify_" + drogon::utils::genRandomString(16);
std::filesystem::create_directories(tmpDir);
std::filesystem::permissions(tmpDir, std::filesystem::perms::owner_all);
// Set GNUPGHOME environment for this context
std::string gnupgHome = tmpDir;
// Create GPGME context
GpgmeContextWrapper ctx;
gpgme_error_t err = gpgme_new(ctx.ptr());
if (err) {
LOG_ERROR << "Failed to create GPGME context: " << gpgme_strerror(err);
std::filesystem::remove_all(tmpDir);
return false;
}
// Create GPG home directory
std::string keyringDir = tmpDir + "/gnupg";
std::string mkdirGpgCmd = "mkdir -p " + keyringDir + " && chmod 700 " + keyringDir;
if (system(mkdirGpgCmd.c_str()) != 0) {
LOG_ERROR << "Failed to create GPG home directory: " << keyringDir;
std::string cleanupCmd = "rm -rf " + tmpDir;
system(cleanupCmd.c_str());
// Set the engine info to use our temporary directory
err = gpgme_ctx_set_engine_info(ctx.get(), GPGME_PROTOCOL_OpenPGP, nullptr, gnupgHome.c_str());
if (err) {
LOG_ERROR << "Failed to set GPGME engine info: " << gpgme_strerror(err);
std::filesystem::remove_all(tmpDir);
return false;
}
// Write files
std::string messageFile = tmpDir + "/message.txt";
std::string sigFile = tmpDir + "/signature.asc";
std::string pubkeyFile = tmpDir + "/pubkey.asc";
// Write message file
std::ofstream msgOut(messageFile);
if (!msgOut) {
LOG_ERROR << "Failed to create message file: " << messageFile;
std::string cleanupCmd = "rm -rf " + tmpDir;
system(cleanupCmd.c_str());
// Set protocol
gpgme_set_protocol(ctx.get(), GPGME_PROTOCOL_OpenPGP);
// Import the public key
GpgmeDataWrapper keyData;
err = gpgme_data_new_from_mem(keyData.ptr(), publicKey.c_str(), publicKey.size(), 1);
if (err) {
LOG_ERROR << "Failed to create key data: " << gpgme_strerror(err);
std::filesystem::remove_all(tmpDir);
return false;
}
msgOut << message;
msgOut.close();
// Write signature file
std::ofstream sigOut(sigFile);
if (!sigOut) {
LOG_ERROR << "Failed to create signature file: " << sigFile;
std::string cleanupCmd = "rm -rf " + tmpDir;
system(cleanupCmd.c_str());
err = gpgme_op_import(ctx.get(), keyData.get());
if (err) {
LOG_ERROR << "Failed to import public key: " << gpgme_strerror(err);
std::filesystem::remove_all(tmpDir);
return false;
}
sigOut << signature;
sigOut.close();
// Write public key file
std::ofstream keyOut(pubkeyFile);
if (!keyOut) {
LOG_ERROR << "Failed to create public key file: " << pubkeyFile;
std::string cleanupCmd = "rm -rf " + tmpDir;
system(cleanupCmd.c_str());
gpgme_import_result_t importResult = gpgme_op_import_result(ctx.get());
if (!importResult || (importResult->imported == 0 && importResult->unchanged == 0)) {
LOG_ERROR << "No keys were imported";
std::filesystem::remove_all(tmpDir);
return false;
}
keyOut << publicKey;
keyOut.close();
// Initialize GPG (create trustdb if needed)
std::string initCmd = "GNUPGHOME=" + keyringDir + " gpg --batch --yes --list-keys 2>&1";
executeGpgCommand(initCmd); // This will create the trustdb if it doesn't exist
// Import the public key to the temporary keyring
// Use --trust-model always to avoid trust issues
std::string importCmd = "GNUPGHOME=" + keyringDir +
" gpg --batch --yes --trust-model always --import " + pubkeyFile + " 2>&1";
std::string importResult = executeGpgCommand(importCmd);
LOG_DEBUG << "GPG import result: " << importResult;
// Check if import was successful (be more lenient with the check)
bool importSuccess = (importResult.find("imported") != std::string::npos) ||
(importResult.find("unchanged") != std::string::npos) ||
(importResult.find("processed: 1") != std::string::npos) ||
(importResult.find("public key") != std::string::npos);
if (!importSuccess) {
LOG_ERROR << "Failed to import public key. Import output: " << importResult;
// Try to get more information about what went wrong
std::string debugCmd = "GNUPGHOME=" + keyringDir + " gpg --list-keys 2>&1";
std::string debugResult = executeGpgCommand(debugCmd);
LOG_ERROR << "GPG keyring state: " << debugResult;
// Cleanup
std::string cleanupCmd = "rm -rf " + tmpDir;
system(cleanupCmd.c_str());
LOG_DEBUG << "GPGME imported " << importResult->imported << " keys, "
<< importResult->unchanged << " unchanged";
// Create data objects for signature and message
GpgmeDataWrapper sigData;
err = gpgme_data_new_from_mem(sigData.ptr(), signature.c_str(), signature.size(), 1);
if (err) {
LOG_ERROR << "Failed to create signature data: " << gpgme_strerror(err);
std::filesystem::remove_all(tmpDir);
return false;
}
GpgmeDataWrapper msgData;
err = gpgme_data_new_from_mem(msgData.ptr(), message.c_str(), message.size(), 1);
if (err) {
LOG_ERROR << "Failed to create message data: " << gpgme_strerror(err);
std::filesystem::remove_all(tmpDir);
return false;
}
// Verify the signature
// Use --trust-model always to avoid trust issues
std::string verifyCmd = "GNUPGHOME=" + keyringDir +
" gpg --batch --yes --trust-model always --verify " +
sigFile + " " + messageFile + " 2>&1";
std::string verifyResult = executeGpgCommand(verifyCmd);
LOG_DEBUG << "GPG verify result: " << verifyResult;
// Check if verification succeeded (check both English and potential localized messages)
bool verified = (verifyResult.find("Good signature") != std::string::npos) ||
(verifyResult.find("gpg: Good signature") != std::string::npos) ||
(verifyResult.find("Signature made") != std::string::npos &&
verifyResult.find("BAD signature") == std::string::npos);
if (!verified) {
LOG_WARN << "Signature verification failed. Verify output: " << verifyResult;
} else {
LOG_INFO << "Signature verification successful for challenge";
err = gpgme_op_verify(ctx.get(), sigData.get(), msgData.get(), nullptr);
if (err) {
LOG_WARN << "Signature verification failed: " << gpgme_strerror(err);
std::filesystem::remove_all(tmpDir);
return false;
}
// Cleanup temporary files
std::string cleanupCmd = "rm -rf " + tmpDir;
system(cleanupCmd.c_str());
// Check verification result
gpgme_verify_result_t verifyResult = gpgme_op_verify_result(ctx.get());
if (!verifyResult || !verifyResult->signatures) {
LOG_WARN << "No signatures found in verification result";
std::filesystem::remove_all(tmpDir);
return false;
}
// Check if signature is valid
gpgme_signature_t sig = verifyResult->signatures;
bool verified = (sig->status == GPG_ERR_NO_ERROR);
if (verified) {
LOG_INFO << "Signature verification successful for challenge";
} else {
LOG_WARN << "Signature verification failed: " << gpgme_strerror(sig->status);
}
// Cleanup temporary directory
std::filesystem::remove_all(tmpDir);
return verified;
} catch (const std::exception& e) {
LOG_ERROR << "Exception during signature verification: " << e.what();
return false;
@ -314,8 +303,8 @@ void AuthService::registerUser(const std::string& username, const std::string& p
return;
}
if (!std::regex_match(username, std::regex("^[a-zA-Z0-9_]+$"))) {
callback(false, "Username can only contain letters, numbers, and underscores", 0);
if (!std::regex_match(username, std::regex("^[a-zA-Z][a-zA-Z0-9_]*$"))) {
callback(false, "Username must start with a letter and contain only letters, numbers, and underscores", 0);
return;
}
@ -454,7 +443,7 @@ void AuthService::loginUser(const std::string& username, const std::string& pass
return;
}
*dbClient << "SELECT id, username, password_hash, is_admin, is_streamer, is_pgp_only, bio, avatar_url, pgp_only_enabled_at, user_color "
*dbClient << "SELECT id, username, password_hash, is_admin, is_moderator, is_streamer, is_restreamer, is_bot, is_texter, is_pgp_only, is_disabled, bio, avatar_url, banner_url, banner_position, banner_zoom, banner_position_x, graffiti_url, pgp_only_enabled_at, user_color "
"FROM users WHERE username = $1 LIMIT 1"
<< username
>> [password, callback, this](const Result& r) {
@ -463,18 +452,25 @@ void AuthService::loginUser(const std::string& username, const std::string& pass
callback(false, "", UserInfo{});
return;
}
// Check if account is disabled
bool isDisabled = r[0]["is_disabled"].isNull() ? false : r[0]["is_disabled"].as<bool>();
if (isDisabled) {
callback(false, "Account disabled", UserInfo{});
return;
}
// Check if PGP-only is enabled BEFORE password validation
bool isPgpOnly = r[0]["is_pgp_only"].isNull() ? false : r[0]["is_pgp_only"].as<bool>();
if (isPgpOnly) {
// Return a specific error for PGP-only accounts
callback(false, "PGP-only login enabled for this account", UserInfo{});
return;
}
std::string hash = r[0]["password_hash"].as<std::string>();
bool valid = false;
try {
valid = BCrypt::validatePassword(password, hash);
@ -483,23 +479,32 @@ void AuthService::loginUser(const std::string& username, const std::string& pass
callback(false, "", UserInfo{});
return;
}
if (!valid) {
callback(false, "", UserInfo{});
return;
}
UserInfo user;
user.id = r[0]["id"].as<int64_t>();
user.username = r[0]["username"].as<std::string>();
user.isAdmin = r[0]["is_admin"].isNull() ? false : r[0]["is_admin"].as<bool>();
user.isModerator = r[0]["is_moderator"].isNull() ? false : r[0]["is_moderator"].as<bool>();
user.isStreamer = r[0]["is_streamer"].isNull() ? false : r[0]["is_streamer"].as<bool>();
user.isRestreamer = r[0]["is_restreamer"].isNull() ? false : r[0]["is_restreamer"].as<bool>();
user.isBot = r[0]["is_bot"].isNull() ? false : r[0]["is_bot"].as<bool>();
user.isTexter = r[0]["is_texter"].isNull() ? false : r[0]["is_texter"].as<bool>();
user.isPgpOnly = isPgpOnly;
user.bio = r[0]["bio"].isNull() ? "" : r[0]["bio"].as<std::string>();
user.avatarUrl = r[0]["avatar_url"].isNull() ? "" : r[0]["avatar_url"].as<std::string>();
user.bannerUrl = r[0]["banner_url"].isNull() ? "" : r[0]["banner_url"].as<std::string>();
user.bannerPosition = r[0]["banner_position"].isNull() ? 50 : r[0]["banner_position"].as<int>();
user.bannerZoom = r[0]["banner_zoom"].isNull() ? 100 : r[0]["banner_zoom"].as<int>();
user.bannerPositionX = r[0]["banner_position_x"].isNull() ? 50 : r[0]["banner_position_x"].as<int>();
user.graffitiUrl = r[0]["graffiti_url"].isNull() ? "" : r[0]["graffiti_url"].as<std::string>();
user.pgpOnlyEnabledAt = r[0]["pgp_only_enabled_at"].isNull() ? "" : r[0]["pgp_only_enabled_at"].as<std::string>();
user.colorCode = r[0]["user_color"].isNull() ? "#561D5E" : r[0]["user_color"].as<std::string>();
std::string token = generateToken(user);
callback(true, token, user);
} catch (const std::exception& e) {
@ -593,8 +598,8 @@ void AuthService::verifyPgpLogin(const std::string& username, const std::string&
return;
}
*dbClient << "SELECT pk.public_key, u.id, u.username, u.is_admin, u.is_streamer, "
"u.is_pgp_only, u.bio, u.avatar_url, u.pgp_only_enabled_at, u.user_color "
*dbClient << "SELECT pk.public_key, u.id, u.username, u.is_admin, u.is_moderator, u.is_streamer, u.is_restreamer, u.is_bot, u.is_texter, "
"u.is_pgp_only, u.is_disabled, u.bio, u.avatar_url, u.banner_url, u.banner_position, u.banner_zoom, u.banner_position_x, u.graffiti_url, u.pgp_only_enabled_at, u.user_color "
"FROM pgp_keys pk JOIN users u ON pk.user_id = u.id "
"WHERE u.username = $1 ORDER BY pk.created_at DESC LIMIT 1"
<< username
@ -605,31 +610,47 @@ void AuthService::verifyPgpLogin(const std::string& username, const std::string&
callback(false, "", UserInfo{});
return;
}
// Check if account is disabled
bool isDisabled = r[0]["is_disabled"].isNull() ? false : r[0]["is_disabled"].as<bool>();
if (isDisabled) {
callback(false, "Account disabled", UserInfo{});
return;
}
std::string publicKey = r[0]["public_key"].as<std::string>();
// CRITICAL: Server-side signature verification
bool signatureValid = verifyPgpSignature(challenge, signature, publicKey);
if (!signatureValid) {
LOG_WARN << "Invalid PGP signature for user";
callback(false, "Invalid signature", UserInfo{});
return;
}
LOG_INFO << "PGP signature verified successfully for user";
UserInfo user;
user.id = r[0]["id"].as<int64_t>();
user.username = r[0]["username"].as<std::string>();
user.isAdmin = r[0]["is_admin"].isNull() ? false : r[0]["is_admin"].as<bool>();
user.isModerator = r[0]["is_moderator"].isNull() ? false : r[0]["is_moderator"].as<bool>();
user.isStreamer = r[0]["is_streamer"].isNull() ? false : r[0]["is_streamer"].as<bool>();
user.isRestreamer = r[0]["is_restreamer"].isNull() ? false : r[0]["is_restreamer"].as<bool>();
user.isBot = r[0]["is_bot"].isNull() ? false : r[0]["is_bot"].as<bool>();
user.isTexter = r[0]["is_texter"].isNull() ? false : r[0]["is_texter"].as<bool>();
user.isPgpOnly = r[0]["is_pgp_only"].isNull() ? false : r[0]["is_pgp_only"].as<bool>();
user.bio = r[0]["bio"].isNull() ? "" : r[0]["bio"].as<std::string>();
user.avatarUrl = r[0]["avatar_url"].isNull() ? "" : r[0]["avatar_url"].as<std::string>();
user.bannerUrl = r[0]["banner_url"].isNull() ? "" : r[0]["banner_url"].as<std::string>();
user.bannerPosition = r[0]["banner_position"].isNull() ? 50 : r[0]["banner_position"].as<int>();
user.bannerZoom = r[0]["banner_zoom"].isNull() ? 100 : r[0]["banner_zoom"].as<int>();
user.bannerPositionX = r[0]["banner_position_x"].isNull() ? 50 : r[0]["banner_position_x"].as<int>();
user.graffitiUrl = r[0]["graffiti_url"].isNull() ? "" : r[0]["graffiti_url"].as<std::string>();
user.pgpOnlyEnabledAt = r[0]["pgp_only_enabled_at"].isNull() ? "" : r[0]["pgp_only_enabled_at"].as<std::string>();
user.colorCode = r[0]["user_color"].isNull() ? "#561D5E" : r[0]["user_color"].as<std::string>();
std::string token = generateToken(user);
callback(true, token, user);
} catch (const std::exception& e) {
@ -653,25 +674,70 @@ void AuthService::verifyPgpLogin(const std::string& username, const std::string&
}
}
// SECURITY FIX #5: Validate JWT secret has minimum length and entropy
void AuthService::validateAndLoadJwtSecret() {
if (!jwtSecret_.empty()) {
return; // Already loaded and validated
}
const char* envSecret = std::getenv("JWT_SECRET");
if (!envSecret || strlen(envSecret) == 0) {
throw std::runtime_error("JWT_SECRET environment variable is not set");
}
size_t secretLen = strlen(envSecret);
// Require at least 32 characters (256 bits) for HS256
if (secretLen < 32) {
throw std::runtime_error("JWT_SECRET must be at least 32 characters (256 bits) for security");
}
// Basic entropy check - ensure not all same character
bool hasVariety = false;
for (size_t i = 1; i < secretLen && !hasVariety; ++i) {
if (envSecret[i] != envSecret[0]) {
hasVariety = true;
}
}
if (!hasVariety) {
throw std::runtime_error("JWT_SECRET has insufficient entropy - all characters are the same");
}
// Check for common weak secrets
std::string secretLower = envSecret;
std::transform(secretLower.begin(), secretLower.end(), secretLower.begin(), ::tolower);
if (secretLower.find("secret") != std::string::npos ||
secretLower.find("password") != std::string::npos ||
secretLower.find("123456") != std::string::npos) {
LOG_WARN << "JWT_SECRET appears to contain common weak patterns - consider using a stronger secret";
}
jwtSecret_ = std::string(envSecret);
LOG_INFO << "JWT secret loaded and validated (" << secretLen << " characters)";
}
std::string AuthService::generateToken(const UserInfo& user) {
try {
if (jwtSecret_.empty()) {
const char* envSecret = std::getenv("JWT_SECRET");
jwtSecret_ = envSecret ? std::string(envSecret) : "your-jwt-secret";
}
validateAndLoadJwtSecret();
// SECURITY FIX: Reduced JWT expiry from 24h to 1h to limit token exposure window
auto token = jwt::create()
.set_issuer("streaming-app")
.set_type("JWS")
.set_type("JWT")
.set_issued_at(std::chrono::system_clock::now())
.set_expires_at(std::chrono::system_clock::now() + std::chrono::hours(24))
.set_expires_at(std::chrono::system_clock::now() + std::chrono::hours(1))
.set_payload_claim("user_id", jwt::claim(std::to_string(user.id)))
.set_payload_claim("username", jwt::claim(user.username))
.set_payload_claim("is_admin", jwt::claim(std::to_string(user.isAdmin)))
.set_payload_claim("is_moderator", jwt::claim(std::to_string(user.isModerator)))
.set_payload_claim("is_streamer", jwt::claim(std::to_string(user.isStreamer)))
.set_payload_claim("is_restreamer", jwt::claim(std::to_string(user.isRestreamer)))
.set_payload_claim("is_disabled", jwt::claim(std::to_string(user.isDisabled))) // SECURITY FIX #26
.set_payload_claim("token_version", jwt::claim(std::to_string(user.tokenVersion))) // SECURITY FIX #10
.set_payload_claim("color_code", jwt::claim(
user.colorCode.empty() ? "#561D5E" : user.colorCode
)) // Ensure color is never empty
.set_payload_claim("avatar_url", jwt::claim(user.avatarUrl))
.sign(jwt::algorithm::hs256{jwtSecret_});
return token;
@ -683,11 +749,8 @@ std::string AuthService::generateToken(const UserInfo& user) {
bool AuthService::validateToken(const std::string& token, UserInfo& userInfo) {
try {
if (jwtSecret_.empty()) {
const char* envSecret = std::getenv("JWT_SECRET");
jwtSecret_ = envSecret ? std::string(envSecret) : "your-jwt-secret";
}
validateAndLoadJwtSecret();
auto decoded = jwt::decode(token);
auto verifier = jwt::verify()
@ -699,9 +762,21 @@ bool AuthService::validateToken(const std::string& token, UserInfo& userInfo) {
userInfo.id = std::stoll(decoded.get_payload_claim("user_id").as_string());
userInfo.username = decoded.get_payload_claim("username").as_string();
userInfo.isAdmin = decoded.get_payload_claim("is_admin").as_string() == "1";
userInfo.isStreamer = decoded.has_payload_claim("is_streamer") ?
userInfo.isModerator = decoded.has_payload_claim("is_moderator") ?
decoded.get_payload_claim("is_moderator").as_string() == "1" : false;
userInfo.isStreamer = decoded.has_payload_claim("is_streamer") ?
decoded.get_payload_claim("is_streamer").as_string() == "1" : false;
userInfo.isRestreamer = decoded.has_payload_claim("is_restreamer") ?
decoded.get_payload_claim("is_restreamer").as_string() == "1" : false;
// SECURITY FIX #26: Extract disabled status
userInfo.isDisabled = decoded.has_payload_claim("is_disabled") ?
decoded.get_payload_claim("is_disabled").as_string() == "1" : false;
// SECURITY FIX #10: Extract token version for revocation check
userInfo.tokenVersion = decoded.has_payload_claim("token_version") ?
std::stoi(decoded.get_payload_claim("token_version").as_string()) : 1;
// Get color from token if available, otherwise will need to fetch from DB
if (decoded.has_payload_claim("color_code")) {
userInfo.colorCode = decoded.get_payload_claim("color_code").as_string();
@ -709,7 +784,13 @@ bool AuthService::validateToken(const std::string& token, UserInfo& userInfo) {
// For older tokens without color, default value
userInfo.colorCode = "#561D5E";
}
// SECURITY FIX #26: Reject tokens from disabled accounts
if (userInfo.isDisabled) {
LOG_DEBUG << "Token rejected - user account is disabled: " << userInfo.username;
return false;
}
return true;
} catch (const std::exception& e) {
LOG_DEBUG << "Token validation failed: " << e.what();
@ -717,6 +798,23 @@ bool AuthService::validateToken(const std::string& token, UserInfo& userInfo) {
}
}
// Chat service compatibility method
std::optional<UserClaims> AuthService::verifyToken(const std::string& token) {
UserInfo userInfo;
if (validateToken(token, userInfo)) {
UserClaims claims;
claims.userId = std::to_string(userInfo.id);
claims.username = userInfo.username;
claims.userColor = userInfo.colorCode;
claims.isAdmin = userInfo.isAdmin;
claims.isModerator = userInfo.isModerator;
claims.isStreamer = userInfo.isStreamer;
claims.isRestreamer = userInfo.isRestreamer;
return claims;
}
return std::nullopt;
}
void AuthService::updatePassword(int64_t userId, const std::string& oldPassword,
const std::string& newPassword,
std::function<void(bool, const std::string&)> callback) {
@ -771,9 +869,11 @@ void AuthService::updatePassword(int64_t userId, const std::string& oldPassword,
return;
}
*dbClient << "UPDATE users SET password_hash = $1 WHERE id = $2"
// SECURITY FIX #10: Increment token_version to invalidate all existing tokens
*dbClient << "UPDATE users SET password_hash = $1, token_version = COALESCE(token_version, 0) + 1 WHERE id = $2"
<< newHash << userId
>> [callback](const Result&) {
>> [callback, userId](const Result&) {
LOG_INFO << "Password updated and token_version incremented for user " << userId;
callback(true, "");
}
>> [callback](const DrogonDbException& e) {
@ -804,7 +904,7 @@ void AuthService::fetchUserInfo(int64_t userId, std::function<void(bool, const U
return;
}
*dbClient << "SELECT id, username, is_admin, is_streamer, is_pgp_only, bio, avatar_url, pgp_only_enabled_at, user_color "
*dbClient << "SELECT id, username, is_admin, is_moderator, is_streamer, is_restreamer, is_bot, is_texter, is_pgp_only, bio, avatar_url, banner_url, banner_position, banner_zoom, banner_position_x, graffiti_url, pgp_only_enabled_at, user_color "
"FROM users WHERE id = $1 LIMIT 1"
<< userId
>> [callback](const Result& r) {
@ -813,18 +913,27 @@ void AuthService::fetchUserInfo(int64_t userId, std::function<void(bool, const U
callback(false, UserInfo{});
return;
}
UserInfo user;
user.id = r[0]["id"].as<int64_t>();
user.username = r[0]["username"].as<std::string>();
user.isAdmin = r[0]["is_admin"].isNull() ? false : r[0]["is_admin"].as<bool>();
user.isModerator = r[0]["is_moderator"].isNull() ? false : r[0]["is_moderator"].as<bool>();
user.isStreamer = r[0]["is_streamer"].isNull() ? false : r[0]["is_streamer"].as<bool>();
user.isRestreamer = r[0]["is_restreamer"].isNull() ? false : r[0]["is_restreamer"].as<bool>();
user.isBot = r[0]["is_bot"].isNull() ? false : r[0]["is_bot"].as<bool>();
user.isTexter = r[0]["is_texter"].isNull() ? false : r[0]["is_texter"].as<bool>();
user.isPgpOnly = r[0]["is_pgp_only"].isNull() ? false : r[0]["is_pgp_only"].as<bool>();
user.bio = r[0]["bio"].isNull() ? "" : r[0]["bio"].as<std::string>();
user.avatarUrl = r[0]["avatar_url"].isNull() ? "" : r[0]["avatar_url"].as<std::string>();
user.bannerUrl = r[0]["banner_url"].isNull() ? "" : r[0]["banner_url"].as<std::string>();
user.bannerPosition = r[0]["banner_position"].isNull() ? 50 : r[0]["banner_position"].as<int>();
user.bannerZoom = r[0]["banner_zoom"].isNull() ? 100 : r[0]["banner_zoom"].as<int>();
user.bannerPositionX = r[0]["banner_position_x"].isNull() ? 50 : r[0]["banner_position_x"].as<int>();
user.graffitiUrl = r[0]["graffiti_url"].isNull() ? "" : r[0]["graffiti_url"].as<std::string>();
user.pgpOnlyEnabledAt = r[0]["pgp_only_enabled_at"].isNull() ? "" : r[0]["pgp_only_enabled_at"].as<std::string>();
user.colorCode = r[0]["user_color"].isNull() ? "#561D5E" : r[0]["user_color"].as<std::string>();
callback(true, user);
} catch (const std::exception& e) {
LOG_ERROR << "Exception processing user data: " << e.what();

View file

@ -2,6 +2,7 @@
#include <string>
#include <functional>
#include <memory>
#include <optional>
#include <jwt-cpp/jwt.h>
#include <bcrypt/BCrypt.hpp>
@ -9,12 +10,38 @@ struct UserInfo {
int64_t id = 0;
std::string username;
bool isAdmin = false;
bool isModerator = false; // Site-wide moderator role
bool isStreamer = false;
bool isRestreamer = false;
bool isBot = false;
bool isTexter = false;
bool isPgpOnly = false;
bool isDisabled = false; // SECURITY FIX #26: Track disabled status
std::string bio;
std::string avatarUrl;
std::string bannerUrl;
int bannerPosition = 50; // Y position percentage (0-100) for object-position
int bannerZoom = 100; // Zoom percentage (100-200)
int bannerPositionX = 50; // X position percentage (0-100) for object-position
std::string graffitiUrl;
std::string pgpOnlyEnabledAt;
std::string colorCode;
double ubercoinBalance = 0.0; // Übercoin balance (3 decimal places)
std::string createdAt; // Account creation date (for burn rate calculation)
int tokenVersion = 1; // SECURITY FIX #10: Token version for revocation
};
// Chat service compatibility struct
struct UserClaims {
std::string userId;
std::string username;
std::string userColor;
bool isAdmin;
bool isModerator; // Site-wide moderator role
bool isStreamer;
bool isRestreamer;
UserClaims() : isAdmin(false), isModerator(false), isStreamer(false), isRestreamer(false) {}
};
class AuthService {
@ -40,7 +67,10 @@ public:
std::string generateToken(const UserInfo& user);
bool validateToken(const std::string& token, UserInfo& userInfo);
// Chat service compatibility method
std::optional<UserClaims> verifyToken(const std::string& token);
// New method to fetch complete user info including color
void fetchUserInfo(int64_t userId, std::function<void(bool, const UserInfo&)> callback);
@ -56,6 +86,7 @@ public:
private:
AuthService() = default;
std::string jwtSecret_;
bool validatePassword(const std::string& password, std::string& error);
void validateAndLoadJwtSecret(); // SECURITY FIX #5
};

View file

@ -0,0 +1,164 @@
#include "CensorService.h"
#include <drogon/drogon.h>
#include <sstream>
#include <algorithm>
#include <cctype>
using namespace drogon;
using namespace drogon::orm;
// Maximum length for a single censored word (ReDoS prevention)
static constexpr size_t MAX_WORD_LENGTH = 100;
// Maximum number of censored words
static constexpr size_t MAX_WORD_COUNT = 500;
void CensorService::loadCensoredWords(std::function<void(bool)> callback) {
auto dbClient = app().getDbClient();
*dbClient << "SELECT setting_value FROM site_settings WHERE setting_key = 'censored_words'"
>> [this, callback](const Result& r) {
// Build new patterns in temporary variables
std::vector<std::string> newWords;
std::optional<std::regex> newPattern;
if (!r.empty() && !r[0]["setting_value"].isNull()) {
std::string wordsStr = r[0]["setting_value"].as<std::string>();
// Parse comma-separated words
std::stringstream ss(wordsStr);
std::string word;
while (std::getline(ss, word, ',') && newWords.size() < MAX_WORD_COUNT) {
// Trim whitespace
size_t start = word.find_first_not_of(" \t\r\n");
size_t end = word.find_last_not_of(" \t\r\n");
if (start != std::string::npos && end != std::string::npos) {
word = word.substr(start, end - start + 1);
// Skip empty words and words exceeding max length (ReDoS prevention)
if (!word.empty() && word.length() <= MAX_WORD_LENGTH) {
newWords.push_back(word);
} else if (word.length() > MAX_WORD_LENGTH) {
LOG_WARN << "Skipping censored word exceeding " << MAX_WORD_LENGTH << " chars";
}
}
}
// Build combined pattern
newPattern = buildCombinedPattern(newWords);
}
// Atomic swap under lock
{
std::unique_lock<std::shared_mutex> lock(mutex_);
censoredWords_ = std::move(newWords);
combinedPattern_ = std::move(newPattern);
}
LOG_INFO << "Loaded " << censoredWords_.size() << " censored words";
if (callback) {
callback(true);
}
}
>> [callback](const DrogonDbException& e) {
LOG_ERROR << "Failed to load censored words: " << e.base().what();
if (callback) {
callback(false);
}
};
}
std::optional<std::regex> CensorService::buildCombinedPattern(const std::vector<std::string>& words) {
if (words.empty()) {
return std::nullopt;
}
try {
// Build combined pattern: \b(word1|word2|word3)\b
std::string pattern = "\\b(";
bool first = true;
for (const auto& word : words) {
if (!first) {
pattern += "|";
}
first = false;
// Escape special regex characters
for (char c : word) {
if (c == '.' || c == '^' || c == '$' || c == '*' || c == '+' ||
c == '?' || c == '(' || c == ')' || c == '[' || c == ']' ||
c == '{' || c == '}' || c == '|' || c == '\\') {
pattern += '\\';
}
pattern += c;
}
}
pattern += ")\\b";
return std::regex(pattern, std::regex_constants::icase);
} catch (const std::regex_error& e) {
LOG_ERROR << "Failed to build combined censored pattern: " << e.what();
return std::nullopt;
}
}
std::string CensorService::censor(const std::string& text) const {
if (text.empty()) {
return text;
}
std::shared_lock<std::shared_mutex> lock(mutex_);
if (!combinedPattern_) {
return text;
}
std::string result;
try {
// Replace censored words with asterisks
std::sregex_iterator begin(text.begin(), text.end(), *combinedPattern_);
std::sregex_iterator end;
size_t lastPos = 0;
for (std::sregex_iterator it = begin; it != end; ++it) {
const std::smatch& match = *it;
// Append text before match
result += text.substr(lastPos, match.position() - lastPos);
// Replace match with asterisks of same length
result += std::string(match.length(), '*');
lastPos = match.position() + match.length();
}
// Append remaining text
result += text.substr(lastPos);
} catch (const std::regex_error& e) {
LOG_ERROR << "Regex replace error: " << e.what();
return text;
}
return result;
}
bool CensorService::containsCensoredWords(const std::string& text) const {
if (text.empty()) {
return false;
}
std::shared_lock<std::shared_mutex> lock(mutex_);
if (!combinedPattern_) {
return false;
}
try {
return std::regex_search(text, *combinedPattern_);
} catch (const std::regex_error& e) {
LOG_ERROR << "Regex search error: " << e.what();
return false;
}
}
std::vector<std::string> CensorService::getCensoredWords() const {
std::shared_lock<std::shared_mutex> lock(mutex_);
return censoredWords_;
}

View file

@ -0,0 +1,37 @@
#pragma once
#include <string>
#include <vector>
#include <shared_mutex>
#include <regex>
#include <functional>
#include <optional>
class CensorService {
public:
static CensorService& getInstance() {
static CensorService instance;
return instance;
}
// Load censored words from database
void loadCensoredWords(std::function<void(bool)> callback = nullptr);
// Censor text by replacing censored words with asterisks (case-insensitive)
std::string censor(const std::string& text) const;
// Check if text contains any censored words
bool containsCensoredWords(const std::string& text) const;
// Get the list of censored words (for debugging/admin)
std::vector<std::string> getCensoredWords() const;
private:
CensorService() = default;
mutable std::shared_mutex mutex_;
std::vector<std::string> censoredWords_;
std::optional<std::regex> combinedPattern_; // Single combined pattern for efficiency
// Build a single combined regex pattern from all words
std::optional<std::regex> buildCombinedPattern(const std::vector<std::string>& words);
};

View file

@ -1,76 +0,0 @@
#pragma once
#include <drogon/drogon.h>
namespace middleware {
class CorsMiddleware {
public:
struct Config {
std::vector<std::string> allowOrigins = {"*"};
std::vector<std::string> allowMethods = {"GET", "POST", "PUT", "DELETE", "OPTIONS"};
std::vector<std::string> allowHeaders = {"Content-Type", "Authorization"};
bool allowCredentials = true;
int maxAge = 86400;
};
static void enable(const Config& config = {}) {
using namespace drogon;
auto cfg = std::make_shared<Config>(config);
auto addHeaders = [cfg](const HttpResponsePtr &resp, const HttpRequestPtr &req) {
std::string origin = req->getHeader("Origin");
// Check if origin is allowed
bool allowed = false;
for (const auto& allowedOrigin : cfg->allowOrigins) {
if (allowedOrigin == "*" || allowedOrigin == origin) {
allowed = true;
break;
}
}
if (allowed) {
resp->addHeader("Access-Control-Allow-Origin", origin.empty() ? "*" : origin);
resp->addHeader("Access-Control-Allow-Methods", joinStrings(cfg->allowMethods, ", "));
resp->addHeader("Access-Control-Allow-Headers", joinStrings(cfg->allowHeaders, ", "));
if (cfg->allowCredentials) {
resp->addHeader("Access-Control-Allow-Credentials", "true");
}
}
};
// Handle preflight requests
app().registerPreRoutingAdvice([cfg, addHeaders](const HttpRequestPtr &req,
AdviceCallback &&acb,
AdviceChainCallback &&accb) {
if (req->getMethod() == Options) {
auto resp = HttpResponse::newHttpResponse();
resp->setStatusCode(k204NoContent);
addHeaders(resp, req);
resp->addHeader("Access-Control-Max-Age", std::to_string(cfg->maxAge));
acb(resp);
return;
}
accb();
});
// Add CORS headers to all responses
app().registerPostHandlingAdvice([addHeaders](const HttpRequestPtr &req,
const HttpResponsePtr &resp) {
addHeaders(resp, req);
});
}
private:
static std::string joinStrings(const std::vector<std::string>& strings, const std::string& delimiter) {
std::string result;
for (size_t i = 0; i < strings.size(); ++i) {
result += strings[i];
if (i < strings.size() - 1) result += delimiter;
}
return result;
}
};
} // namespace middleware

View file

@ -1,9 +1,9 @@
#include "DatabaseService.h"
#include "../services/RedisHelper.h"
#include <drogon/orm/DbClient.h>
#include <random>
#include <sstream>
#include <iomanip>
#include <openssl/rand.h>
using namespace drogon;
using namespace drogon::orm;
@ -12,22 +12,25 @@ namespace {
void storeKeyInRedis(const std::string& streamKey) {
// Store the stream key in Redis for validation (24 hour TTL)
bool stored = RedisHelper::storeKey("stream_key:" + streamKey, "1", 86400);
if (stored) {
LOG_INFO << "Stored stream key in Redis: " << streamKey;
} else {
LOG_ERROR << "Failed to store key in Redis: " << streamKey;
}
}
// SECURITY FIX: Use cryptographically secure random bytes instead of mt19937
std::string generateStreamKey() {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(0, 255);
unsigned char bytes[32]; // 32 bytes = 64 hex characters
if (RAND_bytes(bytes, sizeof(bytes)) != 1) {
LOG_ERROR << "Failed to generate cryptographically secure random bytes";
throw std::runtime_error("Failed to generate secure stream key");
}
std::stringstream ss;
for (int i = 0; i < 16; ++i) {
ss << std::hex << std::setw(2) << std::setfill('0') << dis(gen);
for (size_t i = 0; i < sizeof(bytes); ++i) {
ss << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(bytes[i]);
}
return ss.str();
}

View file

@ -28,16 +28,17 @@ void RedisHelper::ensureConnected() {
sw::redis::ConnectionOptions opts;
opts.host = getRedisHost();
opts.port = getRedisPort();
opts.db = getRedisDb();
const char* envPass = std::getenv("REDIS_PASS");
if (envPass && strlen(envPass) > 0) {
opts.password = envPass;
}
opts.socket_timeout = std::chrono::milliseconds(1000);
opts.connect_timeout = std::chrono::milliseconds(1000);
LOG_INFO << "Connecting to Redis at " << opts.host << ":" << opts.port;
LOG_INFO << "Connecting to Redis at " << opts.host << ":" << opts.port << " db=" << opts.db;
_redis = std::make_unique<sw::redis::Redis>(opts);
_redis->ping();
@ -71,17 +72,35 @@ int RedisHelper::getRedisPort() const {
return std::stoi(envPort);
} catch (...) {}
}
try {
const auto& config = drogon::app().getCustomConfig();
if (config.isMember("redis") && config["redis"].isMember("port")) {
return config["redis"]["port"].asInt();
}
} catch (...) {}
return 6379;
}
int RedisHelper::getRedisDb() const {
const char* envDb = std::getenv("REDIS_DB");
if (envDb) {
try {
return std::stoi(envDb);
} catch (...) {}
}
try {
const auto& config = drogon::app().getCustomConfig();
if (config.isMember("redis") && config["redis"].isMember("db")) {
return config["redis"]["db"].asInt();
}
} catch (...) {}
return 0; // Default to db 0
}
void RedisHelper::executeInThreadPool(std::function<void()> task) {
auto loop = drogon::app().getLoop();
if (!loop) {
@ -212,19 +231,20 @@ void RedisHelper::expireAsync(const std::string &key,
// Sync versions for compatibility
std::unique_ptr<sw::redis::Redis> RedisHelper::getConnection() {
ensureConnected();
sw::redis::ConnectionOptions opts;
opts.host = getRedisHost();
opts.port = getRedisPort();
opts.db = getRedisDb();
const char* envPass = std::getenv("REDIS_PASS");
if (envPass && strlen(envPass) > 0) {
opts.password = envPass;
}
opts.socket_timeout = std::chrono::milliseconds(200);
opts.connect_timeout = std::chrono::milliseconds(200);
return std::make_unique<sw::redis::Redis>(opts);
}

View file

@ -121,6 +121,7 @@ private:
void executeInThreadPool(std::function<void()> task);
std::string getRedisHost() const;
int getRedisPort() const;
int getRedisDb() const;
std::unique_ptr<sw::redis::Redis> _redis;
bool _initialized;

View file

@ -0,0 +1,397 @@
#include "RestreamService.h"
#include <drogon/drogon.h>
#include <drogon/orm/DbClient.h>
#include <memory>
using namespace drogon;
using namespace drogon::orm;
// execCurl removed - using Drogon HttpClient instead for security
std::string RestreamService::getBaseUrl() {
const char* envUrl = std::getenv("OME_API_URL");
if (envUrl) {
return std::string(envUrl);
}
return "http://ovenmediaengine:8081";
}
std::string RestreamService::getApiToken() {
const char* envToken = std::getenv("OME_API_TOKEN");
if (!envToken || strlen(envToken) == 0) {
throw std::runtime_error("OME_API_TOKEN environment variable is not set");
}
return std::string(envToken);
}
HttpClientPtr RestreamService::getClient() {
return HttpClient::newHttpClient(getBaseUrl());
}
HttpRequestPtr RestreamService::createRequest(HttpMethod method, const std::string& path) {
auto request = HttpRequest::newHttpRequest();
request->setMethod(method);
request->setPath(path);
const auto token = getApiToken();
const auto b64 = drogon::utils::base64Encode(token);
request->addHeader("Authorization", std::string("Basic ") + b64);
return request;
}
HttpRequestPtr RestreamService::createJsonRequest(HttpMethod method, const std::string& path,
const Json::Value& body) {
auto request = HttpRequest::newHttpJsonRequest(body);
request->setMethod(method);
request->setPath(path);
const auto token = getApiToken();
const auto b64 = drogon::utils::base64Encode(token);
request->addHeader("Authorization", std::string("Basic ") + b64);
return request;
}
std::string RestreamService::generatePushId(const std::string& streamKey, int64_t destinationId) {
return "restream_" + streamKey + "_" + std::to_string(destinationId);
}
void RestreamService::updateDestinationStatus(int64_t destinationId, bool isConnected, const std::string& error) {
auto dbClient = app().getDbClient();
if (isConnected) {
*dbClient << "UPDATE restream_destinations SET is_connected = true, last_error = NULL, "
"last_connected_at = CURRENT_TIMESTAMP WHERE id = $1"
<< destinationId
>> [destinationId](const Result&) {
LOG_INFO << "Restream destination " << destinationId << " connected";
}
>> [destinationId](const DrogonDbException& e) {
LOG_ERROR << "Failed to update restream destination " << destinationId
<< ": " << e.base().what();
};
} else {
*dbClient << "UPDATE restream_destinations SET is_connected = false, last_error = $1 WHERE id = $2"
<< error << destinationId
>> [destinationId](const Result&) {
LOG_INFO << "Restream destination " << destinationId << " disconnected";
}
>> [destinationId](const DrogonDbException& e) {
LOG_ERROR << "Failed to update restream destination " << destinationId
<< ": " << e.base().what();
};
}
}
void RestreamService::startPush(const std::string& sourceStreamKey, const RestreamDestination& dest,
std::function<void(bool, const std::string&)> callback) {
// Build the full destination URL with stream key
std::string fullUrl = dest.rtmpUrl;
if (!fullUrl.empty() && fullUrl.back() != '/') {
fullUrl += '/';
}
fullUrl += dest.streamKey;
std::string pushId = generatePushId(sourceStreamKey, dest.id);
auto destId = dest.id;
LOG_INFO << "Starting RTMP push for stream " << sourceStreamKey
<< " to " << dest.name << " (" << dest.rtmpUrl << ")";
// Build JSON body
Json::Value body;
body["id"] = pushId;
body["stream"]["name"] = sourceStreamKey;
body["protocol"] = "rtmp";
body["url"] = fullUrl;
// Use Drogon HttpClient instead of curl for security
auto request = createJsonRequest(drogon::Post, "/v1/vhosts/default/apps/app:startPush", body);
LOG_INFO << "Sending HTTP request for push start";
getClient()->sendRequest(request,
[this, callback, pushId, sourceStreamKey, destId](ReqResult result, const HttpResponsePtr& response) {
if (result != ReqResult::Ok || !response) {
std::string error = "Failed to connect to OME API";
updateDestinationStatus(destId, false, error);
callback(false, error);
LOG_ERROR << "Failed to start RTMP push: " << error;
return;
}
auto json = response->getJsonObject();
if (json) {
int statusCode = (*json).get("statusCode", 0).asInt();
std::string message = (*json).get("message", "").asString();
// 200 = success, 400 with "Duplicate ID" = already running (treat as success)
bool isSuccess = (statusCode == 200);
bool isDuplicate = (statusCode == 400 && message.find("Duplicate") != std::string::npos);
if (isSuccess || isDuplicate) {
// Track the active push
{
std::lock_guard<std::mutex> lock(pushMutex_);
activePushes_[sourceStreamKey][destId] = pushId;
}
updateDestinationStatus(destId, true, "");
callback(true, "");
if (isDuplicate) {
LOG_INFO << "RTMP push already active (duplicate ID): " << pushId;
} else {
LOG_INFO << "RTMP push started successfully: " << pushId;
}
return;
} else {
std::string error = (*json).get("message", "Unknown error").asString();
updateDestinationStatus(destId, false, error);
callback(false, error);
LOG_ERROR << "Failed to start RTMP push: " << error;
return;
}
}
std::string error = "Invalid response from OME API";
updateDestinationStatus(destId, false, error);
callback(false, error);
LOG_ERROR << "Failed to start RTMP push: " << error;
});
}
void RestreamService::stopPush(const std::string& sourceStreamKey, int64_t destinationId,
std::function<void(bool)> callback) {
std::string pushId;
{
std::lock_guard<std::mutex> lock(pushMutex_);
auto streamIt = activePushes_.find(sourceStreamKey);
if (streamIt != activePushes_.end()) {
auto destIt = streamIt->second.find(destinationId);
if (destIt != streamIt->second.end()) {
pushId = destIt->second;
}
}
}
// If not tracked in memory, generate the push ID anyway and try to stop it
// This handles cases where server restarted but push is still active on OME
if (pushId.empty()) {
pushId = generatePushId(sourceStreamKey, destinationId);
}
LOG_INFO << "Stopping RTMP push: " << pushId;
// Build JSON body
Json::Value body;
body["id"] = pushId;
// Use Drogon HttpClient instead of curl for security
auto request = createJsonRequest(drogon::Post, "/v1/vhosts/default/apps/app:stopPush", body);
LOG_INFO << "Sending HTTP request for push stop";
getClient()->sendRequest(request,
[this, callback, pushId, sourceStreamKey, destinationId](ReqResult result, const HttpResponsePtr& response) {
// Remove from tracking regardless of result
{
std::lock_guard<std::mutex> lock(pushMutex_);
auto streamIt = activePushes_.find(sourceStreamKey);
if (streamIt != activePushes_.end()) {
streamIt->second.erase(destinationId);
if (streamIt->second.empty()) {
activePushes_.erase(streamIt);
}
}
}
updateDestinationStatus(destinationId, false, "");
if (result == ReqResult::Ok && response) {
auto json = response->getJsonObject();
if (json) {
int statusCode = (*json).get("statusCode", 0).asInt();
if (statusCode == 200 || statusCode == 404) {
callback(true);
LOG_INFO << "RTMP push stopped: " << pushId;
return;
}
}
}
// Even if API call failed, we've removed from tracking
callback(true);
LOG_WARN << "RTMP push stop may have failed, but removed from tracking: " << pushId;
});
}
void RestreamService::stopAllPushes(const std::string& sourceStreamKey,
std::function<void(bool)> callback) {
std::vector<int64_t> destinationIds;
{
std::lock_guard<std::mutex> lock(pushMutex_);
auto streamIt = activePushes_.find(sourceStreamKey);
if (streamIt != activePushes_.end()) {
for (const auto& [destId, pushId] : streamIt->second) {
destinationIds.push_back(destId);
}
}
}
if (destinationIds.empty()) {
callback(true);
return;
}
// Stop each push
auto remaining = std::make_shared<std::atomic<int>>(destinationIds.size());
auto allSuccess = std::make_shared<std::atomic<bool>>(true);
for (int64_t destId : destinationIds) {
stopPush(sourceStreamKey, destId, [remaining, allSuccess, callback](bool success) {
if (!success) {
allSuccess->store(false);
}
if (--(*remaining) == 0) {
callback(allSuccess->load());
}
});
}
}
void RestreamService::getPushStatus(const std::string& sourceStreamKey, int64_t destinationId,
std::function<void(bool, bool isConnected, const std::string& error)> callback) {
std::string pushId;
{
std::lock_guard<std::mutex> lock(pushMutex_);
auto streamIt = activePushes_.find(sourceStreamKey);
if (streamIt != activePushes_.end()) {
auto destIt = streamIt->second.find(destinationId);
if (destIt != streamIt->second.end()) {
pushId = destIt->second;
}
}
}
if (pushId.empty()) {
callback(true, false, "Not connected");
return;
}
// OME API: GET /v1/vhosts/{vhost}/apps/{app}/push
std::string path = "/v1/vhosts/default/apps/app/push";
auto request = createRequest(Get, path);
getClient()->sendRequest(request,
[callback, pushId](ReqResult result, const HttpResponsePtr& response) {
if (result == ReqResult::Ok && response && response->getStatusCode() == k200OK) {
try {
auto json = *response->getJsonObject();
// Look for our push in the response
if (json.isMember("response") && json["response"].isArray()) {
for (const auto& push : json["response"]) {
if (push["id"].asString() == pushId) {
std::string state = push.get("state", "unknown").asString();
bool connected = (state == "started" || state == "connected");
std::string error = push.get("error", "").asString();
callback(true, connected, error);
return;
}
}
}
callback(true, false, "Push not found");
} catch (const std::exception& e) {
callback(false, false, e.what());
}
} else {
callback(false, false, "Failed to get push status");
}
});
}
void RestreamService::startAllDestinations(const std::string& streamKey, int64_t realmId) {
LOG_INFO << "Starting all restream destinations for realm " << realmId;
auto dbClient = app().getDbClient();
*dbClient << "SELECT id, realm_id, name, rtmp_url, stream_key, enabled "
"FROM restream_destinations WHERE realm_id = $1 AND enabled = true"
<< realmId
>> [this, streamKey](const Result& r) {
for (const auto& row : r) {
RestreamDestination dest;
dest.id = row["id"].as<int64_t>();
dest.realmId = row["realm_id"].as<int64_t>();
dest.name = row["name"].as<std::string>();
dest.rtmpUrl = row["rtmp_url"].as<std::string>();
dest.streamKey = row["stream_key"].as<std::string>();
dest.enabled = row["enabled"].as<bool>();
startPush(streamKey, dest, [dest](bool success, const std::string& error) {
if (!success) {
LOG_ERROR << "Failed to start restream to " << dest.name << ": " << error;
}
});
}
}
>> [realmId](const DrogonDbException& e) {
LOG_ERROR << "Failed to fetch restream destinations for realm " << realmId
<< ": " << e.base().what();
};
}
void RestreamService::stopAllDestinations(const std::string& streamKey, int64_t realmId) {
LOG_INFO << "Stopping all restream destinations for realm " << realmId;
stopAllPushes(streamKey, [realmId](bool success) {
if (!success) {
LOG_WARN << "Some restream pushes may not have stopped cleanly for realm " << realmId;
}
});
// Also update all destinations in DB as disconnected
auto dbClient = app().getDbClient();
*dbClient << "UPDATE restream_destinations SET is_connected = false WHERE realm_id = $1"
<< realmId
>> [](const Result&) {}
>> [realmId](const DrogonDbException& e) {
LOG_ERROR << "Failed to update restream destinations for realm " << realmId
<< ": " << e.base().what();
};
}
void RestreamService::attemptReconnections(const std::string& streamKey, int64_t realmId) {
// Get all enabled but disconnected destinations and try to reconnect
auto dbClient = app().getDbClient();
*dbClient << "SELECT id, realm_id, name, rtmp_url, stream_key, enabled, is_connected "
"FROM restream_destinations "
"WHERE realm_id = $1 AND enabled = true AND is_connected = false"
<< realmId
>> [this, streamKey](const Result& r) {
for (const auto& row : r) {
RestreamDestination dest;
dest.id = row["id"].as<int64_t>();
dest.realmId = row["realm_id"].as<int64_t>();
dest.name = row["name"].as<std::string>();
dest.rtmpUrl = row["rtmp_url"].as<std::string>();
dest.streamKey = row["stream_key"].as<std::string>();
dest.enabled = row["enabled"].as<bool>();
LOG_INFO << "Attempting to reconnect restream destination: " << dest.name;
startPush(streamKey, dest, [dest](bool success, const std::string& error) {
if (success) {
LOG_INFO << "Reconnected restream to " << dest.name;
} else {
LOG_WARN << "Reconnection failed for " << dest.name << ": " << error;
}
});
}
}
>> [realmId](const DrogonDbException& e) {
LOG_ERROR << "Failed to fetch disconnected restream destinations for realm " << realmId
<< ": " << e.base().what();
};
}

View file

@ -0,0 +1,75 @@
#pragma once
#include <drogon/HttpClient.h>
#include <drogon/utils/Utilities.h>
#include <functional>
#include <string>
#include <unordered_map>
#include <mutex>
#include <memory>
struct RestreamDestination {
int64_t id;
int64_t realmId;
std::string name;
std::string rtmpUrl;
std::string streamKey;
bool enabled;
bool isConnected;
std::string lastError;
};
class RestreamService {
public:
static RestreamService& getInstance() {
static RestreamService instance;
return instance;
}
// Start pushing stream to a destination
void startPush(const std::string& sourceStreamKey, const RestreamDestination& dest,
std::function<void(bool, const std::string&)> callback);
// Stop pushing stream to a destination
void stopPush(const std::string& sourceStreamKey, int64_t destinationId,
std::function<void(bool)> callback);
// Stop all pushes for a stream
void stopAllPushes(const std::string& sourceStreamKey,
std::function<void(bool)> callback);
// Get push status for a destination
void getPushStatus(const std::string& sourceStreamKey, int64_t destinationId,
std::function<void(bool, bool isConnected, const std::string& error)> callback);
// Start all enabled destinations for a realm when stream goes live
void startAllDestinations(const std::string& streamKey, int64_t realmId);
// Stop all destinations for a realm when stream goes offline
void stopAllDestinations(const std::string& streamKey, int64_t realmId);
// Attempt reconnection for failed destinations (called periodically)
void attemptReconnections(const std::string& streamKey, int64_t realmId);
private:
RestreamService() = default;
~RestreamService() = default;
RestreamService(const RestreamService&) = delete;
RestreamService& operator=(const RestreamService&) = delete;
std::string getBaseUrl();
std::string getApiToken();
drogon::HttpClientPtr getClient();
drogon::HttpRequestPtr createRequest(drogon::HttpMethod method, const std::string& path);
drogon::HttpRequestPtr createJsonRequest(drogon::HttpMethod method, const std::string& path,
const Json::Value& body);
// Generate a unique push ID for tracking
std::string generatePushId(const std::string& streamKey, int64_t destinationId);
// Update destination status in database
void updateDestinationStatus(int64_t destinationId, bool isConnected, const std::string& error);
// Track active pushes: streamKey -> (destinationId -> pushId)
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>> activePushes_;
std::mutex pushMutex_;
};

View file

@ -2,6 +2,7 @@
#include "../controllers/StreamController.h"
#include "../services/RedisHelper.h"
#include "../services/OmeClient.h"
#include "../services/RestreamService.h"
#include <drogon/HttpClient.h>
#include <drogon/utils/Utilities.h>
#include <set>
@ -116,19 +117,25 @@ void StatsService::pollOmeStats() {
// Update each active stream
for (const auto& streamKey : activeStreamKeys) {
LOG_INFO << "Processing active stream: " << streamKey;
// IMMEDIATELY update database to mark as live
// IMMEDIATELY update database to mark as live and get realm ID
auto dbClient = app().getDbClient();
*dbClient << "UPDATE realms SET is_live = true, viewer_count = 0, "
"updated_at = CURRENT_TIMESTAMP WHERE stream_key = $1"
"updated_at = CURRENT_TIMESTAMP WHERE stream_key = $1 RETURNING id"
<< streamKey
>> [streamKey](const orm::Result&) {
>> [streamKey](const orm::Result& r) {
LOG_INFO << "Successfully marked realm as live: " << streamKey;
// Attempt reconnection for any disconnected restream destinations
if (!r.empty()) {
int64_t realmId = r[0]["id"].as<int64_t>();
RestreamService::getInstance().attemptReconnections(streamKey, realmId);
}
}
>> [streamKey](const orm::DrogonDbException& e) {
LOG_ERROR << "Failed to update realm live status: " << e.base().what();
};
// Then update detailed stats
updateStreamStats(streamKey);
}
@ -167,10 +174,12 @@ void StatsService::updateStreamStats(const std::string& streamKey) {
fetchStatsFromOme(streamKey, [this, streamKey](bool success, const StreamStats& stats) {
if (success) {
StreamStats updatedStats = stats;
updatedStats.uniqueViewers = getUniqueViewerCount(streamKey);
// Only count viewer tokens when stream is actually live
// Offline streams should show 0 viewers (tokens may linger for 5 min after disconnect)
updatedStats.uniqueViewers = stats.isLive ? getUniqueViewerCount(streamKey) : 0;
storeStatsInRedis(streamKey, updatedStats);
// Update realm in database
updateRealmLiveStatus(streamKey, updatedStats);
@ -267,31 +276,31 @@ void StatsService::fetchStatsFromOme(const std::string& streamKey,
hasInput = true;
const auto& input = data["input"];
// Get bitrate from input tracks
// Get bitrate from input tracks (OME returns bytes/sec, convert to bits/sec)
if (input.isMember("tracks") && input["tracks"].isArray()) {
for (const auto& track : input["tracks"]) {
if (track["type"].asString() == "video" && track.isMember("bitrate")) {
stats.bitrate = track["bitrate"].asDouble();
stats.bitrate = track["bitrate"].asDouble() * 8; // Convert bytes/sec to bits/sec
}
}
}
}
// Alternative: Check lastThroughputIn
// Alternative: Check lastThroughputIn (OME returns bytes/sec, convert to bits/sec)
if (!hasInput && data.isMember("lastThroughputIn")) {
double throughput = data["lastThroughputIn"].asDouble();
if (throughput > 0) {
hasInput = true;
stats.bitrate = throughput;
stats.bitrate = throughput * 8; // Convert bytes/sec to bits/sec
}
}
// Alternative: Check avgThroughputIn
// Alternative: Check avgThroughputIn (OME returns bytes/sec, convert to bits/sec)
if (!hasInput && data.isMember("avgThroughputIn")) {
double avgThroughput = data["avgThroughputIn"].asDouble();
if (avgThroughput > 0) {
hasInput = true;
stats.bitrate = avgThroughput;
stats.bitrate = avgThroughput * 8; // Convert bytes/sec to bits/sec
}
}
@ -479,8 +488,8 @@ void StatsService::getStreamStats(const std::string& streamKey,
fetchStatsFromOme(streamKey, [this, callback, streamKey](bool success, const StreamStats& stats) {
if (success) {
StreamStats updatedStats = stats;
// Set uniqueViewers on cache miss
updatedStats.uniqueViewers = getUniqueViewerCount(streamKey);
// Only count viewer tokens when stream is actually live
updatedStats.uniqueViewers = stats.isLive ? getUniqueViewerCount(streamKey) : 0;
callback(true, updatedStats);
} else {
callback(false, stats);
@ -504,7 +513,7 @@ void StatsService::getStreamStats(const std::string& streamKey,
stats.resolution = json["resolution"].asString();
stats.fps = json["fps"].asDouble();
stats.isLive = json["is_live"].asBool();
// Parse protocol connections
if (json.isMember("protocol_connections")) {
const auto& pc = json["protocol_connections"];
@ -513,19 +522,42 @@ void StatsService::getStreamStats(const std::string& streamKey,
stats.protocolConnections.llhls = pc["llhls"].asInt64();
stats.protocolConnections.dash = pc["dash"].asInt64();
}
stats.lastUpdated = std::chrono::system_clock::time_point(
std::chrono::seconds(json["last_updated"].asInt64())
);
callback(true, stats);
// Verify is_live from database (source of truth from webhooks)
// This prevents stale cache from overriding the webhook-updated DB state
auto dbClient = app().getDbClient();
*dbClient << "SELECT is_live FROM realms WHERE stream_key = $1"
<< streamKey
>> [callback, stats](const orm::Result& r) mutable {
if (!r.empty()) {
bool dbIsLive = r[0]["is_live"].as<bool>();
// If database says live but cache says offline, trust database
// (webhooks update DB immediately, cache may be stale)
if (dbIsLive && !stats.isLive) {
LOG_DEBUG << "Overriding stale cache: DB says live, cache says offline";
stats.isLive = true;
}
}
callback(true, stats);
}
>> [callback, stats](const orm::DrogonDbException& e) {
LOG_ERROR << "Failed to verify is_live from DB: " << e.base().what();
// Fall back to cached value on DB error
callback(true, stats);
};
LOG_DEBUG << "Retrieved cached stats for " << streamKey;
return; // Callback handled async
} else {
// Fallback to fresh fetch if cached data is corrupted
fetchStatsFromOme(streamKey, [this, callback, streamKey](bool success, const StreamStats& stats) {
if (success) {
StreamStats updatedStats = stats;
updatedStats.uniqueViewers = getUniqueViewerCount(streamKey);
// Only count viewer tokens when stream is actually live
updatedStats.uniqueViewers = stats.isLive ? getUniqueViewerCount(streamKey) : 0;
callback(true, updatedStats);
} else {
callback(false, stats);
@ -538,7 +570,8 @@ void StatsService::getStreamStats(const std::string& streamKey,
fetchStatsFromOme(streamKey, [this, callback, streamKey](bool success, const StreamStats& stats) {
if (success) {
StreamStats updatedStats = stats;
updatedStats.uniqueViewers = getUniqueViewerCount(streamKey);
// Only count viewer tokens when stream is actually live
updatedStats.uniqueViewers = stats.isLive ? getUniqueViewerCount(streamKey) : 0;
callback(true, updatedStats);
} else {
callback(false, stats);

View file

@ -69,5 +69,7 @@ private:
std::atomic<bool> running_{false};
std::optional<trantor::TimerId> timerId_;
std::chrono::seconds pollInterval_{2}; // Poll every 2 seconds
// Poll every 5 seconds for near-instant stats updates
// Real-time updates also come via OME webhooks (see StreamController::handleOmeWebhook)
std::chrono::seconds pollInterval_{5};
};

View file

@ -0,0 +1,281 @@
#include "TreasuryService.h"
#include <ctime>
#include <cmath>
#include <iomanip>
#include <sstream>
using namespace drogon;
TreasuryService::~TreasuryService() {
shutdown();
}
void TreasuryService::initialize() {
LOG_INFO << "Initializing Treasury Service...";
running_ = true;
}
void TreasuryService::startScheduler() {
if (!running_) {
LOG_WARN << "Treasury service not initialized, cannot start scheduler";
return;
}
LOG_INFO << "Starting treasury scheduler...";
if (auto loop = drogon::app().getLoop()) {
try {
// Do an immediate check on startup (catches up on missed tasks)
checkAndRunTasks();
// Then set up the hourly timer
timerId_ = loop->runEvery(
checkInterval_.count(),
[this]() {
if (!running_) return;
try {
checkAndRunTasks();
} catch (const std::exception& e) {
LOG_ERROR << "Error in treasury scheduler: " << e.what();
}
}
);
LOG_INFO << "Treasury scheduler started with " << checkInterval_.count() << "s interval";
} catch (const std::exception& e) {
LOG_ERROR << "Failed to create treasury timer: " << e.what();
}
} else {
LOG_ERROR << "Event loop not available for treasury scheduler";
}
}
void TreasuryService::shutdown() {
LOG_INFO << "Shutting down Treasury Service...";
running_ = false;
if (timerId_.has_value()) {
if (auto loop = drogon::app().getLoop()) {
loop->invalidateTimer(timerId_.value());
}
timerId_.reset();
}
}
void TreasuryService::checkAndRunTasks() {
LOG_INFO << "Treasury scheduler: checking for pending tasks...";
auto dbClient = app().getDbClient();
// Get current time info
std::time_t now = std::time(nullptr);
std::tm* localTime = std::localtime(&now);
int dayOfWeek = localTime->tm_wday; // 0 = Sunday, 1 = Monday, ..., 6 = Saturday
// Get treasury timestamps
*dbClient << "SELECT last_growth_at, last_distribution_at FROM ubercoin_treasury WHERE id = 1"
>> [this, dayOfWeek, localTime](const orm::Result& r) {
if (r.empty()) {
LOG_WARN << "Treasury record not found";
return;
}
bool needsGrowth = false;
bool needsDistribution = false;
std::time_t now = std::time(nullptr);
std::tm* todayStart = std::localtime(&now);
todayStart->tm_hour = 0;
todayStart->tm_min = 0;
todayStart->tm_sec = 0;
std::time_t todayStartTime = std::mktime(todayStart);
// Check if growth is needed (Mon-Sat, once per day)
if (dayOfWeek >= 1 && dayOfWeek <= 6) { // Monday to Saturday
if (r[0]["last_growth_at"].isNull()) {
needsGrowth = true;
} else {
std::string lastGrowthStr = r[0]["last_growth_at"].as<std::string>();
std::tm lastGrowthTm = {};
std::istringstream ss(lastGrowthStr);
ss >> std::get_time(&lastGrowthTm, "%Y-%m-%d %H:%M:%S");
std::time_t lastGrowthTime = std::mktime(&lastGrowthTm);
// If last growth was before today, we need to apply growth
if (lastGrowthTime < todayStartTime) {
needsGrowth = true;
}
}
}
// Check if distribution is needed (Sunday, once per week)
if (dayOfWeek == 0) { // Sunday
if (r[0]["last_distribution_at"].isNull()) {
needsDistribution = true;
} else {
std::string lastDistStr = r[0]["last_distribution_at"].as<std::string>();
std::tm lastDistTm = {};
std::istringstream ss(lastDistStr);
ss >> std::get_time(&lastDistTm, "%Y-%m-%d %H:%M:%S");
std::time_t lastDistTime = std::mktime(&lastDistTm);
// If last distribution was before today, we need to distribute
if (lastDistTime < todayStartTime) {
needsDistribution = true;
}
}
}
if (needsGrowth) {
LOG_INFO << "Treasury scheduler: applying daily growth";
this->applyDailyGrowth();
}
if (needsDistribution) {
LOG_INFO << "Treasury scheduler: distributing to users";
this->distributeToUsers();
}
if (!needsGrowth && !needsDistribution) {
LOG_INFO << "Treasury scheduler: no tasks needed at this time";
}
}
>> [](const orm::DrogonDbException& e) {
LOG_ERROR << "Treasury scheduler: failed to check timestamps: " << e.base().what();
};
}
void TreasuryService::applyDailyGrowth() {
auto dbClient = app().getDbClient();
// Apply 3.3% growth to treasury balance
*dbClient << "UPDATE ubercoin_treasury SET balance = balance * 1.033, last_growth_at = NOW() WHERE id = 1 RETURNING balance"
>> [](const orm::Result& r) {
double newBalance = 0.0;
if (!r.empty()) {
newBalance = r[0]["balance"].as<double>();
}
LOG_INFO << "Treasury growth applied (3.3%). New balance: " << newBalance;
}
>> [](const orm::DrogonDbException& e) {
LOG_ERROR << "Failed to apply treasury growth: " << e.base().what();
};
}
void TreasuryService::distributeToUsers() {
auto dbClient = app().getDbClient();
// Get treasury balance
*dbClient << "SELECT balance FROM ubercoin_treasury WHERE id = 1"
>> [this, dbClient](const orm::Result& r) {
if (r.empty()) {
LOG_ERROR << "Treasury not found for distribution";
return;
}
double treasuryBalance = r[0]["balance"].as<double>();
if (treasuryBalance <= 0) {
LOG_INFO << "Treasury empty, nothing to distribute";
return;
}
// Get all users with their created_at for burn rate calculation
*dbClient << "SELECT id, created_at, ubercoin_balance FROM users"
>> [this, dbClient, treasuryBalance](const orm::Result& users) {
if (users.empty()) {
LOG_INFO << "No users to distribute to";
return;
}
int64_t userCount = users.size();
double sharePerUser = treasuryBalance / static_cast<double>(userCount);
double totalDistributed = 0.0;
double totalDestroyed = 0.0;
// Calculate distributions for each user
for (const auto& row : users) {
int64_t userId = row["id"].as<int64_t>();
std::string createdAt = row["created_at"].as<std::string>();
double currentBalance = row["ubercoin_balance"].isNull() ? 0.0 : row["ubercoin_balance"].as<double>();
int accountAgeDays = this->calculateAccountAgeDays(createdAt);
double burnRate = this->calculateBurnRate(accountAgeDays);
// Calculate received amount (after burn) - ceiling for user benefit
double receivedAmount = sharePerUser * (100.0 - burnRate) / 100.0;
receivedAmount = std::ceil(receivedAmount * 1000.0) / 1000.0;
double destroyedAmount = sharePerUser - receivedAmount;
double newBalance = currentBalance + receivedAmount;
// Update user balance
*dbClient << "UPDATE users SET ubercoin_balance = $1 WHERE id = $2"
<< newBalance << userId
>> [](const orm::Result&) {}
>> [userId](const orm::DrogonDbException& e) {
LOG_ERROR << "Failed to update user " << userId << " balance in distribution: " << e.base().what();
};
totalDistributed += receivedAmount;
totalDestroyed += destroyedAmount;
}
// Reset treasury balance to 0 and update total_destroyed
*dbClient << "UPDATE ubercoin_treasury SET balance = 0, total_destroyed = total_destroyed + $1, last_distribution_at = NOW() WHERE id = 1"
<< totalDestroyed
>> [totalDistributed, totalDestroyed, userCount](const orm::Result&) {
LOG_INFO << "Treasury distributed successfully. Users: " << userCount
<< ", Distributed: " << totalDistributed
<< ", Destroyed: " << totalDestroyed;
}
>> [](const orm::DrogonDbException& e) {
LOG_ERROR << "Failed to reset treasury: " << e.base().what();
};
}
>> [](const orm::DrogonDbException& e) {
LOG_ERROR << "Failed to get users for distribution: " << e.base().what();
};
}
>> [](const orm::DrogonDbException& e) {
LOG_ERROR << "Failed to get treasury balance: " << e.base().what();
};
}
double TreasuryService::calculateBurnRate(int accountAgeDays) {
double burnRate = 99.0 * std::exp(-static_cast<double>(accountAgeDays) / 180.0);
return std::max(1.0, burnRate);
}
int TreasuryService::calculateAccountAgeDays(const std::string& createdAt) {
try {
// Parse ISO 8601 timestamp (e.g., "2025-01-15T10:30:00+00:00" or "2025-01-15 10:30:00")
std::tm tm = {};
std::istringstream ss(createdAt);
// Try ISO 8601 format first
ss >> std::get_time(&tm, "%Y-%m-%dT%H:%M:%S");
if (ss.fail()) {
// Try space-separated format
ss.clear();
ss.str(createdAt);
ss >> std::get_time(&tm, "%Y-%m-%d %H:%M:%S");
}
if (ss.fail()) {
LOG_WARN << "Failed to parse created_at timestamp: " << createdAt;
return 0;
}
std::time_t createdTime = std::mktime(&tm);
std::time_t now = std::time(nullptr);
// Calculate difference in days
double diffSeconds = std::difftime(now, createdTime);
return static_cast<int>(diffSeconds / (60 * 60 * 24));
} catch (const std::exception& e) {
LOG_ERROR << "Error calculating account age: " << e.what();
return 0;
}
}

View file

@ -0,0 +1,40 @@
#pragma once
#include <drogon/drogon.h>
#include <trantor/net/EventLoop.h>
#include <atomic>
#include <optional>
#include <chrono>
class TreasuryService {
public:
static TreasuryService& getInstance() {
static TreasuryService instance;
return instance;
}
void initialize();
void startScheduler();
void shutdown();
// Manual triggers (for testing/admin)
void applyDailyGrowth();
void distributeToUsers();
private:
TreasuryService() = default;
~TreasuryService();
TreasuryService(const TreasuryService&) = delete;
TreasuryService& operator=(const TreasuryService&) = delete;
void checkAndRunTasks();
// Burn rate calculation helpers
double calculateBurnRate(int accountAgeDays);
int calculateAccountAgeDays(const std::string& createdAt);
std::atomic<bool> running_{false};
std::optional<trantor::TimerId> timerId_;
// Check every hour (3600 seconds)
std::chrono::seconds checkInterval_{3600};
};