/*
 * Sparse list implementation
 * Copywrite (c) 2009 Steven Fuerst
 *
 * This file may be licensed under the terms of of the
 * GNU General Public License Version 2 or any later version (the ``GPL'').
 *
 * Software distributed under the License is distributed
 * on an ``AS IS'' basis, WITHOUT WARRANTY OF ANY KIND, either
 * express or implied. See the GPL for the specific language
 * governing rights and limitations.
 *
 * You should have received a copy of the GPL along with this
 * program. If not, go to http://www.gnu.org/licenses/gpl.html
 * or write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */


#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <err.h>

#define SP_MASK	(-(uintptr_t) 16)	/* portable way of writing ~15 pointer mask */
#define SP_PTR	0x8		/* Bitflag to denote a pointer start */

#define SP_SHIFT ((int) (8*sizeof(void *) - 3))	/* Three bits at a time */
#define SP_COUNT ((int) ((sizeof(void *) / 2) * 5)) /* Number of triplets */
#define SP_OFFSET ((int) (8*sizeof(void *) - 3 * SP_COUNT)) /* Remainder */

//#define DEBUG_SPLIST

typedef struct splist splist;

struct splist
{
	uintptr_t snext;
} __attribute__((aligned(16)));

static int is_marked(splist *s)
{
	return s->snext & SP_PTR;
}

static void sp_init(splist *s)
{
	s->snext = (uintptr_t) s;
#ifdef DEBUG_SPLIST
	if (s->snext & 15) errx(1, "Incorrectly aligned node\n");
#endif
}

splist *sp_next(splist *s)
{
	return (splist *) (s->snext & SP_MASK);
}


/* Read a pointer starting at s */
static splist *read_ptr(splist *s)
{
	uintptr_t ptr = 0;
	int count;
	
	splist *n = sp_next(s);
	
	/* Scan forward, collating the back pointer */
	for (count = 0; count < SP_COUNT; count++)
	{
		ptr >>= 3;
		ptr += s->snext << SP_SHIFT;

#ifdef DEBUG_SPLIST	
		if (count && is_marked(s)) errx(1, "Start set in middle of pointer\n");
#endif
	
		s = n;
		n = sp_next(s);
	}
		
#ifdef DEBUG_SPLIST
	if (!is_marked((splist *) ptr)) errx(1, "Back link broken\n");
#endif

	return (splist *) ptr;
}

/* Write a pointer at s that points to p */
static splist *write_ptr(splist *s, splist *p)
{
	uintptr_t ptr = ((uintptr_t) p) >> SP_OFFSET;
	
	int count;
	
	splist *start = s;
	splist *n;
	
	for (count = 0; count < SP_COUNT; count++)
	{
		/* Get next pointer */
		n = sp_next(s);
	
		/* Add pointer bits to pointer */
		s->snext = (uintptr_t) n + (ptr & 7);
		ptr >>= 3;
		
		/* Move to next */
		s = n;
	}
	
	/* Finally turn on marker bit */
	start->snext += SP_PTR;
	
	return s;
}

splist *sp_prev(splist *node)
{
	splist *n, *p;

#ifdef DEBUG_SPLIST
	int count = 0;
#endif

	/* Scan forward until we either reach ourselves, or find a marked bit */
	for (p = node; !is_marked(p); p = n)
	{
		n = sp_next(p);
		
		/* Went around the loop? */
		if (n == node) return p;

#ifdef DEBUG_SPLIST
		count++;
		if (count >= SP_COUNT * 2) errx(1, "Count too big\n");
#endif
	}
	
	/* Jump along backwards link */
	n = read_ptr(p);
		
	/* Look for the node previous to the one we wanted */
	do
	{
		p = n;
		n = sp_next(n);

#ifdef DEBUG_SPLIST	
		count++;
		if (count >= SP_COUNT * 2) errx(1, "Count too big\n");
#endif
	}
	while (n != node);
	
	return p;
}

