
# XXX: Add a test for audio seeking.
import errno
import os
import unittest

import numm
import numpy

import util

test_sound = util.test_file('test.wav')

class NummSoundTest(util.NummTestCase):
    def test_sound2np(self):
        a = numm.sound2np(test_sound)
        ref = numpy.arange(128).repeat(2).reshape((128, 2))
        self.assertEqual(numpy.int16, a.dtype)
        self.assertArrayEqual(ref, a)

    def test_sound2np_n_frames(self):
        a = numm.sound2np(test_sound, n_frames=64)
        ref = numpy.arange(64).repeat(2).reshape((64, 2))
        self.assertEqual(numpy.int16, a.dtype)
        self.assertArrayEqual(ref, a)

    def test_np2sound(self):
        a = numpy.arange(128, dtype=numpy.int16).repeat(2).reshape((128, 2))

        with util.Tmp(suffix=".wav") as path:
            numm.np2sound(a, path)
            b = numm.sound2np(path)

        self.assertEqual(a.dtype, b.dtype)
        self.assertArrayEqual(a, b)

    def test_sound2np_seek_past_eof(self):
        a = numm.sound2np(test_sound, start=8192)
        self.assertEqual((0, 2), a.shape)

    def test_sound2np_seek(self):
        whole = numm.sound2np(test_sound)
        part = numm.sound2np(test_sound, start=5)
        self.assertEqual(whole[5:].shape, part.shape)

    def test_sound_chunks(self):
        frames = list(numm.sound_chunks(test_sound))
        self.assertEqual((128, 2), numpy.concatenate(frames).shape)
        self.assertEqual(frames[0].timestamp, 0)

    def test_precise_sound_chunks(self):
        frames = list(numm.sound.sound_chunks(test_sound, chunk_size=32))
        self.assertEqual((128, 2), numpy.concatenate(frames).shape)
        # Each frame has 32 samples.
        self.assertEqual([32, 32, 32, 32], map(len, frames))
        self.assertEqual(frames[0].timestamp, 0)

    def test_precise_sound_chunks_padding(self):
        frames = list(numm.sound.sound_chunks(test_sound, chunk_size=40))
        # Length is longer because of chunk padding.
        self.assertEqual((160, 2), numpy.concatenate(frames).shape)
        # Each frame has 40 samples. Last frame is padded.
        self.assertEqual([40, 40, 40, 40], map(len, frames))
        self.assertEqual(0, frames[-1][-12:].sum())
        self.assertEqual(frames[0].timestamp, 0)

    @unittest.skipUnless(os.name == "posix", "requires unix")
    def test_fd_cleanup(self):
        import resource # module requires unix
        def get_fds():
            fds = []
            for i in range(resource.getrlimit(resource.RLIMIT_NOFILE)[0]):
                try:
                    os.fstat(i)
                    fds.append(i)
                except OSError, e:
                    if e.errno != errno.EBADF:
                        raise e
            return set(fds)
        # gstreamer seems to eat some fds on the first pipeline, so run once first
        numm.sound2np(test_sound)
        old_fds = get_fds()
        numm.sound2np(test_sound)
        self.assertLessEqual(get_fds(), old_fds)
