diff --git a/src/connection.h b/src/connection.h index 6456b95..2e2da2d 100644 --- a/src/connection.h +++ b/src/connection.h @@ -11,5 +11,5 @@ struct BaseConnection { bool Open(); bool Close(); bool Write(const void* data, size_t length); - bool Read(void* data, size_t& length); + bool Read(void* data, size_t length); }; diff --git a/src/connection_win.cpp b/src/connection_win.cpp index afa4212..2decdc1 100644 --- a/src/connection_win.cpp +++ b/src/connection_win.cpp @@ -59,11 +59,7 @@ bool BaseConnection::Close() bool BaseConnection::Write(const void* data, size_t length) { auto self = reinterpret_cast(this); - BOOL success = ::WriteFile(self->pipe, data, length, nullptr, nullptr); - if (!success) { - self->Close(); - } - return success; + return ::WriteFile(self->pipe, data, length, nullptr, nullptr) == TRUE; } bool BaseConnection::Read(void* data, size_t length) diff --git a/src/discord-rpc.cpp b/src/discord-rpc.cpp index 9e09432..22df12a 100644 --- a/src/discord-rpc.cpp +++ b/src/discord-rpc.cpp @@ -5,14 +5,57 @@ #include "rapidjson/document.h" #include +#include +#include +#include +#include -static RpcConnection* MyConnection = nullptr; +static RpcConnection* Connection{nullptr}; static char ApplicationId[64]{}; static DiscordEventHandlers Handlers{}; -static bool WasJustConnected = false; -static bool WasJustDisconnected = false; +static std::atomic_bool WasJustConnected{false}; +static std::atomic_bool WasJustDisconnected{false}; static int LastErrorCode = 0; static char LastErrorMessage[256]; +static std::atomic_bool KeepRunning{true}; +static std::mutex WaitForIOMutex; +static std::condition_variable WaitForIOActivity; +static std::thread IoThread; + +void Discord_UpdateConnection() +{ + if (!Connection->IsOpen()) { + Connection->Open(); + } + else { + // reads + rapidjson::Document message; + while (Connection->Read(message)) { + // todo: do something... + printf("Hey, I got a message\n"); + } + } +} + +void DiscordRpcIo() +{ + printf("Discord io thread start\n"); + const std::chrono::duration maxWait{500LL}; + + while (KeepRunning.load()) { + Discord_UpdateConnection(); + + std::unique_lock lock(WaitForIOMutex); + WaitForIOActivity.wait_for(lock, maxWait); + } + Connection->Close(); + printf("Discord io thread stop\n"); +} + +void SignalIOActivity() +{ + WaitForIOActivity.notify_all(); +} extern "C" void Discord_Initialize(const char* applicationId, DiscordEventHandlers* handlers) { @@ -23,66 +66,49 @@ extern "C" void Discord_Initialize(const char* applicationId, DiscordEventHandle Handlers = {}; } - MyConnection = RpcConnection::Create(applicationId); - MyConnection->onConnect = []() { WasJustConnected = true; }; - MyConnection->onDisconnect = [](int err, const char* message) { + Connection = RpcConnection::Create(applicationId); + Connection->onConnect = []() { + WasJustConnected.exchange(true); + }; + Connection->onDisconnect = [](int err, const char* message) { LastErrorCode = err; StringCopy(LastErrorMessage, message, sizeof(LastErrorMessage)); - WasJustDisconnected = true; + WasJustDisconnected.exchange(true); }; - MyConnection->Open(); + + IoThread = std::thread(DiscordRpcIo); } extern "C" void Discord_Shutdown() { + Connection->onConnect = nullptr; + Connection->onDisconnect = nullptr; Handlers = {}; - MyConnection->onConnect = nullptr; - MyConnection->onDisconnect = nullptr; - MyConnection->Close(); - RpcConnection::Destroy(MyConnection); + KeepRunning.exchange(false); + SignalIOActivity(); + if (IoThread.joinable()) { + IoThread.join(); + } + RpcConnection::Destroy(Connection); } extern "C" void Discord_UpdatePresence(const DiscordRichPresence* presence) { - auto frame = MyConnection->GetNextFrame(); - frame->opcode = OPCODE::FRAME; - char* jsonWrite = frame->message; + char jsonBuffer[16 * 1024]; + char* jsonWrite = jsonBuffer; JsonWriteRichPresenceObj(jsonWrite, presence); - frame->length = jsonWrite - frame->message; - MyConnection->WriteFrame(frame); + size_t length = jsonWrite - jsonBuffer; + Connection->Write(jsonBuffer, length); + SignalIOActivity(); } extern "C" void Discord_Update() { - while (auto frame = MyConnection->Read()) { - rapidjson::Document d; - if (frame->length > 0) { - d.ParseInsitu(frame->message); - } - - switch (frame->opcode) { - case OPCODE::HANDSHAKE: - // does this happen? - break; - case OPCODE::CLOSE: - LastErrorCode = d["code"].GetInt(); - StringCopy(LastErrorMessage, d["code"].GetString(), sizeof(LastErrorMessage)); - MyConnection->Close(); - break; - case OPCODE::FRAME: - // todo - break; - } - } - - // fire callbacks - if (WasJustDisconnected && Handlers.disconnected) { - WasJustDisconnected = false; + if (WasJustDisconnected.exchange(false) && Handlers.disconnected) { Handlers.disconnected(LastErrorCode, LastErrorMessage); } - if (WasJustConnected && Handlers.ready) { - WasJustConnected = false; + if (WasJustConnected.exchange(false) && Handlers.ready) { Handlers.ready(); } } diff --git a/src/rpc_connection.cpp b/src/rpc_connection.cpp index df292c5..84c023a 100644 --- a/src/rpc_connection.cpp +++ b/src/rpc_connection.cpp @@ -1,6 +1,18 @@ #include "rpc_connection.h" +#include "yolojson.h" -RpcConnection Instance; +#include + +static const int RpcVersion = 1; +static RpcConnection Instance; +static const size_t SendQueueSize = 4; +static RpcConnection::MessageFrame SendQueue[SendQueueSize]; +static std::atomic_uint SendQueueNext = 0; + +static RpcConnection::MessageFrame* NextSendFrame() { + auto index = (SendQueueNext++) % SendQueueSize; + return &SendQueue[index]; +} /*static*/ RpcConnection* RpcConnection::Create(const char* applicationId) { @@ -11,6 +23,110 @@ RpcConnection Instance; /*static*/ void RpcConnection::Destroy(RpcConnection*& c) { + c->Close(); BaseConnection::Destroy(c->connection); } +void RpcConnection::Open() +{ + if (state == State::Connected) { + return; + } + + if (state == State::Disconnected) { + if (connection->Open()) { + state = State::Connecting; + } + else { + return; + } + } + + auto handshakeFrame = NextSendFrame(); + handshakeFrame->opcode = Opcode::Handshake; + char* json = handshakeFrame->message; + JsonWriteHandshakeObj(json, RpcVersion, appId); + handshakeFrame->length = json - handshakeFrame->message; + + if (connection->Write(handshakeFrame, sizeof(MessageFrameHeader) + handshakeFrame->length)) { + state = State::Connected; + if (onConnect) { + onConnect(); + } + } +} + +void RpcConnection::Close() +{ + if (onDisconnect && state == State::Connected) { + onDisconnect(lastErrorCode, lastErrorMessage); + } + connection->Close(); + state = State::Disconnected; +} + +void RpcConnection::Write(const void* data, size_t length) +{ + auto frame = NextSendFrame(); + frame->opcode = Opcode::Frame; + memcpy(frame->message, data, length); + frame->length = length; + if (!connection->Write(frame, sizeof(MessageFrameHeader) + length)) { + Close(); + } +} + +bool RpcConnection::Read(rapidjson::Document& message) +{ + if (state != State::Connected) { + return false; + } + MessageFrame readFrame; + for (;;) { + bool didRead = connection->Read(&readFrame, sizeof(MessageFrameHeader)); + if (!didRead) { + return false; + } + + if (readFrame.length > 0) { + didRead = connection->Read(readFrame.message, readFrame.length); + if (!didRead) { + lastErrorCode = -2; + StringCopy(lastErrorMessage, "Partial data in frame"); + Close(); + return false; + } + readFrame.message[readFrame.length] = 0; + message.ParseInsitu(readFrame.message); + } + + switch (readFrame.opcode) { + case Opcode::Close: + { + lastErrorCode = message["code"].GetInt(); + const auto& m = message["message"]; + StringCopy(lastErrorMessage, m.GetString(), sizeof(lastErrorMessage)); + Close(); + return false; + } + case Opcode::Frame: + return true; + case Opcode::Ping: + { + MessageFrameHeader frame{ Opcode::Pong, 0 }; + if (!connection->Write(&frame, sizeof(MessageFrameHeader))) { + Close(); + } + break; + } + case Opcode::Pong: + break; + default: + // something bad happened + lastErrorCode = -1; + StringCopy(lastErrorMessage, "Bad ipc frame"); + Close(); + return false; + } + } +} diff --git a/src/rpc_connection.h b/src/rpc_connection.h index 0013d4f..23cc5da 100644 --- a/src/rpc_connection.h +++ b/src/rpc_connection.h @@ -1,30 +1,49 @@ #pragma once #include "connection.h" +#include "rapidjson/document.h" struct RpcConnection { enum class Opcode : uint32_t { Handshake = 0, Frame = 1, Close = 2, + Ping = 3, + Pong = 4, }; - struct MessageFrame { + struct MessageFrameHeader { Opcode opcode; uint32_t length; - char message[64 * 1024 - 8]; + }; + + struct MessageFrame : public MessageFrameHeader { + char message[64 * 1024 - sizeof(MessageFrameHeader)]; + }; + + enum class State : uint32_t { + Disconnected, + Connecting, + Connected, }; BaseConnection* connection{nullptr}; + State state{State::Disconnected}; void (*onConnect)(){nullptr}; void (*onDisconnect)(int errorCode, const char* message){nullptr}; char appId[64]{}; + int lastErrorCode{0}; + char lastErrorMessage[256]{}; static RpcConnection* Create(const char* applicationId); static void Destroy(RpcConnection*&); + inline bool IsOpen() const { + return state == State::Connected; + } + void Open(); void Close(); void Write(const void* data, size_t length); - bool Read(void* data, size_t& length); -}; \ No newline at end of file + bool Read(rapidjson::Document& message); +};