/*  $Header: /proj/software/pub/CVSROOT/uClinux/uC-src/intersil-tools/prismoids.c,v 1.1 2003/02/23 23:31:23 mrustad Exp $
 *
 *  Copyright (C) 2002 Intersil Americas Inc
 *
 *  This library is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU Lesser General Public
 *  License as published by the Free Software Foundation; either
 *  version 2.1 of the License.
 *
 *  This library is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 *  Lesser General Public License for more details.
 *
 *  You should have received a copy of the GNU Lesser General Public
 *  License along with this library; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#include <net/if.h>
#include <errno.h>
#include <unistd.h>
#include <netinet/in.h>

#include "prismoids.h"

/* Uncomment this to see debug output */
//#define DEBUG

/* Handler for getting/setting OIDs */
static struct islnetlink_handle getsetoid_h = { fd: 0 };

/* Prototypes */
static int netlink_send(struct islnetlink_handle *islnetlink_h, struct nlmsghdr *nlmsg);
static int netlink_recv(int fd, struct nlmsghdr *nlmsg,	int max_nlmsglen, int groups);
static int pimfor_encode_header(struct pimfor_hdr *header, int operation,
				unsigned long oid, int device_id, int flags, int length);
static int pimfor_decode_header(struct pimfor_hdr *header, int *operation, unsigned long *oid,
				int *device_id, int *flags, int *length);
static int isl_netlink_open(struct islnetlink_handle *islnetlink_h, __u32 trapgroup);
static int isl_netlink_close(struct islnetlink_handle *islnetlink_h);
static int build_pimfor_packet(struct nlmsghdr **ret_packet, int operation,
				const char *ifname, __u32 oid, void *data, __u32 data_len);
static int pimfor_send_rcv(struct islnetlink_handle *islnetlink_h, const char *ifname, int operation,
			__u32 oid, void *data, __u32 data_len, struct nlmsghdr **rcv_packet);

/* Open a prism socket for getting/setting oids. On success, this function
 * returns a file descriptor to the socket, otherwise, a value < 0 is returned
 */
int prismoid_open(void)
{
	int ret;

	if(getsetoid_h.fd > 0)
		return getsetoid_h.fd;

	ret = isl_netlink_open(&getsetoid_h, 0);
	if(ret < 0)
		return ret;
	else
		return getsetoid_h.fd;
}

int prismoid_close(void)
{
	return isl_netlink_close(&getsetoid_h);
}


int get_prismoid(char *ifname, unsigned int oid, void *data, int data_len)
{
	struct nlmsghdr *rcv_packet;
	struct pimfor_hdr *pimfor_packet;
	int ret = -1;
	int ret_oper, ret_devid, ret_flags, ret_length;
	unsigned long ret_oid;

	if(getsetoid_h.fd <= 0) {
		ret = isl_netlink_open(&getsetoid_h, 0);
		if(ret < 0) {
			return ret;
		}
	}

	if(pimfor_send_rcv(&getsetoid_h, ifname, PIMFOR_OP_GET,
			oid, data, data_len, &rcv_packet) < 0) {
		fprintf(stderr, "Get: send_rcv failed\n");
		return -1;
	}

	pimfor_packet = NLMSG_DATA(rcv_packet);

	if(pimfor_decode_header(pimfor_packet, &ret_oper,
		&ret_oid, &ret_devid, &ret_flags, &ret_length) < 0) {
		fprintf(stderr, "Get: invalid PIMFOR header\n");
		goto out;
	}

#ifdef DEBUG
        printf("GET: PIMFOR reply op %d oid %lx devid %d flgs %d len %d\n", ret_oper, ret_oid, ret_devid, ret_flags, ret_length);
#endif

	if((ret_oper == PIMFOR_OP_RESPONSE) &&
	   /* Should we check this as well??? (ret_devid == devid) && */
	   (ret_length <= data_len)) {
		memcpy(data, PIMFOR_DATA(pimfor_packet), ret_length);
		ret = 0;
	} else {
		fprintf(stderr, "GET: Error in PIMFOR reply %d %lx %d %d %d\n", ret_oper,
		ret_oid, ret_devid, ret_flags, ret_length);
	}

out:
	free(rcv_packet);
	return ret;
}


