--- trunk/mkinitrd-magellan/klibc/usr/kinit/nfsmount/dummypmap.c 2009/04/24 18:09:28 814 +++ trunk/mkinitrd-magellan/klibc/usr/kinit/nfsmount/dummypmap.c 2009/04/24 18:32:46 815 @@ -15,18 +15,23 @@ #include #include +#include "dummypmap.h" #include "sunrpc.h" extern const char *progname; -struct portmap_call { - struct rpc_call rpc; +struct portmap_args { uint32_t program; uint32_t version; uint32_t proto; uint32_t port; }; +struct portmap_call { + struct rpc_call rpc; + struct portmap_args args; +}; + struct portmap_reply { struct rpc_reply rpc; uint32_t port; @@ -66,19 +71,94 @@ } } +static void * get_auth(struct rpc_auth *auth) +{ + switch (ntohl(auth->flavor)) { + case AUTH_NULL: + /* Fallthrough */ + case AUTH_UNIX: + return (char *)&auth->body + ntohl(auth->len); + default: + return NULL; + } +} + +static int check_unix_cred(struct rpc_auth *cred) +{ + uint32_t len; + int quad_len; + uint32_t node_name_len; + int quad_name_len; + uint32_t *base; + uint32_t *pos; + int ret = -1; + + len = ntohl(cred->len); + quad_len = (len + 3) >> 2; + if (quad_len < 6) + /* Malformed creds */ + goto out; + + base = pos = cred->body; + + /* Skip timestamp */ + pos++; + + /* Skip node name: only localhost can succeed. */ + node_name_len = ntohl(*pos++); + quad_name_len = (node_name_len + 3) >> 2; + if (pos + quad_name_len + 3 > base + quad_len) + /* Malformed creds */ + goto out; + pos += quad_name_len; + + /* uid must be 0 */ + if (*pos++ != 0) + goto out; + + /* gid must be 0 */ + if (*pos++ != 0) + goto out; + + /* Skip remaining gids */ + ret = 0; + +out: + return ret; +} + +static int check_cred(struct rpc_auth *cred) +{ + switch (ntohl(cred->flavor)) { + case AUTH_NULL: + return 0; + case AUTH_UNIX: + return check_unix_cred(cred); + default: + return -1; + } +} + +static int check_vrf(struct rpc_auth *vrf) +{ + return (vrf->flavor == htonl(AUTH_NULL)) ? 0 : -1; +} + static int dummy_portmap(int sock, FILE *portmap_file) { struct sockaddr_in sin; int pktlen, addrlen; - union { - struct portmap_call c; - unsigned char b[65536]; /* Max UDP packet size */ - } pkt; + unsigned char pkt[65536]; /* Max UDP packet size */ + /* RPC UDP packets do not include TCP fragment size */ + struct rpc_call *rpc = (struct rpc_call *) &pkt[-4]; + struct rpc_auth *cred; + struct rpc_auth *vrf; + struct portmap_args *args; struct portmap_reply rply; for (;;) { addrlen = sizeof sin; - pktlen = recvfrom(sock, &pkt.c.rpc.hdr.udp, sizeof pkt, 0, + pktlen = recvfrom(sock, &pkt, sizeof pkt, 0, (struct sockaddr *)&sin, &addrlen); if (pktlen < 0) { @@ -92,39 +172,44 @@ if (pktlen + 4 < sizeof(struct portmap_call)) continue; /* Bad packet */ - if (pkt.c.rpc.hdr.udp.msg_type != htonl(RPC_CALL)) + if (rpc->hdr.udp.msg_type != htonl(RPC_CALL)) continue; /* Bad packet */ memset(&rply, 0, sizeof rply); - rply.rpc.hdr.udp.xid = pkt.c.rpc.hdr.udp.xid; + rply.rpc.hdr.udp.xid = rpc->hdr.udp.xid; rply.rpc.hdr.udp.msg_type = htonl(RPC_REPLY); - if (pkt.c.rpc.rpc_vers != htonl(2)) { + cred = (struct rpc_auth *) &rpc->cred_flavor; + if (rpc->rpc_vers != htonl(2)) { rply.rpc.reply_state = htonl(REPLY_DENIED); /* state <- RPC_MISMATCH == 0 */ - } else if (pkt.c.rpc.program != htonl(PORTMAP_PROGRAM)) { + } else if (rpc->program != htonl(PORTMAP_PROGRAM)) { rply.rpc.reply_state = htonl(PROG_UNAVAIL); - } else if (pkt.c.rpc.prog_vers != htonl(2)) { + } else if (rpc->prog_vers != htonl(2)) { rply.rpc.reply_state = htonl(PROG_MISMATCH); - } else if (pkt.c.rpc.cred_len != 0 || pkt.c.rpc.vrf_len != 0) { + } else if (!(vrf = get_auth(cred)) || + (char *)vrf > (char *)pkt + pktlen - 8 - sizeof(*args) || + !(args = get_auth(vrf)) || + (char *)args > (char *)pkt + pktlen - sizeof(*args) || + check_cred(cred) || check_vrf(vrf)) { /* Can't deal with credentials data; the kernel won't send them */ rply.rpc.reply_state = htonl(SYSTEM_ERR); } else { - switch (ntohl(pkt.c.rpc.proc)) { + switch (ntohl(rpc->proc)) { case PMAP_PROC_NULL: break; case PMAP_PROC_SET: - if (pkt.c.proto == htonl(IPPROTO_TCP) || - pkt.c.proto == htonl(IPPROTO_UDP)) { + if (args->proto == htonl(IPPROTO_TCP) || + args->proto == htonl(IPPROTO_UDP)) { if (portmap_file) fprintf(portmap_file, "%u %u %s %u\n", - ntohl(pkt.c.program), - ntohl(pkt.c.version), - protoname(pkt.c.proto), - ntohl(pkt.c.port)); + ntohl(args->program), + ntohl(args->version), + protoname(args->proto), + ntohl(args->port)); rply.port = htonl(1); /* TRUE = success */ } break;