1
0
mirror of https://github.com/RPCS3/rpcs3.git synced 2024-11-21 18:22:33 +01:00

sys_net: improvements

sys_net: implement reusable addr for p2p sockets
sys_net: implement getpeername for P2PS
sys_net: sockets inherit NBIO from their parent
This commit is contained in:
RipleyTom 2022-10-21 00:47:27 +02:00 committed by Megamouse
parent 5f1aafb961
commit 9b121a6414
10 changed files with 275 additions and 118 deletions

View File

@ -1272,7 +1272,7 @@ error_code sceNpBasicAddFriend(vm::cptr<SceNpId> contact, vm::cptr<char> body, s
error_code sceNpBasicGetFriendListEntryCount(vm::ptr<u32> count) error_code sceNpBasicGetFriendListEntryCount(vm::ptr<u32> count)
{ {
sceNp.warning("sceNpBasicGetFriendListEntryCount(count=*0x%x)", count); sceNp.trace("sceNpBasicGetFriendListEntryCount(count=*0x%x)", count);
auto& nph = g_fxo->get<named_thread<np::np_handler>>(); auto& nph = g_fxo->get<named_thread<np::np_handler>>();
@ -1299,7 +1299,7 @@ error_code sceNpBasicGetFriendListEntryCount(vm::ptr<u32> count)
error_code sceNpBasicGetFriendListEntry(u32 index, vm::ptr<SceNpId> npid) error_code sceNpBasicGetFriendListEntry(u32 index, vm::ptr<SceNpId> npid)
{ {
sceNp.warning("sceNpBasicGetFriendListEntry(index=%d, npid=*0x%x)", index, npid); sceNp.trace("sceNpBasicGetFriendListEntry(index=%d, npid=*0x%x)", index, npid);
auto& nph = g_fxo->get<named_thread<np::np_handler>>(); auto& nph = g_fxo->get<named_thread<np::np_handler>>();

View File

@ -48,6 +48,12 @@ public:
sys_net_linger linger; sys_net_linger linger;
}; };
struct sockopt_cache
{
sockopt_data data{};
s32 len = 0;
};
public: public:
SAVESTATE_INIT_POS(7); // Dependency on RPCN SAVESTATE_INIT_POS(7); // Dependency on RPCN
@ -114,7 +120,7 @@ protected:
lv2_socket(utils::serial&, bool); lv2_socket(utils::serial&, bool);
shared_mutex mutex; shared_mutex mutex;
u32 lv2_id = 0; s32 lv2_id = 0;
socket_type socket = 0; socket_type socket = 0;

View File

@ -107,6 +107,9 @@ std::tuple<bool, s32, std::shared_ptr<lv2_socket>, sys_net_sockaddr> lv2_socket_
auto newsock = std::make_shared<lv2_socket_native>(family, type, protocol); auto newsock = std::make_shared<lv2_socket_native>(family, type, protocol);
newsock->set_socket(native_socket, family, type, protocol); newsock->set_socket(native_socket, family, type, protocol);
// Sockets inherit non blocking behaviour from their parent
newsock->so_nbio = so_nbio;
sys_net_sockaddr ps3_addr = native_addr_to_sys_net_addr(native_addr); sys_net_sockaddr ps3_addr = native_addr_to_sys_net_addr(native_addr);
return {true, 0, std::move(newsock), ps3_addr}; return {true, 0, std::move(newsock), ps3_addr};

View File

@ -14,7 +14,7 @@ lv2_socket_p2p::lv2_socket_p2p(lv2_socket_family family, lv2_socket_type type, l
lv2_socket_p2p::lv2_socket_p2p(utils::serial& ar, lv2_socket_type type) lv2_socket_p2p::lv2_socket_p2p(utils::serial& ar, lv2_socket_type type)
: lv2_socket(ar, type) : lv2_socket(ar, type)
{ {
ar(port, vport); ar(port, vport, bound_addr);
std::deque<std::pair<sys_net_sockaddr_in_p2p, std::vector<u8>>> data_dequeue{ar}; std::deque<std::pair<sys_net_sockaddr_in_p2p, std::vector<u8>>> data_dequeue{ar};
@ -29,7 +29,7 @@ lv2_socket_p2p::lv2_socket_p2p(utils::serial& ar, lv2_socket_type type)
void lv2_socket_p2p::save(utils::serial& ar) void lv2_socket_p2p::save(utils::serial& ar)
{ {
static_cast<lv2_socket*>(this)->save(ar, true); static_cast<lv2_socket*>(this)->save(ar, true);
ar(port, vport); ar(port, vport, bound_addr);
std::deque<std::pair<sys_net_sockaddr_in_p2p, std::vector<u8>>> data_dequeue; std::deque<std::pair<sys_net_sockaddr_in_p2p, std::vector<u8>>> data_dequeue;
@ -140,15 +140,22 @@ s32 lv2_socket_p2p::bind(const sys_net_sockaddr& addr)
p2p_vport++; p2p_vport++;
} }
} }
else else if (pport.bound_p2p_vports.contains(p2p_vport))
{ {
if (pport.bound_p2p_vports.contains(p2p_vport)) // Check that all other sockets are SO_REUSEADDR or SO_REUSEPORT
auto& bound_sockets = ::at32(pport.bound_p2p_vports, p2p_vport);
if (!sys_net_helpers::all_reusable(bound_sockets))
{ {
return -SYS_NET_EADDRINUSE; return -SYS_NET_EADDRINUSE;
} }
}
pport.bound_p2p_vports.insert(std::make_pair(p2p_vport, lv2_id)); bound_sockets.insert(lv2_id);
}
else
{
std::set<s32> bound_ports{lv2_id};
pport.bound_p2p_vports.insert(std::make_pair(p2p_vport, std::move(bound_ports)));
}
} }
} }
@ -157,7 +164,7 @@ s32 lv2_socket_p2p::bind(const sys_net_sockaddr& addr)
port = p2p_port; port = p2p_port;
vport = p2p_vport; vport = p2p_vport;
socket = real_socket; socket = real_socket;
last_bound_addr = addr; bound_addr = psa_in_p2p->sin_addr;
} }
return CELL_OK; return CELL_OK;
@ -185,15 +192,26 @@ std::pair<s32, sys_net_sockaddr> lv2_socket_p2p::getsockname()
return {CELL_OK, sn_addr}; return {CELL_OK, sn_addr};
} }
std::tuple<s32, lv2_socket::sockopt_data, u32> lv2_socket_p2p::getsockopt([[maybe_unused]] s32 level, [[maybe_unused]] s32 optname, [[maybe_unused]] u32 len) std::tuple<s32, lv2_socket::sockopt_data, u32> lv2_socket_p2p::getsockopt(s32 level, s32 optname, u32 len)
{ {
// TODO std::lock_guard lock(mutex);
const u64 key = (static_cast<u64>(level) << 32) | static_cast<u64>(optname);
if (!sockopts.contains(key))
{
sys_net.error("Unhandled getsockopt(level=%d, optname=%d, len=%d)", level, optname, len);
return {}; return {};
} }
const auto& cache = ::at32(sockopts, key);
return {CELL_OK, cache.data, cache.len};
}
s32 lv2_socket_p2p::setsockopt(s32 level, s32 optname, const std::vector<u8>& optval) s32 lv2_socket_p2p::setsockopt(s32 level, s32 optname, const std::vector<u8>& optval)
{ {
// TODO std::lock_guard lock(mutex);
int native_int = *reinterpret_cast<const be_t<s32>*>(optval.data()); int native_int = *reinterpret_cast<const be_t<s32>*>(optval.data());
if (level == SYS_NET_SOL_SOCKET && optname == SYS_NET_SO_NBIO) if (level == SYS_NET_SOL_SOCKET && optname == SYS_NET_SO_NBIO)
@ -201,7 +219,14 @@ s32 lv2_socket_p2p::setsockopt(s32 level, s32 optname, const std::vector<u8>& op
so_nbio = native_int; so_nbio = native_int;
} }
return {}; const u64 key = (static_cast<u64>(level) << 32) | static_cast<u64>(optname);
sockopt_cache cache{};
memcpy(&cache.data._int, optval.data(), optval.size());
cache.len = optval.size();
sockopts[key] = std::move(cache);
return CELL_OK;
} }
std::optional<std::tuple<s32, std::vector<u8>, sys_net_sockaddr>> lv2_socket_p2p::recvfrom(s32 flags, u32 len, bool is_lock) std::optional<std::tuple<s32, std::vector<u8>, sys_net_sockaddr>> lv2_socket_p2p::recvfrom(s32 flags, u32 len, bool is_lock)
@ -258,10 +283,12 @@ std::optional<s32> lv2_socket_p2p::sendto(s32 flags, const std::vector<u8>& buf,
inet_ntop(AF_INET, &native_addr.sin_addr, ip_str, sizeof(ip_str)); inet_ntop(AF_INET, &native_addr.sin_addr, ip_str, sizeof(ip_str));
sys_net.trace("[P2P] Sending a packet to %s:%d:%d", ip_str, p2p_port, p2p_vport); sys_net.trace("[P2P] Sending a packet to %s:%d:%d", ip_str, p2p_port, p2p_vport);
std::vector<u8> p2p_data(buf.size() + sizeof(u16)); std::vector<u8> p2p_data(buf.size() + VPORT_P2P_HEADER_SIZE);
const le_t<u16> p2p_vport_le = p2p_vport; const le_t<u16> p2p_vport_le = p2p_vport;
const le_t<u16> p2p_flags_le = P2P_FLAG_P2P;
memcpy(p2p_data.data(), &p2p_vport_le, sizeof(u16)); memcpy(p2p_data.data(), &p2p_vport_le, sizeof(u16));
memcpy(p2p_data.data() + sizeof(u16), buf.data(), buf.size()); memcpy(p2p_data.data() + sizeof(u16), &p2p_flags_le, sizeof(u16));
memcpy(p2p_data.data() + VPORT_P2P_HEADER_SIZE, buf.data(), buf.size());
int native_flags = 0; int native_flags = 0;
if (flags & SYS_NET_MSG_WAITALL) if (flags & SYS_NET_MSG_WAITALL)
@ -307,10 +334,21 @@ void lv2_socket_p2p::close()
auto& p2p_port = ::at32(nc.list_p2p_ports, port); auto& p2p_port = ::at32(nc.list_p2p_ports, port);
{ {
std::lock_guard lock(p2p_port.bound_p2p_vports_mutex); std::lock_guard lock(p2p_port.bound_p2p_vports_mutex);
if (!p2p_port.bound_p2p_vports.contains(vport))
{
return;
}
auto& bound_sockets = ::at32(p2p_port.bound_p2p_vports, vport);
bound_sockets.erase(lv2_id);
if (bound_sockets.empty())
{
p2p_port.bound_p2p_vports.erase(vport); p2p_port.bound_p2p_vports.erase(vport);
} }
} }
} }
}
s32 lv2_socket_p2p::shutdown([[maybe_unused]] s32 how) s32 lv2_socket_p2p::shutdown([[maybe_unused]] s32 how)
{ {

View File

@ -41,4 +41,6 @@ protected:
u32 bound_addr = 0; u32 bound_addr = 0;
// Queue containing received packets from network_thread for SYS_NET_SOCK_DGRAM_P2P sockets // Queue containing received packets from network_thread for SYS_NET_SOCK_DGRAM_P2P sockets
std::queue<std::pair<sys_net_sockaddr_in_p2p, std::vector<u8>>> data{}; std::queue<std::pair<sys_net_sockaddr_in_p2p, std::vector<u8>>> data{};
// List of sock options
std::map<u64, sockopt_cache> sockopts;
}; };

