/*
 * key.c
 *
 * Copyright (c) 2001 Dug Song <dugsong@arbor.net>
 * Copyright (c) 2001 Arbor Networks, Inc.
 *
 * $Id: key.c,v 1.1.1.1 2001/12/15 00:20:46 dirt Exp $
 */

#include <sys/types.h>
#include <sys/stat.h>
#include <sys/uio.h>

#include <openssl/ssl.h>

#include <errno.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>

#include "key.h"
#include "kn.h"
#include "ssh.h"
#include "util.h"
#include "x509.h"

typedef int (*key_loader)(struct key *, struct iovec *);

static key_loader pubkey_loaders[] = {
	kn_load_public,
	ssh_load_public,
	x509_load_public,
	NULL
};

static key_loader privkey_loaders[] = {
	kn_load_private,
	ssh_load_private,
	x509_load_private,
	NULL
};

static int
load_file(struct iovec *iov, char *filename)
{
	struct stat st;
	int fd;
	
	if ((fd = open(filename, O_RDONLY)) < 0)
		return (-1);
	
	if (fstat(fd, &st) < 0)
		return (-1);
	
	if (st.st_size == 0) {
		errno = EINVAL;
		return (-1);
	}
	if ((iov->iov_base = malloc(st.st_size + 1)) == NULL)
		return (-1);

	iov->iov_len = st.st_size;
	((u_char *)iov->iov_base)[iov->iov_len] = '\0';
	
	if (read(fd, iov->iov_base, iov->iov_len) != iov->iov_len) {
		free(iov->iov_base);
		return (-1);
	}
	close(fd);
	
	return (0);
}

struct key *
key_new(void)
{
	struct key *k;

	if ((k = calloc(sizeof(*k), 1)) == NULL)
		return (NULL);

	return (k);
}

int
key_load_private(struct key *k, char *filename)
{
	struct iovec iov;
	int i;
	
	if (load_file(&iov, filename) < 0)
		return (-1);

	for (i = 0; privkey_loaders[i] != NULL; i++) {
		if (privkey_loaders[i](k, &iov) == 0)
			return (0);
	}
	return (-1);
}

int
key_load_public(struct key *k, char *filename)
{
	struct iovec iov;
	int i;

	if (load_file(&iov, filename) < 0)
		return (-1);

	for (i = 0; pubkey_loaders[i] != NULL; i++) {
		if (pubkey_loaders[i](k, &iov) == 0)
			return (0);
	}
	return (-1);
}

int
key_sign(struct key *k, u_char *msg, int mlen, u_char *sig, int slen)
{
	switch (k->type) {
	case KEY_RSA:
		if (RSA_size((RSA *)k->data) > slen) {
			fprintf(stderr, "RSA modulus too large: %d bits\n",
			    RSA_size((RSA *)k->data));
			return (-1);
		}
		if (RSA_sign(NID_sha1, msg, mlen, sig, &slen,
		    (RSA *)k->data) <= 0) {
			fprintf(stderr, "RSA signing failed\n");
			return (-1);
		}
		break;

	case KEY_DSA:
		if (DSA_size((DSA *)k->data) > slen) {
			fprintf(stderr, "DSA signature size too large: "
			    "%d bits\n", DSA_size((DSA *)k->data));
			return (-1);
		}
		if (DSA_sign(NID_sha1, msg, mlen, sig, &slen,
		    (DSA *)k->data) <= 0) {
			fprintf(stderr, "DSA signing failed\n");
			return (-1);
		}
		break;

	default:
		fprintf(stderr, "Unknown key type: %d\n", k->type);
		return (-1);
	}
	return (slen);
}

int
key_verify(struct key *k, u_char *msg, int mlen, u_char *sig, int slen)
{
	switch (k->type) {
		
	case KEY_RSA:
		if (RSA_verify(NID_sha1, msg, mlen,
		    sig, slen, (RSA *)k->data) <= 0) {
			fprintf(stderr, "RSA verification failed\n");
			return (-1);
		}
		break;
		
	case KEY_DSA:
		if (DSA_verify(NID_sha1, msg, mlen,
		    sig, slen, (DSA *)k->data) <= 0) {
			fprintf(stderr, "DSA verification failed\n");
			return (-1);
		}
		break;
		
	default:
		fprintf(stderr, "Unknown key type: %d\n", k->type);
		return (-1);
	}
	return (slen);
}

void
key_free(struct key *k)
{
	if (k->type == KEY_RSA)
		RSA_free((RSA *)k->data);
	else if (k->type == KEY_DSA)
		DSA_free((DSA *)k->data);
	else if (k->data != NULL)
		free(k->data);
	
	free(k);
}
