From: Gang Yan yangang@kylinos.cn
This patch introduces the '__u32 proto' variable to the 'send_query' and 'recv_nlmsg' functions for further extending function.
In the 'send_query' function, the inclusion of this variable makes the structure clearer and more readable.
In the 'recv_nlmsg' function, the '__u32 proto' variable ensures that the 'diag_info' field remains unmodified when processing IPPROTO_TCP data, thereby preventing unintended transformation into 'mptcp_info' format.
While at it, increment iovlen directly when an item is added to simplify this portion of the code and improve its readaility.
Co-developed-by: Geliang Tang geliang@kernel.org Signed-off-by: Geliang Tang geliang@kernel.org Signed-off-by: Gang Yan yangang@kylinos.cn Reviewed-by: Matthieu Baerts (NGI0) matttbe@kernel.org Signed-off-by: Matthieu Baerts (NGI0) matttbe@kernel.org --- tools/testing/selftests/net/mptcp/mptcp_diag.c | 38 ++++++++++++++------------ 1 file changed, 20 insertions(+), 18 deletions(-)
diff --git a/tools/testing/selftests/net/mptcp/mptcp_diag.c b/tools/testing/selftests/net/mptcp/mptcp_diag.c index 76135aba71ad24c25c7babb6875e8a6dd7636b21..cc0326548e4ec44060da83f1f77e498bcedc82a9 100644 --- a/tools/testing/selftests/net/mptcp/mptcp_diag.c +++ b/tools/testing/selftests/net/mptcp/mptcp_diag.c @@ -62,7 +62,7 @@ static void die_usage(int r) exit(r); }
-static void send_query(int fd, struct inet_diag_req_v2 *r) +static void send_query(int fd, struct inet_diag_req_v2 *r, __u32 proto) { struct sockaddr_nl nladdr = { .nl_family = AF_NETLINK @@ -80,21 +80,22 @@ static void send_query(int fd, struct inet_diag_req_v2 *r) }; struct rtattr rta_proto; struct iovec iov[6]; - int iovlen = 1; - __u32 proto; + int iovlen = 0;
- proto = IPPROTO_MPTCP; - rta_proto.rta_type = INET_DIAG_REQ_PROTOCOL; - rta_proto.rta_len = RTA_LENGTH(sizeof(proto)); - - iov[0] = (struct iovec) { + iov[iovlen++] = (struct iovec) { .iov_base = &req, .iov_len = sizeof(req) }; - iov[iovlen] = (struct iovec){ &rta_proto, sizeof(rta_proto)}; - iov[iovlen + 1] = (struct iovec){ &proto, sizeof(proto)}; - req.nlh.nlmsg_len += RTA_LENGTH(sizeof(proto)); - iovlen += 2; + + if (proto == IPPROTO_MPTCP) { + rta_proto.rta_type = INET_DIAG_REQ_PROTOCOL; + rta_proto.rta_len = RTA_LENGTH(sizeof(proto)); + + iov[iovlen++] = (struct iovec){ &rta_proto, sizeof(rta_proto)}; + iov[iovlen++] = (struct iovec){ &proto, sizeof(proto)}; + req.nlh.nlmsg_len += RTA_LENGTH(sizeof(proto)); + } + struct msghdr msg = { .msg_name = &nladdr, .msg_namelen = sizeof(nladdr), @@ -158,7 +159,7 @@ static void print_info_msg(struct mptcp_info *info) printf("bytes_acked: %llu\n", info->mptcpi_bytes_acked); }
-static void parse_nlmsg(struct nlmsghdr *nlh) +static void parse_nlmsg(struct nlmsghdr *nlh, __u32 proto) { struct inet_diag_msg *r = NLMSG_DATA(nlh); struct rtattr *tb[INET_DIAG_MAX + 1]; @@ -167,7 +168,7 @@ static void parse_nlmsg(struct nlmsghdr *nlh) nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*r)), NLA_F_NESTED);
- if (tb[INET_DIAG_INFO]) { + if (proto == IPPROTO_MPTCP && tb[INET_DIAG_INFO]) { int len = RTA_PAYLOAD(tb[INET_DIAG_INFO]); struct mptcp_info *info;
@@ -183,7 +184,7 @@ static void parse_nlmsg(struct nlmsghdr *nlh) } }
-static void recv_nlmsg(int fd) +static void recv_nlmsg(int fd, __u32 proto) { char rcv_buff[8192]; struct nlmsghdr *nlh = (struct nlmsghdr *)rcv_buff; @@ -216,7 +217,7 @@ static void recv_nlmsg(int fd) -(err->error), strerror(-(err->error))); break; } - parse_nlmsg(nlh); + parse_nlmsg(nlh, proto); nlh = NLMSG_NEXT(nlh, len); } } @@ -230,14 +231,15 @@ static void get_mptcpinfo(__u32 token) .idiag_ext = 1 << (INET_DIAG_INFO - 1), .id.idiag_cookie[0] = token, }; + __u32 proto = IPPROTO_MPTCP; int fd;
fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG); if (fd < 0) die_perror("Netlink socket");
- send_query(fd, &r); - recv_nlmsg(fd); + send_query(fd, &r, proto); + recv_nlmsg(fd, proto);
close(fd); }