Service/UDS: Updated BeginHostingNetwork

This commit is contained in:
B3n30 2017-10-04 09:18:21 +02:00
parent f6d16c3f87
commit ed9db735a2
3 changed files with 104 additions and 30 deletions

View file

@ -4,6 +4,7 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <atomic>
#include <cstring> #include <cstring>
#include <list> #include <list>
#include <mutex> #include <mutex>
@ -27,6 +28,12 @@
namespace Service { namespace Service {
namespace NWM { namespace NWM {
namespace ErrCodes {
enum {
NotInitialized = 2,
};
} // namespace ErrCodes
// Event that is signaled every time the connection status changes. // Event that is signaled every time the connection status changes.
static Kernel::SharedPtr<Kernel::Event> connection_status_event; static Kernel::SharedPtr<Kernel::Event> connection_status_event;
@ -37,6 +44,8 @@ static Kernel::SharedPtr<Kernel::SharedMemory> recv_buffer_memory;
// Connection status of this 3DS. // Connection status of this 3DS.
static ConnectionStatus connection_status{}; static ConnectionStatus connection_status{};
static std::atomic<bool> initialized(false);
/* Node information about the current network. /* Node information about the current network.
* The amount of elements in this vector is always the maximum number * The amount of elements in this vector is always the maximum number
* of nodes specified in the network configuration. * of nodes specified in the network configuration.
@ -155,7 +164,7 @@ void HandleAssociationResponseFrame(const Network::WifiPacket& packet) {
"Could not join network"); "Could not join network");
{ {
std::lock_guard<std::mutex> lock(connection_status_mutex); std::lock_guard<std::mutex> lock(connection_status_mutex);
ASSERT(connection_status.status == static_cast<u32>(NetworkStatus::NotConnected)); ASSERT(connection_status.status == static_cast<u32>(NetworkStatus::Connecting));
} }
// Send the EAPoL-Start packet to the server. // Send the EAPoL-Start packet to the server.
@ -171,8 +180,9 @@ void HandleAssociationResponseFrame(const Network::WifiPacket& packet) {
} }
static void HandleEAPoLPacket(const Network::WifiPacket& packet) { static void HandleEAPoLPacket(const Network::WifiPacket& packet) {
std::lock_guard<std::recursive_mutex> hle_lock(HLE::g_hle_lock); std::unique_lock<std::recursive_mutex> hle_lock(HLE::g_hle_lock, std::defer_lock);
std::lock_guard<std::mutex> lock(connection_status_mutex); std::unique_lock<std::mutex> lock(connection_status_mutex, std::defer_lock);
std::lock(hle_lock, lock);
if (GetEAPoLFrameType(packet.data) == EAPoLStartMagic) { if (GetEAPoLFrameType(packet.data) == EAPoLStartMagic) {
if (connection_status.status != static_cast<u32>(NetworkStatus::ConnectedAsHost)) { if (connection_status.status != static_cast<u32>(NetworkStatus::ConnectedAsHost)) {
@ -220,7 +230,7 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) {
// The 3ds does this presumably to support spectators. // The 3ds does this presumably to support spectators.
connection_status_event->Signal(); connection_status_event->Signal();
} else { } else {
if (connection_status.status != static_cast<u32>(NetworkStatus::NotConnected)) { if (connection_status.status != static_cast<u32>(NetworkStatus::Connecting)) {
LOG_DEBUG(Service_NWM, "Connection sequence aborted, because connection status is %u", LOG_DEBUG(Service_NWM, "Connection sequence aborted, because connection status is %u",
connection_status.status); connection_status.status);
return; return;
@ -249,15 +259,15 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) {
// Some games require ConnectToNetwork to block, for now it doesn't // Some games require ConnectToNetwork to block, for now it doesn't
// If blocking is implemented this lock needs to be changed, // If blocking is implemented this lock needs to be changed,
// otherwise it might cause deadlocks // otherwise it might cause deadlocks
std::lock_guard<std::recursive_mutex> lock(HLE::g_hle_lock);
connection_status_event->Signal(); connection_status_event->Signal();
} }
} }
static void HandleSecureDataPacket(const Network::WifiPacket& packet) { static void HandleSecureDataPacket(const Network::WifiPacket& packet) {
auto secure_data = ParseSecureDataHeader(packet.data); auto secure_data = ParseSecureDataHeader(packet.data);
std::lock_guard<std::recursive_mutex> hle_lock(HLE::g_hle_lock); std::unique_lock<std::recursive_mutex> hle_lock(HLE::g_hle_lock, std::defer_lock);
std::lock_guard<std::mutex> lock(connection_status_mutex); std::unique_lock<std::mutex> lock(connection_status_mutex, std::defer_lock);
std::lock(hle_lock, lock);
if (secure_data.src_node_id == connection_status.network_node_id) { if (secure_data.src_node_id == connection_status.network_node_id) {
// Ignore packets that came from ourselves. // Ignore packets that came from ourselves.
@ -315,7 +325,7 @@ void StartConnectionSequence(const MacAddress& server) {
WifiPacket auth_request; WifiPacket auth_request;
{ {
std::lock_guard<std::mutex> lock(connection_status_mutex); std::lock_guard<std::mutex> lock(connection_status_mutex);
ASSERT(connection_status.status == static_cast<u32>(NetworkStatus::NotConnected)); connection_status.status = static_cast<u32>(NetworkStatus::Connecting);
// TODO(Subv): Handle timeout. // TODO(Subv): Handle timeout.
@ -546,6 +556,8 @@ static void InitializeWithVersion(Interface* self) {
recv_buffer_memory = Kernel::g_handle_table.Get<Kernel::SharedMemory>(sharedmem_handle); recv_buffer_memory = Kernel::g_handle_table.Get<Kernel::SharedMemory>(sharedmem_handle);
initialized = true;
ASSERT_MSG(recv_buffer_memory->size == sharedmem_size, "Invalid shared memory size."); ASSERT_MSG(recv_buffer_memory->size == sharedmem_size, "Invalid shared memory size.");
{ {
@ -614,8 +626,12 @@ static void GetNodeInformation(Interface* self) {
IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0xD, 1, 0); IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0xD, 1, 0);
u16 network_node_id = rp.Pop<u16>(); u16 network_node_id = rp.Pop<u16>();
IPC::RequestBuilder rb = rp.MakeBuilder(11, 0); if (!initialized) {
rb.Push(RESULT_SUCCESS); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
rb.Push(ResultCode(ErrorDescription::NotInitialized, ErrorModule::UDS,
ErrorSummary::StatusChanged, ErrorLevel::Status));
return;
}
{ {
std::lock_guard<std::mutex> lock(connection_status_mutex); std::lock_guard<std::mutex> lock(connection_status_mutex);
@ -623,7 +639,15 @@ static void GetNodeInformation(Interface* self) {
[network_node_id](const NodeInfo& node) { [network_node_id](const NodeInfo& node) {
return node.network_node_id == network_node_id; return node.network_node_id == network_node_id;
}); });
ASSERT(itr != node_info.end()); if (itr == node_info.end()) {
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
rb.Push(ResultCode(ErrorDescription::NotFound, ErrorModule::UDS,
ErrorSummary::WrongArgument, ErrorLevel::Status));
return;
}
IPC::RequestBuilder rb = rp.MakeBuilder(11, 0);
rb.Push(RESULT_SUCCESS);
rb.PushRaw<NodeInfo>(*itr); rb.PushRaw<NodeInfo>(*itr);
} }
LOG_DEBUG(Service_NWM, "called"); LOG_DEBUG(Service_NWM, "called");
@ -653,13 +677,29 @@ static void Bind(Interface* self) {
LOG_DEBUG(Service_NWM, "called"); LOG_DEBUG(Service_NWM, "called");
if (data_channel == 0) { if (data_channel == 0 || bind_node_id == 0) {
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS, rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS,
ErrorSummary::WrongArgument, ErrorLevel::Usage)); ErrorSummary::WrongArgument, ErrorLevel::Usage));
return; return;
} }
constexpr size_t MaxBindNodes = 16;
if (channel_data.size() >= MaxBindNodes) {
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
rb.Push(ResultCode(ErrorDescription::OutOfMemory, ErrorModule::UDS,
ErrorSummary::OutOfResource, ErrorLevel::Status));
return;
}
constexpr u32 MinRecvBufferSize = 0x5F4;
if (recv_buffer_size < MinRecvBufferSize) {
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
rb.Push(ResultCode(ErrorDescription::TooLarge, ErrorModule::UDS,
ErrorSummary::WrongArgument, ErrorLevel::Usage));
return;
}
// Create a new event for this bind node. // Create a new event for this bind node.
auto event = Kernel::Event::Create(Kernel::ResetType::OneShot, auto event = Kernel::Event::Create(Kernel::ResetType::OneShot,
"NWM::BindNodeEvent" + std::to_string(bind_node_id)); "NWM::BindNodeEvent" + std::to_string(bind_node_id));
@ -687,6 +727,12 @@ static void Unbind(Interface* self) {
IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0x12, 1, 0); IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0x12, 1, 0);
u32 bind_node_id = rp.Pop<u32>(); u32 bind_node_id = rp.Pop<u32>();
if (bind_node_id == 0) {
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS,
ErrorSummary::WrongArgument, ErrorLevel::Usage));
return;
}
std::lock_guard<std::mutex> lock(connection_status_mutex); std::lock_guard<std::mutex> lock(connection_status_mutex);
@ -699,8 +745,13 @@ static void Unbind(Interface* self) {
channel_data.erase(itr); channel_data.erase(itr);
} }
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); IPC::RequestBuilder rb = rp.MakeBuilder(5, 0);
rb.Push(RESULT_SUCCESS); rb.Push(RESULT_SUCCESS);
rb.Push(bind_node_id);
// TODO(B3N30): Find out what the other return values are
rb.Push<u32>(0);
rb.Push<u32>(0);
rb.Push<u32>(0);
} }
/** /**
@ -729,13 +780,14 @@ static void BeginHostingNetwork(Interface* self) {
LOG_DEBUG(Service_NWM, "called"); LOG_DEBUG(Service_NWM, "called");
{
std::lock_guard<std::mutex> lock(connection_status_mutex);
Memory::ReadBlock(network_info_address, &network_info, sizeof(NetworkInfo)); Memory::ReadBlock(network_info_address, &network_info, sizeof(NetworkInfo));
// The real UDS module throws a fatal error if this assert fails. // The real UDS module throws a fatal error if this assert fails.
ASSERT_MSG(network_info.max_nodes > 1, "Trying to host a network of only one member."); ASSERT_MSG(network_info.max_nodes > 1, "Trying to host a network of only one member.");
{
std::lock_guard<std::mutex> lock(connection_status_mutex);
connection_status.status = static_cast<u32>(NetworkStatus::ConnectedAsHost); connection_status.status = static_cast<u32>(NetworkStatus::ConnectedAsHost);
// Ensure the application data size is less than the maximum value. // Ensure the application data size is less than the maximum value.
@ -749,11 +801,13 @@ static void BeginHostingNetwork(Interface* self) {
connection_status.max_nodes = network_info.max_nodes; connection_status.max_nodes = network_info.max_nodes;
// Resize the nodes list to hold max_nodes. // Resize the nodes list to hold max_nodes.
node_info.clear();
node_info.resize(network_info.max_nodes); node_info.resize(network_info.max_nodes);
// There's currently only one node in the network (the host). // There's currently only one node in the network (the host).
connection_status.total_nodes = 1; connection_status.total_nodes = 1;
network_info.total_nodes = 1; network_info.total_nodes = 1;
// The host is always the first node // The host is always the first node
connection_status.network_node_id = 1; connection_status.network_node_id = 1;
current_node.network_node_id = 1; current_node.network_node_id = 1;
@ -762,12 +816,22 @@ static void BeginHostingNetwork(Interface* self) {
connection_status.node_bitmask |= 1; connection_status.node_bitmask |= 1;
// Notify the application that the first node was set. // Notify the application that the first node was set.
connection_status.changed_nodes |= 1; connection_status.changed_nodes |= 1;
node_info[0] = current_node;
if (auto room_member = Network::GetRoomMember().lock()) {
if (room_member->IsConnected()) {
network_info.host_mac_address = room_member->GetMacAddress();
} else {
network_info.host_mac_address = {{0x0, 0x0, 0x0, 0x0, 0x0, 0x0}};
} }
}
node_info[0] = current_node;
// If the game has a preferred channel, use that instead. // If the game has a preferred channel, use that instead.
if (network_info.channel != 0) if (network_info.channel != 0)
network_channel = network_info.channel; network_channel = network_info.channel;
else
network_info.channel = DefaultNetworkChannel;
}
connection_status_event->Signal(); connection_status_event->Signal();
@ -775,8 +839,7 @@ static void BeginHostingNetwork(Interface* self) {
CoreTiming::ScheduleEvent(msToCycles(DefaultBeaconInterval * MillisecondsPerTU), CoreTiming::ScheduleEvent(msToCycles(DefaultBeaconInterval * MillisecondsPerTU),
beacon_broadcast_event, 0); beacon_broadcast_event, 0);
LOG_WARNING(Service_NWM, LOG_DEBUG(Service_NWM, "An UDS network has been created.");
"An UDS network has been created, but broadcasting it is unimplemented.");
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
rb.Push(RESULT_SUCCESS); rb.Push(RESULT_SUCCESS);
@ -929,6 +992,14 @@ static void PullPacket(Interface* self) {
ASSERT(desc_size == max_out_buff_size); ASSERT(desc_size == max_out_buff_size);
std::lock_guard<std::mutex> lock(connection_status_mutex); std::lock_guard<std::mutex> lock(connection_status_mutex);
if (connection_status.status != static_cast<u32>(NetworkStatus::ConnectedAsHost) &&
connection_status.status != static_cast<u32>(NetworkStatus::ConnectedAsClient) &&
connection_status.status != static_cast<u32>(NetworkStatus::ConnectedAsSpectator)) {
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS,
ErrorSummary::InvalidState, ErrorLevel::Status));
return;
}
auto channel = auto channel =
std::find_if(channel_data.begin(), channel_data.end(), [bind_node_id](const auto& data) { std::find_if(channel_data.begin(), channel_data.end(), [bind_node_id](const auto& data) {
@ -937,8 +1008,8 @@ static void PullPacket(Interface* self) {
if (channel == channel_data.end()) { if (channel == channel_data.end()) {
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
// TODO(B3N30): Find the right error code rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS,
rb.Push<u32>(-1); ErrorSummary::WrongArgument, ErrorLevel::Usage));
return; return;
} }
@ -959,7 +1030,8 @@ static void PullPacket(Interface* self) {
if (data_size > max_out_buff_size) { if (data_size > max_out_buff_size) {
IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
rb.Push<u32>(0xE10113E9); rb.Push(ResultCode(ErrorDescription::TooLarge, ErrorModule::UDS,
ErrorSummary::WrongArgument, ErrorLevel::Usage));
return; return;
} }
@ -1225,6 +1297,7 @@ NWM_UDS::~NWM_UDS() {
channel_data.clear(); channel_data.clear();
connection_status_event = nullptr; connection_status_event = nullptr;
recv_buffer_memory = nullptr; recv_buffer_memory = nullptr;
initialized = false;
{ {
std::lock_guard<std::mutex> lock(connection_status_mutex); std::lock_guard<std::mutex> lock(connection_status_mutex);

View file

@ -32,7 +32,7 @@ struct NodeInfo {
std::array<u16_le, 10> username; std::array<u16_le, 10> username;
INSERT_PADDING_BYTES(4); INSERT_PADDING_BYTES(4);
u16_le network_node_id; u16_le network_node_id;
std::array<u8, 6> address; INSERT_PADDING_BYTES(6);
}; };
static_assert(sizeof(NodeInfo) == 40, "NodeInfo has incorrect size."); static_assert(sizeof(NodeInfo) == 40, "NodeInfo has incorrect size.");
@ -42,6 +42,7 @@ using NodeList = std::vector<NodeInfo>;
enum class NetworkStatus { enum class NetworkStatus {
NotConnected = 3, NotConnected = 3,
ConnectedAsHost = 6, ConnectedAsHost = 6,
Connecting = 7,
ConnectedAsClient = 9, ConnectedAsClient = 9,
ConnectedAsSpectator = 10, ConnectedAsSpectator = 10,
}; };

View file

@ -52,7 +52,7 @@ struct SecureDataHeader {
u16_be dest_node_id; u16_be dest_node_id;
u16_be src_node_id; u16_be src_node_id;
u32 GetActualDataSize() { u32 GetActualDataSize() const {
return protocol_size - sizeof(SecureDataHeader); return protocol_size - sizeof(SecureDataHeader);
} }
}; };