int set_prismoid(char *ifname, unsigned int oid, void *data, int data_len)
{
	struct nlmsghdr *rcv_packet;
	struct pimfor_hdr *pimfor_packet;
	int ret = -1;
	int ret_oper, ret_devid, ret_flags, ret_length;
	unsigned long ret_oid;

	if(getsetoid_h.fd <= 0) {
		ret = isl_netlink_open(&getsetoid_h, 0);
		if(ret < 0) {
			return ret;
		}
	}

	if(pimfor_send_rcv(&getsetoid_h, ifname, PIMFOR_OP_SET,
			oid, data, data_len, &rcv_packet) < 0) {
		fprintf(stderr, "Set: send_rcv failed\n");
		return -1;
	}

	pimfor_packet = NLMSG_DATA(rcv_packet);

	if(pimfor_decode_header(pimfor_packet, &ret_oper,
		&ret_oid, &ret_devid, &ret_flags, &ret_length) < 0) {
		fprintf(stderr, "Set: invalid PIMFOR header\n");
		goto out;
	}
#ifdef DEBUG
        printf("SET: PIMFOR reply op %d oid %lx devid %d flgs %d len %d\n", ret_oper, ret_oid, ret_devid, ret_flags, ret_length);
#endif

	if((ret_oper == PIMFOR_OP_RESPONSE) &&
	   /* Should we check this as well??? (ret_devid == devid) && */
	   (ret_length <= data_len)) {
		memcpy(data, PIMFOR_DATA(pimfor_packet), ret_length);
		ret = 0;
	} else {
		fprintf(stderr, "SET: Error in PIMFOR reply %d %lx %d %d %d\n", ret_oper,
		ret_oid, ret_devid, ret_flags, ret_length);
	}

out:
	free(rcv_packet);
	return ret;
}


/* Open a prism Traps socket. On success, this function returns a
 * file descriptor to the socket, otherwise, a value < 0 is returned
 */
int prismtraps_open(struct islnetlink_handle *islnetlink_h, __u32 trapgroup)
{
	int ret = isl_netlink_open(islnetlink_h, trapgroup);

	if(ret < 0)
		return ret;
	else
		return islnetlink_h->fd;
}

int prismtraps_close(struct islnetlink_handle *islnetlink_h)
{
	return isl_netlink_close(islnetlink_h);
}