View File

@ -184,7 +184,7 @@ void initialize_tcp_timeout_monitor()
g_fxo->need<named_thread<tcp_timeout_monitor>>(); g_fxo->need<named_thread<tcp_timeout_monitor>>();
} }
u16 u2s_tcp_checksum(const u16* buffer, usz size) u16 u2s_tcp_checksum(const le_t<u16>* buffer, usz size)
{ {
u32 cksum = 0; u32 cksum = 0;
while (size > 1) while (size > 1)
@ -202,20 +202,22 @@ u16 u2s_tcp_checksum(const u16* buffer, usz size)
std::vector<u8> generate_u2s_packet(const p2ps_encapsulated_tcp& header, const u8* data, const u32 datasize) std::vector<u8> generate_u2s_packet(const p2ps_encapsulated_tcp& header, const u8* data, const u32 datasize)
{ {
const u32 packet_size = (sizeof(u16) + sizeof(p2ps_encapsulated_tcp) + datasize); const u32 packet_size = (VPORT_P2P_HEADER_SIZE + sizeof(p2ps_encapsulated_tcp) + datasize);
ensure(packet_size < 65535); // packet size shouldn't be bigger than possible UDP payload ensure(packet_size < 65535); // packet size shouldn't be bigger than possible UDP payload
std::vector<u8> packet(packet_size); std::vector<u8> packet(packet_size);
u8* packet_data = packet.data(); u8* packet_data = packet.data();
le_t<u16> dst_port_le = +header.dst_port; le_t<u16> dst_port_le = +header.dst_port;
le_t<u16> p2p_flags_le = P2P_FLAG_P2PS;
memcpy(packet_data, &dst_port_le, sizeof(u16)); memcpy(packet_data, &dst_port_le, sizeof(u16));
memcpy(packet_data + sizeof(u16), &header, sizeof(p2ps_encapsulated_tcp)); memcpy(packet_data + sizeof(u16), &p2p_flags_le, sizeof(u16));
memcpy(packet_data + VPORT_P2P_HEADER_SIZE, &header, sizeof(p2ps_encapsulated_tcp));
if (datasize) if (datasize)
memcpy(packet_data + sizeof(u16) + sizeof(p2ps_encapsulated_tcp), data, datasize); memcpy(packet_data + VPORT_P2P_HEADER_SIZE + sizeof(p2ps_encapsulated_tcp), data, datasize);
auto* hdr_ptr = reinterpret_cast<p2ps_encapsulated_tcp*>(packet_data + sizeof(u16)); auto* hdr_ptr = reinterpret_cast<p2ps_encapsulated_tcp*>(packet_data + VPORT_P2P_HEADER_SIZE);
hdr_ptr->checksum = 0; hdr_ptr->checksum = 0;
hdr_ptr->checksum = u2s_tcp_checksum(utils::bless<u16>(hdr_ptr), sizeof(p2ps_encapsulated_tcp) + datasize); hdr_ptr->checksum = u2s_tcp_checksum(utils::bless<le_t<u16>>(hdr_ptr), sizeof(p2ps_encapsulated_tcp) + datasize);
return packet; return packet;
} }
@ -225,7 +227,7 @@ lv2_socket_p2ps::lv2_socket_p2ps(lv2_socket_family family, lv2_socket_type type,
{ {
} }
lv2_socket_p2ps::lv2_socket_p2ps(socket_type socket, u16 port, u16 vport, u32 op_addr, u16 op_port, u16 op_vport, u64 cur_seq, u64 data_beg_seq) lv2_socket_p2ps::lv2_socket_p2ps(socket_type socket, u16 port, u16 vport, u32 op_addr, u16 op_port, u16 op_vport, u64 cur_seq, u64 data_beg_seq, s32 so_nbio)
: lv2_socket_p2p(SYS_NET_AF_INET, SYS_NET_SOCK_STREAM_P2P, SYS_NET_IPPROTO_IP) : lv2_socket_p2p(SYS_NET_AF_INET, SYS_NET_SOCK_STREAM_P2P, SYS_NET_IPPROTO_IP)
{ {
this->socket = socket; this->socket = socket;
@ -236,6 +238,7 @@ lv2_socket_p2ps::lv2_socket_p2ps(socket_type socket, u16 port, u16 vport, u32 op
this->op_vport = op_vport; this->op_vport = op_vport;
this->cur_seq = cur_seq; this->cur_seq = cur_seq;
this->data_beg_seq = data_beg_seq; this->data_beg_seq = data_beg_seq;
this->so_nbio = so_nbio;
status = p2ps_stream_status::stream_connected; status = p2ps_stream_status::stream_connected;
} }
@ -410,7 +413,7 @@ bool lv2_socket_p2ps::handle_listening(p2ps_encapsulated_tcp* tcp_header, [[mayb
const u16 new_op_vport = tcp_header->src_port; const u16 new_op_vport = tcp_header->src_port;
const u64 new_cur_seq = send_hdr.seq + 1; const u64 new_cur_seq = send_hdr.seq + 1;
const u64 new_data_beg_seq = send_hdr.ack; const u64 new_data_beg_seq = send_hdr.ack;
auto sock_lv2 = std::make_shared<lv2_socket_p2ps>(socket, port, vport, new_op_addr, new_op_port, new_op_vport, new_cur_seq, new_data_beg_seq); auto sock_lv2 = std::make_shared<lv2_socket_p2ps>(socket, port, vport, new_op_addr, new_op_port, new_op_vport, new_cur_seq, new_data_beg_seq, so_nbio);
const s32 new_sock_id = idm::import_existing<lv2_socket>(sock_lv2); const s32 new_sock_id = idm::import_existing<lv2_socket>(sock_lv2);
sock_lv2->set_lv2_id(new_sock_id); sock_lv2->set_lv2_id(new_sock_id);
const u64 key_connected = (reinterpret_cast<struct sockaddr_in*>(op_addr)->sin_addr.s_addr) | (static_cast<u64>(tcp_header->src_port) << 48) | (static_cast<u64>(tcp_header->dst_port) << 32); const u64 key_connected = (reinterpret_cast<struct sockaddr_in*>(op_addr)->sin_addr.s_addr) | (static_cast<u64>(tcp_header->src_port) << 48) | (static_cast<u64>(tcp_header->dst_port) << 32);
@ -494,6 +497,27 @@ void lv2_socket_p2ps::set_status(p2ps_stream_status new_status)
status = new_status; status = new_status;
} }
std::pair<s32, sys_net_sockaddr> lv2_socket_p2ps::getpeername()
{
std::lock_guard lock(mutex);
if (!op_addr || !op_port || !op_vport)
{
return {-SYS_NET_ENOTCONN, {}};
}
sys_net_sockaddr res{};
sys_net_sockaddr_in_p2p* p2p_addr = reinterpret_cast<sys_net_sockaddr_in_p2p*>(&res);
p2p_addr->sin_len = sizeof(sys_net_sockaddr_in_p2p);
p2p_addr->sin_family = SYS_NET_AF_INET;
p2p_addr->sin_addr = std::bit_cast<be_t<u32>, u32>(op_addr);
p2p_addr->sin_port = op_vport;
p2p_addr->sin_vport = op_port;
return {CELL_OK, res};
}
std::tuple<bool, s32, std::shared_ptr<lv2_socket>, sys_net_sockaddr> lv2_socket_p2ps::accept(bool is_lock) std::tuple<bool, s32, std::shared_ptr<lv2_socket>, sys_net_sockaddr> lv2_socket_p2ps::accept(bool is_lock)
{ {
std::unique_lock<shared_mutex> lock(mutex, std::defer_lock); std::unique_lock<shared_mutex> lock(mutex, std::defer_lock);
@ -572,26 +596,36 @@ s32 lv2_socket_p2ps::bind(const sys_net_sockaddr& addr)
if (p2p_vport == 0) if (p2p_vport == 0)
{ {
p2p_vport = 30000; p2p_vport = 30000;
while (pport.bound_p2p_streams.contains((static_cast<u64>(p2p_vport) << 32))) while (pport.bound_p2ps_vports.contains(p2p_vport))
{ {
p2p_vport++; p2p_vport++;
} }
pport.bound_p2p_streams.emplace((static_cast<u64>(p2p_vport) << 32), lv2_id); std::set<s32> bound_ports{lv2_id};
pport.bound_p2ps_vports.insert(std::make_pair(p2p_vport, std::move(bound_ports)));
} }
else else
{ {
const u64 key = (static_cast<u64>(p2p_vport) << 32); if (pport.bound_p2ps_vports.contains(p2p_vport))
if (pport.bound_p2p_streams.contains(key)) {
auto& bound_sockets = ::at32(pport.bound_p2ps_vports, p2p_vport);
if (!sys_net_helpers::all_reusable(bound_sockets))
{ {
return -SYS_NET_EADDRINUSE; return -SYS_NET_EADDRINUSE;
} }
pport.bound_p2p_streams.emplace(key, lv2_id);
bound_sockets.insert(lv2_id);
}
else
{
std::set<s32> bound_ports{lv2_id};
pport.bound_p2ps_vports.insert(std::make_pair(p2p_vport, std::move(bound_ports)));
}
} }
port = p2p_port; port = p2p_port;
vport = p2p_vport; vport = p2p_vport;
socket = real_socket; socket = real_socket;
last_bound_addr = addr; bound_addr = psa_in_p2p->sin_addr;
} }
} }
@ -817,13 +851,24 @@ void lv2_socket_p2ps::close()
std::lock_guard lock(p2p_port.bound_p2p_vports_mutex); std::lock_guard lock(p2p_port.bound_p2p_vports_mutex);
for (auto it = p2p_port.bound_p2p_streams.begin(); it != p2p_port.bound_p2p_streams.end();) for (auto it = p2p_port.bound_p2p_streams.begin(); it != p2p_port.bound_p2p_streams.end();)
{ {
if (static_cast<u32>(it->second) == lv2_id) if (it->second == lv2_id)
{ {
it = p2p_port.bound_p2p_streams.erase(it); it = p2p_port.bound_p2p_streams.erase(it);
continue; continue;
} }
it++; it++;
} }
if (p2p_port.bound_p2ps_vports.contains(vport))
{
auto& bound_ports = ::at32(p2p_port.bound_p2ps_vports, vport);
bound_ports.erase(lv2_id);
if (bound_ports.empty())
{
p2p_port.bound_p2ps_vports.erase(vport);
}
}
} }
} }
} }