#ifdef DEBUG_SPLIST
static void test_splist(splist *s)
{
	splist *n = sp_next(s);
	splist *p;
	
	splist *prev = NULL;

	int count = 0;

	while (n != s)
	{
		p = n;
		n = sp_next(n);
		if (sp_prev(n) != p)  errx(1, "next - prev missmatch\n");
		
		if (is_marked(p))
		{
			if (prev && (read_ptr(p) != prev)) errx(1, "prev pointer broken\n");
			prev = p;
		}
		
		count++;
		
		if (count > 1000000) errx(1, "loop found\n");
	}
}
#else
#define test_splist(S) ((void) S)
#endif

/* Insert node after s */
void sp_insert(splist *s, splist *node)
{
	int count_before = 0;
	int count_after = 1;
	splist *ptr_before;
	splist *ptr_after;
	
	splist *n = sp_next(s);

#ifdef DEBUG_SPLIST
	if (node->snext & 15) errx(1, "Incorrectly aligned node\n");
#endif
	
	/* Find back link after us */
	while (!is_marked(n))
	{
		count_after++;
		n = sp_next(n);
		
		/* Full loop? */
		if (n == s)
		{
			/* Insert it */
			node->snext = s->snext;
			s->snext = (uintptr_t) node;
			
			/* Enough for two pointers? */
			if (count_after == SP_COUNT * 2 - 1)
			{
				int count;
	
				/* Skip past the pointer, so we know where to point to */
				for (count = 0; count < SP_COUNT; count++)
				{
					n = sp_next(n);
				}
				
				write_ptr(s, n);
				write_ptr(n, s);
				test_splist(s);
			}
			
			/* Enough room for a single pointer? */
			else if (count_after >= SP_COUNT)
			{
				/* Create a self-linked pointer */
				write_ptr(s, s);
				test_splist(s);
			}
			
			return;
		}
	}
	
	ptr_after = n;
	ptr_before = read_ptr(ptr_after);
	n = ptr_before;
	
	for (n = ptr_before; n != s; n = sp_next(n))
	{
		count_before++;
	}
	
	/* If too many, then split sequence */
	if (count_before + count_after >= SP_COUNT * 2 - 1)
	{
		n = read_ptr(ptr_before);
		
		/* Insert it (may overwrite first pointer) */
		node->snext = s->snext;
		s->snext = (uintptr_t) node;
	
		n = write_ptr(ptr_before,n);
		write_ptr(n, ptr_before);
		write_ptr(ptr_after, n);
		test_splist(s);
	}
	else
	{
		if (count_before < SP_COUNT)
		{
			n = read_ptr(ptr_before);
		
			/* Insert it (overwrites pointer) */
			node->snext = s->snext;
			s->snext = (uintptr_t) node;
			
			write_ptr(ptr_before, n);
			test_splist(s);
		}
		else
		{
			/* Insert it */
			node->snext = s->snext;
			s->snext = (uintptr_t) node;
			test_splist(s);
		}
	}
}

/* Split an overly-large region into two (or possibly even three) */
static void nice_split(splist *start)
{
	splist *n = sp_next(start);
	splist *new = NULL;
	splist *new2 = NULL;
	splist *end;
	int count = 1;
	
	while (!is_marked(n))
	{
		if (count == SP_COUNT) new = n;
		if (count == SP_COUNT * 2) new2 = n;
		n = sp_next(n);
		count++;
	}
	
	end = n;
	
	if (count >= SP_COUNT * 3)
	{
		write_ptr(new, start);
		write_ptr(new2, new);
		write_ptr(end, new2);
		
		test_splist(start);
	}
	else if (count >= SP_COUNT * 2)
	{
		write_ptr(new, start);
		write_ptr(end, new);
		
		test_splist(start);
	}
	else
	{
		write_ptr(end, start);
		
		test_splist(start);
	}
}

/*
 * Delete node which is directly after p
 */
