diff --git a/src/core/hle/service/nwm/nwm_uds.cpp b/src/core/hle/service/nwm/nwm_uds.cpp index 0aa63cc1e..d2baf249c 100644 --- a/src/core/hle/service/nwm/nwm_uds.cpp +++ b/src/core/hle/service/nwm/nwm_uds.cpp @@ -47,8 +47,17 @@ static NodeList node_info; // Node information about our own system. static NodeInfo current_node; -// Mapping of bind node ids to their respective events. -static std::unordered_map> bind_node_events; +struct BindNodeData { + u32 bind_node_id; ///< Id of the bind node associated with this data. + u8 channel; ///< Channel that this bind node was bound to. + u16 network_node_id; ///< Node id this bind node is associated with, only packets from this + /// network node will be received. + Kernel::SharedPtr event; ///< Receive event for this bind node. + std::deque> received_packets; ///< List of packets received on this channel. +}; + +// Mapping of data channels to their internal data. +static std::unordered_map channel_data; // The WiFi network channel that the network is currently on. // Since we're not actually interacting with physical radio waves, this is just a dummy value. @@ -75,6 +84,9 @@ constexpr size_t MaxBeaconFrames = 15; // List of the last beacons received from the network. static std::list received_beacons; +// Network node id used when a SecureData packet is addressed to every connected node. +constexpr u16 BroadcastNetworkNodeId = 0xFFFF; + /** * Returns a list of received 802.11 beacon frames from the specified sender since the last call. */ @@ -143,7 +155,7 @@ void HandleAssociationResponseFrame(const Network::WifiPacket& packet) { "Could not join network"); { std::lock_guard lock(connection_status_mutex); - ASSERT(connection_status.status == static_cast(NetworkStatus::Connecting)); + ASSERT(connection_status.status == static_cast(NetworkStatus::NotConnected)); } // Send the EAPoL-Start packet to the server. @@ -159,6 +171,7 @@ void HandleAssociationResponseFrame(const Network::WifiPacket& packet) { } static void HandleEAPoLPacket(const Network::WifiPacket& packet) { + std::lock_guard hle_lock(HLE::g_hle_lock); std::lock_guard lock(connection_status_mutex); if (GetEAPoLFrameType(packet.data) == EAPoLStartMagic) { @@ -205,7 +218,6 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { SendPacket(eapol_logoff); // TODO(B3N30): Broadcast updated node list // The 3ds does this presumably to support spectators. - std::lock_guard lock(HLE::g_hle_lock); connection_status_event->Signal(); } else { if (connection_status.status != static_cast(NetworkStatus::NotConnected)) { @@ -242,6 +254,58 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { } } +static void HandleSecureDataPacket(const Network::WifiPacket& packet) { + auto secure_data = ParseSecureDataHeader(packet.data); + std::lock_guard hle_lock(HLE::g_hle_lock); + std::lock_guard lock(connection_status_mutex); + + if (secure_data.src_node_id == connection_status.network_node_id) { + // Ignore packets that came from ourselves. + return; + } + + if (secure_data.dest_node_id != connection_status.network_node_id && + secure_data.dest_node_id != BroadcastNetworkNodeId) { + // The packet wasn't addressed to us, we can only act as a router if we're the host. + // However, we might have received this packet due to a broadcast from the host, in that + // case just ignore it. + ASSERT_MSG(packet.destination_address == Network::BroadcastMac || + connection_status.status == static_cast(NetworkStatus::ConnectedAsHost), + "Can't be a router if we're not a host"); + + if (connection_status.status == static_cast(NetworkStatus::ConnectedAsHost) && + secure_data.dest_node_id != BroadcastNetworkNodeId) { + // Broadcast the packet so the right receiver can get it. + // TODO(B3N30): Is there a flag that makes this kind of routing be unicast instead of + // multicast? Perhaps this is a way to allow spectators to see some of the packets. + Network::WifiPacket out_packet = packet; + out_packet.destination_address = Network::BroadcastMac; + SendPacket(out_packet); + } + return; + } + + // The packet is addressed to us (or to everyone using the broadcast node id), handle it. + // TODO(B3N30): We don't currently send nor handle management frames. + ASSERT(!secure_data.is_management); + + // TODO(B3N30): Allow more than one bind node per channel. + auto channel_info = channel_data.find(secure_data.data_channel); + // Ignore packets from channels we're not interested in. + if (channel_info == channel_data.end()) + return; + + if (channel_info->second.network_node_id != BroadcastNetworkNodeId && + channel_info->second.network_node_id != secure_data.src_node_id) + return; + + // Add the received packet to the data queue. + channel_info->second.received_packets.emplace_back(packet.data); + + // Signal the data event. We can do this directly because we locked g_hle_lock + channel_info->second.event->Signal(); +} + /* * Start a connection sequence with an UDS server. The sequence starts by sending an 802.11 * authentication frame with SEQ1. @@ -329,7 +393,7 @@ static void HandleDataFrame(const Network::WifiPacket& packet) { HandleEAPoLPacket(packet); break; case EtherType::SecureData: - // TODO(B3N30): Handle SecureData packets + HandleSecureDataPacket(packet); break; } } @@ -557,8 +621,6 @@ static void Bind(Interface* self) { u8 data_channel = rp.Pop(); u16 network_node_id = rp.Pop(); - // TODO(Subv): Store the data channel and verify it when receiving data frames. - LOG_DEBUG(Service_NWM, "called"); if (data_channel == 0) { @@ -569,13 +631,15 @@ static void Bind(Interface* self) { } // Create a new event for this bind node. - // TODO(Subv): Signal this event when new data is received on this data channel. auto event = Kernel::Event::Create(Kernel::ResetType::OneShot, "NWM::BindNodeEvent" + std::to_string(bind_node_id)); - bind_node_events[bind_node_id] = event; + std::lock_guard lock(connection_status_mutex); + + ASSERT(channel_data.find(data_channel) == channel_data.end()); + // TODO(B3N30): Support more than one bind node per channel. + channel_data[data_channel] = {bind_node_id, data_channel, network_node_id, event}; IPC::RequestBuilder rb = rp.MakeBuilder(1, 2); - rb.Push(RESULT_SUCCESS); rb.PushCopyHandles(Kernel::g_handle_table.Create(event).Unwrap()); } @@ -722,31 +786,25 @@ static void SendTo(Interface* self) { size_t desc_size; const VAddr input_address = rp.PopStaticBuffer(&desc_size, false); - ASSERT(desc_size == data_size); + ASSERT(desc_size >= data_size); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); - u16 network_node_id; - - { - std::lock_guard lock(connection_status_mutex); - if (connection_status.status != static_cast(NetworkStatus::ConnectedAsClient) && - connection_status.status != static_cast(NetworkStatus::ConnectedAsHost)) { - rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS, - ErrorSummary::InvalidState, ErrorLevel::Status)); - return; - } - - if (dest_node_id == connection_status.network_node_id) { - rb.Push(ResultCode(ErrorDescription::NotFound, ErrorModule::UDS, - ErrorSummary::WrongArgument, ErrorLevel::Status)); - return; - } - - network_node_id = connection_status.network_node_id; + std::lock_guard lock(connection_status_mutex); + if (connection_status.status != static_cast(NetworkStatus::ConnectedAsClient) && + connection_status.status != static_cast(NetworkStatus::ConnectedAsHost)) { + rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS, + ErrorSummary::InvalidState, ErrorLevel::Status)); + return; } - // TODO(Subv): Do something with the flags. + if (dest_node_id == connection_status.network_node_id) { + rb.Push(ResultCode(ErrorDescription::NotFound, ErrorModule::UDS, + ErrorSummary::WrongArgument, ErrorLevel::Status)); + return; + } + + // TODO(B3N30): Do something with the flags. constexpr size_t MaxSize = 0x5C6; if (data_size > MaxSize) { @@ -758,20 +816,107 @@ static void SendTo(Interface* self) { std::vector data(data_size); Memory::ReadBlock(input_address, data.data(), data.size()); - // TODO(Subv): Increment the sequence number after each sent packet. + // TODO(B3N30): Increment the sequence number after each sent packet. u16 sequence_number = 0; - std::vector data_payload = - GenerateDataPayload(data, data_channel, dest_node_id, network_node_id, sequence_number); + std::vector data_payload = GenerateDataPayload( + data, data_channel, dest_node_id, connection_status.network_node_id, sequence_number); - // TODO(Subv): Retrieve the MAC address of the dest_node_id and our own to encrypt + // TODO(B3N30): Retrieve the MAC address of the dest_node_id and our own to encrypt // and encapsulate the payload. - // TODO(Subv): Send the frame. + Network::WifiPacket packet; + // Data frames are sent to the host, who then decides where to route it to. If we're the host, + // just directly broadcast the frame. + if (connection_status.status == static_cast(NetworkStatus::ConnectedAsHost)) + packet.destination_address = Network::BroadcastMac; + else + packet.destination_address = network_info.host_mac_address; + packet.channel = network_channel; + packet.data = std::move(data_payload); + packet.type = Network::WifiPacket::PacketType::Data; + + SendPacket(packet); rb.Push(RESULT_SUCCESS); +} - LOG_WARNING(Service_NWM, "(STUB) called dest_node_id=%u size=%u flags=%u channel=%u", - static_cast(dest_node_id), data_size, flags, static_cast(data_channel)); +/** + * NWM_UDS::PullPacket service function. + * Receives a data frame from the specified bind node id + * Inputs: + * 0 : Command header. + * 1 : Bind node id. + * 2 : Max out buff size >> 2. + * 3 : Max out buff size. + * 64 : Output buffer descriptor + * 65 : Output buffer address + * Outputs: + * 0 : Return header + * 1 : Result of function, 0 on success, otherwise error code + * 2 : Received data size + * 3 : u16 Source network node id + * 4 : Buffer descriptor + * 5 : Buffer address + */ +static void PullPacket(Interface* self) { + IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0x14, 3, 0); + + u32 bind_node_id = rp.Pop(); + u32 max_out_buff_size_aligned = rp.Pop(); + u32 max_out_buff_size = rp.Pop(); + + size_t desc_size; + const VAddr output_address = rp.PeekStaticBuffer(0, &desc_size); + ASSERT(desc_size == max_out_buff_size); + + std::lock_guard lock(connection_status_mutex); + + auto channel = + std::find_if(channel_data.begin(), channel_data.end(), [bind_node_id](const auto& data) { + return data.second.bind_node_id == bind_node_id; + }); + + if (channel == channel_data.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + // TODO(B3N30): Find the right error code + rb.Push(-1); + return; + } + + if (channel->second.received_packets.empty()) { + Memory::ZeroBlock(output_address, desc_size); + IPC::RequestBuilder rb = rp.MakeBuilder(3, 2); + rb.Push(RESULT_SUCCESS); + rb.Push(0); + rb.Push(0); + rb.PushStaticBuffer(output_address, desc_size, 0); + return; + } + + const auto& next_packet = channel->second.received_packets.front(); + + auto secure_data = ParseSecureDataHeader(next_packet); + auto data_size = secure_data.GetActualDataSize(); + + if (data_size > max_out_buff_size) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(0xE10113E9); + return; + } + + IPC::RequestBuilder rb = rp.MakeBuilder(3, 2); + Memory::ZeroBlock(output_address, desc_size); + // Write the actual data. + Memory::WriteBlock(output_address, + next_packet.data() + sizeof(LLCHeader) + sizeof(SecureDataHeader), + data_size); + + rb.Push(RESULT_SUCCESS); + rb.Push(data_size); + rb.Push(secure_data.src_node_id); + rb.PushStaticBuffer(output_address, desc_size, 0); + + channel->second.received_packets.pop_front(); } /** @@ -993,7 +1138,7 @@ const Interface::FunctionInfo FunctionTable[] = { {0x00110040, nullptr, "GetApplicationData"}, {0x00120100, Bind, "Bind"}, {0x00130040, nullptr, "Unbind"}, - {0x001400C0, nullptr, "PullPacket"}, + {0x001400C0, PullPacket, "PullPacket"}, {0x00150080, nullptr, "SetMaxSendDelay"}, {0x00170182, SendTo, "SendTo"}, {0x001A0000, GetChannel, "GetChannel"}, @@ -1018,7 +1163,7 @@ NWM_UDS::NWM_UDS() { NWM_UDS::~NWM_UDS() { network_info = {}; - bind_node_events.clear(); + channel_data.clear(); connection_status_event = nullptr; recv_buffer_memory = nullptr; diff --git a/src/core/hle/service/nwm/nwm_uds.h b/src/core/hle/service/nwm/nwm_uds.h index f1caaf974..5508959fc 100644 --- a/src/core/hle/service/nwm/nwm_uds.h +++ b/src/core/hle/service/nwm/nwm_uds.h @@ -42,7 +42,6 @@ using NodeList = std::vector; enum class NetworkStatus { NotConnected = 3, ConnectedAsHost = 6, - Connecting = 7, ConnectedAsClient = 9, ConnectedAsSpectator = 10, }; diff --git a/src/core/hle/service/nwm/uds_data.cpp b/src/core/hle/service/nwm/uds_data.cpp index 4b389710f..8f0743819 100644 --- a/src/core/hle/service/nwm/uds_data.cpp +++ b/src/core/hle/service/nwm/uds_data.cpp @@ -275,6 +275,15 @@ std::vector GenerateDataPayload(const std::vector& data, u8 channel, u16 return buffer; } +SecureDataHeader ParseSecureDataHeader(const std::vector& data) { + SecureDataHeader header; + + // Skip the LLC header + std::memcpy(&header, data.data() + sizeof(LLCHeader), sizeof(header)); + + return header; +} + std::vector GenerateEAPoLStartFrame(u16 association_id, const NodeInfo& node_info) { EAPoLStartPacket eapol_start{}; eapol_start.association_id = association_id; diff --git a/src/core/hle/service/nwm/uds_data.h b/src/core/hle/service/nwm/uds_data.h index 76bccb1bf..4161025a9 100644 --- a/src/core/hle/service/nwm/uds_data.h +++ b/src/core/hle/service/nwm/uds_data.h @@ -51,6 +51,10 @@ struct SecureDataHeader { u16_be sequence_number; u16_be dest_node_id; u16_be src_node_id; + + u32 GetActualDataSize() { + return protocol_size - sizeof(SecureDataHeader); + } }; static_assert(sizeof(SecureDataHeader) == 14, "SecureDataHeader has the wrong size"); @@ -118,6 +122,11 @@ static_assert(sizeof(EAPoLLogoffPacket) == 0x298, "EAPoLLogoffPacket has the wro std::vector GenerateDataPayload(const std::vector& data, u8 channel, u16 dest_node, u16 src_node, u16 sequence_number); +/* + * Returns the SecureDataHeader stored in an 802.11 data frame. + */ +SecureDataHeader ParseSecureDataHeader(const std::vector& data); + /* * Generates an unencrypted 802.11 data frame body with the EAPoL-Start format for UDS * communication.