View File

@ -51,14 +51,14 @@ enum p2ps_tcp_flags : u8
}; };
void initialize_tcp_timeout_monitor(); void initialize_tcp_timeout_monitor();
u16 u2s_tcp_checksum(const u16* buffer, usz size); u16 u2s_tcp_checksum(const le_t<u16>* buffer, usz size);
std::vector<u8> generate_u2s_packet(const p2ps_encapsulated_tcp& header, const u8* data, const u32 datasize); std::vector<u8> generate_u2s_packet(const p2ps_encapsulated_tcp& header, const u8* data, const u32 datasize);
class lv2_socket_p2ps final : public lv2_socket_p2p class lv2_socket_p2ps final : public lv2_socket_p2p
{ {
public: public:
lv2_socket_p2ps(lv2_socket_family family, lv2_socket_type type, lv2_ip_protocol protocol); lv2_socket_p2ps(lv2_socket_family family, lv2_socket_type type, lv2_ip_protocol protocol);
lv2_socket_p2ps(socket_type socket, u16 port, u16 vport, u32 op_addr, u16 op_port, u16 op_vport, u64 cur_seq, u64 data_beg_seq); lv2_socket_p2ps(socket_type socket, u16 port, u16 vport, u32 op_addr, u16 op_port, u16 op_vport, u64 cur_seq, u64 data_beg_seq, s32 so_nbio);
lv2_socket_p2ps(utils::serial& ar, lv2_socket_type type); lv2_socket_p2ps(utils::serial& ar, lv2_socket_type type);
void save(utils::serial& ar); void save(utils::serial& ar);
@ -73,6 +73,7 @@ public:
std::optional<s32> connect(const sys_net_sockaddr& addr) override; std::optional<s32> connect(const sys_net_sockaddr& addr) override;
std::pair<s32, sys_net_sockaddr> getpeername() override;
std::pair<s32, sys_net_sockaddr> getsockname() override; std::pair<s32, sys_net_sockaddr> getsockname() override;
s32 listen(s32 backlog) override; s32 listen(s32 backlog) override;

View File

@ -13,6 +13,33 @@
LOG_CHANNEL(sys_net); LOG_CHANNEL(sys_net);
namespace sys_net_helpers
{
bool all_reusable(const std::set<s32>& sock_ids)
{
for (const s32 sock_id : sock_ids)
{
const auto [_, reusable] = idm::check<lv2_socket>(sock_id, [&](lv2_socket& sock) -> bool
{
auto [res_reuseaddr, optval_reuseaddr, optlen_reuseaddr] = sock.getsockopt(SYS_NET_SOL_SOCKET, SYS_NET_SO_REUSEADDR, sizeof(s32));
auto [res_reuseport, optval_reuseport, optlen_reuseport] = sock.getsockopt(SYS_NET_SOL_SOCKET, SYS_NET_SO_REUSEPORT, sizeof(s32));
const bool reuse_addr = optlen_reuseaddr == 4 && !!optval_reuseaddr._int;
const bool reuse_port = optlen_reuseport == 4 && !!optval_reuseport._int;
return (reuse_addr || reuse_port);
});
if (!reusable)
{
return false;
}
}
return true;
}
} // namespace sys_net_helpers
nt_p2p_port::nt_p2p_port(u16 port) nt_p2p_port::nt_p2p_port(u16 port)
: port(port) : port(port)
{ {
@ -164,6 +191,16 @@ bool nt_p2p_port::recv_data()
} }
} }
if (recv_res < VPORT_P2P_HEADER_SIZE)
{
return true;
}
const u16 vport_flags = *reinterpret_cast<le_t<u16>*>(p2p_recv_data.data() + sizeof(u16));
std::vector<u8> p2p_data(recv_res - VPORT_P2P_HEADER_SIZE);
memcpy(p2p_data.data(), p2p_recv_data.data() + VPORT_P2P_HEADER_SIZE, p2p_data.size());
if (vport_flags & P2P_FLAG_P2P)
{ {
std::lock_guard lock(bound_p2p_vports_mutex); std::lock_guard lock(bound_p2p_vports_mutex);
if (bound_p2p_vports.contains(dst_vport)) if (bound_p2p_vports.contains(dst_vport))
@ -176,37 +213,41 @@ bool nt_p2p_port::recv_data()
p2p_addr.sin_vport = dst_vport; p2p_addr.sin_vport = dst_vport;
p2p_addr.sin_port = std::bit_cast<be_t<u16>, u16>(reinterpret_cast<struct sockaddr_in*>(&native_addr)->sin_port); p2p_addr.sin_port = std::bit_cast<be_t<u16>, u16>(reinterpret_cast<struct sockaddr_in*>(&native_addr)->sin_port);
std::vector<u8> p2p_data(recv_res - sizeof(u16)); auto& bound_sockets = ::at32(bound_p2p_vports, dst_vport);
memcpy(p2p_data.data(), p2p_recv_data.data() + sizeof(u16), recv_res - sizeof(u16));
const auto sock = idm::check<lv2_socket>(::at32(bound_p2p_vports, dst_vport), [&](lv2_socket& sock) for (const auto sock_id : bound_sockets)
{
const auto sock = idm::check<lv2_socket>(sock_id, [&](lv2_socket& sock)
{ {
ensure(sock.get_type() == SYS_NET_SOCK_DGRAM_P2P); ensure(sock.get_type() == SYS_NET_SOCK_DGRAM_P2P);
auto& sock_p2p = reinterpret_cast<lv2_socket_p2p&>(sock); auto& sock_p2p = reinterpret_cast<lv2_socket_p2p&>(sock);
sock_p2p.handle_new_data(std::move(p2p_addr), std::move(p2p_data)); sock_p2p.handle_new_data(p2p_addr, p2p_data);
}); });
// Should not happen in theory
if (!sock) if (!sock)
{
sys_net.error("Socket %d found in bound_p2p_vports didn't exist!", sock_id);
bound_sockets.erase(sock_id);
if (bound_sockets.empty())
{
bound_p2p_vports.erase(dst_vport); bound_p2p_vports.erase(dst_vport);
}
}
}
return true; return true;
} }
} }
else if (vport_flags & P2P_FLAG_P2PS)
// Not directed at a bound DGRAM_P2P vport so check if the packet is a STREAM-P2P packet {
if (p2p_data.size() < sizeof(p2ps_encapsulated_tcp))
const auto sp_size = recv_res - sizeof(u16);
u8* sp_data = p2p_recv_data.data() + sizeof(u16);
if (sp_size < sizeof(p2ps_encapsulated_tcp))
{ {
sys_net.notice("Received P2P packet targeted at unbound vport(likely) or invalid(vport=%d)", dst_vport); sys_net.notice("Received P2P packet targeted at unbound vport(likely) or invalid(vport=%d)", dst_vport);
return true; return true;
} }
auto* tcp_header = reinterpret_cast<p2ps_encapsulated_tcp*>(sp_data); auto* tcp_header = reinterpret_cast<p2ps_encapsulated_tcp*>(p2p_data.data());
// Validate signature & length // Validate signature & length
if (tcp_header->signature != P2PS_U2S_SIG) if (tcp_header->signature != P2PS_U2S_SIG)
@ -215,7 +256,7 @@ bool nt_p2p_port::recv_data()
return true; return true;
} }
if (tcp_header->length != (sp_size - sizeof(p2ps_encapsulated_tcp))) if (tcp_header->length != (p2p_data.size() - sizeof(p2ps_encapsulated_tcp)))
{ {
sys_net.error("Received STREAM-P2P packet tcp length didn't match packet length"); sys_net.error("Received STREAM-P2P packet tcp length didn't match packet length");
return true; return true;
@ -231,7 +272,7 @@ bool nt_p2p_port::recv_data()
// Validate checksum // Validate checksum
u16 given_checksum = tcp_header->checksum; u16 given_checksum = tcp_header->checksum;
tcp_header->checksum = 0; tcp_header->checksum = 0;
if (given_checksum != u2s_tcp_checksum(reinterpret_cast<const u16*>(sp_data), sp_size)) if (given_checksum != u2s_tcp_checksum(reinterpret_cast<const le_t<u16>*>(p2p_data.data()), p2p_data.size()))
{ {
sys_net.error("Checksum is invalid, dropping packet!"); sys_net.error("Checksum is invalid, dropping packet!");
return true; return true;
@ -239,7 +280,6 @@ bool nt_p2p_port::recv_data()
// The packet is valid, check if it's bound // The packet is valid, check if it's bound
const u64 key_connected = (reinterpret_cast<struct sockaddr_in*>(&native_addr)->sin_addr.s_addr) | (static_cast<u64>(tcp_header->src_port) << 48) | (static_cast<u64>(tcp_header->dst_port) << 32); const u64 key_connected = (reinterpret_cast<struct sockaddr_in*>(&native_addr)->sin_addr.s_addr) | (static_cast<u64>(tcp_header->src_port) << 48) | (static_cast<u64>(tcp_header->dst_port) << 32);
const u64 key_listening = (static_cast<u64>(tcp_header->dst_port) << 32);
{ {
std::lock_guard lock(bound_p2p_vports_mutex); std::lock_guard lock(bound_p2p_vports_mutex);
@ -247,18 +287,23 @@ bool nt_p2p_port::recv_data()
{ {
const auto sock_id = ::at32(bound_p2p_streams, key_connected); const auto sock_id = ::at32(bound_p2p_streams, key_connected);
sys_net.trace("Received packet for connected STREAM-P2P socket(s=%d)", sock_id); sys_net.trace("Received packet for connected STREAM-P2P socket(s=%d)", sock_id);
handle_connected(sock_id, tcp_header, sp_data + sizeof(p2ps_encapsulated_tcp), &native_addr); handle_connected(sock_id, tcp_header, p2p_data.data() + sizeof(p2ps_encapsulated_tcp), &native_addr);
return true; return true;
} }
if (bound_p2p_streams.contains(key_listening)) if (bound_p2ps_vports.contains(tcp_header->dst_port))
{
const auto& bound_sockets = ::at32(bound_p2ps_vports, tcp_header->dst_port);
for (const auto sock_id : bound_sockets)
{ {
const auto sock_id = ::at32(bound_p2p_streams, key_listening);
sys_net.trace("Received packet for listening STREAM-P2P socket(s=%d)", sock_id); sys_net.trace("Received packet for listening STREAM-P2P socket(s=%d)", sock_id);
handle_listening(sock_id, tcp_header, sp_data + sizeof(p2ps_encapsulated_tcp), &native_addr); handle_listening(sock_id, tcp_header, p2p_data.data() + sizeof(p2ps_encapsulated_tcp), &native_addr);
}
return true; return true;
} }
} }
}
sys_net.notice("Received a STREAM-P2P packet with no bound target"); sys_net.notice("Received a STREAM-P2P packet with no bound target");
return true; return true;

View File

@ -1,5 +1,7 @@
#pragma once #pragma once
#include <set>
#include "lv2_socket_p2ps.h" #include "lv2_socket_p2ps.h"
#ifdef _WIN32 #ifdef _WIN32
@ -18,6 +20,14 @@
#endif #endif
#endif #endif
constexpr s32 VPORT_P2P_HEADER_SIZE = sizeof(u16) + sizeof(u16);
enum VPORT_P2P_FLAGS
{
P2P_FLAG_P2P = 1,
P2P_FLAG_P2PS = 1 << 1,
};
struct signaling_message struct signaling_message
{ {
u32 src_addr = 0; u32 src_addr = 0;
@ -26,6 +36,11 @@ struct signaling_message
std::vector<u8> data; std::vector<u8> data;
}; };
namespace sys_net_helpers
{
bool all_reusable(const std::set<s32>& sock_ids);
}
struct nt_p2p_port struct nt_p2p_port
{ {
// Real socket where P2P packets are received/sent // Real socket where P2P packets are received/sent
@ -33,9 +48,11 @@ struct nt_p2p_port
u16 port = 0; u16 port = 0;
shared_mutex bound_p2p_vports_mutex; shared_mutex bound_p2p_vports_mutex;
// For DGRAM_P2P sockets(vport, sock_id) // For DGRAM_P2P sockets (vport, sock_ids)
std::map<u16, s32> bound_p2p_vports{}; std::map<u16, std::set<s32>> bound_p2p_vports{};
// For STREAM_P2P sockets(key, sock_id) // For STREAM_P2P sockets (vport, sock_ids)
std::map<u16, std::set<s32>> bound_p2ps_vports{};
// List of active(either from a connect or an accept) P2PS sockets (key, sock_id)
// key is ( (src_vport) << 48 | (dst_vport) << 32 | addr ) with src_vport and addr being 0 for listening sockets // key is ( (src_vport) << 48 | (dst_vport) << 32 | addr ) with src_vport and addr being 0 for listening sockets
std::map<u64, s32> bound_p2p_streams{}; std::map<u64, s32> bound_p2p_streams{};

View File

@ -71,7 +71,7 @@ namespace rpcn
return get_localized_string(rpcn_state_to_localized_string_id(state)); return get_localized_string(rpcn_state_to_localized_string_id(state));
} }
constexpr u32 RPCN_PROTOCOL_VERSION = 17; constexpr u32 RPCN_PROTOCOL_VERSION = 18;
constexpr usz RPCN_HEADER_SIZE = 15; constexpr usz RPCN_HEADER_SIZE = 15;
constexpr usz COMMUNICATION_ID_SIZE = 9; constexpr usz COMMUNICATION_ID_SIZE = 9;