package main import ( "encoding/json" "fmt" "io" "log" "net/http" "net/netip" "net/url" "os" "path" "path/filepath" "text/template" "time" "github.com/bitly/go-simplejson" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/gorilla/sessions" "github.com/gorilla/websocket" "github.com/julienschmidt/httprouter" "github.com/markbates/goth" "github.com/markbates/goth/gothic" "github.com/markbates/goth/providers/discord" "golang.org/x/exp/slices" ) const COOKIE_NAME = "_dndmusicbot" var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, } func init() { key := []byte(os.Getenv("SESSION_SECRET")) maxAge := 86400 * 30 // 30 days isProd := true // Set to true when serving over https store := sessions.NewCookieStore([]byte(key)) store.MaxAge(maxAge) store.Options.Path = "/" store.Options.HttpOnly = true // HttpOnly should always be enabled store.Options.Secure = isProd gothic.Store = store goth.UseProviders( discord.New(config.GetString("discord.id"), config.GetString("discord.secret"), config.GetString("discord.callback"), discord.ScopeIdentify, discord.ScopeEmail, discord.ScopeGuilds, discord.ScopeReadGuilds), ) app.router = httprouter.New() app.router.GET("/", auth(app.Index)) app.router.GET("/playlists", auth(app.Web_Playlists)) app.router.GET("/ambiance", auth(app.Web_Ambiance)) app.router.GET("/play/:playlist", auth(app.Play)) app.router.GET("/reset", auth(app.Reset)) app.router.GET("/public/*js", auth(app.ServeFiles)) app.router.GET("/css/*css", auth(app.ServeFiles)) app.router.GET("/auth/callback", app.AuthHandler) app.router.GET("/youtube/:id", local(ProxyTube)) app.router.HandlerFunc("GET", "/ws", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Printf("WS connection from %v\n", r.RemoteAddr) conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println(err) return } err = handleWS(conn) if err != nil { log.Printf("WS connection closed, %v\n", r.RemoteAddr) } })) go func() { log.Fatal(http.ListenAndServe(":"+config.GetString("web.port"), app.router)) }() } type IndexData struct { Playlists []Playlist Ambiance []Ambiance } func (app App) AuthHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { user, err := gothic.CompleteUserAuth(w, r) if err != nil { fmt.Fprintln(w, err) return } profile_url := &url.URL{Scheme: "https", Host: "discord.com", Path: "/api/users/@me", } member_url := fmt.Sprintf("%s/guilds/%s/member", profile_url.String(), config.GetString("discord.guild")) member_req, err := http.NewRequest("GET", member_url, nil) if err != nil { fmt.Fprintln(w, err) return } member_req.Header.Add("Authorization", "Bearer "+user.AccessToken) member_resp, err := http.DefaultClient.Do(member_req) if err != nil { fmt.Fprintln(w, err) return } member_body, _ := io.ReadAll(member_resp.Body) defer member_resp.Body.Close() member, err := simplejson.NewJson(member_body) if err != nil { fmt.Fprintln(w, err) return } groups, err := member.GetPath("roles").StringArray() if err != nil { fmt.Fprintln(w, err) return } ok := false for _, group := range config.GetStringSlice("discord.groups") { if slices.Contains(groups, group) { ok = true } } if ok { token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "nbf": time.Now().Unix(), "exp": time.Now().Add(time.Hour * 720).Unix(), }) tokenString, err := token.SignedString([]byte(os.Getenv("SESSION_SECRET"))) if err != nil { fmt.Fprintln(w, err) return } cookie := new(http.Cookie) cookie.Name = COOKIE_NAME cookie.Value = tokenString cookie.Path = "/" cookie.Secure = true cookie.HttpOnly = true cookie.MaxAge = 86400 * 30 http.SetCookie(w, cookie) http.RedirectHandler("/", http.StatusFound).ServeHTTP(w, r) } } // Middleware to check that the connection is local. func local(n httprouter.Handle) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { addr, err := netip.ParseAddrPort(r.RemoteAddr) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } switch { case addr.Addr().IsLoopback(): fallthrough case addr.Addr().IsPrivate(): n(w, r, ps) default: http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) } } } func auth(n httprouter.Handle) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { if os.Getenv("APP_ENV") == "test" { n(w, r, ps) return } values := r.URL.Query() values.Add("provider", "discord") r.URL.RawQuery = values.Encode() auth_cookie, err := r.Cookie(COOKIE_NAME) if err == nil { token, err := jwt.Parse(auth_cookie.Value, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(os.Getenv("SESSION_SECRET")), nil }) if err != nil { fmt.Fprintln(w, err) return } if token.Valid { n(w, r, ps) } else { gothic.BeginAuthHandler(w, r) return } } else if err == http.ErrNoCookie { gothic.BeginAuthHandler(w, r) return } else if err != nil { fmt.Fprintln(w, err) return } } } func (app App) ServeFiles(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { filePath := filepath.Join(".", r.URL.Path) file, err := os.Open(filePath) if err != nil { log.Println(err) http.Error(w, "no such file", http.StatusNotFound) return } defer file.Close() fileStat, err := os.Stat(filePath) if err != nil { log.Println(err) http.Error(w, "unable to get file stat", http.StatusInternalServerError) } _, filename := path.Split(filePath) t := fileStat.ModTime() http.ServeContent(w, r, filename, t, file) } func (app App) Index(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { playlists, err := app.GetPlaylists() if err != nil { http.Error(w, "Unable to get playlists. "+err.Error(), http.StatusInternalServerError) } amblist, err := GetAmbiances() if err != nil { log.Println(err) return } data := IndexData{playlists, amblist} t := template.Must(template.New("index.tmpl").ParseFiles("tmpl/index.tmpl")) err = t.Execute(w, data) if err != nil { http.Error(w, "Unable to load template. "+err.Error(), http.StatusInternalServerError) } } func (app App) Web_Playlists(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { playlists, err := app.GetPlaylists() if err != nil { http.Error(w, "Unable to get playlists. "+err.Error(), http.StatusInternalServerError) } err = json.NewEncoder(w).Encode(playlists) if err != nil { http.Error(w, "Unable to get playlists. "+err.Error(), http.StatusInternalServerError) } } func (app App) Web_Ambiance(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { ambiance, err := GetAmbiances() if err != nil { http.Error(w, "Unable to get ambiance. "+err.Error(), http.StatusInternalServerError) } err = json.NewEncoder(w).Encode(ambiance) if err != nil { http.Error(w, "Unable to get playlists. "+err.Error(), http.StatusInternalServerError) } } func (app *App) Play(w http.ResponseWriter, r *http.Request, p httprouter.Params) { plname := p.ByName("playlist") if plname == "reset" { app.events.Emit("stop", nil) return } plid, err := uuid.ParseBytes([]byte(plname)) if err != nil { http.Error(w, "Unable to parse uuid. "+err.Error(), http.StatusInternalServerError) } app.events.Emit("new_playlist", plid) } func (app *App) Add(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { r.ParseForm() } func (app *App) Reset(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { app.events.Emit("stop", nil) }