void sp_del_node(splist *p, splist *node)
{
	int count_before = 0;
	int count_after = 1;
	splist *ptr_before;
	splist *ptr_after;
	
	splist *n;
	
#ifdef DEBUG_SPLIST
	if (sp_next(p) != node) errx(1, "Forward link doesn't point to node\n");
#endif
	
	if (is_marked(node))
	{
		n = read_ptr(node);
		
		/* Disconnect the node */
		p->snext = (uintptr_t) sp_next(node) + (p->snext & ~SP_MASK);
		
		/* Are we exactly the size of one pointer? */
		if (n == node)
		{
			test_splist(p);
			return;
		}
		
		nice_split(n);
		
		return;
	}
	
	/* Not marked, need to find how large the region is */
	n = sp_next(node);
	while (!is_marked(n))
	{
		count_after++;
		n = sp_next(n);
		
		/* Full loop? */
		if (n == node)
		{
			/* Just delete it */
			p->snext = node->snext;
			test_splist(p);
			return;
		}
	}
	
	ptr_after = n;
	ptr_before = read_ptr(n);
	
	/* Count nodes before us */
	for (n = ptr_before; n != node; n = sp_next(n))
	{
		count_before++;
	}
	
	/* Doesn't overlap pointer? */
	if (count_before > SP_COUNT)
	{
		/* Just delete it */
		p->snext = node->snext;
		test_splist(p);
		return;
	}
	
	n = read_ptr(ptr_before);
	p->snext = (uintptr_t) sp_next(node) + (p->snext & ~SP_MASK);
	
	/* Will fit */
	if (count_before + count_after > SP_COUNT)
	{
		write_ptr(ptr_before, n);
		test_splist(p);
		return;
	}
	
	/* Remove pointer */
	ptr_before->snext &= ~SP_PTR;
	
	if (ptr_before == ptr_after)
	{
		/* Pointer no longer fits */
		test_splist(p);
		return;
	}
	
	/* Jump over removed pointer */
	nice_split(n);
}

/* Delete node, previous node calculation folded in */
void sp_del(splist *node)
{
	int count_before = 0;
	int count_after = 1;
	splist *ptr_before;
	splist *ptr_after;
	
	splist *n, *p = NULL;

	if (is_marked(node))
	{
		/* Jump along backwards link */
		splist *t = read_ptr(node);
		
		/* Look for the node previous to the one we wanted */
		n = t;
		do
		{
			p = n;
			n = sp_next(n);
		}
		while (n != node);
		
		/* Disconnect the node */
		p->snext = (uintptr_t) sp_next(node) + (p->snext & ~SP_MASK);
		
		/* Are we exactly the size of one pointer? */
		if (t == node)
		{
			test_splist(p);
			return;
		}
		
		nice_split(t);
		
		return;
	}
	
	/* Not marked, need to find how large the region is */
	n = sp_next(node);
	while (!is_marked(n))
	{
		count_after++;
		p = n;
		n = sp_next(n);
		
		/* Full loop? */
		if (n == node)
		{
			/* Just delete it */
			p->snext = node->snext;
			test_splist(p);
			return;
		}
	}
	
	ptr_after = n;
	ptr_before = read_ptr(n);
	
	/* Count nodes before us */
	for (n = ptr_before; n != node; n = sp_next(n))
	{
		p = n;
		count_before++;
	}
	
	/* Doesn't overlap pointer? */
	if (count_before > SP_COUNT)
	{
		/* Just delete it */
		p->snext = node->snext;
		test_splist(p);
		return;
	}
	
	n = read_ptr(ptr_before);
	p->snext = (uintptr_t) sp_next(node) + (p->snext & ~SP_MASK);
	
	/* Will fit */
	if (count_before + count_after > SP_COUNT)
	{
		write_ptr(ptr_before, n);
		test_splist(p);
		return;
	}
	
	/* Remove pointer bit-mark */
	ptr_before->snext -= SP_PTR;
	
	if (ptr_before == ptr_after)
	{
		/* Pointer no longer fits */
		test_splist(p);
		return;
	}
	
	/* Jump over removed pointer */
	nice_split(n);
}

#define TEST_SIZE	(1 << 10)
struct test_struct
{
	splist s;
};

static struct test_struct test_array[TEST_SIZE];

int main(void)
{
	int i;
	int loc;
		
	splist s;
	
	sp_init(&s);
	
	/* Randomly insert and delete nodes from the list */
	for (i = 0; i < (1 << 25); i++)
	{
		loc = random() & (TEST_SIZE - 1);
		
		if (test_array[loc].s.snext)
		{
			sp_del(&test_array[loc].s);
			
			/* Make errors easier to find */
			test_array[loc].s.snext = 0;
		}
		else
		{
			sp_insert(&s, &test_array[loc].s);
		}
	}

	return 0;
}