int prismtraps_decode(struct islnetlink_handle *islnetlink_h, __u32 *trapoid,
			__u32 *devid, char **trapdata, __u32 *data_len)
{
        int ret_oper, ret_devid, ret_flags, ret_length, status;
        unsigned long ret_oid;
        struct nlmsghdr *trap_packet;
        struct pimfor_hdr *pimfor_packet;
        int err = -1;
        char peek_buffer[NLMSG_ALIGN(sizeof(struct nlmsghdr)) + sizeof(struct pimfor_hdr)];
        int packetsize = sizeof(peek_buffer);

        /* We don't know the size of the trapdata yet, we peek the headers to find it out... */
        status = recvfrom(islnetlink_h->fd, peek_buffer, packetsize, MSG_PEEK, NULL, NULL);
        if(status < 0)
        {
            fprintf(stderr,"recvfrom failed");
            return -1;
	}

        pimfor_packet = (struct pimfor_hdr*)NLMSG_DATA(peek_buffer);

        if(pimfor_decode_header(pimfor_packet, &ret_oper,
		&ret_oid, &ret_devid, &ret_flags, &ret_length) < 0) {
		fprintf(stderr, "Trap: invalid PIMFOR header\n");
                return -1;
	}

        *trapoid  = ret_oid;
        *data_len = ret_length;
        *devid = ret_devid;

#ifdef DEBUG
        printf("PIMFOR CHECK TRAP oper %d oid %lx devid %d flgs %d len %d\n", ret_oper, ret_oid, ret_devid, ret_flags, ret_length);
#endif

        /* Okay, now we know the data-size, get the trap data... */
        packetsize = (NLMSG_SPACE(ret_length + sizeof(struct pimfor_hdr)));

        trap_packet = malloc(packetsize);
	if(!trap_packet) {
		fprintf(stderr, "Cannot allocate NETLINK packet");
                return -1;
	}

        /* Clear the mallocced area... */
        memset(trap_packet, 0, packetsize);

        status = netlink_recv(islnetlink_h->fd, trap_packet, packetsize, 0);
        if(status < 0) {
            fprintf(stderr,"netlink_recvfrom failed");
            goto out;
	}

	if(ret_oper == PIMFOR_OP_TRAP) {
            /* We malloc the data here, it's the applications responsibility to free it! */
            *trapdata = malloc(ret_length);
            if (!*trapdata) {
                fprintf(stderr, "Couldn't malloc data... \n");
                goto out;
            }

            pimfor_packet = NLMSG_DATA(trap_packet);
            memcpy(*trapdata, PIMFOR_DATA(pimfor_packet), ret_length);

            err = 0;
        }

out:
        free(trap_packet);
	return err;

}


/* helper functions */


/* For conversion big/little endian */
#define swab32(x) \
({ \
        __u32 __x = (x); \
        ((__u32)( \
                (((__u32)(__x) & (__u32)0x000000ffUL) << 24) | \
                (((__u32)(__x) & (__u32)0x0000ff00UL) <<  8) | \
                (((__u32)(__x) & (__u32)0x00ff0000UL) >>  8) | \
                (((__u32)(__x) & (__u32)0xff000000UL) >> 24) )); \
})


static int netlink_send(struct islnetlink_handle *islnetlink_h, struct nlmsghdr *nlmsg)
{
	int status;
	struct sockaddr_nl nladdr;
	struct iovec iov = { (void*)nlmsg, nlmsg->nlmsg_len };
	struct msghdr msg = {
		(void*)&nladdr, sizeof(nladdr),
		&iov,	1,
		NULL,	0,
		0
	};

	memset(&nladdr, 0, sizeof(nladdr));
	nladdr.nl_family = AF_NETLINK;
	nladdr.nl_pid = 0;
	nladdr.nl_groups = 0;

	nlmsg->nlmsg_seq = ++islnetlink_h->seq;

	status = sendmsg(islnetlink_h->fd, &msg, 0);

	if (status < 0) {
		perror("Failed to send netlink msg");
	}

	return status;
}

/*
 * Receive a netlink packet and copy it to the buffer pointed to by nlmsg
 * On success, the length of the received message is returned, otherwise -1.
 */
