#!/usr/local/bin/pike

inherit Stdio.Port;
Thread.Fifo jobs=Thread.Fifo();

void server(string host, int port)
{
  int count;
  while(1)
  {
    object(Stdio.File) server=Stdio.File();
    if(!server->connect(host,port))
    {
      werror("Connection to %s:%d failed!\n",host,port);
      do {
	sleep(5);
      } while( ! jobs->size() );
      continue;
    }

    object(Stdio.File) io;
    string ret;

    mixed foo=jobs->read();

    if(arrayp(foo))
    {
      [ ret, io ] = foo;
    } else {
      io=foo;
      sscanf(io->read(4),"%4c",int args);
      array(string) cmd=({});
      for(int e=0;e<args;e++)
      {
	sscanf(io->read(4),"%4c",int len);
	string tmp2=io->read(len);
	cmd+=({tmp2});
      }

      werror("%4d %s:%d %s\n",++count,host,port,cmd*" ");

      ret=sprintf("%4c",sizeof(cmd));
      foreach(cmd, string x)
	ret+=sprintf("%4c%s",strlen(x),x);
    }
    
    if(server->write(ret) != strlen(ret))
    {
      destruct(server);
      jobs->write( ({ ret, io }) );
      werror("%s:%d Write failed!\n",host,port);
      destruct(server);
      continue;
    }

    if(server->proxy)
      server->proxy(io);
    else
      thread_create(lambda(object server, object io)
		    {
		      while(string s=io->read(1000,1))
			server->write(s);
		    },server,io);

    int len;
    do
    {
      string tmp;
      len=0;
      sscanf(tmp=server->read(4),"%4c",len);
      io->write(tmp+server->read(len));
    }while(len);
    io->write(server->read(4));
    io->close("rw");
    server->close("rw");
    destruct(io);
    destruct(server);
  }
}


void handle_connections(string *hosts)
{
  while(1)
  {
    if(object io=accept())
    {
      sscanf(io->query_address(),"%s ",string ip);
      if(search(hosts, ip)==-1)
      {
	destruct(io);
	continue;
      }
      jobs->write(io);
    }else{
      werror("Accept failed "+errno()+"\n");
    }
  }
}

int main(int argc, string *argv)
{
  if(argc<2)
  {
    werror("Usage: sprshd <port> <hosts>\n");
    exit(1);
  }
  if(!bind((int)argv[1]))
  {
    werror("Failed to bind port.\n");
    exit(1);
  }

  string *hosts=({});
  for(int e=2;e<sizeof(argv);e++)
  {
    if(sscanf(argv[e],"%s:%d\n",string host, int port))
    {
      werror("Starting server for %s:%d\n",host,port);
      thread_create(server,host,port);
      continue;
    }
    if(sscanf(argv[e],"%*d.%*d")==2)
    {
      hosts+=({argv[e]});
      continue;
    }
    mixed tmp=gethostbyname(argv[e]);
    if(!tmp)
    {
      werror("Gethostbyname("+argv[e]+") failed.\n");
      exit(1);
    }
    hosts+=tmp[1];
  }

  write("Sprshd load balancer ready. ("+version()+").\n");

  handle_connections(hosts);
  return 0;
}