static int netlink_recv(int fd, struct nlmsghdr *nlmsg,
			int max_nlmsglen, int groups)
{
	int status;
	int nlmsglen;
	int payload_len;

	struct nlmsghdr *h;
	struct sockaddr_nl nladdr;
	struct iovec iov;

	struct msghdr msg = {
		(void*)&nladdr, sizeof(nladdr),
		&iov,	1,
		NULL,	0,
		0
	};

	memset(&nladdr, 0, sizeof(nladdr));
	nladdr.nl_family = AF_NETLINK;
	nladdr.nl_pid = 0;
	nladdr.nl_groups = groups;

	iov.iov_base = nlmsg;
	iov.iov_len = max_nlmsglen;

	memset(nlmsg, 0, max_nlmsglen);

retry:
	status = recvmsg(fd, &msg, 0);

	if (status < 0) {
		if (errno == EINTR)
			goto retry;

		perror("Failed to receive netlink msg");
		return -1;
	}
	if (status == 0) {
		fprintf(stderr, "EOF on netlink\n");
		return -1;
	}

	if (msg.msg_flags & MSG_TRUNC) {
		fprintf(stderr, "Message truncated\n");
		return -1;
	}

	h = (struct nlmsghdr *)nlmsg;

	nlmsglen = h->nlmsg_len;
	payload_len = nlmsglen - sizeof(struct nlmsghdr);

#ifdef DEBUG
        printf("Netlink packet received of length %d, %d status = %d, seq = %d\n", nlmsglen, payload_len, status, h->nlmsg_seq);
#endif

	if (payload_len < 0 || nlmsglen > status) {
		fprintf(stderr, "Invalid netlink message\n");
		return -1;
	}

	if (h->nlmsg_type == NLMSG_ERROR) {
		struct nlmsgerr *nlerr = (struct nlmsgerr *)NLMSG_DATA(h);
		if (payload_len < sizeof(struct nlmsgerr)) {
			fprintf(stderr, "Netlink error truncated\n");
		} else {
			errno = -nlerr->error;
			if (errno == 0) {
				return nlmsglen;
			}
			perror("netlink error");
		}
		return -1;
	}

	return nlmsglen;
}


static int pimfor_encode_header(struct pimfor_hdr *header, int operation,
				unsigned long oid, int device_id, int flags, int length)
{
	if(!header)
		return -1;

	// byte oriented members
	header->version = PIMFOR_VERSION_1;
	header->operation = operation;
	header->device_id = device_id;
	header->flags = flags;

	// word oriented members with byte order depending on the flags
	if (flags & PIMFOR_FLAG_LITTLE_ENDIAN) {
		// use little endian coding
		header->oid = ntohl(swab32(oid));
		header->length = ntohl(swab32(length));
	} else {
		// use big endian coding
		header->oid = ntohl(oid);
		header->length = ntohl(length);
	}
	return 0;
}

static int pimfor_decode_header(struct pimfor_hdr *header, int *operation,
				 unsigned long *oid, int *device_id,
				 int *flags, int *length)
{
	if(!header)
		return -1;

	// byte oriented members
	*operation = header->operation;
	*device_id = header->device_id;
	*flags = header->flags;

	// word oriented members with byte order depending on the flags
	if (*flags & PIMFOR_FLAG_LITTLE_ENDIAN) {
		// use little endian coding
		*oid = ntohl(swab32(header->oid));
		*length = ntohl(swab32(header->length));
	} else {
		// use big endian coding
		*oid = ntohl(header->oid);
		*length = ntohl(header->length);
	}

	return 0;
}

static int isl_netlink_open(struct islnetlink_handle *islnetlink_h, __u32 trapgroup)
{
	memset(islnetlink_h, 0, sizeof(struct islnetlink_handle));

	islnetlink_h->fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_ISIL);

	if (islnetlink_h->fd < 0) {
		perror("Cannot open prism traps netlink socket");
		return -1;
	}

	islnetlink_h->local.nl_family = AF_NETLINK;
	islnetlink_h->local.nl_groups = trapgroup;

	if(bind(islnetlink_h->fd, (struct sockaddr *)&islnetlink_h->local,
			sizeof(islnetlink_h->local)) < 0) {
		perror("Cannot bind prism netlink socket");
		goto err_out;
	}

	return 0;

err_out:
	close(islnetlink_h->fd);
	islnetlink_h->fd = 0;

	return -1;
}

static int isl_netlink_close(struct islnetlink_handle *islnetlink_h)
{
	if(close(islnetlink_h->fd) < 0) {
		perror("Could not close prismoid netlink socket");
		return -1;
	}
	islnetlink_h->fd = 0;

	return 0;
}

static int build_pimfor_packet(struct nlmsghdr **ret_packet, int operation,
				const char *ifname, __u32 oid, void *data, __u32 data_len)
{
	struct nlmsghdr *packet;
	/* determine the size of the buffer that can hold both the request
	 * and response netlink packet. This is the sum of the netlink
	 * msgheader, the PIMFOR header and the payload.
         * For some reason, the length size is also alligned in the netlink macros,
         * so we use the NLMSG_SPACE macro here i.s.o. the NLMSG_LENGTH macro.
         */
	int packetsize = NLMSG_SPACE(data_len + sizeof(struct pimfor_hdr));
	struct pimfor_hdr *pimfor_packet;
	struct ifreq ifr;
	int sockfd, devid;

	packet = malloc(packetsize);
	if(!packet) {
		fprintf(stderr, "Cannot allocate PIMFOR packet\n");
		return -1;
	}
	memset(packet, 0, packetsize);
	packet->nlmsg_type = NETLINK_TYPE_PIMFOR;
	packet->nlmsg_len = packetsize;

	/* Get the ifindex of the interface. This is the devid in the PIMFOR message */
	sockfd = socket(AF_INET, SOCK_STREAM, 0);
	if (sockfd < 0) {
		perror("Cannot open socket");
		goto err_out_no_close;
	}
	strncpy(ifr.ifr_name, ifname, sizeof(ifr.ifr_name));

	if (ioctl(sockfd, SIOCGIFINDEX, &ifr) < 0) {
		fprintf(stderr, "Cannot get ifindex of interface %s: ", ifname);
		perror("");
		goto err_out;
	}
	devid = ifr.ifr_ifindex;

	/* make the PIMFOR packet */
	pimfor_packet = (struct pimfor_hdr *) NLMSG_DATA(packet);
	switch(operation) {
		case PIMFOR_OP_GET:
			pimfor_encode_header(pimfor_packet, PIMFOR_OP_GET, oid, devid, 0, data_len);
			/* We also need to copy here cause the get should get/store data somewhere */
			memcpy(PIMFOR_DATA(pimfor_packet), data, data_len);

			break;
		case PIMFOR_OP_SET:
			pimfor_encode_header(pimfor_packet, PIMFOR_OP_SET, oid, devid, 0, data_len);
			memcpy(PIMFOR_DATA(pimfor_packet), data, data_len);
			break;
		case PIMFOR_OP_RESPONSE:
		case PIMFOR_OP_ERROR:
		case PIMFOR_OP_TRAP:
		default:
			fprintf(stderr, "PIMFOR operation %d not supported\n", operation);
			goto err_out;
			break;
	}

	close(sockfd);
	*ret_packet = packet;
	return packetsize;

err_out:
	close(sockfd);
err_out_no_close:
	free(packet);
	*ret_packet = NULL;

	return -1;
}

static int pimfor_send_rcv(struct islnetlink_handle *islnetlink_h, const char *ifname, int operation,
			__u32 oid, void *data, __u32 data_len, struct nlmsghdr **rcv_packet)
{
	struct nlmsghdr *packet;
	int packetsize;

	if(islnetlink_h->fd <= 0) {
		perror("Netlink socket not open");
		return -1;
	}

	packetsize = build_pimfor_packet(&packet, operation, ifname, oid, data, data_len);
	if(packetsize < 0) {
		fprintf(stderr, "Failed to build PIMFOR packet\n");
		goto err_out;
	}
	if(netlink_send(islnetlink_h, packet) < 0) {
		perror("Get: netlink send failed ");
		goto err_out;
	}

	/* Now wait for the reply message and return the data */
	memset(packet, 0, packetsize);
	if(netlink_recv(islnetlink_h->fd, packet, packetsize, 0) < 0) {
		fprintf(stderr, "netlink recv failed ");
		goto err_out;
	}

	if(packet->nlmsg_seq != islnetlink_h->seq) {
		fprintf(stderr, "Get: invalid netlink sequence %d %d\n",
			packet->nlmsg_seq, islnetlink_h->seq);
		goto err_out;
	}

	*rcv_packet = packet;
	return 0;

err_out:
	free(packet);
	*rcv_packet = NULL;
	return -1;
}
